├── README.md
├── images
└── SVdecay.jpg
├── src
├── README.md
├── codings
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── coding.cpython-36.pyc
│ │ ├── qsgd.cpython-36.pyc
│ │ ├── qsvd.cpython-36.pyc
│ │ └── svd.cpython-36.pyc
│ ├── coding.py
│ ├── qsgd.py
│ ├── svd.py
│ └── utils.py
├── data
│ └── data_prepare.py
├── data_loader_ops
│ ├── __init__.py
│ └── my_data_loader.py
├── data_prepare.sh
├── datasets.py
├── distributed_evaluator.py
├── distributed_nn.py
├── distributed_worker.py
├── evaluate_pytorch.sh
├── model_ops
│ ├── __init__.py
│ ├── alexnet.py
│ ├── densenet.py
│ ├── fc_nn.py
│ ├── lenet.py
│ ├── resnet.py
│ ├── resnet_split.py
│ └── vgg.py
├── nn_ops.py
├── optim
│ ├── __init__.py
│ ├── adam.py
│ └── sgd.py
├── output
│ └── models
│ │ └── README.md
├── run_pytorch.sh
├── single_machine.py
├── sync_replicas_master_nn.py
├── tiny_tuning_parser.py
├── tune.sh
└── utils.py
└── tools
├── config
├── hosts
├── hosts_address
├── hosts_alias
├── install.sh
├── local_script.sh
├── pre_run.sh
├── pytorch_ec2.py
├── remote_script.sh
└── update_git_dir.sh
/README.md:
--------------------------------------------------------------------------------
1 | # Atomo: Communication-efficient Learning via Atomic Sparsification
2 | This repository contains source code for Atomo, a general framework for atomic sparsification of stochastic gradients. Please check [the full paper](http://papers.nips.cc/paper/8191-atomo-communication-efficient-learning-via-atomic-sparsification) for detailed information about this project.
3 |
4 | ## Overview:
5 | ATOMO is a general framework for atomic sparsification of stochastic gradients. Given a gradient, an atomic decomposition,
6 | and a sparsity budget, ATOMO gives a random unbiased sparsification of the atoms minimizing variance. ATOMO sets up and optimally solves a meta-optimization that minimizes the variance of the sparsified gradient, subject to the constraints
7 | that it is sparse on the atomic basis, and also is an unbiased estimator of the input.
8 |
9 |

10 |
11 | ## Depdendencies:
12 | Tested stable depdencises:
13 | * python 2.7 (Anaconda)
14 | * PyTorch 0.3.0 (*please note that, we're moving to PyTorch 0.4.0, and 1.0.x*)
15 | * torchvision 0.1.18
16 | * MPI4Py 0.3.0
17 | * python-blosc 1.5.0
18 |
19 | We highly recommend installing an [Anaconda](https://www.continuum.io/downloads) environment.
20 | You will get a high-quality BLAS library (MKL) and you get a controlled compiler version regardless of your Linux distro.
21 |
22 | We provide [this script](https://github.com/hwang595/ATOMO/blob/master/tools/pre_run.sh) to help you with building all dependencies. To do that you can run:
23 | ```
24 | bash ./tools/pre_run.sh
25 | ```
26 |
27 | ## Cluster Setup:
28 | For running on distributed cluster, the first thing you need do is to launch AWS EC2 instances.
29 | ### Launching Instances:
30 | [This script](https://github.com/hwang595/ps_pytorch/blob/master/tools/pytorch_ec2.py) helps you to launch EC2 instances automatically, but before running this script, you should follow [the instruction](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-getting-started.html) to setup AWS CLI on your local machine.
31 | After that, please edit this part in `./tools/pytorch_ec2.py`
32 | ``` python
33 | cfg = Cfg({
34 | "name" : "PS_PYTORCH", # Unique name for this specific configuration
35 | "key_name": "NameOfKeyFile", # Necessary to ssh into created instances
36 | # Cluster topology
37 | "n_masters" : 1, # Should always be 1
38 | "n_workers" : 8,
39 | "num_replicas_to_aggregate" : "8", # deprecated, not necessary
40 | "method" : "spot",
41 | # Region speficiation
42 | "region" : "us-west-2",
43 | "availability_zone" : "us-west-2b",
44 | # Machine type - instance type configuration.
45 | "master_type" : "m4.2xlarge",
46 | "worker_type" : "m4.2xlarge",
47 | # please only use this AMI for pytorch
48 | "image_id": "ami-xxxxxxxx", # id of AMI
49 | # Launch specifications
50 | "spot_price" : "0.15", # Has to be a string
51 | # SSH configuration
52 | "ssh_username" : "ubuntu", # For sshing. E.G: ssh ssh_username@hostname
53 | "path_to_keyfile" : "/dir/to/NameOfKeyFile.pem",
54 |
55 | # NFS configuration
56 | # To set up these values, go to Services > ElasticFileSystem > Create new filesystem, and follow the directions.
57 | #"nfs_ip_address" : "172.31.3.173", # us-west-2c
58 | #"nfs_ip_address" : "172.31.35.0", # us-west-2a
59 | "nfs_ip_address" : "172.31.14.225", # us-west-2b
60 | "nfs_mount_point" : "/home/ubuntu/shared", # NFS base dir
61 | ```
62 | For setting everything up on EC2 cluster, the easiest way is to setup one machine and create an AMI. Then use the AMI id for `image_id` in `pytorch_ec2.py`. Then, launch EC2 instances by running
63 | ```
64 | python ./tools/pytorch_ec2.py launch
65 | ```
66 | After all launched instances are ready (this may take a while), getting private ips of instances by
67 | ```
68 | python ./tools/pytorch_ec2.py get_hosts
69 | ```
70 | this will write ips into a file named `hosts_address`, which looks like
71 | ```
72 | 172.31.16.226 (${PS_IP})
73 | 172.31.27.245
74 | 172.31.29.131
75 | 172.31.18.108
76 | 172.31.18.174
77 | 172.31.17.228
78 | 172.31.16.25
79 | 172.31.30.61
80 | 172.31.29.30
81 | ```
82 | After generating the `hosts_address` of all EC2 instances, running the following command will copy your keyfile to the parameter server (PS) instance whose address is always the first one in `hosts_address`. `local_script.sh` will also do some basic configurations e.g. clone this git repo
83 | ```
84 | bash ./tool/local_script.sh ${PS_IP}
85 | ```
86 | ### SSH related:
87 | At this stage, you should ssh to the PS instance and all operation should happen on PS. In PS setting, PS should be able to ssh to any compute node, [this part](https://github.com/hwang595/ATOMO/blob/master/tools/remote_script.sh#L8-L16) dose the job for you by running (after ssh to the PS)
88 | ```
89 | bash ./tools/remote_script.sh
90 | ```
91 |
92 | ## Prepare Datasets
93 | We currently support [MNIST](http://yann.lecun.com/exdb/mnist/) and [Cifar10](https://www.cs.toronto.edu/~kriz/cifar.html) datasets. Download, split, and transform datasets by (and `./tools/remote_script.sh` dose this for you)
94 | ```
95 | bash ./src/data_prepare.sh
96 | ```
97 |
98 | ## Job Launching
99 | Since this project is built on MPI, tasks are required to be launched by PS (or master) instance. `run_pytorch.sh` wraps job-launching process up. Commonly used options (arguments) are listed as following:
100 |
101 | | Argument | Comments |
102 | | ----------------------------- | ---------------------------------------- |
103 | | `n` | Number of processes (size of cluster) e.g. if we have P compute node and 1 PS, n=P+1. |
104 | | `hostfile` | A directory to the file that contains Private IPs of every node in the cluster, we use `hosts_address` here as [mentioned before](#launching-instances). |
105 | | `lr` | Inital learning rate that will be use. |
106 | | `momentum` | Value of momentum that will be use. |
107 | | `network` | Types of deep neural nets, currently `LeNet`, `ResNet-18/32/50/110/152`, and `VGGs` are supported. |
108 | | `dataset` | Datasets use for training. |
109 | | `batch-size` | Batch size for optimization algorithms. |
110 | | `test-batch-size` | Batch size used during model evaluation. |
111 | | `comm-type` | A fake parameter, please always set it to be `Bcast`. |
112 | | `num-aggregate` | Number of gradients required for the PS to aggregate. |
113 | | `max-steps` | The maximum number of iterations to train. |
114 | | `svd-rank` | The expected rank of gradients ATOMO used (which is the same as the sparsity budget `s` in our paper). |
115 | | `epochs` | The maximal number of epochs to train (somehow redundant). |
116 | | `eval-freq` | Frequency of iterations to evaluation the model. |
117 | | `enable-gpu`| Training on CPU/GPU, if CPU please leave this argument empty. |
118 | |`train-dir` | Directory to save model checkpoints for evaluation. |
119 |
120 | ## Model Evaluation
121 | [Distributed evaluator](https://github.com/hwang595/ATOMO/blob/master/src/distributed_evaluator.py) will fetch model checkpoints from the shared directory and evaluate model on validation set.
122 | To evaluate model, you can run
123 | ```
124 | bash ./src/evaluate_pytorch.sh
125 | ```
126 | with specified arguments.
127 |
128 | Evaluation arguments are listed as following:
129 |
130 | | Argument | Comments |
131 | | ----------------------------- | ---------------------------------------- |
132 | | `eval-batch-size` | Batch size (on validation set) used during model evaluation. |
133 | | `eval-freq` | Frequency of iterations to evaluation the model, should be set to the same value as [run_pytorch.sh](https://github.com/hwang595/ps_pytorch/blob/master/src/run_pytorch.sh). |
134 | | `network` | Types of deep neural nets, should be set to the same value as [run_pytorch.sh](https://github.com/hwang595/ps_pytorch/blob/master/src/run_pytorch.sh). |
135 | | `dataset` | Datasets use for training, should be set to the same value as [run_pytorch.sh](https://github.com/hwang595/ps_pytorch/blob/master/src/run_pytorch.sh). |
136 | | `model-dir` | Directory to save model checkpoints for evaluation, should be set to the same value as [run_pytorch.sh](https://github.com/hwang595/ps_pytorch/blob/master/src/run_pytorch.sh). |
137 |
138 | ## Future Work
139 | Those are potential directions we are actively working on, stay tuned!
140 | * Explore the use of Atomo with Fourier decompositions, due to its utility and prevalence in signal processing.
141 | * Examine how we can sparsify and compress gradients in a joint fashion to further reduce communication costs.
142 | * Explore jointly sparsification of the SVD and and its singular vectors.
143 | * Integrate ATOMO to state-of-the-art PS (or distributed) frameworks e.g. [Ray](https://rise.cs.berkeley.edu/projects/ray/).
144 |
145 | ## Citation
146 |
147 | ```
148 | @inproceedings{wang2018atomo,
149 | title={ATOMO: Communication-efficient Learning via Atomic Sparsification},
150 | author={Wang, Hongyi and Sievert, Scott and Liu, Shengchao and Charles, Zachary and Papailiopoulos, Dimitris and Wright, Stephen},
151 | booktitle={Advances in Neural Information Processing Systems},
152 | pages={9871--9882},
153 | year={2018}
154 | }
155 | ```
156 |
--------------------------------------------------------------------------------
/images/SVdecay.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hwang595/ATOMO/8a21e54938bc4b0809293b306eade3ab9307ea70/images/SVdecay.jpg
--------------------------------------------------------------------------------
/src/README.md:
--------------------------------------------------------------------------------
1 | do remeber to transfer the `numpy.ndarray` to `numpy.float64` before push them to MPI send/receive buffer, otherwise there will be some data type transfer issues between numpy and MPI.
2 |
3 | # To use it on single machine:
4 | ```
5 | python single_machine.py --dataset=MNIST/Cifar10 --network=LeNet/Resnet --batch-size=${BATCH_SIZE}
6 | ```
7 |
8 | # To use it on distributed cluster:
9 | ```
10 | mpirun -n ${NUM_WORKERS} --hostfile=${HOST_DIR} python distributed_nn.py --dataset=MNIST/Cifar10 --network=LeNet/Resnet --batch-size=${BATCH_SIZE}
11 | ```
12 |
13 | # Run the whole thing automatically
14 | The first thing you need do is to launch AWS EC2 instances, you can do that using `tools/pytorch_ec2.py` by running the following command:
15 | ```
16 | python pytorch_ec2.py launch
17 | ```
18 | After the launch command are executed and all instances are initialized (this may cost several minutes), you need to fetch the host addresses:
19 | ```
20 | python pytorch_ec2.py get_hosts
21 | ```
22 | Then, copying essential configuration files and hosts files using the public address of master node (the first address in the `host` file):
23 | ```
24 | sh local_script.sh ${MASTER_PUB_ADDR}
25 | ```
26 | After that, launch to master node manually, running the remote script under `$HOME` dir:
27 | ```
28 | sh remote_script.sh
29 | ```
30 | This script will do the cluster setup and data preparation works for you.
31 |
--------------------------------------------------------------------------------
/src/codings/__init__.py:
--------------------------------------------------------------------------------
1 | from .coding import Coding
2 | from .svd import SVD
3 | from .qsgd import QSGD
4 | import utils
5 |
6 | __all__ = ['coding', 'svd', 'qsgd', 'utils']
--------------------------------------------------------------------------------
/src/codings/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hwang595/ATOMO/8a21e54938bc4b0809293b306eade3ab9307ea70/src/codings/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/src/codings/__pycache__/coding.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hwang595/ATOMO/8a21e54938bc4b0809293b306eade3ab9307ea70/src/codings/__pycache__/coding.cpython-36.pyc
--------------------------------------------------------------------------------
/src/codings/__pycache__/qsgd.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hwang595/ATOMO/8a21e54938bc4b0809293b306eade3ab9307ea70/src/codings/__pycache__/qsgd.cpython-36.pyc
--------------------------------------------------------------------------------
/src/codings/__pycache__/qsvd.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hwang595/ATOMO/8a21e54938bc4b0809293b306eade3ab9307ea70/src/codings/__pycache__/qsvd.cpython-36.pyc
--------------------------------------------------------------------------------
/src/codings/__pycache__/svd.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hwang595/ATOMO/8a21e54938bc4b0809293b306eade3ab9307ea70/src/codings/__pycache__/svd.cpython-36.pyc
--------------------------------------------------------------------------------
/src/codings/coding.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class Coding:
4 | def __init__(self, *args, **kwargs):
5 | self.codes = []
6 |
7 | def encode(self, grad, *args, **kwargs):
8 | raise NotImplementedError()
9 |
10 | def decode(self, code, *args, **kwargs):
11 | raise NotImplementedError()
--------------------------------------------------------------------------------
/src/codings/qsgd.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 | import numpy as np
3 | import numpy.linalg as LA
4 | from scipy import stats
5 | import torch
6 | import time
7 | import math
8 | from .coding import Coding
9 |
10 | import torch.nn.functional as F
11 |
12 |
13 | class QSGD(Coding):
14 | def __init__(self, scheme='qsgd', bucket_size=512, *args, **kwargs):
15 | self.scheme = scheme
16 | self._quantization_level = kwargs['quantization_level']
17 | self._bucket_size=bucket_size
18 |
19 | def encode(self, v, **kwargs):
20 | if isinstance(v, (torch.Tensor, torch.cuda.FloatTensor)):
21 | w = v.cpu().numpy().flat[:]
22 | elif isinstance(v, np.ndarray):
23 | w = v.flat[:]
24 | else:
25 | raise ValueError("Object passed to encode not ndarray or torch.Tensor")
26 |
27 | if 'neo_bucket_size' in kwargs.keys():
28 | bucket_size = min(self._bucket_size, kwargs['neo_bucket_size'])
29 | else:
30 | bucket_size = self._bucket_size
31 | # Apply bucketing
32 | if bucket_size != 0:
33 | code_buckets = []
34 | shape = v.shape
35 | neo_kwargs = {'neo_bucket_size': 0}
36 | buckets = np.split(w, (w.shape[0] + bucket_size - 1) / bucket_size)
37 | for bucket in buckets:
38 | code = self.encode(bucket, **neo_kwargs)
39 | code_buckets.append(code)
40 | return {'code_buckets': code_buckets, 'shape': shape}
41 |
42 | if self.scheme == 'qsgd':
43 | norm = LA.norm(v)
44 | elif self.scheme == 'terngrad':
45 | norm = np.linalg.norm(w, ord=np.inf)
46 | limit = grad_clip_limit(w, clip_factor=2.5)
47 | w = np.clip(w, -limit, limit)
48 |
49 | s = (1 << self._quantization_level) - 1
50 | shape = v.shape
51 |
52 | num_int_each_64_bits = int(64 / (2 + self._quantization_level)) # number of element stored / 64-bits
53 | num_section = num_int_each_64_bits
54 | len_each_section = int((w.shape[0] + num_section - 1) / num_section) # number of 64-bits to ship w vector
55 | w = np.pad(w, (0, len_each_section * num_section - w.shape[0]), mode='constant') # pad w to length of total elements
56 |
57 | sign_array = np.sign(w)
58 | sign_array += 1 # -1, 0, 1 to 0, 1, 2
59 | sign_array = sign_array.astype('uint64')
60 | normalization_array = np.abs(w) / norm * s
61 |
62 | truncated_array = normalization_array.astype(int) # l <= \frac{s \|w\|_i}{\|w\|_2} <= l+1
63 | prob_array = normalization_array - truncated_array # \frac{s \|w\|_i}{\|w\|_2} - l i.e. p function p(a, s) = as - l
64 | dice_array = np.random.rand(len(prob_array))
65 | xi_array = truncated_array + (dice_array > prob_array) # l+1 or l
66 | xi_array = xi_array.astype('uint64')
67 |
68 | xi_array = xi_array.reshape((num_section, len_each_section))
69 | sign_array = sign_array.reshape((num_section, len_each_section))
70 |
71 | neo_array = np.zeros(len_each_section)
72 | neo_array = neo_array.astype('uint64')
73 |
74 | for i in range(num_int_each_64_bits):
75 | xi = xi_array[i]
76 | sign = sign_array[i]
77 | neo_array <<= (2 + self._quantization_level)
78 | neo_array = neo_array | (sign << self._quantization_level | xi)
79 |
80 | code = {'neo': neo_array, 'norm': norm, 'quantization_level': self._quantization_level,
81 | 'len_each_section': len_each_section, 'num_int_each_64_bits': num_int_each_64_bits,
82 | 'shape': shape}
83 |
84 | if kwargs.pop('timings', False):
85 | data = {}
86 | return code, data
87 | return code
88 |
89 | def decode(self, code, cuda=False, implementation='numpy', codes=[], **kwargs):
90 | """
91 | Decode the coding.
92 | ## NumPy
93 | 'comm_wait': 0.0728750228881836,
94 | 'decode_time': 0.1349341869354248,
95 | 'example_to_gpu': 0.0006515979766845703,
96 | 'grad_compute_time': 0.5815503597259521,
97 | 'grad_forward_pass': 0.23496603965759277,
98 | 'grad_variance_increase': 31.754316389320049,
99 | 'iallgather_prepare_time': 0.017401456832885742,
100 | 'isend_time': 0.029105424880981445,
101 | ## PT GPU
102 | """
103 | if self.scheme == 'terngrad' and len(codes) > 0:
104 | code['norm'] = self._get_max_norm(codes)
105 |
106 | if implementation == 'numpy':
107 | if 'neo_bucket_size' in kwargs.keys():
108 | bucket_size = min(self._bucket_size, kwargs['neo_bucket_size'])
109 | else:
110 | bucket_size = self._bucket_size
111 | # Decode from bucketing
112 | if bucket_size != 0:
113 | v_list = []
114 | neo_kwargs = {'neo_bucket_size': 0}
115 | for code_bucket in code['code_buckets']:
116 | v = self.decode(code=code_bucket, cuda=cuda, implementation=implementation, codes=codes, **neo_kwargs)
117 | v_list.extend(v)
118 | v = np.array(v_list)
119 | v = v.reshape(code['shape'])
120 | else:
121 | norm = code['norm']
122 | s = (1 << self._quantization_level) - 1
123 |
124 | real_size = np.prod(code['shape'])
125 |
126 | neo_array = code['neo'].astype('uint64')
127 | num_int_each_64_bits = code['num_int_each_64_bits']
128 | num_section = num_int_each_64_bits
129 | len_each_section = code['len_each_section']
130 | xi_array = np.ones((num_section, len_each_section))
131 | sign_array = np.ones((num_section, len_each_section))
132 | mask_for_xi = (1 << self._quantization_level) - 1
133 | mask_for_sign = 3 << self._quantization_level
134 | for i in range(num_int_each_64_bits)[::-1]:
135 | sign_array[i] = (neo_array & mask_for_sign) >> self._quantization_level
136 | xi_array[i] = neo_array & mask_for_xi
137 | neo_array >>= (2 + self._quantization_level)
138 |
139 | xi_array = xi_array.reshape(-1).astype('uint64')
140 | sign_array = sign_array.reshape(-1).astype('int8')
141 | sign_array -= 1
142 | v = sign_array * xi_array * norm / s
143 |
144 | v = v[:real_size]
145 | v = v.reshape(code['shape'])
146 | else:
147 | raise ValueError('Whoops, implementation')
148 | v = torch.Tensor(v)
149 | if cuda:
150 | v = v.cuda()
151 | return v
152 |
153 | def _get_max_norm(self, codes):
154 | scalars = [code['norm'] for code in codes]
155 | return max(scalars)
156 |
157 | def encode_cuda(self, v, **kwargs):
158 | if isinstance(v, torch.cuda.FloatTensor):
159 | w = v.view(-1)
160 | else:
161 | raise ValueError("Object passed wasn't set on GUDA, please check CUDA availability!")
162 |
163 | #norm = LA.norm(v)
164 | norm = torch.norm(w)
165 |
166 | s = (1 << self._quantization_level) - 1
167 | shape = v.size()
168 |
169 | num_int_each_64_bits = int(64 / (2 + self._quantization_level)) # number of element stored / 64-bits
170 | num_section = num_int_each_64_bits
171 | len_each_section = int((w.shape[0] + num_section - 1) / num_section) # number of 64-bits to ship w vector
172 |
173 | w = F.pad(w, (0, len_each_section * num_section - w.size()[0]), 'constant', 0)
174 |
175 | sign_array = torch.sign(w)
176 |
177 | sign_array += 1 # -1, 0, 1 to 0, 1, 2
178 | #sign_array = sign_array.astype('uint64')
179 | sign_array = sign_array.to(dtype=torch.int64)
180 |
181 | normalization_array = torch.abs(w) / norm * s
182 |
183 | #truncated_array = normalization_array.astype(int)
184 | truncated_array = normalization_array.to(dtype=torch.int) # l <= \frac{s \|w\|_i}{\|w\|_2} <= l+1
185 |
186 | prob_array = normalization_array - truncated_array.float() # \frac{s \|w\|_i}{\|w\|_2} - l i.e. p function p(a, s) = as - l
187 |
188 | dice_array = torch.rand(len(prob_array)).to(torch.device("cuda"))
189 |
190 | xi_array = truncated_array + (dice_array > prob_array).to(dtype=torch.int) # l+1 or l
191 | xi_array = xi_array.to(dtype=torch.int64)
192 | xi_array = xi_array.view((num_section, len_each_section))
193 |
194 | sign_array = sign_array.view((num_section, len_each_section))
195 |
196 | neo_array = torch.zeros(len_each_section).to(dtype=torch.int64).to(torch.device("cuda"))
197 |
198 | for i in range(num_int_each_64_bits):
199 | xi = xi_array[i]
200 | sign = sign_array[i]
201 | neo_array *= 2**(2 + self._quantization_level)
202 | sign *= 2**self._quantization_level
203 | sign += xi
204 | neo_array += sign.to(dtype=torch.int64)
205 |
206 | code = {'neo': neo_array, 'norm': norm, 'quantization_level': self._quantization_level,
207 | 'len_each_section': len_each_section, 'num_int_each_64_bits': num_int_each_64_bits,
208 | 'shape': shape}
209 | return code
210 |
211 |
212 | def grad_clip_limit(grad, clip_factor=2.5):
213 | """ Get the scalers."""
214 | if clip_factor > 1.0e-5:
215 | return clip_factor * np.std(grad.flat[:])
216 | return np.max(np.abs(grad.flat[:]))
217 |
218 |
219 | if __name__ == "__main__":
220 | a_cpu = torch.randn(20)
221 | a_cuda = a_cpu.to(torch.device("cuda"))
222 |
223 | kwargs = {'quantization_level':8}
224 | coder = QSGD(bucket_size=0, **kwargs)
225 | print(a_cuda.dtype)
226 | code_cpu = coder.encode(a_cpu)
227 | code_cuda = coder.encode_cuda(a_cuda)
228 | print("CPU compression: {}, Type: {}".format(code_cpu['neo'], code_cpu['neo'].dtype))
229 | print("")
230 | print("CUDA compression: {}, Type: {}".format(code_cuda['neo'], code_cuda['neo'].dtype))
--------------------------------------------------------------------------------
/src/codings/svd.py:
--------------------------------------------------------------------------------
1 | # import torch
2 | import numpy as np
3 | import numpy.linalg as LA
4 | import warnings
5 | import sys
6 | import time
7 | import torch
8 |
9 | from .coding import Coding
10 | from .utils import nuclear_indicator, l1_indicator
11 |
12 | def _resize_to_2d(x):
13 | """
14 | x.shape > 2
15 | If x.shape = (a, b, *c), assumed that each one of (a, b) pairs has relevant information in c.
16 | """
17 | shape = x.shape
18 | if x.ndim == 1:
19 | n = x.shape[0]
20 | return x.reshape((n//2, 2))
21 | if all([s == 1 for s in shape[2:]]):
22 | return x.reshape((shape[0], shape[1]))
23 | # each of (a, b) has related features
24 | x = x.reshape((shape[0], shape[1], -1))
25 | # stack those related features into a tall matrix
26 | x_tmp = x.reshape((shape[0]*shape[1], -1))
27 | tmp_shape = x_tmp.shape
28 | return x_tmp.reshape((int(tmp_shape[0]/2), int(tmp_shape[1]*2)))
29 |
30 | def _resize_to_2d_cuda(x):
31 | """
32 | x.shape > 2
33 | If x.shape = (a, b, *c), assumed that each one of (a, b) pairs has relevant information in c.
34 | """
35 | shape = x.size()
36 | if x.dim() == 1:
37 | n = x.size()[0]
38 | return x.reshape((n//2, 2))
39 | if all([s == 1 for s in shape[2:]]):
40 | return x.reshape((shape[0], shape[1]))
41 | # each of (a, b) has related features
42 | x = x.reshape((shape[0], shape[1], -1))
43 | # stack those related features into a tall matrix
44 | x_tmp = x.reshape((shape[0]*shape[1], -1))
45 | tmp_shape = x_tmp.shape
46 | return x_tmp.reshape((int(tmp_shape[0]/2), int(tmp_shape[1]*2)))
47 |
48 |
49 | def _sample_svd(s, rank=0):
50 | if s[0] < 1e-6:
51 | return [0], np.array([1.0])
52 | probs = s / s[0] if rank == 0 else rank * s / s.sum()
53 | for i, p in enumerate(probs):
54 | if p > 1:
55 | probs[i]=1
56 | sampled_idx = []
57 | sample_probs = []
58 | for i, p in enumerate(probs):
59 | #if np.random.rand() < p:
60 | # random sampling from bernulli distribution
61 | if np.random.binomial(1, p):
62 | sampled_idx += [i]
63 | sample_probs += [p]
64 | rank_hat = len(sampled_idx)
65 | if rank_hat == 0: # or (rank != 0 and np.abs(rank_hat - rank) >= 3):
66 | return _sample_svd(s, rank=rank)
67 | return np.array(sampled_idx, dtype=int), np.array(sample_probs)
68 |
69 |
70 | class SVD(Coding):
71 | def __init__(self, compress=True, rank=0, random_sample=True,
72 | fetch_indicator=None, *args, **kwargs):
73 | self.svd_rank = rank
74 | self.random_sample = random_sample
75 | self.compress = compress
76 | # fetch indicator or not
77 | self.__fetch_indicator = fetch_indicator
78 |
79 | def encode(self, grad, **kwargs):
80 | # move to CPU; SVD is 5x faster on CPU (at least in torch)
81 | if not self.compress:
82 | shape = list(grad.shape)
83 | return {'grad': grad, 'encode': False}#, {}
84 |
85 | orig_size = list(grad.shape)
86 | ndims = grad.ndim
87 | reshaped_flag = False
88 | if ndims != 2:
89 | grad = _resize_to_2d(grad)
90 | shape = list(grad.shape)
91 | ndims = len(shape)
92 | reshaped_flag = True
93 |
94 | if ndims == 2:
95 | u, s, vT = LA.svd(grad, full_matrices=False)
96 |
97 | if self.__fetch_indicator:
98 | nuclear_indicator = nuclear_indicator(grad, s)
99 | l1_indicator = l1_indicator(grad)
100 | print("Step: {}, Nuclear Indicator: {}, L1 Indicator: {}".format(
101 | kwargs['step'], nuclear_indicator, l1_indicator))
102 |
103 | if self.random_sample:
104 | i, probs = _sample_svd(s, rank=self.svd_rank)
105 | u = u[:, i]
106 | s = s[i] / probs
107 | # v = v[:, i]
108 | vT = vT[i, :]
109 | elif self.svd_rank > 0:
110 | u = u[:, :self.svd_rank]
111 | s = s[:self.svd_rank]
112 | # v = v[:, :self.svd_rank]
113 | vT = vT[:self.svd_rank, :]
114 |
115 | return {'u': u, 's': s, 'vT': vT, 'orig_size': orig_size,
116 | 'reshaped': reshaped_flag, 'encode': True,
117 | 'rank': self.svd_rank}
118 | return {'grad': grad, 'encode': False}
119 |
120 | def encode_cuda(self, grad, device, **kwargs):
121 | if not isinstance(grad, torch.cuda.FloatTensor):
122 | raise ValueError("Object passed wasn't set on GUDA, please check CUDA availability!")
123 | # taking SVD on GPU
124 | if not self.compress:
125 | shape = list(grad.shape)
126 | return {'grad': grad, 'encode': False}#, {}
127 |
128 | orig_size = list(grad.size())
129 | ndims = grad.dim()
130 |
131 | reshaped_flag = False
132 |
133 | if ndims != 2:
134 | grad = _resize_to_2d_cuda(grad)
135 | shape = list(grad.size())
136 | ndims = len(shape)
137 | reshaped_flag = True
138 |
139 | if ndims == 2:
140 | #u, s, vT = LA.svd(grad, full_matrices=False)
141 | u, s, vT = torch.svd(grad, some=True)
142 |
143 | if self.random_sample:
144 | i, probs = _sample_svd(s, rank=self.svd_rank)
145 | u = u[:, i]
146 | s = s[i] / torch.tensor(probs).float().to(device)
147 | # v = v[:, i]
148 | vT = vT[i, :]
149 | elif self.svd_rank > 0:
150 | u = u[:, :self.svd_rank]
151 | s = s[:self.svd_rank]
152 | # v = v[:, :self.svd_rank]
153 | vT = vT[:self.svd_rank, :]
154 |
155 | return {'u': u, 's': s, 'vT': vT, 'orig_size': orig_size,
156 | 'reshaped': reshaped_flag, 'encode': True,
157 | 'rank': self.svd_rank}
158 | return {'grad': grad, 'encode': False}
159 |
160 | def decode(self, encode_output, cuda=False, **kwargs):
161 | if isinstance(encode_output, tuple) and len(encode_output) == 1:
162 | encode_output = encode_output[0]
163 | encode = encode_output.get('encode', False)
164 | if not encode:
165 | grad = encode_output['grad']
166 | grad = torch.Tensor(grad)
167 | if cuda:
168 | grad = grad.cuda(async=True)
169 | return grad
170 |
171 | u, s, vT = (encode_output[key] for key in ['u', 's', 'vT'])
172 | #grad = u @ np.diag(s) @ vT
173 | grad = np.dot(np.dot(u, np.diag(s)), vT)
174 | grad = torch.Tensor(grad)
175 | grad = grad.view(encode_output['orig_size'])
176 | if cuda:
177 | grad = grad.cuda(async=True)
178 | return grad
179 |
180 | if isinstance(u, np.ndarray):
181 | u = torch.Tensor(u)
182 | if isinstance(s, np.ndarray):
183 | s = torch.Tensor(s)
184 | if isinstance(vT, np.ndarray):
185 | vT = torch.Tensor(vT)
186 | if cuda:
187 | u = u.contiguous().cuda(async=True)
188 | s = s.contiguous().cuda(async=True)
189 | vT = vT.contiguous().cuda(async=True)
190 | # u = u.cuda(async=True)
191 | # s = s.cuda(async=True)
192 | # vT = vT.cuda(async=True)
193 | #grad_approx = u @ torch.diag(s) @ vT
194 | grad_approx = np.dot(np.dot(u, torch.diag(s)), vT)
195 | if encode_output.get('reshaped', False):
196 | grad_approx = grad_approx.view(encode_output['orig_size'])
197 | return grad_approx
--------------------------------------------------------------------------------
/src/codings/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def nuclear_indicator(grad, s):
4 | m, n = grad.shape
5 | return np.sum(s)*np.sqrt(m+n)
6 |
7 | def l1_indicator(grad):
8 | return np.linalg.norm(grad.reshape(-1), 1)
--------------------------------------------------------------------------------
/src/data/data_prepare.py:
--------------------------------------------------------------------------------
1 | """
2 | Since we need to initialize the dataset in parallel, pre-download data is necessary
3 | This script will do the job for you
4 | """
5 | import torch
6 | from torchvision import datasets, transforms
7 |
8 |
9 | if __name__ == "__main__":
10 | training_set_mnist = datasets.MNIST('./mnist_data', train=True, download=True,
11 | transform=transforms.Compose([
12 | transforms.ToTensor(),
13 | transforms.Normalize((0.1307,), (0.3081,))]))
14 | train_loader_mnist = torch.utils.data.DataLoader(training_set_mnist, batch_size=128, shuffle=True)
15 | test_loader_mnist = torch.utils.data.DataLoader(
16 | datasets.MNIST('./mnist_data', train=False, transform=transforms.Compose([
17 | transforms.ToTensor(),
18 | transforms.Normalize((0.1307,), (0.3081,))
19 | ])), batch_size=100, shuffle=True)
20 | trainset_cifar10 = datasets.CIFAR10(root='./cifar10_data', train=True,
21 | download=True, transform=transforms.Compose([
22 | transforms.ToTensor(),
23 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
24 | ]))
25 | train_loader_cifar10 = torch.utils.data.DataLoader(trainset_cifar10, batch_size=128,
26 | shuffle=True)
27 | test_loader_cifar10 = torch.utils.data.DataLoader(
28 | datasets.CIFAR10('./cifar10_data', train=False, transform=transforms.Compose([
29 | transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
30 | ])), batch_size=100, shuffle=True)
31 | # load training and test set here:
32 | training_set = datasets.CIFAR100(root='./cifar100_data', train=True,
33 | download=True, transform=transforms.Compose([
34 | transforms.ToTensor(),
35 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
36 | ]))
37 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=128,
38 | shuffle=True)
39 | testset = datasets.CIFAR100(root='./cifar100_data', train=False,
40 | download=True, transform=transforms.Compose([
41 | transforms.ToTensor(),
42 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
43 | ]))
44 | test_loader = torch.utils.data.DataLoader(testset, batch_size=1000,
45 | shuffle=False)
--------------------------------------------------------------------------------
/src/data_loader_ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hwang595/ATOMO/8a21e54938bc4b0809293b306eade3ab9307ea70/src/data_loader_ops/__init__.py
--------------------------------------------------------------------------------
/src/data_loader_ops/my_data_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.multiprocessing as multiprocessing
3 | from torch.utils.data.sampler import SequentialSampler, RandomSampler, BatchSampler
4 | import collections
5 | import sys
6 | import traceback
7 | import threading
8 |
9 | PY2 = sys.version_info[0] == 2
10 | PY3 = sys.version_info[0] == 3
11 |
12 |
13 | if PY2:
14 | string_classes = basestring
15 | else:
16 | string_classes = (str, bytes)
17 |
18 |
19 | if sys.version_info[0] == 2:
20 | import Queue as queue
21 | else:
22 | import queue
23 |
24 |
25 | _use_shared_memory = False
26 | """Whether to use shared memory in default_collate"""
27 |
28 |
29 | class ExceptionWrapper(object):
30 | "Wraps an exception plus traceback to communicate across threads"
31 |
32 | def __init__(self, exc_info):
33 | self.exc_type = exc_info[0]
34 | self.exc_msg = "".join(traceback.format_exception(*exc_info))
35 |
36 |
37 | def _worker_loop(dataset, index_queue, data_queue, collate_fn):
38 | global _use_shared_memory
39 | _use_shared_memory = True
40 |
41 | torch.set_num_threads(1)
42 | while True:
43 | r = index_queue.get()
44 | if r is None:
45 | data_queue.put(None)
46 | break
47 | idx, batch_indices = r
48 | try:
49 | samples = collate_fn([dataset[i] for i in batch_indices])
50 | except Exception:
51 | data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
52 | else:
53 | data_queue.put((idx, samples))
54 |
55 |
56 | def _pin_memory_loop(in_queue, out_queue, done_event):
57 | while True:
58 | try:
59 | r = in_queue.get()
60 | except:
61 | if done_event.is_set():
62 | return
63 | raise
64 | if r is None:
65 | break
66 | if isinstance(r[1], ExceptionWrapper):
67 | out_queue.put(r)
68 | continue
69 | idx, batch = r
70 | try:
71 | batch = pin_memory_batch(batch)
72 | except Exception:
73 | out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
74 | else:
75 | out_queue.put((idx, batch))
76 |
77 |
78 | numpy_type_map = {
79 | 'float64': torch.DoubleTensor,
80 | 'float32': torch.FloatTensor,
81 | 'float16': torch.HalfTensor,
82 | 'int64': torch.LongTensor,
83 | 'int32': torch.IntTensor,
84 | 'int16': torch.ShortTensor,
85 | 'int8': torch.CharTensor,
86 | 'uint8': torch.ByteTensor,
87 | }
88 |
89 |
90 | def default_collate(batch):
91 | "Puts each data field into a tensor with outer dimension batch size"
92 | if torch.is_tensor(batch[0]):
93 | out = None
94 | if _use_shared_memory:
95 | # If we're in a background process, concatenate directly into a
96 | # shared memory tensor to avoid an extra copy
97 | numel = sum([x.numel() for x in batch])
98 | storage = batch[0].storage()._new_shared(numel)
99 | out = batch[0].new(storage)
100 | return torch.stack(batch, 0, out=out)
101 | elif type(batch[0]).__module__ == 'numpy':
102 | elem = batch[0]
103 | if type(elem).__name__ == 'ndarray':
104 | return torch.stack([torch.from_numpy(b) for b in batch], 0)
105 | if elem.shape == (): # scalars
106 | py_type = float if elem.dtype.name.startswith('float') else int
107 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
108 | elif isinstance(batch[0], int):
109 | return torch.LongTensor(batch)
110 | elif isinstance(batch[0], float):
111 | return torch.DoubleTensor(batch)
112 | elif isinstance(batch[0], string_classes):
113 | return batch
114 | elif isinstance(batch[0], collections.Mapping):
115 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
116 | elif isinstance(batch[0], collections.Sequence):
117 | transposed = zip(*batch)
118 | return [default_collate(samples) for samples in transposed]
119 |
120 | raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}"
121 | .format(type(batch[0]))))
122 |
123 |
124 | def pin_memory_batch(batch):
125 | if torch.is_tensor(batch):
126 | return batch.pin_memory()
127 | elif isinstance(batch, string_classes):
128 | return batch
129 | elif isinstance(batch, collections.Mapping):
130 | return {k: pin_memory_batch(sample) for k, sample in batch.items()}
131 | elif isinstance(batch, collections.Sequence):
132 | return [pin_memory_batch(sample) for sample in batch]
133 | else:
134 | return batch
135 |
136 |
137 | class DataLoaderIter(object):
138 | "Iterates once over the DataLoader's dataset, as specified by the sampler"
139 |
140 | def __init__(self, loader):
141 | self.dataset = loader.dataset
142 | self.collate_fn = loader.collate_fn
143 | self.batch_sampler = loader.batch_sampler
144 | self.num_workers = loader.num_workers
145 | self.pin_memory = loader.pin_memory
146 | self.done_event = threading.Event()
147 |
148 | self.sample_iter = iter(self.batch_sampler)
149 |
150 | if self.num_workers > 0:
151 | self.index_queue = multiprocessing.SimpleQueue()
152 | self.data_queue = multiprocessing.SimpleQueue()
153 | self.batches_outstanding = 0
154 | self.shutdown = False
155 | self.send_idx = 0
156 | self.rcvd_idx = 0
157 | self.reorder_dict = {}
158 |
159 | self.workers = [
160 | multiprocessing.Process(
161 | target=_worker_loop,
162 | args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn))
163 | for _ in range(self.num_workers)]
164 |
165 | for w in self.workers:
166 | w.daemon = True # ensure that the worker exits on process exit
167 | w.start()
168 |
169 | if self.pin_memory:
170 | in_data = self.data_queue
171 | self.data_queue = queue.Queue()
172 | self.pin_thread = threading.Thread(
173 | target=_pin_memory_loop,
174 | args=(in_data, self.data_queue, self.done_event))
175 | self.pin_thread.daemon = True
176 | self.pin_thread.start()
177 |
178 | # prime the prefetch loop
179 | for _ in range(2 * self.num_workers):
180 | self._put_indices()
181 |
182 | def __len__(self):
183 | return len(self.batch_sampler)
184 |
185 | def __next__(self):
186 | if self.num_workers == 0: # same-process loading
187 | #TODO(hwang): try to figure out what's happening right here and fix this issue
188 | indices = next(self.sample_iter) # may raise StopIteration
189 | batch = self.collate_fn([self.dataset[i] for i in indices])
190 | if self.pin_memory:
191 | batch = pin_memory_batch(batch)
192 | return batch
193 |
194 | # check if the next sample has already been generated
195 | if self.rcvd_idx in self.reorder_dict:
196 | batch = self.reorder_dict.pop(self.rcvd_idx)
197 | return self._process_next_batch(batch)
198 |
199 | if self.batches_outstanding == 0:
200 | self._shutdown_workers()
201 | raise StopIteration
202 |
203 | while True:
204 | assert (not self.shutdown and self.batches_outstanding > 0)
205 | idx, batch = self.data_queue.get()
206 | self.batches_outstanding -= 1
207 | if idx != self.rcvd_idx:
208 | # store out-of-order samples
209 | self.reorder_dict[idx] = batch
210 | continue
211 | return self._process_next_batch(batch)
212 |
213 | next = __next__ # Python 2 compatibility
214 |
215 | def __iter__(self):
216 | return self
217 |
218 | def _put_indices(self):
219 | assert self.batches_outstanding < 2 * self.num_workers
220 | indices = next(self.sample_iter, None)
221 | if indices is None:
222 | return
223 | self.index_queue.put((self.send_idx, indices))
224 | self.batches_outstanding += 1
225 | self.send_idx += 1
226 |
227 | def _process_next_batch(self, batch):
228 | self.rcvd_idx += 1
229 | self._put_indices()
230 | if isinstance(batch, ExceptionWrapper):
231 | raise batch.exc_type(batch.exc_msg)
232 | return batch
233 |
234 | def __getstate__(self):
235 | # TODO: add limited pickling support for sharing an iterator
236 | # across multiple threads for HOGWILD.
237 | # Probably the best way to do this is by moving the sample pushing
238 | # to a separate thread and then just sharing the data queue
239 | # but signalling the end is tricky without a non-blocking API
240 | raise NotImplementedError("DataLoaderIterator cannot be pickled")
241 |
242 | def _shutdown_workers(self):
243 | if not self.shutdown:
244 | self.shutdown = True
245 | self.done_event.set()
246 | for _ in self.workers:
247 | self.index_queue.put(None)
248 |
249 | def __del__(self):
250 | if self.num_workers > 0:
251 | self._shutdown_workers()
252 |
253 |
254 | class DataLoader(object):
255 | """
256 | Data loader. Combines a dataset and a sampler, and provides
257 | single- or multi-process iterators over the dataset.
258 | Arguments:
259 | dataset (Dataset): dataset from which to load the data.
260 | batch_size (int, optional): how many samples per batch to load
261 | (default: 1).
262 | shuffle (bool, optional): set to ``True`` to have the data reshuffled
263 | at every epoch (default: False).
264 | sampler (Sampler, optional): defines the strategy to draw samples from
265 | the dataset. If specified, ``shuffle`` must be False.
266 | batch_sampler (Sampler, optional): like sampler, but returns a batch of
267 | indices at a time. Mutually exclusive with batch_size, shuffle,
268 | sampler, and drop_last.
269 | num_workers (int, optional): how many subprocesses to use for data
270 | loading. 0 means that the data will be loaded in the main process
271 | (default: 0)
272 | collate_fn (callable, optional): merges a list of samples to form a mini-batch.
273 | pin_memory (bool, optional): If ``True``, the data loader will copy tensors
274 | into CUDA pinned memory before returning them.
275 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
276 | if the dataset size is not divisible by the batch size. If False and
277 | the size of dataset is not divisible by the batch size, then the last batch
278 | will be smaller. (default: False)
279 | """
280 |
281 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
282 | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False):
283 | self.dataset = dataset
284 | self.batch_size = batch_size
285 | self.num_workers = num_workers
286 | self.collate_fn = collate_fn
287 | self.pin_memory = pin_memory
288 | self.drop_last = drop_last
289 | #self._index_in_epoch = 0
290 |
291 |
292 | if batch_sampler is not None:
293 | if batch_size > 1 or shuffle or sampler is not None or drop_last:
294 | raise ValueError('batch_sampler is mutually exclusive with '
295 | 'batch_size, shuffle, sampler, and drop_last')
296 |
297 | if sampler is not None and shuffle:
298 | raise ValueError('sampler is mutually exclusive with shuffle')
299 |
300 | if batch_sampler is None:
301 | if sampler is None:
302 | if shuffle:
303 | sampler = RandomSampler(dataset)
304 | else:
305 | sampler = SequentialSampler(dataset)
306 | batch_sampler = BatchSampler(sampler, batch_size, drop_last)
307 |
308 | self.sampler = sampler
309 | self.batch_sampler = batch_sampler
310 | self.data_iterator = DataLoaderIter(self)
311 |
312 | def __iter__(self):
313 | return DataLoaderIter(self)
314 |
315 | def __len__(self):
316 | return len(self.batch_sampler)
317 |
318 | def next_batch(self):
319 | return next(self.data_iterator)
--------------------------------------------------------------------------------
/src/data_prepare.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | python ./data/data_prepare.py
--------------------------------------------------------------------------------
/src/datasets.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 |
4 | from torch.utils.data import Dataset
5 | import torch.utils.data as data
6 | from PIL import Image
7 | import os
8 | import os.path
9 | import numpy as np
10 |
11 | class MNISTDataset(Dataset):
12 | def __init__(self, dataset, transform):
13 | self.dataset = dataset
14 | self.images = dataset.images
15 | self.labels = dataset.labels
16 | self.transform = transform
17 |
18 | def __len__(self):
19 | return self.images.shape[0]
20 |
21 | def __getitem__(self, idx):
22 | data_sample = self.images[idx]
23 | label_sample = self.labels[idx]
24 | if self.transform:
25 | data_sample = self.transform(data_sample)
26 | label_sample = self.transform(label_sample)
27 | return data_sample, label_sample
28 |
29 | def next_batch(self, batch_size):
30 | image_batch, label_batch = self.dataset.next_batch(batch_size=batch_size)
31 | return torch.from_numpy(image_batch), torch.from_numpy(label_batch)
32 | # TODO(hwang): figure out why `ToTensor` caused error here
33 | #return self.transform(image_batch), self.transform(label_batch)
34 |
35 | class Cifar10Dataset(Dataset):
36 | def __init__(self, dataset, transform):
37 | self.dataset = dataset
38 | self.images = dataset.images
39 | self.labels = dataset.labels
40 | self.transform = transform
41 |
42 | def __len__(self):
43 | return self.images.shape[0]
44 |
45 | def __getitem__(self, idx):
46 | data_sample = self.images[idx]
47 | label_sample = self.labels[idx]
48 | if self.transform:
49 | data_sample = self.transform(data_sample)
50 | label_sample = self.transform(label_sample)
51 | return data_sample, label_sample
52 |
53 | def next_batch(self, batch_size):
54 | image_batch, label_batch = self.dataset.next_batch(batch_size=batch_size)
55 | return torch.from_numpy(image_batch), torch.from_numpy(label_batch)
56 | # TODO(hwang): figure out why `ToTensor` caused error here
57 | #return self.transform(image_batch), self.transform(label_batch)
58 |
59 | import torch.utils.data as data
60 | from PIL import Image
61 | import os
62 | import os.path
63 | import numpy as np
64 |
65 | import os
66 | import os.path
67 | import hashlib
68 | import errno
69 |
70 | def check_integrity(fpath, md5):
71 | if not os.path.isfile(fpath):
72 | return False
73 | md5o = hashlib.md5()
74 | with open(fpath, 'rb') as f:
75 | # read in 1MB chunks
76 | for chunk in iter(lambda: f.read(1024 * 1024), b''):
77 | md5o.update(chunk)
78 | md5c = md5o.hexdigest()
79 | if md5c != md5:
80 | return False
81 | return True
82 |
83 |
84 | def download_url(url, root, filename, md5):
85 | from six.moves import urllib
86 |
87 | root = os.path.expanduser(root)
88 | fpath = os.path.join(root, filename)
89 |
90 | try:
91 | os.makedirs(root)
92 | except OSError as e:
93 | if e.errno == errno.EEXIST:
94 | pass
95 | else:
96 | raise
97 |
98 | # downloads file
99 | if os.path.isfile(fpath) and check_integrity(fpath, md5):
100 | print('Using downloaded and verified file: ' + fpath)
101 | else:
102 | try:
103 | print('Downloading ' + url + ' to ' + fpath)
104 | urllib.request.urlretrieve(url, fpath)
105 | except:
106 | if url[:5] == 'https':
107 | url = url.replace('https:', 'http:')
108 | print('Failed download. Trying https -> http instead.'
109 | ' Downloading ' + url + ' to ' + fpath)
110 | urllib.request.urlretrieve(url, fpath)
111 |
112 |
113 | class SVHN(data.Dataset):
114 | """`SVHN `_ Dataset.
115 | Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
116 | we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
117 | expect the class labels to be in the range `[0, C-1]`
118 | Args:
119 | root (string): Root directory of dataset where directory
120 | ``SVHN`` exists.
121 | split (string): One of {'train', 'test', 'extra'}.
122 | Accordingly dataset is selected. 'extra' is Extra training set.
123 | transform (callable, optional): A function/transform that takes in an PIL image
124 | and returns a transformed version. E.g, ``transforms.RandomCrop``
125 | target_transform (callable, optional): A function/transform that takes in the
126 | target and transforms it.
127 | download (bool, optional): If true, downloads the dataset from the internet and
128 | puts it in root directory. If dataset is already downloaded, it is not
129 | downloaded again.
130 | """
131 | url = ""
132 | filename = ""
133 | file_md5 = ""
134 |
135 | split_list = {
136 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
137 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"],
138 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
139 | "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"],
140 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
141 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
142 |
143 | def __init__(self, root, split='train',
144 | transform=None, target_transform=None, download=False):
145 | self.root = os.path.expanduser(root)
146 | self.transform = transform
147 | self.target_transform = target_transform
148 | self.split = split # training set or test set or extra set
149 |
150 | if self.split not in self.split_list:
151 | raise ValueError('Wrong split entered! Please use split="train" '
152 | 'or split="extra" or split="test"')
153 |
154 | self.url = self.split_list[split][0]
155 | self.filename = self.split_list[split][1]
156 | self.file_md5 = self.split_list[split][2]
157 |
158 | if download:
159 | self.download()
160 |
161 | if not self._check_integrity():
162 | raise RuntimeError('Dataset not found or corrupted.' +
163 | ' You can use download=True to download it')
164 |
165 | # import here rather than at top of file because this is
166 | # an optional dependency for torchvision
167 | import scipy.io as sio
168 |
169 | # reading(loading) mat file as array
170 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
171 |
172 | self.data = loaded_mat['X']
173 | # loading from the .mat file gives an np array of type np.uint8
174 | # converting to np.int64, so that we have a LongTensor after
175 | # the conversion from the numpy array
176 | # the squeeze is needed to obtain a 1D tensor
177 | self.labels = loaded_mat['y'].astype(np.int64).squeeze()
178 |
179 | # the svhn dataset assigns the class label "10" to the digit 0
180 | # this makes it inconsistent with several loss functions
181 | # which expect the class labels to be in the range [0, C-1]
182 | np.place(self.labels, self.labels == 10, 0)
183 | self.data = np.transpose(self.data, (3, 2, 0, 1))
184 |
185 | def __getitem__(self, index):
186 | """
187 | Args:
188 | index (int): Index
189 | Returns:
190 | tuple: (image, target) where target is index of the target class.
191 | """
192 | img, target = self.data[index], int(self.labels[index])
193 |
194 | # doing this so that it is consistent with all other datasets
195 | # to return a PIL Image
196 | img = Image.fromarray(np.transpose(img, (1, 2, 0)))
197 |
198 | if self.transform is not None:
199 | img = self.transform(img)
200 |
201 | if self.target_transform is not None:
202 | target = self.target_transform(target)
203 |
204 | return img, target
205 |
206 | def __len__(self):
207 | return len(self.data)
208 |
209 | def _check_integrity(self):
210 | root = self.root
211 | md5 = self.split_list[self.split][2]
212 | fpath = os.path.join(root, self.filename)
213 | return check_integrity(fpath, md5)
214 |
215 | def download(self):
216 | md5 = self.split_list[self.split][2]
217 | download_url(self.url, self.root, self.filename, md5)
218 |
219 | def __repr__(self):
220 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
221 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
222 | fmt_str += ' Split: {}\n'.format(self.split)
223 | fmt_str += ' Root Location: {}\n'.format(self.root)
224 | tmp = ' Transforms (if any): '
225 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
226 | tmp = ' Target Transforms (if any): '
227 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
228 | return fmt_str
--------------------------------------------------------------------------------
/src/distributed_evaluator.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os.path
3 | import time
4 | import argparse
5 | from datetime import datetime
6 | import copy
7 |
8 | from mpi4py import MPI
9 | import numpy as np
10 |
11 | from nn_ops import NN_Trainer
12 |
13 | import torch
14 | from torch.autograd import Variable
15 | import torch.nn.functional as F
16 | from torchvision import datasets, transforms
17 | from torch.utils.data import DataLoader
18 |
19 | from model_ops.lenet import LeNet, LeNetSplit
20 | from model_ops.resnet import *
21 | from model_ops.resnet_split import *
22 | #from util import build_model
23 |
24 |
25 | def accuracy(output, target, topk=(1,)):
26 | """Computes the precision@k for the specified values of k"""
27 | maxk = max(topk)
28 | batch_size = target.size(0)
29 |
30 | _, pred = output.topk(maxk, 1, True, True)
31 | pred = pred.t()
32 | correct = pred.eq(target.view(1, -1).expand_as(pred))
33 | res = []
34 | for k in topk:
35 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
36 | res.append(correct_k.mul_(100.0 / batch_size))
37 | return res
38 |
39 | def add_fit_args(parser):
40 | """
41 | parser : argparse.ArgumentParser
42 | return a parser added with args required by fit
43 | """
44 | # Validation settings
45 | parser.add_argument('--eval-batch-size', type=int, default=10000, metavar='N',
46 | help='the batch size when doing model validation, complete at once on default')
47 | parser.add_argument('--eval-freq', type=int, default=50, metavar='N',
48 | help='it determines per how many step the model should be evaluated')
49 | parser.add_argument('--model-dir', type=str, default='output/models/', metavar='N',
50 | help='directory to save the temp model during the training process for evaluation')
51 | parser.add_argument('--dataset', type=str, default='MNIST', metavar='N',
52 | help='which dataset used in training, MNIST and Cifar10 supported currently')
53 | parser.add_argument('--network', type=str, default='LeNet', metavar='N',
54 | help='which kind of network we are going to use, support LeNet and ResNet currently')
55 | args = parser.parse_args()
56 | return args
57 |
58 | class DistributedEvaluator(NN_Trainer):
59 | '''
60 | The DistributedEvaluator aims at providing a seperate node in the distributed cluster to evaluate
61 | the model on validation/test set and return the results
62 | In this version, the DistributedEvaluator will only load the model from the dir where the master
63 | save the model and do the evaluation task based on a user defined frequency
64 | '''
65 | def __init__(self, **kwargs):
66 | self._cur_step = 0
67 | self._model_dir = kwargs['model_dir']
68 | self._eval_freq = int(kwargs['eval_freq'])
69 | self._eval_batch_size = kwargs['eval_batch_size']
70 | self.network_config = kwargs['network']
71 | # this one is going to be used to avoid fetch the weights for multiple times
72 | self._layer_cur_step = []
73 |
74 | def evaluate(self, validation_loader):
75 | # init objective to fetch at the begining
76 | self._next_step_to_fetch = self._cur_step + self._eval_freq
77 | self._num_batch_per_epoch = len(validation_loader) / self._eval_batch_size
78 | # check if next temp model exsits, if not we wait here else we continue to do the model evaluation
79 | while True:
80 | model_dir_=self._model_dir_generator(self._next_step_to_fetch)
81 | if os.path.isfile(model_dir_):
82 | self._load_model(model_dir_)
83 | print("Evaluator evaluating results on step {}".format(self._next_step_to_fetch))
84 | self._evaluate_model(validation_loader)
85 | self._next_step_to_fetch += self._eval_freq
86 | else:
87 | # TODO(hwang): sleep appropriate period of time make sure to tune this parameter
88 | time.sleep(10)
89 |
90 | def _evaluate_model(self, test_loader):
91 | self.network.eval()
92 | test_loss = 0
93 | correct = 0
94 | prec1_counter_ = prec5_counter_ = batch_counter_ = 0
95 | for data, y_batch in test_loader:
96 | data, target = Variable(data, volatile=True), Variable(y_batch)
97 |
98 | output = self.network(data)
99 | test_loss += F.nll_loss(F.log_softmax(output), target, size_average=False).data[0]
100 |
101 | prec1_tmp, prec5_tmp = accuracy(output.data, target.data, topk=(1, 5))
102 | prec1_counter_ += prec1_tmp.numpy()[0]
103 | prec5_counter_ += prec5_tmp.numpy()[0]
104 | batch_counter_ += 1
105 | prec1 = prec1_counter_ / batch_counter_
106 | prec5 = prec5_counter_ / batch_counter_
107 | test_loss /= len(test_loader.dataset)
108 | print('Test set: Step: {}, Average loss: {:.4f}, Prec@1: {} Prec@5: {}'.format(self._cur_step,
109 | test_loss, prec1, prec5))
110 |
111 | def _load_model(self, file_path):
112 | #self.network = build_model(self.network_config, num_classes=10)
113 | # build network
114 | if self.network_config == "LeNet":
115 | self.network=LeNet()
116 | elif self.network_config == "ResNet18":
117 | self.network=ResNet18(num_classes=num_classes)
118 | elif self.network_config == "ResNet34":
119 | self.network=ResNet34(num_classes=num_classes)
120 | elif self.network_config == "FC":
121 | self.network=FC_NN()
122 | elif self.network_config == "DenseNet":
123 | self.network=DenseNet(growthRate=40, depth=190, reduction=0.5,
124 | bottleneck=True, nClasses=10)
125 | elif self.network_config == "VGG11":
126 | self.network=vgg11_bn(num_classes)
127 | elif self.network_config == "AlexNet":
128 | self.network=alexnet(num_classes=10)
129 |
130 | with open(file_path, "rb") as f_:
131 | self.network.load_state_dict(torch.load(f_))
132 |
133 | def _model_dir_generator(self, next_step_to_fetch):
134 | return self._model_dir+"model_step_"+str(next_step_to_fetch)
135 |
136 | if __name__ == "__main__":
137 | # this is only a simple test case
138 | args = add_fit_args(argparse.ArgumentParser(description='PyTorch Distributed Evaluator'))
139 |
140 | # load training and test set here:
141 | if args.dataset == "MNIST":
142 | test_loader = torch.utils.data.DataLoader(
143 | datasets.MNIST('../data', train=False, transform=transforms.Compose([
144 | transforms.ToTensor(),
145 | transforms.Normalize((0.1307,), (0.3081,))
146 | ])), batch_size=args.eval_batch_size, shuffle=True)
147 | elif args.dataset == "Cifar10":
148 | test_loader = torch.utils.data.DataLoader(
149 | datasets.CIFAR10('./cifar10_data', train=False, transform=transforms.Compose([
150 | transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
151 | ])), batch_size=args.eval_batch_size, shuffle=True)
152 |
153 | kwargs_evaluator={
154 | 'network':args.network,
155 | 'model_dir':args.model_dir,
156 | 'eval_freq':args.eval_freq,
157 | 'eval_batch_size':args.eval_batch_size}
158 | evaluator_nn = DistributedEvaluator(**kwargs_evaluator)
159 | evaluator_nn.evaluate(validation_loader=test_loader)
160 | print("I am worker: {} in all {} workers".format(worker_fc_nn.rank, worker_fc_nn.world_size))
--------------------------------------------------------------------------------
/src/distributed_nn.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import sys
4 | import math
5 | import threading
6 | import argparse
7 | import time
8 |
9 | import numpy as np
10 | from mpi4py import MPI
11 |
12 | import torch
13 | from torch.autograd import Variable
14 | from torch import nn
15 | from torch.nn.parallel.replicate import replicate
16 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
17 | from torch.nn.parallel.parallel_apply import parallel_apply
18 | import torch.nn.functional as F
19 |
20 | from torchvision import datasets, transforms
21 |
22 | from nn_ops import NN_Trainer, accuracy
23 | from data_loader_ops.my_data_loader import DataLoader
24 |
25 | from distributed_worker import *
26 | from sync_replicas_master_nn import *
27 |
28 | from datasets import SVHN
29 |
30 |
31 | def add_fit_args(parser):
32 | """
33 | parser : argparse.ArgumentParser
34 | return a parser added with args required by fit
35 | """
36 | # Training settings
37 | parser.add_argument('--batch-size', type=int, default=128, metavar='N',
38 | help='input batch size for training (default: 64)')
39 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
40 | help='input batch size for testing (default: 1000)')
41 | parser.add_argument('--max-steps', type=int, default=10000, metavar='N',
42 | help='the maximum number of iterations')
43 | parser.add_argument('--epochs', type=int, default=100, metavar='N',
44 | help='number of epochs to train (default: 10)')
45 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
46 | help='learning rate (default: 0.01)')
47 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
48 | help='SGD momentum (default: 0.5)')
49 | parser.add_argument('--lr-shrinkage', type=float, default=0.95, metavar='M',
50 | help='exponential decay factor of lr schedule (default: 0.95)')
51 | parser.add_argument('--no-cuda', action='store_true', default=False,
52 | help='disables CUDA training')
53 | parser.add_argument('--seed', type=int, default=1, metavar='S',
54 | help='random seed (default: 1)')
55 | parser.add_argument('--log-interval', type=int, default=10, metavar='N',
56 | help='how many batches to wait before logging training status')
57 | parser.add_argument('--network', type=str, default='LeNet', metavar='N',
58 | help='which kind of network we are going to use, support LeNet and ResNet currently')
59 | parser.add_argument('--code', type=str, default='sgd',
60 | help='which coding method use e.g. sgd, qsgd, svd')
61 | parser.add_argument('--bucket-size', type=int, default=512,
62 | help='bucket sizes used in QSGD')
63 | parser.add_argument('--dataset', type=str, default='MNIST', metavar='N',
64 | help='which dataset used in training, MNIST and Cifar10 supported currently')
65 | parser.add_argument('--comm-type', type=str, default='Bcast', metavar='N',
66 | help='which kind of method we use during the mode fetching stage')
67 | parser.add_argument('--num-aggregate', type=int, default=5, metavar='N',
68 | help='how many number of gradients we wish to gather at each iteration')
69 | parser.add_argument('--eval-freq', type=int, default=50, metavar='N',
70 | help='it determines per how many step the model should be evaluated')
71 | parser.add_argument('--train-dir', type=str, default='output/models/', metavar='N',
72 | help='directory to save the temp model during the training process for evaluation')
73 | parser.add_argument('--compress', type=bool, default=False,
74 | help='whether to use gradient approx method')
75 | parser.add_argument('--enable-gpu', type=bool, default=False,
76 | help='whether to use gradient approx method')
77 |
78 | parser.add_argument('--svd-rank', default=0, help='Boolean int: compress or not',
79 | type=int)
80 | parser.add_argument('--quantization-level', type=int, default=4, help='int: bits used in QSGD')
81 | args = parser.parse_args()
82 | return args
83 |
84 | if __name__ == "__main__":
85 | # this is only a simple test case
86 | comm = MPI.COMM_WORLD
87 | rank = comm.Get_rank()
88 | world_size = comm.Get_size()
89 |
90 | args = add_fit_args(argparse.ArgumentParser(description='PyTorch MNIST Single Machine Test'))
91 |
92 | # load training and test set here:
93 | if args.dataset == "MNIST":
94 | training_set = datasets.MNIST('./mnist_data', train=True, download=True,
95 | transform=transforms.Compose([
96 | transforms.ToTensor(),
97 | transforms.Normalize((0.1307,), (0.3081,))]))
98 | train_loader = DataLoader(training_set, batch_size=args.batch_size, shuffle=True)
99 | test_loader = torch.utils.data.DataLoader(
100 | datasets.MNIST('./mnist_data', train=False, transform=transforms.Compose([
101 | transforms.ToTensor(),
102 | transforms.Normalize((0.1307,), (0.3081,))
103 | ])), batch_size=args.test_batch_size, shuffle=True)
104 | elif args.dataset == "Cifar10":
105 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
106 | std=[x/255.0 for x in [63.0, 62.1, 66.7]])
107 | transform_train = transforms.Compose([
108 | transforms.ToTensor(),
109 | transforms.Lambda(lambda x: F.pad(
110 | Variable(x.unsqueeze(0), requires_grad=False, volatile=True),
111 | (4,4,4,4),mode='reflect').data.squeeze()),
112 | transforms.ToPILImage(),
113 | transforms.RandomCrop(32),
114 | transforms.RandomHorizontalFlip(),
115 | transforms.ToTensor(),
116 | normalize,
117 | ])
118 | # data prep for test set
119 | transform_test = transforms.Compose([
120 | transforms.ToTensor(),
121 | normalize])
122 | # load training and test set here:
123 | training_set = datasets.CIFAR10(root='./cifar10_data', train=True,
124 | download=True, transform=transform_train)
125 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=args.batch_size,
126 | shuffle=True)
127 | testset = datasets.CIFAR10(root='./cifar10_data', train=False,
128 | download=True, transform=transform_test)
129 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size,
130 | shuffle=False)
131 | elif args.dataset == 'SVHN':
132 | training_set = SVHN('./svhn_data', split='train', transform=transforms.Compose([
133 | transforms.RandomCrop(32, padding=4),
134 | transforms.RandomHorizontalFlip(),
135 | transforms.ToTensor(),
136 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
137 | ]))
138 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=128,
139 | shuffle=True)
140 | transform_test = transforms.Compose([
141 | transforms.ToTensor(),
142 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
143 | ])
144 | testset = SVHN(root='./svhn_data', split='test',
145 | download=True, transform=transform_test)
146 | test_loader = torch.utils.data.DataLoader(testset, batch_size=1000,
147 | shuffle=False)
148 | elif args.dataset == 'Cifar100':
149 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
150 | std=[x/255.0 for x in [63.0, 62.1, 66.7]])
151 | transform_train = transforms.Compose([
152 | transforms.ToTensor(),
153 | transforms.Lambda(lambda x: F.pad(
154 | Variable(x.unsqueeze(0), requires_grad=False, volatile=True),
155 | (4,4,4,4),mode='reflect').data.squeeze()),
156 | transforms.ToPILImage(),
157 | transforms.RandomCrop(32),
158 | transforms.RandomHorizontalFlip(),
159 | transforms.ToTensor(),
160 | normalize,
161 | ])
162 | # data prep for test set
163 | transform_test = transforms.Compose([
164 | transforms.ToTensor(),
165 | normalize])
166 | # load training and test set here:
167 | training_set = datasets.CIFAR100(root='./cifar100_data', train=True,
168 | download=True, transform=transform_train)
169 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=args.batch_size,
170 | shuffle=True)
171 | testset = datasets.CIFAR100(root='./cifar100_data', train=False,
172 | download=True, transform=transform_test)
173 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size,
174 | shuffle=False)
175 | elif args.dataset == 'ImageNet':
176 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
177 | std=[x/255.0 for x in [63.0, 62.1, 66.7]])
178 | # data prep for training set
179 | # note that the key point to reach convergence performance reported in this paper (https://arxiv.org/abs/1512.03385)
180 | # is to implement data augmentation
181 | transform_train = transforms.Compose([
182 | transforms.Scale((227, 227)),
183 | transforms.ToTensor(),
184 | transforms.Lambda(lambda x: F.pad(
185 | Variable(x.unsqueeze(0), requires_grad=False, volatile=True),
186 | (4,4,4,4),mode='reflect').data.squeeze()),
187 | transforms.ToPILImage(),
188 | transforms.RandomCrop(227),
189 | transforms.RandomHorizontalFlip(),
190 | transforms.ToTensor(),
191 | normalize,
192 | ])
193 | # data prep for test set
194 | transform_test = transforms.Compose([
195 | transforms.ToTensor(),
196 | normalize])
197 | # load training and test set here:
198 | training_set = datasets.CIFAR10(root='./cifar10_data', train=True,
199 | download=True, transform=transform_train)
200 | #training_set = datasets.CIFAR10(root='./cifar10_data', train=True,
201 | # download=True, transform=transform_test)
202 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=args.batch_size,
203 | shuffle=True)
204 | testset = datasets.CIFAR10(root='./cifar10_data', train=False,
205 | download=True, transform=transform_test)
206 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size,
207 | shuffle=False)
208 |
209 | kwargs_master = {'batch_size':args.batch_size,
210 | 'learning_rate':args.lr,
211 | 'max_epochs':args.epochs,
212 | 'max_steps':args.max_steps,
213 | 'momentum':args.momentum,
214 | 'network':args.network,
215 | 'comm_method':args.comm_type,
216 | 'eval_freq':args.eval_freq,
217 | 'train_dir':args.train_dir,
218 | 'compress':args.compress,
219 | 'num_aggregate':args.num_aggregate,
220 | 'enable_gpu':args.enable_gpu,
221 | 'lr_shrinkage':args.lr_shrinkage,
222 | 'code':args.code,
223 | 'svd_rank':args.svd_rank,
224 | 'quantization_level':args.quantization_level,
225 | 'bucket_size':args.bucket_size}
226 |
227 | kwargs_worker = {'batch_size':args.batch_size,
228 | 'learning_rate':args.lr,
229 | 'max_epochs':args.epochs,
230 | 'momentum':args.momentum,
231 | 'network':args.network,
232 | 'max_steps':args.max_steps,
233 | 'comm_method':args.comm_type,
234 | 'compress':args.compress,
235 | 'enable_gpu':args.enable_gpu,
236 | 'eval_freq':args.eval_freq,
237 | 'train_dir':args.train_dir,
238 | 'code':args.code,
239 | 'svd_rank':args.svd_rank,
240 | 'quantization_level':args.quantization_level,
241 | 'bucket_size':args.bucket_size}
242 |
243 | if rank == 0:
244 | master_fc_nn = SyncReplicasMaster_NN(comm=comm, **kwargs_master)
245 | if args.dataset == 'Cifar100':
246 | master_fc_nn.build_model(num_classes=100)
247 | else:
248 | master_fc_nn.build_model(num_classes=10)
249 | print("I am the master: the world size is {}, cur step: {}".format(master_fc_nn.world_size, master_fc_nn.cur_step))
250 | master_fc_nn.train()
251 | print("Done sending messages to workers!")
252 | else:
253 | worker_fc_nn = DistributedWorker(comm=comm, **kwargs_worker)
254 | if args.dataset == 'Cifar100':
255 | worker_fc_nn.build_model(num_classes=100)
256 | else:
257 | worker_fc_nn.build_model(num_classes=10)
258 | print("I am worker: {} in all {} workers, next step: {}".format(worker_fc_nn.rank, worker_fc_nn.world_size-1, worker_fc_nn.next_step))
259 | worker_fc_nn.train(train_loader=train_loader, test_loader=test_loader)
260 | print("Worker Done Jobs! ...")
--------------------------------------------------------------------------------
/src/distributed_worker.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from mpi4py import MPI
3 | import numpy as np
4 |
5 | from nn_ops import NN_Trainer
6 |
7 | from model_ops.lenet import LeNet, LeNetSplit
8 | from model_ops.resnet import *
9 | from model_ops.resnet_split import *
10 | from model_ops.vgg import *
11 | from model_ops.alexnet import *
12 | from model_ops.fc_nn import FC_NN, FC_NN_Split
13 | from model_ops.densenet import DenseNet
14 |
15 | from utils import compress
16 | import codings
17 |
18 | import torch
19 | from torch.autograd import Variable
20 |
21 | import time
22 | from datetime import datetime
23 | import copy
24 | from sys import getsizeof
25 | import pickle
26 |
27 | STEP_START_ = 1
28 | # use compression tool to make it run faster
29 | _FAKE_SGD = True
30 | TAG_LIST_ = [i*30 for i in range(50000)]
31 |
32 | def prepare_grad_list(params):
33 | grad_list = []
34 | for param_idx, param in enumerate(params):
35 | # get gradient from layers here
36 | # in this version we fetch weights at once
37 | # remember to change type here, which is essential
38 | grads = param.grad.data.numpy().astype(np.float32)
39 | grad_list.append((param_idx, grads))
40 | return grad_list
41 |
42 | def accuracy(output, target, topk=(1,)):
43 | """Computes the precision@k for the specified values of k"""
44 | maxk = max(topk)
45 | batch_size = target.size(0)
46 |
47 | _, pred = output.topk(maxk, 1, True, True)
48 | pred = pred.t()
49 | correct = pred.eq(target.view(1, -1).expand_as(pred))
50 |
51 | res = []
52 | for k in topk:
53 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
54 | res.append(correct_k.mul_(100.0 / batch_size))
55 | return res
56 |
57 | def _4d_to_2d(tensor):
58 | return tensor.view(tensor.size()[0], tensor.size()[1]*tensor.size()[2]*tensor.size()[3])
59 |
60 | def _construct_grad_packet(module):
61 | '''
62 | input: weight (\in R^{m*n}) and bias (\in R^{m*n})
63 | output: grad packet (\in R^{m*(n+1)})
64 | '''
65 | ndims = len(module.weight.grad.data.size())
66 | if ndims == 4:
67 | tmp_grad = _4d_to_2d(module.weight.grad.data)
68 | if module.bias is None:
69 | return tmp_grad
70 | else:
71 | return torch.cat((tmp_grad, module.bias.grad.data), 1)
72 | elif ndims == 2:
73 | if module.bias is None:
74 | return module.weight.grad.data
75 | else:
76 | return torch.cat((module.weight.grad.data, module.bias.grad.data), 1)
77 | elif ndims == 1:
78 | if module.bias is None:
79 | return module.weight.grad.data
80 | else:
81 | return torch.cat((module.bias.grad.data, module.weight.grad.data), 0)
82 |
83 |
84 | class ModelBuffer(object):
85 | def __init__(self, network):
86 | """
87 | this class is used to save model weights received from parameter server
88 | current step for each layer of model will also be updated here to make sure
89 | the model is always up-to-date
90 | """
91 | self.recv_buf = []
92 | self.layer_cur_step = []
93 | for param_idx, param in enumerate(network.parameters()):
94 | self.recv_buf.append(np.zeros(param.size()))
95 | self.layer_cur_step.append(0)
96 |
97 |
98 | class DistributedWorker(NN_Trainer):
99 | def __init__(self, comm, **kwargs):
100 | self.comm = comm # get MPI communicator object
101 | self.world_size = comm.Get_size() # total number of processes
102 | self.rank = comm.Get_rank() # rank of this Worker
103 | #self.status = MPI.Status()
104 | self.cur_step = 0
105 | self.next_step = 0 # we will fetch this one from parameter server
106 |
107 | self.batch_size = kwargs['batch_size']
108 | self.max_epochs = kwargs['max_epochs']
109 | self.momentum = kwargs['momentum']
110 | self.lr = kwargs['learning_rate']
111 | self.network_config = kwargs['network']
112 | self._max_steps = kwargs['max_steps']
113 | self.comm_type = kwargs['comm_method']
114 | self._compress = kwargs['compress']
115 | self._enable_gpu = kwargs['enable_gpu']
116 | self._eval_batch_size = 100
117 | self._eval_freq = kwargs['eval_freq']
118 | self._train_dir = kwargs['train_dir']
119 | # encode related
120 | self._svd_rank = kwargs['svd_rank']
121 | self._quantization_level = kwargs['quantization_level']
122 | self._bucket_size = kwargs['bucket_size']
123 |
124 | # this one is going to be used to avoid fetch the weights for multiple times
125 | self._layer_cur_step = []
126 | self._code = kwargs['code']
127 | if kwargs['code'] == 'sgd':
128 | if not _FAKE_SGD:
129 | self._coder = codings.svd.SVD(compress=False)
130 | else:
131 | self._coder = codings.lossless_compress.LosslessCompress()
132 | elif kwargs['code'] == 'svd':
133 | print("train.py, svd_rank =", self._svd_rank)
134 | self._coder = codings.svd.SVD(random_sample=True,
135 | rank=self._svd_rank, compress=True)
136 | else:
137 | raise ValueError('args.code not recognized')
138 |
139 | def build_model(self, num_classes=10):
140 | # build network
141 | if self.network_config == "LeNet":
142 | self.network=LeNet()
143 | elif self.network_config == "ResNet18":
144 | self.network=ResNet18(num_classes=num_classes)
145 | elif self.network_config == "ResNet34":
146 | self.network=ResNet34(num_classes=num_classes)
147 | elif self.network_config == "FC":
148 | self.network=FC_NN()
149 | elif self.network_config == "DenseNet":
150 | self.network=DenseNet(growthRate=40, depth=190, reduction=0.5,
151 | bottleneck=True, nClasses=10)
152 | elif self.network_config == "VGG11":
153 | self.network=vgg11_bn(num_classes)
154 | elif self.network_config == "AlexNet":
155 | self.network=alexnet(num_classes=10)
156 |
157 | # set up optimizer
158 | self.optimizer = torch.optim.SGD(self.network.parameters(), lr=self.lr, momentum=self.momentum)
159 | self.criterion = nn.CrossEntropyLoss()
160 | # assign a buffer for receiving models from parameter server
161 | self.init_recv_buf()
162 | # enable GPU here
163 | if self._enable_gpu:
164 | self.network.cuda()
165 |
166 | def train(self, train_loader, test_loader):
167 | # the first step we need to do here is to sync fetch the inital worl_step from the parameter server
168 | # we still need to make sure the value we fetched from parameter server is 1
169 | self.sync_fetch_step()
170 | # do some sync check here
171 | assert(self.update_step())
172 | assert(self.cur_step == STEP_START_)
173 |
174 | # number of batches in one epoch
175 | num_batch_per_epoch = len(train_loader.dataset) / self.batch_size
176 | batch_idx = -1
177 | epoch_idx = 0
178 | epoch_avg_loss = 0
179 | iteration_last_step=0
180 | iter_start_time=0
181 |
182 | first = True
183 |
184 | print("Worker {}: starting training".format(self.rank))
185 | # start the training process
186 | # start the training process
187 | for num_epoch in range(self.max_epochs):
188 | for batch_idx, (train_image_batch, train_label_batch) in enumerate(train_loader):
189 | # worker exit task
190 | if self.cur_step == self._max_steps:
191 | break
192 | if self._enable_gpu:
193 | X_batch, y_batch = Variable(train_image_batch.cuda()), Variable(train_label_batch.cuda())
194 | else:
195 | X_batch, y_batch = Variable(train_image_batch), Variable(train_label_batch)
196 | while True:
197 | # the worker shouldn't know the current global step
198 | # except received the message from parameter server
199 | self.async_fetch_step()
200 |
201 | # the only way every worker know which step they're currently on is to check the cur step variable
202 | updated = self.update_step()
203 |
204 | if (not updated) and (not first):
205 | # wait here unitl enter next step
206 | continue
207 |
208 | # the real start point of this iteration
209 | iteration_last_step = time.time() - iter_start_time
210 | iter_start_time = time.time()
211 | first = False
212 | print("Rank of this node: {}, Current step: {}".format(self.rank, self.cur_step))
213 |
214 | # TODO(hwang): return layer request here and do weight before the forward step begins, rather than implement
215 | # the wait() in the fetch function
216 | fetch_weight_start_time = time.time()
217 |
218 | self.async_fetch_weights_bcast()
219 |
220 | fetch_weight_duration = time.time() - fetch_weight_start_time
221 |
222 | # switch to training mode
223 | self.network.train()
224 | # manage batch index manually
225 | self.optimizer.zero_grad()
226 |
227 | # forward step
228 | comp_start = time.time()
229 | logits = self.network(X_batch)
230 | loss = self.criterion(logits, y_batch)
231 |
232 | epoch_avg_loss += loss.data[0]
233 |
234 | # backward step
235 | backward_start_time = time.time()
236 | loss.backward()
237 | comp_dur = time.time() - comp_start
238 |
239 | # gradient encoding step
240 | encode_start = time.time()
241 | msgs,_msg_counter = self._encode()
242 | encode_dur = time.time() - encode_start
243 |
244 | # communication step
245 | comm_start = time.time()
246 | self._send_grads(msgs)
247 | comm_dur = time.time()-comm_start
248 |
249 | prec1, prec5 = accuracy(logits.data, y_batch.data, topk=(1, 5))
250 | if self._enable_gpu:
251 | prec1, prec5 = prec1.cpu().numpy()[0], prec5.cpu().numpy()[0]
252 | else:
253 | prec1, prec5 = prec1.numpy()[0], prec5.numpy()[0]
254 | # on the end of a certain iteration
255 | print('Worker: {}, Step: {}, Epoch: {} [{}/{} ({:.0f}%)], Loss: {:.4f}, Time Cost: {:.4f}, Comp: {:.4f}, Encode: {: .4f}, Comm: {: .4f}, Msg(MB): {: .4f}, Prec@1: {: .4f}, Prec@5: {: .4f}'.format(
256 | self.rank, self.cur_step, num_epoch, batch_idx * self.batch_size, len(train_loader.dataset),
257 | (100. * (batch_idx * self.batch_size) / len(train_loader.dataset)), loss.data[0], time.time()-iter_start_time,
258 | comp_dur, encode_dur, comm_dur, _msg_counter/(1024.0**2), prec1, prec5))
259 | # break here to fetch data then enter fetching step loop again
260 | if self.cur_step%self._eval_freq==0:
261 | self._evaluate_model(test_loader)
262 | break
263 |
264 | def init_recv_buf(self):
265 | self.model_recv_buf = ModelBuffer(self.network)
266 |
267 | def sync_fetch_step(self):
268 | '''fetch the first step from the parameter server'''
269 | self.next_step = self.comm.recv(source=0, tag=10)
270 |
271 | def async_fetch_step(self):
272 | req = self.comm.irecv(source=0, tag=10)
273 | self.next_step = req.wait()
274 |
275 | def async_fetch_weights_bcast(self):
276 | layers_to_update = []
277 | for layer_idx, layer in enumerate(self.model_recv_buf.recv_buf):
278 | if self.model_recv_buf.layer_cur_step[layer_idx] < self.cur_step:
279 | layers_to_update.append(layer_idx)
280 | self.comm.Bcast([self.model_recv_buf.recv_buf[layer_idx], MPI.DOUBLE], root=0)
281 | weights_to_update = []
282 | for req_idx, layer_idx in enumerate(layers_to_update):
283 | weights = self.model_recv_buf.recv_buf[req_idx]
284 | weights_to_update.append(weights)
285 | # we also need to update the layer cur step here:
286 | self.model_recv_buf.layer_cur_step[req_idx] = self.cur_step
287 | self.model_update(weights_to_update)
288 |
289 | def update_step(self):
290 | '''update local (global) step on worker'''
291 | changed = (self.cur_step != self.next_step)
292 | self.cur_step = self.next_step
293 | return changed
294 |
295 | def model_update(self, weights_to_update):
296 | """write model fetched from parameter server to local model"""
297 | new_state_dict = {}
298 | model_counter_ = 0
299 | for param_idx,(key_name, param) in enumerate(self.network.state_dict().items()):
300 | # handle the case that `running_mean` and `running_var` contained in `BatchNorm` layer
301 | if "running_mean" in key_name or "running_var" in key_name:
302 | tmp_dict={key_name: param}
303 | else:
304 | assert param.size() == weights_to_update[model_counter_].shape
305 | if self._enable_gpu:
306 | tmp_dict = {key_name: torch.from_numpy(weights_to_update[model_counter_]).cuda()}
307 | else:
308 | tmp_dict = {key_name: torch.from_numpy(weights_to_update[model_counter_])}
309 | model_counter_ += 1
310 | new_state_dict.update(tmp_dict)
311 | self.network.load_state_dict(new_state_dict)
312 |
313 | def _encode(self):
314 | msgs = []
315 | _msg_counter = 0
316 | def __count_msg_sizes(msg):
317 | return len(msg)
318 | for p_index, p in enumerate(self.network.parameters()):
319 | if self._enable_gpu:
320 | grad = p.grad.data.cpu().numpy().astype(np.float32)
321 | else:
322 | p.grad.data.numpy().astype(np.float32)
323 | coded = self._coder.encode(grad)
324 | pickled = pickle.dumps(coded)
325 | byte_code = bytearray(pickled)
326 | _msg_counter+=__count_msg_sizes(byte_code)
327 | msgs.append(byte_code)
328 | return msgs, _msg_counter
329 |
330 | def _send_grads(self, msgs):
331 | req_send_check = []
332 | for msg_index, m in enumerate(msgs):
333 | req_isend = self.comm.isend(m, dest=0, tag=88+msg_index)
334 | req_isend.wait()
335 | [req_send_check[i].wait() for i in range(len(req_send_check))]
336 |
337 | def _generate_model_path(self):
338 | return self._train_dir+"model_step_"+str(self.cur_step)
339 |
340 | def _save_model(self, file_path):
341 | with open(file_path, "wb") as f_:
342 | torch.save(self.network.state_dict(), f_)
343 |
344 | def _evaluate_model(self, test_loader):
345 | self.network.eval()
346 | test_loss = 0
347 | correct = 0
348 | prec1_counter_ = prec5_counter_ = batch_counter_ = 0
349 | for data, y_batch in test_loader:
350 | if self._enable_gpu:
351 | data, target = Variable(data.cuda(), volatile=True), Variable(y_batch.cuda())
352 | else:
353 | data, target = Variable(data, volatile=True), Variable(y_batch)
354 |
355 | output = self.network(data)
356 | test_loss += F.nll_loss(F.log_softmax(output), target, size_average=False).data[0]
357 |
358 | prec1_tmp, prec5_tmp = accuracy(output.data, target.data, topk=(1, 5))
359 | if self._enable_gpu:
360 | prec1_counter_ += prec1_tmp.cpu().numpy()[0]
361 | prec5_counter_ += prec5_tmp.cpu().numpy()[0]
362 | else:
363 | prec1_counter_ += prec1_tmp.numpy()[0]
364 | prec5_counter_ += prec5_tmp.numpy()[0]
365 | batch_counter_ += 1
366 | prec1 = prec1_counter_ / batch_counter_
367 | prec5 = prec5_counter_ / batch_counter_
368 | test_loss /= len(test_loader.dataset)
369 | print('Test set: Step: {}, Average loss: {:.4f}, Prec@1: {} Prec@5: {}'.format(self.cur_step,
370 | test_loss, prec1, prec5))
371 |
372 | if __name__ == "__main__":
373 | # this is only a simple test case
374 | comm = MPI.COMM_WORLD
375 | rank = comm.Get_rank()
376 | world_size = comm.Get_size()
377 | worker_fc_nn = WorkerFC_NN(comm=comm, world_size=world_size, rank=rank)
378 | print("I am worker: {} in all {} workers".format(worker_fc_nn.rank, worker_fc_nn.world_size))
--------------------------------------------------------------------------------
/src/evaluate_pytorch.sh:
--------------------------------------------------------------------------------
1 | python distributed_evaluator.py \
2 | --eval-batch-size=10000 \
3 | --eval-freq=50 \
4 | --network=ResNet18 \
5 | --dataset=Cifar10 \
6 | --model-dir=/home/ubuntu/MPI_shared/
--------------------------------------------------------------------------------
/src/model_ops/__init__.py:
--------------------------------------------------------------------------------
1 | from . import densenet, fc_nn, lenet, resnet, vgg, alexnet
2 |
3 | __all__ = ['densenet', 'fc_nn', 'lenet', 'resnet', 'vgg', 'alexnet']
--------------------------------------------------------------------------------
/src/model_ops/alexnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 |
4 |
5 | __all__ = ['AlexNet', 'alexnet']
6 |
7 |
8 | model_urls = {
9 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
10 | }
11 |
12 |
13 | class AlexNet(nn.Module):
14 |
15 | def __init__(self, num_classes=1000):
16 | super(AlexNet, self).__init__()
17 | self.features = nn.Sequential(
18 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
19 | nn.ReLU(inplace=True),
20 | nn.MaxPool2d(kernel_size=3, stride=2),
21 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
22 | nn.ReLU(inplace=True),
23 | nn.MaxPool2d(kernel_size=3, stride=2),
24 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
25 | nn.ReLU(inplace=True),
26 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
27 | nn.ReLU(inplace=True),
28 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
29 | nn.ReLU(inplace=True),
30 | nn.MaxPool2d(kernel_size=3, stride=2),
31 | )
32 | self.classifier = nn.Sequential(
33 | nn.Dropout(),
34 | nn.Linear(256 * 6 * 6, 4096),
35 | nn.ReLU(inplace=True),
36 | nn.Dropout(),
37 | nn.Linear(4096, 4096),
38 | nn.ReLU(inplace=True),
39 | nn.Linear(4096, num_classes),
40 | )
41 |
42 | def forward(self, x):
43 | x = self.features(x)
44 | x = x.view(x.size(0), 256 * 6 * 6)
45 | x = self.classifier(x)
46 | return x
47 |
48 |
49 | def alexnet(pretrained=False, **kwargs):
50 | r"""AlexNet model architecture from the
51 | `"One weird trick..." `_ paper.
52 | Args:
53 | pretrained (bool): If True, returns a model pre-trained on ImageNet
54 | """
55 | model = AlexNet(**kwargs)
56 | if pretrained:
57 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))
58 | return model
--------------------------------------------------------------------------------
/src/model_ops/densenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import torch.nn as nn
4 | import torch.optim as optim
5 |
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 |
9 | import torchvision.datasets as dset
10 | import torchvision.transforms as transforms
11 | from torch.utils.data import DataLoader
12 |
13 | import torchvision.models as models
14 |
15 | import sys
16 | import math
17 |
18 | class Bottleneck(nn.Module):
19 | def __init__(self, nChannels, growthRate):
20 | super(Bottleneck, self).__init__()
21 | interChannels = 4*growthRate
22 | self.bn1 = nn.BatchNorm2d(nChannels)
23 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1,
24 | bias=False)
25 | self.bn2 = nn.BatchNorm2d(interChannels)
26 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3,
27 | padding=1, bias=False)
28 |
29 | def forward(self, x):
30 | out = self.conv1(F.relu(self.bn1(x)))
31 | out = self.conv2(F.relu(self.bn2(out)))
32 | out = torch.cat((x, out), 1)
33 | return out
34 |
35 | class SingleLayer(nn.Module):
36 | def __init__(self, nChannels, growthRate):
37 | super(SingleLayer, self).__init__()
38 | self.bn1 = nn.BatchNorm2d(nChannels)
39 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3,
40 | padding=1, bias=False)
41 |
42 | def forward(self, x):
43 | out = self.conv1(F.relu(self.bn1(x)))
44 | out = torch.cat((x, out), 1)
45 | return out
46 |
47 | class Transition(nn.Module):
48 | def __init__(self, nChannels, nOutChannels):
49 | super(Transition, self).__init__()
50 | self.bn1 = nn.BatchNorm2d(nChannels)
51 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1,
52 | bias=False)
53 |
54 | def forward(self, x):
55 | out = self.conv1(F.relu(self.bn1(x)))
56 | out = F.avg_pool2d(out, 2)
57 | return out
58 |
59 |
60 | class DenseNet(nn.Module):
61 | def __init__(self, growthRate, depth, reduction, nClasses, bottleneck):
62 | super(DenseNet, self).__init__()
63 |
64 | nDenseBlocks = (depth-4) // 3
65 | if bottleneck:
66 | nDenseBlocks //= 2
67 |
68 | nChannels = 2*growthRate
69 | self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1,
70 | bias=False)
71 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
72 | nChannels += nDenseBlocks*growthRate
73 | nOutChannels = int(math.floor(nChannels*reduction))
74 | self.trans1 = Transition(nChannels, nOutChannels)
75 |
76 | nChannels = nOutChannels
77 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
78 | nChannels += nDenseBlocks*growthRate
79 | nOutChannels = int(math.floor(nChannels*reduction))
80 | self.trans2 = Transition(nChannels, nOutChannels)
81 |
82 | nChannels = nOutChannels
83 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
84 | nChannels += nDenseBlocks*growthRate
85 |
86 | self.bn1 = nn.BatchNorm2d(nChannels)
87 | self.fc = nn.Linear(nChannels, nClasses)
88 |
89 | for m in self.modules():
90 | if isinstance(m, nn.Conv2d):
91 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
92 | m.weight.data.normal_(0, math.sqrt(2. / n))
93 | elif isinstance(m, nn.BatchNorm2d):
94 | m.weight.data.fill_(1)
95 | m.bias.data.zero_()
96 | elif isinstance(m, nn.Linear):
97 | m.bias.data.zero_()
98 |
99 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck):
100 | layers = []
101 | for i in range(int(nDenseBlocks)):
102 | if bottleneck:
103 | layers.append(Bottleneck(nChannels, growthRate))
104 | else:
105 | layers.append(SingleLayer(nChannels, growthRate))
106 | nChannels += growthRate
107 | return nn.Sequential(*layers)
108 |
109 | def forward(self, x):
110 | out = self.conv1(x)
111 | out = self.trans1(self.dense1(out))
112 | out = self.trans2(self.dense2(out))
113 | out = self.dense3(out)
114 | out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8))
115 | out = F.log_softmax(self.fc(out))
116 | return out
--------------------------------------------------------------------------------
/src/model_ops/fc_nn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | import pandas as pd
6 | import numpy as np
7 | from torch.autograd import Variable
8 |
9 | from mpi4py import MPI
10 |
11 | # we use fc nn here for our simple case
12 | class FC_NN(nn.Module):
13 | def __init__(self):
14 | super(FC_NN, self).__init__()
15 | self.fc1 = nn.Linear(784, 800)
16 | self.fc2 = nn.Linear(800, 500)
17 | self.fc3 = nn.Linear(500, 10)
18 | self.relu = nn.ReLU()
19 | self.sigmoid = nn.Sigmoid()
20 | def forward(self, x):
21 | x = x.view(-1, x.size()[1]*x.size()[2]*x.size()[3])
22 | x = self.fc1(x)
23 | x = self.relu(x)
24 | x = self.fc2(x)
25 | x = self.relu(x)
26 | x = self.fc3(x)
27 | x = self.sigmoid(x)
28 | return x
29 | def name(self):
30 | return 'fc_nn'
31 |
32 | # we use fc nn here for our simple case
33 | class FC_NN_Split(nn.Module):
34 | def __init__(self):
35 | super(FC_NN_Split, self).__init__()
36 | self.fc1 = nn.Linear(784, 800)
37 | self.fc2 = nn.Linear(800, 500)
38 | self.fc3 = nn.Linear(500, 10)
39 | self.relu = nn.ReLU()
40 | self.sigmoid = nn.Sigmoid()
41 | # helper
42 | self.full_modules = [self.fc1, self.fc2, self.fc3]
43 | self._init_channel_index = len(self.full_modules)*2
44 | def forward(self, x):
45 | '''
46 | split layers
47 | '''
48 | self.output = []
49 | self.input = []
50 | x = x.view(-1, x.size()[1]*x.size()[2]*x.size()[3])
51 |
52 | x = Variable(x.data, requires_grad=True)
53 | self.input.append(x)
54 | x = self.fc1(x)
55 | self.output.append(x)
56 |
57 | x = Variable(x.data, requires_grad=True)
58 | self.input.append(x)
59 | x = self.relu(x)
60 | self.output.append(x)
61 |
62 | x = Variable(x.data, requires_grad=True)
63 | self.input.append(x)
64 | x = self.fc2(x)
65 | self.output.append(x)
66 |
67 | x = Variable(x.data, requires_grad=True)
68 | self.input.append(x)
69 | x = self.relu(x)
70 | self.output.append(x)
71 |
72 | x = Variable(x.data, requires_grad=True)
73 | self.input.append(x)
74 | x = self.fc3(x)
75 | self.output.append(x)
76 |
77 | x = Variable(x.data, requires_grad=True)
78 | self.input.append(x)
79 | x = self.sigmoid(x)
80 | self.output.append(x)
81 | return x
82 | @property
83 | def fetch_init_channel_index(self):
84 | return self._init_channel_index
85 | def backward_normal(self, g, communicator, req_send_check, cur_step):
86 | mod_avail_index = len(self.full_modules)-1
87 | #channel_index = len(self.full_modules)*2-2
88 | channel_index = self._init_channel_index - 2
89 | mod_counters_ = [0]*len(self.full_modules)
90 | for i, output in reversed(list(enumerate(self.output))):
91 | req_send_check[-1].wait()
92 | if i == (len(self.output) - 1):
93 | # for last node, use g
94 | output.backward(g)
95 | # get gradient here after some sanity checks:
96 | '''
97 | tmp_grad = self.full_modules[mod_avail_index].weight.grad
98 | if not pd.isnull(tmp_grad):
99 | grads = tmp_grad.data.numpy().astype(np.float64)
100 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
101 | req_send_check.append(req_isend)
102 | # update counters
103 | mod_avail_index-=1
104 | channel_index-=1
105 | else:
106 | continue
107 | '''
108 | else:
109 | output.backward(self.input[i+1].grad.data)
110 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
111 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad
112 | # specific for this fc nn setting
113 | if mod_avail_index == len(self.full_modules)-1:
114 | if not pd.isnull(tmp_grad_weight):
115 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
116 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
117 | req_send_check.append(req_isend)
118 | # update counters
119 | mod_avail_index-=1
120 | channel_index-=1
121 | else:
122 | continue
123 | else:
124 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias):
125 | # we always send bias first
126 | if mod_counters_[mod_avail_index] == 0:
127 | grads = tmp_grad_bias.data.numpy().astype(np.float64)
128 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
129 | req_send_check.append(req_isend)
130 | channel_index-=1
131 | mod_counters_[mod_avail_index]+=1
132 | elif mod_counters_[mod_avail_index] == 1:
133 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
134 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
135 | req_send_check.append(req_isend)
136 | channel_index-=1
137 | mod_counters_[mod_avail_index]+=1
138 | # update counters
139 | mod_avail_index-=1
140 | else:
141 | continue
142 | if mod_counters_[0] == 1:
143 | req_send_check[-1].wait()
144 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
145 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
146 | req_send_check.append(req_isend)
147 | # if cur_step >= 2:
148 | # exit()
149 | return req_send_check
150 | @property
151 | def name(self):
152 | return 'fc_nn'
--------------------------------------------------------------------------------
/src/model_ops/lenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | import pandas as pd
6 | import numpy as np
7 | from torch.autograd import Variable
8 |
9 | from mpi4py import MPI
10 |
11 | # we use LeNet here for our simple case
12 | class LeNet(nn.Module):
13 | def __init__(self):
14 | super(LeNet, self).__init__()
15 | self.conv1 = nn.Conv2d(1, 20, 5, 1)
16 | self.conv2 = nn.Conv2d(20, 50, 5, 1)
17 | self.fc1 = nn.Linear(4*4*50, 500)
18 | self.fc2 = nn.Linear(500, 10)
19 | self.ceriation = nn.CrossEntropyLoss()
20 | self.full_modules = [self.conv1, self.conv2, self.fc1, self.fc2]
21 |
22 | def forward(self, x):
23 | x = self.conv1(x)
24 | x = F.max_pool2d(x, 2, 2)
25 | x = F.relu(x)
26 | x = self.conv2(x)
27 | x = F.max_pool2d(x, 2, 2)
28 | x = F.relu(x)
29 | x = x.view(-1, 4*4*50)
30 | x = self.fc1(x)
31 | x = self.fc2(x)
32 | #loss = self.ceriation(x, target)
33 | return x
34 | def name(self):
35 | return 'lenet'
36 |
37 | class LeNetSplit(nn.Module):
38 | '''
39 | this is a module that we split the module and do backward process layer by layer
40 | please don't call this module for normal uses, this is a hack and run slower than
41 | the automatic chain rule version
42 | '''
43 | def __init__(self):
44 | super(LeNetSplit, self).__init__()
45 | self.conv1 = nn.Conv2d(1, 20, 5, 1)
46 | self.conv2 = nn.Conv2d(20, 50, 5, 1)
47 | self.fc1 = nn.Linear(4*4*50, 500)
48 | self.fc2 = nn.Linear(500, 10)
49 |
50 | self.maxpool2d = nn.MaxPool2d(2, stride=2)
51 | self.relu = nn.ReLU()
52 |
53 | self.full_modules = [self.conv1, self.conv2, self.fc1, self.fc2]
54 | self._init_channel_index = len(self.full_modules)*2
55 |
56 | self.criterion = nn.CrossEntropyLoss()
57 |
58 | def forward(self, x):
59 | self.output = []
60 | self.input = []
61 | x = Variable(x.data, requires_grad=True)
62 | self.input.append(x)
63 | x = self.conv1(x)
64 | self.output.append(x)
65 |
66 | x = Variable(x.data, requires_grad=True)
67 | self.input.append(x)
68 | x = self.maxpool2d(x)
69 | self.output.append(x)
70 |
71 | x = Variable(x.data, requires_grad=True)
72 | self.input.append(x)
73 | x = self.relu(x)
74 | self.output.append(x)
75 |
76 | x = Variable(x.data, requires_grad=True)
77 | self.input.append(x)
78 | x = self.conv2(x)
79 | self.output.append(x)
80 |
81 | x = Variable(x.data, requires_grad=True)
82 | self.input.append(x)
83 | x = self.maxpool2d(x)
84 | self.output.append(x)
85 |
86 | x = Variable(x.data, requires_grad=True)
87 | self.input.append(x)
88 | x = self.relu(x)
89 | self.output.append(x)
90 |
91 | x = x.view(-1, 4*4*50)
92 |
93 | x = Variable(x.data, requires_grad=True)
94 | self.input.append(x)
95 | x = self.fc1(x)
96 | self.output.append(x)
97 |
98 | x = Variable(x.data, requires_grad=True)
99 | self.input.append(x)
100 | x = self.fc2(x)
101 | self.output.append(x)
102 | return x
103 |
104 | @property
105 | def fetch_init_channel_index(self):
106 | return self._init_channel_index
107 |
108 | def backward_normal(self, g, communicator, req_send_check, cur_step):
109 | mod_avail_index = len(self.full_modules)-1
110 | #channel_index = len(self.full_modules)*2-2
111 | channel_index = self._init_channel_index - 2
112 | mod_counters_ = [0]*len(self.full_modules)
113 | for i, output in reversed(list(enumerate(self.output))):
114 | req_send_check[-1].wait()
115 | if i == (len(self.output) - 1):
116 | # for last node, use g
117 | output.backward(g)
118 | # get gradient here after some sanity checks:
119 | tmp_grad = self.full_modules[mod_avail_index].weight.grad
120 | if not pd.isnull(tmp_grad):
121 | grads = tmp_grad.data.numpy().astype(np.float64)
122 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
123 | req_send_check.append(req_isend)
124 | # update counters
125 | mod_avail_index-=1
126 | channel_index-=1
127 | else:
128 | continue
129 | else:
130 | output.backward(self.input[i+1].grad.data)
131 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
132 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad
133 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias):
134 | # we always send bias first
135 | if mod_counters_[mod_avail_index] == 0:
136 | grads = tmp_grad_bias.data.numpy().astype(np.float64)
137 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
138 | req_send_check.append(req_isend)
139 | channel_index-=1
140 | mod_counters_[mod_avail_index]+=1
141 | elif mod_counters_[mod_avail_index] == 1:
142 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
143 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
144 | req_send_check.append(req_isend)
145 | channel_index-=1
146 | mod_counters_[mod_avail_index]+=1
147 | # update counters
148 | mod_avail_index-=1
149 | else:
150 | continue
151 | if mod_counters_[0] == 1:
152 | req_send_check[-1].wait()
153 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
154 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
155 | req_send_check.append(req_isend)
156 | return req_send_check
157 |
158 | def backward_signal_kill(self, g, communicator, req_send_check, cur_step):
159 | '''
160 | This killer is triggered by signals bcasting from master, channel of
161 | signal is kept checking by each worker to determine if they're the
162 | straggler
163 | '''
164 | mod_avail_index = len(self.full_modules)-1
165 | channel_index = self._init_channel_index - 2
166 | mod_counters_ = [0]*len(self.full_modules)
167 |
168 | # should kill flag
169 | should_kill = False
170 |
171 | for i, output in reversed(list(enumerate(self.output))):
172 | ############################ killing process on workers #####################################
173 | for _ in range(10000):
174 | status = MPI.Status()
175 | communicator.Iprobe(0, 77, status)
176 | if status.Get_source() == 0:
177 | print("Worker {}, Cur Step: {} I'm the straggler, killing myself!".format(communicator.Get_rank(), cur_step))
178 | tmp = communicator.recv(source=0, tag=77)
179 | should_kill = True
180 | break
181 | if should_kill:
182 | break
183 | ############################################################################################
184 |
185 | if i == (len(self.output) - 1):
186 | # for last node, use g
187 | output.backward(g)
188 | # get gradient here after some sanity checks:
189 | tmp_grad = self.full_modules[mod_avail_index].weight.grad
190 | if not pd.isnull(tmp_grad):
191 | grads = tmp_grad.data.numpy().astype(np.float64)
192 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
193 | req_send_check.append(req_isend)
194 | # update counters
195 | mod_avail_index-=1
196 | channel_index-=1
197 | else:
198 | continue
199 | else:
200 | output.backward(self.input[i+1].grad.data)
201 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
202 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad
203 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias):
204 | # we always send bias first
205 | if mod_counters_[mod_avail_index] == 0:
206 | grads = tmp_grad_bias.data.numpy().astype(np.float64)
207 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
208 | req_send_check.append(req_isend)
209 | channel_index-=1
210 | mod_counters_[mod_avail_index]+=1
211 | elif mod_counters_[mod_avail_index] == 1:
212 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
213 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
214 | req_send_check.append(req_isend)
215 | channel_index-=1
216 | mod_counters_[mod_avail_index]+=1
217 | # update counters
218 | mod_avail_index-=1
219 | else:
220 | continue
221 | if mod_counters_[0] == 1:
222 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
223 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
224 | req_send_check.append(req_isend)
225 | return req_send_check
226 |
227 | def backward_timeout_kill(self, g, communicator, req_send_check):
228 | """do we even need this?"""
229 | pass
--------------------------------------------------------------------------------
/src/model_ops/resnet.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 | For Pre-activation ResNet, see 'preact_resnet.py'.
3 | Reference:
4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from torch.autograd import Variable
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, in_planes, planes, stride=1):
18 | super(BasicBlock, self).__init__()
19 | self.full_modules = []
20 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
21 | self.full_modules.append(self.conv1)
22 | self.bn1 = nn.BatchNorm2d(planes)
23 | self.full_modules.append(self.bn1)
24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
25 | self.full_modules.append(self.conv2)
26 | self.bn2 = nn.BatchNorm2d(planes)
27 | self.full_modules.append(self.bn2)
28 |
29 | self.shortcut = nn.Sequential()
30 | if stride != 1 or in_planes != self.expansion*planes:
31 | self.shortcut = nn.Sequential(
32 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
33 | nn.BatchNorm2d(self.expansion*planes)
34 | )
35 | self.full_modules.append(self.shortcut[0])
36 | self.full_modules.append(self.shortcut[1])
37 |
38 | def forward(self, x):
39 | out = F.relu(self.bn1(self.conv1(x)))
40 | out = self.bn2(self.conv2(out))
41 | out += self.shortcut(x)
42 | out = F.relu(out)
43 | return out
44 |
45 |
46 | class Bottleneck(nn.Module):
47 | expansion = 4
48 |
49 | def __init__(self, in_planes, planes, stride=1):
50 | super(Bottleneck, self).__init__()
51 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
52 | self.bn1 = nn.BatchNorm2d(planes)
53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
54 | self.bn2 = nn.BatchNorm2d(planes)
55 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
56 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
57 |
58 | self.shortcut = nn.Sequential()
59 | if stride != 1 or in_planes != self.expansion*planes:
60 | self.shortcut = nn.Sequential(
61 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
62 | nn.BatchNorm2d(self.expansion*planes)
63 | )
64 |
65 | def forward(self, x):
66 | out = F.relu(self.bn1(self.conv1(x)))
67 | out = F.relu(self.bn2(self.conv2(out)))
68 | out = self.bn3(self.conv3(out))
69 | out += self.shortcut(x)
70 | out = F.relu(out)
71 | return out
72 |
73 |
74 | class ResNet(nn.Module):
75 | def __init__(self, block, num_blocks, num_classes=10):
76 | super(ResNet, self).__init__()
77 | self.in_planes = 64
78 | self.full_modules = []
79 |
80 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
81 | self.full_modules.append(self.conv1)
82 | self.bn1 = nn.BatchNorm2d(64)
83 | self.full_modules.append(self.bn1)
84 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
85 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
86 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
87 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
88 | self.linear = nn.Linear(512*block.expansion, num_classes)
89 | self.full_modules.append(self.linear)
90 |
91 | def _make_layer(self, block, planes, num_blocks, stride):
92 | strides = [stride] + [1]*(num_blocks-1)
93 | layers = []
94 | for stride in strides:
95 | block_layers=block(self.in_planes, planes, stride)
96 | layers.append(block_layers)
97 | for m in block_layers.full_modules:
98 | self.full_modules.append(m)
99 | self.in_planes = planes * block.expansion
100 | return nn.Sequential(*layers)
101 |
102 | def forward(self, x):
103 | out = F.relu(self.bn1(self.conv1(x)))
104 | out = self.layer1(out)
105 | out = self.layer2(out)
106 | out = self.layer3(out)
107 | out = self.layer4(out)
108 | out = F.avg_pool2d(out, 4)
109 | out = out.view(out.size(0), -1)
110 | out = self.linear(out)
111 | return out
112 |
113 |
114 | def ResNet18(num_classes):
115 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes)
116 |
117 | def ResNet34():
118 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes)
119 |
120 | def ResNet50():
121 | return ResNet(Bottleneck, [3,4,6,3])
122 |
123 | def ResNet101():
124 | return ResNet(Bottleneck, [3,4,23,3])
125 |
126 | def ResNet152():
127 | return ResNet(Bottleneck, [3,8,36,3])
--------------------------------------------------------------------------------
/src/model_ops/resnet_split.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 | For Pre-activation ResNet, see 'preact_resnet.py'.
3 | Reference:
4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
6 |
7 |
8 | Please Note that, this version is a hack, it's super hacky, never call this one for normal use
9 | '''
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | from torch.autograd import Variable
15 |
16 | import pandas as pd
17 | import numpy as np
18 | import timeout_decorator
19 |
20 | from mpi4py import MPI
21 |
22 | LAYER_DIGITS= int(1e+3)
23 | TIMEOUT_THRESHOLD_=10
24 |
25 | def generate_tag(layer_tag, step_token):
26 | '''
27 | Tag component [current-step-token (which help to recogize stale gradient)
28 | +layer-tag]
29 | we only limit the digits for layer tag here since step token can be
30 | extremely large e.g. 10k steps
31 |
32 | :param layer_tag
33 | :param step token
34 | :return:
35 | '''
36 | tag = step_token * LAYER_DIGITS \
37 | + layer_tag
38 | tag = int(tag)
39 | return tag
40 |
41 | class BasicBlockSplit(nn.Module):
42 | expansion = 1
43 |
44 | def __init__(self, in_planes, planes, stride=1):
45 | super(BasicBlockSplit, self).__init__()
46 | self.full_modules = []
47 |
48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
49 | self.full_modules.append(self.conv1)
50 |
51 | self.bn1 = nn.BatchNorm2d(planes)
52 | self.full_modules.append(self.bn1)
53 |
54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
55 | self.full_modules.append(self.conv2)
56 |
57 | self.bn2 = nn.BatchNorm2d(planes)
58 | self.full_modules.append(self.bn2)
59 |
60 | self.relu = nn.ReLU()
61 |
62 | self.shortcut = nn.Sequential()
63 | if stride != 1 or in_planes != self.expansion*planes:
64 | self.shortcut = nn.Sequential(
65 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
66 | nn.BatchNorm2d(self.expansion*planes)
67 | )
68 | self.full_modules.append(self.shortcut[0])
69 | self.full_modules.append(self.shortcut[1])
70 |
71 | def forward(self, x, input_list, output_list):
72 | '''
73 | the input_list and output_list here is similar to input/output in ResNet class
74 | '''
75 | # we skip the detach and append operation on the very first x here
76 | # since that's done outside of this function
77 | out = self.conv1(x)
78 | output_list.append(out)
79 |
80 | out = Variable(out.data, requires_grad=True)
81 | input_list.append(out)
82 | out = self.bn1(out)
83 | output_list.append(out)
84 |
85 | out = Variable(out.data, requires_grad=True)
86 | input_list.append(out)
87 | out = self.relu(out)
88 | output_list.append(out)
89 |
90 | out = Variable(out.data, requires_grad=True)
91 | input_list.append(out)
92 | out = self.conv2(out)
93 | output_list.append(out)
94 |
95 | out = Variable(out.data, requires_grad=True)
96 | input_list.append(out)
97 | out = self.bn2(out)
98 | output_list.append(out)
99 |
100 | # TODO(hwang): figure out if this part also need hack
101 | out += self.shortcut(x)
102 |
103 | out = Variable(out.data, requires_grad=True)
104 | input_list.append(out)
105 | out = self.relu(out)
106 | output_list.append(out)
107 | return out, input_list, output_list
108 |
109 |
110 | class Bottleneck(nn.Module):
111 | expansion = 4
112 |
113 | def __init__(self, in_planes, planes, stride=1):
114 | super(Bottleneck, self).__init__()
115 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
116 | self.bn1 = nn.BatchNorm2d(planes)
117 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
118 | self.bn2 = nn.BatchNorm2d(planes)
119 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
120 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
121 |
122 | self.shortcut = nn.Sequential()
123 | if stride != 1 or in_planes != self.expansion*planes:
124 | self.shortcut = nn.Sequential(
125 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
126 | nn.BatchNorm2d(self.expansion*planes)
127 | )
128 |
129 | def forward(self, x):
130 | # we skip the detach operation on the very first x here since that's done outside of this function
131 | out = F.relu(self.bn1(self.conv1(x)))
132 | out = F.relu(self.bn2(self.conv2(out)))
133 | out = self.bn3(self.conv3(out))
134 | out += self.shortcut(x)
135 | out = F.relu(out)
136 | return out
137 |
138 |
139 | class ResNetSplit(nn.Module):
140 | def __init__(self, block, num_blocks, kill_threshold, num_classes=10):
141 | super(ResNetSplit, self).__init__()
142 | global TIMEOUT_THRESHOLD_
143 | self.in_planes = 64
144 | self.full_modules = []
145 |
146 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
147 | self.full_modules.append(self.conv1)
148 |
149 | self.bn1 = nn.BatchNorm2d(64)
150 | self.full_modules.append(self.bn1)
151 |
152 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
153 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
154 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
155 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
156 |
157 | self.linear = nn.Linear(512*block.expansion, num_classes)
158 | self.full_modules.append(self.linear)
159 |
160 | self.relu = nn.ReLU()
161 | self.avg_pool2d = nn.AvgPool2d(kernel_size=4)
162 | self._init_channel_index = self.count_channel_index()
163 | self._timeout_threshold = kill_threshold
164 | TIMEOUT_THRESHOLD_ = self._timeout_threshold
165 | self.killed_request_list = []
166 |
167 | @property
168 | def fetch_init_channel_index(self):
169 | return self._init_channel_index
170 |
171 | def _make_layer(self, block, planes, num_blocks, stride):
172 | strides = [stride] + [1]*(num_blocks-1)
173 | layers = []
174 | for stride in strides:
175 | block_layers = block(self.in_planes, planes, stride)
176 | layers.append(block_layers)
177 | for m in block_layers.full_modules:
178 | self.full_modules.append(m)
179 |
180 | self.in_planes = planes * block.expansion
181 | layers_split = nn.ModuleList(layers)
182 |
183 | return layers_split
184 |
185 | def forward(self, x):
186 | # use these containers to save intermediate variables
187 | self.output = []
188 | self.input = []
189 |
190 | # start the forward process right here implement the following logic to every intermediate var:
191 | # detach from previous history
192 | x = Variable(x.data, requires_grad=True)
193 | self.input.append(x)
194 | x = self.conv1(x)
195 | # add to list of outputs
196 | self.output.append(x)
197 |
198 | x = Variable(x.data, requires_grad=True)
199 | self.input.append(x)
200 | x = self.bn1(x)
201 | self.output.append(x)
202 |
203 | x = Variable(x.data, requires_grad=True)
204 | self.input.append(x)
205 | x = self.relu(x)
206 | self.output.append(x)
207 |
208 | # start to handle blocks
209 | for layer in self.layer1:
210 | # each `layer` here is either a `BasicBlockSplit` or `BottleneckSplit`
211 | x = Variable(x.data, requires_grad=True)
212 | self.input.append(x)
213 | # call the `.forward()` func in `BasicBlockSplit` or `BottleneckSplit` here
214 | x, self.input, self.output = layer(x, self.input, self.output)
215 |
216 | for layer in self.layer2:
217 | # each `layer` here is either a `BasicBlockSplit` or `BottleneckSplit`
218 | x = Variable(x.data, requires_grad=True)
219 | self.input.append(x)
220 | # call the `.forward()` func in `BasicBlockSplit` or `BottleneckSplit` here
221 | x, self.input, self.output = layer(x, self.input, self.output)
222 |
223 | for layer in self.layer3:
224 | # each `layer` here is either a `BasicBlockSplit` or `BottleneckSplit`
225 | x = Variable(x.data, requires_grad=True)
226 | self.input.append(x)
227 | # call the `.forward()` func in `BasicBlockSplit` or `BottleneckSplit` here
228 | x, self.input, self.output = layer(x, self.input, self.output)
229 |
230 | for layer in self.layer4:
231 | # each `layer` here is either a `BasicBlockSplit` or `BottleneckSplit`
232 | x = Variable(x.data, requires_grad=True)
233 | self.input.append(x)
234 | # call the `.forward()` func in `BasicBlockSplit` or `BottleneckSplit` here
235 | x, self.input, self.output = layer(x, self.input, self.output)
236 |
237 | x = Variable(x.data, requires_grad=True)
238 | self.input.append(x)
239 | x = self.avg_pool2d(x)
240 | self.output.append(x)
241 |
242 | x = x.view(x.size(0), -1)
243 |
244 | x = Variable(x.data, requires_grad=True)
245 | self.input.append(x)
246 | x = self.linear(x)
247 | self.output.append(x)
248 | return x
249 |
250 | def count_channel_index(self):
251 | channel_index_ = 0
252 | for k, v in self.state_dict().items():
253 | if "running_mean" in k or "running_var" in k:
254 | continue
255 | else:
256 | channel_index_ += 1
257 | return channel_index_
258 |
259 | def backward(self, g, communicator, req_send_check, cur_step):
260 | mod_avail_index = len(self.full_modules)-1
261 | channel_index = self._init_channel_index-2
262 | mod_counters_ = [0]*len(self.full_modules)
263 | for i, output in reversed(list(enumerate(self.output))):
264 | # send layer only after the last layer is received
265 | req_send_check[-1].wait()
266 | if i == (len(self.output) - 1):
267 | # for last node, use g
268 | output.backward(g)
269 | # get gradient here after some sanity checks:
270 | tmp_grad = self.full_modules[mod_avail_index].weight.grad
271 | if not pd.isnull(tmp_grad):
272 | grads = tmp_grad.data.numpy().astype(np.float64)
273 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
274 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step))
275 | req_send_check.append(req_isend)
276 | # update counters
277 | mod_avail_index-=1
278 | channel_index-=1
279 | else:
280 | continue
281 | else:
282 | if output.size() == self.input[i+1].grad.size():
283 | output.backward(self.input[i+1].grad.data)
284 | else:
285 | tmp_grad_output = self.input[i+1].grad.view(output.size())
286 | output.backward(tmp_grad_output)
287 |
288 | # since in resnet we do not use bias weight for conv layer
289 | if pd.isnull(self.full_modules[mod_avail_index].bias):
290 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
291 |
292 | if not pd.isnull(tmp_grad_weight):
293 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
294 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
295 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step))
296 | req_send_check.append(req_isend)
297 | channel_index-=1
298 | mod_counters_[mod_avail_index]=2
299 | # update counters
300 | mod_avail_index-=1
301 | else:
302 | continue
303 | else:
304 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
305 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad
306 |
307 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias):
308 | # we always send bias first
309 | if mod_counters_[mod_avail_index] == 0:
310 | grads = tmp_grad_bias.data.numpy().astype(np.float64)
311 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
312 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step))
313 | req_send_check.append(req_isend)
314 | channel_index-=1
315 | mod_counters_[mod_avail_index]+=1
316 | elif mod_counters_[mod_avail_index] == 1:
317 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
318 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
319 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step))
320 | req_send_check.append(req_isend)
321 | channel_index-=1
322 | mod_counters_[mod_avail_index]+=1
323 | # update counters
324 | mod_avail_index-=1
325 | else:
326 | continue
327 | # handle the remaining gradients here to send to parameter server
328 | while channel_index >= 0:
329 | req_send_check[-1].wait()
330 | if pd.isnull(self.full_modules[mod_avail_index].bias):
331 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
332 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
333 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
334 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step))
335 | req_send_check.append(req_isend)
336 | channel_index-=1
337 | mod_counters_[mod_avail_index]=2
338 | # update counters
339 | mod_avail_index-=1
340 | else:
341 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
342 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad
343 | # we always send bias first
344 | if mod_counters_[mod_avail_index] == 0:
345 | grads = tmp_grad_bias.data.numpy().astype(np.float64)
346 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
347 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step))
348 | req_send_check.append(req_isend)
349 | channel_index-=1
350 | mod_counters_[mod_avail_index]+=1
351 | elif mod_counters_[mod_avail_index] == 1:
352 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
353 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
354 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step))
355 | req_send_check.append(req_isend)
356 | channel_index-=1
357 | mod_counters_[mod_avail_index]+=1
358 | # update counters
359 | mod_avail_index-=1
360 | return req_send_check
361 |
362 | def backward_normal(self, g, communicator, req_send_check, cur_step):
363 | mod_avail_index = len(self.full_modules)-1
364 | channel_index = self._init_channel_index-2
365 | mod_counters_ = [0]*len(self.full_modules)
366 | for i, output in reversed(list(enumerate(self.output))):
367 | # send layer only after the last layer is received
368 | req_send_check[-1].wait()
369 | if i == (len(self.output) - 1):
370 | # for last node, use g
371 | output.backward(g)
372 | # get gradient here after some sanity checks:
373 | tmp_grad = self.full_modules[mod_avail_index].weight.grad
374 | if not pd.isnull(tmp_grad):
375 | grads = tmp_grad.data.numpy().astype(np.float64)
376 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
377 | req_send_check.append(req_isend)
378 | # update counters
379 | mod_avail_index-=1
380 | channel_index-=1
381 | else:
382 | continue
383 | else:
384 | if output.size() == self.input[i+1].grad.size():
385 | output.backward(self.input[i+1].grad.data)
386 | else:
387 | tmp_grad_output = self.input[i+1].grad.view(output.size())
388 | output.backward(tmp_grad_output)
389 |
390 | # since in resnet we do not use bias weight for conv layer
391 | if pd.isnull(self.full_modules[mod_avail_index].bias):
392 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
393 |
394 | if not pd.isnull(tmp_grad_weight):
395 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
396 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
397 | req_send_check.append(req_isend)
398 | channel_index-=1
399 | mod_counters_[mod_avail_index]=2
400 | # update counters
401 | mod_avail_index-=1
402 | else:
403 | continue
404 | else:
405 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
406 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad
407 |
408 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias):
409 | # we always send bias first
410 | if mod_counters_[mod_avail_index] == 0:
411 | grads = tmp_grad_bias.data.numpy().astype(np.float64)
412 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
413 | req_send_check.append(req_isend)
414 | channel_index-=1
415 | mod_counters_[mod_avail_index]+=1
416 | elif mod_counters_[mod_avail_index] == 1:
417 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
418 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
419 | req_send_check.append(req_isend)
420 | channel_index-=1
421 | mod_counters_[mod_avail_index]+=1
422 | # update counters
423 | mod_avail_index-=1
424 | else:
425 | continue
426 | # handle the remaining gradients here to send to parameter server
427 | while channel_index >= 0:
428 | req_send_check[-1].wait()
429 | if pd.isnull(self.full_modules[mod_avail_index].bias):
430 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
431 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
432 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
433 | req_send_check.append(req_isend)
434 | channel_index-=1
435 | mod_counters_[mod_avail_index]=2
436 | # update counters
437 | mod_avail_index-=1
438 | else:
439 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
440 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad
441 | # we always send bias first
442 | if mod_counters_[mod_avail_index] == 0:
443 | grads = tmp_grad_bias.data.numpy().astype(np.float64)
444 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
445 | req_send_check.append(req_isend)
446 | channel_index-=1
447 | mod_counters_[mod_avail_index]+=1
448 | elif mod_counters_[mod_avail_index] == 1:
449 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
450 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
451 | req_send_check.append(req_isend)
452 | channel_index-=1
453 | mod_counters_[mod_avail_index]+=1
454 | # update counters
455 | mod_avail_index-=1
456 | return req_send_check
457 |
458 | def backward_signal_kill(self, g, communicator, req_send_check, cur_step):
459 | mod_avail_index = len(self.full_modules)-1
460 | channel_index = self._init_channel_index-2
461 | mod_counters_ = [0]*len(self.full_modules)
462 |
463 | # should kill flag
464 | should_kill = False
465 |
466 | for i, output in reversed(list(enumerate(self.output))):
467 | ############################ killing process on workers #####################################
468 | for _ in range(100):
469 | status = MPI.Status()
470 | communicator.Iprobe(0, 77, status)
471 | if status.Get_source() == 0:
472 | print("Worker {}, Cur Step: {} I'm the straggler, killing myself!".format(communicator.Get_rank(), cur_step))
473 | tmp = communicator.recv(source=0, tag=77)
474 | should_kill = True
475 | break
476 | if should_kill:
477 | channel_index=-5
478 | break
479 | ############################################################################################
480 | if i == (len(self.output) - 1):
481 | # for last node, use g
482 | output.backward(g)
483 | # get gradient here after some sanity checks:
484 | tmp_grad = self.full_modules[mod_avail_index].weight.grad
485 | if not pd.isnull(tmp_grad):
486 | grads = tmp_grad.data.numpy().astype(np.float64)
487 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
488 | req_send_check.append(req_isend)
489 | # update counters
490 | mod_avail_index-=1
491 | channel_index-=1
492 | else:
493 | continue
494 | else:
495 | if output.size() == self.input[i+1].grad.size():
496 | output.backward(self.input[i+1].grad.data)
497 | else:
498 | tmp_grad_output = self.input[i+1].grad.view(output.size())
499 | output.backward(tmp_grad_output)
500 |
501 | # since in resnet we do not use bias weight for conv layer
502 | if pd.isnull(self.full_modules[mod_avail_index].bias):
503 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
504 |
505 | if not pd.isnull(tmp_grad_weight):
506 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
507 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
508 | req_send_check.append(req_isend)
509 | channel_index-=1
510 | mod_counters_[mod_avail_index]=2
511 | # update counters
512 | mod_avail_index-=1
513 | else:
514 | continue
515 | else:
516 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
517 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad
518 |
519 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias):
520 | # we always send bias first
521 | if mod_counters_[mod_avail_index] == 0:
522 | grads = tmp_grad_bias.data.numpy().astype(np.float64)
523 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
524 | req_send_check.append(req_isend)
525 | channel_index-=1
526 | mod_counters_[mod_avail_index]+=1
527 | elif mod_counters_[mod_avail_index] == 1:
528 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
529 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
530 | req_send_check.append(req_isend)
531 | channel_index-=1
532 | mod_counters_[mod_avail_index]+=1
533 | # update counters
534 | mod_avail_index-=1
535 | else:
536 | continue
537 | # handle the remaining gradients here to send to parameter server
538 | while channel_index >= 0:
539 | if pd.isnull(self.full_modules[mod_avail_index].bias):
540 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
541 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
542 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
543 | req_send_check.append(req_isend)
544 | channel_index-=1
545 | mod_counters_[mod_avail_index]=2
546 | # update counters
547 | mod_avail_index-=1
548 | else:
549 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
550 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad
551 | # we always send bias first
552 | if mod_counters_[mod_avail_index] == 0:
553 | grads = tmp_grad_bias.data.numpy().astype(np.float64)
554 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
555 | req_send_check.append(req_isend)
556 | channel_index-=1
557 | mod_counters_[mod_avail_index]+=1
558 | elif mod_counters_[mod_avail_index] == 1:
559 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
560 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
561 | req_send_check.append(req_isend)
562 | channel_index-=1
563 | mod_counters_[mod_avail_index]+=1
564 | # update counters
565 | mod_avail_index-=1
566 | if channel_index == -1:
567 | killed = False
568 | elif channel_index == -5:
569 | killed = True
570 | return req_send_check, killed
571 |
572 | @timeout_decorator.timeout(10.5, timeout_exception=StopIteration)
573 | def backward_timeout_kill(self, g, communicator, req_send_check, cur_step):
574 | mod_avail_index = len(self.full_modules)-1
575 | channel_index = self._init_channel_index-2
576 | mod_counters_ = [0]*len(self.full_modules)
577 |
578 | # meset request list of killed workers
579 | self.killed_request_list = []
580 | for i, output in reversed(list(enumerate(self.output))):
581 | # send layer only after the last layer is received
582 | req_send_check[-1].wait()
583 | if i == (len(self.output) - 1):
584 | # for last node, use g
585 | output.backward(g)
586 | # get gradient here after some sanity checks:
587 | tmp_grad = self.full_modules[mod_avail_index].weight.grad
588 | if not pd.isnull(tmp_grad):
589 | grads = tmp_grad.data.numpy().astype(np.float64)
590 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
591 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step))
592 | req_send_check.append(req_isend)
593 | #self.killed_request_list.append(req_isend)
594 | # update counters
595 | mod_avail_index-=1
596 | channel_index-=1
597 | else:
598 | continue
599 | else:
600 | if output.size() == self.input[i+1].grad.size():
601 | output.backward(self.input[i+1].grad.data)
602 | else:
603 | tmp_grad_output = self.input[i+1].grad.view(output.size())
604 | output.backward(tmp_grad_output)
605 |
606 | # since in resnet we do not use bias weight for conv layer
607 | if pd.isnull(self.full_modules[mod_avail_index].bias):
608 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
609 |
610 | if not pd.isnull(tmp_grad_weight):
611 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
612 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
613 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step))
614 | req_send_check.append(req_isend)
615 | #self.killed_request_list.append(req_isend)
616 | channel_index-=1
617 | mod_counters_[mod_avail_index]=2
618 | # update counters
619 | mod_avail_index-=1
620 | else:
621 | continue
622 | else:
623 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
624 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad
625 |
626 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias):
627 | # we always send bias first
628 | if mod_counters_[mod_avail_index] == 0:
629 | grads = tmp_grad_bias.data.numpy().astype(np.float64)
630 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
631 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step))
632 | req_send_check.append(req_isend)
633 | #self.killed_request_list.append(req_isend)
634 | channel_index-=1
635 | mod_counters_[mod_avail_index]+=1
636 | elif mod_counters_[mod_avail_index] == 1:
637 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
638 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
639 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step))
640 | req_send_check.append(req_isend)
641 | #self.killed_request_list.append(req_isend)
642 | channel_index-=1
643 | mod_counters_[mod_avail_index]+=1
644 | # update counters
645 | mod_avail_index-=1
646 | else:
647 | continue
648 | # handle the remaining gradients here to send to parameter server
649 | while channel_index >= 0:
650 | req_send_check[-1].wait()
651 | if pd.isnull(self.full_modules[mod_avail_index].bias):
652 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
653 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
654 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
655 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step))
656 | req_send_check.append(req_isend)
657 | #self.killed_request_list.append(req_isend)
658 | channel_index-=1
659 | mod_counters_[mod_avail_index]=2
660 | # update counters
661 | mod_avail_index-=1
662 | else:
663 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
664 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad
665 | # we always send bias first
666 | if mod_counters_[mod_avail_index] == 0:
667 | grads = tmp_grad_bias.data.numpy().astype(np.float64)
668 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
669 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step))
670 | req_send_check.append(req_isend)
671 | #self.killed_request_list.append(req_isend)
672 | channel_index-=1
673 | mod_counters_[mod_avail_index]+=1
674 | elif mod_counters_[mod_avail_index] == 1:
675 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
676 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
677 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step))
678 | req_send_check.append(req_isend)
679 | #self.killed_request_list.append(req_isend)
680 | channel_index-=1
681 | mod_counters_[mod_avail_index]+=1
682 | # update counters
683 | mod_avail_index-=1
684 | return req_send_check
685 |
686 | def backward_single(self, g):
687 | for i, output in reversed(list(enumerate(self.output))):
688 | #print("Backward processing, step {}".format(i))
689 | #print("--------------------------------------------------------")
690 | if i == (len(self.output) - 1):
691 | # for last node, use g
692 | output.backward(g)
693 | else:
694 |
695 | #print(output.size())
696 | #print(self.input[i+1].grad.size())
697 | #tmp = self.input[i+1].grad.view(output.size())
698 | #print(tmp.size())
699 | #print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
700 |
701 | if output.size() == self.input[i+1].grad.size():
702 | output.backward(self.input[i+1].grad.data)
703 | else:
704 | tmp_grad_output = self.input[i+1].grad.view(output.size())
705 | output.backward(tmp_grad_output)
706 |
707 | def ResNetSplit18(kill_threshold):
708 | return ResNetSplit(BasicBlockSplit, [2,2,2,2], kill_threshold=kill_threshold)
709 |
710 | def ResNetSplit34():
711 | return ResNetSplit(BasicBlockSplit, [3,4,6,3])
712 |
713 | def ResNetSplit50():
714 | return ResNetSplit(Bottleneck, [3,4,6,3])
715 |
716 | def ResNetSplit101():
717 | return ResNetSplit(Bottleneck, [3,4,23,3])
718 |
719 | def ResNetSplit152():
720 | return ResNetSplit(Bottleneck, [3,8,36,3])
721 |
722 | if __name__ == "__main__":
723 | a = ResNetSplit18(1)
724 | print("Done!")
--------------------------------------------------------------------------------
/src/model_ops/vgg.py:
--------------------------------------------------------------------------------
1 | '''
2 | Modified from https://github.com/pytorch/vision.git
3 | '''
4 | import math
5 |
6 | import torch.nn as nn
7 | import torch.nn.init as init
8 |
9 | __all__ = [
10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
11 | 'vgg19_bn', 'vgg19',
12 | ]
13 |
14 |
15 | class VGG(nn.Module):
16 | '''
17 | VGG model
18 | '''
19 | def __init__(self, features, num_classes=10):
20 | super(VGG, self).__init__()
21 | self.features = features
22 | self.classifier = nn.Sequential(
23 | nn.Dropout(),
24 | nn.Linear(512, 512),
25 | nn.ReLU(True),
26 | nn.Dropout(),
27 | nn.Linear(512, 512),
28 | nn.ReLU(True),
29 | nn.Linear(512, num_classes),
30 | )
31 | # Initialize weights
32 | for m in self.modules():
33 | if isinstance(m, nn.Conv2d):
34 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
35 | m.weight.data.normal_(0, math.sqrt(2. / n))
36 | m.bias.data.zero_()
37 |
38 |
39 | def forward(self, x):
40 | x = self.features(x)
41 | x = x.view(x.size(0), -1)
42 | x = self.classifier(x)
43 | return x
44 |
45 |
46 | def make_layers(cfg, batch_norm=False):
47 | layers = []
48 | in_channels = 3
49 | for v in cfg:
50 | if v == 'M':
51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
52 | else:
53 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
54 | if batch_norm:
55 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
56 | else:
57 | layers += [conv2d, nn.ReLU(inplace=True)]
58 | in_channels = v
59 | return nn.Sequential(*layers)
60 |
61 |
62 | cfg = {
63 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
64 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
65 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
66 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
67 | 512, 512, 512, 512, 'M'],
68 | }
69 |
70 |
71 | def vgg11():
72 | """VGG 11-layer model (configuration "A")"""
73 | return VGG(make_layers(cfg['A']))
74 |
75 |
76 | def vgg11_bn(num_classes=10):
77 | """VGG 11-layer model (configuration "A") with batch normalization"""
78 | return VGG(make_layers(cfg['A'], batch_norm=True), num_classes=num_classes)
79 |
80 |
81 | def vgg13():
82 | """VGG 13-layer model (configuration "B")"""
83 | return VGG(make_layers(cfg['B']))
84 |
85 |
86 | def vgg13_bn():
87 | """VGG 13-layer model (configuration "B") with batch normalization"""
88 | return VGG(make_layers(cfg['B'], batch_norm=True))
89 |
90 |
91 | def vgg16():
92 | """VGG 16-layer model (configuration "D")"""
93 | return VGG(make_layers(cfg['D']))
94 |
95 |
96 | def vgg16_bn():
97 | """VGG 16-layer model (configuration "D") with batch normalization"""
98 | return VGG(make_layers(cfg['D'], batch_norm=True))
99 |
100 |
101 | def vgg19():
102 | """VGG 19-layer model (configuration "E")"""
103 | return VGG(make_layers(cfg['E']))
104 |
105 |
106 | def vgg19_bn():
107 | """VGG 19-layer model (configuration 'E') with batch normalization"""
108 | return VGG(make_layers(cfg['E'], batch_norm=True))
--------------------------------------------------------------------------------
/src/nn_ops.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import time
4 |
5 | import numpy as np
6 | import torch
7 | from torch.autograd import Variable
8 |
9 | from model_ops.lenet import LeNet
10 | from model_ops.resnet import *
11 | from model_ops.resnet_split import *
12 | from model_ops.alexnet import *
13 |
14 | import numpy.linalg as LA
15 |
16 |
17 | def nuclear_indicator(grad, s):
18 | m, n = grad.shape
19 | return np.sum(s)*np.sqrt(m+n)
20 |
21 |
22 | def l1_indicator(grad):
23 | return np.linalg.norm(grad.reshape(-1), 1)
24 |
25 |
26 | def _resize_to_2d(x):
27 | """
28 | x.shape > 2
29 | If x.shape = (a, b, *c), assumed that each one of (a, b) pairs has relevant information in c.
30 | """
31 | shape = x.shape
32 | if x.ndim == 1:
33 | n = x.shape[0]
34 | return x.reshape((n//2, 2))
35 | if all([s == 1 for s in shape[2:]]):
36 | return x.reshape((shape[0], shape[1]))
37 | # each of (a, b) has related features
38 | x = x.reshape((shape[0], shape[1], -1))
39 | # stack those related features into a tall matrix
40 | x_tmp = x.reshape((shape[0]*shape[1], -1))
41 | tmp_shape = x_tmp.shape
42 | return x_tmp.reshape((int(tmp_shape[0]/2), int(tmp_shape[1]*2)))
43 |
44 |
45 | def _sample_svd(s, rank=0):
46 | if s[0] < 1e-6:
47 | return [0], np.array([1.0])
48 | probs = s / s[0] if rank == 0 else rank * s / s.sum()
49 | for i, p in enumerate(probs):
50 | if p > 1:
51 | probs[i]=1
52 | sampled_idx = []
53 | sample_probs = []
54 | for i, p in enumerate(probs):
55 | #if np.random.rand() < p:
56 | # random sampling from bernulli distribution
57 | if np.random.binomial(1, p):
58 | sampled_idx += [i]
59 | sample_probs += [p]
60 | rank_hat = len(sampled_idx)
61 | if rank_hat == 0: # or (rank != 0 and np.abs(rank_hat - rank) >= 3):
62 | return _sample_svd(s, rank=rank)
63 | return np.array(sampled_idx, dtype=int), np.array(sample_probs)
64 |
65 |
66 | def svd_encode(grad, **kwargs):
67 | orig_size = list(grad.shape)
68 | ndims = grad.ndim
69 | reshaped_flag = False
70 | if ndims != 2:
71 | grad = _resize_to_2d(grad)
72 | shape = list(grad.shape)
73 | ndims = len(shape)
74 | reshaped_flag = True
75 |
76 | if ndims == 2:
77 | u, s, vT = LA.svd(grad, full_matrices=False)
78 |
79 | nuclear_ind = nuclear_indicator(grad, s)
80 | l1_ind = l1_indicator(grad)
81 | print("Step: {}, Nuclear Indicator: {}, L1 Indicator: {}".format(
82 | kwargs['step'], nuclear_ind, l1_ind))
83 |
84 |
85 | '''this is a trial example, we use MNIST on LeNet for simple test here'''
86 | def accuracy(output, target, topk=(1,)):
87 | """Computes the precision@k for the specified values of k"""
88 | maxk = max(topk)
89 | batch_size = target.size(0)
90 |
91 | _, pred = output.topk(maxk, 1, True, True)
92 | pred = pred.t()
93 | correct = pred.eq(target.view(1, -1).expand_as(pred))
94 |
95 | res = []
96 | for k in topk:
97 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
98 | res.append(correct_k.mul_(100.0 / batch_size))
99 | return res
100 |
101 | class NN_Trainer(object):
102 | def __init__(self, **kwargs):
103 | self.batch_size = kwargs['batch_size']
104 | self.lr = kwargs['learning_rate']
105 | self.max_epochs = kwargs['max_epochs']
106 | self.momentum = kwargs['momentum']
107 | self.network_config = kwargs['network']
108 |
109 | def build_model(self):
110 | # build network
111 | if self.network_config == "LeNet":
112 | self.network=LeNet()
113 | elif self.network_config == "ResNet18":
114 | self.network=ResNet18()
115 | elif self.network_config == "ResNet34":
116 | self.network=ResNet34()
117 | elif self.network_config == "AlexNet":
118 | self.network=alexnet()
119 | # set up optimizer
120 | self.optimizer = torch.optim.SGD(self.network.parameters(), lr=self.lr, momentum=self.momentum)
121 | self.criterion = torch.nn.CrossEntropyLoss()
122 |
123 | def train_and_validate(self, train_loader, test_loader):
124 | # iterate of epochs
125 | for i in range(self.max_epochs):
126 | # change back to training mode
127 | self.network.train()
128 | for batch_idx, (data, y_batch) in enumerate(train_loader):
129 | iter_start_time = time.time()
130 | data, target = Variable(data), Variable(y_batch)
131 | self.optimizer.zero_grad()
132 | ################# backward on normal model ############################
133 |
134 | logits = self.network(data)
135 | loss = self.criterion(logits, target)
136 | loss.backward()
137 | #######################################################################
138 |
139 | ################ backward on splitted model ###########################
140 | '''
141 | logits = self.network(data)
142 | logits_1 = Variable(logits.data, requires_grad=True)
143 | loss = self.criterion(logits_1, target)
144 | loss.backward()
145 | self.network.backward_single(logits_1.grad)
146 | '''
147 | #######################################################################
148 | tmp_time_0 = time.time()
149 |
150 | #for param in self.network.parameters():
151 | # grads = param.grad.data.numpy().astype(np.float64)
152 | # svd_encode(grads, step=batch_idx)
153 |
154 | duration_backward = time.time()-tmp_time_0
155 |
156 | tmp_time_1 = time.time()
157 | self.optimizer.step()
158 | duration_update = time.time()-tmp_time_1
159 |
160 | # calculate training accuracy
161 | prec1, prec5 = accuracy(logits.data, y_batch, topk=(1, 5))
162 | # load the training info
163 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} Prec@1: {} Prec@5: {} Time Cost: {}'.format(
164 | i, batch_idx * len(data), len(train_loader.dataset),
165 | 100. * batch_idx / len(train_loader), loss.data[0],
166 | prec1.numpy()[0],
167 | prec5.numpy()[0], time.time()-iter_start_time))
168 | # we evaluate the model performance on end of each epoch
169 | self.validate(test_loader)
170 |
171 | def validate(self, test_loader):
172 | self.network.eval()
173 | test_loss = 0
174 | correct = 0
175 | prec1_counter_ = prec5_counter_ = batch_counter_ = 0
176 | for data, y_batch in test_loader:
177 | data, target = Variable(data, volatile=True), Variable(y_batch)
178 | output = self.network(data)
179 | test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
180 | #pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
181 | #correct += pred.eq(target.data.view_as(pred)).cpu().sum()
182 | prec1_tmp, prec5_tmp = accuracy(output.data, y_batch, topk=(1, 5))
183 | prec1_counter_ += prec1_tmp.numpy()[0]
184 | prec5_counter_ += prec5_tmp.numpy()[0]
185 | batch_counter_ += 1
186 | prec1 = prec1_counter_ / batch_counter_
187 | prec5 = prec5_counter_ / batch_counter_
188 | test_loss /= len(test_loader.dataset)
189 | print('Test set: Average loss: {:.4f}, Prec@1: {} Prec@5: {}'.format(test_loss, prec1, prec5))
190 |
--------------------------------------------------------------------------------
/src/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from . import adam, sgd
2 |
3 | __all__ = ['adam', 'sgd']
--------------------------------------------------------------------------------
/src/optim/adam.py:
--------------------------------------------------------------------------------
1 | '''
2 | modified version of Adam optimizer
3 | by Hongyi Wang
4 | '''
5 |
6 | import math
7 | import torch
8 | from torch.optim import Optimizer
9 |
10 |
11 | class Adam(Optimizer):
12 | """Implements Adam algorithm.
13 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
14 | Arguments:
15 | params (iterable): iterable of parameters to optimize or dicts defining
16 | parameter groups
17 | lr (float, optional): learning rate (default: 1e-3)
18 | betas (Tuple[float, float], optional): coefficients used for computing
19 | running averages of gradient and its square (default: (0.9, 0.999))
20 | eps (float, optional): term added to the denominator to improve
21 | numerical stability (default: 1e-8)
22 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
23 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
24 | algorithm from the paper `On the Convergence of Adam and Beyond`_
25 | .. _Adam\: A Method for Stochastic Optimization:
26 | https://arxiv.org/abs/1412.6980
27 | .. _On the Convergence of Adam and Beyond:
28 | https://openreview.net/forum?id=ryQu7f-RZ
29 | """
30 |
31 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
32 | weight_decay=0, amsgrad=False):
33 | defaults = dict(lr=lr, betas=betas, eps=eps,
34 | weight_decay=weight_decay, amsgrad=amsgrad)
35 | super(Adam, self).__init__(params, defaults)
36 |
37 | def step(self, grads, closure=None):
38 | """Performs a single optimization step.
39 | Arguments:
40 | closure (callable, optional): A closure that reevaluates the model
41 | and returns the loss.
42 | """
43 | loss = None
44 | if closure is not None:
45 | loss = closure()
46 |
47 | for group in self.param_groups:
48 | for i,p in enumerate(group['params']):
49 | grad = torch.from_numpy(grads[i]).float()
50 | if grad.is_sparse:
51 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
52 | amsgrad = group['amsgrad']
53 |
54 | state = self.state[p]
55 |
56 | # State initialization
57 | if len(state) == 0:
58 | state['step'] = 0
59 | # Exponential moving average of gradient values
60 | state['exp_avg'] = torch.zeros_like(p.data)
61 | # Exponential moving average of squared gradient values
62 | state['exp_avg_sq'] = torch.zeros_like(p.data)
63 | if amsgrad:
64 | # Maintains max of all exp. moving avg. of sq. grad. values
65 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
66 |
67 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
68 | if amsgrad:
69 | max_exp_avg_sq = state['max_exp_avg_sq']
70 | beta1, beta2 = group['betas']
71 |
72 | state['step'] += 1
73 |
74 | if group['weight_decay'] != 0:
75 | grad = grad.add(group['weight_decay'], p.data)
76 |
77 | # Decay the first and second moment running average coefficient
78 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
79 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
80 | if amsgrad:
81 | # Maintains the maximum of all 2nd moment running avg. till now
82 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
83 | # Use the max. for normalizing running avg. of gradient
84 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
85 | else:
86 | denom = exp_avg_sq.sqrt().add_(group['eps'])
87 |
88 | bias_correction1 = 1 - beta1 ** state['step']
89 | bias_correction2 = 1 - beta2 ** state['step']
90 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
91 |
92 | p.data.addcdiv_(-step_size, exp_avg, denom)
93 |
94 | return loss
--------------------------------------------------------------------------------
/src/optim/sgd.py:
--------------------------------------------------------------------------------
1 | '''
2 | modified version of SGD optimizer
3 | by Hongyi Wang
4 | '''
5 | import torch
6 | from torch.optim import Optimizer
7 |
8 |
9 | class SGD(Optimizer):
10 | r"""Implements stochastic gradient descent (optionally with momentum).
11 | Nesterov momentum is based on the formula from
12 | `On the importance of initialization and momentum in deep learning`__.
13 | Args:
14 | params (iterable): iterable of parameters to optimize or dicts defining
15 | parameter groups
16 | lr (float): learning rate
17 | momentum (float, optional): momentum factor (default: 0)
18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
19 | dampening (float, optional): dampening for momentum (default: 0)
20 | nesterov (bool, optional): enables Nesterov momentum (default: False)
21 | Example:
22 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
23 | >>> optimizer.zero_grad()
24 | >>> loss_fn(model(input), target).backward()
25 | >>> optimizer.step()
26 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
27 | .. note::
28 | The implementation of SGD with Momentum/Nesterov subtly differs from
29 | Sutskever et. al. and implementations in some other frameworks.
30 | Considering the specific case of Momentum, the update can be written as
31 | .. math::
32 | v = \rho * v + g \\
33 | p = p - lr * v
34 | where p, g, v and :math:`\rho` denote the parameters, gradient,
35 | velocity, and momentum respectively.
36 | This is in contrast to Sutskever et. al. and
37 | other frameworks which employ an update of the form
38 | .. math::
39 | v = \rho * v + lr * g \\
40 | p = p - v
41 | The Nesterov version is analogously modified.
42 | """
43 |
44 | def __init__(self, params, lr=0.1, momentum=0, dampening=0,
45 | weight_decay=0, nesterov=False):
46 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
47 | weight_decay=weight_decay, nesterov=nesterov)
48 | if nesterov and (momentum <= 0 or dampening != 0):
49 | raise ValueError("Nesterov momentum requires a momentum and zero dampening")
50 | super(SGD, self).__init__(params, defaults)
51 |
52 | def __setstate__(self, state):
53 | super(SGD, self).__setstate__(state)
54 | for group in self.param_groups:
55 | group.setdefault('nesterov', False)
56 |
57 | def step(self, grads, closure=None, cuda=False):
58 | """Performs a single optimization step.
59 | Arguments:
60 | closure (callable, optional): A closure that reevaluates the model
61 | and returns the loss.
62 | """
63 | loss = None
64 | if closure is not None:
65 | loss = closure()
66 |
67 | for group in self.param_groups:
68 | weight_decay = group['weight_decay']
69 | momentum = group['momentum']
70 | dampening = group['dampening']
71 | nesterov = group['nesterov']
72 |
73 | for i,p in enumerate(group['params']):
74 | d_p = torch.from_numpy(grads[i]).float().cuda() if cuda else torch.from_numpy(grads[i]).float()
75 | if weight_decay != 0:
76 | d_p.add_(weight_decay, p.data)
77 | if momentum != 0:
78 | param_state = self.state[p]
79 | if 'momentum_buffer' not in param_state:
80 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
81 | buf.mul_(momentum).add_(d_p)
82 | else:
83 | buf = param_state['momentum_buffer']
84 | buf.mul_(momentum).add_(1 - dampening, d_p)
85 | if nesterov:
86 | d_p = d_p.add(momentum, buf)
87 | else:
88 | d_p = buf
89 | p.data.add_(-group['lr'], d_p)
90 | return loss
--------------------------------------------------------------------------------
/src/output/models/README.md:
--------------------------------------------------------------------------------
1 | # This directory is used to save the intermediate model during the training process for evaluation
--------------------------------------------------------------------------------
/src/run_pytorch.sh:
--------------------------------------------------------------------------------
1 | mpirun -n 3 --hostfile hosts_address \
2 | python distributed_nn.py \
3 | --lr=0.01 \
4 | --lr-shrinkage=0.95 \
5 | --momentum=0.0 \
6 | --network=ResNet18 \
7 | --dataset=Cifar10 \
8 | --batch-size=128 \
9 | --test-batch-size=200 \
10 | --comm-type=Bcast \
11 | --num-aggregate=2 \
12 | --eval-freq=200 \
13 | --epochs=10 \
14 | --max-steps=1000000 \
15 | --svd-rank=3 \
16 | --quantization-level=4 \
17 | --bucket-size=512 \
18 | --code=sgd \
19 | --enable-gpu= \
20 | --train-dir=/home/ubuntu/
--------------------------------------------------------------------------------
/src/single_machine.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import threading
5 | import argparse
6 | import time
7 |
8 | import torch
9 | from torch.autograd import Variable
10 | from torch._utils import _flatten_tensors, _unflatten_tensors
11 | from torch.cuda.comm import broadcast_coalesced
12 | from torch.cuda import nccl
13 |
14 | import torch.nn as nn
15 | from torch.nn.parallel.replicate import replicate
16 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
17 | from torch.nn.parallel.parallel_apply import parallel_apply
18 | import torch.nn.functional as F
19 |
20 | from torchvision import datasets, transforms
21 |
22 | from nn_ops import NN_Trainer, accuracy
23 | from data_loader_ops.my_data_loader import DataLoader
24 |
25 | from cifar10 import cifar10
26 | from datasets import Cifar10Dataset
27 |
28 |
29 | def add_fit_args(parser):
30 | """
31 | parser : argparse.ArgumentParser
32 | return a parser added with args required by fit
33 | """
34 | # Training settings
35 | parser.add_argument('--batch-size', type=int, default=128, metavar='N',
36 | help='input batch size for training (default: 64)')
37 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
38 | help='input batch size for testing (default: 1000)')
39 | parser.add_argument('--epochs', type=int, default=100, metavar='N',
40 | help='number of epochs to train (default: 10)')
41 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
42 | help='learning rate (default: 0.01)')
43 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
44 | help='SGD momentum (default: 0.5)')
45 | parser.add_argument('--no-cuda', action='store_true', default=False,
46 | help='disables CUDA training')
47 | parser.add_argument('--seed', type=int, default=1, metavar='S',
48 | help='random seed (default: 1)')
49 | parser.add_argument('--log-interval', type=int, default=10, metavar='N',
50 | help='how many batches to wait before logging training status')
51 | parser.add_argument('--network', type=str, default='LeNet', metavar='N',
52 | help='which kind of network we are going to use, support LeNet and ResNet currently')
53 | parser.add_argument('--dataset', type=str, default='MNIST', metavar='N',
54 | help='which dataset used in training, MNIST and Cifar10 supported currently')
55 | args = parser.parse_args()
56 | return args
57 |
58 | # we use LeNet here for our simple case
59 | class LeNet(nn.Module):
60 | def __init__(self):
61 | super(LeNet, self).__init__()
62 | self.conv1 = nn.Conv2d(1, 20, 5, 1)
63 | self.conv2 = nn.Conv2d(20, 50, 5, 1)
64 | self.fc1 = nn.Linear(4*4*50, 500)
65 | self.fc2 = nn.Linear(500, 10)
66 | self.ceriation = nn.CrossEntropyLoss()
67 | def forward(self, x, target):
68 | x = self.conv1(x)
69 | x = F.max_pool2d(x, 2, 2)
70 | x = F.relu(x)
71 | x = self.conv2(x)
72 | x = F.max_pool2d(x, 2, 2)
73 | x = F.relu(x)
74 | x = x.view(-1, 4*4*50)
75 | x = self.fc1(x)
76 | x = self.fc2(x)
77 | loss = self.ceriation(x, target)
78 | return x, loss
79 | def name(self):
80 | return 'lenet'
81 |
82 | class LeNetLearner:
83 | """a deprecated class, please don't call this one in any time"""
84 | def __init__(self, rank, world_size, args):
85 | self._step_changed = False
86 | self._update_step = False
87 | self._new_step_queued = 0
88 | self._rank = rank
89 | self._world_size = world_size
90 | self._cur_step = 0
91 | self._next_step = self._cur_step + 1
92 | self._step_fetch_request = False
93 | self.max_num_epochs = args.epochs
94 | self.lr = args.lr
95 | self.momentum = args.momentum
96 |
97 | def build_model(self):
98 | self.network = LeNet()
99 |
100 | # only for test use
101 | self.module = self.network
102 |
103 | # this is only used for test
104 | self.optimizer = torch.optim.SGD(self.network.parameters(), lr=self.lr, momentum=self.momentum)
105 |
106 | def test_model(self):
107 | '''this is only for test, please don't call this function'''
108 | from copy import deepcopy
109 | self._module_copies = [deepcopy(self.module)]
110 | self.device_ids = []
111 |
112 | t = None
113 | for p in self.module.parameters():
114 | tp = type(p.data)
115 | if t is not None and t is not tp:
116 | raise ValueError("DistributedDataParallel requires all parameters' data to be of the same type")
117 | t = tp
118 |
119 | self.bucket_sizes = []
120 | self.bucket_map = {}
121 | MB = 1024 * 1024
122 | self.broadcast_bucket_size = 10 * MB # used for param sync before forward
123 | bucket_bytes_cap = 1 * MB
124 | bucket_bytes = bucket_bytes_cap # to init the first bucket immediately
125 | for param_tuple in zip(*map(lambda m: m.parameters(), self._module_copies)):
126 | if bucket_bytes >= bucket_bytes_cap:
127 | self.bucket_sizes.append(0)
128 | bucket_bytes = 0
129 | self.bucket_sizes[-1] += 1
130 | for p in param_tuple:
131 | self.bucket_map[p] = len(self.bucket_sizes) - 1
132 | bucket_bytes += p.numel() * p.element_size()
133 |
134 | self.buckets = [[[] for _ in range(len(self.device_ids))] for _ in range(len(self.bucket_sizes))]
135 | self.bucket_events = [[None] * len(self.device_ids) for _ in range(len(self.bucket_sizes))]
136 | self.reduced = [False] * len(self.bucket_sizes)
137 |
138 | def train(self, train_loader=None):
139 | self.network.train()
140 |
141 | # iterate of epochs
142 | for i in range(self.max_num_epochs):
143 | for batch_idx, (data, y_batch) in enumerate(train_loader):
144 | iter_start_time = time.time()
145 | data, target = Variable(data), Variable(y_batch)
146 | self.optimizer.zero_grad()
147 | logits, loss = self.network(data, target)
148 | tmp_time_0 = time.time()
149 | loss.backward()
150 |
151 | for params in self.network.parameters():
152 | print(params.grad.data.numpy())
153 | print('**********************************************************')
154 |
155 | if batch_idx == 5:
156 | self.update_state_dict()
157 |
158 | duration_backward = time.time()-tmp_time_0
159 |
160 | tmp_time_1 = time.time()
161 | self.optimizer.step()
162 | duration_update = time.time()-tmp_time_1
163 |
164 | print("backward duration: {}".format(duration_backward))
165 | print("update duration: {}".format(duration_update))
166 | # calculate training accuracy
167 | prec1, prec5 = accuracy(logits.data, y_batch, topk=(1, 5))
168 | # load the training info
169 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} Prec@1: {} Prec@5: {} Time Cost: {}'.format(
170 | i, batch_idx * len(data), len(train_loader.dataset),
171 | 100. * batch_idx / len(train_loader), loss.data[0],
172 | prec1.numpy()[0],
173 | prec5.numpy()[0], time.time()-iter_start_time))
174 |
175 | def update_state_dict(self):
176 | """for this test version, we set all params to zeros here"""
177 | # we need to build a state dict first
178 | new_state_dict = {}
179 | for key_name, param in self.network.state_dict().items():
180 | tmp_dict = {key_name: torch.FloatTensor(param.size()).zero_()}
181 | new_state_dict.update(tmp_dict)
182 | self.network.load_state_dict(new_state_dict)
183 |
184 |
185 | if __name__ == "__main__":
186 | args = add_fit_args(argparse.ArgumentParser(description='PyTorch MNIST Single Machine Test'))
187 |
188 | kwargs = {'batch_size':args.batch_size, 'learning_rate':args.lr, 'max_epochs':args.epochs, 'momentum':args.momentum, 'network':args.network}
189 |
190 | # load training and test set here:
191 | if args.dataset == "MNIST":
192 | training_set = datasets.MNIST('../data', train=True, download=True,
193 | transform=transforms.Compose([
194 | transforms.ToTensor(),
195 | transforms.Normalize((0.1307,), (0.3081,))]))
196 | train_loader = DataLoader(training_set, batch_size=args.batch_size, shuffle=True)
197 | test_loader = torch.utils.data.DataLoader(
198 | datasets.MNIST('../data', train=False, transform=transforms.Compose([
199 | transforms.ToTensor(),
200 | transforms.Normalize((0.1307,), (0.3081,))
201 | ])), batch_size=args.test_batch_size, shuffle=True)
202 | elif args.dataset == "Cifar10":
203 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
204 | std=[x/255.0 for x in [63.0, 62.1, 66.7]])
205 | transform_train = transforms.Compose([
206 | transforms.ToTensor(),
207 | transforms.Lambda(lambda x: F.pad(
208 | Variable(x.unsqueeze(0), requires_grad=False, volatile=True),
209 | (4,4,4,4),mode='reflect').data.squeeze()),
210 | transforms.ToPILImage(),
211 | transforms.RandomCrop(32),
212 | transforms.RandomHorizontalFlip(),
213 | transforms.ToTensor(),
214 | normalize,
215 | ])
216 | # data prep for test set
217 | transform_test = transforms.Compose([
218 | transforms.ToTensor(),
219 | normalize])
220 | # load training and test set here:
221 | training_set = datasets.CIFAR10(root='./cifar10_data', train=True,
222 | download=True, transform=transform_train)
223 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=args.batch_size,
224 | shuffle=True)
225 | testset = datasets.CIFAR10(root='./cifar10_data', train=False,
226 | download=True, transform=transform_test)
227 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size,
228 | shuffle=False)
229 | elif args.dataset == "ImageNet":
230 | # Data loading code
231 | traindir = os.path.join('/home/ubuntu/data/' 'train')
232 | valdir = os.path.join('/home/ubuntu/data/', 'val')
233 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
234 | std=[0.229, 0.224, 0.225])
235 |
236 | train_dataset = datasets.ImageFolder(
237 | traindir,
238 | transforms.Compose([
239 | transforms.RandomResizedCrop(224),
240 | transforms.RandomHorizontalFlip(),
241 | transforms.ToTensor(),
242 | normalize,
243 | ]))
244 | test_dataset = datasets.ImageFolder(
245 | valdir,
246 | transforms.Compose([
247 | transforms.ToTensor(),
248 | normalize,
249 | ]))
250 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
251 | shuffle=True)
252 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch_size,
253 | shuffle=False)
254 |
255 | nn_learner = NN_Trainer(**kwargs)
256 | nn_learner.build_model()
257 | nn_learner.train_and_validate(train_loader=train_loader, test_loader=test_loader)
--------------------------------------------------------------------------------
/src/sync_replicas_master_nn.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import time
3 | import copy
4 | from sys import getsizeof
5 | import warnings
6 | import pickle
7 |
8 | from mpi4py import MPI
9 | import numpy as np
10 |
11 | from nn_ops import NN_Trainer
12 |
13 | from model_ops.lenet import LeNet, LeNetSplit
14 | from model_ops.resnet import *
15 | from model_ops.resnet_split import *
16 | from model_ops.vgg import *
17 | from model_ops.alexnet import *
18 | from model_ops.fc_nn import FC_NN, FC_NN_Split
19 | from model_ops.densenet import DenseNet
20 |
21 | from optim.adam import Adam
22 | from optim.sgd import SGD
23 | from utils import decompress
24 |
25 | import torch
26 | import codings
27 |
28 | STEP_START_ = 1
29 | # use compression tool to make it run faster
30 | _FAKE_SGD = True
31 |
32 | def update_params_dist_version(param, avg_grad, learning_rate):
33 | '''
34 | update the network layer by layer
35 | '''
36 | assert param.shape == avg_grad.shape
37 | param -= learning_rate * avg_grad
38 | return param
39 |
40 |
41 | def accuracy(output, target, topk=(1,)):
42 | """Computes the precision@k for the specified values of k"""
43 | maxk = max(topk)
44 | batch_size = target.size(0)
45 |
46 | _, pred = output.topk(maxk, 1, True, True)
47 | pred = pred.t()
48 | correct = pred.eq(target.view(1, -1).expand_as(pred))
49 |
50 | res = []
51 | for k in topk:
52 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
53 | res.append(correct_k.mul_(100.0 / batch_size))
54 | return res
55 |
56 |
57 | class GradientAccumulator(object):
58 | '''
59 | a simple class to implement gradient aggregator like the `Conditional Accumulators` in tensorflow
60 | '''
61 | def __init__(self, module, num_worker, mode='None'):
62 | super(GradientAccumulator, self).__init__()
63 | # we will update this counter dynamically during the training process
64 | # the length of this counter should be number of fc layers in the network
65 | # we used list to contain gradients of layers
66 | self.gradient_aggregate_counter = []
67 | self.model_index_range = []
68 | self.gradient_aggregator = []
69 |
70 | for param_idx, param in enumerate(module.parameters()):
71 | tmp_aggregator = []
72 | for worker_idx in range(num_worker):
73 | _shape = param.size()
74 | if len(_shape) == 1:
75 | tmp_aggregator.append(bytearray(getsizeof(np.zeros((_shape[0]*2,)))*4))
76 | else:
77 | tmp_aggregator.append(bytearray(getsizeof(np.zeros(_shape))*4))
78 | # initialize the gradient aggragator
79 | self.gradient_aggregator.append(tmp_aggregator)
80 | self.gradient_aggregate_counter.append(0)
81 | self.model_index_range.append(param_idx)
82 |
83 | def meset_everything(self):
84 | self._meset_grad_counter()
85 | self._meset_grad_aggregator()
86 |
87 | def _meset_grad_counter(self):
88 | self.gradient_aggregate_counter = [0 for _ in self.gradient_aggregate_counter]
89 |
90 | def _meset_grad_aggregator(self):
91 | pass
92 |
93 |
94 | class SyncReplicasMaster_NN(NN_Trainer):
95 | def __init__(self, comm, **kwargs):
96 | '''master node here, no rank needed since the rank will always be 0 for master node'''
97 | self.comm = comm # get MPI communicator object
98 | self.world_size = comm.Get_size() # total number of processes
99 | self.cur_step = STEP_START_
100 | # initial learning rate:
101 | self.lr = kwargs['learning_rate']
102 | # we use this parameter to shrink the step size per epoch
103 | self._lr_shrinkage = kwargs['lr_shrinkage']
104 | self._base_lr = kwargs['learning_rate']
105 | # TODO(hwang): change this to a more sophisticated version later
106 | self.shrinkage_freq = 50
107 | self.shrink_counter = 0
108 |
109 | self.momentum = kwargs['momentum']
110 | self.network_config = kwargs['network']
111 | self.comm_type = kwargs['comm_method']
112 |
113 | self._num_grad_to_collect = self.world_size - 1
114 | # used to aggregate tmp gradients, the length is the same as # of fc layer
115 | self._grad_aggregate_buffer = []
116 | self._model_shapes = []
117 | self._first_grad_received = False
118 | self._eval_freq = kwargs['eval_freq']
119 | self._train_dir = kwargs['train_dir']
120 | self._max_steps = kwargs['max_steps']
121 | self._compress = kwargs['compress']
122 |
123 | self._enable_gpu = kwargs['enable_gpu']
124 | self._num_aggregate = kwargs['num_aggregate']
125 |
126 | self._svd_rank = kwargs['svd_rank']
127 | self._quantization_level = kwargs['quantization_level']
128 | self._bucket_size = kwargs['bucket_size']
129 | self._r = self._svd_rank
130 |
131 | ############ will be deprecated soon #############################
132 | self._eval_batch_size = 1000
133 |
134 | if kwargs['code'] == 'sgd':
135 | if not _FAKE_SGD:
136 | self._coder = codings.svd.SVD(compress=False)
137 | else:
138 | self._coder = codings.lossless_compress.LosslessCompress()
139 | elif kwargs['code'] == 'svd':
140 | print("train.py, svd_rank =", self._svd_rank)
141 | self._coder = codings.svd.SVD(random_sample=False, rank=self._svd_rank,
142 | compress=True)
143 | else:
144 | raise ValueError('args.code not recognized')
145 |
146 | def build_model(self, num_classes=10):
147 | # build network
148 | if self.network_config == "LeNet":
149 | self.network=LeNet()
150 | elif self.network_config == "ResNet18":
151 | self.network=ResNet18(num_classes=num_classes)
152 | elif self.network_config == "ResNet34":
153 | self.network=ResNet34(num_classes=num_classes)
154 | elif self.network_config == "FC":
155 | self.network=FC_NN()
156 | elif self.network_config == "DenseNet":
157 | self.network=DenseNet(growthRate=40, depth=190, reduction=0.5,
158 | bottleneck=True, nClasses=10)
159 | elif self.network_config == "VGG11":
160 | self.network=vgg11_bn(num_classes)
161 | elif self.network_config == "AlexNet":
162 | self.network=alexnet(num_classes=10)
163 |
164 | # TODO(hwang): make sure this is useful
165 | self.optimizer = SGD(self.network.parameters(), lr=self.lr, momentum=self.momentum)
166 | # assign a gradient accumulator to collect gradients from workers
167 | self.grad_accumulator = GradientAccumulator(self.network, self.world_size-1, self._compress)
168 | self.init_model_shapes()
169 | # enable GPU here
170 | if self._enable_gpu:
171 | self.network.cuda()
172 |
173 | def train(self):
174 | # the first step we need to do here is to sync fetch the inital worl_step from the parameter server
175 | # we still need to make sure the value we fetched from parameter server is 1
176 | # please note that step is start from one here
177 | self.async_bcast_step()
178 |
179 | # fake test here:
180 | for i in range(1, self._max_steps+1):
181 | # switch back to training mode
182 | self.network.train()
183 | self._first_grad_received = False
184 | enough_gradients_received = False
185 |
186 | print("Master node is entering step: {}".format(i))
187 |
188 | self.async_bcast_step()
189 |
190 | self.async_bcast_layer_weights_bcast()
191 |
192 | # set the gradient fetch step and gather the request
193 | gradient_fetch_requests=self.async_fetch_gradient_start()
194 |
195 | coded_msgs = {}
196 | # wait for enough gradients to be aggregated:
197 | gather_start_time = time.time()
198 | while not enough_gradients_received:
199 | status = MPI.Status()
200 | source, code = MPI.Request.waitany(requests=gradient_fetch_requests, status=status)
201 | layer_index = status.tag-88
202 | if layer_index not in coded_msgs.keys():
203 | coded_msgs[layer_index] = [code]
204 | else:
205 | coded_msgs[layer_index].append(code)
206 |
207 | self.grad_accumulator.gradient_aggregate_counter[layer_index] += 1
208 |
209 | #print(self.grad_accumulator.gradient_aggregate_counter)
210 | #print('---------------------------------------------------------------------')
211 |
212 | enough_gradients_received = True
213 | for j in self.grad_accumulator.gradient_aggregate_counter:
214 | enough_gradients_received = enough_gradients_received and (j >= self._num_grad_to_collect)
215 | gather_duration = time.time() - gather_start_time
216 |
217 | decode_start = time.time()
218 | self._decode(coded_msgs)
219 | decode_dur = time.time() - decode_start
220 | # update `state_dict` in pytorch modules
221 | print("Master: Step: {}, Decode Cost: {}, Cur lr {}, Gather: {}".format(self.cur_step, decode_dur, self.lr, gather_duration))
222 | self._model_update()
223 |
224 | # reset essential elements
225 | self.meset_grad_buffer()
226 | self.grad_accumulator.meset_everything()
227 | # save model for validation in a pre-specified frequency
228 | #if self.cur_step%self._eval_freq == 0:
229 | # if "ResNet" not in self.network_config:
230 | # self._save_model(file_path=self._generate_model_path())
231 | self.cur_step += 1
232 | if self.cur_step % self.shrinkage_freq == 0:
233 | self.shrink_counter += 1
234 | self.lr = self._base_lr * self._lr_shrinkage ** self.shrink_counter
235 |
236 | def _model_update(self):
237 | # gradient shipped from workers are averaged and update the model
238 | self._grad_aggregate_buffer = map(lambda x: x / float(self._num_grad_to_collect), self._grad_aggregate_buffer)
239 | self.optimizer.step(grads=self._grad_aggregate_buffer, cuda=self._enable_gpu)
240 |
241 | def init_model_shapes(self):
242 | for p_index, p in enumerate(self.network.parameters()):
243 | self._model_shapes.append(p.size())
244 | self._grad_aggregate_buffer.append(np.zeros(p.size()))
245 |
246 | def async_bcast_step(self):
247 | req_list = []
248 | for i in range(self.world_size):
249 | if i != 0:
250 | req_list.append(self.comm.isend(self.cur_step, dest=i, tag=10))
251 | for i in range(len(req_list)):
252 | req_list[i].wait()
253 |
254 | def async_bcast_layer_weights_async(self):
255 | request_layers = []
256 | for layer_idx, layer in enumerate(self.network.parameters()):
257 | request_workers = []
258 | layer_to_send = layer.data.numpy().astype(np.float64)
259 | for i in range(self.world_size):
260 | if i != 0:
261 | req = self.comm.Isend([layer_to_send, MPI.DOUBLE], dest=i, tag=11+layer_idx)
262 | request_workers.append(req)
263 |
264 | request_layers.append(request_workers)
265 | # TODO(hwang): check to see if these `wait` calls are necessary here
266 | for req_l in request_layers:
267 | for req_worker in req_l:
268 | req_worker.wait()
269 |
270 | def async_bcast_layer_weights_bcast(self):
271 | request_layers = []
272 | for layer_idx, layer in enumerate(self.network.parameters()):
273 | request_workers = []
274 | if self._enable_gpu:
275 | # copy data to CPU then do the communicaiton staff
276 | layer_to_send = layer.data.cpu().numpy().astype(np.float64)
277 | else:
278 | layer_to_send = layer.data.numpy().astype(np.float64)
279 | self.comm.Bcast([layer_to_send, MPI.DOUBLE], root=0)
280 |
281 | def async_fetch_gradient_start(self):
282 | '''
283 | make gradient fetch requests and return the request list
284 | '''
285 | gradient_fetch_requests = []
286 | for module_idx, module in enumerate(self.network.parameters()):
287 | for k in range(self._num_grad_to_collect):
288 | req = self.comm.irecv(self.grad_accumulator.gradient_aggregator[module_idx][k], source=k+1, tag=88+module_idx)
289 | gradient_fetch_requests.append(req)
290 | return gradient_fetch_requests
291 |
292 | def aggregate_gradient(self, gradient, layer_idx):
293 | '''
294 | keep in mind the gradient here is wrapped gradient, which means it contains `W` and `b`
295 | '''
296 | self._grad_aggregate_buffer[layer_idx] += gradient.numpy().astype(np.float64)
297 |
298 | def model_update(self, tmp_module):
299 | """write model fetched from parameter server to local model"""
300 | new_state_dict = {}
301 | model_counter_ = 0
302 | for param_idx,(key_name, param) in enumerate(self.network.state_dict().items()):
303 | # handle the case that `running_mean` and `running_var` contained in `BatchNorm` layer
304 | if "running_mean" in key_name or "running_var" in key_name:
305 | tmp_dict = {key_name : param}
306 | else:
307 | assert param.size() == tmp_module[model_counter_].shape
308 | tmp_dict = {key_name: torch.from_numpy(tmp_module[model_counter_])}
309 | model_counter_+=1
310 | new_state_dict.update(tmp_dict)
311 | self.network.load_state_dict(new_state_dict)
312 |
313 | def meset_grad_buffer(self):
314 | for i in range(len(self._grad_aggregate_buffer)):
315 | self._grad_aggregate_buffer[i][0] = np.zeros(self._grad_aggregate_buffer[i][0].shape)
316 | if len(self._grad_aggregate_buffer[i])==2:
317 | self._grad_aggregate_buffer[i][1] = np.zeros(self._grad_aggregate_buffer[i][1].shape)
318 |
319 | def _decode(self, coded_msgs):
320 | # k: `layer_index` v: coded gradients
321 | for index, (k, v) in enumerate(coded_msgs.iteritems()):
322 | for code in v:
323 | code = pickle.loads(code)
324 | grad=self._coder.decode(code)
325 | try:
326 | assert (grad.shape == self._model_shapes[k])
327 | except AssertionError:
328 | warnings.warn("shape dosen't match, should really be careful")
329 | self.aggregate_gradient(gradient=grad, layer_idx=k)
330 |
331 | def _generate_model_path(self):
332 | return self._train_dir+"model_step_"+str(self.cur_step)
333 |
334 | def _save_model(self, file_path):
335 | with open(file_path, "wb") as f_:
336 | torch.save(self.network.state_dict(), f_)
337 |
338 | def _evaluate_model(self, validation_loader):
339 | self.network.eval()
340 | prec1_counter_ = prec5_counter_ = batch_counter_ = 0
341 | # which indicate an epoch based validation is done
342 | while validation_loader.dataset.epochs_completed <= self._epoch_counter:
343 | eval_image_batch, eval_label_batch = validation_loader.next_batch(batch_size=self._eval_batch_size)
344 | X_batch, y_batch = Variable(eval_image_batch.float()), Variable(eval_label_batch.long())
345 | output = self.network(X_batch)
346 | prec1_tmp, prec5_tmp = accuracy(output.data, eval_label_batch.long(), topk=(1, 5))
347 | prec1_counter_ += prec1_tmp
348 | prec5_counter_ += prec5_tmp
349 | batch_counter_ += 1
350 | prec1 = prec1_counter_ / batch_counter_
351 | prec5 = prec5_counter_ / batch_counter_
352 | self._epoch_counter = validation_loader.dataset.epochs_completed
353 | if self._enable_gpu:
354 | prec1 = prec1.cpu().numpy()[0]
355 | prec5 = prec5.cpu().numpy()[0]
356 | else:
357 | prec1 = prec1.numpy()[0]
358 | prec5 = prec5.numpy()[0]
359 | print('Testset Performance: Cur Step:{} Prec@1: {} Prec@5: {}'.format(self.cur_step, prec1, prec5))
--------------------------------------------------------------------------------
/src/tiny_tuning_parser.py:
--------------------------------------------------------------------------------
1 | import re
2 | import argparse
3 |
4 | parser = argparse.ArgumentParser(description='Distributed Tunning')
5 | parser.add_argument('--tuning-dir', type=str, default='./tune/', metavar='N',
6 | help='directory to save temp tuning logs')
7 | parser.add_argument('--tuning-lr', type=float, default=0.125, metavar='N',
8 | help='candidate learning rate used during tunning')
9 | parser.add_argument('--num-workers', type=int, default=16,
10 | help='number of workers in the cluster')
11 | args = parser.parse_args()
12 |
13 | loss_stat = []
14 | with open(args.tuning_dir, 'rb') as file:
15 | for line in file.readlines():
16 | line_content = line.rstrip('\n')
17 | search = re.search(
18 | 'Worker: .*, Step: .*, Epoch: .* \[.* \(.*\)\], Loss: (.*), Time Cost: .*, Comp: .*, Encode: .*, Comm: .*, Msg\(MB\): .*',
19 | line_content)
20 | if search:
21 | loss = float(search.group(1))
22 | loss_stat.append(loss)
23 | try:
24 | assert len(loss_stat) == args.num_workers
25 | except AssertionError:
26 | print("Illeagel Number of Workers! ")
27 | print("Avged loss for lr candidate: {}=========>{}".format(args.tuning_lr, sum(loss_stat)/float(len(loss_stat))))
--------------------------------------------------------------------------------
/src/tune.sh:
--------------------------------------------------------------------------------
1 | tune_dir=~/grad_lossy_compression/src/tune/
2 | max_tuning_step=100
3 | method=svd
4 | mkdir ${tune_dir}
5 |
6 | echo "Start parameter tuning ..."
7 | for lr in 0.0078125 0.015625 0.03125 0.0625 0.125 0.25 0.5
8 | do
9 | echo "Trial running for learning rate: ${lr}"
10 | mpirun -n 17 --hostfile hosts_address \
11 | python distributed_nn.py \
12 | --lr=${lr} \
13 | --lr-shrinkage=1.0 \
14 | --momentum=0.0 \
15 | --network=ResNet18 \
16 | --dataset=Cifar10 \
17 | --batch-size=8 \
18 | --test-batch-size=200 \
19 | --comm-type=Bcast \
20 | --num-aggregate=16 \
21 | --eval-freq=200 \
22 | --epochs=10 \
23 | --max-steps=${max_tuning_step} \
24 | --svd-rank=3 \
25 | --quantization-level=4 \
26 | --bucket-size=512 \
27 | --code=${method} \
28 | --enable-gpu= \
29 | --train-dir=/home/ubuntu/ > ${tune_dir}${method}_lr_${lr}
30 |
31 | cat ${tune_dir}${method}_lr_${lr} | grep Step:\ ${max_tuning_step} > ${tune_dir}${method}_lr_${lr}_processing
32 | bash ~/killall.sh
33 | done
34 |
35 | for lr in 0.0078125 0.015625 0.03125 0.0625 0.125 0.25 0.5
36 | do
37 | echo "Logging out tunning results"
38 | python tiny_tuning_parser.py \
39 | --tuning-dir=${tune_dir}${method}_lr_${lr}_processing \
40 | --tuning-lr=${lr} \
41 | --num-workers=16
42 | done
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import blosc
2 |
3 | def compress(msg, level=0, name='blosclz'):
4 | """
5 | Compress a message.
6 | """
7 | if name in {'lz4', 'snappy'}:
8 | raise ValueError('Do not specify lz4 or snappy. I ran into hard to '
9 | 'debug issues when I did this. blosclz seems to work')
10 | # we always use default level for now
11 | code = blosc.compress(msg, cname=name)
12 | return bytearray(code)
13 |
14 | def decompress(code):
15 | msg = blosc.decompress(code)
16 | return msg
--------------------------------------------------------------------------------
/tools/config:
--------------------------------------------------------------------------------
1 | Host *
2 | StrictHostKeyChecking no
3 |
--------------------------------------------------------------------------------
/tools/hosts:
--------------------------------------------------------------------------------
1 | 172.31.28.130 deeplearning-worker1
2 | 172.31.29.102 deeplearning-worker2
3 | 172.31.27.116 deeplearning-worker3
4 | 172.31.28.111 deeplearning-worker4
5 | 172.31.27.53 deeplearning-worker5
6 | 172.31.22.220 deeplearning-worker6
7 | 172.31.19.79 deeplearning-worker7
8 | 172.31.18.28 deeplearning-worker8
9 | 172.31.16.213 deeplearning-worker9
10 |
--------------------------------------------------------------------------------
/tools/hosts_address:
--------------------------------------------------------------------------------
1 | 172.31.28.130
2 | 172.31.29.102
3 | 172.31.27.116
4 | 172.31.28.111
5 | 172.31.27.53
6 | 172.31.22.220
7 | 172.31.19.79
8 | 172.31.18.28
9 | 172.31.16.213
10 |
--------------------------------------------------------------------------------
/tools/hosts_alias:
--------------------------------------------------------------------------------
1 | deeplearning-worker1
2 | deeplearning-worker2
3 | deeplearning-worker3
4 | deeplearning-worker4
5 | deeplearning-worker5
6 | deeplearning-worker6
7 | deeplearning-worker7
8 | deeplearning-worker8
9 | deeplearning-worker9
10 |
--------------------------------------------------------------------------------
/tools/install.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # append hosts
4 | sudo bash -c "cat hosts >> /etc/hosts"
5 | cp config ~/.ssh/
6 |
7 | export DEEPLEARNING_WORKERS_COUNT=`wc -l < hosts`
8 |
9 | git config --global user.name hwang595
10 | git config --global user.email hongyiwang@cs.wisc.edu
11 |
12 | sudo apt-get update
13 | sudo apt-get install pdsh -y
14 | pdsh -R ssh -w deeplearning-worker[1-$DEEPLEARNING_WORKERS_COUNT] "sudo apt-get update; sudo apt-get install pdsh -y"
15 |
--------------------------------------------------------------------------------
/tools/local_script.sh:
--------------------------------------------------------------------------------
1 | KEY_PEM_DIR=/home/hwang/My_Code/AWS/HongyiWKeyPair.pem
2 | KEY_PEM_NAME=HongyiWKeyPair.pem
3 | PUB_IP_ADDR="$1"
4 | echo "Public address of master node: ${PUB_IP_ADDR}"
5 |
6 | ssh -o "StrictHostKeyChecking no" ubuntu@${PUB_IP_ADDR}
7 | scp -i ${KEY_PEM_DIR} ${KEY_PEM_DIR} ubuntu@${PUB_IP_ADDR}:~/.ssh
8 | scp -i ${KEY_PEM_DIR} hosts hosts_address config ubuntu@${PUB_IP_ADDR}:~/
9 | scp -i ${KEY_PEM_DIR} -r ~/My_Code/grad_lossy_compression ubuntu@${PUB_IP_ADDR}:~/
10 | ssh -i ${KEY_PEM_DIR} ubuntu@${PUB_IP_ADDR} 'sudo apt-get update; cp /home/ubuntu/grad_lossy_compression/tools/remote_script.sh ~/'
--------------------------------------------------------------------------------
/tools/pre_run.sh:
--------------------------------------------------------------------------------
1 | sudo apt-get update
2 | conda update -y -n base conda
3 | #conda install pytorch torchvision -y -c pytorch
4 | conda install -y pytorch=0.3.0 -c soumith
5 | conda install -y torchvision=0.1.8
6 | conda install -y -c anaconda python-blosc
7 | conda install -y -c anaconda mpi4py
8 | conda install -y libgcc
9 |
10 | # install and figure hdmedians
11 | cd ~
12 | sudo apt-get install -y gcc
13 | source /home/ubuntu/anaconda2/bin/activate ~/anaconda2
14 | git clone https://github.com/daleroberts/hdmedians.git
15 | cd hdmedians
16 | python setup.py install
--------------------------------------------------------------------------------
/tools/remote_script.sh:
--------------------------------------------------------------------------------
1 | KEY_PEM_NAME=HongyiWKeyPair.pem
2 | export DEEPLEARNING_WORKERS_COUNT=`wc -l < hosts`
3 |
4 | sudo bash -c "cat hosts >> /etc/hosts"
5 | cp config ~/.ssh/
6 |
7 | cd ~/.ssh
8 | eval `ssh-agent -s`
9 | ssh-add ${KEY_PEM_NAME}
10 | ssh-keygen -t rsa -b 4096 -C "hongyiwang.hdu@gmail.com"
11 |
12 | for i in $(seq 2 $DEEPLEARNING_WORKERS_COUNT);
13 | do
14 | scp -i ${KEY_PEM_NAME} id_rsa.pub deeplearning-worker${i}:~/.ssh
15 | scp -i ${KEY_PEM_NAME} -r /home/ubuntu/grad_lossy_compression deeplearning-worker${i}:~/.ssh
16 | ssh -i ${KEY_PEM_NAME} deeplearning-worker${i} 'cd ~/.ssh; cat id_rsa.pub >> authorized_keys'
17 | echo "Done writing public key to worker: deeplearning-worker${i}"
18 | done
--------------------------------------------------------------------------------
/tools/update_git_dir.sh:
--------------------------------------------------------------------------------
1 | cd ~
2 | KEY_PEM_NAME=HongyiScript.pem
3 | export DEEPLEARNING_WORKERS_COUNT=`wc -l < hosts`
4 |
5 | sudo bash -c "cat hosts >> /etc/hosts"
6 |
7 | for i in $(seq 2 $DEEPLEARNING_WORKERS_COUNT);
8 | do
9 | ssh -i ${KEY_PEM_NAME} deeplearning-worker${i} 'git config --global user.name hwang595; git config --global user.email hongyiwang@cs.wisc.edu; cd ~/pytorch_distributed_nn; git pull'
10 | echo "Done pull git repo on worker: deeplearning-worker${i}"
11 | done
--------------------------------------------------------------------------------