├── README.md ├── example_notebook.ipynb ├── optim ├── gen_sgd.py ├── models.py ├── prep_data.py ├── train.py └── utils.py ├── quant └── quant.py ├── requirements.txt └── utils └── plotting.py /README.md: -------------------------------------------------------------------------------- 1 | # Code guidelines 2 | 3 | This implementation is based on PyTorch (1.5.0) in Python (3.8). 4 | 5 | It enables to run simulated distributed optimization with master node on any number of workers based on [PyTorch SGD Optimizer](https://pytorch.org/docs/stable/optim.html#torch.optim.SGD) with gradient compression. Communication can be compressed on both workers and master level. Error-Feedback is also enabled. For more details, please see our [manuscript](https://arxiv.org/pdf/2006.11077.pdf). 6 | 7 | ### Installation 8 | 9 | To install requirements 10 | ```sh 11 | $ pip install -r requirements.txt 12 | ``` 13 | 14 | ### Example Notebook 15 | To run our code see [example notebook](example_notebook.ipynb). 16 | 17 | ### Citing 18 | In case you find this this code useful, please consider citing 19 | 20 | ``` 21 | @article{horvath2020better, 22 | title={A Better Alternative to Error Feedback for Communication-Efficient Distributed Learning}, 23 | author={Horv\'{a}th, Samuel and Richt\'{a}rik, Peter}, 24 | journal={arXiv preprint arXiv:2006.11077}, 25 | year={2020} 26 | } 27 | ``` 28 | 29 | ### License 30 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 31 | -------------------------------------------------------------------------------- /example_notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Example notebook\n", 8 | "\n", 9 | "This notebook serves as an example on how to run and replicate experiments from our code base\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "### Libraries" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "from torch import nn\n", 26 | "import numpy as np\n", 27 | "\n", 28 | "from optim.train import tune_step_size, run_tuned_exp\n", 29 | "from optim.models import MNISTNet, MNISTLogReg, resnet18, vgg11\n", 30 | "\n", 31 | "from optim.utils import save_exp, load_exp, read_all_runs, create_exp\n", 32 | "from utils.plotting import plot\n", 33 | "\n", 34 | "from quant.quant import c_nat, random_dithering_wrap, rand_spars_wrap, \\\n", 35 | " top_k_wrap, grad_spars_wrap, biased_unbiased_wrap, combine_two_wrap" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## Example run" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "### Choose parameters " 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "dataset = 'cifar10' # datasets, current options: 'mnist', 'cifar10', 'cifar100'\n", 59 | "model = 'resnet18' # for saving purposes\n", 60 | "net = resnet18 # for the list of all models, see optim/models.py\n", 61 | "criterion = nn.CrossEntropyLoss() # loss, which is considered\n", 62 | "epochs = 50 # number of epochs \n", 63 | "n_workers = 8 # number of workers\n", 64 | "batch_size = 32 # local batch size on each worker\n", 65 | "seed = 40 # fixed seed, which allows experiment replication\n", 66 | "lrs = np.array([0.1, 0.05, 0.01]) # learning rates, which are considered during tuning stage\n", 67 | "momentum = 0.9 # momentum for optimizer, default 0\n", 68 | "weight_decay = 0 # weight_decay for optimizer, default 0\n", 69 | "\n", 70 | "exp_name = dataset + '_' + model # experiment name" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "### Choose compression operator\n", 78 | "\n", 79 | "Choose the one from the list. Compression is applied, when each node communicates with master. " 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "- No Compression" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "compression = {'wrapper': False, 'compression': None} " 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "- Natural Compression: [paper](https://arxiv.org/pdf/1905.10988.pdf)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "compression = {'wrapper': False, 'compression': c_nat} " 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "- Natural/Standard Dithering: [paper](https://arxiv.org/pdf/1905.10988.pdf)\n", 119 | " - `'p'`: norm\n", 120 | " - `'s'`: number of levels\n", 121 | " - `'natural'`: if `True` then Natural Dithering else Standard dithering" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "compression = {'wrapper': True, 'compression': random_dithering_wrap, 'p': np.inf, 's': 1, 'natural': True} " 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "- Gradient Sparsification: [paper](https://papers.nips.cc/paper/7405-gradient-sparsification-for-communication-efficient-distributed-optimization.pdf)\n", 138 | " - `'h'`: sparsity, $h \\in [0,1]$" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "compression = {'wrapper': True, 'compression': grad_spars_wrap, 'h': 1/20} " 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "- Top-K Sparsification: \n", 155 | " - `'h'`: sparsity, $h \\in [0,1]$" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "compression = {'wrapper': True, 'compression': top_k_wrap, 'h': 1/20}" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": {}, 170 | "source": [ 171 | "- Random Sparsification: \n", 172 | " - `'h'`: sparsity, $h \\in [0,1]$" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "compression = {'wrapper': True, 'compression': rand_spars_wrap, 'h': 1/20} " 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "metadata": {}, 187 | "source": [ 188 | "- Combination of two compression operators: \n", 189 | " - `'comp1'`: compression operator applied to original vector\n", 190 | " - `'comp2'`: (should be unbiased) if `'func'` is `biased_unbiased_wrap` then compression operator is applied to error \n", 191 | " `e = g - comp1(g)` and the resulting compression returns `comp1(g) + comp2(e)`, if `'func'` is `combine_two_wrap`, then the \n", 192 | " resulting compression returns `comp1(comp2(g))`\n", 193 | " - Example of combination of Top-K and Gradient Sparsification below " 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "compression = {'combine': {\n", 203 | " 'func': biased_unbiased_wrap, \n", 204 | " 'comp_1': {'wrapper': True, 'compression': top_k_wrap, 'h': 1/40},\n", 205 | " 'comp_2': {'wrapper': True, 'compression': grad_spars_wrap, 'h': 1/40}\n", 206 | " }\n", 207 | " }" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": {}, 213 | "source": [ 214 | "If you want to add, compression from master to nodes, then change default `master_compression = None` to any compression mentioned above. This compression does not have option for Error Feedback, thus unbiased compression should be selected." 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "master_compression = None" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": {}, 229 | "source": [ 230 | "### Error Feedback\n", 231 | "If `True`, Error feedback is used.." 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "error_feedback = True " 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": {}, 246 | "source": [ 247 | "### Wrap Experiment" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "name = exp_name + '_unique_identifier' # name based on which your experiment will be stored\n", 257 | "exp = create_exp(name, dataset, net, n_workers, epochs, seed, batch_size, lrs,\n", 258 | " compression, error_feedback, criterion, master_compression, momentum)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": {}, 264 | "source": [ 265 | "### Tune Step Size and Save Experiment" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "exp['lr'] = tune_step_size(exp)\n", 275 | "save_exp(exp)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "metadata": {}, 281 | "source": [ 282 | "### Run Experiment with tuned Step Size\n", 283 | "Each experiment is run `RUNS = 5` times. This value can be adjusted in [optim/train.py](optim/train.py) file." 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "metadata": {}, 290 | "outputs": [], 291 | "source": [ 292 | "run_tuned_exp(exp)" 293 | ] 294 | }, 295 | { 296 | "cell_type": "markdown", 297 | "metadata": {}, 298 | "source": [ 299 | "### Compare methods -- Plotting" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": null, 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "exp_1 = load_exp(exp_name + '_identifier_1')\n", 309 | "exp_2 = load_exp(exp_name + '_identifier_2')\n", 310 | "exp_3 = load_exp(exp_name + '_identifier_3')" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "kind = 'test_acc' # options: 'train_loss', 'test_loss', 'test_acc' \n", 320 | "exp_type = 'experiment_identifier'\n", 321 | "plot([exp_1, exp_2, exp_3], kind, log_scale=False,\n", 322 | " legend=['Exp_1_Name', 'Exp_2_Name', 'Exp_1_Name'],\n", 323 | " y_label='Test accuracy', file='File_To_Store_Plot.pdf')" 324 | ] 325 | } 326 | ], 327 | "metadata": { 328 | "kernelspec": { 329 | "display_name": "Python 3", 330 | "language": "python", 331 | "name": "python3" 332 | }, 333 | "language_info": { 334 | "codemirror_mode": { 335 | "name": "ipython", 336 | "version": 3 337 | }, 338 | "file_extension": ".py", 339 | "mimetype": "text/x-python", 340 | "name": "python", 341 | "nbconvert_exporter": "python", 342 | "pygments_lexer": "ipython3", 343 | "version": "3.7.4" 344 | } 345 | }, 346 | "nbformat": 4, 347 | "nbformat_minor": 4 348 | } 349 | -------------------------------------------------------------------------------- /optim/gen_sgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | 5 | class SGDGen(Optimizer): 6 | r""" 7 | based on torch.optim.SGD implementation 8 | """ 9 | 10 | def __init__(self, params, lr, n_workers, momentum=0, dampening=0, 11 | weight_decay=0, nesterov=False, comp=None, master_comp=None, 12 | error_feedback=False): 13 | if lr < 0.0: 14 | raise ValueError("Invalid learning rate: {}".format(lr)) 15 | if momentum < 0.0: 16 | raise ValueError("Invalid momentum value: {}".format(momentum)) 17 | if weight_decay < 0.0: 18 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 19 | 20 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 21 | weight_decay=weight_decay, nesterov=nesterov) 22 | if nesterov and (momentum <= 0 or dampening != 0): 23 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 24 | super(SGDGen, self).__init__(params, defaults) 25 | 26 | self.comp = comp 27 | self.error_feedback = error_feedback 28 | if self.error_feedback and self.comp is None: 29 | raise ValueError("For Error-Feedback, compression can't be None") 30 | 31 | self.master_comp = master_comp # should be unbiased, Error-Feedback is not supported at the moment 32 | 33 | self.n_workers = n_workers 34 | self.grads_received = 0 35 | 36 | def __setstate__(self, state): 37 | super(SGDGen, self).__setstate__(state) 38 | for group in self.param_groups: 39 | group.setdefault('nesterov', False) 40 | 41 | @torch.no_grad() 42 | def step_local_global(self, w_id, closure=None): 43 | """Performs a single optimization step. 44 | 45 | Arguments: 46 | w_id: integer, id of the worker 47 | closure (callable, optional): A closure that reevaluates the model 48 | and returns the loss. 49 | """ 50 | loss = None 51 | if closure is not None: 52 | loss = closure() 53 | 54 | self.grads_received += 1 55 | 56 | for group in self.param_groups: 57 | weight_decay = group['weight_decay'] 58 | momentum = group['momentum'] 59 | dampening = group['dampening'] 60 | nesterov = group['nesterov'] 61 | 62 | for p in group['params']: 63 | if p.grad is None: 64 | continue 65 | 66 | param_state = self.state[p] 67 | 68 | d_p = p.grad.data 69 | 70 | if self.error_feedback: 71 | error_name = 'error_' + str(w_id) 72 | if error_name not in param_state: 73 | loc_grad = d_p.mul(group['lr']) 74 | else: 75 | loc_grad = d_p.mul(group['lr']) + param_state[error_name] 76 | 77 | d_p = self.comp(loc_grad) 78 | param_state[error_name] = loc_grad - d_p 79 | 80 | else: 81 | if self.comp is not None: 82 | d_p = self.comp(d_p).mul(group['lr']) 83 | else: 84 | d_p = d_p.mul(group['lr']) 85 | 86 | if 'full_grad' not in param_state or self.grads_received == 1: 87 | param_state['full_grad'] = torch.clone(d_p).detach() 88 | else: 89 | param_state['full_grad'] += torch.clone(d_p).detach() 90 | 91 | if self.grads_received == self.n_workers: 92 | grad = param_state['full_grad'] / self.n_workers 93 | 94 | if self.master_comp is not None: 95 | grad = self.master_comp(grad) 96 | 97 | if weight_decay != 0: 98 | grad.add(p, alpha=weight_decay) 99 | if momentum != 0: 100 | if 'momentum_buffer' not in param_state: 101 | buf = param_state['momentum_buffer'] = torch.clone(grad).detach() 102 | else: 103 | buf = param_state['momentum_buffer'] 104 | buf.mul_(momentum).add_(grad, alpha=1 - dampening) 105 | if nesterov: 106 | grad = grad.add(buf, alpha=momentum) 107 | else: 108 | grad = buf 109 | 110 | p.data.add_(grad, alpha=-1) 111 | 112 | if self.grads_received == self.n_workers: 113 | self.grads_received = 0 114 | 115 | return loss 116 | -------------------------------------------------------------------------------- /optim/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | import math 5 | 6 | 7 | # MNIST Fully Connected Net ---------------------------------------------------- 8 | class MNISTNet(nn.Module): 9 | def __init__(self): 10 | super(MNISTNet, self).__init__() 11 | self.fc1 = nn.Linear(28 * 28, 512) 12 | self.fc2 = nn.Linear(512, 10) 13 | 14 | def forward(self, x): 15 | # flatten image input 16 | x = x.view(-1, 28 * 28) 17 | x = F.relu(self.fc1(x)) 18 | x = self.fc2(x) 19 | return x 20 | 21 | 22 | class MNISTLogReg(nn.Module): 23 | def __init__(self): 24 | super(MNISTLogReg, self).__init__() 25 | self.fc = nn.Linear(28 * 28, 10) 26 | 27 | def forward(self, x): 28 | # flatten image input 29 | x = x.view(-1, 28 * 28) 30 | x = self.fc(x) 31 | return x 32 | 33 | 34 | # ResNets (credit to: https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py) 35 | '''ResNet in PyTorch. 36 | For Pre-activation ResNet, see 'preact_resnet.py'. 37 | Reference: 38 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 39 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 40 | ''' 41 | 42 | 43 | class BasicBlock(nn.Module): 44 | expansion = 1 45 | 46 | def __init__(self, in_planes, planes, stride=1): 47 | super(BasicBlock, self).__init__() 48 | self.conv1 = nn.Conv2d( 49 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 50 | self.bn1 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 52 | stride=1, padding=1, bias=False) 53 | self.bn2 = nn.BatchNorm2d(planes) 54 | 55 | self.shortcut = nn.Sequential() 56 | if stride != 1 or in_planes != self.expansion*planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion*planes, 59 | kernel_size=1, stride=stride, bias=False), 60 | nn.BatchNorm2d(self.expansion*planes) 61 | ) 62 | 63 | def forward(self, x): 64 | out = F.relu(self.bn1(self.conv1(x))) 65 | out = self.bn2(self.conv2(out)) 66 | out += self.shortcut(x) 67 | out = F.relu(out) 68 | return out 69 | 70 | 71 | class Bottleneck(nn.Module): 72 | expansion = 4 73 | 74 | def __init__(self, in_planes, planes, stride=1): 75 | super(Bottleneck, self).__init__() 76 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(planes) 78 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 79 | stride=stride, padding=1, bias=False) 80 | self.bn2 = nn.BatchNorm2d(planes) 81 | self.conv3 = nn.Conv2d(planes, self.expansion * 82 | planes, kernel_size=1, bias=False) 83 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 84 | 85 | self.shortcut = nn.Sequential() 86 | if stride != 1 or in_planes != self.expansion*planes: 87 | self.shortcut = nn.Sequential( 88 | nn.Conv2d(in_planes, self.expansion*planes, 89 | kernel_size=1, stride=stride, bias=False), 90 | nn.BatchNorm2d(self.expansion*planes) 91 | ) 92 | 93 | def forward(self, x): 94 | out = F.relu(self.bn1(self.conv1(x))) 95 | out = F.relu(self.bn2(self.conv2(out))) 96 | out = self.bn3(self.conv3(out)) 97 | out += self.shortcut(x) 98 | out = F.relu(out) 99 | return out 100 | 101 | 102 | class ResNet(nn.Module): 103 | def __init__(self, block, num_blocks, num_classes=10): 104 | super(ResNet, self).__init__() 105 | self.in_planes = 64 106 | 107 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 108 | stride=1, padding=1, bias=False) 109 | self.bn1 = nn.BatchNorm2d(64) 110 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 111 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 112 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 113 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 114 | self.linear = nn.Linear(512*block.expansion, num_classes) 115 | 116 | def _make_layer(self, block, planes, num_blocks, stride): 117 | strides = [stride] + [1]*(num_blocks-1) 118 | layers = [] 119 | for stride in strides: 120 | layers.append(block(self.in_planes, planes, stride)) 121 | self.in_planes = planes * block.expansion 122 | return nn.Sequential(*layers) 123 | 124 | def forward(self, x): 125 | out = F.relu(self.bn1(self.conv1(x))) 126 | out = self.layer1(out) 127 | out = self.layer2(out) 128 | out = self.layer3(out) 129 | out = self.layer4(out) 130 | out = F.avg_pool2d(out, 4) 131 | out = out.view(out.size(0), -1) 132 | out = self.linear(out) 133 | return out 134 | 135 | 136 | def resnet18(): 137 | return ResNet(BasicBlock, [2, 2, 2, 2]) 138 | 139 | 140 | def resnet34(): 141 | return ResNet(BasicBlock, [3, 4, 6, 3]) 142 | 143 | 144 | def resnet50(): 145 | return ResNet(Bottleneck, [3, 4, 6, 3]) 146 | 147 | 148 | def resnet101(): 149 | return ResNet(Bottleneck, [3, 4, 23, 3]) 150 | 151 | 152 | def resnet152(): 153 | return ResNet(Bottleneck, [3, 8, 36, 3]) 154 | 155 | 156 | # VGG Nets (credit to https://github.com/chengyangfu/pytorch-vgg-cifar10) 157 | ''' 158 | Modified from https://github.com/pytorch/vision.git 159 | ''' 160 | 161 | __all__ = [ 162 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 163 | 'vgg19_bn', 'vgg19', 164 | ] 165 | 166 | 167 | class VGG(nn.Module): 168 | """ 169 | VGG model 170 | """ 171 | def __init__(self, features): 172 | super(VGG, self).__init__() 173 | self.features = features 174 | self.classifier = nn.Sequential( 175 | nn.Dropout(), 176 | nn.Linear(512, 512), 177 | nn.ReLU(True), 178 | nn.Dropout(), 179 | nn.Linear(512, 512), 180 | nn.ReLU(True), 181 | nn.Linear(512, 10), 182 | ) 183 | # Initialize weights 184 | for m in self.modules(): 185 | if isinstance(m, nn.Conv2d): 186 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 187 | m.weight.data.normal_(0, math.sqrt(2. / n)) 188 | m.bias.data.zero_() 189 | 190 | def forward(self, x): 191 | x = self.features(x) 192 | x = x.view(x.size(0), -1) 193 | x = self.classifier(x) 194 | return x 195 | 196 | 197 | def make_layers(config, batch_norm=False): 198 | layers = [] 199 | in_channels = 3 200 | for v in config: 201 | if v == 'M': 202 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 203 | else: 204 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 205 | if batch_norm: 206 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 207 | else: 208 | layers += [conv2d, nn.ReLU(inplace=True)] 209 | in_channels = v 210 | return nn.Sequential(*layers) 211 | 212 | 213 | cfg = { 214 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 215 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 216 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 217 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 218 | 512, 512, 512, 512, 'M'], 219 | } 220 | 221 | 222 | def vgg11(): 223 | """VGG 11-layer model (configuration "A")""" 224 | return VGG(make_layers(cfg['A'])) 225 | 226 | 227 | def vgg11_bn(): 228 | """VGG 11-layer model (configuration "A") with batch normalization""" 229 | return VGG(make_layers(cfg['A'], batch_norm=True)) 230 | 231 | 232 | def vgg13(): 233 | """VGG 13-layer model (configuration "B")""" 234 | return VGG(make_layers(cfg['B'])) 235 | 236 | 237 | def vgg13_bn(): 238 | """VGG 13-layer model (configuration "B") with batch normalization""" 239 | return VGG(make_layers(cfg['B'], batch_norm=True)) 240 | 241 | 242 | def vgg16(): 243 | """VGG 16-layer model (configuration "D")""" 244 | return VGG(make_layers(cfg['D'])) 245 | 246 | 247 | def vgg16_bn(): 248 | """VGG 16-layer model (configuration "D") with batch normalization""" 249 | return VGG(make_layers(cfg['D'], batch_norm=True)) 250 | 251 | 252 | def vgg19(): 253 | """VGG 19-layer model (configuration "E")""" 254 | return VGG(make_layers(cfg['E'])) 255 | 256 | 257 | def vgg19_bn(): 258 | """VGG 19-layer model (configuration 'E') with batch normalization""" 259 | return VGG(make_layers(cfg['E'], batch_norm=True)) 260 | -------------------------------------------------------------------------------- /optim/prep_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets 3 | import torchvision.transforms as transforms 4 | from torch.utils.data import DataLoader, Subset 5 | 6 | 7 | def create_loaders(dataset_name, n_workers, batch_size, seed=42): 8 | 9 | train_data, test_data = load_data(dataset_name) 10 | 11 | train_loader_workers = dict() 12 | n = len(train_data) 13 | 14 | # preparing iterators for workers and validation set 15 | np.random.seed(seed) 16 | indices = np.arange(n) 17 | np.random.shuffle(indices) 18 | 19 | n_val = np.int(np.floor(0.1 * n)) 20 | val_data = Subset(train_data, indices=indices[:n_val]) 21 | 22 | indices = indices[n_val:] 23 | n = len(indices) 24 | a = np.int(np.floor(n / n_workers)) 25 | top_ind = a * n_workers 26 | seq = range(a, top_ind, a) 27 | split = np.split(indices[:top_ind], seq) 28 | 29 | b = 0 30 | for ind in split: 31 | train_loader_workers[b] = DataLoader(Subset(train_data, ind), batch_size=batch_size, shuffle=True) 32 | b = b + 1 33 | 34 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False) 35 | val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False) 36 | 37 | return train_loader_workers, val_loader, test_loader 38 | 39 | 40 | def load_data(dataset_name): 41 | 42 | if dataset_name == 'mnist': 43 | 44 | transform = transforms.ToTensor() 45 | 46 | train_data = datasets.MNIST(root='data', train=True, 47 | download=True, transform=transform) 48 | 49 | test_data = datasets.MNIST(root='data', train=False, 50 | download=True, transform=transform) 51 | elif dataset_name == 'cifar10': 52 | 53 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 54 | std=[0.229, 0.224, 0.225]) 55 | transform = transforms.Compose([ 56 | transforms.ToTensor(), 57 | normalize, 58 | ]) 59 | 60 | train_data = datasets.CIFAR10(root='data', train=True, 61 | download=True, transform=transform) 62 | 63 | test_data = datasets.CIFAR10(root='data', train=False, 64 | download=True, transform=transform) 65 | elif dataset_name == 'cifar100': 66 | transform = transforms.ToTensor() # add extra transforms 67 | train_data = datasets.CIFAR100(root='data', train=True, 68 | download=True, transform=transform) 69 | 70 | test_data = datasets.CIFAR100(root='data', train=False, 71 | download=True, transform=transform) 72 | else: 73 | raise ValueError(dataset_name + ' is not known.') 74 | 75 | return train_data, test_data 76 | -------------------------------------------------------------------------------- /optim/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .utils import create_run, update_run, save_run, seed_everything 5 | from .prep_data import create_loaders 6 | from .gen_sgd import SGDGen 7 | 8 | RUNS = 5 9 | 10 | 11 | def train_workers(suffix, model, optimizer, criterion, epochs, train_loader_workers, 12 | val_loader, test_loader, n_workers, hpo=False): 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | model.to(device) 15 | run = create_run() 16 | train_loss = np.inf 17 | 18 | best_val_loss = np.inf 19 | test_loss = np.inf 20 | test_acc = 0 21 | 22 | for e in range(epochs): 23 | model.train() 24 | running_loss = 0 25 | train_loader_iter = [iter(train_loader_workers[w]) for w in range(n_workers)] 26 | iter_steps = len(train_loader_workers[0]) 27 | for _ in range(iter_steps): 28 | for w_id in range(n_workers): 29 | data, labels = next(train_loader_iter[w_id]) 30 | data, labels = data.to(device), labels.to(device) 31 | output = model(data) 32 | loss = criterion(output, labels) 33 | loss.backward() 34 | running_loss += loss.item() 35 | optimizer.step_local_global(w_id) 36 | optimizer.zero_grad() 37 | 38 | train_loss = running_loss/(iter_steps*n_workers) 39 | 40 | val_loss, _ = accuracy_and_loss(model, val_loader, criterion, device) 41 | 42 | if val_loss < best_val_loss: 43 | test_loss, test_acc = accuracy_and_loss(model, test_loader, criterion, device) 44 | best_val_loss = val_loss 45 | 46 | update_run(train_loss, test_loss, test_acc, run) 47 | 48 | print("Epoch: {}/{}.. Training Loss: {:.5f}, Test Loss: {:.5f}, Test accuracy: {:.2f} " 49 | .format(e + 1, epochs, train_loss, test_loss, test_acc), end='\r') 50 | 51 | print('') 52 | if not hpo: 53 | save_run(suffix, run) 54 | 55 | return best_val_loss 56 | 57 | 58 | def accuracy_and_loss(model, loader, criterion, device): 59 | correct = 0 60 | total_loss = 0 61 | 62 | model.eval() 63 | for data, labels in loader: 64 | data, labels = data.to(device), labels.to(device) 65 | output = model(data) 66 | loss = criterion(output, labels) 67 | total_loss += loss.item() 68 | 69 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 70 | correct += pred.eq(labels.view_as(pred)).sum().item() 71 | 72 | accuracy = 100. * correct / len(loader.dataset) 73 | total_loss = total_loss / len(loader) 74 | 75 | return total_loss, accuracy 76 | 77 | 78 | def tune_step_size(exp): 79 | best_val_loss = np.inf 80 | best_lr = 0 81 | 82 | seed = exp['seed'] 83 | seed_everything(seed) 84 | hpo = True 85 | 86 | for lr in exp['lrs']: 87 | print('Learning rate {:2.4f}:'.format(lr)) 88 | val_loss = run_workers(lr, exp, hpo=hpo) 89 | 90 | if val_loss < best_val_loss: 91 | best_lr = lr 92 | best_val_loss = val_loss 93 | return best_lr 94 | 95 | 96 | def run_workers(lr, exp, suffix=None, hpo=False): 97 | dataset_name = exp['dataset_name'] 98 | n_workers = exp['n_workers'] 99 | batch_size = exp['batch_size'] 100 | epochs = exp['epochs'] 101 | criterion = exp['criterion'] 102 | error_feedback = exp['error_feedback'] 103 | momentum = exp['momentum'] 104 | weight_decay = exp['weight_decay'] 105 | compression = get_compression(**exp['compression']) 106 | master_compression = exp['master_compression'] 107 | 108 | net = exp['net'] 109 | model = net() 110 | 111 | train_loader_workers, val_loader, test_loader = create_loaders(dataset_name, n_workers, batch_size) 112 | 113 | optimizer = SGDGen(model.parameters(), lr=lr, n_workers=n_workers, error_feedback=error_feedback, 114 | comp=compression, momentum=momentum, weight_decay=weight_decay, master_comp=master_compression) 115 | 116 | val_loss = train_workers(suffix, model, optimizer, criterion, epochs, train_loader_workers, 117 | val_loader, test_loader, n_workers, hpo=hpo) 118 | return val_loss 119 | 120 | 121 | def run_tuned_exp(exp, runs=RUNS, suffix=None): 122 | if suffix is None: 123 | suffix = exp['name'] 124 | 125 | lr = exp['lr'] 126 | 127 | if lr is None: 128 | raise ValueError("Tune step size first") 129 | 130 | seed = exp['seed'] 131 | seed_everything(seed) 132 | 133 | for i in range(runs): 134 | print('Run {:3d}/{:3d}, Name {}:'.format(i+1, runs, suffix)) 135 | suffix_run = suffix + '_' + str(i+1) 136 | run_workers(lr, exp, suffix_run) 137 | 138 | 139 | def get_single_compression(wrapper, compression, **kwargs): 140 | if wrapper: 141 | return compression(**kwargs) 142 | else: 143 | return compression 144 | 145 | 146 | def get_compression(combine=None, **kwargs): 147 | if combine is None: 148 | return get_single_compression(**kwargs) 149 | else: 150 | compression_1 = get_single_compression(**combine['comp_1']) 151 | compression_2 = get_single_compression(**combine['comp_2']) 152 | return combine['func'](compression_1, compression_2) 153 | -------------------------------------------------------------------------------- /optim/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import random 5 | import torch 6 | from pickle import load, dump 7 | 8 | SAVED_RUNS_PATH = 'saved_data/' 9 | EXP_PATH = 'exps_setup/' 10 | 11 | 12 | def save_run(suffix, run): 13 | if not os.path.isdir(SAVED_RUNS_PATH): 14 | os.mkdir(SAVED_RUNS_PATH) 15 | 16 | file = SAVED_RUNS_PATH + suffix + '.pickle' 17 | with open(file, 'wb') as f: 18 | dump(run, f) 19 | 20 | 21 | def read_all_runs(exp, suffix=None): 22 | if suffix is None: 23 | suffix = exp['name'] 24 | 25 | runs = list() 26 | runs_files = glob.glob(SAVED_RUNS_PATH + suffix + '_' + '[1-9]*.pickle') # reads at most first ten runs 27 | for run_file in runs_files: 28 | runs.append(read_run(run_file)) 29 | return runs 30 | 31 | 32 | def read_run(file): 33 | with open(file, 'rb') as f: 34 | run = load(f) 35 | return run 36 | 37 | 38 | def create_run(): 39 | run = {'train_loss': [], 40 | 'test_loss': [], 41 | 'test_acc': [] 42 | } 43 | return run 44 | 45 | 46 | def update_run(train_loss, test_loss, test_acc, run): 47 | run['train_loss'].append(train_loss) 48 | run['test_loss'].append(test_loss) 49 | run['test_acc'].append(test_acc) 50 | 51 | 52 | def save_exp(exp): 53 | if not os.path.isdir(EXP_PATH): 54 | os.mkdir(EXP_PATH) 55 | 56 | file = EXP_PATH + exp['name'] + '.pickle' 57 | with open(file, 'wb') as f: 58 | dump(exp, f) 59 | 60 | 61 | def load_exp(exp_name): 62 | file = EXP_PATH + exp_name + '.pickle' 63 | with open(file, 'rb') as f: 64 | exp = load(f) 65 | return exp 66 | 67 | 68 | def create_exp(name, dataset, net, n_workers, epochs, seed, batch_size, lrs, compression, error_feedback, criterion, 69 | master_compression=None, momentum=0, weight_decay=0): 70 | exp = { 71 | 'name': name, 72 | 'dataset_name': dataset, 73 | 'net': net, 74 | 'n_workers': n_workers, 75 | 'epochs': epochs, 76 | 'seed': seed, 77 | 'batch_size': batch_size, 78 | 'lrs': lrs, 79 | 'lr': None, 80 | 'compression': compression, 81 | 'master_compression': master_compression, 82 | 'error_feedback': error_feedback, 83 | 'criterion': criterion, 84 | 'momentum': momentum, 85 | 'weight_decay': weight_decay 86 | } 87 | return exp 88 | 89 | 90 | def seed_everything(seed=42): 91 | """ 92 | :param seed: 93 | :return: 94 | """ 95 | random.seed(seed) 96 | os.environ['PYTHONHASHSEED'] = str(seed) 97 | np.random.seed(seed) 98 | torch.manual_seed(seed) 99 | torch.cuda.manual_seed(seed) 100 | torch.cuda.manual_seed_all(seed) 101 | # some cudnn methods can be random even after fixing the seed 102 | # unless you tell it to be deterministic 103 | torch.backends.cudnn.deterministic = True 104 | -------------------------------------------------------------------------------- /quant/quant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def prep_grad(x): 6 | x_flat = torch.unsqueeze(x, 0).flatten() 7 | dim = x.shape 8 | d = x_flat.shape[0] 9 | return x_flat, dim, d 10 | 11 | 12 | def c_nat(x): 13 | x, dim, d = prep_grad(x) 14 | # get 2^n out of input 15 | h1 = torch.floor(torch.where(x != 0, torch.log2(torch.abs(x)), x)) 16 | h2 = torch.where(x != 0, torch.pow(2, h1), x) 17 | # extract probability 18 | p = torch.where(x != 0, torch.div(torch.abs(x) - h2, h2), x) 19 | # sample random uniform vector 20 | unif = torch.rand_like(x) 21 | # generate zero one with probability p 22 | zero_one = torch.floor(unif + p) 23 | # generate output 24 | nat = torch.sign(x) * h2 * (1 + zero_one) 25 | return nat.reshape(dim) 26 | 27 | 28 | def random_dithering_opt(x, p, s, natural): 29 | """ 30 | :param x: vector to quantize 31 | :param p: norm parameter 32 | :param s: number of levels 33 | :param natural: if True, natural dithering is used 34 | :return: compressed vector 35 | """ 36 | x, dim, d = prep_grad(x) 37 | # definition of random dithering 38 | norm = torch.norm(x, p=p) 39 | if norm == 0: 40 | return x.reshape(dim) 41 | 42 | if natural: 43 | s = int(2 ** (s - 1)) 44 | f = torch.floor(s * torch.abs(x) / norm + torch.rand_like(x))/s 45 | 46 | if natural: 47 | f = c_nat(f) 48 | res = torch.sign(x) * f 49 | k = res * norm 50 | return k.reshape(dim) 51 | 52 | 53 | def random_dithering_wrap(p=np.inf, s=2, natural=True): 54 | def random_dithering(x): 55 | return random_dithering_opt(x, p=p, s=s, natural=natural) 56 | return random_dithering 57 | 58 | 59 | def rand_spars_opt(x, h): 60 | """ 61 | :param x: vector to sparsify 62 | :param h: density 63 | :return: compressed vector 64 | """ 65 | x, dim, d = prep_grad(x) 66 | # number of coordinates to keep 67 | r = int(np.maximum(1, np.floor(d * h))) 68 | # random vector of r ones and d-r zeros 69 | mask = torch.zeros_like(x) 70 | mask[torch.randperm(d)[:r]] = 1 71 | # just r random coordinates are kept 72 | t = mask * x * (d/r) 73 | t = t.reshape(dim) 74 | return t 75 | 76 | 77 | def rand_spars_wrap(h=0.1): 78 | def rand_spars(x): 79 | return rand_spars_opt(x, h=h) 80 | return rand_spars 81 | 82 | 83 | def top_k_opt(x, h): 84 | """ 85 | :param x: vector to sparsify 86 | :param h: density 87 | :return: compressed vector 88 | """ 89 | x, dim, d = prep_grad(x) 90 | # number of coordinates kept 91 | r = int(np.maximum(1, np.floor(d * h))) 92 | # positions of top_k coordinates 93 | _, ind = torch.topk(torch.abs(x), r) 94 | mask = torch.zeros_like(x) 95 | mask[ind] = 1 96 | t = mask * x 97 | t = t.reshape(dim) 98 | return t 99 | 100 | 101 | def top_k_wrap(h=0.1): 102 | def top_k(x): 103 | return top_k_opt(x, h=h) 104 | return top_k 105 | 106 | 107 | def grad_spars_opt(x, h, max_it): 108 | """ 109 | :param x: vector to sparsify 110 | :param h: density 111 | :param max_it: maximum number of iterations of greedy algorithm 112 | :return: compressed vector 113 | """ 114 | x, dim, d = prep_grad(x) 115 | # number of coordinates kept 116 | r = int(np.maximum(1, np.floor(d * h))) 117 | 118 | abs_x = torch.abs(x) 119 | abs_sum = torch.sum(abs_x) 120 | ones = torch.ones_like(x) 121 | p_0 = r * abs_x / abs_sum 122 | p = torch.min(p_0, ones) 123 | for _ in range(max_it): 124 | p_sub = p[(p != 1).nonzero(as_tuple=True)] 125 | p = torch.where(p >= ones, ones, p) 126 | if len(p_sub) == 0 or torch.sum(torch.abs(p_sub)) == 0: 127 | break 128 | else: 129 | c = (r - d + len(p_sub))/torch.sum(p_sub) 130 | p = torch.min(c * p, ones) 131 | if c <= 1: 132 | break 133 | prob = torch.rand_like(x) 134 | # avoid making very small gradient too big 135 | s = torch.where(p <= 10**(-6), x, x / p) 136 | # we keep just coordinates with high probability 137 | t = torch.where(prob <= p, s, torch.zeros_like(x)) 138 | t = t.reshape(dim) 139 | return t 140 | 141 | 142 | def grad_spars_wrap(h=0.1, max_it=4): 143 | def grad_spars(x): 144 | return grad_spars_opt(x, h=h, max_it=max_it) 145 | return grad_spars 146 | 147 | 148 | def biased_unbiased_wrap(biased_comp, unbiased_comp): 149 | def error_quant(x): 150 | c_1 = biased_comp(x) 151 | error = x - c_1 152 | c_2 = unbiased_comp(error) 153 | return c_1 + c_2 154 | return error_quant 155 | 156 | 157 | def combine_two_wrap(comp_1, comp_2): 158 | def combine(x): 159 | t_1 = comp_1(x) 160 | t_2 = comp_2(t_1) 161 | return t_2 162 | return combine 163 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.1.3 2 | numpy==1.18.1 3 | torch==1.5.0 4 | torchvision==0.6.0 -------------------------------------------------------------------------------- /utils/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | import matplotlib.pyplot as plt 3 | from optim.utils import read_all_runs 4 | import numpy as np 5 | import os 6 | 7 | plt.style.use('ggplot') 8 | mpl.rcParams['pdf.fonttype'] = 42 9 | mpl.rcParams['ps.fonttype'] = 42 10 | mpl.rcParams['lines.linewidth'] = 2.0 11 | mpl.rcParams['legend.fontsize'] = 'x-large' 12 | mpl.rcParams['xtick.labelsize'] = 'x-large' 13 | mpl.rcParams['ytick.labelsize'] = 'x-large' 14 | mpl.rcParams['axes.labelsize'] = 'x-large' 15 | 16 | PLOT_PATH = 'plots/' 17 | 18 | 19 | def plot(exps, kind, suffix=None, log_scale=True, legend=None, file=None, 20 | x_label='epochs', y_label=None): 21 | fig, ax = plt.subplots() 22 | 23 | for exp in exps: 24 | runs = read_all_runs(exp, suffix=suffix) 25 | plot_mean_std(ax, runs, kind) 26 | 27 | if log_scale: 28 | ax.set_yscale('log') 29 | if legend is not None: 30 | ax.legend(legend) 31 | 32 | ax.set_xlabel(x_label) 33 | if y_label is None: 34 | ax.set_ylabel(kind) 35 | else: 36 | ax.set_ylabel(y_label) 37 | 38 | fig.tight_layout() 39 | if file is not None: 40 | if not os.path.isdir(PLOT_PATH): 41 | os.mkdir(PLOT_PATH) 42 | plt.savefig(PLOT_PATH + file + '.pdf') 43 | plt.show() 44 | 45 | 46 | def plot_mean_std(ax, runs, kind): 47 | quant = np.array([run[kind] for run in runs]) 48 | mean = np.mean(quant, axis=0) 49 | std = np.std(quant, axis=0) 50 | 51 | x = np.arange(1, len(mean) + 1) 52 | ax.plot(x, mean) 53 | ax.fill_between(x, mean + std, mean - std, alpha=0.4) 54 | 55 | --------------------------------------------------------------------------------