├── README.md ├── images └── SVdecay.jpg ├── src ├── README.md ├── codings │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── coding.cpython-36.pyc │ │ ├── qsgd.cpython-36.pyc │ │ ├── qsvd.cpython-36.pyc │ │ └── svd.cpython-36.pyc │ ├── coding.py │ ├── qsgd.py │ ├── svd.py │ └── utils.py ├── data │ └── data_prepare.py ├── data_loader_ops │ ├── __init__.py │ └── my_data_loader.py ├── data_prepare.sh ├── datasets.py ├── distributed_evaluator.py ├── distributed_nn.py ├── distributed_worker.py ├── evaluate_pytorch.sh ├── model_ops │ ├── __init__.py │ ├── alexnet.py │ ├── densenet.py │ ├── fc_nn.py │ ├── lenet.py │ ├── resnet.py │ ├── resnet_split.py │ └── vgg.py ├── nn_ops.py ├── optim │ ├── __init__.py │ ├── adam.py │ └── sgd.py ├── output │ └── models │ │ └── README.md ├── run_pytorch.sh ├── single_machine.py ├── sync_replicas_master_nn.py ├── tiny_tuning_parser.py ├── tune.sh └── utils.py └── tools ├── config ├── hosts ├── hosts_address ├── hosts_alias ├── install.sh ├── local_script.sh ├── pre_run.sh ├── pytorch_ec2.py ├── remote_script.sh └── update_git_dir.sh /README.md: -------------------------------------------------------------------------------- 1 | # Atomo: Communication-efficient Learning via Atomic Sparsification 2 | This repository contains source code for Atomo, a general framework for atomic sparsification of stochastic gradients. Please check [the full paper](http://papers.nips.cc/paper/8191-atomo-communication-efficient-learning-via-atomic-sparsification) for detailed information about this project. 3 | 4 | ## Overview: 5 | ATOMO is a general framework for atomic sparsification of stochastic gradients. Given a gradient, an atomic decomposition, 6 | and a sparsity budget, ATOMO gives a random unbiased sparsification of the atoms minimizing variance. ATOMO sets up and optimally solves a meta-optimization that minimizes the variance of the sparsified gradient, subject to the constraints 7 | that it is sparse on the atomic basis, and also is an unbiased estimator of the input. 8 | 9 |
10 | 11 | ## Depdendencies: 12 | Tested stable depdencises: 13 | * python 2.7 (Anaconda) 14 | * PyTorch 0.3.0 (*please note that, we're moving to PyTorch 0.4.0, and 1.0.x*) 15 | * torchvision 0.1.18 16 | * MPI4Py 0.3.0 17 | * python-blosc 1.5.0 18 | 19 | We highly recommend installing an [Anaconda](https://www.continuum.io/downloads) environment. 20 | You will get a high-quality BLAS library (MKL) and you get a controlled compiler version regardless of your Linux distro. 21 | 22 | We provide [this script](https://github.com/hwang595/ATOMO/blob/master/tools/pre_run.sh) to help you with building all dependencies. To do that you can run: 23 | ``` 24 | bash ./tools/pre_run.sh 25 | ``` 26 | 27 | ## Cluster Setup: 28 | For running on distributed cluster, the first thing you need do is to launch AWS EC2 instances. 29 | ### Launching Instances: 30 | [This script](https://github.com/hwang595/ps_pytorch/blob/master/tools/pytorch_ec2.py) helps you to launch EC2 instances automatically, but before running this script, you should follow [the instruction](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-getting-started.html) to setup AWS CLI on your local machine. 31 | After that, please edit this part in `./tools/pytorch_ec2.py` 32 | ``` python 33 | cfg = Cfg({ 34 | "name" : "PS_PYTORCH", # Unique name for this specific configuration 35 | "key_name": "NameOfKeyFile", # Necessary to ssh into created instances 36 | # Cluster topology 37 | "n_masters" : 1, # Should always be 1 38 | "n_workers" : 8, 39 | "num_replicas_to_aggregate" : "8", # deprecated, not necessary 40 | "method" : "spot", 41 | # Region speficiation 42 | "region" : "us-west-2", 43 | "availability_zone" : "us-west-2b", 44 | # Machine type - instance type configuration. 45 | "master_type" : "m4.2xlarge", 46 | "worker_type" : "m4.2xlarge", 47 | # please only use this AMI for pytorch 48 | "image_id": "ami-xxxxxxxx", # id of AMI 49 | # Launch specifications 50 | "spot_price" : "0.15", # Has to be a string 51 | # SSH configuration 52 | "ssh_username" : "ubuntu", # For sshing. E.G: ssh ssh_username@hostname 53 | "path_to_keyfile" : "/dir/to/NameOfKeyFile.pem", 54 | 55 | # NFS configuration 56 | # To set up these values, go to Services > ElasticFileSystem > Create new filesystem, and follow the directions. 57 | #"nfs_ip_address" : "172.31.3.173", # us-west-2c 58 | #"nfs_ip_address" : "172.31.35.0", # us-west-2a 59 | "nfs_ip_address" : "172.31.14.225", # us-west-2b 60 | "nfs_mount_point" : "/home/ubuntu/shared", # NFS base dir 61 | ``` 62 | For setting everything up on EC2 cluster, the easiest way is to setup one machine and create an AMI. Then use the AMI id for `image_id` in `pytorch_ec2.py`. Then, launch EC2 instances by running 63 | ``` 64 | python ./tools/pytorch_ec2.py launch 65 | ``` 66 | After all launched instances are ready (this may take a while), getting private ips of instances by 67 | ``` 68 | python ./tools/pytorch_ec2.py get_hosts 69 | ``` 70 | this will write ips into a file named `hosts_address`, which looks like 71 | ``` 72 | 172.31.16.226 (${PS_IP}) 73 | 172.31.27.245 74 | 172.31.29.131 75 | 172.31.18.108 76 | 172.31.18.174 77 | 172.31.17.228 78 | 172.31.16.25 79 | 172.31.30.61 80 | 172.31.29.30 81 | ``` 82 | After generating the `hosts_address` of all EC2 instances, running the following command will copy your keyfile to the parameter server (PS) instance whose address is always the first one in `hosts_address`. `local_script.sh` will also do some basic configurations e.g. clone this git repo 83 | ``` 84 | bash ./tool/local_script.sh ${PS_IP} 85 | ``` 86 | ### SSH related: 87 | At this stage, you should ssh to the PS instance and all operation should happen on PS. In PS setting, PS should be able to ssh to any compute node, [this part](https://github.com/hwang595/ATOMO/blob/master/tools/remote_script.sh#L8-L16) dose the job for you by running (after ssh to the PS) 88 | ``` 89 | bash ./tools/remote_script.sh 90 | ``` 91 | 92 | ## Prepare Datasets 93 | We currently support [MNIST](http://yann.lecun.com/exdb/mnist/) and [Cifar10](https://www.cs.toronto.edu/~kriz/cifar.html) datasets. Download, split, and transform datasets by (and `./tools/remote_script.sh` dose this for you) 94 | ``` 95 | bash ./src/data_prepare.sh 96 | ``` 97 | 98 | ## Job Launching 99 | Since this project is built on MPI, tasks are required to be launched by PS (or master) instance. `run_pytorch.sh` wraps job-launching process up. Commonly used options (arguments) are listed as following: 100 | 101 | | Argument | Comments | 102 | | ----------------------------- | ---------------------------------------- | 103 | | `n` | Number of processes (size of cluster) e.g. if we have P compute node and 1 PS, n=P+1. | 104 | | `hostfile` | A directory to the file that contains Private IPs of every node in the cluster, we use `hosts_address` here as [mentioned before](#launching-instances). | 105 | | `lr` | Inital learning rate that will be use. | 106 | | `momentum` | Value of momentum that will be use. | 107 | | `network` | Types of deep neural nets, currently `LeNet`, `ResNet-18/32/50/110/152`, and `VGGs` are supported. | 108 | | `dataset` | Datasets use for training. | 109 | | `batch-size` | Batch size for optimization algorithms. | 110 | | `test-batch-size` | Batch size used during model evaluation. | 111 | | `comm-type` | A fake parameter, please always set it to be `Bcast`. | 112 | | `num-aggregate` | Number of gradients required for the PS to aggregate. | 113 | | `max-steps` | The maximum number of iterations to train. | 114 | | `svd-rank` | The expected rank of gradients ATOMO used (which is the same as the sparsity budget `s` in our paper). | 115 | | `epochs` | The maximal number of epochs to train (somehow redundant). | 116 | | `eval-freq` | Frequency of iterations to evaluation the model. | 117 | | `enable-gpu`| Training on CPU/GPU, if CPU please leave this argument empty. | 118 | |`train-dir` | Directory to save model checkpoints for evaluation. | 119 | 120 | ## Model Evaluation 121 | [Distributed evaluator](https://github.com/hwang595/ATOMO/blob/master/src/distributed_evaluator.py) will fetch model checkpoints from the shared directory and evaluate model on validation set. 122 | To evaluate model, you can run 123 | ``` 124 | bash ./src/evaluate_pytorch.sh 125 | ``` 126 | with specified arguments. 127 | 128 | Evaluation arguments are listed as following: 129 | 130 | | Argument | Comments | 131 | | ----------------------------- | ---------------------------------------- | 132 | | `eval-batch-size` | Batch size (on validation set) used during model evaluation. | 133 | | `eval-freq` | Frequency of iterations to evaluation the model, should be set to the same value as [run_pytorch.sh](https://github.com/hwang595/ps_pytorch/blob/master/src/run_pytorch.sh). | 134 | | `network` | Types of deep neural nets, should be set to the same value as [run_pytorch.sh](https://github.com/hwang595/ps_pytorch/blob/master/src/run_pytorch.sh). | 135 | | `dataset` | Datasets use for training, should be set to the same value as [run_pytorch.sh](https://github.com/hwang595/ps_pytorch/blob/master/src/run_pytorch.sh). | 136 | | `model-dir` | Directory to save model checkpoints for evaluation, should be set to the same value as [run_pytorch.sh](https://github.com/hwang595/ps_pytorch/blob/master/src/run_pytorch.sh). | 137 | 138 | ## Future Work 139 | Those are potential directions we are actively working on, stay tuned! 140 | * Explore the use of Atomo with Fourier decompositions, due to its utility and prevalence in signal processing. 141 | * Examine how we can sparsify and compress gradients in a joint fashion to further reduce communication costs. 142 | * Explore jointly sparsification of the SVD and and its singular vectors. 143 | * Integrate ATOMO to state-of-the-art PS (or distributed) frameworks e.g. [Ray](https://rise.cs.berkeley.edu/projects/ray/). 144 | 145 | ## Citation 146 | 147 | ``` 148 | @inproceedings{wang2018atomo, 149 | title={ATOMO: Communication-efficient Learning via Atomic Sparsification}, 150 | author={Wang, Hongyi and Sievert, Scott and Liu, Shengchao and Charles, Zachary and Papailiopoulos, Dimitris and Wright, Stephen}, 151 | booktitle={Advances in Neural Information Processing Systems}, 152 | pages={9871--9882}, 153 | year={2018} 154 | } 155 | ``` 156 | -------------------------------------------------------------------------------- /images/SVdecay.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwang595/ATOMO/8a21e54938bc4b0809293b306eade3ab9307ea70/images/SVdecay.jpg -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | do remeber to transfer the `numpy.ndarray` to `numpy.float64` before push them to MPI send/receive buffer, otherwise there will be some data type transfer issues between numpy and MPI. 2 | 3 | # To use it on single machine: 4 | ``` 5 | python single_machine.py --dataset=MNIST/Cifar10 --network=LeNet/Resnet --batch-size=${BATCH_SIZE} 6 | ``` 7 | 8 | # To use it on distributed cluster: 9 | ``` 10 | mpirun -n ${NUM_WORKERS} --hostfile=${HOST_DIR} python distributed_nn.py --dataset=MNIST/Cifar10 --network=LeNet/Resnet --batch-size=${BATCH_SIZE} 11 | ``` 12 | 13 | # Run the whole thing automatically 14 | The first thing you need do is to launch AWS EC2 instances, you can do that using `tools/pytorch_ec2.py` by running the following command: 15 | ``` 16 | python pytorch_ec2.py launch 17 | ``` 18 | After the launch command are executed and all instances are initialized (this may cost several minutes), you need to fetch the host addresses: 19 | ``` 20 | python pytorch_ec2.py get_hosts 21 | ``` 22 | Then, copying essential configuration files and hosts files using the public address of master node (the first address in the `host` file): 23 | ``` 24 | sh local_script.sh ${MASTER_PUB_ADDR} 25 | ``` 26 | After that, launch to master node manually, running the remote script under `$HOME` dir: 27 | ``` 28 | sh remote_script.sh 29 | ``` 30 | This script will do the cluster setup and data preparation works for you. 31 | -------------------------------------------------------------------------------- /src/codings/__init__.py: -------------------------------------------------------------------------------- 1 | from .coding import Coding 2 | from .svd import SVD 3 | from .qsgd import QSGD 4 | import utils 5 | 6 | __all__ = ['coding', 'svd', 'qsgd', 'utils'] -------------------------------------------------------------------------------- /src/codings/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwang595/ATOMO/8a21e54938bc4b0809293b306eade3ab9307ea70/src/codings/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/codings/__pycache__/coding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwang595/ATOMO/8a21e54938bc4b0809293b306eade3ab9307ea70/src/codings/__pycache__/coding.cpython-36.pyc -------------------------------------------------------------------------------- /src/codings/__pycache__/qsgd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwang595/ATOMO/8a21e54938bc4b0809293b306eade3ab9307ea70/src/codings/__pycache__/qsgd.cpython-36.pyc -------------------------------------------------------------------------------- /src/codings/__pycache__/qsvd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwang595/ATOMO/8a21e54938bc4b0809293b306eade3ab9307ea70/src/codings/__pycache__/qsvd.cpython-36.pyc -------------------------------------------------------------------------------- /src/codings/__pycache__/svd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwang595/ATOMO/8a21e54938bc4b0809293b306eade3ab9307ea70/src/codings/__pycache__/svd.cpython-36.pyc -------------------------------------------------------------------------------- /src/codings/coding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Coding: 4 | def __init__(self, *args, **kwargs): 5 | self.codes = [] 6 | 7 | def encode(self, grad, *args, **kwargs): 8 | raise NotImplementedError() 9 | 10 | def decode(self, code, *args, **kwargs): 11 | raise NotImplementedError() -------------------------------------------------------------------------------- /src/codings/qsgd.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import numpy as np 3 | import numpy.linalg as LA 4 | from scipy import stats 5 | import torch 6 | import time 7 | import math 8 | from .coding import Coding 9 | 10 | import torch.nn.functional as F 11 | 12 | 13 | class QSGD(Coding): 14 | def __init__(self, scheme='qsgd', bucket_size=512, *args, **kwargs): 15 | self.scheme = scheme 16 | self._quantization_level = kwargs['quantization_level'] 17 | self._bucket_size=bucket_size 18 | 19 | def encode(self, v, **kwargs): 20 | if isinstance(v, (torch.Tensor, torch.cuda.FloatTensor)): 21 | w = v.cpu().numpy().flat[:] 22 | elif isinstance(v, np.ndarray): 23 | w = v.flat[:] 24 | else: 25 | raise ValueError("Object passed to encode not ndarray or torch.Tensor") 26 | 27 | if 'neo_bucket_size' in kwargs.keys(): 28 | bucket_size = min(self._bucket_size, kwargs['neo_bucket_size']) 29 | else: 30 | bucket_size = self._bucket_size 31 | # Apply bucketing 32 | if bucket_size != 0: 33 | code_buckets = [] 34 | shape = v.shape 35 | neo_kwargs = {'neo_bucket_size': 0} 36 | buckets = np.split(w, (w.shape[0] + bucket_size - 1) / bucket_size) 37 | for bucket in buckets: 38 | code = self.encode(bucket, **neo_kwargs) 39 | code_buckets.append(code) 40 | return {'code_buckets': code_buckets, 'shape': shape} 41 | 42 | if self.scheme == 'qsgd': 43 | norm = LA.norm(v) 44 | elif self.scheme == 'terngrad': 45 | norm = np.linalg.norm(w, ord=np.inf) 46 | limit = grad_clip_limit(w, clip_factor=2.5) 47 | w = np.clip(w, -limit, limit) 48 | 49 | s = (1 << self._quantization_level) - 1 50 | shape = v.shape 51 | 52 | num_int_each_64_bits = int(64 / (2 + self._quantization_level)) # number of element stored / 64-bits 53 | num_section = num_int_each_64_bits 54 | len_each_section = int((w.shape[0] + num_section - 1) / num_section) # number of 64-bits to ship w vector 55 | w = np.pad(w, (0, len_each_section * num_section - w.shape[0]), mode='constant') # pad w to length of total elements 56 | 57 | sign_array = np.sign(w) 58 | sign_array += 1 # -1, 0, 1 to 0, 1, 2 59 | sign_array = sign_array.astype('uint64') 60 | normalization_array = np.abs(w) / norm * s 61 | 62 | truncated_array = normalization_array.astype(int) # l <= \frac{s \|w\|_i}{\|w\|_2} <= l+1 63 | prob_array = normalization_array - truncated_array # \frac{s \|w\|_i}{\|w\|_2} - l i.e. p function p(a, s) = as - l 64 | dice_array = np.random.rand(len(prob_array)) 65 | xi_array = truncated_array + (dice_array > prob_array) # l+1 or l 66 | xi_array = xi_array.astype('uint64') 67 | 68 | xi_array = xi_array.reshape((num_section, len_each_section)) 69 | sign_array = sign_array.reshape((num_section, len_each_section)) 70 | 71 | neo_array = np.zeros(len_each_section) 72 | neo_array = neo_array.astype('uint64') 73 | 74 | for i in range(num_int_each_64_bits): 75 | xi = xi_array[i] 76 | sign = sign_array[i] 77 | neo_array <<= (2 + self._quantization_level) 78 | neo_array = neo_array | (sign << self._quantization_level | xi) 79 | 80 | code = {'neo': neo_array, 'norm': norm, 'quantization_level': self._quantization_level, 81 | 'len_each_section': len_each_section, 'num_int_each_64_bits': num_int_each_64_bits, 82 | 'shape': shape} 83 | 84 | if kwargs.pop('timings', False): 85 | data = {} 86 | return code, data 87 | return code 88 | 89 | def decode(self, code, cuda=False, implementation='numpy', codes=[], **kwargs): 90 | """ 91 | Decode the coding. 92 | ## NumPy 93 | 'comm_wait': 0.0728750228881836, 94 | 'decode_time': 0.1349341869354248, 95 | 'example_to_gpu': 0.0006515979766845703, 96 | 'grad_compute_time': 0.5815503597259521, 97 | 'grad_forward_pass': 0.23496603965759277, 98 | 'grad_variance_increase': 31.754316389320049, 99 | 'iallgather_prepare_time': 0.017401456832885742, 100 | 'isend_time': 0.029105424880981445, 101 | ## PT GPU 102 | """ 103 | if self.scheme == 'terngrad' and len(codes) > 0: 104 | code['norm'] = self._get_max_norm(codes) 105 | 106 | if implementation == 'numpy': 107 | if 'neo_bucket_size' in kwargs.keys(): 108 | bucket_size = min(self._bucket_size, kwargs['neo_bucket_size']) 109 | else: 110 | bucket_size = self._bucket_size 111 | # Decode from bucketing 112 | if bucket_size != 0: 113 | v_list = [] 114 | neo_kwargs = {'neo_bucket_size': 0} 115 | for code_bucket in code['code_buckets']: 116 | v = self.decode(code=code_bucket, cuda=cuda, implementation=implementation, codes=codes, **neo_kwargs) 117 | v_list.extend(v) 118 | v = np.array(v_list) 119 | v = v.reshape(code['shape']) 120 | else: 121 | norm = code['norm'] 122 | s = (1 << self._quantization_level) - 1 123 | 124 | real_size = np.prod(code['shape']) 125 | 126 | neo_array = code['neo'].astype('uint64') 127 | num_int_each_64_bits = code['num_int_each_64_bits'] 128 | num_section = num_int_each_64_bits 129 | len_each_section = code['len_each_section'] 130 | xi_array = np.ones((num_section, len_each_section)) 131 | sign_array = np.ones((num_section, len_each_section)) 132 | mask_for_xi = (1 << self._quantization_level) - 1 133 | mask_for_sign = 3 << self._quantization_level 134 | for i in range(num_int_each_64_bits)[::-1]: 135 | sign_array[i] = (neo_array & mask_for_sign) >> self._quantization_level 136 | xi_array[i] = neo_array & mask_for_xi 137 | neo_array >>= (2 + self._quantization_level) 138 | 139 | xi_array = xi_array.reshape(-1).astype('uint64') 140 | sign_array = sign_array.reshape(-1).astype('int8') 141 | sign_array -= 1 142 | v = sign_array * xi_array * norm / s 143 | 144 | v = v[:real_size] 145 | v = v.reshape(code['shape']) 146 | else: 147 | raise ValueError('Whoops, implementation') 148 | v = torch.Tensor(v) 149 | if cuda: 150 | v = v.cuda() 151 | return v 152 | 153 | def _get_max_norm(self, codes): 154 | scalars = [code['norm'] for code in codes] 155 | return max(scalars) 156 | 157 | def encode_cuda(self, v, **kwargs): 158 | if isinstance(v, torch.cuda.FloatTensor): 159 | w = v.view(-1) 160 | else: 161 | raise ValueError("Object passed wasn't set on GUDA, please check CUDA availability!") 162 | 163 | #norm = LA.norm(v) 164 | norm = torch.norm(w) 165 | 166 | s = (1 << self._quantization_level) - 1 167 | shape = v.size() 168 | 169 | num_int_each_64_bits = int(64 / (2 + self._quantization_level)) # number of element stored / 64-bits 170 | num_section = num_int_each_64_bits 171 | len_each_section = int((w.shape[0] + num_section - 1) / num_section) # number of 64-bits to ship w vector 172 | 173 | w = F.pad(w, (0, len_each_section * num_section - w.size()[0]), 'constant', 0) 174 | 175 | sign_array = torch.sign(w) 176 | 177 | sign_array += 1 # -1, 0, 1 to 0, 1, 2 178 | #sign_array = sign_array.astype('uint64') 179 | sign_array = sign_array.to(dtype=torch.int64) 180 | 181 | normalization_array = torch.abs(w) / norm * s 182 | 183 | #truncated_array = normalization_array.astype(int) 184 | truncated_array = normalization_array.to(dtype=torch.int) # l <= \frac{s \|w\|_i}{\|w\|_2} <= l+1 185 | 186 | prob_array = normalization_array - truncated_array.float() # \frac{s \|w\|_i}{\|w\|_2} - l i.e. p function p(a, s) = as - l 187 | 188 | dice_array = torch.rand(len(prob_array)).to(torch.device("cuda")) 189 | 190 | xi_array = truncated_array + (dice_array > prob_array).to(dtype=torch.int) # l+1 or l 191 | xi_array = xi_array.to(dtype=torch.int64) 192 | xi_array = xi_array.view((num_section, len_each_section)) 193 | 194 | sign_array = sign_array.view((num_section, len_each_section)) 195 | 196 | neo_array = torch.zeros(len_each_section).to(dtype=torch.int64).to(torch.device("cuda")) 197 | 198 | for i in range(num_int_each_64_bits): 199 | xi = xi_array[i] 200 | sign = sign_array[i] 201 | neo_array *= 2**(2 + self._quantization_level) 202 | sign *= 2**self._quantization_level 203 | sign += xi 204 | neo_array += sign.to(dtype=torch.int64) 205 | 206 | code = {'neo': neo_array, 'norm': norm, 'quantization_level': self._quantization_level, 207 | 'len_each_section': len_each_section, 'num_int_each_64_bits': num_int_each_64_bits, 208 | 'shape': shape} 209 | return code 210 | 211 | 212 | def grad_clip_limit(grad, clip_factor=2.5): 213 | """ Get the scalers.""" 214 | if clip_factor > 1.0e-5: 215 | return clip_factor * np.std(grad.flat[:]) 216 | return np.max(np.abs(grad.flat[:])) 217 | 218 | 219 | if __name__ == "__main__": 220 | a_cpu = torch.randn(20) 221 | a_cuda = a_cpu.to(torch.device("cuda")) 222 | 223 | kwargs = {'quantization_level':8} 224 | coder = QSGD(bucket_size=0, **kwargs) 225 | print(a_cuda.dtype) 226 | code_cpu = coder.encode(a_cpu) 227 | code_cuda = coder.encode_cuda(a_cuda) 228 | print("CPU compression: {}, Type: {}".format(code_cpu['neo'], code_cpu['neo'].dtype)) 229 | print("") 230 | print("CUDA compression: {}, Type: {}".format(code_cuda['neo'], code_cuda['neo'].dtype)) -------------------------------------------------------------------------------- /src/codings/svd.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | import numpy as np 3 | import numpy.linalg as LA 4 | import warnings 5 | import sys 6 | import time 7 | import torch 8 | 9 | from .coding import Coding 10 | from .utils import nuclear_indicator, l1_indicator 11 | 12 | def _resize_to_2d(x): 13 | """ 14 | x.shape > 2 15 | If x.shape = (a, b, *c), assumed that each one of (a, b) pairs has relevant information in c. 16 | """ 17 | shape = x.shape 18 | if x.ndim == 1: 19 | n = x.shape[0] 20 | return x.reshape((n//2, 2)) 21 | if all([s == 1 for s in shape[2:]]): 22 | return x.reshape((shape[0], shape[1])) 23 | # each of (a, b) has related features 24 | x = x.reshape((shape[0], shape[1], -1)) 25 | # stack those related features into a tall matrix 26 | x_tmp = x.reshape((shape[0]*shape[1], -1)) 27 | tmp_shape = x_tmp.shape 28 | return x_tmp.reshape((int(tmp_shape[0]/2), int(tmp_shape[1]*2))) 29 | 30 | def _resize_to_2d_cuda(x): 31 | """ 32 | x.shape > 2 33 | If x.shape = (a, b, *c), assumed that each one of (a, b) pairs has relevant information in c. 34 | """ 35 | shape = x.size() 36 | if x.dim() == 1: 37 | n = x.size()[0] 38 | return x.reshape((n//2, 2)) 39 | if all([s == 1 for s in shape[2:]]): 40 | return x.reshape((shape[0], shape[1])) 41 | # each of (a, b) has related features 42 | x = x.reshape((shape[0], shape[1], -1)) 43 | # stack those related features into a tall matrix 44 | x_tmp = x.reshape((shape[0]*shape[1], -1)) 45 | tmp_shape = x_tmp.shape 46 | return x_tmp.reshape((int(tmp_shape[0]/2), int(tmp_shape[1]*2))) 47 | 48 | 49 | def _sample_svd(s, rank=0): 50 | if s[0] < 1e-6: 51 | return [0], np.array([1.0]) 52 | probs = s / s[0] if rank == 0 else rank * s / s.sum() 53 | for i, p in enumerate(probs): 54 | if p > 1: 55 | probs[i]=1 56 | sampled_idx = [] 57 | sample_probs = [] 58 | for i, p in enumerate(probs): 59 | #if np.random.rand() < p: 60 | # random sampling from bernulli distribution 61 | if np.random.binomial(1, p): 62 | sampled_idx += [i] 63 | sample_probs += [p] 64 | rank_hat = len(sampled_idx) 65 | if rank_hat == 0: # or (rank != 0 and np.abs(rank_hat - rank) >= 3): 66 | return _sample_svd(s, rank=rank) 67 | return np.array(sampled_idx, dtype=int), np.array(sample_probs) 68 | 69 | 70 | class SVD(Coding): 71 | def __init__(self, compress=True, rank=0, random_sample=True, 72 | fetch_indicator=None, *args, **kwargs): 73 | self.svd_rank = rank 74 | self.random_sample = random_sample 75 | self.compress = compress 76 | # fetch indicator or not 77 | self.__fetch_indicator = fetch_indicator 78 | 79 | def encode(self, grad, **kwargs): 80 | # move to CPU; SVD is 5x faster on CPU (at least in torch) 81 | if not self.compress: 82 | shape = list(grad.shape) 83 | return {'grad': grad, 'encode': False}#, {} 84 | 85 | orig_size = list(grad.shape) 86 | ndims = grad.ndim 87 | reshaped_flag = False 88 | if ndims != 2: 89 | grad = _resize_to_2d(grad) 90 | shape = list(grad.shape) 91 | ndims = len(shape) 92 | reshaped_flag = True 93 | 94 | if ndims == 2: 95 | u, s, vT = LA.svd(grad, full_matrices=False) 96 | 97 | if self.__fetch_indicator: 98 | nuclear_indicator = nuclear_indicator(grad, s) 99 | l1_indicator = l1_indicator(grad) 100 | print("Step: {}, Nuclear Indicator: {}, L1 Indicator: {}".format( 101 | kwargs['step'], nuclear_indicator, l1_indicator)) 102 | 103 | if self.random_sample: 104 | i, probs = _sample_svd(s, rank=self.svd_rank) 105 | u = u[:, i] 106 | s = s[i] / probs 107 | # v = v[:, i] 108 | vT = vT[i, :] 109 | elif self.svd_rank > 0: 110 | u = u[:, :self.svd_rank] 111 | s = s[:self.svd_rank] 112 | # v = v[:, :self.svd_rank] 113 | vT = vT[:self.svd_rank, :] 114 | 115 | return {'u': u, 's': s, 'vT': vT, 'orig_size': orig_size, 116 | 'reshaped': reshaped_flag, 'encode': True, 117 | 'rank': self.svd_rank} 118 | return {'grad': grad, 'encode': False} 119 | 120 | def encode_cuda(self, grad, device, **kwargs): 121 | if not isinstance(grad, torch.cuda.FloatTensor): 122 | raise ValueError("Object passed wasn't set on GUDA, please check CUDA availability!") 123 | # taking SVD on GPU 124 | if not self.compress: 125 | shape = list(grad.shape) 126 | return {'grad': grad, 'encode': False}#, {} 127 | 128 | orig_size = list(grad.size()) 129 | ndims = grad.dim() 130 | 131 | reshaped_flag = False 132 | 133 | if ndims != 2: 134 | grad = _resize_to_2d_cuda(grad) 135 | shape = list(grad.size()) 136 | ndims = len(shape) 137 | reshaped_flag = True 138 | 139 | if ndims == 2: 140 | #u, s, vT = LA.svd(grad, full_matrices=False) 141 | u, s, vT = torch.svd(grad, some=True) 142 | 143 | if self.random_sample: 144 | i, probs = _sample_svd(s, rank=self.svd_rank) 145 | u = u[:, i] 146 | s = s[i] / torch.tensor(probs).float().to(device) 147 | # v = v[:, i] 148 | vT = vT[i, :] 149 | elif self.svd_rank > 0: 150 | u = u[:, :self.svd_rank] 151 | s = s[:self.svd_rank] 152 | # v = v[:, :self.svd_rank] 153 | vT = vT[:self.svd_rank, :] 154 | 155 | return {'u': u, 's': s, 'vT': vT, 'orig_size': orig_size, 156 | 'reshaped': reshaped_flag, 'encode': True, 157 | 'rank': self.svd_rank} 158 | return {'grad': grad, 'encode': False} 159 | 160 | def decode(self, encode_output, cuda=False, **kwargs): 161 | if isinstance(encode_output, tuple) and len(encode_output) == 1: 162 | encode_output = encode_output[0] 163 | encode = encode_output.get('encode', False) 164 | if not encode: 165 | grad = encode_output['grad'] 166 | grad = torch.Tensor(grad) 167 | if cuda: 168 | grad = grad.cuda(async=True) 169 | return grad 170 | 171 | u, s, vT = (encode_output[key] for key in ['u', 's', 'vT']) 172 | #grad = u @ np.diag(s) @ vT 173 | grad = np.dot(np.dot(u, np.diag(s)), vT) 174 | grad = torch.Tensor(grad) 175 | grad = grad.view(encode_output['orig_size']) 176 | if cuda: 177 | grad = grad.cuda(async=True) 178 | return grad 179 | 180 | if isinstance(u, np.ndarray): 181 | u = torch.Tensor(u) 182 | if isinstance(s, np.ndarray): 183 | s = torch.Tensor(s) 184 | if isinstance(vT, np.ndarray): 185 | vT = torch.Tensor(vT) 186 | if cuda: 187 | u = u.contiguous().cuda(async=True) 188 | s = s.contiguous().cuda(async=True) 189 | vT = vT.contiguous().cuda(async=True) 190 | # u = u.cuda(async=True) 191 | # s = s.cuda(async=True) 192 | # vT = vT.cuda(async=True) 193 | #grad_approx = u @ torch.diag(s) @ vT 194 | grad_approx = np.dot(np.dot(u, torch.diag(s)), vT) 195 | if encode_output.get('reshaped', False): 196 | grad_approx = grad_approx.view(encode_output['orig_size']) 197 | return grad_approx -------------------------------------------------------------------------------- /src/codings/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def nuclear_indicator(grad, s): 4 | m, n = grad.shape 5 | return np.sum(s)*np.sqrt(m+n) 6 | 7 | def l1_indicator(grad): 8 | return np.linalg.norm(grad.reshape(-1), 1) -------------------------------------------------------------------------------- /src/data/data_prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Since we need to initialize the dataset in parallel, pre-download data is necessary 3 | This script will do the job for you 4 | """ 5 | import torch 6 | from torchvision import datasets, transforms 7 | 8 | 9 | if __name__ == "__main__": 10 | training_set_mnist = datasets.MNIST('./mnist_data', train=True, download=True, 11 | transform=transforms.Compose([ 12 | transforms.ToTensor(), 13 | transforms.Normalize((0.1307,), (0.3081,))])) 14 | train_loader_mnist = torch.utils.data.DataLoader(training_set_mnist, batch_size=128, shuffle=True) 15 | test_loader_mnist = torch.utils.data.DataLoader( 16 | datasets.MNIST('./mnist_data', train=False, transform=transforms.Compose([ 17 | transforms.ToTensor(), 18 | transforms.Normalize((0.1307,), (0.3081,)) 19 | ])), batch_size=100, shuffle=True) 20 | trainset_cifar10 = datasets.CIFAR10(root='./cifar10_data', train=True, 21 | download=True, transform=transforms.Compose([ 22 | transforms.ToTensor(), 23 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 24 | ])) 25 | train_loader_cifar10 = torch.utils.data.DataLoader(trainset_cifar10, batch_size=128, 26 | shuffle=True) 27 | test_loader_cifar10 = torch.utils.data.DataLoader( 28 | datasets.CIFAR10('./cifar10_data', train=False, transform=transforms.Compose([ 29 | transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 30 | ])), batch_size=100, shuffle=True) 31 | # load training and test set here: 32 | training_set = datasets.CIFAR100(root='./cifar100_data', train=True, 33 | download=True, transform=transforms.Compose([ 34 | transforms.ToTensor(), 35 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 36 | ])) 37 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=128, 38 | shuffle=True) 39 | testset = datasets.CIFAR100(root='./cifar100_data', train=False, 40 | download=True, transform=transforms.Compose([ 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 43 | ])) 44 | test_loader = torch.utils.data.DataLoader(testset, batch_size=1000, 45 | shuffle=False) -------------------------------------------------------------------------------- /src/data_loader_ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwang595/ATOMO/8a21e54938bc4b0809293b306eade3ab9307ea70/src/data_loader_ops/__init__.py -------------------------------------------------------------------------------- /src/data_loader_ops/my_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing as multiprocessing 3 | from torch.utils.data.sampler import SequentialSampler, RandomSampler, BatchSampler 4 | import collections 5 | import sys 6 | import traceback 7 | import threading 8 | 9 | PY2 = sys.version_info[0] == 2 10 | PY3 = sys.version_info[0] == 3 11 | 12 | 13 | if PY2: 14 | string_classes = basestring 15 | else: 16 | string_classes = (str, bytes) 17 | 18 | 19 | if sys.version_info[0] == 2: 20 | import Queue as queue 21 | else: 22 | import queue 23 | 24 | 25 | _use_shared_memory = False 26 | """Whether to use shared memory in default_collate""" 27 | 28 | 29 | class ExceptionWrapper(object): 30 | "Wraps an exception plus traceback to communicate across threads" 31 | 32 | def __init__(self, exc_info): 33 | self.exc_type = exc_info[0] 34 | self.exc_msg = "".join(traceback.format_exception(*exc_info)) 35 | 36 | 37 | def _worker_loop(dataset, index_queue, data_queue, collate_fn): 38 | global _use_shared_memory 39 | _use_shared_memory = True 40 | 41 | torch.set_num_threads(1) 42 | while True: 43 | r = index_queue.get() 44 | if r is None: 45 | data_queue.put(None) 46 | break 47 | idx, batch_indices = r 48 | try: 49 | samples = collate_fn([dataset[i] for i in batch_indices]) 50 | except Exception: 51 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 52 | else: 53 | data_queue.put((idx, samples)) 54 | 55 | 56 | def _pin_memory_loop(in_queue, out_queue, done_event): 57 | while True: 58 | try: 59 | r = in_queue.get() 60 | except: 61 | if done_event.is_set(): 62 | return 63 | raise 64 | if r is None: 65 | break 66 | if isinstance(r[1], ExceptionWrapper): 67 | out_queue.put(r) 68 | continue 69 | idx, batch = r 70 | try: 71 | batch = pin_memory_batch(batch) 72 | except Exception: 73 | out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 74 | else: 75 | out_queue.put((idx, batch)) 76 | 77 | 78 | numpy_type_map = { 79 | 'float64': torch.DoubleTensor, 80 | 'float32': torch.FloatTensor, 81 | 'float16': torch.HalfTensor, 82 | 'int64': torch.LongTensor, 83 | 'int32': torch.IntTensor, 84 | 'int16': torch.ShortTensor, 85 | 'int8': torch.CharTensor, 86 | 'uint8': torch.ByteTensor, 87 | } 88 | 89 | 90 | def default_collate(batch): 91 | "Puts each data field into a tensor with outer dimension batch size" 92 | if torch.is_tensor(batch[0]): 93 | out = None 94 | if _use_shared_memory: 95 | # If we're in a background process, concatenate directly into a 96 | # shared memory tensor to avoid an extra copy 97 | numel = sum([x.numel() for x in batch]) 98 | storage = batch[0].storage()._new_shared(numel) 99 | out = batch[0].new(storage) 100 | return torch.stack(batch, 0, out=out) 101 | elif type(batch[0]).__module__ == 'numpy': 102 | elem = batch[0] 103 | if type(elem).__name__ == 'ndarray': 104 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 105 | if elem.shape == (): # scalars 106 | py_type = float if elem.dtype.name.startswith('float') else int 107 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 108 | elif isinstance(batch[0], int): 109 | return torch.LongTensor(batch) 110 | elif isinstance(batch[0], float): 111 | return torch.DoubleTensor(batch) 112 | elif isinstance(batch[0], string_classes): 113 | return batch 114 | elif isinstance(batch[0], collections.Mapping): 115 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 116 | elif isinstance(batch[0], collections.Sequence): 117 | transposed = zip(*batch) 118 | return [default_collate(samples) for samples in transposed] 119 | 120 | raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}" 121 | .format(type(batch[0])))) 122 | 123 | 124 | def pin_memory_batch(batch): 125 | if torch.is_tensor(batch): 126 | return batch.pin_memory() 127 | elif isinstance(batch, string_classes): 128 | return batch 129 | elif isinstance(batch, collections.Mapping): 130 | return {k: pin_memory_batch(sample) for k, sample in batch.items()} 131 | elif isinstance(batch, collections.Sequence): 132 | return [pin_memory_batch(sample) for sample in batch] 133 | else: 134 | return batch 135 | 136 | 137 | class DataLoaderIter(object): 138 | "Iterates once over the DataLoader's dataset, as specified by the sampler" 139 | 140 | def __init__(self, loader): 141 | self.dataset = loader.dataset 142 | self.collate_fn = loader.collate_fn 143 | self.batch_sampler = loader.batch_sampler 144 | self.num_workers = loader.num_workers 145 | self.pin_memory = loader.pin_memory 146 | self.done_event = threading.Event() 147 | 148 | self.sample_iter = iter(self.batch_sampler) 149 | 150 | if self.num_workers > 0: 151 | self.index_queue = multiprocessing.SimpleQueue() 152 | self.data_queue = multiprocessing.SimpleQueue() 153 | self.batches_outstanding = 0 154 | self.shutdown = False 155 | self.send_idx = 0 156 | self.rcvd_idx = 0 157 | self.reorder_dict = {} 158 | 159 | self.workers = [ 160 | multiprocessing.Process( 161 | target=_worker_loop, 162 | args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn)) 163 | for _ in range(self.num_workers)] 164 | 165 | for w in self.workers: 166 | w.daemon = True # ensure that the worker exits on process exit 167 | w.start() 168 | 169 | if self.pin_memory: 170 | in_data = self.data_queue 171 | self.data_queue = queue.Queue() 172 | self.pin_thread = threading.Thread( 173 | target=_pin_memory_loop, 174 | args=(in_data, self.data_queue, self.done_event)) 175 | self.pin_thread.daemon = True 176 | self.pin_thread.start() 177 | 178 | # prime the prefetch loop 179 | for _ in range(2 * self.num_workers): 180 | self._put_indices() 181 | 182 | def __len__(self): 183 | return len(self.batch_sampler) 184 | 185 | def __next__(self): 186 | if self.num_workers == 0: # same-process loading 187 | #TODO(hwang): try to figure out what's happening right here and fix this issue 188 | indices = next(self.sample_iter) # may raise StopIteration 189 | batch = self.collate_fn([self.dataset[i] for i in indices]) 190 | if self.pin_memory: 191 | batch = pin_memory_batch(batch) 192 | return batch 193 | 194 | # check if the next sample has already been generated 195 | if self.rcvd_idx in self.reorder_dict: 196 | batch = self.reorder_dict.pop(self.rcvd_idx) 197 | return self._process_next_batch(batch) 198 | 199 | if self.batches_outstanding == 0: 200 | self._shutdown_workers() 201 | raise StopIteration 202 | 203 | while True: 204 | assert (not self.shutdown and self.batches_outstanding > 0) 205 | idx, batch = self.data_queue.get() 206 | self.batches_outstanding -= 1 207 | if idx != self.rcvd_idx: 208 | # store out-of-order samples 209 | self.reorder_dict[idx] = batch 210 | continue 211 | return self._process_next_batch(batch) 212 | 213 | next = __next__ # Python 2 compatibility 214 | 215 | def __iter__(self): 216 | return self 217 | 218 | def _put_indices(self): 219 | assert self.batches_outstanding < 2 * self.num_workers 220 | indices = next(self.sample_iter, None) 221 | if indices is None: 222 | return 223 | self.index_queue.put((self.send_idx, indices)) 224 | self.batches_outstanding += 1 225 | self.send_idx += 1 226 | 227 | def _process_next_batch(self, batch): 228 | self.rcvd_idx += 1 229 | self._put_indices() 230 | if isinstance(batch, ExceptionWrapper): 231 | raise batch.exc_type(batch.exc_msg) 232 | return batch 233 | 234 | def __getstate__(self): 235 | # TODO: add limited pickling support for sharing an iterator 236 | # across multiple threads for HOGWILD. 237 | # Probably the best way to do this is by moving the sample pushing 238 | # to a separate thread and then just sharing the data queue 239 | # but signalling the end is tricky without a non-blocking API 240 | raise NotImplementedError("DataLoaderIterator cannot be pickled") 241 | 242 | def _shutdown_workers(self): 243 | if not self.shutdown: 244 | self.shutdown = True 245 | self.done_event.set() 246 | for _ in self.workers: 247 | self.index_queue.put(None) 248 | 249 | def __del__(self): 250 | if self.num_workers > 0: 251 | self._shutdown_workers() 252 | 253 | 254 | class DataLoader(object): 255 | """ 256 | Data loader. Combines a dataset and a sampler, and provides 257 | single- or multi-process iterators over the dataset. 258 | Arguments: 259 | dataset (Dataset): dataset from which to load the data. 260 | batch_size (int, optional): how many samples per batch to load 261 | (default: 1). 262 | shuffle (bool, optional): set to ``True`` to have the data reshuffled 263 | at every epoch (default: False). 264 | sampler (Sampler, optional): defines the strategy to draw samples from 265 | the dataset. If specified, ``shuffle`` must be False. 266 | batch_sampler (Sampler, optional): like sampler, but returns a batch of 267 | indices at a time. Mutually exclusive with batch_size, shuffle, 268 | sampler, and drop_last. 269 | num_workers (int, optional): how many subprocesses to use for data 270 | loading. 0 means that the data will be loaded in the main process 271 | (default: 0) 272 | collate_fn (callable, optional): merges a list of samples to form a mini-batch. 273 | pin_memory (bool, optional): If ``True``, the data loader will copy tensors 274 | into CUDA pinned memory before returning them. 275 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, 276 | if the dataset size is not divisible by the batch size. If False and 277 | the size of dataset is not divisible by the batch size, then the last batch 278 | will be smaller. (default: False) 279 | """ 280 | 281 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, 282 | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False): 283 | self.dataset = dataset 284 | self.batch_size = batch_size 285 | self.num_workers = num_workers 286 | self.collate_fn = collate_fn 287 | self.pin_memory = pin_memory 288 | self.drop_last = drop_last 289 | #self._index_in_epoch = 0 290 | 291 | 292 | if batch_sampler is not None: 293 | if batch_size > 1 or shuffle or sampler is not None or drop_last: 294 | raise ValueError('batch_sampler is mutually exclusive with ' 295 | 'batch_size, shuffle, sampler, and drop_last') 296 | 297 | if sampler is not None and shuffle: 298 | raise ValueError('sampler is mutually exclusive with shuffle') 299 | 300 | if batch_sampler is None: 301 | if sampler is None: 302 | if shuffle: 303 | sampler = RandomSampler(dataset) 304 | else: 305 | sampler = SequentialSampler(dataset) 306 | batch_sampler = BatchSampler(sampler, batch_size, drop_last) 307 | 308 | self.sampler = sampler 309 | self.batch_sampler = batch_sampler 310 | self.data_iterator = DataLoaderIter(self) 311 | 312 | def __iter__(self): 313 | return DataLoaderIter(self) 314 | 315 | def __len__(self): 316 | return len(self.batch_sampler) 317 | 318 | def next_batch(self): 319 | return next(self.data_iterator) -------------------------------------------------------------------------------- /src/data_prepare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | python ./data/data_prepare.py -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | 4 | from torch.utils.data import Dataset 5 | import torch.utils.data as data 6 | from PIL import Image 7 | import os 8 | import os.path 9 | import numpy as np 10 | 11 | class MNISTDataset(Dataset): 12 | def __init__(self, dataset, transform): 13 | self.dataset = dataset 14 | self.images = dataset.images 15 | self.labels = dataset.labels 16 | self.transform = transform 17 | 18 | def __len__(self): 19 | return self.images.shape[0] 20 | 21 | def __getitem__(self, idx): 22 | data_sample = self.images[idx] 23 | label_sample = self.labels[idx] 24 | if self.transform: 25 | data_sample = self.transform(data_sample) 26 | label_sample = self.transform(label_sample) 27 | return data_sample, label_sample 28 | 29 | def next_batch(self, batch_size): 30 | image_batch, label_batch = self.dataset.next_batch(batch_size=batch_size) 31 | return torch.from_numpy(image_batch), torch.from_numpy(label_batch) 32 | # TODO(hwang): figure out why `ToTensor` caused error here 33 | #return self.transform(image_batch), self.transform(label_batch) 34 | 35 | class Cifar10Dataset(Dataset): 36 | def __init__(self, dataset, transform): 37 | self.dataset = dataset 38 | self.images = dataset.images 39 | self.labels = dataset.labels 40 | self.transform = transform 41 | 42 | def __len__(self): 43 | return self.images.shape[0] 44 | 45 | def __getitem__(self, idx): 46 | data_sample = self.images[idx] 47 | label_sample = self.labels[idx] 48 | if self.transform: 49 | data_sample = self.transform(data_sample) 50 | label_sample = self.transform(label_sample) 51 | return data_sample, label_sample 52 | 53 | def next_batch(self, batch_size): 54 | image_batch, label_batch = self.dataset.next_batch(batch_size=batch_size) 55 | return torch.from_numpy(image_batch), torch.from_numpy(label_batch) 56 | # TODO(hwang): figure out why `ToTensor` caused error here 57 | #return self.transform(image_batch), self.transform(label_batch) 58 | 59 | import torch.utils.data as data 60 | from PIL import Image 61 | import os 62 | import os.path 63 | import numpy as np 64 | 65 | import os 66 | import os.path 67 | import hashlib 68 | import errno 69 | 70 | def check_integrity(fpath, md5): 71 | if not os.path.isfile(fpath): 72 | return False 73 | md5o = hashlib.md5() 74 | with open(fpath, 'rb') as f: 75 | # read in 1MB chunks 76 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 77 | md5o.update(chunk) 78 | md5c = md5o.hexdigest() 79 | if md5c != md5: 80 | return False 81 | return True 82 | 83 | 84 | def download_url(url, root, filename, md5): 85 | from six.moves import urllib 86 | 87 | root = os.path.expanduser(root) 88 | fpath = os.path.join(root, filename) 89 | 90 | try: 91 | os.makedirs(root) 92 | except OSError as e: 93 | if e.errno == errno.EEXIST: 94 | pass 95 | else: 96 | raise 97 | 98 | # downloads file 99 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 100 | print('Using downloaded and verified file: ' + fpath) 101 | else: 102 | try: 103 | print('Downloading ' + url + ' to ' + fpath) 104 | urllib.request.urlretrieve(url, fpath) 105 | except: 106 | if url[:5] == 'https': 107 | url = url.replace('https:', 'http:') 108 | print('Failed download. Trying https -> http instead.' 109 | ' Downloading ' + url + ' to ' + fpath) 110 | urllib.request.urlretrieve(url, fpath) 111 | 112 | 113 | class SVHN(data.Dataset): 114 | """`SVHN `_ Dataset. 115 | Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset, 116 | we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which 117 | expect the class labels to be in the range `[0, C-1]` 118 | Args: 119 | root (string): Root directory of dataset where directory 120 | ``SVHN`` exists. 121 | split (string): One of {'train', 'test', 'extra'}. 122 | Accordingly dataset is selected. 'extra' is Extra training set. 123 | transform (callable, optional): A function/transform that takes in an PIL image 124 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 125 | target_transform (callable, optional): A function/transform that takes in the 126 | target and transforms it. 127 | download (bool, optional): If true, downloads the dataset from the internet and 128 | puts it in root directory. If dataset is already downloaded, it is not 129 | downloaded again. 130 | """ 131 | url = "" 132 | filename = "" 133 | file_md5 = "" 134 | 135 | split_list = { 136 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 137 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 138 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", 139 | "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], 140 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 141 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]} 142 | 143 | def __init__(self, root, split='train', 144 | transform=None, target_transform=None, download=False): 145 | self.root = os.path.expanduser(root) 146 | self.transform = transform 147 | self.target_transform = target_transform 148 | self.split = split # training set or test set or extra set 149 | 150 | if self.split not in self.split_list: 151 | raise ValueError('Wrong split entered! Please use split="train" ' 152 | 'or split="extra" or split="test"') 153 | 154 | self.url = self.split_list[split][0] 155 | self.filename = self.split_list[split][1] 156 | self.file_md5 = self.split_list[split][2] 157 | 158 | if download: 159 | self.download() 160 | 161 | if not self._check_integrity(): 162 | raise RuntimeError('Dataset not found or corrupted.' + 163 | ' You can use download=True to download it') 164 | 165 | # import here rather than at top of file because this is 166 | # an optional dependency for torchvision 167 | import scipy.io as sio 168 | 169 | # reading(loading) mat file as array 170 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename)) 171 | 172 | self.data = loaded_mat['X'] 173 | # loading from the .mat file gives an np array of type np.uint8 174 | # converting to np.int64, so that we have a LongTensor after 175 | # the conversion from the numpy array 176 | # the squeeze is needed to obtain a 1D tensor 177 | self.labels = loaded_mat['y'].astype(np.int64).squeeze() 178 | 179 | # the svhn dataset assigns the class label "10" to the digit 0 180 | # this makes it inconsistent with several loss functions 181 | # which expect the class labels to be in the range [0, C-1] 182 | np.place(self.labels, self.labels == 10, 0) 183 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 184 | 185 | def __getitem__(self, index): 186 | """ 187 | Args: 188 | index (int): Index 189 | Returns: 190 | tuple: (image, target) where target is index of the target class. 191 | """ 192 | img, target = self.data[index], int(self.labels[index]) 193 | 194 | # doing this so that it is consistent with all other datasets 195 | # to return a PIL Image 196 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 197 | 198 | if self.transform is not None: 199 | img = self.transform(img) 200 | 201 | if self.target_transform is not None: 202 | target = self.target_transform(target) 203 | 204 | return img, target 205 | 206 | def __len__(self): 207 | return len(self.data) 208 | 209 | def _check_integrity(self): 210 | root = self.root 211 | md5 = self.split_list[self.split][2] 212 | fpath = os.path.join(root, self.filename) 213 | return check_integrity(fpath, md5) 214 | 215 | def download(self): 216 | md5 = self.split_list[self.split][2] 217 | download_url(self.url, self.root, self.filename, md5) 218 | 219 | def __repr__(self): 220 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 221 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 222 | fmt_str += ' Split: {}\n'.format(self.split) 223 | fmt_str += ' Root Location: {}\n'.format(self.root) 224 | tmp = ' Transforms (if any): ' 225 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 226 | tmp = ' Target Transforms (if any): ' 227 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 228 | return fmt_str -------------------------------------------------------------------------------- /src/distributed_evaluator.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path 3 | import time 4 | import argparse 5 | from datetime import datetime 6 | import copy 7 | 8 | from mpi4py import MPI 9 | import numpy as np 10 | 11 | from nn_ops import NN_Trainer 12 | 13 | import torch 14 | from torch.autograd import Variable 15 | import torch.nn.functional as F 16 | from torchvision import datasets, transforms 17 | from torch.utils.data import DataLoader 18 | 19 | from model_ops.lenet import LeNet, LeNetSplit 20 | from model_ops.resnet import * 21 | from model_ops.resnet_split import * 22 | #from util import build_model 23 | 24 | 25 | def accuracy(output, target, topk=(1,)): 26 | """Computes the precision@k for the specified values of k""" 27 | maxk = max(topk) 28 | batch_size = target.size(0) 29 | 30 | _, pred = output.topk(maxk, 1, True, True) 31 | pred = pred.t() 32 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 33 | res = [] 34 | for k in topk: 35 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 36 | res.append(correct_k.mul_(100.0 / batch_size)) 37 | return res 38 | 39 | def add_fit_args(parser): 40 | """ 41 | parser : argparse.ArgumentParser 42 | return a parser added with args required by fit 43 | """ 44 | # Validation settings 45 | parser.add_argument('--eval-batch-size', type=int, default=10000, metavar='N', 46 | help='the batch size when doing model validation, complete at once on default') 47 | parser.add_argument('--eval-freq', type=int, default=50, metavar='N', 48 | help='it determines per how many step the model should be evaluated') 49 | parser.add_argument('--model-dir', type=str, default='output/models/', metavar='N', 50 | help='directory to save the temp model during the training process for evaluation') 51 | parser.add_argument('--dataset', type=str, default='MNIST', metavar='N', 52 | help='which dataset used in training, MNIST and Cifar10 supported currently') 53 | parser.add_argument('--network', type=str, default='LeNet', metavar='N', 54 | help='which kind of network we are going to use, support LeNet and ResNet currently') 55 | args = parser.parse_args() 56 | return args 57 | 58 | class DistributedEvaluator(NN_Trainer): 59 | ''' 60 | The DistributedEvaluator aims at providing a seperate node in the distributed cluster to evaluate 61 | the model on validation/test set and return the results 62 | In this version, the DistributedEvaluator will only load the model from the dir where the master 63 | save the model and do the evaluation task based on a user defined frequency 64 | ''' 65 | def __init__(self, **kwargs): 66 | self._cur_step = 0 67 | self._model_dir = kwargs['model_dir'] 68 | self._eval_freq = int(kwargs['eval_freq']) 69 | self._eval_batch_size = kwargs['eval_batch_size'] 70 | self.network_config = kwargs['network'] 71 | # this one is going to be used to avoid fetch the weights for multiple times 72 | self._layer_cur_step = [] 73 | 74 | def evaluate(self, validation_loader): 75 | # init objective to fetch at the begining 76 | self._next_step_to_fetch = self._cur_step + self._eval_freq 77 | self._num_batch_per_epoch = len(validation_loader) / self._eval_batch_size 78 | # check if next temp model exsits, if not we wait here else we continue to do the model evaluation 79 | while True: 80 | model_dir_=self._model_dir_generator(self._next_step_to_fetch) 81 | if os.path.isfile(model_dir_): 82 | self._load_model(model_dir_) 83 | print("Evaluator evaluating results on step {}".format(self._next_step_to_fetch)) 84 | self._evaluate_model(validation_loader) 85 | self._next_step_to_fetch += self._eval_freq 86 | else: 87 | # TODO(hwang): sleep appropriate period of time make sure to tune this parameter 88 | time.sleep(10) 89 | 90 | def _evaluate_model(self, test_loader): 91 | self.network.eval() 92 | test_loss = 0 93 | correct = 0 94 | prec1_counter_ = prec5_counter_ = batch_counter_ = 0 95 | for data, y_batch in test_loader: 96 | data, target = Variable(data, volatile=True), Variable(y_batch) 97 | 98 | output = self.network(data) 99 | test_loss += F.nll_loss(F.log_softmax(output), target, size_average=False).data[0] 100 | 101 | prec1_tmp, prec5_tmp = accuracy(output.data, target.data, topk=(1, 5)) 102 | prec1_counter_ += prec1_tmp.numpy()[0] 103 | prec5_counter_ += prec5_tmp.numpy()[0] 104 | batch_counter_ += 1 105 | prec1 = prec1_counter_ / batch_counter_ 106 | prec5 = prec5_counter_ / batch_counter_ 107 | test_loss /= len(test_loader.dataset) 108 | print('Test set: Step: {}, Average loss: {:.4f}, Prec@1: {} Prec@5: {}'.format(self._cur_step, 109 | test_loss, prec1, prec5)) 110 | 111 | def _load_model(self, file_path): 112 | #self.network = build_model(self.network_config, num_classes=10) 113 | # build network 114 | if self.network_config == "LeNet": 115 | self.network=LeNet() 116 | elif self.network_config == "ResNet18": 117 | self.network=ResNet18(num_classes=num_classes) 118 | elif self.network_config == "ResNet34": 119 | self.network=ResNet34(num_classes=num_classes) 120 | elif self.network_config == "FC": 121 | self.network=FC_NN() 122 | elif self.network_config == "DenseNet": 123 | self.network=DenseNet(growthRate=40, depth=190, reduction=0.5, 124 | bottleneck=True, nClasses=10) 125 | elif self.network_config == "VGG11": 126 | self.network=vgg11_bn(num_classes) 127 | elif self.network_config == "AlexNet": 128 | self.network=alexnet(num_classes=10) 129 | 130 | with open(file_path, "rb") as f_: 131 | self.network.load_state_dict(torch.load(f_)) 132 | 133 | def _model_dir_generator(self, next_step_to_fetch): 134 | return self._model_dir+"model_step_"+str(next_step_to_fetch) 135 | 136 | if __name__ == "__main__": 137 | # this is only a simple test case 138 | args = add_fit_args(argparse.ArgumentParser(description='PyTorch Distributed Evaluator')) 139 | 140 | # load training and test set here: 141 | if args.dataset == "MNIST": 142 | test_loader = torch.utils.data.DataLoader( 143 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 144 | transforms.ToTensor(), 145 | transforms.Normalize((0.1307,), (0.3081,)) 146 | ])), batch_size=args.eval_batch_size, shuffle=True) 147 | elif args.dataset == "Cifar10": 148 | test_loader = torch.utils.data.DataLoader( 149 | datasets.CIFAR10('./cifar10_data', train=False, transform=transforms.Compose([ 150 | transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 151 | ])), batch_size=args.eval_batch_size, shuffle=True) 152 | 153 | kwargs_evaluator={ 154 | 'network':args.network, 155 | 'model_dir':args.model_dir, 156 | 'eval_freq':args.eval_freq, 157 | 'eval_batch_size':args.eval_batch_size} 158 | evaluator_nn = DistributedEvaluator(**kwargs_evaluator) 159 | evaluator_nn.evaluate(validation_loader=test_loader) 160 | print("I am worker: {} in all {} workers".format(worker_fc_nn.rank, worker_fc_nn.world_size)) -------------------------------------------------------------------------------- /src/distributed_nn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import math 5 | import threading 6 | import argparse 7 | import time 8 | 9 | import numpy as np 10 | from mpi4py import MPI 11 | 12 | import torch 13 | from torch.autograd import Variable 14 | from torch import nn 15 | from torch.nn.parallel.replicate import replicate 16 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather 17 | from torch.nn.parallel.parallel_apply import parallel_apply 18 | import torch.nn.functional as F 19 | 20 | from torchvision import datasets, transforms 21 | 22 | from nn_ops import NN_Trainer, accuracy 23 | from data_loader_ops.my_data_loader import DataLoader 24 | 25 | from distributed_worker import * 26 | from sync_replicas_master_nn import * 27 | 28 | from datasets import SVHN 29 | 30 | 31 | def add_fit_args(parser): 32 | """ 33 | parser : argparse.ArgumentParser 34 | return a parser added with args required by fit 35 | """ 36 | # Training settings 37 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 38 | help='input batch size for training (default: 64)') 39 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 40 | help='input batch size for testing (default: 1000)') 41 | parser.add_argument('--max-steps', type=int, default=10000, metavar='N', 42 | help='the maximum number of iterations') 43 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 44 | help='number of epochs to train (default: 10)') 45 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 46 | help='learning rate (default: 0.01)') 47 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 48 | help='SGD momentum (default: 0.5)') 49 | parser.add_argument('--lr-shrinkage', type=float, default=0.95, metavar='M', 50 | help='exponential decay factor of lr schedule (default: 0.95)') 51 | parser.add_argument('--no-cuda', action='store_true', default=False, 52 | help='disables CUDA training') 53 | parser.add_argument('--seed', type=int, default=1, metavar='S', 54 | help='random seed (default: 1)') 55 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 56 | help='how many batches to wait before logging training status') 57 | parser.add_argument('--network', type=str, default='LeNet', metavar='N', 58 | help='which kind of network we are going to use, support LeNet and ResNet currently') 59 | parser.add_argument('--code', type=str, default='sgd', 60 | help='which coding method use e.g. sgd, qsgd, svd') 61 | parser.add_argument('--bucket-size', type=int, default=512, 62 | help='bucket sizes used in QSGD') 63 | parser.add_argument('--dataset', type=str, default='MNIST', metavar='N', 64 | help='which dataset used in training, MNIST and Cifar10 supported currently') 65 | parser.add_argument('--comm-type', type=str, default='Bcast', metavar='N', 66 | help='which kind of method we use during the mode fetching stage') 67 | parser.add_argument('--num-aggregate', type=int, default=5, metavar='N', 68 | help='how many number of gradients we wish to gather at each iteration') 69 | parser.add_argument('--eval-freq', type=int, default=50, metavar='N', 70 | help='it determines per how many step the model should be evaluated') 71 | parser.add_argument('--train-dir', type=str, default='output/models/', metavar='N', 72 | help='directory to save the temp model during the training process for evaluation') 73 | parser.add_argument('--compress', type=bool, default=False, 74 | help='whether to use gradient approx method') 75 | parser.add_argument('--enable-gpu', type=bool, default=False, 76 | help='whether to use gradient approx method') 77 | 78 | parser.add_argument('--svd-rank', default=0, help='Boolean int: compress or not', 79 | type=int) 80 | parser.add_argument('--quantization-level', type=int, default=4, help='int: bits used in QSGD') 81 | args = parser.parse_args() 82 | return args 83 | 84 | if __name__ == "__main__": 85 | # this is only a simple test case 86 | comm = MPI.COMM_WORLD 87 | rank = comm.Get_rank() 88 | world_size = comm.Get_size() 89 | 90 | args = add_fit_args(argparse.ArgumentParser(description='PyTorch MNIST Single Machine Test')) 91 | 92 | # load training and test set here: 93 | if args.dataset == "MNIST": 94 | training_set = datasets.MNIST('./mnist_data', train=True, download=True, 95 | transform=transforms.Compose([ 96 | transforms.ToTensor(), 97 | transforms.Normalize((0.1307,), (0.3081,))])) 98 | train_loader = DataLoader(training_set, batch_size=args.batch_size, shuffle=True) 99 | test_loader = torch.utils.data.DataLoader( 100 | datasets.MNIST('./mnist_data', train=False, transform=transforms.Compose([ 101 | transforms.ToTensor(), 102 | transforms.Normalize((0.1307,), (0.3081,)) 103 | ])), batch_size=args.test_batch_size, shuffle=True) 104 | elif args.dataset == "Cifar10": 105 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 106 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 107 | transform_train = transforms.Compose([ 108 | transforms.ToTensor(), 109 | transforms.Lambda(lambda x: F.pad( 110 | Variable(x.unsqueeze(0), requires_grad=False, volatile=True), 111 | (4,4,4,4),mode='reflect').data.squeeze()), 112 | transforms.ToPILImage(), 113 | transforms.RandomCrop(32), 114 | transforms.RandomHorizontalFlip(), 115 | transforms.ToTensor(), 116 | normalize, 117 | ]) 118 | # data prep for test set 119 | transform_test = transforms.Compose([ 120 | transforms.ToTensor(), 121 | normalize]) 122 | # load training and test set here: 123 | training_set = datasets.CIFAR10(root='./cifar10_data', train=True, 124 | download=True, transform=transform_train) 125 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=args.batch_size, 126 | shuffle=True) 127 | testset = datasets.CIFAR10(root='./cifar10_data', train=False, 128 | download=True, transform=transform_test) 129 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, 130 | shuffle=False) 131 | elif args.dataset == 'SVHN': 132 | training_set = SVHN('./svhn_data', split='train', transform=transforms.Compose([ 133 | transforms.RandomCrop(32, padding=4), 134 | transforms.RandomHorizontalFlip(), 135 | transforms.ToTensor(), 136 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 137 | ])) 138 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=128, 139 | shuffle=True) 140 | transform_test = transforms.Compose([ 141 | transforms.ToTensor(), 142 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 143 | ]) 144 | testset = SVHN(root='./svhn_data', split='test', 145 | download=True, transform=transform_test) 146 | test_loader = torch.utils.data.DataLoader(testset, batch_size=1000, 147 | shuffle=False) 148 | elif args.dataset == 'Cifar100': 149 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 150 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 151 | transform_train = transforms.Compose([ 152 | transforms.ToTensor(), 153 | transforms.Lambda(lambda x: F.pad( 154 | Variable(x.unsqueeze(0), requires_grad=False, volatile=True), 155 | (4,4,4,4),mode='reflect').data.squeeze()), 156 | transforms.ToPILImage(), 157 | transforms.RandomCrop(32), 158 | transforms.RandomHorizontalFlip(), 159 | transforms.ToTensor(), 160 | normalize, 161 | ]) 162 | # data prep for test set 163 | transform_test = transforms.Compose([ 164 | transforms.ToTensor(), 165 | normalize]) 166 | # load training and test set here: 167 | training_set = datasets.CIFAR100(root='./cifar100_data', train=True, 168 | download=True, transform=transform_train) 169 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=args.batch_size, 170 | shuffle=True) 171 | testset = datasets.CIFAR100(root='./cifar100_data', train=False, 172 | download=True, transform=transform_test) 173 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, 174 | shuffle=False) 175 | elif args.dataset == 'ImageNet': 176 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 177 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 178 | # data prep for training set 179 | # note that the key point to reach convergence performance reported in this paper (https://arxiv.org/abs/1512.03385) 180 | # is to implement data augmentation 181 | transform_train = transforms.Compose([ 182 | transforms.Scale((227, 227)), 183 | transforms.ToTensor(), 184 | transforms.Lambda(lambda x: F.pad( 185 | Variable(x.unsqueeze(0), requires_grad=False, volatile=True), 186 | (4,4,4,4),mode='reflect').data.squeeze()), 187 | transforms.ToPILImage(), 188 | transforms.RandomCrop(227), 189 | transforms.RandomHorizontalFlip(), 190 | transforms.ToTensor(), 191 | normalize, 192 | ]) 193 | # data prep for test set 194 | transform_test = transforms.Compose([ 195 | transforms.ToTensor(), 196 | normalize]) 197 | # load training and test set here: 198 | training_set = datasets.CIFAR10(root='./cifar10_data', train=True, 199 | download=True, transform=transform_train) 200 | #training_set = datasets.CIFAR10(root='./cifar10_data', train=True, 201 | # download=True, transform=transform_test) 202 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=args.batch_size, 203 | shuffle=True) 204 | testset = datasets.CIFAR10(root='./cifar10_data', train=False, 205 | download=True, transform=transform_test) 206 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, 207 | shuffle=False) 208 | 209 | kwargs_master = {'batch_size':args.batch_size, 210 | 'learning_rate':args.lr, 211 | 'max_epochs':args.epochs, 212 | 'max_steps':args.max_steps, 213 | 'momentum':args.momentum, 214 | 'network':args.network, 215 | 'comm_method':args.comm_type, 216 | 'eval_freq':args.eval_freq, 217 | 'train_dir':args.train_dir, 218 | 'compress':args.compress, 219 | 'num_aggregate':args.num_aggregate, 220 | 'enable_gpu':args.enable_gpu, 221 | 'lr_shrinkage':args.lr_shrinkage, 222 | 'code':args.code, 223 | 'svd_rank':args.svd_rank, 224 | 'quantization_level':args.quantization_level, 225 | 'bucket_size':args.bucket_size} 226 | 227 | kwargs_worker = {'batch_size':args.batch_size, 228 | 'learning_rate':args.lr, 229 | 'max_epochs':args.epochs, 230 | 'momentum':args.momentum, 231 | 'network':args.network, 232 | 'max_steps':args.max_steps, 233 | 'comm_method':args.comm_type, 234 | 'compress':args.compress, 235 | 'enable_gpu':args.enable_gpu, 236 | 'eval_freq':args.eval_freq, 237 | 'train_dir':args.train_dir, 238 | 'code':args.code, 239 | 'svd_rank':args.svd_rank, 240 | 'quantization_level':args.quantization_level, 241 | 'bucket_size':args.bucket_size} 242 | 243 | if rank == 0: 244 | master_fc_nn = SyncReplicasMaster_NN(comm=comm, **kwargs_master) 245 | if args.dataset == 'Cifar100': 246 | master_fc_nn.build_model(num_classes=100) 247 | else: 248 | master_fc_nn.build_model(num_classes=10) 249 | print("I am the master: the world size is {}, cur step: {}".format(master_fc_nn.world_size, master_fc_nn.cur_step)) 250 | master_fc_nn.train() 251 | print("Done sending messages to workers!") 252 | else: 253 | worker_fc_nn = DistributedWorker(comm=comm, **kwargs_worker) 254 | if args.dataset == 'Cifar100': 255 | worker_fc_nn.build_model(num_classes=100) 256 | else: 257 | worker_fc_nn.build_model(num_classes=10) 258 | print("I am worker: {} in all {} workers, next step: {}".format(worker_fc_nn.rank, worker_fc_nn.world_size-1, worker_fc_nn.next_step)) 259 | worker_fc_nn.train(train_loader=train_loader, test_loader=test_loader) 260 | print("Worker Done Jobs! ...") -------------------------------------------------------------------------------- /src/distributed_worker.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from mpi4py import MPI 3 | import numpy as np 4 | 5 | from nn_ops import NN_Trainer 6 | 7 | from model_ops.lenet import LeNet, LeNetSplit 8 | from model_ops.resnet import * 9 | from model_ops.resnet_split import * 10 | from model_ops.vgg import * 11 | from model_ops.alexnet import * 12 | from model_ops.fc_nn import FC_NN, FC_NN_Split 13 | from model_ops.densenet import DenseNet 14 | 15 | from utils import compress 16 | import codings 17 | 18 | import torch 19 | from torch.autograd import Variable 20 | 21 | import time 22 | from datetime import datetime 23 | import copy 24 | from sys import getsizeof 25 | import pickle 26 | 27 | STEP_START_ = 1 28 | # use compression tool to make it run faster 29 | _FAKE_SGD = True 30 | TAG_LIST_ = [i*30 for i in range(50000)] 31 | 32 | def prepare_grad_list(params): 33 | grad_list = [] 34 | for param_idx, param in enumerate(params): 35 | # get gradient from layers here 36 | # in this version we fetch weights at once 37 | # remember to change type here, which is essential 38 | grads = param.grad.data.numpy().astype(np.float32) 39 | grad_list.append((param_idx, grads)) 40 | return grad_list 41 | 42 | def accuracy(output, target, topk=(1,)): 43 | """Computes the precision@k for the specified values of k""" 44 | maxk = max(topk) 45 | batch_size = target.size(0) 46 | 47 | _, pred = output.topk(maxk, 1, True, True) 48 | pred = pred.t() 49 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 50 | 51 | res = [] 52 | for k in topk: 53 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 54 | res.append(correct_k.mul_(100.0 / batch_size)) 55 | return res 56 | 57 | def _4d_to_2d(tensor): 58 | return tensor.view(tensor.size()[0], tensor.size()[1]*tensor.size()[2]*tensor.size()[3]) 59 | 60 | def _construct_grad_packet(module): 61 | ''' 62 | input: weight (\in R^{m*n}) and bias (\in R^{m*n}) 63 | output: grad packet (\in R^{m*(n+1)}) 64 | ''' 65 | ndims = len(module.weight.grad.data.size()) 66 | if ndims == 4: 67 | tmp_grad = _4d_to_2d(module.weight.grad.data) 68 | if module.bias is None: 69 | return tmp_grad 70 | else: 71 | return torch.cat((tmp_grad, module.bias.grad.data), 1) 72 | elif ndims == 2: 73 | if module.bias is None: 74 | return module.weight.grad.data 75 | else: 76 | return torch.cat((module.weight.grad.data, module.bias.grad.data), 1) 77 | elif ndims == 1: 78 | if module.bias is None: 79 | return module.weight.grad.data 80 | else: 81 | return torch.cat((module.bias.grad.data, module.weight.grad.data), 0) 82 | 83 | 84 | class ModelBuffer(object): 85 | def __init__(self, network): 86 | """ 87 | this class is used to save model weights received from parameter server 88 | current step for each layer of model will also be updated here to make sure 89 | the model is always up-to-date 90 | """ 91 | self.recv_buf = [] 92 | self.layer_cur_step = [] 93 | for param_idx, param in enumerate(network.parameters()): 94 | self.recv_buf.append(np.zeros(param.size())) 95 | self.layer_cur_step.append(0) 96 | 97 | 98 | class DistributedWorker(NN_Trainer): 99 | def __init__(self, comm, **kwargs): 100 | self.comm = comm # get MPI communicator object 101 | self.world_size = comm.Get_size() # total number of processes 102 | self.rank = comm.Get_rank() # rank of this Worker 103 | #self.status = MPI.Status() 104 | self.cur_step = 0 105 | self.next_step = 0 # we will fetch this one from parameter server 106 | 107 | self.batch_size = kwargs['batch_size'] 108 | self.max_epochs = kwargs['max_epochs'] 109 | self.momentum = kwargs['momentum'] 110 | self.lr = kwargs['learning_rate'] 111 | self.network_config = kwargs['network'] 112 | self._max_steps = kwargs['max_steps'] 113 | self.comm_type = kwargs['comm_method'] 114 | self._compress = kwargs['compress'] 115 | self._enable_gpu = kwargs['enable_gpu'] 116 | self._eval_batch_size = 100 117 | self._eval_freq = kwargs['eval_freq'] 118 | self._train_dir = kwargs['train_dir'] 119 | # encode related 120 | self._svd_rank = kwargs['svd_rank'] 121 | self._quantization_level = kwargs['quantization_level'] 122 | self._bucket_size = kwargs['bucket_size'] 123 | 124 | # this one is going to be used to avoid fetch the weights for multiple times 125 | self._layer_cur_step = [] 126 | self._code = kwargs['code'] 127 | if kwargs['code'] == 'sgd': 128 | if not _FAKE_SGD: 129 | self._coder = codings.svd.SVD(compress=False) 130 | else: 131 | self._coder = codings.lossless_compress.LosslessCompress() 132 | elif kwargs['code'] == 'svd': 133 | print("train.py, svd_rank =", self._svd_rank) 134 | self._coder = codings.svd.SVD(random_sample=True, 135 | rank=self._svd_rank, compress=True) 136 | else: 137 | raise ValueError('args.code not recognized') 138 | 139 | def build_model(self, num_classes=10): 140 | # build network 141 | if self.network_config == "LeNet": 142 | self.network=LeNet() 143 | elif self.network_config == "ResNet18": 144 | self.network=ResNet18(num_classes=num_classes) 145 | elif self.network_config == "ResNet34": 146 | self.network=ResNet34(num_classes=num_classes) 147 | elif self.network_config == "FC": 148 | self.network=FC_NN() 149 | elif self.network_config == "DenseNet": 150 | self.network=DenseNet(growthRate=40, depth=190, reduction=0.5, 151 | bottleneck=True, nClasses=10) 152 | elif self.network_config == "VGG11": 153 | self.network=vgg11_bn(num_classes) 154 | elif self.network_config == "AlexNet": 155 | self.network=alexnet(num_classes=10) 156 | 157 | # set up optimizer 158 | self.optimizer = torch.optim.SGD(self.network.parameters(), lr=self.lr, momentum=self.momentum) 159 | self.criterion = nn.CrossEntropyLoss() 160 | # assign a buffer for receiving models from parameter server 161 | self.init_recv_buf() 162 | # enable GPU here 163 | if self._enable_gpu: 164 | self.network.cuda() 165 | 166 | def train(self, train_loader, test_loader): 167 | # the first step we need to do here is to sync fetch the inital worl_step from the parameter server 168 | # we still need to make sure the value we fetched from parameter server is 1 169 | self.sync_fetch_step() 170 | # do some sync check here 171 | assert(self.update_step()) 172 | assert(self.cur_step == STEP_START_) 173 | 174 | # number of batches in one epoch 175 | num_batch_per_epoch = len(train_loader.dataset) / self.batch_size 176 | batch_idx = -1 177 | epoch_idx = 0 178 | epoch_avg_loss = 0 179 | iteration_last_step=0 180 | iter_start_time=0 181 | 182 | first = True 183 | 184 | print("Worker {}: starting training".format(self.rank)) 185 | # start the training process 186 | # start the training process 187 | for num_epoch in range(self.max_epochs): 188 | for batch_idx, (train_image_batch, train_label_batch) in enumerate(train_loader): 189 | # worker exit task 190 | if self.cur_step == self._max_steps: 191 | break 192 | if self._enable_gpu: 193 | X_batch, y_batch = Variable(train_image_batch.cuda()), Variable(train_label_batch.cuda()) 194 | else: 195 | X_batch, y_batch = Variable(train_image_batch), Variable(train_label_batch) 196 | while True: 197 | # the worker shouldn't know the current global step 198 | # except received the message from parameter server 199 | self.async_fetch_step() 200 | 201 | # the only way every worker know which step they're currently on is to check the cur step variable 202 | updated = self.update_step() 203 | 204 | if (not updated) and (not first): 205 | # wait here unitl enter next step 206 | continue 207 | 208 | # the real start point of this iteration 209 | iteration_last_step = time.time() - iter_start_time 210 | iter_start_time = time.time() 211 | first = False 212 | print("Rank of this node: {}, Current step: {}".format(self.rank, self.cur_step)) 213 | 214 | # TODO(hwang): return layer request here and do weight before the forward step begins, rather than implement 215 | # the wait() in the fetch function 216 | fetch_weight_start_time = time.time() 217 | 218 | self.async_fetch_weights_bcast() 219 | 220 | fetch_weight_duration = time.time() - fetch_weight_start_time 221 | 222 | # switch to training mode 223 | self.network.train() 224 | # manage batch index manually 225 | self.optimizer.zero_grad() 226 | 227 | # forward step 228 | comp_start = time.time() 229 | logits = self.network(X_batch) 230 | loss = self.criterion(logits, y_batch) 231 | 232 | epoch_avg_loss += loss.data[0] 233 | 234 | # backward step 235 | backward_start_time = time.time() 236 | loss.backward() 237 | comp_dur = time.time() - comp_start 238 | 239 | # gradient encoding step 240 | encode_start = time.time() 241 | msgs,_msg_counter = self._encode() 242 | encode_dur = time.time() - encode_start 243 | 244 | # communication step 245 | comm_start = time.time() 246 | self._send_grads(msgs) 247 | comm_dur = time.time()-comm_start 248 | 249 | prec1, prec5 = accuracy(logits.data, y_batch.data, topk=(1, 5)) 250 | if self._enable_gpu: 251 | prec1, prec5 = prec1.cpu().numpy()[0], prec5.cpu().numpy()[0] 252 | else: 253 | prec1, prec5 = prec1.numpy()[0], prec5.numpy()[0] 254 | # on the end of a certain iteration 255 | print('Worker: {}, Step: {}, Epoch: {} [{}/{} ({:.0f}%)], Loss: {:.4f}, Time Cost: {:.4f}, Comp: {:.4f}, Encode: {: .4f}, Comm: {: .4f}, Msg(MB): {: .4f}, Prec@1: {: .4f}, Prec@5: {: .4f}'.format( 256 | self.rank, self.cur_step, num_epoch, batch_idx * self.batch_size, len(train_loader.dataset), 257 | (100. * (batch_idx * self.batch_size) / len(train_loader.dataset)), loss.data[0], time.time()-iter_start_time, 258 | comp_dur, encode_dur, comm_dur, _msg_counter/(1024.0**2), prec1, prec5)) 259 | # break here to fetch data then enter fetching step loop again 260 | if self.cur_step%self._eval_freq==0: 261 | self._evaluate_model(test_loader) 262 | break 263 | 264 | def init_recv_buf(self): 265 | self.model_recv_buf = ModelBuffer(self.network) 266 | 267 | def sync_fetch_step(self): 268 | '''fetch the first step from the parameter server''' 269 | self.next_step = self.comm.recv(source=0, tag=10) 270 | 271 | def async_fetch_step(self): 272 | req = self.comm.irecv(source=0, tag=10) 273 | self.next_step = req.wait() 274 | 275 | def async_fetch_weights_bcast(self): 276 | layers_to_update = [] 277 | for layer_idx, layer in enumerate(self.model_recv_buf.recv_buf): 278 | if self.model_recv_buf.layer_cur_step[layer_idx] < self.cur_step: 279 | layers_to_update.append(layer_idx) 280 | self.comm.Bcast([self.model_recv_buf.recv_buf[layer_idx], MPI.DOUBLE], root=0) 281 | weights_to_update = [] 282 | for req_idx, layer_idx in enumerate(layers_to_update): 283 | weights = self.model_recv_buf.recv_buf[req_idx] 284 | weights_to_update.append(weights) 285 | # we also need to update the layer cur step here: 286 | self.model_recv_buf.layer_cur_step[req_idx] = self.cur_step 287 | self.model_update(weights_to_update) 288 | 289 | def update_step(self): 290 | '''update local (global) step on worker''' 291 | changed = (self.cur_step != self.next_step) 292 | self.cur_step = self.next_step 293 | return changed 294 | 295 | def model_update(self, weights_to_update): 296 | """write model fetched from parameter server to local model""" 297 | new_state_dict = {} 298 | model_counter_ = 0 299 | for param_idx,(key_name, param) in enumerate(self.network.state_dict().items()): 300 | # handle the case that `running_mean` and `running_var` contained in `BatchNorm` layer 301 | if "running_mean" in key_name or "running_var" in key_name: 302 | tmp_dict={key_name: param} 303 | else: 304 | assert param.size() == weights_to_update[model_counter_].shape 305 | if self._enable_gpu: 306 | tmp_dict = {key_name: torch.from_numpy(weights_to_update[model_counter_]).cuda()} 307 | else: 308 | tmp_dict = {key_name: torch.from_numpy(weights_to_update[model_counter_])} 309 | model_counter_ += 1 310 | new_state_dict.update(tmp_dict) 311 | self.network.load_state_dict(new_state_dict) 312 | 313 | def _encode(self): 314 | msgs = [] 315 | _msg_counter = 0 316 | def __count_msg_sizes(msg): 317 | return len(msg) 318 | for p_index, p in enumerate(self.network.parameters()): 319 | if self._enable_gpu: 320 | grad = p.grad.data.cpu().numpy().astype(np.float32) 321 | else: 322 | p.grad.data.numpy().astype(np.float32) 323 | coded = self._coder.encode(grad) 324 | pickled = pickle.dumps(coded) 325 | byte_code = bytearray(pickled) 326 | _msg_counter+=__count_msg_sizes(byte_code) 327 | msgs.append(byte_code) 328 | return msgs, _msg_counter 329 | 330 | def _send_grads(self, msgs): 331 | req_send_check = [] 332 | for msg_index, m in enumerate(msgs): 333 | req_isend = self.comm.isend(m, dest=0, tag=88+msg_index) 334 | req_isend.wait() 335 | [req_send_check[i].wait() for i in range(len(req_send_check))] 336 | 337 | def _generate_model_path(self): 338 | return self._train_dir+"model_step_"+str(self.cur_step) 339 | 340 | def _save_model(self, file_path): 341 | with open(file_path, "wb") as f_: 342 | torch.save(self.network.state_dict(), f_) 343 | 344 | def _evaluate_model(self, test_loader): 345 | self.network.eval() 346 | test_loss = 0 347 | correct = 0 348 | prec1_counter_ = prec5_counter_ = batch_counter_ = 0 349 | for data, y_batch in test_loader: 350 | if self._enable_gpu: 351 | data, target = Variable(data.cuda(), volatile=True), Variable(y_batch.cuda()) 352 | else: 353 | data, target = Variable(data, volatile=True), Variable(y_batch) 354 | 355 | output = self.network(data) 356 | test_loss += F.nll_loss(F.log_softmax(output), target, size_average=False).data[0] 357 | 358 | prec1_tmp, prec5_tmp = accuracy(output.data, target.data, topk=(1, 5)) 359 | if self._enable_gpu: 360 | prec1_counter_ += prec1_tmp.cpu().numpy()[0] 361 | prec5_counter_ += prec5_tmp.cpu().numpy()[0] 362 | else: 363 | prec1_counter_ += prec1_tmp.numpy()[0] 364 | prec5_counter_ += prec5_tmp.numpy()[0] 365 | batch_counter_ += 1 366 | prec1 = prec1_counter_ / batch_counter_ 367 | prec5 = prec5_counter_ / batch_counter_ 368 | test_loss /= len(test_loader.dataset) 369 | print('Test set: Step: {}, Average loss: {:.4f}, Prec@1: {} Prec@5: {}'.format(self.cur_step, 370 | test_loss, prec1, prec5)) 371 | 372 | if __name__ == "__main__": 373 | # this is only a simple test case 374 | comm = MPI.COMM_WORLD 375 | rank = comm.Get_rank() 376 | world_size = comm.Get_size() 377 | worker_fc_nn = WorkerFC_NN(comm=comm, world_size=world_size, rank=rank) 378 | print("I am worker: {} in all {} workers".format(worker_fc_nn.rank, worker_fc_nn.world_size)) -------------------------------------------------------------------------------- /src/evaluate_pytorch.sh: -------------------------------------------------------------------------------- 1 | python distributed_evaluator.py \ 2 | --eval-batch-size=10000 \ 3 | --eval-freq=50 \ 4 | --network=ResNet18 \ 5 | --dataset=Cifar10 \ 6 | --model-dir=/home/ubuntu/MPI_shared/ -------------------------------------------------------------------------------- /src/model_ops/__init__.py: -------------------------------------------------------------------------------- 1 | from . import densenet, fc_nn, lenet, resnet, vgg, alexnet 2 | 3 | __all__ = ['densenet', 'fc_nn', 'lenet', 'resnet', 'vgg', 'alexnet'] -------------------------------------------------------------------------------- /src/model_ops/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['AlexNet', 'alexnet'] 6 | 7 | 8 | model_urls = { 9 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 10 | } 11 | 12 | 13 | class AlexNet(nn.Module): 14 | 15 | def __init__(self, num_classes=1000): 16 | super(AlexNet, self).__init__() 17 | self.features = nn.Sequential( 18 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(kernel_size=3, stride=2), 21 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(kernel_size=3, stride=2), 24 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=3, stride=2), 31 | ) 32 | self.classifier = nn.Sequential( 33 | nn.Dropout(), 34 | nn.Linear(256 * 6 * 6, 4096), 35 | nn.ReLU(inplace=True), 36 | nn.Dropout(), 37 | nn.Linear(4096, 4096), 38 | nn.ReLU(inplace=True), 39 | nn.Linear(4096, num_classes), 40 | ) 41 | 42 | def forward(self, x): 43 | x = self.features(x) 44 | x = x.view(x.size(0), 256 * 6 * 6) 45 | x = self.classifier(x) 46 | return x 47 | 48 | 49 | def alexnet(pretrained=False, **kwargs): 50 | r"""AlexNet model architecture from the 51 | `"One weird trick..." `_ paper. 52 | Args: 53 | pretrained (bool): If True, returns a model pre-trained on ImageNet 54 | """ 55 | model = AlexNet(**kwargs) 56 | if pretrained: 57 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) 58 | return model -------------------------------------------------------------------------------- /src/model_ops/densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | import torchvision.datasets as dset 10 | import torchvision.transforms as transforms 11 | from torch.utils.data import DataLoader 12 | 13 | import torchvision.models as models 14 | 15 | import sys 16 | import math 17 | 18 | class Bottleneck(nn.Module): 19 | def __init__(self, nChannels, growthRate): 20 | super(Bottleneck, self).__init__() 21 | interChannels = 4*growthRate 22 | self.bn1 = nn.BatchNorm2d(nChannels) 23 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, 24 | bias=False) 25 | self.bn2 = nn.BatchNorm2d(interChannels) 26 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, 27 | padding=1, bias=False) 28 | 29 | def forward(self, x): 30 | out = self.conv1(F.relu(self.bn1(x))) 31 | out = self.conv2(F.relu(self.bn2(out))) 32 | out = torch.cat((x, out), 1) 33 | return out 34 | 35 | class SingleLayer(nn.Module): 36 | def __init__(self, nChannels, growthRate): 37 | super(SingleLayer, self).__init__() 38 | self.bn1 = nn.BatchNorm2d(nChannels) 39 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, 40 | padding=1, bias=False) 41 | 42 | def forward(self, x): 43 | out = self.conv1(F.relu(self.bn1(x))) 44 | out = torch.cat((x, out), 1) 45 | return out 46 | 47 | class Transition(nn.Module): 48 | def __init__(self, nChannels, nOutChannels): 49 | super(Transition, self).__init__() 50 | self.bn1 = nn.BatchNorm2d(nChannels) 51 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, 52 | bias=False) 53 | 54 | def forward(self, x): 55 | out = self.conv1(F.relu(self.bn1(x))) 56 | out = F.avg_pool2d(out, 2) 57 | return out 58 | 59 | 60 | class DenseNet(nn.Module): 61 | def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): 62 | super(DenseNet, self).__init__() 63 | 64 | nDenseBlocks = (depth-4) // 3 65 | if bottleneck: 66 | nDenseBlocks //= 2 67 | 68 | nChannels = 2*growthRate 69 | self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, 70 | bias=False) 71 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 72 | nChannels += nDenseBlocks*growthRate 73 | nOutChannels = int(math.floor(nChannels*reduction)) 74 | self.trans1 = Transition(nChannels, nOutChannels) 75 | 76 | nChannels = nOutChannels 77 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 78 | nChannels += nDenseBlocks*growthRate 79 | nOutChannels = int(math.floor(nChannels*reduction)) 80 | self.trans2 = Transition(nChannels, nOutChannels) 81 | 82 | nChannels = nOutChannels 83 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 84 | nChannels += nDenseBlocks*growthRate 85 | 86 | self.bn1 = nn.BatchNorm2d(nChannels) 87 | self.fc = nn.Linear(nChannels, nClasses) 88 | 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 92 | m.weight.data.normal_(0, math.sqrt(2. / n)) 93 | elif isinstance(m, nn.BatchNorm2d): 94 | m.weight.data.fill_(1) 95 | m.bias.data.zero_() 96 | elif isinstance(m, nn.Linear): 97 | m.bias.data.zero_() 98 | 99 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): 100 | layers = [] 101 | for i in range(int(nDenseBlocks)): 102 | if bottleneck: 103 | layers.append(Bottleneck(nChannels, growthRate)) 104 | else: 105 | layers.append(SingleLayer(nChannels, growthRate)) 106 | nChannels += growthRate 107 | return nn.Sequential(*layers) 108 | 109 | def forward(self, x): 110 | out = self.conv1(x) 111 | out = self.trans1(self.dense1(out)) 112 | out = self.trans2(self.dense2(out)) 113 | out = self.dense3(out) 114 | out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8)) 115 | out = F.log_softmax(self.fc(out)) 116 | return out -------------------------------------------------------------------------------- /src/model_ops/fc_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | import pandas as pd 6 | import numpy as np 7 | from torch.autograd import Variable 8 | 9 | from mpi4py import MPI 10 | 11 | # we use fc nn here for our simple case 12 | class FC_NN(nn.Module): 13 | def __init__(self): 14 | super(FC_NN, self).__init__() 15 | self.fc1 = nn.Linear(784, 800) 16 | self.fc2 = nn.Linear(800, 500) 17 | self.fc3 = nn.Linear(500, 10) 18 | self.relu = nn.ReLU() 19 | self.sigmoid = nn.Sigmoid() 20 | def forward(self, x): 21 | x = x.view(-1, x.size()[1]*x.size()[2]*x.size()[3]) 22 | x = self.fc1(x) 23 | x = self.relu(x) 24 | x = self.fc2(x) 25 | x = self.relu(x) 26 | x = self.fc3(x) 27 | x = self.sigmoid(x) 28 | return x 29 | def name(self): 30 | return 'fc_nn' 31 | 32 | # we use fc nn here for our simple case 33 | class FC_NN_Split(nn.Module): 34 | def __init__(self): 35 | super(FC_NN_Split, self).__init__() 36 | self.fc1 = nn.Linear(784, 800) 37 | self.fc2 = nn.Linear(800, 500) 38 | self.fc3 = nn.Linear(500, 10) 39 | self.relu = nn.ReLU() 40 | self.sigmoid = nn.Sigmoid() 41 | # helper 42 | self.full_modules = [self.fc1, self.fc2, self.fc3] 43 | self._init_channel_index = len(self.full_modules)*2 44 | def forward(self, x): 45 | ''' 46 | split layers 47 | ''' 48 | self.output = [] 49 | self.input = [] 50 | x = x.view(-1, x.size()[1]*x.size()[2]*x.size()[3]) 51 | 52 | x = Variable(x.data, requires_grad=True) 53 | self.input.append(x) 54 | x = self.fc1(x) 55 | self.output.append(x) 56 | 57 | x = Variable(x.data, requires_grad=True) 58 | self.input.append(x) 59 | x = self.relu(x) 60 | self.output.append(x) 61 | 62 | x = Variable(x.data, requires_grad=True) 63 | self.input.append(x) 64 | x = self.fc2(x) 65 | self.output.append(x) 66 | 67 | x = Variable(x.data, requires_grad=True) 68 | self.input.append(x) 69 | x = self.relu(x) 70 | self.output.append(x) 71 | 72 | x = Variable(x.data, requires_grad=True) 73 | self.input.append(x) 74 | x = self.fc3(x) 75 | self.output.append(x) 76 | 77 | x = Variable(x.data, requires_grad=True) 78 | self.input.append(x) 79 | x = self.sigmoid(x) 80 | self.output.append(x) 81 | return x 82 | @property 83 | def fetch_init_channel_index(self): 84 | return self._init_channel_index 85 | def backward_normal(self, g, communicator, req_send_check, cur_step): 86 | mod_avail_index = len(self.full_modules)-1 87 | #channel_index = len(self.full_modules)*2-2 88 | channel_index = self._init_channel_index - 2 89 | mod_counters_ = [0]*len(self.full_modules) 90 | for i, output in reversed(list(enumerate(self.output))): 91 | req_send_check[-1].wait() 92 | if i == (len(self.output) - 1): 93 | # for last node, use g 94 | output.backward(g) 95 | # get gradient here after some sanity checks: 96 | ''' 97 | tmp_grad = self.full_modules[mod_avail_index].weight.grad 98 | if not pd.isnull(tmp_grad): 99 | grads = tmp_grad.data.numpy().astype(np.float64) 100 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 101 | req_send_check.append(req_isend) 102 | # update counters 103 | mod_avail_index-=1 104 | channel_index-=1 105 | else: 106 | continue 107 | ''' 108 | else: 109 | output.backward(self.input[i+1].grad.data) 110 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 111 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad 112 | # specific for this fc nn setting 113 | if mod_avail_index == len(self.full_modules)-1: 114 | if not pd.isnull(tmp_grad_weight): 115 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 116 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 117 | req_send_check.append(req_isend) 118 | # update counters 119 | mod_avail_index-=1 120 | channel_index-=1 121 | else: 122 | continue 123 | else: 124 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias): 125 | # we always send bias first 126 | if mod_counters_[mod_avail_index] == 0: 127 | grads = tmp_grad_bias.data.numpy().astype(np.float64) 128 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 129 | req_send_check.append(req_isend) 130 | channel_index-=1 131 | mod_counters_[mod_avail_index]+=1 132 | elif mod_counters_[mod_avail_index] == 1: 133 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 134 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 135 | req_send_check.append(req_isend) 136 | channel_index-=1 137 | mod_counters_[mod_avail_index]+=1 138 | # update counters 139 | mod_avail_index-=1 140 | else: 141 | continue 142 | if mod_counters_[0] == 1: 143 | req_send_check[-1].wait() 144 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 145 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 146 | req_send_check.append(req_isend) 147 | # if cur_step >= 2: 148 | # exit() 149 | return req_send_check 150 | @property 151 | def name(self): 152 | return 'fc_nn' -------------------------------------------------------------------------------- /src/model_ops/lenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | import pandas as pd 6 | import numpy as np 7 | from torch.autograd import Variable 8 | 9 | from mpi4py import MPI 10 | 11 | # we use LeNet here for our simple case 12 | class LeNet(nn.Module): 13 | def __init__(self): 14 | super(LeNet, self).__init__() 15 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 16 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 17 | self.fc1 = nn.Linear(4*4*50, 500) 18 | self.fc2 = nn.Linear(500, 10) 19 | self.ceriation = nn.CrossEntropyLoss() 20 | self.full_modules = [self.conv1, self.conv2, self.fc1, self.fc2] 21 | 22 | def forward(self, x): 23 | x = self.conv1(x) 24 | x = F.max_pool2d(x, 2, 2) 25 | x = F.relu(x) 26 | x = self.conv2(x) 27 | x = F.max_pool2d(x, 2, 2) 28 | x = F.relu(x) 29 | x = x.view(-1, 4*4*50) 30 | x = self.fc1(x) 31 | x = self.fc2(x) 32 | #loss = self.ceriation(x, target) 33 | return x 34 | def name(self): 35 | return 'lenet' 36 | 37 | class LeNetSplit(nn.Module): 38 | ''' 39 | this is a module that we split the module and do backward process layer by layer 40 | please don't call this module for normal uses, this is a hack and run slower than 41 | the automatic chain rule version 42 | ''' 43 | def __init__(self): 44 | super(LeNetSplit, self).__init__() 45 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 46 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 47 | self.fc1 = nn.Linear(4*4*50, 500) 48 | self.fc2 = nn.Linear(500, 10) 49 | 50 | self.maxpool2d = nn.MaxPool2d(2, stride=2) 51 | self.relu = nn.ReLU() 52 | 53 | self.full_modules = [self.conv1, self.conv2, self.fc1, self.fc2] 54 | self._init_channel_index = len(self.full_modules)*2 55 | 56 | self.criterion = nn.CrossEntropyLoss() 57 | 58 | def forward(self, x): 59 | self.output = [] 60 | self.input = [] 61 | x = Variable(x.data, requires_grad=True) 62 | self.input.append(x) 63 | x = self.conv1(x) 64 | self.output.append(x) 65 | 66 | x = Variable(x.data, requires_grad=True) 67 | self.input.append(x) 68 | x = self.maxpool2d(x) 69 | self.output.append(x) 70 | 71 | x = Variable(x.data, requires_grad=True) 72 | self.input.append(x) 73 | x = self.relu(x) 74 | self.output.append(x) 75 | 76 | x = Variable(x.data, requires_grad=True) 77 | self.input.append(x) 78 | x = self.conv2(x) 79 | self.output.append(x) 80 | 81 | x = Variable(x.data, requires_grad=True) 82 | self.input.append(x) 83 | x = self.maxpool2d(x) 84 | self.output.append(x) 85 | 86 | x = Variable(x.data, requires_grad=True) 87 | self.input.append(x) 88 | x = self.relu(x) 89 | self.output.append(x) 90 | 91 | x = x.view(-1, 4*4*50) 92 | 93 | x = Variable(x.data, requires_grad=True) 94 | self.input.append(x) 95 | x = self.fc1(x) 96 | self.output.append(x) 97 | 98 | x = Variable(x.data, requires_grad=True) 99 | self.input.append(x) 100 | x = self.fc2(x) 101 | self.output.append(x) 102 | return x 103 | 104 | @property 105 | def fetch_init_channel_index(self): 106 | return self._init_channel_index 107 | 108 | def backward_normal(self, g, communicator, req_send_check, cur_step): 109 | mod_avail_index = len(self.full_modules)-1 110 | #channel_index = len(self.full_modules)*2-2 111 | channel_index = self._init_channel_index - 2 112 | mod_counters_ = [0]*len(self.full_modules) 113 | for i, output in reversed(list(enumerate(self.output))): 114 | req_send_check[-1].wait() 115 | if i == (len(self.output) - 1): 116 | # for last node, use g 117 | output.backward(g) 118 | # get gradient here after some sanity checks: 119 | tmp_grad = self.full_modules[mod_avail_index].weight.grad 120 | if not pd.isnull(tmp_grad): 121 | grads = tmp_grad.data.numpy().astype(np.float64) 122 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 123 | req_send_check.append(req_isend) 124 | # update counters 125 | mod_avail_index-=1 126 | channel_index-=1 127 | else: 128 | continue 129 | else: 130 | output.backward(self.input[i+1].grad.data) 131 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 132 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad 133 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias): 134 | # we always send bias first 135 | if mod_counters_[mod_avail_index] == 0: 136 | grads = tmp_grad_bias.data.numpy().astype(np.float64) 137 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 138 | req_send_check.append(req_isend) 139 | channel_index-=1 140 | mod_counters_[mod_avail_index]+=1 141 | elif mod_counters_[mod_avail_index] == 1: 142 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 143 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 144 | req_send_check.append(req_isend) 145 | channel_index-=1 146 | mod_counters_[mod_avail_index]+=1 147 | # update counters 148 | mod_avail_index-=1 149 | else: 150 | continue 151 | if mod_counters_[0] == 1: 152 | req_send_check[-1].wait() 153 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 154 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 155 | req_send_check.append(req_isend) 156 | return req_send_check 157 | 158 | def backward_signal_kill(self, g, communicator, req_send_check, cur_step): 159 | ''' 160 | This killer is triggered by signals bcasting from master, channel of 161 | signal is kept checking by each worker to determine if they're the 162 | straggler 163 | ''' 164 | mod_avail_index = len(self.full_modules)-1 165 | channel_index = self._init_channel_index - 2 166 | mod_counters_ = [0]*len(self.full_modules) 167 | 168 | # should kill flag 169 | should_kill = False 170 | 171 | for i, output in reversed(list(enumerate(self.output))): 172 | ############################ killing process on workers ##################################### 173 | for _ in range(10000): 174 | status = MPI.Status() 175 | communicator.Iprobe(0, 77, status) 176 | if status.Get_source() == 0: 177 | print("Worker {}, Cur Step: {} I'm the straggler, killing myself!".format(communicator.Get_rank(), cur_step)) 178 | tmp = communicator.recv(source=0, tag=77) 179 | should_kill = True 180 | break 181 | if should_kill: 182 | break 183 | ############################################################################################ 184 | 185 | if i == (len(self.output) - 1): 186 | # for last node, use g 187 | output.backward(g) 188 | # get gradient here after some sanity checks: 189 | tmp_grad = self.full_modules[mod_avail_index].weight.grad 190 | if not pd.isnull(tmp_grad): 191 | grads = tmp_grad.data.numpy().astype(np.float64) 192 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 193 | req_send_check.append(req_isend) 194 | # update counters 195 | mod_avail_index-=1 196 | channel_index-=1 197 | else: 198 | continue 199 | else: 200 | output.backward(self.input[i+1].grad.data) 201 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 202 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad 203 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias): 204 | # we always send bias first 205 | if mod_counters_[mod_avail_index] == 0: 206 | grads = tmp_grad_bias.data.numpy().astype(np.float64) 207 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 208 | req_send_check.append(req_isend) 209 | channel_index-=1 210 | mod_counters_[mod_avail_index]+=1 211 | elif mod_counters_[mod_avail_index] == 1: 212 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 213 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 214 | req_send_check.append(req_isend) 215 | channel_index-=1 216 | mod_counters_[mod_avail_index]+=1 217 | # update counters 218 | mod_avail_index-=1 219 | else: 220 | continue 221 | if mod_counters_[0] == 1: 222 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 223 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 224 | req_send_check.append(req_isend) 225 | return req_send_check 226 | 227 | def backward_timeout_kill(self, g, communicator, req_send_check): 228 | """do we even need this?""" 229 | pass -------------------------------------------------------------------------------- /src/model_ops/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from torch.autograd import Variable 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.full_modules = [] 20 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | self.full_modules.append(self.conv1) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.full_modules.append(self.bn1) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 25 | self.full_modules.append(self.conv2) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | self.full_modules.append(self.bn2) 28 | 29 | self.shortcut = nn.Sequential() 30 | if stride != 1 or in_planes != self.expansion*planes: 31 | self.shortcut = nn.Sequential( 32 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(self.expansion*planes) 34 | ) 35 | self.full_modules.append(self.shortcut[0]) 36 | self.full_modules.append(self.shortcut[1]) 37 | 38 | def forward(self, x): 39 | out = F.relu(self.bn1(self.conv1(x))) 40 | out = self.bn2(self.conv2(out)) 41 | out += self.shortcut(x) 42 | out = F.relu(out) 43 | return out 44 | 45 | 46 | class Bottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, in_planes, planes, stride=1): 50 | super(Bottleneck, self).__init__() 51 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 52 | self.bn1 = nn.BatchNorm2d(planes) 53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 57 | 58 | self.shortcut = nn.Sequential() 59 | if stride != 1 or in_planes != self.expansion*planes: 60 | self.shortcut = nn.Sequential( 61 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 62 | nn.BatchNorm2d(self.expansion*planes) 63 | ) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = F.relu(self.bn2(self.conv2(out))) 68 | out = self.bn3(self.conv3(out)) 69 | out += self.shortcut(x) 70 | out = F.relu(out) 71 | return out 72 | 73 | 74 | class ResNet(nn.Module): 75 | def __init__(self, block, num_blocks, num_classes=10): 76 | super(ResNet, self).__init__() 77 | self.in_planes = 64 78 | self.full_modules = [] 79 | 80 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 81 | self.full_modules.append(self.conv1) 82 | self.bn1 = nn.BatchNorm2d(64) 83 | self.full_modules.append(self.bn1) 84 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 85 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 86 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 87 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 88 | self.linear = nn.Linear(512*block.expansion, num_classes) 89 | self.full_modules.append(self.linear) 90 | 91 | def _make_layer(self, block, planes, num_blocks, stride): 92 | strides = [stride] + [1]*(num_blocks-1) 93 | layers = [] 94 | for stride in strides: 95 | block_layers=block(self.in_planes, planes, stride) 96 | layers.append(block_layers) 97 | for m in block_layers.full_modules: 98 | self.full_modules.append(m) 99 | self.in_planes = planes * block.expansion 100 | return nn.Sequential(*layers) 101 | 102 | def forward(self, x): 103 | out = F.relu(self.bn1(self.conv1(x))) 104 | out = self.layer1(out) 105 | out = self.layer2(out) 106 | out = self.layer3(out) 107 | out = self.layer4(out) 108 | out = F.avg_pool2d(out, 4) 109 | out = out.view(out.size(0), -1) 110 | out = self.linear(out) 111 | return out 112 | 113 | 114 | def ResNet18(num_classes): 115 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes) 116 | 117 | def ResNet34(): 118 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes) 119 | 120 | def ResNet50(): 121 | return ResNet(Bottleneck, [3,4,6,3]) 122 | 123 | def ResNet101(): 124 | return ResNet(Bottleneck, [3,4,23,3]) 125 | 126 | def ResNet152(): 127 | return ResNet(Bottleneck, [3,8,36,3]) -------------------------------------------------------------------------------- /src/model_ops/resnet_split.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | 7 | 8 | Please Note that, this version is a hack, it's super hacky, never call this one for normal use 9 | ''' 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from torch.autograd import Variable 15 | 16 | import pandas as pd 17 | import numpy as np 18 | import timeout_decorator 19 | 20 | from mpi4py import MPI 21 | 22 | LAYER_DIGITS= int(1e+3) 23 | TIMEOUT_THRESHOLD_=10 24 | 25 | def generate_tag(layer_tag, step_token): 26 | ''' 27 | Tag component [current-step-token (which help to recogize stale gradient) 28 | +layer-tag] 29 | we only limit the digits for layer tag here since step token can be 30 | extremely large e.g. 10k steps 31 | 32 | :param layer_tag 33 | :param step token 34 | :return: 35 | ''' 36 | tag = step_token * LAYER_DIGITS \ 37 | + layer_tag 38 | tag = int(tag) 39 | return tag 40 | 41 | class BasicBlockSplit(nn.Module): 42 | expansion = 1 43 | 44 | def __init__(self, in_planes, planes, stride=1): 45 | super(BasicBlockSplit, self).__init__() 46 | self.full_modules = [] 47 | 48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.full_modules.append(self.conv1) 50 | 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.full_modules.append(self.bn1) 53 | 54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 55 | self.full_modules.append(self.conv2) 56 | 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.full_modules.append(self.bn2) 59 | 60 | self.relu = nn.ReLU() 61 | 62 | self.shortcut = nn.Sequential() 63 | if stride != 1 or in_planes != self.expansion*planes: 64 | self.shortcut = nn.Sequential( 65 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 66 | nn.BatchNorm2d(self.expansion*planes) 67 | ) 68 | self.full_modules.append(self.shortcut[0]) 69 | self.full_modules.append(self.shortcut[1]) 70 | 71 | def forward(self, x, input_list, output_list): 72 | ''' 73 | the input_list and output_list here is similar to input/output in ResNet class 74 | ''' 75 | # we skip the detach and append operation on the very first x here 76 | # since that's done outside of this function 77 | out = self.conv1(x) 78 | output_list.append(out) 79 | 80 | out = Variable(out.data, requires_grad=True) 81 | input_list.append(out) 82 | out = self.bn1(out) 83 | output_list.append(out) 84 | 85 | out = Variable(out.data, requires_grad=True) 86 | input_list.append(out) 87 | out = self.relu(out) 88 | output_list.append(out) 89 | 90 | out = Variable(out.data, requires_grad=True) 91 | input_list.append(out) 92 | out = self.conv2(out) 93 | output_list.append(out) 94 | 95 | out = Variable(out.data, requires_grad=True) 96 | input_list.append(out) 97 | out = self.bn2(out) 98 | output_list.append(out) 99 | 100 | # TODO(hwang): figure out if this part also need hack 101 | out += self.shortcut(x) 102 | 103 | out = Variable(out.data, requires_grad=True) 104 | input_list.append(out) 105 | out = self.relu(out) 106 | output_list.append(out) 107 | return out, input_list, output_list 108 | 109 | 110 | class Bottleneck(nn.Module): 111 | expansion = 4 112 | 113 | def __init__(self, in_planes, planes, stride=1): 114 | super(Bottleneck, self).__init__() 115 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 116 | self.bn1 = nn.BatchNorm2d(planes) 117 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 118 | self.bn2 = nn.BatchNorm2d(planes) 119 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 120 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 121 | 122 | self.shortcut = nn.Sequential() 123 | if stride != 1 or in_planes != self.expansion*planes: 124 | self.shortcut = nn.Sequential( 125 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 126 | nn.BatchNorm2d(self.expansion*planes) 127 | ) 128 | 129 | def forward(self, x): 130 | # we skip the detach operation on the very first x here since that's done outside of this function 131 | out = F.relu(self.bn1(self.conv1(x))) 132 | out = F.relu(self.bn2(self.conv2(out))) 133 | out = self.bn3(self.conv3(out)) 134 | out += self.shortcut(x) 135 | out = F.relu(out) 136 | return out 137 | 138 | 139 | class ResNetSplit(nn.Module): 140 | def __init__(self, block, num_blocks, kill_threshold, num_classes=10): 141 | super(ResNetSplit, self).__init__() 142 | global TIMEOUT_THRESHOLD_ 143 | self.in_planes = 64 144 | self.full_modules = [] 145 | 146 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 147 | self.full_modules.append(self.conv1) 148 | 149 | self.bn1 = nn.BatchNorm2d(64) 150 | self.full_modules.append(self.bn1) 151 | 152 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 153 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 154 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 155 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 156 | 157 | self.linear = nn.Linear(512*block.expansion, num_classes) 158 | self.full_modules.append(self.linear) 159 | 160 | self.relu = nn.ReLU() 161 | self.avg_pool2d = nn.AvgPool2d(kernel_size=4) 162 | self._init_channel_index = self.count_channel_index() 163 | self._timeout_threshold = kill_threshold 164 | TIMEOUT_THRESHOLD_ = self._timeout_threshold 165 | self.killed_request_list = [] 166 | 167 | @property 168 | def fetch_init_channel_index(self): 169 | return self._init_channel_index 170 | 171 | def _make_layer(self, block, planes, num_blocks, stride): 172 | strides = [stride] + [1]*(num_blocks-1) 173 | layers = [] 174 | for stride in strides: 175 | block_layers = block(self.in_planes, planes, stride) 176 | layers.append(block_layers) 177 | for m in block_layers.full_modules: 178 | self.full_modules.append(m) 179 | 180 | self.in_planes = planes * block.expansion 181 | layers_split = nn.ModuleList(layers) 182 | 183 | return layers_split 184 | 185 | def forward(self, x): 186 | # use these containers to save intermediate variables 187 | self.output = [] 188 | self.input = [] 189 | 190 | # start the forward process right here implement the following logic to every intermediate var: 191 | # detach from previous history 192 | x = Variable(x.data, requires_grad=True) 193 | self.input.append(x) 194 | x = self.conv1(x) 195 | # add to list of outputs 196 | self.output.append(x) 197 | 198 | x = Variable(x.data, requires_grad=True) 199 | self.input.append(x) 200 | x = self.bn1(x) 201 | self.output.append(x) 202 | 203 | x = Variable(x.data, requires_grad=True) 204 | self.input.append(x) 205 | x = self.relu(x) 206 | self.output.append(x) 207 | 208 | # start to handle blocks 209 | for layer in self.layer1: 210 | # each `layer` here is either a `BasicBlockSplit` or `BottleneckSplit` 211 | x = Variable(x.data, requires_grad=True) 212 | self.input.append(x) 213 | # call the `.forward()` func in `BasicBlockSplit` or `BottleneckSplit` here 214 | x, self.input, self.output = layer(x, self.input, self.output) 215 | 216 | for layer in self.layer2: 217 | # each `layer` here is either a `BasicBlockSplit` or `BottleneckSplit` 218 | x = Variable(x.data, requires_grad=True) 219 | self.input.append(x) 220 | # call the `.forward()` func in `BasicBlockSplit` or `BottleneckSplit` here 221 | x, self.input, self.output = layer(x, self.input, self.output) 222 | 223 | for layer in self.layer3: 224 | # each `layer` here is either a `BasicBlockSplit` or `BottleneckSplit` 225 | x = Variable(x.data, requires_grad=True) 226 | self.input.append(x) 227 | # call the `.forward()` func in `BasicBlockSplit` or `BottleneckSplit` here 228 | x, self.input, self.output = layer(x, self.input, self.output) 229 | 230 | for layer in self.layer4: 231 | # each `layer` here is either a `BasicBlockSplit` or `BottleneckSplit` 232 | x = Variable(x.data, requires_grad=True) 233 | self.input.append(x) 234 | # call the `.forward()` func in `BasicBlockSplit` or `BottleneckSplit` here 235 | x, self.input, self.output = layer(x, self.input, self.output) 236 | 237 | x = Variable(x.data, requires_grad=True) 238 | self.input.append(x) 239 | x = self.avg_pool2d(x) 240 | self.output.append(x) 241 | 242 | x = x.view(x.size(0), -1) 243 | 244 | x = Variable(x.data, requires_grad=True) 245 | self.input.append(x) 246 | x = self.linear(x) 247 | self.output.append(x) 248 | return x 249 | 250 | def count_channel_index(self): 251 | channel_index_ = 0 252 | for k, v in self.state_dict().items(): 253 | if "running_mean" in k or "running_var" in k: 254 | continue 255 | else: 256 | channel_index_ += 1 257 | return channel_index_ 258 | 259 | def backward(self, g, communicator, req_send_check, cur_step): 260 | mod_avail_index = len(self.full_modules)-1 261 | channel_index = self._init_channel_index-2 262 | mod_counters_ = [0]*len(self.full_modules) 263 | for i, output in reversed(list(enumerate(self.output))): 264 | # send layer only after the last layer is received 265 | req_send_check[-1].wait() 266 | if i == (len(self.output) - 1): 267 | # for last node, use g 268 | output.backward(g) 269 | # get gradient here after some sanity checks: 270 | tmp_grad = self.full_modules[mod_avail_index].weight.grad 271 | if not pd.isnull(tmp_grad): 272 | grads = tmp_grad.data.numpy().astype(np.float64) 273 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 274 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step)) 275 | req_send_check.append(req_isend) 276 | # update counters 277 | mod_avail_index-=1 278 | channel_index-=1 279 | else: 280 | continue 281 | else: 282 | if output.size() == self.input[i+1].grad.size(): 283 | output.backward(self.input[i+1].grad.data) 284 | else: 285 | tmp_grad_output = self.input[i+1].grad.view(output.size()) 286 | output.backward(tmp_grad_output) 287 | 288 | # since in resnet we do not use bias weight for conv layer 289 | if pd.isnull(self.full_modules[mod_avail_index].bias): 290 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 291 | 292 | if not pd.isnull(tmp_grad_weight): 293 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 294 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 295 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step)) 296 | req_send_check.append(req_isend) 297 | channel_index-=1 298 | mod_counters_[mod_avail_index]=2 299 | # update counters 300 | mod_avail_index-=1 301 | else: 302 | continue 303 | else: 304 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 305 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad 306 | 307 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias): 308 | # we always send bias first 309 | if mod_counters_[mod_avail_index] == 0: 310 | grads = tmp_grad_bias.data.numpy().astype(np.float64) 311 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 312 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step)) 313 | req_send_check.append(req_isend) 314 | channel_index-=1 315 | mod_counters_[mod_avail_index]+=1 316 | elif mod_counters_[mod_avail_index] == 1: 317 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 318 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 319 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step)) 320 | req_send_check.append(req_isend) 321 | channel_index-=1 322 | mod_counters_[mod_avail_index]+=1 323 | # update counters 324 | mod_avail_index-=1 325 | else: 326 | continue 327 | # handle the remaining gradients here to send to parameter server 328 | while channel_index >= 0: 329 | req_send_check[-1].wait() 330 | if pd.isnull(self.full_modules[mod_avail_index].bias): 331 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 332 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 333 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 334 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step)) 335 | req_send_check.append(req_isend) 336 | channel_index-=1 337 | mod_counters_[mod_avail_index]=2 338 | # update counters 339 | mod_avail_index-=1 340 | else: 341 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 342 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad 343 | # we always send bias first 344 | if mod_counters_[mod_avail_index] == 0: 345 | grads = tmp_grad_bias.data.numpy().astype(np.float64) 346 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 347 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step)) 348 | req_send_check.append(req_isend) 349 | channel_index-=1 350 | mod_counters_[mod_avail_index]+=1 351 | elif mod_counters_[mod_avail_index] == 1: 352 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 353 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 354 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step)) 355 | req_send_check.append(req_isend) 356 | channel_index-=1 357 | mod_counters_[mod_avail_index]+=1 358 | # update counters 359 | mod_avail_index-=1 360 | return req_send_check 361 | 362 | def backward_normal(self, g, communicator, req_send_check, cur_step): 363 | mod_avail_index = len(self.full_modules)-1 364 | channel_index = self._init_channel_index-2 365 | mod_counters_ = [0]*len(self.full_modules) 366 | for i, output in reversed(list(enumerate(self.output))): 367 | # send layer only after the last layer is received 368 | req_send_check[-1].wait() 369 | if i == (len(self.output) - 1): 370 | # for last node, use g 371 | output.backward(g) 372 | # get gradient here after some sanity checks: 373 | tmp_grad = self.full_modules[mod_avail_index].weight.grad 374 | if not pd.isnull(tmp_grad): 375 | grads = tmp_grad.data.numpy().astype(np.float64) 376 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 377 | req_send_check.append(req_isend) 378 | # update counters 379 | mod_avail_index-=1 380 | channel_index-=1 381 | else: 382 | continue 383 | else: 384 | if output.size() == self.input[i+1].grad.size(): 385 | output.backward(self.input[i+1].grad.data) 386 | else: 387 | tmp_grad_output = self.input[i+1].grad.view(output.size()) 388 | output.backward(tmp_grad_output) 389 | 390 | # since in resnet we do not use bias weight for conv layer 391 | if pd.isnull(self.full_modules[mod_avail_index].bias): 392 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 393 | 394 | if not pd.isnull(tmp_grad_weight): 395 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 396 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 397 | req_send_check.append(req_isend) 398 | channel_index-=1 399 | mod_counters_[mod_avail_index]=2 400 | # update counters 401 | mod_avail_index-=1 402 | else: 403 | continue 404 | else: 405 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 406 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad 407 | 408 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias): 409 | # we always send bias first 410 | if mod_counters_[mod_avail_index] == 0: 411 | grads = tmp_grad_bias.data.numpy().astype(np.float64) 412 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 413 | req_send_check.append(req_isend) 414 | channel_index-=1 415 | mod_counters_[mod_avail_index]+=1 416 | elif mod_counters_[mod_avail_index] == 1: 417 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 418 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 419 | req_send_check.append(req_isend) 420 | channel_index-=1 421 | mod_counters_[mod_avail_index]+=1 422 | # update counters 423 | mod_avail_index-=1 424 | else: 425 | continue 426 | # handle the remaining gradients here to send to parameter server 427 | while channel_index >= 0: 428 | req_send_check[-1].wait() 429 | if pd.isnull(self.full_modules[mod_avail_index].bias): 430 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 431 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 432 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 433 | req_send_check.append(req_isend) 434 | channel_index-=1 435 | mod_counters_[mod_avail_index]=2 436 | # update counters 437 | mod_avail_index-=1 438 | else: 439 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 440 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad 441 | # we always send bias first 442 | if mod_counters_[mod_avail_index] == 0: 443 | grads = tmp_grad_bias.data.numpy().astype(np.float64) 444 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 445 | req_send_check.append(req_isend) 446 | channel_index-=1 447 | mod_counters_[mod_avail_index]+=1 448 | elif mod_counters_[mod_avail_index] == 1: 449 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 450 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 451 | req_send_check.append(req_isend) 452 | channel_index-=1 453 | mod_counters_[mod_avail_index]+=1 454 | # update counters 455 | mod_avail_index-=1 456 | return req_send_check 457 | 458 | def backward_signal_kill(self, g, communicator, req_send_check, cur_step): 459 | mod_avail_index = len(self.full_modules)-1 460 | channel_index = self._init_channel_index-2 461 | mod_counters_ = [0]*len(self.full_modules) 462 | 463 | # should kill flag 464 | should_kill = False 465 | 466 | for i, output in reversed(list(enumerate(self.output))): 467 | ############################ killing process on workers ##################################### 468 | for _ in range(100): 469 | status = MPI.Status() 470 | communicator.Iprobe(0, 77, status) 471 | if status.Get_source() == 0: 472 | print("Worker {}, Cur Step: {} I'm the straggler, killing myself!".format(communicator.Get_rank(), cur_step)) 473 | tmp = communicator.recv(source=0, tag=77) 474 | should_kill = True 475 | break 476 | if should_kill: 477 | channel_index=-5 478 | break 479 | ############################################################################################ 480 | if i == (len(self.output) - 1): 481 | # for last node, use g 482 | output.backward(g) 483 | # get gradient here after some sanity checks: 484 | tmp_grad = self.full_modules[mod_avail_index].weight.grad 485 | if not pd.isnull(tmp_grad): 486 | grads = tmp_grad.data.numpy().astype(np.float64) 487 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 488 | req_send_check.append(req_isend) 489 | # update counters 490 | mod_avail_index-=1 491 | channel_index-=1 492 | else: 493 | continue 494 | else: 495 | if output.size() == self.input[i+1].grad.size(): 496 | output.backward(self.input[i+1].grad.data) 497 | else: 498 | tmp_grad_output = self.input[i+1].grad.view(output.size()) 499 | output.backward(tmp_grad_output) 500 | 501 | # since in resnet we do not use bias weight for conv layer 502 | if pd.isnull(self.full_modules[mod_avail_index].bias): 503 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 504 | 505 | if not pd.isnull(tmp_grad_weight): 506 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 507 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 508 | req_send_check.append(req_isend) 509 | channel_index-=1 510 | mod_counters_[mod_avail_index]=2 511 | # update counters 512 | mod_avail_index-=1 513 | else: 514 | continue 515 | else: 516 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 517 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad 518 | 519 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias): 520 | # we always send bias first 521 | if mod_counters_[mod_avail_index] == 0: 522 | grads = tmp_grad_bias.data.numpy().astype(np.float64) 523 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 524 | req_send_check.append(req_isend) 525 | channel_index-=1 526 | mod_counters_[mod_avail_index]+=1 527 | elif mod_counters_[mod_avail_index] == 1: 528 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 529 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 530 | req_send_check.append(req_isend) 531 | channel_index-=1 532 | mod_counters_[mod_avail_index]+=1 533 | # update counters 534 | mod_avail_index-=1 535 | else: 536 | continue 537 | # handle the remaining gradients here to send to parameter server 538 | while channel_index >= 0: 539 | if pd.isnull(self.full_modules[mod_avail_index].bias): 540 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 541 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 542 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 543 | req_send_check.append(req_isend) 544 | channel_index-=1 545 | mod_counters_[mod_avail_index]=2 546 | # update counters 547 | mod_avail_index-=1 548 | else: 549 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 550 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad 551 | # we always send bias first 552 | if mod_counters_[mod_avail_index] == 0: 553 | grads = tmp_grad_bias.data.numpy().astype(np.float64) 554 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 555 | req_send_check.append(req_isend) 556 | channel_index-=1 557 | mod_counters_[mod_avail_index]+=1 558 | elif mod_counters_[mod_avail_index] == 1: 559 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 560 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 561 | req_send_check.append(req_isend) 562 | channel_index-=1 563 | mod_counters_[mod_avail_index]+=1 564 | # update counters 565 | mod_avail_index-=1 566 | if channel_index == -1: 567 | killed = False 568 | elif channel_index == -5: 569 | killed = True 570 | return req_send_check, killed 571 | 572 | @timeout_decorator.timeout(10.5, timeout_exception=StopIteration) 573 | def backward_timeout_kill(self, g, communicator, req_send_check, cur_step): 574 | mod_avail_index = len(self.full_modules)-1 575 | channel_index = self._init_channel_index-2 576 | mod_counters_ = [0]*len(self.full_modules) 577 | 578 | # meset request list of killed workers 579 | self.killed_request_list = [] 580 | for i, output in reversed(list(enumerate(self.output))): 581 | # send layer only after the last layer is received 582 | req_send_check[-1].wait() 583 | if i == (len(self.output) - 1): 584 | # for last node, use g 585 | output.backward(g) 586 | # get gradient here after some sanity checks: 587 | tmp_grad = self.full_modules[mod_avail_index].weight.grad 588 | if not pd.isnull(tmp_grad): 589 | grads = tmp_grad.data.numpy().astype(np.float64) 590 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 591 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step)) 592 | req_send_check.append(req_isend) 593 | #self.killed_request_list.append(req_isend) 594 | # update counters 595 | mod_avail_index-=1 596 | channel_index-=1 597 | else: 598 | continue 599 | else: 600 | if output.size() == self.input[i+1].grad.size(): 601 | output.backward(self.input[i+1].grad.data) 602 | else: 603 | tmp_grad_output = self.input[i+1].grad.view(output.size()) 604 | output.backward(tmp_grad_output) 605 | 606 | # since in resnet we do not use bias weight for conv layer 607 | if pd.isnull(self.full_modules[mod_avail_index].bias): 608 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 609 | 610 | if not pd.isnull(tmp_grad_weight): 611 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 612 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 613 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step)) 614 | req_send_check.append(req_isend) 615 | #self.killed_request_list.append(req_isend) 616 | channel_index-=1 617 | mod_counters_[mod_avail_index]=2 618 | # update counters 619 | mod_avail_index-=1 620 | else: 621 | continue 622 | else: 623 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 624 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad 625 | 626 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias): 627 | # we always send bias first 628 | if mod_counters_[mod_avail_index] == 0: 629 | grads = tmp_grad_bias.data.numpy().astype(np.float64) 630 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 631 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step)) 632 | req_send_check.append(req_isend) 633 | #self.killed_request_list.append(req_isend) 634 | channel_index-=1 635 | mod_counters_[mod_avail_index]+=1 636 | elif mod_counters_[mod_avail_index] == 1: 637 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 638 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 639 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step)) 640 | req_send_check.append(req_isend) 641 | #self.killed_request_list.append(req_isend) 642 | channel_index-=1 643 | mod_counters_[mod_avail_index]+=1 644 | # update counters 645 | mod_avail_index-=1 646 | else: 647 | continue 648 | # handle the remaining gradients here to send to parameter server 649 | while channel_index >= 0: 650 | req_send_check[-1].wait() 651 | if pd.isnull(self.full_modules[mod_avail_index].bias): 652 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 653 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 654 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 655 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step)) 656 | req_send_check.append(req_isend) 657 | #self.killed_request_list.append(req_isend) 658 | channel_index-=1 659 | mod_counters_[mod_avail_index]=2 660 | # update counters 661 | mod_avail_index-=1 662 | else: 663 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 664 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad 665 | # we always send bias first 666 | if mod_counters_[mod_avail_index] == 0: 667 | grads = tmp_grad_bias.data.numpy().astype(np.float64) 668 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 669 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step)) 670 | req_send_check.append(req_isend) 671 | #self.killed_request_list.append(req_isend) 672 | channel_index-=1 673 | mod_counters_[mod_avail_index]+=1 674 | elif mod_counters_[mod_avail_index] == 1: 675 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 676 | #req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 677 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=generate_tag(layer_tag=88+channel_index, step_token=cur_step)) 678 | req_send_check.append(req_isend) 679 | #self.killed_request_list.append(req_isend) 680 | channel_index-=1 681 | mod_counters_[mod_avail_index]+=1 682 | # update counters 683 | mod_avail_index-=1 684 | return req_send_check 685 | 686 | def backward_single(self, g): 687 | for i, output in reversed(list(enumerate(self.output))): 688 | #print("Backward processing, step {}".format(i)) 689 | #print("--------------------------------------------------------") 690 | if i == (len(self.output) - 1): 691 | # for last node, use g 692 | output.backward(g) 693 | else: 694 | 695 | #print(output.size()) 696 | #print(self.input[i+1].grad.size()) 697 | #tmp = self.input[i+1].grad.view(output.size()) 698 | #print(tmp.size()) 699 | #print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") 700 | 701 | if output.size() == self.input[i+1].grad.size(): 702 | output.backward(self.input[i+1].grad.data) 703 | else: 704 | tmp_grad_output = self.input[i+1].grad.view(output.size()) 705 | output.backward(tmp_grad_output) 706 | 707 | def ResNetSplit18(kill_threshold): 708 | return ResNetSplit(BasicBlockSplit, [2,2,2,2], kill_threshold=kill_threshold) 709 | 710 | def ResNetSplit34(): 711 | return ResNetSplit(BasicBlockSplit, [3,4,6,3]) 712 | 713 | def ResNetSplit50(): 714 | return ResNetSplit(Bottleneck, [3,4,6,3]) 715 | 716 | def ResNetSplit101(): 717 | return ResNetSplit(Bottleneck, [3,4,23,3]) 718 | 719 | def ResNetSplit152(): 720 | return ResNetSplit(Bottleneck, [3,8,36,3]) 721 | 722 | if __name__ == "__main__": 723 | a = ResNetSplit18(1) 724 | print("Done!") -------------------------------------------------------------------------------- /src/model_ops/vgg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://github.com/pytorch/vision.git 3 | ''' 4 | import math 5 | 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | class VGG(nn.Module): 16 | ''' 17 | VGG model 18 | ''' 19 | def __init__(self, features, num_classes=10): 20 | super(VGG, self).__init__() 21 | self.features = features 22 | self.classifier = nn.Sequential( 23 | nn.Dropout(), 24 | nn.Linear(512, 512), 25 | nn.ReLU(True), 26 | nn.Dropout(), 27 | nn.Linear(512, 512), 28 | nn.ReLU(True), 29 | nn.Linear(512, num_classes), 30 | ) 31 | # Initialize weights 32 | for m in self.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 35 | m.weight.data.normal_(0, math.sqrt(2. / n)) 36 | m.bias.data.zero_() 37 | 38 | 39 | def forward(self, x): 40 | x = self.features(x) 41 | x = x.view(x.size(0), -1) 42 | x = self.classifier(x) 43 | return x 44 | 45 | 46 | def make_layers(cfg, batch_norm=False): 47 | layers = [] 48 | in_channels = 3 49 | for v in cfg: 50 | if v == 'M': 51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 52 | else: 53 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 54 | if batch_norm: 55 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 56 | else: 57 | layers += [conv2d, nn.ReLU(inplace=True)] 58 | in_channels = v 59 | return nn.Sequential(*layers) 60 | 61 | 62 | cfg = { 63 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 64 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 65 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 66 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 67 | 512, 512, 512, 512, 'M'], 68 | } 69 | 70 | 71 | def vgg11(): 72 | """VGG 11-layer model (configuration "A")""" 73 | return VGG(make_layers(cfg['A'])) 74 | 75 | 76 | def vgg11_bn(num_classes=10): 77 | """VGG 11-layer model (configuration "A") with batch normalization""" 78 | return VGG(make_layers(cfg['A'], batch_norm=True), num_classes=num_classes) 79 | 80 | 81 | def vgg13(): 82 | """VGG 13-layer model (configuration "B")""" 83 | return VGG(make_layers(cfg['B'])) 84 | 85 | 86 | def vgg13_bn(): 87 | """VGG 13-layer model (configuration "B") with batch normalization""" 88 | return VGG(make_layers(cfg['B'], batch_norm=True)) 89 | 90 | 91 | def vgg16(): 92 | """VGG 16-layer model (configuration "D")""" 93 | return VGG(make_layers(cfg['D'])) 94 | 95 | 96 | def vgg16_bn(): 97 | """VGG 16-layer model (configuration "D") with batch normalization""" 98 | return VGG(make_layers(cfg['D'], batch_norm=True)) 99 | 100 | 101 | def vgg19(): 102 | """VGG 19-layer model (configuration "E")""" 103 | return VGG(make_layers(cfg['E'])) 104 | 105 | 106 | def vgg19_bn(): 107 | """VGG 19-layer model (configuration 'E') with batch normalization""" 108 | return VGG(make_layers(cfg['E'], batch_norm=True)) -------------------------------------------------------------------------------- /src/nn_ops.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | from model_ops.lenet import LeNet 10 | from model_ops.resnet import * 11 | from model_ops.resnet_split import * 12 | from model_ops.alexnet import * 13 | 14 | import numpy.linalg as LA 15 | 16 | 17 | def nuclear_indicator(grad, s): 18 | m, n = grad.shape 19 | return np.sum(s)*np.sqrt(m+n) 20 | 21 | 22 | def l1_indicator(grad): 23 | return np.linalg.norm(grad.reshape(-1), 1) 24 | 25 | 26 | def _resize_to_2d(x): 27 | """ 28 | x.shape > 2 29 | If x.shape = (a, b, *c), assumed that each one of (a, b) pairs has relevant information in c. 30 | """ 31 | shape = x.shape 32 | if x.ndim == 1: 33 | n = x.shape[0] 34 | return x.reshape((n//2, 2)) 35 | if all([s == 1 for s in shape[2:]]): 36 | return x.reshape((shape[0], shape[1])) 37 | # each of (a, b) has related features 38 | x = x.reshape((shape[0], shape[1], -1)) 39 | # stack those related features into a tall matrix 40 | x_tmp = x.reshape((shape[0]*shape[1], -1)) 41 | tmp_shape = x_tmp.shape 42 | return x_tmp.reshape((int(tmp_shape[0]/2), int(tmp_shape[1]*2))) 43 | 44 | 45 | def _sample_svd(s, rank=0): 46 | if s[0] < 1e-6: 47 | return [0], np.array([1.0]) 48 | probs = s / s[0] if rank == 0 else rank * s / s.sum() 49 | for i, p in enumerate(probs): 50 | if p > 1: 51 | probs[i]=1 52 | sampled_idx = [] 53 | sample_probs = [] 54 | for i, p in enumerate(probs): 55 | #if np.random.rand() < p: 56 | # random sampling from bernulli distribution 57 | if np.random.binomial(1, p): 58 | sampled_idx += [i] 59 | sample_probs += [p] 60 | rank_hat = len(sampled_idx) 61 | if rank_hat == 0: # or (rank != 0 and np.abs(rank_hat - rank) >= 3): 62 | return _sample_svd(s, rank=rank) 63 | return np.array(sampled_idx, dtype=int), np.array(sample_probs) 64 | 65 | 66 | def svd_encode(grad, **kwargs): 67 | orig_size = list(grad.shape) 68 | ndims = grad.ndim 69 | reshaped_flag = False 70 | if ndims != 2: 71 | grad = _resize_to_2d(grad) 72 | shape = list(grad.shape) 73 | ndims = len(shape) 74 | reshaped_flag = True 75 | 76 | if ndims == 2: 77 | u, s, vT = LA.svd(grad, full_matrices=False) 78 | 79 | nuclear_ind = nuclear_indicator(grad, s) 80 | l1_ind = l1_indicator(grad) 81 | print("Step: {}, Nuclear Indicator: {}, L1 Indicator: {}".format( 82 | kwargs['step'], nuclear_ind, l1_ind)) 83 | 84 | 85 | '''this is a trial example, we use MNIST on LeNet for simple test here''' 86 | def accuracy(output, target, topk=(1,)): 87 | """Computes the precision@k for the specified values of k""" 88 | maxk = max(topk) 89 | batch_size = target.size(0) 90 | 91 | _, pred = output.topk(maxk, 1, True, True) 92 | pred = pred.t() 93 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 94 | 95 | res = [] 96 | for k in topk: 97 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 98 | res.append(correct_k.mul_(100.0 / batch_size)) 99 | return res 100 | 101 | class NN_Trainer(object): 102 | def __init__(self, **kwargs): 103 | self.batch_size = kwargs['batch_size'] 104 | self.lr = kwargs['learning_rate'] 105 | self.max_epochs = kwargs['max_epochs'] 106 | self.momentum = kwargs['momentum'] 107 | self.network_config = kwargs['network'] 108 | 109 | def build_model(self): 110 | # build network 111 | if self.network_config == "LeNet": 112 | self.network=LeNet() 113 | elif self.network_config == "ResNet18": 114 | self.network=ResNet18() 115 | elif self.network_config == "ResNet34": 116 | self.network=ResNet34() 117 | elif self.network_config == "AlexNet": 118 | self.network=alexnet() 119 | # set up optimizer 120 | self.optimizer = torch.optim.SGD(self.network.parameters(), lr=self.lr, momentum=self.momentum) 121 | self.criterion = torch.nn.CrossEntropyLoss() 122 | 123 | def train_and_validate(self, train_loader, test_loader): 124 | # iterate of epochs 125 | for i in range(self.max_epochs): 126 | # change back to training mode 127 | self.network.train() 128 | for batch_idx, (data, y_batch) in enumerate(train_loader): 129 | iter_start_time = time.time() 130 | data, target = Variable(data), Variable(y_batch) 131 | self.optimizer.zero_grad() 132 | ################# backward on normal model ############################ 133 | 134 | logits = self.network(data) 135 | loss = self.criterion(logits, target) 136 | loss.backward() 137 | ####################################################################### 138 | 139 | ################ backward on splitted model ########################### 140 | ''' 141 | logits = self.network(data) 142 | logits_1 = Variable(logits.data, requires_grad=True) 143 | loss = self.criterion(logits_1, target) 144 | loss.backward() 145 | self.network.backward_single(logits_1.grad) 146 | ''' 147 | ####################################################################### 148 | tmp_time_0 = time.time() 149 | 150 | #for param in self.network.parameters(): 151 | # grads = param.grad.data.numpy().astype(np.float64) 152 | # svd_encode(grads, step=batch_idx) 153 | 154 | duration_backward = time.time()-tmp_time_0 155 | 156 | tmp_time_1 = time.time() 157 | self.optimizer.step() 158 | duration_update = time.time()-tmp_time_1 159 | 160 | # calculate training accuracy 161 | prec1, prec5 = accuracy(logits.data, y_batch, topk=(1, 5)) 162 | # load the training info 163 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} Prec@1: {} Prec@5: {} Time Cost: {}'.format( 164 | i, batch_idx * len(data), len(train_loader.dataset), 165 | 100. * batch_idx / len(train_loader), loss.data[0], 166 | prec1.numpy()[0], 167 | prec5.numpy()[0], time.time()-iter_start_time)) 168 | # we evaluate the model performance on end of each epoch 169 | self.validate(test_loader) 170 | 171 | def validate(self, test_loader): 172 | self.network.eval() 173 | test_loss = 0 174 | correct = 0 175 | prec1_counter_ = prec5_counter_ = batch_counter_ = 0 176 | for data, y_batch in test_loader: 177 | data, target = Variable(data, volatile=True), Variable(y_batch) 178 | output = self.network(data) 179 | test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss 180 | #pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 181 | #correct += pred.eq(target.data.view_as(pred)).cpu().sum() 182 | prec1_tmp, prec5_tmp = accuracy(output.data, y_batch, topk=(1, 5)) 183 | prec1_counter_ += prec1_tmp.numpy()[0] 184 | prec5_counter_ += prec5_tmp.numpy()[0] 185 | batch_counter_ += 1 186 | prec1 = prec1_counter_ / batch_counter_ 187 | prec5 = prec5_counter_ / batch_counter_ 188 | test_loss /= len(test_loader.dataset) 189 | print('Test set: Average loss: {:.4f}, Prec@1: {} Prec@5: {}'.format(test_loss, prec1, prec5)) 190 | -------------------------------------------------------------------------------- /src/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from . import adam, sgd 2 | 3 | __all__ = ['adam', 'sgd'] -------------------------------------------------------------------------------- /src/optim/adam.py: -------------------------------------------------------------------------------- 1 | ''' 2 | modified version of Adam optimizer 3 | by Hongyi Wang 4 | ''' 5 | 6 | import math 7 | import torch 8 | from torch.optim import Optimizer 9 | 10 | 11 | class Adam(Optimizer): 12 | """Implements Adam algorithm. 13 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 14 | Arguments: 15 | params (iterable): iterable of parameters to optimize or dicts defining 16 | parameter groups 17 | lr (float, optional): learning rate (default: 1e-3) 18 | betas (Tuple[float, float], optional): coefficients used for computing 19 | running averages of gradient and its square (default: (0.9, 0.999)) 20 | eps (float, optional): term added to the denominator to improve 21 | numerical stability (default: 1e-8) 22 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 23 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 24 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 25 | .. _Adam\: A Method for Stochastic Optimization: 26 | https://arxiv.org/abs/1412.6980 27 | .. _On the Convergence of Adam and Beyond: 28 | https://openreview.net/forum?id=ryQu7f-RZ 29 | """ 30 | 31 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 32 | weight_decay=0, amsgrad=False): 33 | defaults = dict(lr=lr, betas=betas, eps=eps, 34 | weight_decay=weight_decay, amsgrad=amsgrad) 35 | super(Adam, self).__init__(params, defaults) 36 | 37 | def step(self, grads, closure=None): 38 | """Performs a single optimization step. 39 | Arguments: 40 | closure (callable, optional): A closure that reevaluates the model 41 | and returns the loss. 42 | """ 43 | loss = None 44 | if closure is not None: 45 | loss = closure() 46 | 47 | for group in self.param_groups: 48 | for i,p in enumerate(group['params']): 49 | grad = torch.from_numpy(grads[i]).float() 50 | if grad.is_sparse: 51 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 52 | amsgrad = group['amsgrad'] 53 | 54 | state = self.state[p] 55 | 56 | # State initialization 57 | if len(state) == 0: 58 | state['step'] = 0 59 | # Exponential moving average of gradient values 60 | state['exp_avg'] = torch.zeros_like(p.data) 61 | # Exponential moving average of squared gradient values 62 | state['exp_avg_sq'] = torch.zeros_like(p.data) 63 | if amsgrad: 64 | # Maintains max of all exp. moving avg. of sq. grad. values 65 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 66 | 67 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 68 | if amsgrad: 69 | max_exp_avg_sq = state['max_exp_avg_sq'] 70 | beta1, beta2 = group['betas'] 71 | 72 | state['step'] += 1 73 | 74 | if group['weight_decay'] != 0: 75 | grad = grad.add(group['weight_decay'], p.data) 76 | 77 | # Decay the first and second moment running average coefficient 78 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 79 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 80 | if amsgrad: 81 | # Maintains the maximum of all 2nd moment running avg. till now 82 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 83 | # Use the max. for normalizing running avg. of gradient 84 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 85 | else: 86 | denom = exp_avg_sq.sqrt().add_(group['eps']) 87 | 88 | bias_correction1 = 1 - beta1 ** state['step'] 89 | bias_correction2 = 1 - beta2 ** state['step'] 90 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 91 | 92 | p.data.addcdiv_(-step_size, exp_avg, denom) 93 | 94 | return loss -------------------------------------------------------------------------------- /src/optim/sgd.py: -------------------------------------------------------------------------------- 1 | ''' 2 | modified version of SGD optimizer 3 | by Hongyi Wang 4 | ''' 5 | import torch 6 | from torch.optim import Optimizer 7 | 8 | 9 | class SGD(Optimizer): 10 | r"""Implements stochastic gradient descent (optionally with momentum). 11 | Nesterov momentum is based on the formula from 12 | `On the importance of initialization and momentum in deep learning`__. 13 | Args: 14 | params (iterable): iterable of parameters to optimize or dicts defining 15 | parameter groups 16 | lr (float): learning rate 17 | momentum (float, optional): momentum factor (default: 0) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | dampening (float, optional): dampening for momentum (default: 0) 20 | nesterov (bool, optional): enables Nesterov momentum (default: False) 21 | Example: 22 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 23 | >>> optimizer.zero_grad() 24 | >>> loss_fn(model(input), target).backward() 25 | >>> optimizer.step() 26 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 27 | .. note:: 28 | The implementation of SGD with Momentum/Nesterov subtly differs from 29 | Sutskever et. al. and implementations in some other frameworks. 30 | Considering the specific case of Momentum, the update can be written as 31 | .. math:: 32 | v = \rho * v + g \\ 33 | p = p - lr * v 34 | where p, g, v and :math:`\rho` denote the parameters, gradient, 35 | velocity, and momentum respectively. 36 | This is in contrast to Sutskever et. al. and 37 | other frameworks which employ an update of the form 38 | .. math:: 39 | v = \rho * v + lr * g \\ 40 | p = p - v 41 | The Nesterov version is analogously modified. 42 | """ 43 | 44 | def __init__(self, params, lr=0.1, momentum=0, dampening=0, 45 | weight_decay=0, nesterov=False): 46 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 47 | weight_decay=weight_decay, nesterov=nesterov) 48 | if nesterov and (momentum <= 0 or dampening != 0): 49 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 50 | super(SGD, self).__init__(params, defaults) 51 | 52 | def __setstate__(self, state): 53 | super(SGD, self).__setstate__(state) 54 | for group in self.param_groups: 55 | group.setdefault('nesterov', False) 56 | 57 | def step(self, grads, closure=None, cuda=False): 58 | """Performs a single optimization step. 59 | Arguments: 60 | closure (callable, optional): A closure that reevaluates the model 61 | and returns the loss. 62 | """ 63 | loss = None 64 | if closure is not None: 65 | loss = closure() 66 | 67 | for group in self.param_groups: 68 | weight_decay = group['weight_decay'] 69 | momentum = group['momentum'] 70 | dampening = group['dampening'] 71 | nesterov = group['nesterov'] 72 | 73 | for i,p in enumerate(group['params']): 74 | d_p = torch.from_numpy(grads[i]).float().cuda() if cuda else torch.from_numpy(grads[i]).float() 75 | if weight_decay != 0: 76 | d_p.add_(weight_decay, p.data) 77 | if momentum != 0: 78 | param_state = self.state[p] 79 | if 'momentum_buffer' not in param_state: 80 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 81 | buf.mul_(momentum).add_(d_p) 82 | else: 83 | buf = param_state['momentum_buffer'] 84 | buf.mul_(momentum).add_(1 - dampening, d_p) 85 | if nesterov: 86 | d_p = d_p.add(momentum, buf) 87 | else: 88 | d_p = buf 89 | p.data.add_(-group['lr'], d_p) 90 | return loss -------------------------------------------------------------------------------- /src/output/models/README.md: -------------------------------------------------------------------------------- 1 | # This directory is used to save the intermediate model during the training process for evaluation -------------------------------------------------------------------------------- /src/run_pytorch.sh: -------------------------------------------------------------------------------- 1 | mpirun -n 3 --hostfile hosts_address \ 2 | python distributed_nn.py \ 3 | --lr=0.01 \ 4 | --lr-shrinkage=0.95 \ 5 | --momentum=0.0 \ 6 | --network=ResNet18 \ 7 | --dataset=Cifar10 \ 8 | --batch-size=128 \ 9 | --test-batch-size=200 \ 10 | --comm-type=Bcast \ 11 | --num-aggregate=2 \ 12 | --eval-freq=200 \ 13 | --epochs=10 \ 14 | --max-steps=1000000 \ 15 | --svd-rank=3 \ 16 | --quantization-level=4 \ 17 | --bucket-size=512 \ 18 | --code=sgd \ 19 | --enable-gpu= \ 20 | --train-dir=/home/ubuntu/ -------------------------------------------------------------------------------- /src/single_machine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import threading 5 | import argparse 6 | import time 7 | 8 | import torch 9 | from torch.autograd import Variable 10 | from torch._utils import _flatten_tensors, _unflatten_tensors 11 | from torch.cuda.comm import broadcast_coalesced 12 | from torch.cuda import nccl 13 | 14 | import torch.nn as nn 15 | from torch.nn.parallel.replicate import replicate 16 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather 17 | from torch.nn.parallel.parallel_apply import parallel_apply 18 | import torch.nn.functional as F 19 | 20 | from torchvision import datasets, transforms 21 | 22 | from nn_ops import NN_Trainer, accuracy 23 | from data_loader_ops.my_data_loader import DataLoader 24 | 25 | from cifar10 import cifar10 26 | from datasets import Cifar10Dataset 27 | 28 | 29 | def add_fit_args(parser): 30 | """ 31 | parser : argparse.ArgumentParser 32 | return a parser added with args required by fit 33 | """ 34 | # Training settings 35 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 36 | help='input batch size for training (default: 64)') 37 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 38 | help='input batch size for testing (default: 1000)') 39 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 40 | help='number of epochs to train (default: 10)') 41 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 42 | help='learning rate (default: 0.01)') 43 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 44 | help='SGD momentum (default: 0.5)') 45 | parser.add_argument('--no-cuda', action='store_true', default=False, 46 | help='disables CUDA training') 47 | parser.add_argument('--seed', type=int, default=1, metavar='S', 48 | help='random seed (default: 1)') 49 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 50 | help='how many batches to wait before logging training status') 51 | parser.add_argument('--network', type=str, default='LeNet', metavar='N', 52 | help='which kind of network we are going to use, support LeNet and ResNet currently') 53 | parser.add_argument('--dataset', type=str, default='MNIST', metavar='N', 54 | help='which dataset used in training, MNIST and Cifar10 supported currently') 55 | args = parser.parse_args() 56 | return args 57 | 58 | # we use LeNet here for our simple case 59 | class LeNet(nn.Module): 60 | def __init__(self): 61 | super(LeNet, self).__init__() 62 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 63 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 64 | self.fc1 = nn.Linear(4*4*50, 500) 65 | self.fc2 = nn.Linear(500, 10) 66 | self.ceriation = nn.CrossEntropyLoss() 67 | def forward(self, x, target): 68 | x = self.conv1(x) 69 | x = F.max_pool2d(x, 2, 2) 70 | x = F.relu(x) 71 | x = self.conv2(x) 72 | x = F.max_pool2d(x, 2, 2) 73 | x = F.relu(x) 74 | x = x.view(-1, 4*4*50) 75 | x = self.fc1(x) 76 | x = self.fc2(x) 77 | loss = self.ceriation(x, target) 78 | return x, loss 79 | def name(self): 80 | return 'lenet' 81 | 82 | class LeNetLearner: 83 | """a deprecated class, please don't call this one in any time""" 84 | def __init__(self, rank, world_size, args): 85 | self._step_changed = False 86 | self._update_step = False 87 | self._new_step_queued = 0 88 | self._rank = rank 89 | self._world_size = world_size 90 | self._cur_step = 0 91 | self._next_step = self._cur_step + 1 92 | self._step_fetch_request = False 93 | self.max_num_epochs = args.epochs 94 | self.lr = args.lr 95 | self.momentum = args.momentum 96 | 97 | def build_model(self): 98 | self.network = LeNet() 99 | 100 | # only for test use 101 | self.module = self.network 102 | 103 | # this is only used for test 104 | self.optimizer = torch.optim.SGD(self.network.parameters(), lr=self.lr, momentum=self.momentum) 105 | 106 | def test_model(self): 107 | '''this is only for test, please don't call this function''' 108 | from copy import deepcopy 109 | self._module_copies = [deepcopy(self.module)] 110 | self.device_ids = [] 111 | 112 | t = None 113 | for p in self.module.parameters(): 114 | tp = type(p.data) 115 | if t is not None and t is not tp: 116 | raise ValueError("DistributedDataParallel requires all parameters' data to be of the same type") 117 | t = tp 118 | 119 | self.bucket_sizes = [] 120 | self.bucket_map = {} 121 | MB = 1024 * 1024 122 | self.broadcast_bucket_size = 10 * MB # used for param sync before forward 123 | bucket_bytes_cap = 1 * MB 124 | bucket_bytes = bucket_bytes_cap # to init the first bucket immediately 125 | for param_tuple in zip(*map(lambda m: m.parameters(), self._module_copies)): 126 | if bucket_bytes >= bucket_bytes_cap: 127 | self.bucket_sizes.append(0) 128 | bucket_bytes = 0 129 | self.bucket_sizes[-1] += 1 130 | for p in param_tuple: 131 | self.bucket_map[p] = len(self.bucket_sizes) - 1 132 | bucket_bytes += p.numel() * p.element_size() 133 | 134 | self.buckets = [[[] for _ in range(len(self.device_ids))] for _ in range(len(self.bucket_sizes))] 135 | self.bucket_events = [[None] * len(self.device_ids) for _ in range(len(self.bucket_sizes))] 136 | self.reduced = [False] * len(self.bucket_sizes) 137 | 138 | def train(self, train_loader=None): 139 | self.network.train() 140 | 141 | # iterate of epochs 142 | for i in range(self.max_num_epochs): 143 | for batch_idx, (data, y_batch) in enumerate(train_loader): 144 | iter_start_time = time.time() 145 | data, target = Variable(data), Variable(y_batch) 146 | self.optimizer.zero_grad() 147 | logits, loss = self.network(data, target) 148 | tmp_time_0 = time.time() 149 | loss.backward() 150 | 151 | for params in self.network.parameters(): 152 | print(params.grad.data.numpy()) 153 | print('**********************************************************') 154 | 155 | if batch_idx == 5: 156 | self.update_state_dict() 157 | 158 | duration_backward = time.time()-tmp_time_0 159 | 160 | tmp_time_1 = time.time() 161 | self.optimizer.step() 162 | duration_update = time.time()-tmp_time_1 163 | 164 | print("backward duration: {}".format(duration_backward)) 165 | print("update duration: {}".format(duration_update)) 166 | # calculate training accuracy 167 | prec1, prec5 = accuracy(logits.data, y_batch, topk=(1, 5)) 168 | # load the training info 169 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} Prec@1: {} Prec@5: {} Time Cost: {}'.format( 170 | i, batch_idx * len(data), len(train_loader.dataset), 171 | 100. * batch_idx / len(train_loader), loss.data[0], 172 | prec1.numpy()[0], 173 | prec5.numpy()[0], time.time()-iter_start_time)) 174 | 175 | def update_state_dict(self): 176 | """for this test version, we set all params to zeros here""" 177 | # we need to build a state dict first 178 | new_state_dict = {} 179 | for key_name, param in self.network.state_dict().items(): 180 | tmp_dict = {key_name: torch.FloatTensor(param.size()).zero_()} 181 | new_state_dict.update(tmp_dict) 182 | self.network.load_state_dict(new_state_dict) 183 | 184 | 185 | if __name__ == "__main__": 186 | args = add_fit_args(argparse.ArgumentParser(description='PyTorch MNIST Single Machine Test')) 187 | 188 | kwargs = {'batch_size':args.batch_size, 'learning_rate':args.lr, 'max_epochs':args.epochs, 'momentum':args.momentum, 'network':args.network} 189 | 190 | # load training and test set here: 191 | if args.dataset == "MNIST": 192 | training_set = datasets.MNIST('../data', train=True, download=True, 193 | transform=transforms.Compose([ 194 | transforms.ToTensor(), 195 | transforms.Normalize((0.1307,), (0.3081,))])) 196 | train_loader = DataLoader(training_set, batch_size=args.batch_size, shuffle=True) 197 | test_loader = torch.utils.data.DataLoader( 198 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 199 | transforms.ToTensor(), 200 | transforms.Normalize((0.1307,), (0.3081,)) 201 | ])), batch_size=args.test_batch_size, shuffle=True) 202 | elif args.dataset == "Cifar10": 203 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 204 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 205 | transform_train = transforms.Compose([ 206 | transforms.ToTensor(), 207 | transforms.Lambda(lambda x: F.pad( 208 | Variable(x.unsqueeze(0), requires_grad=False, volatile=True), 209 | (4,4,4,4),mode='reflect').data.squeeze()), 210 | transforms.ToPILImage(), 211 | transforms.RandomCrop(32), 212 | transforms.RandomHorizontalFlip(), 213 | transforms.ToTensor(), 214 | normalize, 215 | ]) 216 | # data prep for test set 217 | transform_test = transforms.Compose([ 218 | transforms.ToTensor(), 219 | normalize]) 220 | # load training and test set here: 221 | training_set = datasets.CIFAR10(root='./cifar10_data', train=True, 222 | download=True, transform=transform_train) 223 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=args.batch_size, 224 | shuffle=True) 225 | testset = datasets.CIFAR10(root='./cifar10_data', train=False, 226 | download=True, transform=transform_test) 227 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, 228 | shuffle=False) 229 | elif args.dataset == "ImageNet": 230 | # Data loading code 231 | traindir = os.path.join('/home/ubuntu/data/' 'train') 232 | valdir = os.path.join('/home/ubuntu/data/', 'val') 233 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 234 | std=[0.229, 0.224, 0.225]) 235 | 236 | train_dataset = datasets.ImageFolder( 237 | traindir, 238 | transforms.Compose([ 239 | transforms.RandomResizedCrop(224), 240 | transforms.RandomHorizontalFlip(), 241 | transforms.ToTensor(), 242 | normalize, 243 | ])) 244 | test_dataset = datasets.ImageFolder( 245 | valdir, 246 | transforms.Compose([ 247 | transforms.ToTensor(), 248 | normalize, 249 | ])) 250 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, 251 | shuffle=True) 252 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch_size, 253 | shuffle=False) 254 | 255 | nn_learner = NN_Trainer(**kwargs) 256 | nn_learner.build_model() 257 | nn_learner.train_and_validate(train_loader=train_loader, test_loader=test_loader) -------------------------------------------------------------------------------- /src/sync_replicas_master_nn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import time 3 | import copy 4 | from sys import getsizeof 5 | import warnings 6 | import pickle 7 | 8 | from mpi4py import MPI 9 | import numpy as np 10 | 11 | from nn_ops import NN_Trainer 12 | 13 | from model_ops.lenet import LeNet, LeNetSplit 14 | from model_ops.resnet import * 15 | from model_ops.resnet_split import * 16 | from model_ops.vgg import * 17 | from model_ops.alexnet import * 18 | from model_ops.fc_nn import FC_NN, FC_NN_Split 19 | from model_ops.densenet import DenseNet 20 | 21 | from optim.adam import Adam 22 | from optim.sgd import SGD 23 | from utils import decompress 24 | 25 | import torch 26 | import codings 27 | 28 | STEP_START_ = 1 29 | # use compression tool to make it run faster 30 | _FAKE_SGD = True 31 | 32 | def update_params_dist_version(param, avg_grad, learning_rate): 33 | ''' 34 | update the network layer by layer 35 | ''' 36 | assert param.shape == avg_grad.shape 37 | param -= learning_rate * avg_grad 38 | return param 39 | 40 | 41 | def accuracy(output, target, topk=(1,)): 42 | """Computes the precision@k for the specified values of k""" 43 | maxk = max(topk) 44 | batch_size = target.size(0) 45 | 46 | _, pred = output.topk(maxk, 1, True, True) 47 | pred = pred.t() 48 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 49 | 50 | res = [] 51 | for k in topk: 52 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 53 | res.append(correct_k.mul_(100.0 / batch_size)) 54 | return res 55 | 56 | 57 | class GradientAccumulator(object): 58 | ''' 59 | a simple class to implement gradient aggregator like the `Conditional Accumulators` in tensorflow 60 | ''' 61 | def __init__(self, module, num_worker, mode='None'): 62 | super(GradientAccumulator, self).__init__() 63 | # we will update this counter dynamically during the training process 64 | # the length of this counter should be number of fc layers in the network 65 | # we used list to contain gradients of layers 66 | self.gradient_aggregate_counter = [] 67 | self.model_index_range = [] 68 | self.gradient_aggregator = [] 69 | 70 | for param_idx, param in enumerate(module.parameters()): 71 | tmp_aggregator = [] 72 | for worker_idx in range(num_worker): 73 | _shape = param.size() 74 | if len(_shape) == 1: 75 | tmp_aggregator.append(bytearray(getsizeof(np.zeros((_shape[0]*2,)))*4)) 76 | else: 77 | tmp_aggregator.append(bytearray(getsizeof(np.zeros(_shape))*4)) 78 | # initialize the gradient aggragator 79 | self.gradient_aggregator.append(tmp_aggregator) 80 | self.gradient_aggregate_counter.append(0) 81 | self.model_index_range.append(param_idx) 82 | 83 | def meset_everything(self): 84 | self._meset_grad_counter() 85 | self._meset_grad_aggregator() 86 | 87 | def _meset_grad_counter(self): 88 | self.gradient_aggregate_counter = [0 for _ in self.gradient_aggregate_counter] 89 | 90 | def _meset_grad_aggregator(self): 91 | pass 92 | 93 | 94 | class SyncReplicasMaster_NN(NN_Trainer): 95 | def __init__(self, comm, **kwargs): 96 | '''master node here, no rank needed since the rank will always be 0 for master node''' 97 | self.comm = comm # get MPI communicator object 98 | self.world_size = comm.Get_size() # total number of processes 99 | self.cur_step = STEP_START_ 100 | # initial learning rate: 101 | self.lr = kwargs['learning_rate'] 102 | # we use this parameter to shrink the step size per epoch 103 | self._lr_shrinkage = kwargs['lr_shrinkage'] 104 | self._base_lr = kwargs['learning_rate'] 105 | # TODO(hwang): change this to a more sophisticated version later 106 | self.shrinkage_freq = 50 107 | self.shrink_counter = 0 108 | 109 | self.momentum = kwargs['momentum'] 110 | self.network_config = kwargs['network'] 111 | self.comm_type = kwargs['comm_method'] 112 | 113 | self._num_grad_to_collect = self.world_size - 1 114 | # used to aggregate tmp gradients, the length is the same as # of fc layer 115 | self._grad_aggregate_buffer = [] 116 | self._model_shapes = [] 117 | self._first_grad_received = False 118 | self._eval_freq = kwargs['eval_freq'] 119 | self._train_dir = kwargs['train_dir'] 120 | self._max_steps = kwargs['max_steps'] 121 | self._compress = kwargs['compress'] 122 | 123 | self._enable_gpu = kwargs['enable_gpu'] 124 | self._num_aggregate = kwargs['num_aggregate'] 125 | 126 | self._svd_rank = kwargs['svd_rank'] 127 | self._quantization_level = kwargs['quantization_level'] 128 | self._bucket_size = kwargs['bucket_size'] 129 | self._r = self._svd_rank 130 | 131 | ############ will be deprecated soon ############################# 132 | self._eval_batch_size = 1000 133 | 134 | if kwargs['code'] == 'sgd': 135 | if not _FAKE_SGD: 136 | self._coder = codings.svd.SVD(compress=False) 137 | else: 138 | self._coder = codings.lossless_compress.LosslessCompress() 139 | elif kwargs['code'] == 'svd': 140 | print("train.py, svd_rank =", self._svd_rank) 141 | self._coder = codings.svd.SVD(random_sample=False, rank=self._svd_rank, 142 | compress=True) 143 | else: 144 | raise ValueError('args.code not recognized') 145 | 146 | def build_model(self, num_classes=10): 147 | # build network 148 | if self.network_config == "LeNet": 149 | self.network=LeNet() 150 | elif self.network_config == "ResNet18": 151 | self.network=ResNet18(num_classes=num_classes) 152 | elif self.network_config == "ResNet34": 153 | self.network=ResNet34(num_classes=num_classes) 154 | elif self.network_config == "FC": 155 | self.network=FC_NN() 156 | elif self.network_config == "DenseNet": 157 | self.network=DenseNet(growthRate=40, depth=190, reduction=0.5, 158 | bottleneck=True, nClasses=10) 159 | elif self.network_config == "VGG11": 160 | self.network=vgg11_bn(num_classes) 161 | elif self.network_config == "AlexNet": 162 | self.network=alexnet(num_classes=10) 163 | 164 | # TODO(hwang): make sure this is useful 165 | self.optimizer = SGD(self.network.parameters(), lr=self.lr, momentum=self.momentum) 166 | # assign a gradient accumulator to collect gradients from workers 167 | self.grad_accumulator = GradientAccumulator(self.network, self.world_size-1, self._compress) 168 | self.init_model_shapes() 169 | # enable GPU here 170 | if self._enable_gpu: 171 | self.network.cuda() 172 | 173 | def train(self): 174 | # the first step we need to do here is to sync fetch the inital worl_step from the parameter server 175 | # we still need to make sure the value we fetched from parameter server is 1 176 | # please note that step is start from one here 177 | self.async_bcast_step() 178 | 179 | # fake test here: 180 | for i in range(1, self._max_steps+1): 181 | # switch back to training mode 182 | self.network.train() 183 | self._first_grad_received = False 184 | enough_gradients_received = False 185 | 186 | print("Master node is entering step: {}".format(i)) 187 | 188 | self.async_bcast_step() 189 | 190 | self.async_bcast_layer_weights_bcast() 191 | 192 | # set the gradient fetch step and gather the request 193 | gradient_fetch_requests=self.async_fetch_gradient_start() 194 | 195 | coded_msgs = {} 196 | # wait for enough gradients to be aggregated: 197 | gather_start_time = time.time() 198 | while not enough_gradients_received: 199 | status = MPI.Status() 200 | source, code = MPI.Request.waitany(requests=gradient_fetch_requests, status=status) 201 | layer_index = status.tag-88 202 | if layer_index not in coded_msgs.keys(): 203 | coded_msgs[layer_index] = [code] 204 | else: 205 | coded_msgs[layer_index].append(code) 206 | 207 | self.grad_accumulator.gradient_aggregate_counter[layer_index] += 1 208 | 209 | #print(self.grad_accumulator.gradient_aggregate_counter) 210 | #print('---------------------------------------------------------------------') 211 | 212 | enough_gradients_received = True 213 | for j in self.grad_accumulator.gradient_aggregate_counter: 214 | enough_gradients_received = enough_gradients_received and (j >= self._num_grad_to_collect) 215 | gather_duration = time.time() - gather_start_time 216 | 217 | decode_start = time.time() 218 | self._decode(coded_msgs) 219 | decode_dur = time.time() - decode_start 220 | # update `state_dict` in pytorch modules 221 | print("Master: Step: {}, Decode Cost: {}, Cur lr {}, Gather: {}".format(self.cur_step, decode_dur, self.lr, gather_duration)) 222 | self._model_update() 223 | 224 | # reset essential elements 225 | self.meset_grad_buffer() 226 | self.grad_accumulator.meset_everything() 227 | # save model for validation in a pre-specified frequency 228 | #if self.cur_step%self._eval_freq == 0: 229 | # if "ResNet" not in self.network_config: 230 | # self._save_model(file_path=self._generate_model_path()) 231 | self.cur_step += 1 232 | if self.cur_step % self.shrinkage_freq == 0: 233 | self.shrink_counter += 1 234 | self.lr = self._base_lr * self._lr_shrinkage ** self.shrink_counter 235 | 236 | def _model_update(self): 237 | # gradient shipped from workers are averaged and update the model 238 | self._grad_aggregate_buffer = map(lambda x: x / float(self._num_grad_to_collect), self._grad_aggregate_buffer) 239 | self.optimizer.step(grads=self._grad_aggregate_buffer, cuda=self._enable_gpu) 240 | 241 | def init_model_shapes(self): 242 | for p_index, p in enumerate(self.network.parameters()): 243 | self._model_shapes.append(p.size()) 244 | self._grad_aggregate_buffer.append(np.zeros(p.size())) 245 | 246 | def async_bcast_step(self): 247 | req_list = [] 248 | for i in range(self.world_size): 249 | if i != 0: 250 | req_list.append(self.comm.isend(self.cur_step, dest=i, tag=10)) 251 | for i in range(len(req_list)): 252 | req_list[i].wait() 253 | 254 | def async_bcast_layer_weights_async(self): 255 | request_layers = [] 256 | for layer_idx, layer in enumerate(self.network.parameters()): 257 | request_workers = [] 258 | layer_to_send = layer.data.numpy().astype(np.float64) 259 | for i in range(self.world_size): 260 | if i != 0: 261 | req = self.comm.Isend([layer_to_send, MPI.DOUBLE], dest=i, tag=11+layer_idx) 262 | request_workers.append(req) 263 | 264 | request_layers.append(request_workers) 265 | # TODO(hwang): check to see if these `wait` calls are necessary here 266 | for req_l in request_layers: 267 | for req_worker in req_l: 268 | req_worker.wait() 269 | 270 | def async_bcast_layer_weights_bcast(self): 271 | request_layers = [] 272 | for layer_idx, layer in enumerate(self.network.parameters()): 273 | request_workers = [] 274 | if self._enable_gpu: 275 | # copy data to CPU then do the communicaiton staff 276 | layer_to_send = layer.data.cpu().numpy().astype(np.float64) 277 | else: 278 | layer_to_send = layer.data.numpy().astype(np.float64) 279 | self.comm.Bcast([layer_to_send, MPI.DOUBLE], root=0) 280 | 281 | def async_fetch_gradient_start(self): 282 | ''' 283 | make gradient fetch requests and return the request list 284 | ''' 285 | gradient_fetch_requests = [] 286 | for module_idx, module in enumerate(self.network.parameters()): 287 | for k in range(self._num_grad_to_collect): 288 | req = self.comm.irecv(self.grad_accumulator.gradient_aggregator[module_idx][k], source=k+1, tag=88+module_idx) 289 | gradient_fetch_requests.append(req) 290 | return gradient_fetch_requests 291 | 292 | def aggregate_gradient(self, gradient, layer_idx): 293 | ''' 294 | keep in mind the gradient here is wrapped gradient, which means it contains `W` and `b` 295 | ''' 296 | self._grad_aggregate_buffer[layer_idx] += gradient.numpy().astype(np.float64) 297 | 298 | def model_update(self, tmp_module): 299 | """write model fetched from parameter server to local model""" 300 | new_state_dict = {} 301 | model_counter_ = 0 302 | for param_idx,(key_name, param) in enumerate(self.network.state_dict().items()): 303 | # handle the case that `running_mean` and `running_var` contained in `BatchNorm` layer 304 | if "running_mean" in key_name or "running_var" in key_name: 305 | tmp_dict = {key_name : param} 306 | else: 307 | assert param.size() == tmp_module[model_counter_].shape 308 | tmp_dict = {key_name: torch.from_numpy(tmp_module[model_counter_])} 309 | model_counter_+=1 310 | new_state_dict.update(tmp_dict) 311 | self.network.load_state_dict(new_state_dict) 312 | 313 | def meset_grad_buffer(self): 314 | for i in range(len(self._grad_aggregate_buffer)): 315 | self._grad_aggregate_buffer[i][0] = np.zeros(self._grad_aggregate_buffer[i][0].shape) 316 | if len(self._grad_aggregate_buffer[i])==2: 317 | self._grad_aggregate_buffer[i][1] = np.zeros(self._grad_aggregate_buffer[i][1].shape) 318 | 319 | def _decode(self, coded_msgs): 320 | # k: `layer_index` v: coded gradients 321 | for index, (k, v) in enumerate(coded_msgs.iteritems()): 322 | for code in v: 323 | code = pickle.loads(code) 324 | grad=self._coder.decode(code) 325 | try: 326 | assert (grad.shape == self._model_shapes[k]) 327 | except AssertionError: 328 | warnings.warn("shape dosen't match, should really be careful") 329 | self.aggregate_gradient(gradient=grad, layer_idx=k) 330 | 331 | def _generate_model_path(self): 332 | return self._train_dir+"model_step_"+str(self.cur_step) 333 | 334 | def _save_model(self, file_path): 335 | with open(file_path, "wb") as f_: 336 | torch.save(self.network.state_dict(), f_) 337 | 338 | def _evaluate_model(self, validation_loader): 339 | self.network.eval() 340 | prec1_counter_ = prec5_counter_ = batch_counter_ = 0 341 | # which indicate an epoch based validation is done 342 | while validation_loader.dataset.epochs_completed <= self._epoch_counter: 343 | eval_image_batch, eval_label_batch = validation_loader.next_batch(batch_size=self._eval_batch_size) 344 | X_batch, y_batch = Variable(eval_image_batch.float()), Variable(eval_label_batch.long()) 345 | output = self.network(X_batch) 346 | prec1_tmp, prec5_tmp = accuracy(output.data, eval_label_batch.long(), topk=(1, 5)) 347 | prec1_counter_ += prec1_tmp 348 | prec5_counter_ += prec5_tmp 349 | batch_counter_ += 1 350 | prec1 = prec1_counter_ / batch_counter_ 351 | prec5 = prec5_counter_ / batch_counter_ 352 | self._epoch_counter = validation_loader.dataset.epochs_completed 353 | if self._enable_gpu: 354 | prec1 = prec1.cpu().numpy()[0] 355 | prec5 = prec5.cpu().numpy()[0] 356 | else: 357 | prec1 = prec1.numpy()[0] 358 | prec5 = prec5.numpy()[0] 359 | print('Testset Performance: Cur Step:{} Prec@1: {} Prec@5: {}'.format(self.cur_step, prec1, prec5)) -------------------------------------------------------------------------------- /src/tiny_tuning_parser.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser(description='Distributed Tunning') 5 | parser.add_argument('--tuning-dir', type=str, default='./tune/', metavar='N', 6 | help='directory to save temp tuning logs') 7 | parser.add_argument('--tuning-lr', type=float, default=0.125, metavar='N', 8 | help='candidate learning rate used during tunning') 9 | parser.add_argument('--num-workers', type=int, default=16, 10 | help='number of workers in the cluster') 11 | args = parser.parse_args() 12 | 13 | loss_stat = [] 14 | with open(args.tuning_dir, 'rb') as file: 15 | for line in file.readlines(): 16 | line_content = line.rstrip('\n') 17 | search = re.search( 18 | 'Worker: .*, Step: .*, Epoch: .* \[.* \(.*\)\], Loss: (.*), Time Cost: .*, Comp: .*, Encode: .*, Comm: .*, Msg\(MB\): .*', 19 | line_content) 20 | if search: 21 | loss = float(search.group(1)) 22 | loss_stat.append(loss) 23 | try: 24 | assert len(loss_stat) == args.num_workers 25 | except AssertionError: 26 | print("Illeagel Number of Workers! ") 27 | print("Avged loss for lr candidate: {}=========>{}".format(args.tuning_lr, sum(loss_stat)/float(len(loss_stat)))) -------------------------------------------------------------------------------- /src/tune.sh: -------------------------------------------------------------------------------- 1 | tune_dir=~/grad_lossy_compression/src/tune/ 2 | max_tuning_step=100 3 | method=svd 4 | mkdir ${tune_dir} 5 | 6 | echo "Start parameter tuning ..." 7 | for lr in 0.0078125 0.015625 0.03125 0.0625 0.125 0.25 0.5 8 | do 9 | echo "Trial running for learning rate: ${lr}" 10 | mpirun -n 17 --hostfile hosts_address \ 11 | python distributed_nn.py \ 12 | --lr=${lr} \ 13 | --lr-shrinkage=1.0 \ 14 | --momentum=0.0 \ 15 | --network=ResNet18 \ 16 | --dataset=Cifar10 \ 17 | --batch-size=8 \ 18 | --test-batch-size=200 \ 19 | --comm-type=Bcast \ 20 | --num-aggregate=16 \ 21 | --eval-freq=200 \ 22 | --epochs=10 \ 23 | --max-steps=${max_tuning_step} \ 24 | --svd-rank=3 \ 25 | --quantization-level=4 \ 26 | --bucket-size=512 \ 27 | --code=${method} \ 28 | --enable-gpu= \ 29 | --train-dir=/home/ubuntu/ > ${tune_dir}${method}_lr_${lr} 30 | 31 | cat ${tune_dir}${method}_lr_${lr} | grep Step:\ ${max_tuning_step} > ${tune_dir}${method}_lr_${lr}_processing 32 | bash ~/killall.sh 33 | done 34 | 35 | for lr in 0.0078125 0.015625 0.03125 0.0625 0.125 0.25 0.5 36 | do 37 | echo "Logging out tunning results" 38 | python tiny_tuning_parser.py \ 39 | --tuning-dir=${tune_dir}${method}_lr_${lr}_processing \ 40 | --tuning-lr=${lr} \ 41 | --num-workers=16 42 | done -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import blosc 2 | 3 | def compress(msg, level=0, name='blosclz'): 4 | """ 5 | Compress a message. 6 | """ 7 | if name in {'lz4', 'snappy'}: 8 | raise ValueError('Do not specify lz4 or snappy. I ran into hard to ' 9 | 'debug issues when I did this. blosclz seems to work') 10 | # we always use default level for now 11 | code = blosc.compress(msg, cname=name) 12 | return bytearray(code) 13 | 14 | def decompress(code): 15 | msg = blosc.decompress(code) 16 | return msg -------------------------------------------------------------------------------- /tools/config: -------------------------------------------------------------------------------- 1 | Host * 2 | StrictHostKeyChecking no 3 | -------------------------------------------------------------------------------- /tools/hosts: -------------------------------------------------------------------------------- 1 | 172.31.28.130 deeplearning-worker1 2 | 172.31.29.102 deeplearning-worker2 3 | 172.31.27.116 deeplearning-worker3 4 | 172.31.28.111 deeplearning-worker4 5 | 172.31.27.53 deeplearning-worker5 6 | 172.31.22.220 deeplearning-worker6 7 | 172.31.19.79 deeplearning-worker7 8 | 172.31.18.28 deeplearning-worker8 9 | 172.31.16.213 deeplearning-worker9 10 | -------------------------------------------------------------------------------- /tools/hosts_address: -------------------------------------------------------------------------------- 1 | 172.31.28.130 2 | 172.31.29.102 3 | 172.31.27.116 4 | 172.31.28.111 5 | 172.31.27.53 6 | 172.31.22.220 7 | 172.31.19.79 8 | 172.31.18.28 9 | 172.31.16.213 10 | -------------------------------------------------------------------------------- /tools/hosts_alias: -------------------------------------------------------------------------------- 1 | deeplearning-worker1 2 | deeplearning-worker2 3 | deeplearning-worker3 4 | deeplearning-worker4 5 | deeplearning-worker5 6 | deeplearning-worker6 7 | deeplearning-worker7 8 | deeplearning-worker8 9 | deeplearning-worker9 10 | -------------------------------------------------------------------------------- /tools/install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # append hosts 4 | sudo bash -c "cat hosts >> /etc/hosts" 5 | cp config ~/.ssh/ 6 | 7 | export DEEPLEARNING_WORKERS_COUNT=`wc -l < hosts` 8 | 9 | git config --global user.name hwang595 10 | git config --global user.email hongyiwang@cs.wisc.edu 11 | 12 | sudo apt-get update 13 | sudo apt-get install pdsh -y 14 | pdsh -R ssh -w deeplearning-worker[1-$DEEPLEARNING_WORKERS_COUNT] "sudo apt-get update; sudo apt-get install pdsh -y" 15 | -------------------------------------------------------------------------------- /tools/local_script.sh: -------------------------------------------------------------------------------- 1 | KEY_PEM_DIR=/home/hwang/My_Code/AWS/HongyiWKeyPair.pem 2 | KEY_PEM_NAME=HongyiWKeyPair.pem 3 | PUB_IP_ADDR="$1" 4 | echo "Public address of master node: ${PUB_IP_ADDR}" 5 | 6 | ssh -o "StrictHostKeyChecking no" ubuntu@${PUB_IP_ADDR} 7 | scp -i ${KEY_PEM_DIR} ${KEY_PEM_DIR} ubuntu@${PUB_IP_ADDR}:~/.ssh 8 | scp -i ${KEY_PEM_DIR} hosts hosts_address config ubuntu@${PUB_IP_ADDR}:~/ 9 | scp -i ${KEY_PEM_DIR} -r ~/My_Code/grad_lossy_compression ubuntu@${PUB_IP_ADDR}:~/ 10 | ssh -i ${KEY_PEM_DIR} ubuntu@${PUB_IP_ADDR} 'sudo apt-get update; cp /home/ubuntu/grad_lossy_compression/tools/remote_script.sh ~/' -------------------------------------------------------------------------------- /tools/pre_run.sh: -------------------------------------------------------------------------------- 1 | sudo apt-get update 2 | conda update -y -n base conda 3 | #conda install pytorch torchvision -y -c pytorch 4 | conda install -y pytorch=0.3.0 -c soumith 5 | conda install -y torchvision=0.1.8 6 | conda install -y -c anaconda python-blosc 7 | conda install -y -c anaconda mpi4py 8 | conda install -y libgcc 9 | 10 | # install and figure hdmedians 11 | cd ~ 12 | sudo apt-get install -y gcc 13 | source /home/ubuntu/anaconda2/bin/activate ~/anaconda2 14 | git clone https://github.com/daleroberts/hdmedians.git 15 | cd hdmedians 16 | python setup.py install -------------------------------------------------------------------------------- /tools/remote_script.sh: -------------------------------------------------------------------------------- 1 | KEY_PEM_NAME=HongyiWKeyPair.pem 2 | export DEEPLEARNING_WORKERS_COUNT=`wc -l < hosts` 3 | 4 | sudo bash -c "cat hosts >> /etc/hosts" 5 | cp config ~/.ssh/ 6 | 7 | cd ~/.ssh 8 | eval `ssh-agent -s` 9 | ssh-add ${KEY_PEM_NAME} 10 | ssh-keygen -t rsa -b 4096 -C "hongyiwang.hdu@gmail.com" 11 | 12 | for i in $(seq 2 $DEEPLEARNING_WORKERS_COUNT); 13 | do 14 | scp -i ${KEY_PEM_NAME} id_rsa.pub deeplearning-worker${i}:~/.ssh 15 | scp -i ${KEY_PEM_NAME} -r /home/ubuntu/grad_lossy_compression deeplearning-worker${i}:~/.ssh 16 | ssh -i ${KEY_PEM_NAME} deeplearning-worker${i} 'cd ~/.ssh; cat id_rsa.pub >> authorized_keys' 17 | echo "Done writing public key to worker: deeplearning-worker${i}" 18 | done -------------------------------------------------------------------------------- /tools/update_git_dir.sh: -------------------------------------------------------------------------------- 1 | cd ~ 2 | KEY_PEM_NAME=HongyiScript.pem 3 | export DEEPLEARNING_WORKERS_COUNT=`wc -l < hosts` 4 | 5 | sudo bash -c "cat hosts >> /etc/hosts" 6 | 7 | for i in $(seq 2 $DEEPLEARNING_WORKERS_COUNT); 8 | do 9 | ssh -i ${KEY_PEM_NAME} deeplearning-worker${i} 'git config --global user.name hwang595; git config --global user.email hongyiwang@cs.wisc.edu; cd ~/pytorch_distributed_nn; git pull' 10 | echo "Done pull git repo on worker: deeplearning-worker${i}" 11 | done --------------------------------------------------------------------------------