├── .gitattributes ├── models ├── convmixer.py ├── vgg.py ├── resnet.py ├── simple.py ├── randaug.py └── wideresnet.py ├── README.md ├── options.py ├── federated_main.py ├── federated_main-ef.py ├── sampling.py ├── utils.py ├── compressors.py └── update.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /models/convmixer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | 4 | class Residual(nn.Module): 5 | def __init__(self, fn): 6 | super().__init__() 7 | self.fn = fn 8 | 9 | def forward(self, x): 10 | return self.fn(x) + x 11 | 12 | 13 | def ConvMixer(dim=256, depth=8, kernel_size=5, patch_size=2, n_classes=10): 14 | return nn.Sequential( 15 | nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size), 16 | nn.GELU(), 17 | nn.BatchNorm2d(dim), 18 | *[nn.Sequential( 19 | Residual(nn.Sequential( 20 | nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), 21 | nn.GELU(), 22 | nn.BatchNorm2d(dim) 23 | )), 24 | nn.Conv2d(dim, dim, kernel_size=1), 25 | nn.GELU(), 26 | nn.BatchNorm2d(dim) 27 | ) for i in range(depth)], 28 | nn.AdaptiveAvgPool2d((1,1)), 29 | nn.Flatten(), 30 | nn.Linear(dim, n_classes) 31 | ) 32 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | cfg = { 8 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 10 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 11 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 12 | } 13 | 14 | 15 | class VGG(nn.Module): 16 | def __init__(self, vgg_name, num_classes = 10): 17 | super(VGG, self).__init__() 18 | self.features = self._make_layers(cfg[vgg_name]) 19 | self.classifier = nn.Linear(512, num_classes) 20 | 21 | def forward(self, x): 22 | out = self.features(x) 23 | out = out.view(out.size(0), -1) 24 | out = self.classifier(out) 25 | return out 26 | 27 | def _make_layers(self, cfg): 28 | layers = [] 29 | in_channels = 3 30 | for x in cfg: 31 | if x == 'M': 32 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 33 | else: 34 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(x), 36 | nn.ReLU(inplace=True)] 37 | in_channels = x 38 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 39 | return nn.Sequential(*layers) 40 | 41 | # net = VGG('VGG11') 42 | # x = torch.randn(2,3,32,32) 43 | # print(net(Variable(x)).size()) 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedCAMS 2 | 3 | This repository contains the PyTorch implementation of Federated AMSGrad with Max Stabilization (FedAMS), and Federated Communication compressed AMSGrad with Max Stabilization (FedCAMS) in (accepted by ICML 2022). 4 | 5 | ## Prerequisites 6 | Pytorch 1.11.0 7 | 8 | CUDA 11.3 9 | 10 | ## Running the experiments 11 | 12 | To run the experiment for FedAMS: 13 | 14 | ``` 15 | python3 federated_main.py --model=resnet --dataset=cifar10 --gpu=0 --local_bs=20 --epochs=500 --iid=1 --optimizer=fedams --local_lr=0.01 --lr=1.0 --local_ep=3 --eps=0 --max_init=1e-3 16 | ``` 17 | To run the experiment for FedCAMS: 18 | ``` 19 | python3 federated_main-ef.py --model=resnet --dataset=cifar10 --gpu=0 --local_bs=20 --epochs=500 --iid=1 --optimizer=fedams --local_lr=0.01 --lr=1.0 --local_ep=3 --eps=0 --max_init=1e-3 20 | ``` 21 | ## Options 22 | The default values for various paramters parsed to the experiment are given in ```options.py```. 23 | 24 | ```--dataset:``` Default: 'cifar10'. Options: 'mnist', 'fmnist', 'cifar100'. 25 | 26 | ```--model:``` Default: 'cnn'. Options: 'mlp', 'resnet', 'convmixer'. 27 | 28 | ```--gpu:``` To use cuda, set to a specific GPU ID. 29 | 30 | ```--epochs:``` Number of rounds of training. 31 | 32 | ```--local_ep:``` Number of local epochs. 33 | 34 | ```--local_lr:``` Learning rate for local update. 35 | 36 | ```--lr:``` Learning rate for global update. 37 | 38 | ```--local_bs:``` Local update batch size. 39 | 40 | ```--iid:``` Default set to IID. Set to 0 for non-IID. 41 | 42 | ```--num_users:``` Number of users. Default is 100. 43 | 44 | ```--frac:``` Fraction of users to be used for federated updates. Default is 0.1. 45 | 46 | ```--optimizer:``` Default: 'fedavg'. Options: 'fedadam', 'fedams'. 47 | 48 | ```--compressor:``` Compression strategy. Default: 'sign'. Options: 'topk64', 'topk128', 'topk256'. 49 | 50 | ## Citation 51 | Please check our paper for technical details and full results. 52 | ``` 53 | @inproceedings{wang2022communication, 54 | title={Communication-Efficient Adaptive Federated Learning}, 55 | author={Wang, Yujia and Lin, Lu and Chen, Jinghui}, 56 | booktitle={Proceedings of the International Conference on Machine Learning (ICML)}, 57 | year={2022} 58 | } 59 | 60 | ``` 61 | 62 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import argparse 6 | 7 | 8 | def args_parser(): 9 | parser = argparse.ArgumentParser() 10 | 11 | # federated arguments (Notation for the arguments followed from paper) 12 | parser.add_argument('--epochs', type=int, default=500, 13 | help="number of rounds of training") 14 | parser.add_argument('--num_users', type=int, default=100, 15 | help="number of users: K") 16 | parser.add_argument('--frac', type=float, default=0.1, 17 | help='the fraction of clients: C') 18 | parser.add_argument('--local_ep', type=int, default=3, 19 | help="the number of local epochs: E") 20 | parser.add_argument('--local_bs', type=int, default=20, 21 | help="local batch size: B") 22 | 23 | parser.add_argument('--local_lr', type=float, default=0.01, 24 | help='learning rate for local update') 25 | parser.add_argument('--lr', type=float, default=1.0, 26 | help='learning rate for global update') 27 | parser.add_argument('--momentum', type=float, default=0.0, 28 | help='SGD momentum (default: 0.0)') 29 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam') 30 | parser.add_argument('--beta2', type=float, default=0.99, help='beta2 for adam') 31 | parser.add_argument('--eps', type=float, default=1e-8, help='eps for adam') 32 | parser.add_argument('--max_init', type=float, default=0.0, help='initialize max_v for adam') 33 | 34 | 35 | 36 | # model arguments 37 | parser.add_argument('--model', type=str, default='cnn', help='model name') 38 | 39 | # other arguments 40 | parser.add_argument('--dataset', type=str, default='cifar10', help="name \ 41 | of dataset") 42 | parser.add_argument('--num_classes', type=int, default=10, help="number \ 43 | of classes") 44 | 45 | parser.add_argument('--gpu', default=0, help="To use cuda, set \ 46 | to a specific GPU ID. Default set to use CPU.") 47 | parser.add_argument('--optimizer', type=str, default='fedavg', help="type \ 48 | of optimizer") 49 | parser.add_argument('--iid', type=int, default=1, 50 | help='Default set to IID. Set to 0 for non-IID.') 51 | parser.add_argument('--unequal', type=int, default=0, 52 | help='whether to use unequal data splits for \ 53 | non-i.i.d setting (use 0 for equal splits)') 54 | parser.add_argument('--stopping_rounds', type=int, default=10, 55 | help='rounds of early stopping') 56 | parser.add_argument('--verbose', type=int, default=0, help='verbose') 57 | parser.add_argument('--seed', type=int, default=1, help='random seed') 58 | parser.add_argument('--save', type=int, default=1, help='whether to save results') 59 | parser.add_argument('--outfolder', type=str, default='./results') 60 | 61 | parser.add_argument('--compressor', type=str, default='sign', help='compressor strategy') 62 | args = parser.parse_args() 63 | return args 64 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from torch.autograd import Variable 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion*planes) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = self.bn2(self.conv2(out)) 36 | out += self.shortcut(x) 37 | out = F.relu(out) 38 | return out 39 | 40 | 41 | class Bottleneck(nn.Module): 42 | expansion = 4 43 | 44 | def __init__(self, in_planes, planes, stride=1): 45 | super(Bottleneck, self).__init__() 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 52 | 53 | self.shortcut = nn.Sequential() 54 | if stride != 1 or in_planes != self.expansion*planes: 55 | self.shortcut = nn.Sequential( 56 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 57 | nn.BatchNorm2d(self.expansion*planes) 58 | ) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.bn1(self.conv1(x))) 62 | out = F.relu(self.bn2(self.conv2(out))) 63 | out = self.bn3(self.conv3(out)) 64 | out += self.shortcut(x) 65 | out = F.relu(out) 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, block, num_blocks, num_classes=10): 71 | super(ResNet, self).__init__() 72 | self.in_planes = 64 73 | 74 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(64) 76 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 77 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 78 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 79 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 80 | self.linear = nn.Linear(512*block.expansion, num_classes) 81 | 82 | def _make_layer(self, block, planes, num_blocks, stride): 83 | strides = [stride] + [1]*(num_blocks-1) 84 | layers = [] 85 | for stride in strides: 86 | layers.append(block(self.in_planes, planes, stride)) 87 | self.in_planes = planes * block.expansion 88 | return nn.Sequential(*layers) 89 | 90 | def forward(self, x): 91 | out = F.relu(self.bn1(self.conv1(x))) 92 | out = self.layer1(out) 93 | out = self.layer2(out) 94 | out = self.layer3(out) 95 | out = self.layer4(out) 96 | out = F.avg_pool2d(out, 4) 97 | out = out.view(out.size(0), -1) 98 | out = self.linear(out) 99 | return out 100 | 101 | 102 | def ResNet18(num_classes = 10): 103 | return ResNet(BasicBlock, [2,2,2,2], num_classes = num_classes) 104 | 105 | def ResNet34(num_classes = 10): 106 | return ResNet(BasicBlock, [3,4,6,3], num_classes = num_classes) 107 | 108 | def ResNet50(num_classes = 10): 109 | return ResNet(Bottleneck, [3,4,6,3], num_classes = num_classes) 110 | 111 | def ResNet101(num_classes = 10): 112 | return ResNet(Bottleneck, [3,4,23,3], num_classes = num_classes) 113 | 114 | def ResNet152(num_classes = 10): 115 | return ResNet(Bottleneck, [3,8,36,3], num_classes = num_classes) 116 | 117 | 118 | def test(): 119 | net = ResNet18() 120 | y = net(Variable(torch.randn(1,3,32,32))) 121 | print(y.size()) 122 | 123 | # test() 124 | -------------------------------------------------------------------------------- /models/simple.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class MLP(nn.Module): 10 | def __init__(self, dim_in, dim_hidden, dim_out): 11 | super(MLP, self).__init__() 12 | self.layer_input = nn.Linear(dim_in, dim_hidden) 13 | self.relu = nn.ReLU() 14 | self.dropout = nn.Dropout() 15 | self.layer_hidden = nn.Linear(dim_hidden, dim_out) 16 | self.softmax = nn.Softmax(dim=1) 17 | 18 | def forward(self, x): 19 | x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1]) 20 | x = self.layer_input(x) 21 | x = self.dropout(x) 22 | x = self.relu(x) 23 | x = self.layer_hidden(x) 24 | return x 25 | 26 | 27 | class CNNMnist(nn.Module): 28 | def __init__(self, num_classes, num_channels): 29 | super(CNNMnist, self).__init__() 30 | self.conv1 = nn.Conv2d(num_channels, 10, kernel_size=5) 31 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 32 | self.conv2_drop = nn.Dropout2d() 33 | self.fc1 = nn.Linear(320, 50) 34 | self.fc2 = nn.Linear(50, num_classes) 35 | 36 | def forward(self, x): 37 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 38 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 39 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 40 | x = F.relu(self.fc1(x)) 41 | x = F.dropout(x, training=self.training) 42 | x = self.fc2(x) 43 | return x 44 | 45 | 46 | class CNNFashion_Mnist(nn.Module): 47 | def __init__(self, num_classes): 48 | super(CNNFashion_Mnist, self).__init__() 49 | self.layer1 = nn.Sequential( 50 | nn.Conv2d(1, 16, kernel_size=5, padding=2), 51 | nn.BatchNorm2d(16), 52 | nn.ReLU(), 53 | nn.MaxPool2d(2)) 54 | self.layer2 = nn.Sequential( 55 | nn.Conv2d(16, 32, kernel_size=5, padding=2), 56 | nn.BatchNorm2d(32), 57 | nn.ReLU(), 58 | nn.MaxPool2d(2)) 59 | self.fc = nn.Linear(7*7*32, num_classes) 60 | 61 | def forward(self, x): 62 | out = self.layer1(x) 63 | out = self.layer2(out) 64 | out = out.view(out.size(0), -1) 65 | out = self.fc(out) 66 | return out 67 | 68 | 69 | class CNNCifar(nn.Module): 70 | def __init__(self, num_classes): 71 | super(CNNCifar, self).__init__() 72 | self.conv1 = nn.Conv2d(3, 6, 5) 73 | self.pool = nn.MaxPool2d(2, 2) 74 | self.conv2 = nn.Conv2d(6, 16, 5) 75 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 76 | self.fc2 = nn.Linear(120, 84) 77 | self.fc3 = nn.Linear(84, num_classes) 78 | 79 | def forward(self, x): 80 | x = self.pool(F.relu(self.conv1(x))) 81 | x = self.pool(F.relu(self.conv2(x))) 82 | x = x.view(-1, 16 * 5 * 5) 83 | x = F.relu(self.fc1(x)) 84 | x = F.relu(self.fc2(x)) 85 | out = self.fc3(x) 86 | return out 87 | 88 | 89 | class CNNLarge(nn.Module): 90 | def __init__(self): 91 | super().__init__() 92 | self.network = nn.Sequential( 93 | nn.Conv2d(3, 32, kernel_size=3, padding=1), 94 | nn.ReLU(), 95 | nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(), 97 | nn.MaxPool2d(2, 2), # output: 64 x 16 x 16 98 | nn.BatchNorm2d(64), 99 | 100 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), 101 | nn.ReLU(), 102 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), 103 | nn.ReLU(), 104 | nn.MaxPool2d(2, 2), # output: 128 x 8 x 8 105 | nn.BatchNorm2d(128), 106 | 107 | nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), 108 | nn.ReLU(), 109 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), 110 | nn.ReLU(), 111 | nn.MaxPool2d(2, 2), # output: 256 x 4 x 4 112 | nn.BatchNorm2d(256), 113 | 114 | nn.Flatten(), 115 | nn.Linear(256*4*4, 1024), 116 | nn.ReLU(), 117 | nn.Linear(1024, 512), 118 | nn.ReLU(), 119 | nn.Linear(512, 10)) 120 | 121 | def forward(self, xb): 122 | return self.network(xb) 123 | 124 | 125 | 126 | class Autoencoder(nn.Module): 127 | def __init__(self): 128 | super(Autoencoder,self).__init__() 129 | self.encoder = nn.Sequential( 130 | # 28 x 28 131 | nn.Conv2d(1, 4, kernel_size=5), 132 | # 4 x 24 x 24 133 | nn.ReLU(True), 134 | nn.Conv2d(4, 8, kernel_size=5), 135 | nn.ReLU(True), 136 | # 8 x 20 x 20 = 3200 137 | nn.Flatten(), 138 | nn.Linear(3200, 10), 139 | # 10 140 | nn.Softmax(), 141 | ) 142 | self.decoder = nn.Sequential( 143 | # 10 144 | nn.Linear(10, 400), 145 | # 400 146 | nn.ReLU(True), 147 | nn.Linear(400, 4000), 148 | # 4000 149 | nn.ReLU(True), 150 | nn.Unflatten(1, (10, 20, 20)), 151 | # 10 x 20 x 20 152 | nn.ConvTranspose2d(10, 10, kernel_size=5), 153 | # 24 x 24 154 | nn.ConvTranspose2d(10, 1, kernel_size=5), 155 | # 28 x 28 156 | nn.Sigmoid(), 157 | ) 158 | def forward(self, x): 159 | enc = self.encoder(x) 160 | dec = self.decoder(enc) 161 | return dec -------------------------------------------------------------------------------- /federated_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | 6 | import os 7 | import copy 8 | import time 9 | import pickle 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | import math 14 | import torch 15 | from torch import nn 16 | from tensorboardX import SummaryWriter 17 | 18 | from options import args_parser 19 | from update import LocalUpdate, update_model_inplace, test_inference 20 | from utils import get_model, get_dataset, average_weights, exp_details, average_parameter_delta 21 | 22 | if __name__ == '__main__': 23 | start_time = time.time() 24 | 25 | args = args_parser() 26 | exp_details(args) 27 | 28 | # define paths 29 | # out_dir_name = args.model + args.dataset + args.optimizer + '_lr' + str(args.lr) + '_locallr' + str(args.local_lr) + '_localep' + str(args.local_ep) +'_localbs' + str(args.local_bs) + '_eps' + str(args.eps) 30 | file_name = '/{}_{}_{}_llr[{}]_glr[{}]_eps[{}]_le[{}]_bs[{}]_iid[{}]_mi[{}]_frac[{}].pkl'.\ 31 | format(args.dataset, args.model, args.optimizer, 32 | args.local_lr, args.lr, args.eps, 33 | args.local_ep, args.local_bs, args.iid, args.max_init, args.frac) 34 | logger = SummaryWriter('./logs/'+file_name) 35 | 36 | device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else "cpu") 37 | torch.set_num_threads(1) # limit cpu use 38 | print ('-- pytorch version: ', torch.__version__) 39 | 40 | np.random.seed(args.seed) 41 | torch.manual_seed(args.seed) 42 | if device != 'cpu': 43 | torch.cuda.manual_seed(args.seed) 44 | 45 | if not os.path.exists(args.outfolder): 46 | os.mkdir(args.outfolder) 47 | 48 | # load dataset and user groups 49 | train_dataset, test_dataset, num_classes, user_groups = get_dataset(args) 50 | 51 | # Set the model to train and send it to device. 52 | global_model = get_model(args.model, args.dataset, train_dataset[0][0].shape, num_classes) 53 | global_model.to(device) 54 | global_model.train() 55 | 56 | 57 | momentum_buffer_list = [] 58 | exp_avgs = [] 59 | exp_avg_sqs = [] 60 | max_exp_avg_sqs = [] 61 | for i, p in enumerate(global_model.parameters()): 62 | momentum_buffer_list.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)) 63 | exp_avgs.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)) 64 | exp_avg_sqs.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)) 65 | max_exp_avg_sqs.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)+args.max_init) # 1e-2 66 | 67 | 68 | 69 | 70 | 71 | # Training 72 | train_loss_sampled, train_loss, train_accuracy = [], [], [] 73 | test_loss, test_accuracy = [], [] 74 | start_time = time.time() 75 | for epoch in tqdm(range(args.epochs)): 76 | ep_time = time.time() 77 | 78 | local_weights, local_params, local_losses = [], [], [] 79 | print(f'\n | Global Training Round : {epoch+1} |\n') 80 | 81 | par_before = [] 82 | for p in global_model.parameters(): # get trainable parameters 83 | par_before.append(p.data.detach().clone()) 84 | # this is to store parameters before update 85 | w0 = global_model.state_dict() # get all parameters, includeing batch normalization related ones 86 | 87 | global_model.train() 88 | m = max(int(args.frac * args.num_users), 1) 89 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 90 | 91 | for idx in idxs_users: 92 | 93 | local_model = LocalUpdate(args=args, dataset=train_dataset, 94 | idxs=user_groups[idx], logger=logger) 95 | 96 | w, p, loss = local_model.update_weights_local( 97 | model=copy.deepcopy(global_model), global_round=epoch) 98 | 99 | 100 | local_weights.append(copy.deepcopy(w)) 101 | local_params.append(copy.deepcopy(p)) 102 | local_losses.append(copy.deepcopy(loss)) 103 | 104 | bn_weights = average_weights(local_weights) 105 | global_model.load_state_dict(bn_weights) 106 | 107 | # this is to update trainable parameters via different optimizers 108 | global_delta = average_parameter_delta(local_params, par_before) # calculate compression in this function 109 | 110 | update_model_inplace( 111 | global_model, par_before, global_delta, args, epoch, 112 | momentum_buffer_list, exp_avgs, exp_avg_sqs, max_exp_avg_sqs) 113 | 114 | 115 | # report and store loss and accuracy 116 | # this is local training loss on sampled users 117 | loss_avg = sum(local_losses) / len(local_losses) 118 | train_loss.append(loss_avg) 119 | 120 | print('Epoch Run Time: {0:0.4f} of {1} global rounds'.format(time.time()-ep_time, epoch+1)) 121 | print(f'Training Loss : {train_loss[-1]}') 122 | logger.add_scalar('train loss', train_loss[-1], epoch) 123 | 124 | global_model.eval() 125 | 126 | 127 | # Test inference after completion of training 128 | test_acc, test_ls = test_inference(args, global_model, test_dataset) 129 | test_accuracy.append(test_acc) 130 | test_loss.append(test_ls) 131 | 132 | # print global training loss after every rounds 133 | 134 | print(f'Test Loss : {test_loss[-1]}') 135 | print(f'Test Accuracy : {test_accuracy[-1]} \n') 136 | 137 | logger.add_scalar('test loss', test_loss[-1], epoch) 138 | logger.add_scalar('test acc', test_accuracy[-1], epoch) 139 | 140 | if args.save: 141 | # Saving the objects train_loss and train_accuracy: 142 | with open(args.outfolder + file_name, 'wb') as f: 143 | pickle.dump([train_loss, test_loss, test_accuracy], f) 144 | 145 | 146 | print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time)) 147 | 148 | -------------------------------------------------------------------------------- /federated_main-ef.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | 6 | import os 7 | import copy 8 | import time 9 | import pickle 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | import math 14 | import torch 15 | from tensorboardX import SummaryWriter 16 | 17 | from options import args_parser 18 | from update import LocalUpdate, update_model_inplace, test_inference 19 | from utils import get_model, get_dataset, average_weights, exp_details, average_parameter_delta 20 | import utils 21 | 22 | if __name__ == '__main__': 23 | start_time = time.time() 24 | 25 | args = args_parser() 26 | exp_details(args) 27 | 28 | # define paths 29 | # out_dir_name = args.model + '_compress_' + args.dataset + args.optimizer + '_lr' + str(args.lr) + '_locallr' + str( args.local_lr) + '_localep' + str(args.local_ep) +'_localbs' + str(args.local_bs) + '_eps' + str(args.eps) 30 | file_name = '/ef_{}_{}_{}_llr[{}]_glr[{}]_eps[{}]_le[{}]_bs[{}]_iid[{}]_mi[{}]_frac[{}]_{}.pkl'.\ 31 | format(args.dataset, args.model, args.optimizer, 32 | args.local_lr, args.lr, args.eps, 33 | args.local_ep, args.local_bs, args.iid, args.max_init, args.frac, args.compressor) 34 | logger = SummaryWriter('./logs/'+file_name) 35 | 36 | device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else "cpu") 37 | torch.set_num_threads(1) # limit cpu use 38 | print ('-- pytorch version: ', torch.__version__) 39 | 40 | np.random.seed(args.seed) 41 | torch.manual_seed(args.seed) 42 | if device != 'cpu': 43 | torch.cuda.manual_seed(args.seed) 44 | 45 | if not os.path.exists(args.outfolder): 46 | os.mkdir(args.outfolder) 47 | 48 | # load dataset and user groups 49 | train_dataset, test_dataset, num_classes, user_groups = get_dataset(args) 50 | 51 | # Set the model to train and send it to device. 52 | global_model = get_model(args.model, args.dataset, train_dataset[0][0].shape, num_classes) 53 | global_model.to(device) 54 | global_model.train() 55 | 56 | momentum_buffer_list = [] 57 | exp_avgs = [] 58 | exp_avg_sqs = [] 59 | max_exp_avg_sqs = [] 60 | for i, p in enumerate(global_model.parameters()): 61 | momentum_buffer_list.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)) 62 | exp_avgs.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)) 63 | exp_avg_sqs.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)) 64 | max_exp_avg_sqs.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)+args.max_init) # 1e-2 65 | 66 | 67 | ### init error ------- 68 | e = [] 69 | for id in range(args.num_users): 70 | ei = [] 71 | for i, p in enumerate(global_model.parameters()): 72 | ei.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)) 73 | e.append(ei) 74 | D = sum(p.numel() for p in global_model.parameters()) 75 | print('total dimension:', D) 76 | print('compressor:', args.compressor) 77 | 78 | 79 | # Training 80 | train_loss_sampled, train_loss, train_accuracy = [], [], [] 81 | test_loss, test_accuracy = [], [] 82 | start_time = time.time() 83 | for epoch in tqdm(range(args.epochs)): 84 | ep_time = time.time() 85 | 86 | local_weights, local_params, local_losses = [], [], [] 87 | print(f'\n | Global Training Round : {epoch+1} |\n') 88 | 89 | 90 | par_before = [] 91 | for p in global_model.parameters(): # get trainable parameters 92 | par_before.append(p.data.detach().clone()) 93 | # this is to store parameters before update 94 | w0 = global_model.state_dict() # get all parameters, includeing batch normalization related ones 95 | 96 | 97 | global_model.train() 98 | m = max(int(args.frac * args.num_users), 1) 99 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 100 | 101 | for idx in idxs_users: 102 | 103 | local_model = LocalUpdate(args=args, dataset=train_dataset, 104 | idxs=user_groups[idx], logger=logger) 105 | 106 | w, p, loss = local_model.update_weights_local( 107 | model=copy.deepcopy(global_model), global_round=epoch) 108 | 109 | 110 | ####### add error feedback ####### 111 | delta = utils.sub_params(p, par_before) 112 | tmp = utils.add_params(e[idx], delta) 113 | delta_out = local_model.compressSignal(tmp, D) 114 | e[idx] = utils.sub_params(tmp, delta_out) 115 | 116 | local_weights.append(copy.deepcopy(w)) 117 | local_params.append(copy.deepcopy(utils.add_params(delta_out, par_before))) 118 | local_losses.append(copy.deepcopy(loss)) 119 | 120 | 121 | 122 | bn_weights = average_weights(local_weights) 123 | global_model.load_state_dict(bn_weights) 124 | 125 | global_delta = average_parameter_delta(local_params, par_before) # calculate compression in this function 126 | 127 | update_model_inplace( 128 | global_model, par_before, global_delta, args, epoch, 129 | momentum_buffer_list, exp_avgs, exp_avg_sqs, max_exp_avg_sqs) 130 | 131 | # report and store loss and accuracy 132 | # this is local training loss on sampled users 133 | loss_avg = sum(local_losses) / len(local_losses) 134 | train_loss.append(loss_avg) 135 | 136 | 137 | 138 | global_model.eval() 139 | 140 | 141 | # Test inference after completion of training 142 | test_acc, test_ls = test_inference(args, global_model, test_dataset) 143 | test_accuracy.append(test_acc) 144 | test_loss.append(test_ls) 145 | 146 | # print global training loss after every rounds 147 | print('Epoch Run Time: {0:0.4f} of {1} global rounds'.format(time.time()-ep_time, epoch+1)) 148 | print(f'Training Loss : {train_loss[-1]}') 149 | print(f'Test Loss : {test_loss[-1]}') 150 | print(f'Test Accuracy : {test_accuracy[-1]} \n') 151 | logger.add_scalar('train loss', train_loss[-1], epoch) 152 | logger.add_scalar('test loss', test_loss[-1], epoch) 153 | logger.add_scalar('test acc', test_accuracy[-1], epoch) 154 | 155 | if args.save: 156 | # Saving the objects train_loss and train_accuracy: 157 | 158 | 159 | with open(args.outfolder + file_name, 'wb') as f: 160 | pickle.dump([train_loss, test_loss, test_accuracy], f) 161 | 162 | print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time)) 163 | 164 | 165 | -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | 6 | import numpy as np 7 | from torchvision import datasets, transforms 8 | 9 | 10 | def mnist_iid(dataset, num_users): 11 | """ 12 | Sample I.I.D. client data from MNIST dataset 13 | :param dataset: 14 | :param num_users: 15 | :return: dict of image index 16 | """ 17 | num_items = int(len(dataset)/num_users) 18 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 19 | for i in range(num_users): 20 | dict_users[i] = set(np.random.choice(all_idxs, num_items, 21 | replace=False)) 22 | all_idxs = list(set(all_idxs) - dict_users[i]) 23 | return dict_users 24 | 25 | 26 | def mnist_noniid(dataset, num_users): 27 | """ 28 | Sample non-I.I.D client data from MNIST dataset 29 | :param dataset: 30 | :param num_users: 31 | :return: 32 | """ 33 | # 60,000 training imgs --> 200 imgs/shard X 300 shards 34 | num_shards, num_imgs = 200, 300 35 | idx_shard = [i for i in range(num_shards)] 36 | dict_users = {i: np.array([]) for i in range(num_users)} 37 | idxs = np.arange(num_shards*num_imgs) 38 | labels = dataset.train_labels.numpy() 39 | 40 | # sort labels 41 | idxs_labels = np.vstack((idxs, labels)) 42 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] 43 | idxs = idxs_labels[0, :] 44 | 45 | # divide and assign 2 shards/client 46 | for i in range(num_users): 47 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 48 | idx_shard = list(set(idx_shard) - rand_set) 49 | for rand in rand_set: 50 | dict_users[i] = np.concatenate( 51 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 52 | return dict_users 53 | 54 | 55 | def mnist_noniid_unequal(dataset, num_users): 56 | """ 57 | Sample non-I.I.D client data from MNIST dataset s.t clients 58 | have unequal amount of data 59 | :param dataset: 60 | :param num_users: 61 | :returns a dict of clients with each clients assigned certain 62 | number of training imgs 63 | """ 64 | # 60,000 training imgs --> 50 imgs/shard X 1200 shards 65 | num_shards, num_imgs = 1200, 50 66 | idx_shard = [i for i in range(num_shards)] 67 | dict_users = {i: np.array([]) for i in range(num_users)} 68 | idxs = np.arange(num_shards*num_imgs) 69 | labels = dataset.train_labels.numpy() 70 | 71 | # sort labels 72 | idxs_labels = np.vstack((idxs, labels)) 73 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] 74 | idxs = idxs_labels[0, :] 75 | 76 | # Minimum and maximum shards assigned per client: 77 | min_shard = 1 78 | max_shard = 30 79 | 80 | # Divide the shards into random chunks for every client 81 | # s.t the sum of these chunks = num_shards 82 | random_shard_size = np.random.randint(min_shard, max_shard+1, 83 | size=num_users) 84 | random_shard_size = np.around(random_shard_size / 85 | sum(random_shard_size) * num_shards) 86 | random_shard_size = random_shard_size.astype(int) 87 | 88 | # Assign the shards randomly to each client 89 | if sum(random_shard_size) > num_shards: 90 | 91 | for i in range(num_users): 92 | # First assign each client 1 shard to ensure every client has 93 | # atleast one shard of data 94 | rand_set = set(np.random.choice(idx_shard, 1, replace=False)) 95 | idx_shard = list(set(idx_shard) - rand_set) 96 | for rand in rand_set: 97 | dict_users[i] = np.concatenate( 98 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), 99 | axis=0) 100 | 101 | random_shard_size = random_shard_size-1 102 | 103 | # Next, randomly assign the remaining shards 104 | for i in range(num_users): 105 | if len(idx_shard) == 0: 106 | continue 107 | shard_size = random_shard_size[i] 108 | if shard_size > len(idx_shard): 109 | shard_size = len(idx_shard) 110 | rand_set = set(np.random.choice(idx_shard, shard_size, 111 | replace=False)) 112 | idx_shard = list(set(idx_shard) - rand_set) 113 | for rand in rand_set: 114 | dict_users[i] = np.concatenate( 115 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), 116 | axis=0) 117 | else: 118 | 119 | for i in range(num_users): 120 | shard_size = random_shard_size[i] 121 | rand_set = set(np.random.choice(idx_shard, shard_size, 122 | replace=False)) 123 | idx_shard = list(set(idx_shard) - rand_set) 124 | for rand in rand_set: 125 | dict_users[i] = np.concatenate( 126 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), 127 | axis=0) 128 | 129 | if len(idx_shard) > 0: 130 | # Add the leftover shards to the client with minimum images: 131 | shard_size = len(idx_shard) 132 | # Add the remaining shard to the client with lowest data 133 | k = min(dict_users, key=lambda x: len(dict_users.get(x))) 134 | rand_set = set(np.random.choice(idx_shard, shard_size, 135 | replace=False)) 136 | idx_shard = list(set(idx_shard) - rand_set) 137 | for rand in rand_set: 138 | dict_users[k] = np.concatenate( 139 | (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]), 140 | axis=0) 141 | 142 | return dict_users 143 | 144 | 145 | def cifar_iid(dataset, num_users): 146 | """ 147 | Sample I.I.D. client data from CIFAR10 dataset 148 | :param dataset: 149 | :param num_users: 150 | :return: dict of image index 151 | """ 152 | num_items = int(len(dataset)/num_users) 153 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 154 | for i in range(num_users): 155 | dict_users[i] = set(np.random.choice(all_idxs, num_items, 156 | replace=False)) 157 | all_idxs = list(set(all_idxs) - dict_users[i]) 158 | return dict_users 159 | 160 | 161 | def cifar_noniid(dataset, num_users): 162 | """ 163 | Sample non-I.I.D client data from CIFAR10 dataset 164 | :param dataset: 165 | :param num_users: 166 | :return: 167 | """ 168 | num_shards, num_imgs = 200, 250 169 | idx_shard = [i for i in range(num_shards)] 170 | dict_users = {i: np.array([]) for i in range(num_users)} 171 | idxs = np.arange(num_shards*num_imgs) 172 | labels = [dataset[i][1] for i in range(len(dataset))] 173 | # labels = np.array(dataset.train_labels) 174 | 175 | # sort labels 176 | idxs_labels = np.vstack((idxs, labels)) 177 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] 178 | idxs = idxs_labels[0, :] 179 | 180 | # divide and assign 181 | for i in range(num_users): 182 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 183 | idx_shard = list(set(idx_shard) - rand_set) 184 | for rand in rand_set: 185 | dict_users[i] = np.concatenate( 186 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 187 | return dict_users 188 | 189 | 190 | if __name__ == '__main__': 191 | dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True, 192 | transform=transforms.Compose([ 193 | transforms.ToTensor(), 194 | transforms.Normalize((0.1307,), 195 | (0.3081,)) 196 | ])) 197 | num = 100 198 | d = mnist_noniid(dataset_train, num) 199 | -------------------------------------------------------------------------------- /models/randaug.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | 11 | def ShearX(img, v): # [-0.3, 0.3] 12 | assert -0.3 <= v <= 0.3 13 | if random.random() > 0.5: 14 | v = -v 15 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 16 | 17 | 18 | def ShearY(img, v): # [-0.3, 0.3] 19 | assert -0.3 <= v <= 0.3 20 | if random.random() > 0.5: 21 | v = -v 22 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 23 | 24 | 25 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 26 | assert -0.45 <= v <= 0.45 27 | if random.random() > 0.5: 28 | v = -v 29 | v = v * img.size[0] 30 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 31 | 32 | 33 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 34 | assert 0 <= v 35 | if random.random() > 0.5: 36 | v = -v 37 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 38 | 39 | 40 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 41 | assert -0.45 <= v <= 0.45 42 | if random.random() > 0.5: 43 | v = -v 44 | v = v * img.size[1] 45 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 46 | 47 | 48 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 49 | assert 0 <= v 50 | if random.random() > 0.5: 51 | v = -v 52 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 53 | 54 | 55 | def Rotate(img, v): # [-30, 30] 56 | assert -30 <= v <= 30 57 | if random.random() > 0.5: 58 | v = -v 59 | return img.rotate(v) 60 | 61 | 62 | def AutoContrast(img, _): 63 | return PIL.ImageOps.autocontrast(img) 64 | 65 | 66 | def Invert(img, _): 67 | return PIL.ImageOps.invert(img) 68 | 69 | 70 | def Equalize(img, _): 71 | return PIL.ImageOps.equalize(img) 72 | 73 | 74 | def Flip(img, _): # not from the paper 75 | return PIL.ImageOps.mirror(img) 76 | 77 | 78 | def Solarize(img, v): # [0, 256] 79 | assert 0 <= v <= 256 80 | return PIL.ImageOps.solarize(img, v) 81 | 82 | 83 | def SolarizeAdd(img, addition=0, threshold=128): 84 | img_np = np.array(img).astype(np.int) 85 | img_np = img_np + addition 86 | img_np = np.clip(img_np, 0, 255) 87 | img_np = img_np.astype(np.uint8) 88 | img = Image.fromarray(img_np) 89 | return PIL.ImageOps.solarize(img, threshold) 90 | 91 | 92 | def Posterize(img, v): # [4, 8] 93 | v = int(v) 94 | v = max(1, v) 95 | return PIL.ImageOps.posterize(img, v) 96 | 97 | 98 | def Contrast(img, v): # [0.1,1.9] 99 | assert 0.1 <= v <= 1.9 100 | return PIL.ImageEnhance.Contrast(img).enhance(v) 101 | 102 | 103 | def Color(img, v): # [0.1,1.9] 104 | assert 0.1 <= v <= 1.9 105 | return PIL.ImageEnhance.Color(img).enhance(v) 106 | 107 | 108 | def Brightness(img, v): # [0.1,1.9] 109 | assert 0.1 <= v <= 1.9 110 | return PIL.ImageEnhance.Brightness(img).enhance(v) 111 | 112 | 113 | def Sharpness(img, v): # [0.1,1.9] 114 | assert 0.1 <= v <= 1.9 115 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 116 | 117 | 118 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 119 | assert 0.0 <= v <= 0.2 120 | if v <= 0.: 121 | return img 122 | 123 | v = v * img.size[0] 124 | return CutoutAbs(img, v) 125 | 126 | 127 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 128 | # assert 0 <= v <= 20 129 | if v < 0: 130 | return img 131 | w, h = img.size 132 | x0 = np.random.uniform(w) 133 | y0 = np.random.uniform(h) 134 | 135 | x0 = int(max(0, x0 - v / 2.)) 136 | y0 = int(max(0, y0 - v / 2.)) 137 | x1 = min(w, x0 + v) 138 | y1 = min(h, y0 + v) 139 | 140 | xy = (x0, y0, x1, y1) 141 | color = (125, 123, 114) 142 | # color = (0, 0, 0) 143 | img = img.copy() 144 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 145 | return img 146 | 147 | 148 | def SamplePairing(imgs): # [0, 0.4] 149 | def f(img1, v): 150 | i = np.random.choice(len(imgs)) 151 | img2 = PIL.Image.fromarray(imgs[i]) 152 | return PIL.Image.blend(img1, img2, v) 153 | 154 | return f 155 | 156 | 157 | def Identity(img, v): 158 | return img 159 | 160 | 161 | def augment_list(): # 16 oeprations and their ranges 162 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 163 | # l = [ 164 | # (Identity, 0., 1.0), 165 | # (ShearX, 0., 0.3), # 0 166 | # (ShearY, 0., 0.3), # 1 167 | # (TranslateX, 0., 0.33), # 2 168 | # (TranslateY, 0., 0.33), # 3 169 | # (Rotate, 0, 30), # 4 170 | # (AutoContrast, 0, 1), # 5 171 | # (Invert, 0, 1), # 6 172 | # (Equalize, 0, 1), # 7 173 | # (Solarize, 0, 110), # 8 174 | # (Posterize, 4, 8), # 9 175 | # # (Contrast, 0.1, 1.9), # 10 176 | # (Color, 0.1, 1.9), # 11 177 | # (Brightness, 0.1, 1.9), # 12 178 | # (Sharpness, 0.1, 1.9), # 13 179 | # # (Cutout, 0, 0.2), # 14 180 | # # (SamplePairing(imgs), 0, 0.4), # 15 181 | # ] 182 | 183 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 184 | l = [ 185 | (AutoContrast, 0, 1), 186 | (Equalize, 0, 1), 187 | (Invert, 0, 1), 188 | (Rotate, 0, 30), 189 | (Posterize, 0, 4), 190 | (Solarize, 0, 256), 191 | (SolarizeAdd, 0, 110), 192 | (Color, 0.1, 1.9), 193 | (Contrast, 0.1, 1.9), 194 | (Brightness, 0.1, 1.9), 195 | (Sharpness, 0.1, 1.9), 196 | (ShearX, 0., 0.3), 197 | (ShearY, 0., 0.3), 198 | (CutoutAbs, 0, 40), 199 | (TranslateXabs, 0., 100), 200 | (TranslateYabs, 0., 100), 201 | ] 202 | 203 | return l 204 | 205 | 206 | class Lighting(object): 207 | """Lighting noise(AlexNet - style PCA - based noise)""" 208 | 209 | def __init__(self, alphastd, eigval, eigvec): 210 | self.alphastd = alphastd 211 | self.eigval = torch.Tensor(eigval) 212 | self.eigvec = torch.Tensor(eigvec) 213 | 214 | def __call__(self, img): 215 | if self.alphastd == 0: 216 | return img 217 | 218 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 219 | rgb = self.eigvec.type_as(img).clone() \ 220 | .mul(alpha.view(1, 3).expand(3, 3)) \ 221 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 222 | .sum(1).squeeze() 223 | 224 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 225 | 226 | 227 | class CutoutDefault(object): 228 | """ 229 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 230 | """ 231 | def __init__(self, length): 232 | self.length = length 233 | 234 | def __call__(self, img): 235 | h, w = img.size(1), img.size(2) 236 | mask = np.ones((h, w), np.float32) 237 | y = np.random.randint(h) 238 | x = np.random.randint(w) 239 | 240 | y1 = np.clip(y - self.length // 2, 0, h) 241 | y2 = np.clip(y + self.length // 2, 0, h) 242 | x1 = np.clip(x - self.length // 2, 0, w) 243 | x2 = np.clip(x + self.length // 2, 0, w) 244 | 245 | mask[y1: y2, x1: x2] = 0. 246 | mask = torch.from_numpy(mask) 247 | mask = mask.expand_as(img) 248 | img *= mask 249 | return img 250 | 251 | 252 | class RandAugment: 253 | def __init__(self, n, m): 254 | self.n = n 255 | self.m = m # [0, 30] 256 | self.augment_list = augment_list() 257 | 258 | def __call__(self, img): 259 | ops = random.choices(self.augment_list, k=self.n) 260 | for op, minval, maxval in ops: 261 | val = (float(self.m) / 30) * float(maxval - minval) + minval 262 | img = op(img, val) 263 | 264 | return img -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import copy 6 | import torch 7 | from torchvision import datasets, transforms 8 | from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal 9 | from sampling import cifar_iid, cifar_noniid 10 | # from models.randaug import RandAugment 11 | 12 | def get_model(model_name, dataset, img_size, nclass): 13 | if model_name == 'vggnet': 14 | from models import vgg 15 | model = vgg.VGG('VGG11', num_classes=nclass) 16 | 17 | elif model_name == 'resnet': 18 | from models import resnet 19 | model = resnet.ResNet18(num_classes=nclass) 20 | 21 | elif model_name == 'wideresnet': 22 | from models import wideresnet 23 | model = wideresnet.WResNet_cifar10(num_classes=nclass, depth=16, multiplier=4) 24 | 25 | elif model_name == 'cnnlarge': 26 | from models import simple 27 | model = simple.CNNLarge() 28 | 29 | elif model_name == 'convmixer': 30 | from models import convmixer 31 | model = convmixer.ConvMixer(n_classes=nclass) 32 | 33 | elif model_name == 'cnn': 34 | from models import simple 35 | 36 | if dataset == 'mnist': 37 | model = simple.CNNMnist(num_classes=nclass, num_channels=1) 38 | elif dataset == 'fmnist': 39 | model = simple.CNNFashion_Mnist(num_classes=nclass) 40 | elif dataset == 'cifar': 41 | model = simple.CNNCifar(num_classes=nclass) 42 | elif model_name == 'ae': 43 | from models import simple 44 | 45 | if dataset == 'mnist' or dataset == 'fmnist': 46 | model = simple.Autoencoder() 47 | 48 | elif model_name == 'mlp': 49 | from models import simple 50 | 51 | len_in = 1 52 | for x in img_size: 53 | len_in *= x 54 | model = simple.MLP(dim_in=len_in, dim_hidden=64, 55 | dim_out=nclass) 56 | else: 57 | exit('Error: unrecognized model') 58 | 59 | return model 60 | 61 | 62 | def get_dataset(args): 63 | """ Returns train and test datasets and a user group which is a dict where 64 | the keys are the user index and the values are the corresponding data for 65 | each of those users. 66 | """ 67 | 68 | if args.dataset == 'cifar10' or 'cifar100': 69 | 70 | transform_train = transforms.Compose([ 71 | transforms.RandomCrop(32, padding=4), 72 | transforms.RandomHorizontalFlip(), 73 | # transforms.RandAugment(num_ops=2, magnitude=14), 74 | transforms.ToTensor(), 75 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 76 | ]) 77 | 78 | # transform_train.transforms.insert(0, RandAugment(2, 14)) 79 | 80 | transform_test = transforms.Compose([ 81 | transforms.ToTensor(), 82 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 83 | ]) 84 | 85 | if args.dataset == 'cifar10': 86 | data_dir = '../data/cifar/' 87 | 88 | train_dataset = datasets.CIFAR10(data_dir, train=True, download=True, 89 | transform=transform_train) 90 | 91 | test_dataset = datasets.CIFAR10(data_dir, train=False, download=True, 92 | transform=transform_test) 93 | 94 | num_classes = 10 95 | elif args.dataset == 'cifar100': 96 | data_dir = '../data/cifar100/' 97 | 98 | train_dataset = datasets.CIFAR100(data_dir, train=True, download=True, 99 | transform=transform_train) 100 | 101 | test_dataset = datasets.CIFAR100(data_dir, train=False, download=True, 102 | transform=transform_test) 103 | 104 | num_classes = 100 105 | # sample training data amongst users 106 | if args.iid: 107 | # Sample IID user data from Mnist 108 | user_groups = cifar_iid(train_dataset, args.num_users) 109 | else: 110 | # Sample Non-IID user data from Mnist 111 | if args.unequal: 112 | # Chose uneuqal splits for every user 113 | raise NotImplementedError() 114 | else: 115 | # Chose euqal splits for every user 116 | user_groups = cifar_noniid(train_dataset, args.num_users) 117 | 118 | 119 | 120 | elif args.dataset == 'mnist' or 'fmnist': 121 | apply_transform = transforms.Compose([ 122 | transforms.ToTensor(), 123 | # transforms.Normalize((0.1307,), (0.3081,)) 124 | ]) 125 | 126 | if args.dataset == 'mnist': 127 | data_dir = '../data/mnist/' 128 | train_dataset = datasets.MNIST(data_dir, train=True, download=True, 129 | transform=apply_transform) 130 | 131 | test_dataset = datasets.MNIST(data_dir, train=False, download=True, 132 | transform=apply_transform) 133 | else: 134 | data_dir = '../data/fmnist/' 135 | train_dataset = datasets.FashionMNIST(data_dir, train=True, download=True, 136 | transform=apply_transform) 137 | 138 | test_dataset = datasets.FashionMNIST(data_dir, train=False, download=True, 139 | transform=apply_transform) 140 | 141 | 142 | train_dataset = datasets.MNIST(data_dir, train=True, download=True, 143 | transform=apply_transform) 144 | 145 | test_dataset = datasets.MNIST(data_dir, train=False, download=True, 146 | transform=apply_transform) 147 | num_classes = 10 148 | 149 | 150 | # sample training data amongst users 151 | if args.iid: 152 | # Sample IID user data from Mnist 153 | user_groups = mnist_iid(train_dataset, args.num_users) 154 | else: 155 | # Sample Non-IID user data from Mnist 156 | if args.unequal: 157 | # Chose uneuqal splits for every user 158 | user_groups = mnist_noniid_unequal(train_dataset, args.num_users) 159 | else: 160 | # Chose euqal splits for every user 161 | user_groups = mnist_noniid(train_dataset, args.num_users) 162 | 163 | 164 | 165 | return train_dataset, test_dataset, num_classes, user_groups 166 | 167 | 168 | def average_weights(w): 169 | """ 170 | Returns the average of the weights. 171 | """ 172 | w_avg = copy.deepcopy(w[0]) 173 | for key in w_avg.keys(): 174 | for i in range(1, len(w)): 175 | w_avg[key] += w[i][key] 176 | w_avg[key] = torch.div(w_avg[key], len(w)) 177 | return w_avg 178 | 179 | 180 | def average_parameter_delta(ws, w0): 181 | w_avg = copy.deepcopy(ws[0]) 182 | for key in range(len(w_avg)): 183 | w_avg[key] = torch.zeros_like(w_avg[key]) 184 | for i in range(0, len(ws)): 185 | w_avg[key] += ws[i][key] - w0[key] 186 | w_avg[key] = torch.div(w_avg[key], len(ws)) 187 | return w_avg 188 | 189 | 190 | def exp_details(args): 191 | print('\nExperimental details:') 192 | print(f' Model : {args.model}') 193 | print(f' Optimizer : {args.optimizer}') 194 | print(f' Learning : {args.lr}') 195 | print(f' Global Rounds : {args.epochs}\n') 196 | 197 | print(' Federated parameters:') 198 | if args.iid: 199 | print(' IID') 200 | else: 201 | print(' Non-IID') 202 | print(f' Fraction of users : {args.frac}') 203 | print(f' Local Batch size : {args.local_bs}') 204 | print(f' Local Epochs : {args.local_ep}\n') 205 | return 206 | 207 | 208 | def add_params(x, y): 209 | z = [] 210 | for i in range(len(x)): 211 | z.append(x[i] + y[i]) 212 | return z 213 | 214 | 215 | def sub_params(x, y): 216 | z = [] 217 | for i in range(len(x)): 218 | z.append(x[i] - y[i]) 219 | return z 220 | 221 | 222 | def mult_param(alpha, x): 223 | z = [] 224 | for i in range(len(x)): 225 | z.append(alpha*x[i]) 226 | return z 227 | 228 | 229 | def norm_of_param(x): 230 | z = 0 231 | for i in range(len(x)): 232 | z += torch.norm(x[i].flatten(0)) 233 | return z 234 | -------------------------------------------------------------------------------- /models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.transforms as transforms 3 | import math 4 | 5 | __all__ = ['wide_WResNet'] 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | def init_model(model): 15 | for m in model.modules(): 16 | if isinstance(m, nn.Conv2d): 17 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 18 | m.weight.data.normal_(0, math.sqrt(2. / n)) 19 | elif isinstance(m, nn.BatchNorm2d): 20 | m.weight.data.fill_(1) 21 | m.bias.data.zero_() 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class WResNet(nn.Module): 96 | 97 | def __init__(self): 98 | super(WResNet, self).__init__() 99 | 100 | def _make_layer(self, block, planes, blocks, stride=1): 101 | downsample = None 102 | if stride != 1 or self.inplanes != planes * block.expansion: 103 | downsample = nn.Sequential( 104 | nn.Conv2d(self.inplanes, planes * block.expansion, 105 | kernel_size=1, stride=stride, bias=False), 106 | nn.BatchNorm2d(planes * block.expansion), 107 | ) 108 | 109 | layers = [] 110 | layers.append(block(self.inplanes, planes, stride, downsample)) 111 | self.inplanes = planes * block.expansion 112 | for i in range(1, blocks): 113 | layers.append(block(self.inplanes, planes)) 114 | 115 | return nn.Sequential(*layers) 116 | 117 | def forward(self, x): 118 | x = self.feats(x) 119 | x = x.view(x.size(0), -1) 120 | x = self.fc(x) 121 | 122 | return x 123 | 124 | 125 | class WResNet_imagenet(WResNet): 126 | 127 | def __init__(self, num_classes=1000, 128 | block=Bottleneck, layers=[3, 4, 23, 3]): 129 | super(WResNet_imagenet, self).__init__() 130 | self.inplanes = 64 131 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 132 | bias=False) 133 | self.bn1 = nn.BatchNorm2d(64) 134 | self.relu = nn.ReLU(inplace=True) 135 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 136 | self.layer1 = self._make_layer(block, 64, layers[0]) 137 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 138 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 139 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 140 | self.avgpool = nn.AvgPool2d(7) 141 | self.feats = nn.Sequential(self.conv1, 142 | self.bn1, 143 | self.relu, 144 | self.maxpool, 145 | 146 | self.layer1, 147 | self.layer2, 148 | self.layer3, 149 | self.layer4, 150 | 151 | self.avgpool) 152 | self.fc = nn.Linear(512 * block.expansion, num_classes) 153 | 154 | init_model(self) 155 | self.regime = { 156 | 0: {'optimizer': 'SGD', 'lr': 1e-1, 'weight_decay': 1e-4, 'momentum': 0.9}, 157 | 30: {'lr': 1e-2}, 158 | 60: {'lr': 1e-3}, 159 | 90: {'lr': 1e-4} 160 | } 161 | 162 | 163 | class WResNet_cifar10(WResNet): 164 | 165 | def __init__(self, num_classes=10, multiplier=1, 166 | block=BasicBlock, depth=18): 167 | super(WResNet_cifar10, self).__init__() 168 | self.inplanes = 16 * multiplier 169 | n = int((depth - 2) / 6) 170 | self.conv1 = nn.Conv2d(3, 16 * multiplier, kernel_size=3, stride=1, padding=1, 171 | bias=False) 172 | self.bn1 = nn.BatchNorm2d(16 * multiplier) 173 | self.relu = nn.ReLU(inplace=True) 174 | self.maxpool = lambda x: x 175 | self.layer1 = self._make_layer(block, 16 * multiplier, n) 176 | self.layer2 = self._make_layer(block, 32 * multiplier, n, stride=2) 177 | self.layer3 = self._make_layer(block, 64 * multiplier, n, stride=2) 178 | self.layer4 = lambda x: x 179 | self.avgpool = nn.AvgPool2d(8) 180 | self.fc = nn.Linear(64 * multiplier, num_classes) 181 | self.feats = nn.Sequential(self.conv1, 182 | self.bn1, 183 | self.relu, 184 | self.layer1, 185 | self.layer2, 186 | self.layer3, 187 | self.avgpool) 188 | init_model(self) 189 | 190 | self.regime = { 191 | 0: {'optimizer': 'SGD', 'lr': 1e-1, 192 | 'weight_decay': 1e-4, 'momentum': 0.9}, 193 | 60: {'lr': 2e-2}, 194 | 120: {'lr': 4e-3}, 195 | 140: {'lr': 1e-4} 196 | } 197 | 198 | # def wideresnet_cifar(num_classes=num_classes): 199 | # return WResNet_cifar10(num_classes=num_classes, block=BasicBlock, depth=16, multiplier=4) 200 | 201 | def wide_WResNet(**kwargs): 202 | num_classes, depth, dataset = map( 203 | kwargs.get, ['num_classes', 'depth', 'dataset']) 204 | if dataset == 'imagenet': 205 | num_classes = num_classes or 1000 206 | depth = depth or 50 207 | if depth == 18: 208 | return WResNet_imagenet(num_classes=num_classes, 209 | block=BasicBlock, layers=[2, 2, 2, 2]) 210 | if depth == 34: 211 | return WResNet_imagenet(num_classes=num_classes, 212 | block=BasicBlock, layers=[3, 4, 6, 3]) 213 | if depth == 50: 214 | return WResNet_imagenet(num_classes=num_classes, 215 | block=Bottleneck, layers=[3, 4, 6, 3]) 216 | if depth == 101: 217 | return WResNet_imagenet(num_classes=num_classes, 218 | block=Bottleneck, layers=[3, 4, 23, 3]) 219 | if depth == 152: 220 | return WResNet_imagenet(num_classes=num_classes, 221 | block=Bottleneck, layers=[3, 8, 36, 3]) 222 | 223 | elif dataset == 'cifar10': 224 | num_classes = num_classes or 10 225 | depth = depth or 16 226 | return WResNet_cifar10(num_classes=num_classes, 227 | block=BasicBlock, depth=depth, multiplier=4) 228 | elif dataset == 'cifar100': 229 | num_classes = num_classes or 100 230 | depth = depth or 16 231 | return WResNet_cifar10(num_classes=num_classes, 232 | block=BasicBlock, depth=depth, multiplier=4) -------------------------------------------------------------------------------- /compressors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | import numpy as np 6 | import random, math 7 | 8 | 9 | class CompressorType: 10 | IDENTICAL = 1 # Identical compressor 11 | LAZY_COMPRESSOR = 2 # Lazy or Bernulli compressor 12 | RANDK_COMPRESSOR = 3 # Rank-K compressor 13 | NATURAL_COMPRESSOR_FP64 = 4 # Natural compressor with FP64 14 | NATURAL_COMPRESSOR_FP32 = 5 # Natural compressor with FP32 15 | STANDARD_DITHERING_FP64 = 6 # Standard dithering with FP64 16 | STANDARD_DITHERING_FP32 = 7 # Standard dithering with FP32 17 | NATURAL_DITHERING_FP32 = 8 # Natural Dithering applied for FP32 components vectors 18 | NATURAL_DITHERING_FP64 = 9 # Natural Dithering applied for FP64 components vectors 19 | TOPK_COMPRESSOR = 10 # Top-K compressor 20 | SIGN_COMPRESSOR = 11 # Sign compressor 21 | ONEBIT_SIGN_COMPRESSOR = 12 # One bit sign compressor 22 | 23 | class Compressor: 24 | def __init__(self, compressorName = ""): 25 | self.compressorName = compressorName 26 | self.compressorType = CompressorType.IDENTICAL 27 | self.w = 0.0 28 | self.last_need_to_send_advance = 0 29 | self.component_bits_size = 32 30 | 31 | def name(self): 32 | omega = r'$\omega$' 33 | if self.compressorType == CompressorType.IDENTICAL: return f"Identical" 34 | if self.compressorType == CompressorType.LAZY_COMPRESSOR: return f"Bernoulli(Lazy) [p={self.P:g},{omega}={self.getW():.1f}]" 35 | if self.compressorType == CompressorType.RANDK_COMPRESSOR: return f" (K={self.K})" 36 | if self.compressorType == CompressorType.NATURAL_COMPRESSOR_FP64: return f"Natural for fp64 [{omega}={self.getW():.1f}]" 37 | if self.compressorType == CompressorType.NATURAL_COMPRESSOR_FP32: return f"Natural for fp32 [{omega}={self.getW():.1f}]" 38 | if self.compressorType == CompressorType.STANDARD_DITHERING_FP64: return f"Standard Dithering for fp64[s={self.s}]" 39 | if self.compressorType == CompressorType.STANDARD_DITHERING_FP64: return f"Standard Dithering for fp32[s={self.s}]" 40 | if self.compressorType == CompressorType.NATURAL_DITHERING_FP32: return f"Natural Dithering for fp32[s={self.s},{omega}={self.getW():.1f}]" 41 | if self.compressorType == CompressorType.NATURAL_DITHERING_FP64: return f"Natural Dithering for fp64[s={self.s},{omega}={self.getW():.1f}]" 42 | if self.compressorType == CompressorType.TOPK_COMPRESSOR: return f" Top (K={self.K})" 43 | if self.compressorType == CompressorType.SIGN_COMPRESSOR: return f"Sign" 44 | if self.compressorType == CompressorType.ONEBIT_SIGN_COMPRESSOR: return f"One Bit Sign" 45 | return "?" 46 | 47 | def fullName(self): 48 | omega = r'$\omega$' 49 | if self.compressorType == CompressorType.IDENTICAL: return f"Identical" 50 | if self.compressorType == CompressorType.LAZY_COMPRESSOR: return f"Bernoulli(Lazy) [p={self.P:g},{omega}={self.getW():.1f}]" 51 | if self.compressorType == CompressorType.RANDK_COMPRESSOR: return f"Rand [K={self.K}]" 52 | if self.compressorType == CompressorType.NATURAL_COMPRESSOR_FP64: return f"Natural for fp64 [{omega}={self.getW():.1f}]" 53 | if self.compressorType == CompressorType.NATURAL_COMPRESSOR_FP32: return f"Natural for fp32 [{omega}={self.getW():.1f}]" 54 | if self.compressorType == CompressorType.STANDARD_DITHERING_FP64: return f"Standard Dithering for fp64[s={self.s}]" 55 | if self.compressorType == CompressorType.STANDARD_DITHERING_FP64: return f"Standard Dithering for fp32[s={self.s}]" 56 | if self.compressorType == CompressorType.NATURAL_DITHERING_FP32: return f"Natural Dithering for fp32[s={self.s},{omega}={self.getW():.1f}]" 57 | if self.compressorType == CompressorType.NATURAL_DITHERING_FP64: return f"Natural Dithering for fp64[s={self.s},{omega}={self.getW():.1f}]" 58 | if self.compressorType == CompressorType.TOPK_COMPRESSOR: return f"Top [K={self.K}]" 59 | if self.compressorType == CompressorType.SIGN_COMPRESSOR: return f"Sign" 60 | if self.compressorType == CompressorType.ONEBIT_SIGN_COMPRESSOR: return f"One Bit Sign" 61 | return "?" 62 | 63 | def resetStats(self): 64 | self.last_need_to_send_advance = 0 65 | 66 | def makeIdenticalCompressor(self): 67 | self.compressorType = CompressorType.IDENTICAL 68 | self.resetStats() 69 | 70 | def makeLazyCompressor(self, P): 71 | self.compressorType = CompressorType.LAZY_COMPRESSOR 72 | self.P = P 73 | self.w = 1.0 / P - 1.0 74 | self.resetStats() 75 | 76 | def makeStandardDitheringFP64(self, levels, vectorNormCompressor, p = float("inf")): 77 | self.compressorType = CompressorType.STANDARD_DITHERING_FP64 78 | self.levelsValues = np.arange(0.0, 1.1, 1.0/levels) # levels + 1 values in range [0.0, 1.0] which uniformly split this segment 79 | self.s = len(self.levelsValues) - 1 # # should be equal to level 80 | assert self.s == levels 81 | 82 | self.p = p 83 | self.vectorNormCompressor = vectorNormCompressor 84 | self.w = 0.0 # TODO 85 | 86 | self.resetStats() 87 | 88 | def makeStandardDitheringFP32(self, levels, vectorNormCompressor, p = float("inf")): 89 | self.compressorType = CompressorType.STANDARD_DITHERING_FP32 90 | self.levelsValues = torch.arange(0.0, 1.1, 1.0/levels) # levels + 1 values in range [0.0, 1.0] which uniformly split this segment 91 | self.s = len(self.levelsValues) - 1 # should be equal to level 92 | assert self.s == levels 93 | 94 | self.p = p 95 | self.vectorNormCompressor = vectorNormCompressor 96 | self.w = 0.0 # TODO 97 | 98 | self.resetStats() 99 | 100 | def makeQSGD_FP64(self, levels, dInput): 101 | norm_compressor = Compressor("norm_compressor") 102 | norm_compressor.makeIdenticalCompressor() 103 | self.makeStandardDitheringFP64(levels, norm_compressor, p = 2) 104 | # Lemma 3.1. from https://arxiv.org/pdf/1610.02132.pdf, page 5 105 | self.w = min(dInput/(levels*levels), dInput**0.5/levels) 106 | 107 | def makeNaturalDitheringFP64(self, levels, dInput, p = float("inf")): 108 | self.compressorType = CompressorType.NATURAL_DITHERING_FP64 109 | self.levelsValues = torch.zeros(levels + 1) 110 | for i in range(levels): 111 | self.levelsValues[i] = (1.0/2.0)**i 112 | self.levelsValues = torch.flip(self.levelsValues, dims = [0]) 113 | self.s = len(self.levelsValues) - 1 114 | assert self.s == levels 115 | 116 | self.p = p 117 | 118 | r = min(p, 2) 119 | self.w = 1.0/8.0 + (dInput** (1.0/r)) / (2**(self.s - 1)) * min(1, (dInput**(1.0/r)) / (2**(self.s-1))) 120 | self.resetStats() 121 | 122 | def makeNaturalDitheringFP32(self, levels, dInput, p = float("inf")): 123 | self.compressorType = CompressorType.NATURAL_DITHERING_FP32 124 | self.levelsValues = torch.zeros(levels + 1) 125 | for i in range(levels): 126 | self.levelsValues[i] = (1.0/2.0)**i 127 | self.levelsValues = torch.flip(self.levelsValues, dims=[0]) 128 | self.s = len(self.levelsValues) - 1 129 | assert self.s == levels 130 | 131 | self.p = p 132 | 133 | r = min(p, 2) 134 | self.w = 1.0/8.0 + (dInput** (1.0/r)) / (2**(self.s - 1)) * min(1, (dInput**(1.0/r)) / (2**(self.s-1))) 135 | self.resetStats() 136 | 137 | # K - how much component we leave from input vector 138 | def makeRandKCompressor(self, K): 139 | self.compressorType = CompressorType.RANDK_COMPRESSOR 140 | self.K = K 141 | self.resetStats() 142 | 143 | def makeTopKCompressor(self, K): 144 | self.compressorType = CompressorType.TOPK_COMPRESSOR 145 | self.K = K 146 | self.resetStats() 147 | 148 | def makeNaturalCompressorFP64(self): 149 | self.compressorType = CompressorType.NATURAL_COMPRESSOR_FP64 150 | self.w = 1.0/8.0 151 | self.resetStats() 152 | 153 | def makeNaturalCompressorFP32(self): 154 | self.compressorType = CompressorType.NATURAL_COMPRESSOR_FP32 155 | self.w = 1.0/8.0 156 | self.resetStats() 157 | 158 | def makeSignCompressor(self, freeze_iteration=0): 159 | self.compressorType = CompressorType.SIGN_COMPRESSOR 160 | self.freeze_iteration = freeze_iteration 161 | self.resetStats() 162 | 163 | def makeOneBitSignCompressor(self, freeze_iteration=0): 164 | self.compressorType = CompressorType.ONEBIT_SIGN_COMPRESSOR 165 | self.freeze_iteration = freeze_iteration 166 | self.resetStats() 167 | 168 | def getW(self): 169 | return self.w 170 | 171 | def compressVector(self, x, iteration=0): 172 | d = max(x.shape) 173 | 174 | if self.compressorType == CompressorType.IDENTICAL: 175 | out = x.clone() 176 | self.last_need_to_send_advance = d * self.component_bits_size 177 | 178 | 179 | elif self.compressorType == CompressorType.TOPK_COMPRESSOR: 180 | #S = torch.arange(d) 181 | # np.random.shuffle(S) 182 | top_size = max(int(self.K*d), 1) 183 | _, S = torch.topk(torch.abs(x), top_size) 184 | out = torch.zeros_like(x) 185 | out[S] = x[S] 186 | # !!! in real case, one needs to send the out vector and a support set to indicate the indices of top K 187 | self.last_need_to_send_advance = 2 * top_size * self.component_bits_size 188 | 189 | elif self.compressorType == CompressorType.SIGN_COMPRESSOR: 190 | if iteration < self.freeze_iteration: 191 | out = x.clone() 192 | self.last_need_to_send_advance = d * self.component_bits_size 193 | else: 194 | 195 | out = torch.sign(x) 196 | 197 | scale = torch.norm(x, p=1) / torch.numel(x) 198 | 199 | out.mul_(scale) # <-- we use this just for similation 200 | 201 | 202 | # !!! in real case, one needs to send D bits for {0, 1} and 32 bits for the scale constant 203 | self.last_need_to_send_advance = d + self.component_bits_size 204 | 205 | elif self.compressorType == CompressorType.ONEBIT_SIGN_COMPRESSOR: 206 | # according to one bit adam paper, 207 | # during warmup, the signal is not compressed 208 | if iteration < self.freeze_iteration: 209 | out = x.clone() 210 | self.last_need_to_send_advance = d * self.component_bits_size 211 | else: 212 | out = torch.sign(x) 213 | # out.add_(1).bool().float().add_(-0.5).mul_(2.0) 214 | scale = torch.norm(x) / np.sqrt(torch.numel(x)) 215 | # out = torch.cat((scale, out), 0) <-- in real case, only send a scale, and a {0,1}^D output 216 | # this is just for similate 217 | out.mul_(scale) # <-- we use this just for similation 218 | # !!! in real case, one needs to send D bits for {0, 1} and 32 bits for the scale constant 219 | self.last_need_to_send_advance = d + self.component_bits_size 220 | 221 | return out 222 | -------------------------------------------------------------------------------- /update.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import copy 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader, Dataset 10 | import compressors 11 | 12 | 13 | class DatasetSplit(Dataset): 14 | """An abstract Dataset class wrapped around Pytorch Dataset class. 15 | """ 16 | 17 | def __init__(self, dataset, idxs): 18 | self.dataset = dataset 19 | self.idxs = [int(i) for i in idxs] 20 | 21 | def __len__(self): 22 | return len(self.idxs) 23 | 24 | def __getitem__(self, item): 25 | image, label = self.dataset[self.idxs[item]] 26 | return torch.tensor(image), torch.tensor(label) 27 | 28 | 29 | class LocalUpdate(object): 30 | def __init__(self, args, dataset, idxs, logger): 31 | self.args = args 32 | self.logger = logger 33 | self.trainloader, self.validloader, self.testloader = self.train_val_test( 34 | dataset, list(idxs)) 35 | self.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else "cpu") 36 | # Default criterion set to NLL loss function 37 | self.criterion = nn.CrossEntropyLoss().to(self.device) 38 | 39 | ###### define compressors ####### 40 | self.compressor = compressors.Compressor() 41 | if args.compressor == 'identical': 42 | self.compressor.makeIdenticalCompressor() 43 | elif args.compressor == 'topk256': 44 | self.compressor.makeTopKCompressor(1/256) 45 | elif args.compressor == 'topk128': 46 | self.compressor.makeTopKCompressor(1/128) 47 | elif args.compressor == 'topk64': 48 | self.compressor.makeTopKCompressor(1/64) 49 | elif args.compressor == 'sign': 50 | self.compressor.makeSignCompressor() 51 | else: 52 | exit('unknown compressor: {}'.format(args.compressor)) 53 | 54 | 55 | 56 | def train_val_test(self, dataset, idxs): 57 | """ 58 | Returns train, validation and test dataloaders for a given dataset 59 | and user indexes. 60 | """ 61 | # split indexes for train, validation, and test (80, 10, 10) 62 | idxs_train = idxs[:int(0.8*len(idxs))] 63 | idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))] 64 | idxs_test = idxs[int(0.9*len(idxs)):] 65 | 66 | trainloader = DataLoader(DatasetSplit(dataset, idxs_train), 67 | batch_size=self.args.local_bs, shuffle=True) 68 | validloader = DataLoader(DatasetSplit(dataset, idxs_val), 69 | batch_size=int(len(idxs_val)/10), shuffle=False) 70 | testloader = DataLoader(DatasetSplit(dataset, idxs_test), 71 | batch_size=int(len(idxs_test)/10), shuffle=False) 72 | return trainloader, validloader, testloader 73 | 74 | def update_weights_local(self, model, global_round): 75 | # Set mode to train model 76 | model.train() 77 | epoch_loss = [] 78 | 79 | optimizer = torch.optim.SGD(model.parameters(), lr=self.args.local_lr, momentum=0) 80 | 81 | for iter in range(self.args.local_ep): 82 | batch_loss = [] 83 | total = 0 84 | for batch_idx, (images, labels) in enumerate(self.trainloader): 85 | images, labels = images.to(self.device), labels.to(self.device) 86 | 87 | model.zero_grad() 88 | logits = model(images) 89 | loss = self.criterion(logits, labels) 90 | loss.backward() 91 | optimizer.step() 92 | 93 | if self.args.verbose and (batch_idx % 10 == 0): 94 | print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 95 | global_round, iter, batch_idx * len(images), 96 | len(self.trainloader.dataset), 97 | 100. * batch_idx / len(self.trainloader), loss.item())) 98 | 99 | batch_loss.append(loss.item() * len(labels)) 100 | total += len(labels) 101 | epoch_loss.append(sum(batch_loss)/total) 102 | 103 | par_after = [] 104 | for p in model.parameters(): 105 | par_after.append(p.data.detach().clone()) 106 | 107 | 108 | return model.state_dict(), par_after, sum(epoch_loss) / len(epoch_loss) 109 | 110 | 111 | def compressSignal(self, signal, D): 112 | # transit_bits = 0 113 | signal_compressed = [] 114 | for p in signal: 115 | signal_compressed.append(torch.zeros_like(p)) 116 | 117 | signal_flatten = torch.zeros(D).to(self.device) 118 | 119 | signal_offset = 0 120 | for t in range(len(signal)): 121 | offset = len(signal[t].flatten(0)) 122 | signal_flatten[(signal_offset):(signal_offset + offset)] = signal[t].flatten(0) 123 | signal_offset += offset 124 | 125 | 126 | signal_flatten = self.compressor.compressVector(signal_flatten) 127 | # transit_bits += compressors.Compressor.last_need_to_send_advance 128 | 129 | signal_offset = 0 130 | for t in range(len(signal)): 131 | offset = len(signal[t].flatten(0)) 132 | signal_compressed[t].flatten(0)[:] = signal_flatten[(signal_offset):(signal_offset + offset)] 133 | signal_offset += offset 134 | 135 | return signal_compressed 136 | 137 | def compressSignal_layerwise(self, signal, D): 138 | transit_bits = 0 139 | # signal_compressed = [] 140 | for p in signal: 141 | signal_compressed.append(torch.zeros_like(p)) 142 | 143 | signal_flatten = torch.zeros(D).to(self.device) 144 | 145 | signal_offset = 0 146 | for t in range(len(signal)): 147 | offset = len(signal[t].flatten(0)) 148 | signal_flatten[(signal_offset):(signal_offset + offset)] = self.compressor.compressVector(signal[t].flatten(0), self.iteration) 149 | # transit_bits += compressors.Compressor.last_need_to_send_advance 150 | signal_offset += offset 151 | 152 | signal_offset = 0 153 | for t in range(len(signal)): 154 | offset = len(signal[t].flatten(0)) 155 | signal_compressed[t].flatten(0)[:] = signal_flatten[(signal_offset):(signal_offset + offset)] 156 | signal_offset += offset 157 | 158 | return signal_compressed 159 | 160 | 161 | 162 | def update_weights(self, model, global_round): 163 | # Set mode to train model 164 | model.train() 165 | epoch_loss = [] 166 | 167 | optimizer = torch.optim.SGD(model.parameters(), lr=self.args.local_lr, momentum=0) 168 | 169 | for iter in range(self.args.local_ep): 170 | batch_loss = [] 171 | for batch_idx, (images, labels) in enumerate(self.trainloader): 172 | images, labels = images.to(self.device), labels.to(self.device) 173 | 174 | model.zero_grad() 175 | log_probs = model(images) 176 | loss = self.criterion(log_probs, labels) 177 | loss.backward() 178 | optimizer.step() 179 | 180 | if self.args.verbose and (batch_idx % 10 == 0): 181 | print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 182 | global_round, iter, batch_idx * len(images), 183 | len(self.trainloader.dataset), 184 | 100. * batch_idx / len(self.trainloader), loss.item())) 185 | # self.logger.add_scalar('loss', loss.item()) 186 | batch_loss.append(loss.item()/len(labels)) 187 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 188 | 189 | par_after = [] 190 | for p in model.parameters(): 191 | par_after.append(p.data.detach().clone()) 192 | 193 | return par_after, sum(epoch_loss) / len(epoch_loss) 194 | 195 | def inference(self, model): 196 | """ Returns the inference accuracy and loss. 197 | """ 198 | 199 | model.eval() 200 | loss, total, correct = 0.0, 0.0, 0.0 201 | 202 | for batch_idx, (images, labels) in enumerate(self.testloader): 203 | images, labels = images.to(self.device), labels.to(self.device) 204 | 205 | # Inference 206 | outputs = model(images) 207 | batch_loss = self.criterion(outputs, labels) 208 | loss += batch_loss.item() * len(labels) 209 | 210 | # Prediction 211 | _, pred_labels = torch.max(outputs, 1) 212 | pred_labels = pred_labels.view(-1) 213 | correct += torch.sum(torch.eq(pred_labels, labels)).item() 214 | total += len(labels) 215 | 216 | accuracy = correct/total 217 | loss = loss/total 218 | return accuracy, loss 219 | 220 | 221 | def update_model_inplace(model, par_before, delta, args, cur_iter, momentum_buffer_list, exp_avgs, exp_avg_sqs, max_exp_avg_sqs): 222 | grads = copy.deepcopy(delta) 223 | 224 | # learning rate decay 225 | iteration = cur_iter + 1 # add 1 is to make sure nonzero denominator in adam calculation 226 | # if iteration < int(args.epochs/2): 227 | # lr_decay = 1.0 228 | # elif iteration < int(3*args.epochs/4): 229 | # lr_decay = 0.1 230 | # else: 231 | # lr_decay = 0.01 232 | lr_decay=1.0 233 | 234 | for i, param in enumerate(model.parameters()): 235 | grad = grads[i] # recieve the aggregated (averaged) gradient 236 | 237 | # SGD calculation 238 | if args.optimizer == 'fedavg': 239 | # need to reset the trainable parameter 240 | # because we have updated the model via state_dict when dealing with batch normalization 241 | param.data.add_(param.data, alpha=-1).add_(par_before[i], alpha=1).add_(grad, alpha=args.lr * lr_decay) 242 | # param.data.add_(grad, alpha=args.lr * lr_decay) 243 | # SGD+momentum calculation 244 | elif args.optimizer == 'fedavgm': 245 | buf = momentum_buffer_list[i] 246 | buf.mul_(args.momentum).add_(grad, alpha=1) 247 | grad = buf 248 | param.data.add_(param.data, alpha=-1).add_(par_before[i], alpha=1).add_(grad, alpha=args.lr * lr_decay) 249 | # adam calculation 250 | elif args.optimizer == 'fedadam': 251 | exp_avg = exp_avgs[i] 252 | exp_avg_sq = exp_avg_sqs[i] 253 | 254 | bias_correction1 = 1 - args.beta1 ** iteration 255 | bias_correction2 = 1 - args.beta2 ** iteration 256 | 257 | exp_avg.mul_(args.beta1).add_(grad, alpha=1 - args.beta1) 258 | exp_avg_sq.mul_(args.beta2).addcmul_(grad, grad.conj(), value=1 - args.beta2) 259 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(args.eps) # without maximum 260 | 261 | step_size = args.lr * lr_decay / bias_correction1 262 | 263 | param.data.add_(param.data, alpha=-1).add_(par_before[i], alpha=1).addcdiv_(exp_avg, denom, value=step_size) 264 | elif args.optimizer == 'fedams': 265 | exp_avg = exp_avgs[i] 266 | exp_avg_sq = exp_avg_sqs[i] 267 | 268 | bias_correction1 = 1 - args.beta1 ** iteration 269 | bias_correction2 = 1 - args.beta2 ** iteration 270 | 271 | exp_avg.mul_(args.beta1).add_(grad, alpha=1 - args.beta1) 272 | exp_avg_sq.mul_(args.beta2).addcmul_(grad, grad.conj(), value=1 - args.beta2) 273 | torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) 274 | denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(args.eps) 275 | 276 | step_size = args.lr * lr_decay / bias_correction1 277 | 278 | param.data.add_(param.data, alpha=-1).add_(par_before[i], alpha=1).addcdiv_(exp_avg, denom, value=step_size) 279 | elif args.optimizer == 'fedamsd': 280 | lr_decay=1.0/math.sqrt(iteration) 281 | 282 | exp_avg = exp_avgs[i] 283 | exp_avg_sq = exp_avg_sqs[i] 284 | 285 | bias_correction1 = 1 - args.beta1 ** iteration 286 | bias_correction2 = 1 - args.beta2 ** iteration 287 | 288 | exp_avg.mul_(args.beta1).add_(grad, alpha=1 - args.beta1) 289 | exp_avg_sq.mul_(args.beta2).addcmul_(grad, grad.conj(), value=1 - args.beta2) 290 | torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) 291 | denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(args.eps) 292 | 293 | step_size = args.lr * lr_decay / bias_correction1 294 | 295 | param.data.add_(param.data, alpha=-1).add_(par_before[i], alpha=1).addcdiv_(exp_avg, denom, value=step_size) 296 | elif args.optimizer == 'fedadagrad': 297 | exp_avg_sq = exp_avg_sqs[i] 298 | exp_avg_sq.addcmul_(1, grad, grad) 299 | param.data.add_(param.data, alpha=-1).add_(par_before[i], alpha=1).addcdiv_(grad, exp_avg_sq.sqrt().add_(args.eps), value=args.lr * lr_decay) 300 | elif args.optimizer == 'fedyogi': 301 | exp_avg = exp_avgs[i] 302 | exp_avg_sq = exp_avg_sqs[i] 303 | 304 | bias_correction1 = 1 - args.beta1 ** iteration 305 | bias_correction2 = 1 - args.beta2 ** iteration 306 | 307 | exp_avg.mul_(args.beta1).add_(grad, alpha=1 - args.beta1) 308 | tmp_sq = grad ** 2 309 | tmp_diff = exp_avg_sq - tmp_sq 310 | exp_avg_sq.add_( - (1 - args.beta2), torch.sign(tmp_diff) * tmp_sq) 311 | 312 | denom = exp_avg_sq.sqrt().add_(args.eps) 313 | 314 | step_size = args.lr * lr_decay * math.sqrt(bias_correction2) / bias_correction1 315 | 316 | param.data.add_(param.data, alpha=-1).add_(par_before[i], alpha=1).addcdiv_(exp_avg, denom, value=step_size) 317 | 318 | else: 319 | exit('unknown optimizer: {}'.format(args.optimizer)) 320 | 321 | 322 | def test_inference(args, model, test_dataset): 323 | """ Returns the test accuracy and loss. 324 | """ 325 | 326 | model.eval() 327 | loss, total, correct = 0.0, 0.0, 0.0 328 | 329 | device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else "cpu") 330 | criterion = nn.CrossEntropyLoss().to(device) 331 | testloader = DataLoader(test_dataset, batch_size=128, 332 | shuffle=False) 333 | 334 | for batch_idx, (images, labels) in enumerate(testloader): 335 | images, labels = images.to(device), labels.to(device) 336 | 337 | # Inference 338 | outputs = model(images) 339 | batch_loss = criterion(outputs, labels) 340 | loss += batch_loss.item() * len(labels) 341 | 342 | # Prediction 343 | _, pred_labels = torch.max(outputs, 1) 344 | pred_labels = pred_labels.view(-1) 345 | correct += torch.sum(torch.eq(pred_labels, labels)).item() 346 | total += len(labels) 347 | 348 | accuracy = correct/total 349 | loss = loss/total 350 | return accuracy, loss 351 | 352 | 353 | --------------------------------------------------------------------------------