├── MD_AirComp_FEEL ├── models │ ├── convmixer.py │ ├── vgg.py │ ├── resnet.py │ ├── simple.py │ ├── randaug.py │ └── wideresnet.py ├── options.py ├── resnet_s.py ├── utils.py ├── main_error_free_channel.py ├── compressors.py ├── sampling.py ├── update.py └── main_MIMO_channel.py ├── LICENSE └── README.md /MD_AirComp_FEEL/models/convmixer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | 4 | class Residual(nn.Module): 5 | def __init__(self, fn): 6 | super().__init__() 7 | self.fn = fn 8 | 9 | def forward(self, x): 10 | return self.fn(x) + x 11 | 12 | 13 | def ConvMixer(dim=256, depth=8, kernel_size=5, patch_size=2, n_classes=10): 14 | return nn.Sequential( 15 | nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size), 16 | nn.GELU(), 17 | nn.BatchNorm2d(dim), 18 | *[nn.Sequential( 19 | Residual(nn.Sequential( 20 | nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), 21 | nn.GELU(), 22 | nn.BatchNorm2d(dim) 23 | )), 24 | nn.Conv2d(dim, dim, kernel_size=1), 25 | nn.GELU(), 26 | nn.BatchNorm2d(dim) 27 | ) for i in range(depth)], 28 | nn.AdaptiveAvgPool2d((1,1)), 29 | nn.Flatten(), 30 | nn.Linear(dim, n_classes) 31 | ) 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LI QIAO 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MD_AirComp_FEEL/models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | cfg = { 8 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 10 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 11 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 12 | } 13 | 14 | 15 | class VGG(nn.Module): 16 | def __init__(self, vgg_name, num_classes = 10): 17 | super(VGG, self).__init__() 18 | self.features = self._make_layers(cfg[vgg_name]) 19 | self.classifier = nn.Linear(512, num_classes) 20 | 21 | def forward(self, x): 22 | out = self.features(x) 23 | out = out.view(out.size(0), -1) 24 | out = self.classifier(out) 25 | return out 26 | 27 | def _make_layers(self, cfg): 28 | layers = [] 29 | in_channels = 3 30 | for x in cfg: 31 | if x == 'M': 32 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 33 | else: 34 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(x), 36 | nn.ReLU(inplace=True)] 37 | in_channels = x 38 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 39 | return nn.Sequential(*layers) 40 | 41 | # net = VGG('VGG11') 42 | # x = torch.randn(2,3,32,32) 43 | # print(net(Variable(x)).size()) 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MD-AirComp 2 | This repository contains code for [Massive Digital Over-the-Air Computation for Communication-Efficient Federated Edge Learning](https://arxiv.org/abs/2405.15969). The project focuses on uplink multi-user MIMO over-the-air computation systems and its application in federated edge learning. 3 | 4 | ## Installation 5 | 6 | ### Prerequisites 7 | 8 | - Python >= 3.8 9 | - Anaconda or Miniconda (for creating virtual environments) 10 | 11 | ### Setup 12 | 13 | 1. Create a Conda environment: 14 | 15 | ```bash 16 | conda create --name MDAirComp python=3.8 17 | conda activate MDAirComp 18 | 19 | 2. Install required packages: 20 | 21 | ```bash 22 | pip install numpy tqdm pandas tensorboardX scikit-learn faiss-cpu scipy torchvision 23 | 24 | 3. To run the MIMO channel simulation: 25 | 26 | ```bash 27 | python main_MIMO_channel.py 28 | 29 | 4. To run the error-free channel simulation: 30 | 31 | ```bash 32 | python main_error_free_channel.py 33 | 34 | ### Visit Our Lab 35 | 36 | Check out our lab's research at [GaoZhen Lab](https://gaozhen16.github.io/) or [IPC Lab, Imperial College](https://www.imperial.ac.uk/information-processing-and-communications-lab/publications/). Discover our latest projects and meet the team! 37 | 38 | ### Citations 39 | 40 | If you find this project useful, please cite the related original paper as: 41 | 42 | ``` 43 | @article{qiao2024massive, 44 | title={Massive digital over-the-air computation for communication-efficient federated edge learning}, 45 | author={Qiao, Li and Gao, Zhen and Mashhadi, Mahdi Boloursaz and G{\"u}und{\"u}z, Deniz}, 46 | journal={IEEE Journal on Selected Areas in Communications}, 47 | year={2024}, 48 | publisher={IEEE} 49 | } 50 | 51 | @inproceedings{qiao2023unsourced, 52 | title={Unsourced massive access-based digital over-the-air computation for efficient federated edge learning}, 53 | author={Qiao, Li and Gao, Zhen and Li, Zhongxiang and G{\"u}nd{\"u}z, Deniz}, 54 | booktitle={2023 IEEE International Symposium on Information Theory (ISIT)}, 55 | pages={2003--2008}, 56 | year={2023}, 57 | organization={IEEE} 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /MD_AirComp_FEEL/options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import argparse 6 | import numpy as np 7 | 8 | def args_parser(): 9 | parser = argparse.ArgumentParser() 10 | 11 | # federated arguments (Notation for the arguments followed from paper) 12 | parser.add_argument('--epochs', type=int, default=500, 13 | help="number of rounds of training") 14 | # parser.add_argument('--num_users', type=int, default=100, 15 | # help="number of users: K") 16 | # parser.add_argument('--frac', type=float, default=0.5, 17 | # help='the fraction of clients: C') 18 | parser.add_argument('--num_users', type=int, default=40, 19 | help="number of users: K") 20 | parser.add_argument('--frac', type=float, default=0.3, 21 | help='the fraction of clients: C') 22 | parser.add_argument('--local_ep', type=int, default=3, 23 | help="the number of local epochs: E") 24 | parser.add_argument('--local_bs', type=int, default=20, 25 | help="local batch size: B") 26 | parser.add_argument('--Vb', type=int, default=20, 27 | help="Vb-dimensional quantization") 28 | parser.add_argument('--M', type=int, default=2**6, 29 | help="quantization level") 30 | parser.add_argument('--L', type=int, default=64, # int(args.num_users * args.frac) int(6 * 2**args.Qbit/8) 31 | help="Length of each codeword in the codebook") 32 | args = parser.parse_args() 33 | Rl = np.random.rand(args.L, args.M).astype(np.float32) 34 | Im = np.random.rand(args.L, args.M).astype(np.float32) 35 | Rl[Rl >= 0.5] = 1 36 | Rl[Rl < 0.5] = -1 37 | Im[Im >= 0.5] = 1 38 | Im[Im < 0.5] = -1 39 | UMmat = (1 / np.sqrt(2)) * (Rl + 1j * Im) 40 | parser.add_argument('--UM', type=np.float32, default=UMmat, help="Length of each codeword in the codebook") # help="Length of each codeword in the codebook") 41 | # parser.add_argument('--UM', type=np.float32, default=(1 / np.sqrt(2)) * (np.random.randn(args.L, args.M).astype(np.float32) + 1j * np.random.randn(args.L, args.M).astype(np.float32)), 42 | # help="Length of each codeword in the codebook") 43 | parser.add_argument('--SNR', type=int, default=20, help="SNR") 44 | parser.add_argument('--local_lr', type=float, default=0.01, 45 | help='learning rate for local update') 46 | parser.add_argument('--lr', type=float, default=1.0, 47 | help='learning rate for global update') 48 | parser.add_argument('--momentum', type=float, default=0.0, 49 | help='SGD momentum (default: 0.0)') 50 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam') 51 | parser.add_argument('--beta2', type=float, default=0.99, help='beta2 for adam') 52 | parser.add_argument('--eps', type=float, default=0, help='eps for adam') 53 | parser.add_argument('--max_init', type=float, default=1e-3, help='initialize max_v for adam') 54 | 55 | 56 | 57 | # model arguments 58 | parser.add_argument('--model', type=str, default='convmixer', help='model name') 59 | 60 | # other arguments 61 | parser.add_argument('--dataset', type=str, default='cifar10', help="name \ 62 | of dataset") 63 | parser.add_argument('--num_classes', type=int, default=10, help="number \ 64 | of classes") 65 | 66 | parser.add_argument('--gpu', default=0, help="To use cuda, set \ 67 | to a specific GPU ID. Default set to use CPU.") 68 | parser.add_argument('--optimizer', type=str, default='fedavg', help="type \ 69 | of optimizer") 70 | parser.add_argument('--iid', type=int, default=1, 71 | help='Default set to IID. Set to 0 for non-IID.') 72 | parser.add_argument('--unequal', type=int, default=0, 73 | help='whether to use unequal data splits for \ 74 | non-i.i.d setting (use 0 for equal splits)') 75 | parser.add_argument('--stopping_rounds', type=int, default=10, 76 | help='rounds of early stopping') 77 | parser.add_argument('--verbose', type=int, default=0, help='verbose') 78 | parser.add_argument('--seed', type=int, default=1, help='random seed') 79 | parser.add_argument('--save', type=int, default=1, help='whether to save results') 80 | parser.add_argument('--outfolder', type=str, default='./results') 81 | 82 | parser.add_argument('--compressor', type=str, default='Kmeans', help='compressor strategy') 83 | args = parser.parse_args() 84 | return args 85 | -------------------------------------------------------------------------------- /MD_AirComp_FEEL/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from torch.autograd import Variable 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion*planes) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = self.bn2(self.conv2(out)) 36 | out += self.shortcut(x) 37 | out = F.relu(out) 38 | return out 39 | 40 | 41 | class Bottleneck(nn.Module): 42 | expansion = 4 43 | 44 | def __init__(self, in_planes, planes, stride=1): 45 | super(Bottleneck, self).__init__() 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 52 | 53 | self.shortcut = nn.Sequential() 54 | if stride != 1 or in_planes != self.expansion*planes: 55 | self.shortcut = nn.Sequential( 56 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 57 | nn.BatchNorm2d(self.expansion*planes) 58 | ) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.bn1(self.conv1(x))) 62 | out = F.relu(self.bn2(self.conv2(out))) 63 | out = self.bn3(self.conv3(out)) 64 | out += self.shortcut(x) 65 | out = F.relu(out) 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, block, num_blocks, num_classes=10): 71 | super(ResNet, self).__init__() 72 | self.in_planes = 64 73 | 74 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(64) 76 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 77 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 78 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 79 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 80 | self.linear = nn.Linear(512*block.expansion, num_classes) 81 | 82 | def _make_layer(self, block, planes, num_blocks, stride): 83 | strides = [stride] + [1]*(num_blocks-1) 84 | layers = [] 85 | for stride in strides: 86 | layers.append(block(self.in_planes, planes, stride)) 87 | self.in_planes = planes * block.expansion 88 | return nn.Sequential(*layers) 89 | 90 | def forward(self, x): 91 | out = F.relu(self.bn1(self.conv1(x))) 92 | out = self.layer1(out) 93 | out = self.layer2(out) 94 | out = self.layer3(out) 95 | out = self.layer4(out) 96 | out = F.avg_pool2d(out, 4) 97 | out = out.view(out.size(0), -1) 98 | out = self.linear(out) 99 | return out 100 | 101 | 102 | def ResNet18(num_classes = 10): 103 | return ResNet(BasicBlock, [2,2,2,2], num_classes = num_classes) 104 | 105 | def ResNet34(num_classes = 10): 106 | return ResNet(BasicBlock, [3,4,6,3], num_classes = num_classes) 107 | 108 | def ResNet50(num_classes = 10): 109 | return ResNet(Bottleneck, [3,4,6,3], num_classes = num_classes) 110 | 111 | def ResNet101(num_classes = 10): 112 | return ResNet(Bottleneck, [3,4,23,3], num_classes = num_classes) 113 | 114 | def ResNet152(num_classes = 10): 115 | return ResNet(Bottleneck, [3,8,36,3], num_classes = num_classes) 116 | 117 | 118 | def test(): 119 | net = ResNet18() 120 | y = net(Variable(torch.randn(1,3,32,32))) 121 | print(y.size()) 122 | 123 | # test() 124 | -------------------------------------------------------------------------------- /MD_AirComp_FEEL/models/simple.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class MLP(nn.Module): 10 | def __init__(self, dim_in, dim_hidden, dim_out): 11 | super(MLP, self).__init__() 12 | self.layer_input = nn.Linear(dim_in, dim_hidden) 13 | self.relu = nn.ReLU() 14 | self.dropout = nn.Dropout() 15 | self.layer_hidden = nn.Linear(dim_hidden, dim_out) 16 | self.softmax = nn.Softmax(dim=1) 17 | 18 | def forward(self, x): 19 | x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1]) 20 | x = self.layer_input(x) 21 | x = self.dropout(x) 22 | x = self.relu(x) 23 | x = self.layer_hidden(x) 24 | return x 25 | 26 | 27 | class CNNMnist(nn.Module): 28 | def __init__(self, num_classes, num_channels): 29 | super(CNNMnist, self).__init__() 30 | self.conv1 = nn.Conv2d(num_channels, 10, kernel_size=5) 31 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 32 | self.conv2_drop = nn.Dropout2d() 33 | self.fc1 = nn.Linear(320, 50) 34 | self.fc2 = nn.Linear(50, num_classes) 35 | 36 | def forward(self, x): 37 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 38 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 39 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 40 | x = F.relu(self.fc1(x)) 41 | x = F.dropout(x, training=self.training) 42 | x = self.fc2(x) 43 | return x 44 | 45 | 46 | class CNNFashion_Mnist(nn.Module): 47 | def __init__(self, num_classes): 48 | super(CNNFashion_Mnist, self).__init__() 49 | self.layer1 = nn.Sequential( 50 | nn.Conv2d(1, 16, kernel_size=5, padding=2), 51 | nn.BatchNorm2d(16), 52 | nn.ReLU(), 53 | nn.MaxPool2d(2)) 54 | self.layer2 = nn.Sequential( 55 | nn.Conv2d(16, 32, kernel_size=5, padding=2), 56 | nn.BatchNorm2d(32), 57 | nn.ReLU(), 58 | nn.MaxPool2d(2)) 59 | self.fc = nn.Linear(7*7*32, num_classes) 60 | 61 | def forward(self, x): 62 | out = self.layer1(x) 63 | out = self.layer2(out) 64 | out = out.view(out.size(0), -1) 65 | out = self.fc(out) 66 | return out 67 | 68 | 69 | class CNNCifar(nn.Module): 70 | def __init__(self, num_classes): 71 | super(CNNCifar, self).__init__() 72 | self.conv1 = nn.Conv2d(3, 6, 5) 73 | self.pool = nn.MaxPool2d(2, 2) 74 | self.conv2 = nn.Conv2d(6, 16, 5) 75 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 76 | self.fc2 = nn.Linear(120, 84) 77 | self.fc3 = nn.Linear(84, num_classes) 78 | 79 | def forward(self, x): 80 | x = self.pool(F.relu(self.conv1(x))) 81 | x = self.pool(F.relu(self.conv2(x))) 82 | x = x.view(-1, 16 * 5 * 5) 83 | x = F.relu(self.fc1(x)) 84 | x = F.relu(self.fc2(x)) 85 | out = self.fc3(x) 86 | return out 87 | 88 | 89 | class CNNLarge(nn.Module): 90 | def __init__(self): 91 | super().__init__() 92 | self.network = nn.Sequential( 93 | nn.Conv2d(3, 32, kernel_size=3, padding=1), 94 | nn.ReLU(), 95 | nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(), 97 | nn.MaxPool2d(2, 2), # output: 64 x 16 x 16 98 | nn.BatchNorm2d(64), 99 | 100 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), 101 | nn.ReLU(), 102 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), 103 | nn.ReLU(), 104 | nn.MaxPool2d(2, 2), # output: 128 x 8 x 8 105 | nn.BatchNorm2d(128), 106 | 107 | nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), 108 | nn.ReLU(), 109 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), 110 | nn.ReLU(), 111 | nn.MaxPool2d(2, 2), # output: 256 x 4 x 4 112 | nn.BatchNorm2d(256), 113 | 114 | nn.Flatten(), 115 | nn.Linear(256*4*4, 1024), 116 | nn.ReLU(), 117 | nn.Linear(1024, 512), 118 | nn.ReLU(), 119 | nn.Linear(512, 10)) 120 | 121 | def forward(self, xb): 122 | return self.network(xb) 123 | 124 | 125 | 126 | class Autoencoder(nn.Module): 127 | def __init__(self): 128 | super(Autoencoder,self).__init__() 129 | self.encoder = nn.Sequential( 130 | # 28 x 28 131 | nn.Conv2d(1, 4, kernel_size=5), 132 | # 4 x 24 x 24 133 | nn.ReLU(True), 134 | nn.Conv2d(4, 8, kernel_size=5), 135 | nn.ReLU(True), 136 | # 8 x 20 x 20 = 3200 137 | nn.Flatten(), 138 | nn.Linear(3200, 10), 139 | # 10 140 | nn.Softmax(), 141 | ) 142 | self.decoder = nn.Sequential( 143 | # 10 144 | nn.Linear(10, 400), 145 | # 400 146 | nn.ReLU(True), 147 | nn.Linear(400, 4000), 148 | # 4000 149 | nn.ReLU(True), 150 | nn.Unflatten(1, (10, 20, 20)), 151 | # 10 x 20 x 20 152 | nn.ConvTranspose2d(10, 10, kernel_size=5), 153 | # 24 x 24 154 | nn.ConvTranspose2d(10, 1, kernel_size=5), 155 | # 28 x 28 156 | nn.Sigmoid(), 157 | ) 158 | def forward(self, x): 159 | enc = self.encoder(x) 160 | dec = self.decoder(enc) 161 | return dec -------------------------------------------------------------------------------- /MD_AirComp_FEEL/resnet_s.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 3 | 4 | The implementation and structure of this file is hugely influenced by [2] 5 | which is implemented for ImageNet and doesn't have option A for identity. 6 | Moreover, most of the implementations on the web is copy-paste from 7 | torchvision's resnet and has wrong number of params. 8 | 9 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 10 | number of layers and parameters: 11 | 12 | name | layers | params 13 | ResNet20 | 20 | 0.27M 14 | ResNet32 | 32 | 0.46M 15 | ResNet44 | 44 | 0.66M 16 | ResNet56 | 56 | 0.85M 17 | ResNet110 | 110 | 1.7M 18 | ResNet1202| 1202 | 19.4m 19 | 20 | which this implementation indeed has. 21 | 22 | Reference: 23 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 24 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 25 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 26 | 27 | If you use this implementation in you work, please don't forget to mention the 28 | author, Yerlan Idelbayev. 29 | ''' 30 | import torch 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | import torch.nn.init as init 34 | 35 | from torch.autograd import Variable 36 | 37 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 38 | 39 | def _weights_init(m): 40 | classname = m.__class__.__name__ 41 | #print(classname) 42 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 43 | init.kaiming_normal_(m.weight) 44 | 45 | class LambdaLayer(nn.Module): 46 | def __init__(self, lambd): 47 | super(LambdaLayer, self).__init__() 48 | self.lambd = lambd 49 | 50 | def forward(self, x): 51 | return self.lambd(x) 52 | 53 | 54 | class BasicBlock(nn.Module): 55 | expansion = 1 56 | 57 | def __init__(self, in_planes, planes, stride=1, option='A'): 58 | super(BasicBlock, self).__init__() 59 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 62 | self.bn2 = nn.BatchNorm2d(planes) 63 | 64 | self.shortcut = nn.Sequential() 65 | if stride != 1 or in_planes != planes: 66 | if option == 'A': 67 | """ 68 | For CIFAR10 ResNet paper uses option A. 69 | """ 70 | self.shortcut = LambdaLayer(lambda x: 71 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 72 | elif option == 'B': 73 | self.shortcut = nn.Sequential( 74 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 75 | nn.BatchNorm2d(self.expansion * planes) 76 | ) 77 | 78 | def forward(self, x): 79 | out = F.relu(self.bn1(self.conv1(x))) 80 | out = self.bn2(self.conv2(out)) 81 | out += self.shortcut(x) 82 | out = F.relu(out) 83 | return out 84 | 85 | 86 | class ResNet(nn.Module): 87 | def __init__(self, block, num_blocks, num_classes=10): 88 | super(ResNet, self).__init__() 89 | self.in_planes = 16 90 | 91 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 92 | self.bn1 = nn.BatchNorm2d(16) 93 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 94 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 95 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 96 | self.linear = nn.Linear(64, num_classes) 97 | 98 | self.apply(_weights_init) 99 | 100 | def _make_layer(self, block, planes, num_blocks, stride): 101 | strides = [stride] + [1]*(num_blocks-1) 102 | layers = [] 103 | for stride in strides: 104 | layers.append(block(self.in_planes, planes, stride)) 105 | self.in_planes = planes * block.expansion 106 | 107 | return nn.Sequential(*layers) 108 | 109 | def forward(self, x): 110 | out = F.relu(self.bn1(self.conv1(x))) 111 | out = self.layer1(out) 112 | out = self.layer2(out) 113 | out = self.layer3(out) 114 | out = F.avg_pool2d(out, out.size()[3]) 115 | out = out.view(out.size(0), -1) 116 | out = self.linear(out) 117 | return out 118 | 119 | 120 | def resnet20(): 121 | return ResNet(BasicBlock, [3, 3, 3]) 122 | 123 | 124 | def resnet32(): 125 | return ResNet(BasicBlock, [5, 5, 5]) 126 | 127 | 128 | def resnet44(): 129 | return ResNet(BasicBlock, [7, 7, 7]) 130 | 131 | 132 | def resnet56(): 133 | return ResNet(BasicBlock, [9, 9, 9]) 134 | 135 | 136 | def resnet110(): 137 | return ResNet(BasicBlock, [18, 18, 18]) 138 | 139 | 140 | def resnet1202(): 141 | return ResNet(BasicBlock, [200, 200, 200]) 142 | 143 | 144 | def test(net): 145 | import numpy as np 146 | total_params = 0 147 | 148 | for x in filter(lambda p: p.requires_grad, net.parameters()): 149 | total_params += np.prod(x.data.numpy().shape) 150 | print("Total number of params", total_params) 151 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 152 | 153 | 154 | if __name__ == "__main__": 155 | for net_name in __all__: 156 | if net_name.startswith('resnet'): 157 | print(net_name) 158 | test(globals()[net_name]()) 159 | print() -------------------------------------------------------------------------------- /MD_AirComp_FEEL/models/randaug.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | 11 | def ShearX(img, v): # [-0.3, 0.3] 12 | assert -0.3 <= v <= 0.3 13 | if random.random() > 0.5: 14 | v = -v 15 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 16 | 17 | 18 | def ShearY(img, v): # [-0.3, 0.3] 19 | assert -0.3 <= v <= 0.3 20 | if random.random() > 0.5: 21 | v = -v 22 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 23 | 24 | 25 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 26 | assert -0.45 <= v <= 0.45 27 | if random.random() > 0.5: 28 | v = -v 29 | v = v * img.size[0] 30 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 31 | 32 | 33 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 34 | assert 0 <= v 35 | if random.random() > 0.5: 36 | v = -v 37 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 38 | 39 | 40 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 41 | assert -0.45 <= v <= 0.45 42 | if random.random() > 0.5: 43 | v = -v 44 | v = v * img.size[1] 45 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 46 | 47 | 48 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 49 | assert 0 <= v 50 | if random.random() > 0.5: 51 | v = -v 52 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 53 | 54 | 55 | def Rotate(img, v): # [-30, 30] 56 | assert -30 <= v <= 30 57 | if random.random() > 0.5: 58 | v = -v 59 | return img.rotate(v) 60 | 61 | 62 | def AutoContrast(img, _): 63 | return PIL.ImageOps.autocontrast(img) 64 | 65 | 66 | def Invert(img, _): 67 | return PIL.ImageOps.invert(img) 68 | 69 | 70 | def Equalize(img, _): 71 | return PIL.ImageOps.equalize(img) 72 | 73 | 74 | def Flip(img, _): # not from the paper 75 | return PIL.ImageOps.mirror(img) 76 | 77 | 78 | def Solarize(img, v): # [0, 256] 79 | assert 0 <= v <= 256 80 | return PIL.ImageOps.solarize(img, v) 81 | 82 | 83 | def SolarizeAdd(img, addition=0, threshold=128): 84 | img_np = np.array(img).astype(np.int) 85 | img_np = img_np + addition 86 | img_np = np.clip(img_np, 0, 255) 87 | img_np = img_np.astype(np.uint8) 88 | img = Image.fromarray(img_np) 89 | return PIL.ImageOps.solarize(img, threshold) 90 | 91 | 92 | def Posterize(img, v): # [4, 8] 93 | v = int(v) 94 | v = max(1, v) 95 | return PIL.ImageOps.posterize(img, v) 96 | 97 | 98 | def Contrast(img, v): # [0.1,1.9] 99 | assert 0.1 <= v <= 1.9 100 | return PIL.ImageEnhance.Contrast(img).enhance(v) 101 | 102 | 103 | def Color(img, v): # [0.1,1.9] 104 | assert 0.1 <= v <= 1.9 105 | return PIL.ImageEnhance.Color(img).enhance(v) 106 | 107 | 108 | def Brightness(img, v): # [0.1,1.9] 109 | assert 0.1 <= v <= 1.9 110 | return PIL.ImageEnhance.Brightness(img).enhance(v) 111 | 112 | 113 | def Sharpness(img, v): # [0.1,1.9] 114 | assert 0.1 <= v <= 1.9 115 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 116 | 117 | 118 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 119 | assert 0.0 <= v <= 0.2 120 | if v <= 0.: 121 | return img 122 | 123 | v = v * img.size[0] 124 | return CutoutAbs(img, v) 125 | 126 | 127 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 128 | # assert 0 <= v <= 20 129 | if v < 0: 130 | return img 131 | w, h = img.size 132 | x0 = np.random.uniform(w) 133 | y0 = np.random.uniform(h) 134 | 135 | x0 = int(max(0, x0 - v / 2.)) 136 | y0 = int(max(0, y0 - v / 2.)) 137 | x1 = min(w, x0 + v) 138 | y1 = min(h, y0 + v) 139 | 140 | xy = (x0, y0, x1, y1) 141 | color = (125, 123, 114) 142 | # color = (0, 0, 0) 143 | img = img.copy() 144 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 145 | return img 146 | 147 | 148 | def SamplePairing(imgs): # [0, 0.4] 149 | def f(img1, v): 150 | i = np.random.choice(len(imgs)) 151 | img2 = PIL.Image.fromarray(imgs[i]) 152 | return PIL.Image.blend(img1, img2, v) 153 | 154 | return f 155 | 156 | 157 | def Identity(img, v): 158 | return img 159 | 160 | 161 | def augment_list(): # 16 oeprations and their ranges 162 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 163 | # l = [ 164 | # (Identity, 0., 1.0), 165 | # (ShearX, 0., 0.3), # 0 166 | # (ShearY, 0., 0.3), # 1 167 | # (TranslateX, 0., 0.33), # 2 168 | # (TranslateY, 0., 0.33), # 3 169 | # (Rotate, 0, 30), # 4 170 | # (AutoContrast, 0, 1), # 5 171 | # (Invert, 0, 1), # 6 172 | # (Equalize, 0, 1), # 7 173 | # (Solarize, 0, 110), # 8 174 | # (Posterize, 4, 8), # 9 175 | # # (Contrast, 0.1, 1.9), # 10 176 | # (Color, 0.1, 1.9), # 11 177 | # (Brightness, 0.1, 1.9), # 12 178 | # (Sharpness, 0.1, 1.9), # 13 179 | # # (Cutout, 0, 0.2), # 14 180 | # # (SamplePairing(imgs), 0, 0.4), # 15 181 | # ] 182 | 183 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 184 | l = [ 185 | (AutoContrast, 0, 1), 186 | (Equalize, 0, 1), 187 | (Invert, 0, 1), 188 | (Rotate, 0, 30), 189 | (Posterize, 0, 4), 190 | (Solarize, 0, 256), 191 | (SolarizeAdd, 0, 110), 192 | (Color, 0.1, 1.9), 193 | (Contrast, 0.1, 1.9), 194 | (Brightness, 0.1, 1.9), 195 | (Sharpness, 0.1, 1.9), 196 | (ShearX, 0., 0.3), 197 | (ShearY, 0., 0.3), 198 | (CutoutAbs, 0, 40), 199 | (TranslateXabs, 0., 100), 200 | (TranslateYabs, 0., 100), 201 | ] 202 | 203 | return l 204 | 205 | 206 | class Lighting(object): 207 | """Lighting noise(AlexNet - style PCA - based noise)""" 208 | 209 | def __init__(self, alphastd, eigval, eigvec): 210 | self.alphastd = alphastd 211 | self.eigval = torch.Tensor(eigval) 212 | self.eigvec = torch.Tensor(eigvec) 213 | 214 | def __call__(self, img): 215 | if self.alphastd == 0: 216 | return img 217 | 218 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 219 | rgb = self.eigvec.type_as(img).clone() \ 220 | .mul(alpha.view(1, 3).expand(3, 3)) \ 221 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 222 | .sum(1).squeeze() 223 | 224 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 225 | 226 | 227 | class CutoutDefault(object): 228 | """ 229 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 230 | """ 231 | def __init__(self, length): 232 | self.length = length 233 | 234 | def __call__(self, img): 235 | h, w = img.size(1), img.size(2) 236 | mask = np.ones((h, w), np.float32) 237 | y = np.random.randint(h) 238 | x = np.random.randint(w) 239 | 240 | y1 = np.clip(y - self.length // 2, 0, h) 241 | y2 = np.clip(y + self.length // 2, 0, h) 242 | x1 = np.clip(x - self.length // 2, 0, w) 243 | x2 = np.clip(x + self.length // 2, 0, w) 244 | 245 | mask[y1: y2, x1: x2] = 0. 246 | mask = torch.from_numpy(mask) 247 | mask = mask.expand_as(img) 248 | img *= mask 249 | return img 250 | 251 | 252 | class RandAugment: 253 | def __init__(self, n, m): 254 | self.n = n 255 | self.m = m # [0, 30] 256 | self.augment_list = augment_list() 257 | 258 | def __call__(self, img): 259 | ops = random.choices(self.augment_list, k=self.n) 260 | for op, minval, maxval in ops: 261 | val = (float(self.m) / 30) * float(maxval - minval) + minval 262 | img = op(img, val) 263 | 264 | return img -------------------------------------------------------------------------------- /MD_AirComp_FEEL/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import copy 6 | import torch 7 | from torchvision import datasets, transforms 8 | from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal 9 | from sampling import cifar_iid, cifar_noniid 10 | # from models.randaug import RandAugment 11 | 12 | def get_model(model_name, dataset, img_size, nclass): 13 | if model_name == 'vggnet': 14 | from models import vgg 15 | model = vgg.VGG('VGG11', num_classes=nclass) 16 | 17 | elif model_name == 'resnet': 18 | from models import resnet 19 | model = resnet.ResNet18(num_classes=nclass) 20 | elif model_name == 'resnet-s': 21 | from resnet_s import resnet20 22 | model = resnet20() 23 | 24 | elif model_name == 'wideresnet': 25 | from models import wideresnet 26 | model = wideresnet.WResNet_cifar10(num_classes=nclass, depth=16, multiplier=4) 27 | 28 | elif model_name == 'cnnlarge': 29 | from models import simple 30 | model = simple.CNNLarge() 31 | 32 | elif model_name == 'convmixer': 33 | from models import convmixer 34 | model = convmixer.ConvMixer(n_classes=nclass) 35 | 36 | elif model_name == 'cnn': 37 | from models import simple 38 | 39 | if dataset == 'mnist': 40 | model = simple.CNNMnist(num_classes=nclass, num_channels=1) 41 | elif dataset == 'fmnist': 42 | model = simple.CNNFashion_Mnist(num_classes=nclass) 43 | elif dataset == 'cifar': 44 | model = simple.CNNCifar(num_classes=nclass) 45 | elif model_name == 'ae': 46 | from models import simple 47 | 48 | if dataset == 'mnist' or dataset == 'fmnist': 49 | model = simple.Autoencoder() 50 | 51 | elif model_name == 'mlp': 52 | from models import simple 53 | 54 | len_in = 1 55 | for x in img_size: 56 | len_in *= x 57 | model = simple.MLP(dim_in=len_in, dim_hidden=64, 58 | dim_out=nclass) 59 | else: 60 | exit('Error: unrecognized model') 61 | 62 | return model 63 | 64 | 65 | def get_dataset(args): 66 | """ Returns train and test datasets and a user group which is a dict where 67 | the keys are the user index and the values are the corresponding data for 68 | each of those users. 69 | """ 70 | 71 | if args.dataset == 'cifar10' or 'cifar100': 72 | 73 | transform_train = transforms.Compose([ 74 | transforms.RandomCrop(32, padding=4), 75 | transforms.RandomHorizontalFlip(), 76 | # transforms.RandAugment(num_ops=2, magnitude=14), 77 | transforms.ToTensor(), 78 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 79 | ]) 80 | 81 | # transform_train.transforms.insert(0, RandAugment(2, 14)) 82 | 83 | transform_test = transforms.Compose([ 84 | transforms.ToTensor(), 85 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 86 | ]) 87 | 88 | if args.dataset == 'cifar10': 89 | data_dir = '../data/cifar/' 90 | 91 | train_dataset = datasets.CIFAR10(data_dir, train=True, download=True, 92 | transform=transform_train) 93 | 94 | test_dataset = datasets.CIFAR10(data_dir, train=False, download=True, 95 | transform=transform_test) 96 | 97 | num_classes = 10 98 | elif args.dataset == 'cifar100': 99 | data_dir = '../data/cifar100/' 100 | 101 | train_dataset = datasets.CIFAR100(data_dir, train=True, download=True, 102 | transform=transform_train) 103 | 104 | test_dataset = datasets.CIFAR100(data_dir, train=False, download=True, 105 | transform=transform_test) 106 | 107 | num_classes = 100 108 | # sample training data amongst users 109 | if args.iid: 110 | # Sample IID user data from Mnist 111 | user_groups = cifar_iid(train_dataset, args.num_users) 112 | else: 113 | # Sample Non-IID user data from Mnist 114 | if args.unequal: 115 | # Chose uneuqal splits for every user 116 | raise NotImplementedError() 117 | else: 118 | # Chose euqal splits for every user 119 | user_groups = cifar_noniid(train_dataset, args) 120 | 121 | 122 | 123 | elif args.dataset == 'mnist' or 'fmnist': 124 | apply_transform = transforms.Compose([ 125 | transforms.ToTensor(), 126 | # transforms.Normalize((0.1307,), (0.3081,)) 127 | ]) 128 | 129 | if args.dataset == 'mnist': 130 | data_dir = '../data/mnist/' 131 | train_dataset = datasets.MNIST(data_dir, train=True, download=True, 132 | transform=apply_transform) 133 | 134 | test_dataset = datasets.MNIST(data_dir, train=False, download=True, 135 | transform=apply_transform) 136 | else: 137 | data_dir = '../data/fmnist/' 138 | train_dataset = datasets.FashionMNIST(data_dir, train=True, download=True, 139 | transform=apply_transform) 140 | 141 | test_dataset = datasets.FashionMNIST(data_dir, train=False, download=True, 142 | transform=apply_transform) 143 | 144 | 145 | train_dataset = datasets.MNIST(data_dir, train=True, download=True, 146 | transform=apply_transform) 147 | 148 | test_dataset = datasets.MNIST(data_dir, train=False, download=True, 149 | transform=apply_transform) 150 | num_classes = 10 151 | 152 | 153 | # sample training data amongst users 154 | if args.iid: 155 | # Sample IID user data from Mnist 156 | user_groups = mnist_iid(train_dataset, args.num_users) 157 | else: 158 | # Sample Non-IID user data from Mnist 159 | if args.unequal: 160 | # Chose uneuqal splits for every user 161 | user_groups = mnist_noniid_unequal(train_dataset, args.num_users) 162 | else: 163 | # Chose euqal splits for every user 164 | user_groups = mnist_noniid(train_dataset, args.num_users) 165 | 166 | 167 | 168 | return train_dataset, test_dataset, num_classes, user_groups 169 | 170 | 171 | def average_weights(w): 172 | """ 173 | Returns the average of the weights. 174 | """ 175 | w_avg = copy.deepcopy(w[0]) 176 | for key in w_avg.keys(): 177 | for i in range(1, len(w)): 178 | w_avg[key] += w[i][key] 179 | w_avg[key] = torch.div(w_avg[key], len(w)) 180 | return w_avg 181 | 182 | 183 | def average_parameter_delta(ws, w0): 184 | w_avg = copy.deepcopy(ws[0]) 185 | for key in range(len(w_avg)): 186 | w_avg[key] = torch.zeros_like(w_avg[key]) 187 | for i in range(0, len(ws)): 188 | # w_avg[key] += ws[i][key] - w0[key] 189 | w_avg[key] += ws[i][key] 190 | w_avg[key] = torch.div(w_avg[key], len(ws)) 191 | return w_avg 192 | 193 | 194 | def exp_details(args): 195 | print('\nExperimental details:') 196 | print(f' Model : {args.model}') 197 | print(f' Optimizer : {args.optimizer}') 198 | print(f' Learning : {args.lr}') 199 | print(f' Global Rounds : {args.epochs}\n') 200 | 201 | print(' Federated parameters:') 202 | if args.iid: 203 | print(' IID') 204 | else: 205 | print(' Non-IID') 206 | print(f' Fraction of users : {args.frac}') 207 | print(f' Local Batch size : {args.local_bs}') 208 | print(f' Local Epochs : {args.local_ep}\n') 209 | return 210 | 211 | 212 | def add_params(x, y): 213 | z = [] 214 | for i in range(len(x)): 215 | z.append(x[i] + y[i]) 216 | return z 217 | 218 | 219 | def sub_params(x, y): 220 | z = [] 221 | for i in range(len(x)): 222 | z.append(x[i] - y[i]) 223 | return z 224 | 225 | 226 | def mult_param(alpha, x): 227 | z = [] 228 | for i in range(len(x)): 229 | z.append(alpha*x[i]) 230 | return z 231 | 232 | 233 | def norm_of_param(x): 234 | z = 0 235 | for i in range(len(x)): 236 | z += torch.norm(x[i].flatten(0)) 237 | return z 238 | -------------------------------------------------------------------------------- /MD_AirComp_FEEL/models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.transforms as transforms 3 | import math 4 | 5 | __all__ = ['wide_WResNet'] 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | def init_model(model): 15 | for m in model.modules(): 16 | if isinstance(m, nn.Conv2d): 17 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 18 | m.weight.data.normal_(0, math.sqrt(2. / n)) 19 | elif isinstance(m, nn.BatchNorm2d): 20 | m.weight.data.fill_(1) 21 | m.bias.data.zero_() 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class WResNet(nn.Module): 96 | 97 | def __init__(self): 98 | super(WResNet, self).__init__() 99 | 100 | def _make_layer(self, block, planes, blocks, stride=1): 101 | downsample = None 102 | if stride != 1 or self.inplanes != planes * block.expansion: 103 | downsample = nn.Sequential( 104 | nn.Conv2d(self.inplanes, planes * block.expansion, 105 | kernel_size=1, stride=stride, bias=False), 106 | nn.BatchNorm2d(planes * block.expansion), 107 | ) 108 | 109 | layers = [] 110 | layers.append(block(self.inplanes, planes, stride, downsample)) 111 | self.inplanes = planes * block.expansion 112 | for i in range(1, blocks): 113 | layers.append(block(self.inplanes, planes)) 114 | 115 | return nn.Sequential(*layers) 116 | 117 | def forward(self, x): 118 | x = self.feats(x) 119 | x = x.view(x.size(0), -1) 120 | x = self.fc(x) 121 | 122 | return x 123 | 124 | 125 | class WResNet_imagenet(WResNet): 126 | 127 | def __init__(self, num_classes=1000, 128 | block=Bottleneck, layers=[3, 4, 23, 3]): 129 | super(WResNet_imagenet, self).__init__() 130 | self.inplanes = 64 131 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 132 | bias=False) 133 | self.bn1 = nn.BatchNorm2d(64) 134 | self.relu = nn.ReLU(inplace=True) 135 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 136 | self.layer1 = self._make_layer(block, 64, layers[0]) 137 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 138 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 139 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 140 | self.avgpool = nn.AvgPool2d(7) 141 | self.feats = nn.Sequential(self.conv1, 142 | self.bn1, 143 | self.relu, 144 | self.maxpool, 145 | 146 | self.layer1, 147 | self.layer2, 148 | self.layer3, 149 | self.layer4, 150 | 151 | self.avgpool) 152 | self.fc = nn.Linear(512 * block.expansion, num_classes) 153 | 154 | init_model(self) 155 | self.regime = { 156 | 0: {'optimizer': 'SGD', 'lr': 1e-1, 'weight_decay': 1e-4, 'momentum': 0.9}, 157 | 30: {'lr': 1e-2}, 158 | 60: {'lr': 1e-3}, 159 | 90: {'lr': 1e-4} 160 | } 161 | 162 | 163 | class WResNet_cifar10(WResNet): 164 | 165 | def __init__(self, num_classes=10, multiplier=1, 166 | block=BasicBlock, depth=18): 167 | super(WResNet_cifar10, self).__init__() 168 | self.inplanes = 16 * multiplier 169 | n = int((depth - 2) / 6) 170 | self.conv1 = nn.Conv2d(3, 16 * multiplier, kernel_size=3, stride=1, padding=1, 171 | bias=False) 172 | self.bn1 = nn.BatchNorm2d(16 * multiplier) 173 | self.relu = nn.ReLU(inplace=True) 174 | self.maxpool = lambda x: x 175 | self.layer1 = self._make_layer(block, 16 * multiplier, n) 176 | self.layer2 = self._make_layer(block, 32 * multiplier, n, stride=2) 177 | self.layer3 = self._make_layer(block, 64 * multiplier, n, stride=2) 178 | self.layer4 = lambda x: x 179 | self.avgpool = nn.AvgPool2d(8) 180 | self.fc = nn.Linear(64 * multiplier, num_classes) 181 | self.feats = nn.Sequential(self.conv1, 182 | self.bn1, 183 | self.relu, 184 | self.layer1, 185 | self.layer2, 186 | self.layer3, 187 | self.avgpool) 188 | init_model(self) 189 | 190 | self.regime = { 191 | 0: {'optimizer': 'SGD', 'lr': 1e-1, 192 | 'weight_decay': 1e-4, 'momentum': 0.9}, 193 | 60: {'lr': 2e-2}, 194 | 120: {'lr': 4e-3}, 195 | 140: {'lr': 1e-4} 196 | } 197 | 198 | # def wideresnet_cifar(num_classes=num_classes): 199 | # return WResNet_cifar10(num_classes=num_classes, block=BasicBlock, depth=16, multiplier=4) 200 | 201 | def wide_WResNet(**kwargs): 202 | num_classes, depth, dataset = map( 203 | kwargs.get, ['num_classes', 'depth', 'dataset']) 204 | if dataset == 'imagenet': 205 | num_classes = num_classes or 1000 206 | depth = depth or 50 207 | if depth == 18: 208 | return WResNet_imagenet(num_classes=num_classes, 209 | block=BasicBlock, layers=[2, 2, 2, 2]) 210 | if depth == 34: 211 | return WResNet_imagenet(num_classes=num_classes, 212 | block=BasicBlock, layers=[3, 4, 6, 3]) 213 | if depth == 50: 214 | return WResNet_imagenet(num_classes=num_classes, 215 | block=Bottleneck, layers=[3, 4, 6, 3]) 216 | if depth == 101: 217 | return WResNet_imagenet(num_classes=num_classes, 218 | block=Bottleneck, layers=[3, 4, 23, 3]) 219 | if depth == 152: 220 | return WResNet_imagenet(num_classes=num_classes, 221 | block=Bottleneck, layers=[3, 8, 36, 3]) 222 | 223 | elif dataset == 'cifar10': 224 | num_classes = num_classes or 10 225 | depth = depth or 16 226 | return WResNet_cifar10(num_classes=num_classes, 227 | block=BasicBlock, depth=depth, multiplier=4) 228 | elif dataset == 'cifar100': 229 | num_classes = num_classes or 100 230 | depth = depth or 16 231 | return WResNet_cifar10(num_classes=num_classes, 232 | block=BasicBlock, depth=depth, multiplier=4) -------------------------------------------------------------------------------- /MD_AirComp_FEEL/main_error_free_channel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | 6 | import os 7 | import copy 8 | import time 9 | import pickle 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | import math 14 | import torch 15 | from tensorboardX import SummaryWriter 16 | 17 | from options import args_parser 18 | from update import LocalUpdate, update_model_inplace, test_inference 19 | from utils import get_model, get_dataset, average_weights, exp_details, average_parameter_delta 20 | import utils 21 | from sklearn.cluster import kmeans_plusplus 22 | import faiss 23 | 24 | if __name__ == '__main__': 25 | start_time = time.time() 26 | 27 | args = args_parser() 28 | # exp_details(args) 29 | args.seed = 42 30 | args.M = 2 ** 6 # quantization levels 31 | args.Vb = 20 # diemnsion of each vector quantization 32 | args.iid = 0 # non i.i.d. data distribution 33 | args.Nummm = 10000 # number of data samples for data split 34 | args.epochs = 1000 35 | args.model = 'resnet-s' 36 | args.optimizer = 'fedavg' 37 | exp_details(args) 38 | # import os 39 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1" 40 | 41 | # define paths 42 | # out_dir_name = args.model + '_compress_' + args.dataset + args.optimizer + '_lr' + str(args.lr) + '_locallr' + str( args.local_lr) + '_localep' + str(args.local_ep) +'_localbs' + str(args.local_bs) + '_eps' + str(args.eps) 43 | file_name = '/Results_ErrFreeChannel_{}_{}_{}_llr[{}]_glr[{}]_Vb[{}]_le[{}]_bs[{}]_iid[{}]_Ql[{}]_frac[{}]_{}.pkl'.\ 44 | format(args.dataset, args.model, args.optimizer, 45 | args.local_lr, args.lr, args.Vb, 46 | args.local_ep, args.local_bs, args.iid, args.M, args.frac, args.compressor) 47 | logger = SummaryWriter('./logs/'+file_name) 48 | 49 | device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else "cpu") 50 | torch.set_num_threads(1) # limit cpu use 51 | print ('-- pytorch version: ', torch.__version__) 52 | 53 | np.random.seed(args.seed) 54 | torch.manual_seed(args.seed) 55 | if device != 'cpu': 56 | torch.cuda.manual_seed(args.seed) 57 | 58 | if not os.path.exists(args.outfolder): 59 | os.mkdir(args.outfolder) 60 | 61 | # load dataset and user groups 62 | train_dataset, test_dataset, num_classes, user_groups = get_dataset(args) 63 | 64 | # Set the model to train and send it to device. 65 | global_model = get_model(args.model, args.dataset, train_dataset[0][0].shape, num_classes) 66 | global_model.to(device) 67 | global_model.train() 68 | 69 | momentum_buffer_list = [] 70 | exp_avgs = [] 71 | exp_avg_sqs = [] 72 | max_exp_avg_sqs = [] 73 | for i, p in enumerate(global_model.parameters()): 74 | momentum_buffer_list.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)) 75 | exp_avgs.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)) 76 | exp_avg_sqs.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)) 77 | max_exp_avg_sqs.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)+args.max_init) # 1e-2 78 | 79 | 80 | ### init error ------- 81 | e = [] 82 | for id in range(args.num_users): 83 | ei = [] 84 | for i, p in enumerate(global_model.parameters()): 85 | ei.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)) 86 | e.append(ei) 87 | D = sum(p.numel() for p in global_model.parameters()) 88 | print('total dimension:', D) 89 | print('compressor:', args.compressor) 90 | 91 | # BS = args.Vb 92 | # M = args.M 93 | s_nExtraZero = D % args.Vb 94 | if s_nExtraZero != 0: 95 | s_nExtraZero = args.Vb - s_nExtraZero 96 | Qerr = torch.ones((int(args.frac * args.num_users), args.epochs)) 97 | 98 | # Training 99 | train_loss_sampled, train_loss, train_accuracy = [], [], [] 100 | test_loss, test_accuracy = [], [] 101 | start_time = time.time() 102 | for epoch in tqdm(range(args.epochs)): 103 | ep_time = time.time() 104 | 105 | local_weights, local_params, local_losses = [], [], [] 106 | print(f'\n | Global Training Round : {epoch+1} |\n') 107 | 108 | 109 | par_before = [] 110 | for p in global_model.parameters(): # get trainable parameters 111 | par_before.append(p.data.detach().clone()) 112 | # this is to store parameters before update 113 | w0 = global_model.state_dict() # get all parameters, includeing batch normalization related ones 114 | 115 | 116 | global_model.train() 117 | m = max(int(args.frac * args.num_users), 1) 118 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 119 | 120 | for ooo in range(len(idxs_users)): 121 | idx = idxs_users[ooo] 122 | local_model = LocalUpdate(args=args, dataset=train_dataset, 123 | idxs=user_groups[idx], logger=logger) 124 | 125 | w, p, loss = local_model.update_weights_local( 126 | model=copy.deepcopy(global_model), global_round=epoch) 127 | 128 | 129 | ####### add error feedback ####### 130 | delta = utils.sub_params(p, par_before) 131 | tmp = utils.add_params(e[idx], delta) 132 | 133 | for i in range(len(tmp)): 134 | if i == 0: 135 | VecTmp = tmp[i].cpu().reshape(1, -1) 136 | else: 137 | VecTmp = torch.cat((VecTmp, tmp[i].cpu().reshape(1, -1)), 1) 138 | # if ooo == 0: 139 | # MatSig = VecTmp 140 | # else: 141 | # MatSig = torch.cat((MatSig, VecTmp), 0) 142 | 143 | # delta_out = [] 144 | # L = 0 145 | # for i in range(len(tmp)): 146 | # Ll = L + tmp[i].numel() 147 | # delta_out.append(VecTmp[:, L:Ll].reshape(tmp[i].shape)) 148 | # L = Ll 149 | VecTmp = torch.cat((VecTmp, torch.zeros((1, s_nExtraZero))), 1) 150 | if ooo == 0: 151 | data2 = VecTmp.numpy().reshape(-1, args.Vb) 152 | centers, _ = kmeans_plusplus(data2, n_clusters=args.M, random_state=0) 153 | index = faiss.IndexFlatL2(args.Vb) 154 | index.add(centers) 155 | _, intIndex = index.search(VecTmp.numpy().reshape(-1, args.Vb), 1) 156 | SigVQ0 = torch.tensor(centers[intIndex].reshape(1, -1)) 157 | del intIndex 158 | else: 159 | _, intIndex = index.search(VecTmp.numpy().reshape(-1, args.Vb), 1) 160 | SigVQ0 = torch.tensor(centers[intIndex].reshape(1, -1)) 161 | del intIndex 162 | Qerr[ooo, epoch] = 10 * torch.log10(torch.sum((torch.abs(SigVQ0 - VecTmp) ** 2)) / torch.sum((torch.abs(VecTmp) ** 2))) 163 | if s_nExtraZero == 0: 164 | SigVQ = SigVQ0 165 | else: 166 | SigVQ = SigVQ0[:, :-s_nExtraZero] 167 | 168 | delta_out = [] 169 | L = 0 170 | for i in range(len(tmp)): 171 | Ll = L + tmp[i].numel() 172 | delta_out.append(SigVQ[:, L:Ll].reshape(tmp[i].shape).cuda()) 173 | L = Ll 174 | del SigVQ 175 | 176 | # delta_out = local_model.compressSignal(tmp, D) 177 | e[idx] = utils.sub_params(tmp, delta_out) 178 | 179 | local_weights.append(copy.deepcopy(w)) 180 | # local_params.append(copy.deepcopy(utils.add_params(delta_out, par_before))) 181 | local_params.append(copy.deepcopy(delta_out)) 182 | local_losses.append(copy.deepcopy(loss)) 183 | 184 | 185 | 186 | bn_weights = average_weights(local_weights) 187 | global_model.load_state_dict(bn_weights) 188 | 189 | global_delta = average_parameter_delta(local_params, par_before) 190 | # global_delta = average_parameter_delta(local_params, par_before) # calculate compression in this function 191 | 192 | update_model_inplace( 193 | global_model, par_before, global_delta, args, epoch, 194 | momentum_buffer_list, exp_avgs, exp_avg_sqs, max_exp_avg_sqs) 195 | 196 | # report and store loss and accuracy 197 | # this is local training loss on sampled users 198 | loss_avg = sum(local_losses) / len(local_losses) 199 | train_loss.append(loss_avg) 200 | 201 | 202 | 203 | global_model.eval() 204 | 205 | 206 | # Test inference after completion of training 207 | test_acc, test_ls = test_inference(args, global_model, test_dataset) 208 | test_accuracy.append(test_acc) 209 | test_loss.append(test_ls) 210 | 211 | # print global training loss after every rounds 212 | print('Epoch Run Time: {0:0.4f} of {1} global rounds'.format(time.time()-ep_time, epoch+1)) 213 | print(f'Training Loss : {train_loss[-1]}') 214 | print(f'Test Loss : {test_loss[-1]}') 215 | print(f'Test Accuracy : {test_accuracy[-1]} \n') 216 | logger.add_scalar('train loss', train_loss[-1], epoch) 217 | logger.add_scalar('test loss', test_loss[-1], epoch) 218 | logger.add_scalar('test acc', test_accuracy[-1], epoch) 219 | 220 | if args.save: 221 | # Saving the objects train_loss and train_accuracy: 222 | 223 | 224 | with open(args.outfolder + file_name, 'wb') as f: 225 | pickle.dump([train_loss, test_loss, test_accuracy, Qerr], f) 226 | 227 | print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time)) 228 | 229 | 230 | -------------------------------------------------------------------------------- /MD_AirComp_FEEL/compressors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | import numpy as np 6 | import random, math 7 | 8 | 9 | class CompressorType: 10 | IDENTICAL = 1 # Identical compressor 11 | LAZY_COMPRESSOR = 2 # Lazy or Bernulli compressor 12 | RANDK_COMPRESSOR = 3 # Rank-K compressor 13 | NATURAL_COMPRESSOR_FP64 = 4 # Natural compressor with FP64 14 | NATURAL_COMPRESSOR_FP32 = 5 # Natural compressor with FP32 15 | STANDARD_DITHERING_FP64 = 6 # Standard dithering with FP64 16 | STANDARD_DITHERING_FP32 = 7 # Standard dithering with FP32 17 | NATURAL_DITHERING_FP32 = 8 # Natural Dithering applied for FP32 components vectors 18 | NATURAL_DITHERING_FP64 = 9 # Natural Dithering applied for FP64 components vectors 19 | TOPK_COMPRESSOR = 10 # Top-K compressor 20 | SIGN_COMPRESSOR = 11 # Sign compressor 21 | ONEBIT_SIGN_COMPRESSOR = 12 # One bit sign compressor 22 | 23 | class Compressor: 24 | def __init__(self, compressorName = ""): 25 | self.compressorName = compressorName 26 | self.compressorType = CompressorType.IDENTICAL 27 | self.w = 0.0 28 | self.last_need_to_send_advance = 0 29 | self.component_bits_size = 32 30 | 31 | def name(self): 32 | omega = r'$\omega$' 33 | if self.compressorType == CompressorType.IDENTICAL: return f"Identical" 34 | if self.compressorType == CompressorType.LAZY_COMPRESSOR: return f"Bernoulli(Lazy) [p={self.P:g},{omega}={self.getW():.1f}]" 35 | if self.compressorType == CompressorType.RANDK_COMPRESSOR: return f" (K={self.K})" 36 | if self.compressorType == CompressorType.NATURAL_COMPRESSOR_FP64: return f"Natural for fp64 [{omega}={self.getW():.1f}]" 37 | if self.compressorType == CompressorType.NATURAL_COMPRESSOR_FP32: return f"Natural for fp32 [{omega}={self.getW():.1f}]" 38 | if self.compressorType == CompressorType.STANDARD_DITHERING_FP64: return f"Standard Dithering for fp64[s={self.s}]" 39 | if self.compressorType == CompressorType.STANDARD_DITHERING_FP64: return f"Standard Dithering for fp32[s={self.s}]" 40 | if self.compressorType == CompressorType.NATURAL_DITHERING_FP32: return f"Natural Dithering for fp32[s={self.s},{omega}={self.getW():.1f}]" 41 | if self.compressorType == CompressorType.NATURAL_DITHERING_FP64: return f"Natural Dithering for fp64[s={self.s},{omega}={self.getW():.1f}]" 42 | if self.compressorType == CompressorType.TOPK_COMPRESSOR: return f" Top (K={self.K})" 43 | if self.compressorType == CompressorType.SIGN_COMPRESSOR: return f"Sign" 44 | if self.compressorType == CompressorType.ONEBIT_SIGN_COMPRESSOR: return f"One Bit Sign" 45 | return "?" 46 | 47 | def fullName(self): 48 | omega = r'$\omega$' 49 | if self.compressorType == CompressorType.IDENTICAL: return f"Identical" 50 | if self.compressorType == CompressorType.LAZY_COMPRESSOR: return f"Bernoulli(Lazy) [p={self.P:g},{omega}={self.getW():.1f}]" 51 | if self.compressorType == CompressorType.RANDK_COMPRESSOR: return f"Rand [K={self.K}]" 52 | if self.compressorType == CompressorType.NATURAL_COMPRESSOR_FP64: return f"Natural for fp64 [{omega}={self.getW():.1f}]" 53 | if self.compressorType == CompressorType.NATURAL_COMPRESSOR_FP32: return f"Natural for fp32 [{omega}={self.getW():.1f}]" 54 | if self.compressorType == CompressorType.STANDARD_DITHERING_FP64: return f"Standard Dithering for fp64[s={self.s}]" 55 | if self.compressorType == CompressorType.STANDARD_DITHERING_FP64: return f"Standard Dithering for fp32[s={self.s}]" 56 | if self.compressorType == CompressorType.NATURAL_DITHERING_FP32: return f"Natural Dithering for fp32[s={self.s},{omega}={self.getW():.1f}]" 57 | if self.compressorType == CompressorType.NATURAL_DITHERING_FP64: return f"Natural Dithering for fp64[s={self.s},{omega}={self.getW():.1f}]" 58 | if self.compressorType == CompressorType.TOPK_COMPRESSOR: return f"Top [K={self.K}]" 59 | if self.compressorType == CompressorType.SIGN_COMPRESSOR: return f"Sign" 60 | if self.compressorType == CompressorType.ONEBIT_SIGN_COMPRESSOR: return f"One Bit Sign" 61 | return "?" 62 | 63 | def resetStats(self): 64 | self.last_need_to_send_advance = 0 65 | 66 | def makeIdenticalCompressor(self): 67 | self.compressorType = CompressorType.IDENTICAL 68 | self.resetStats() 69 | 70 | def makeLazyCompressor(self, P): 71 | self.compressorType = CompressorType.LAZY_COMPRESSOR 72 | self.P = P 73 | self.w = 1.0 / P - 1.0 74 | self.resetStats() 75 | 76 | def makeStandardDitheringFP64(self, levels, vectorNormCompressor, p = float("inf")): 77 | self.compressorType = CompressorType.STANDARD_DITHERING_FP64 78 | self.levelsValues = np.arange(0.0, 1.1, 1.0/levels) # levels + 1 values in range [0.0, 1.0] which uniformly split this segment 79 | self.s = len(self.levelsValues) - 1 # # should be equal to level 80 | assert self.s == levels 81 | 82 | self.p = p 83 | self.vectorNormCompressor = vectorNormCompressor 84 | self.w = 0.0 # TODO 85 | 86 | self.resetStats() 87 | 88 | def makeStandardDitheringFP32(self, levels, vectorNormCompressor, p = float("inf")): 89 | self.compressorType = CompressorType.STANDARD_DITHERING_FP32 90 | self.levelsValues = torch.arange(0.0, 1.1, 1.0/levels) # levels + 1 values in range [0.0, 1.0] which uniformly split this segment 91 | self.s = len(self.levelsValues) - 1 # should be equal to level 92 | assert self.s == levels 93 | 94 | self.p = p 95 | self.vectorNormCompressor = vectorNormCompressor 96 | self.w = 0.0 # TODO 97 | 98 | self.resetStats() 99 | 100 | def makeQSGD_FP64(self, levels, dInput): 101 | norm_compressor = Compressor("norm_compressor") 102 | norm_compressor.makeIdenticalCompressor() 103 | self.makeStandardDitheringFP64(levels, norm_compressor, p = 2) 104 | # Lemma 3.1. from https://arxiv.org/pdf/1610.02132.pdf, page 5 105 | self.w = min(dInput/(levels*levels), dInput**0.5/levels) 106 | 107 | def makeNaturalDitheringFP64(self, levels, dInput, p = float("inf")): 108 | self.compressorType = CompressorType.NATURAL_DITHERING_FP64 109 | self.levelsValues = torch.zeros(levels + 1) 110 | for i in range(levels): 111 | self.levelsValues[i] = (1.0/2.0)**i 112 | self.levelsValues = torch.flip(self.levelsValues, dims = [0]) 113 | self.s = len(self.levelsValues) - 1 114 | assert self.s == levels 115 | 116 | self.p = p 117 | 118 | r = min(p, 2) 119 | self.w = 1.0/8.0 + (dInput** (1.0/r)) / (2**(self.s - 1)) * min(1, (dInput**(1.0/r)) / (2**(self.s-1))) 120 | self.resetStats() 121 | 122 | def makeNaturalDitheringFP32(self, levels, dInput, p = float("inf")): 123 | self.compressorType = CompressorType.NATURAL_DITHERING_FP32 124 | self.levelsValues = torch.zeros(levels + 1) 125 | for i in range(levels): 126 | self.levelsValues[i] = (1.0/2.0)**i 127 | self.levelsValues = torch.flip(self.levelsValues, dims=[0]) 128 | self.s = len(self.levelsValues) - 1 129 | assert self.s == levels 130 | 131 | self.p = p 132 | 133 | r = min(p, 2) 134 | self.w = 1.0/8.0 + (dInput** (1.0/r)) / (2**(self.s - 1)) * min(1, (dInput**(1.0/r)) / (2**(self.s-1))) 135 | self.resetStats() 136 | 137 | # K - how much component we leave from input vector 138 | def makeRandKCompressor(self, K): 139 | self.compressorType = CompressorType.RANDK_COMPRESSOR 140 | self.K = K 141 | self.resetStats() 142 | 143 | def makeTopKCompressor(self, K): 144 | self.compressorType = CompressorType.TOPK_COMPRESSOR 145 | self.K = K 146 | self.resetStats() 147 | 148 | def makeNaturalCompressorFP64(self): 149 | self.compressorType = CompressorType.NATURAL_COMPRESSOR_FP64 150 | self.w = 1.0/8.0 151 | self.resetStats() 152 | 153 | def makeNaturalCompressorFP32(self): 154 | self.compressorType = CompressorType.NATURAL_COMPRESSOR_FP32 155 | self.w = 1.0/8.0 156 | self.resetStats() 157 | 158 | def makeSignCompressor(self, freeze_iteration=0): 159 | self.compressorType = CompressorType.SIGN_COMPRESSOR 160 | self.freeze_iteration = freeze_iteration 161 | self.resetStats() 162 | 163 | def makeOneBitSignCompressor(self, freeze_iteration=0): 164 | self.compressorType = CompressorType.ONEBIT_SIGN_COMPRESSOR 165 | self.freeze_iteration = freeze_iteration 166 | self.resetStats() 167 | 168 | def getW(self): 169 | return self.w 170 | 171 | def compressVector(self, x, iteration=0): 172 | d = max(x.shape) 173 | 174 | if self.compressorType == CompressorType.IDENTICAL: 175 | out = x.clone() 176 | self.last_need_to_send_advance = d * self.component_bits_size 177 | 178 | 179 | elif self.compressorType == CompressorType.TOPK_COMPRESSOR: 180 | #S = torch.arange(d) 181 | # np.random.shuffle(S) 182 | top_size = max(int(self.K*d), 1) 183 | _, S = torch.topk(torch.abs(x), top_size) 184 | out = torch.zeros_like(x) 185 | out[S] = x[S] 186 | # !!! in real case, one needs to send the out vector and a support set to indicate the indices of top K 187 | self.last_need_to_send_advance = 2 * top_size * self.component_bits_size 188 | 189 | elif self.compressorType == CompressorType.SIGN_COMPRESSOR: 190 | if iteration < self.freeze_iteration: 191 | out = x.clone() 192 | self.last_need_to_send_advance = d * self.component_bits_size 193 | else: 194 | 195 | out = torch.sign(x) 196 | 197 | scale = torch.norm(x, p=1) / torch.numel(x) 198 | 199 | out.mul_(scale) # <-- we use this just for similation 200 | 201 | 202 | # !!! in real case, one needs to send D bits for {0, 1} and 32 bits for the scale constant 203 | self.last_need_to_send_advance = d + self.component_bits_size 204 | 205 | elif self.compressorType == CompressorType.ONEBIT_SIGN_COMPRESSOR: 206 | # according to one bit adam paper, 207 | # during warmup, the signal is not compressed 208 | if iteration < self.freeze_iteration: 209 | out = x.clone() 210 | self.last_need_to_send_advance = d * self.component_bits_size 211 | else: 212 | out = torch.sign(x) 213 | # out.add_(1).bool().float().add_(-0.5).mul_(2.0) 214 | scale = torch.norm(x) / np.sqrt(torch.numel(x)) 215 | # out = torch.cat((scale, out), 0) <-- in real case, only send a scale, and a {0,1}^D output 216 | # this is just for similate 217 | out.mul_(scale) # <-- we use this just for similation 218 | # !!! in real case, one needs to send D bits for {0, 1} and 32 bits for the scale constant 219 | self.last_need_to_send_advance = d + self.component_bits_size 220 | 221 | return out 222 | -------------------------------------------------------------------------------- /MD_AirComp_FEEL/sampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | 6 | import numpy as np 7 | from torchvision import datasets, transforms 8 | 9 | 10 | def mnist_iid(dataset, num_users): 11 | """ 12 | Sample I.I.D. client data from MNIST dataset 13 | :param dataset: 14 | :param num_users: 15 | :return: dict of image index 16 | """ 17 | num_items = int(len(dataset)/num_users) 18 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 19 | for i in range(num_users): 20 | dict_users[i] = set(np.random.choice(all_idxs, num_items, 21 | replace=False)) 22 | all_idxs = list(set(all_idxs) - dict_users[i]) 23 | return dict_users 24 | 25 | 26 | def mnist_noniid(dataset, num_users): 27 | """ 28 | Sample non-I.I.D client data from MNIST dataset 29 | :param dataset: 30 | :param num_users: 31 | :return: 32 | """ 33 | # 60,000 training imgs --> 200 imgs/shard X 300 shards 34 | num_shards, num_imgs = 200, 300 35 | idx_shard = [i for i in range(num_shards)] 36 | dict_users = {i: np.array([]) for i in range(num_users)} 37 | idxs = np.arange(num_shards*num_imgs) 38 | labels = dataset.train_labels.numpy() 39 | 40 | # sort labels 41 | idxs_labels = np.vstack((idxs, labels)) 42 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] 43 | idxs = idxs_labels[0, :] 44 | 45 | # divide and assign 2 shards/client 46 | for i in range(num_users): 47 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 48 | idx_shard = list(set(idx_shard) - rand_set) 49 | for rand in rand_set: 50 | dict_users[i] = np.concatenate( 51 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 52 | return dict_users 53 | 54 | 55 | def mnist_noniid_unequal(dataset, num_users): 56 | """ 57 | Sample non-I.I.D client data from MNIST dataset s.t clients 58 | have unequal amount of data 59 | :param dataset: 60 | :param num_users: 61 | :returns a dict of clients with each clients assigned certain 62 | number of training imgs 63 | """ 64 | # 60,000 training imgs --> 50 imgs/shard X 1200 shards 65 | num_shards, num_imgs = 1200, 50 66 | idx_shard = [i for i in range(num_shards)] 67 | dict_users = {i: np.array([]) for i in range(num_users)} 68 | idxs = np.arange(num_shards*num_imgs) 69 | labels = dataset.train_labels.numpy() 70 | 71 | # sort labels 72 | idxs_labels = np.vstack((idxs, labels)) 73 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] 74 | idxs = idxs_labels[0, :] 75 | 76 | # Minimum and maximum shards assigned per client: 77 | min_shard = 1 78 | max_shard = 30 79 | 80 | # Divide the shards into random chunks for every client 81 | # s.t the sum of these chunks = num_shards 82 | random_shard_size = np.random.randint(min_shard, max_shard+1, 83 | size=num_users) 84 | random_shard_size = np.around(random_shard_size / 85 | sum(random_shard_size) * num_shards) 86 | random_shard_size = random_shard_size.astype(int) 87 | 88 | # Assign the shards randomly to each client 89 | if sum(random_shard_size) > num_shards: 90 | 91 | for i in range(num_users): 92 | # First assign each client 1 shard to ensure every client has 93 | # atleast one shard of data 94 | rand_set = set(np.random.choice(idx_shard, 1, replace=False)) 95 | idx_shard = list(set(idx_shard) - rand_set) 96 | for rand in rand_set: 97 | dict_users[i] = np.concatenate( 98 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), 99 | axis=0) 100 | 101 | random_shard_size = random_shard_size-1 102 | 103 | # Next, randomly assign the remaining shards 104 | for i in range(num_users): 105 | if len(idx_shard) == 0: 106 | continue 107 | shard_size = random_shard_size[i] 108 | if shard_size > len(idx_shard): 109 | shard_size = len(idx_shard) 110 | rand_set = set(np.random.choice(idx_shard, shard_size, 111 | replace=False)) 112 | idx_shard = list(set(idx_shard) - rand_set) 113 | for rand in rand_set: 114 | dict_users[i] = np.concatenate( 115 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), 116 | axis=0) 117 | else: 118 | 119 | for i in range(num_users): 120 | shard_size = random_shard_size[i] 121 | rand_set = set(np.random.choice(idx_shard, shard_size, 122 | replace=False)) 123 | idx_shard = list(set(idx_shard) - rand_set) 124 | for rand in rand_set: 125 | dict_users[i] = np.concatenate( 126 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), 127 | axis=0) 128 | 129 | if len(idx_shard) > 0: 130 | # Add the leftover shards to the client with minimum images: 131 | shard_size = len(idx_shard) 132 | # Add the remaining shard to the client with lowest data 133 | k = min(dict_users, key=lambda x: len(dict_users.get(x))) 134 | rand_set = set(np.random.choice(idx_shard, shard_size, 135 | replace=False)) 136 | idx_shard = list(set(idx_shard) - rand_set) 137 | for rand in rand_set: 138 | dict_users[k] = np.concatenate( 139 | (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]), 140 | axis=0) 141 | 142 | return dict_users 143 | 144 | 145 | def cifar_iid(dataset, num_users): 146 | """ 147 | Sample I.I.D. client data from CIFAR10 dataset 148 | :param dataset: 149 | :param num_users: 150 | :return: dict of image index 151 | """ 152 | num_items = int(len(dataset)/num_users) 153 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 154 | for i in range(num_users): 155 | dict_users[i] = set(np.random.choice(all_idxs, num_items, 156 | replace=False)) 157 | all_idxs = list(set(all_idxs) - dict_users[i]) 158 | return dict_users 159 | 160 | # def cifar_noniid(args, dataset, num_users): 161 | # """ 162 | # Sample non-I.I.D client data from cifar dataset 163 | # :param dataset: 164 | # :param num_users: 165 | # :return: 166 | # Each device randomly sample 167 | # """ 168 | # lenRandom = args.Nummm 169 | # num_items = int(lenRandom/num_users) 170 | # dict_users, all_idxs = {}, [i for i in range(len(dataset))] 171 | # for ii in range(num_users): 172 | # dict_users[ii] = set(np.random.choice(all_idxs, num_items, replace=False)) 173 | # all_idxs = list(set(all_idxs) - dict_users[ii]) 174 | # 175 | # labels = np.array(dataset.targets) 176 | # labels = labels[all_idxs] 177 | # 178 | # # sort labels 179 | # idxs = np.arange(len(labels)) 180 | # idxs_labels = np.vstack((idxs, labels)) 181 | # idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 182 | # idxs = idxs_labels[0,:] 183 | # 184 | # # divide and assign 185 | # numImage = int(len(idxs)/num_users) 186 | # for ii in range(num_users): 187 | # temp = idxs[ii*numImage:(ii+1)*numImage] 188 | # dict_users[ii] = np.concatenate((list(dict_users[ii]), temp), axis=0) 189 | # 190 | # return dict_users 191 | 192 | # def cifar_noniid(dataset, args): 193 | # """ 194 | # Sample non-I.I.D client data from CIFAR10 dataset 195 | # :param dataset: 196 | # :param num_users: 197 | # :return: 198 | # """ 199 | # num_users = args.num_users 200 | # num_shards, num_imgs = 200, 250 201 | # idx_shard = [i for i in range(num_shards)] 202 | # dict_users = {i: np.array([]) for i in range(num_users)} 203 | # idxs = np.arange(num_shards*num_imgs) 204 | # labels = [dataset[i][1] for i in range(len(dataset))] 205 | # # labels = np.array(dataset.train_labels) 206 | # 207 | # # sort labels 208 | # idxs_labels = np.vstack((idxs, labels)) 209 | # idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] 210 | # idxs = idxs_labels[0, :] 211 | # 212 | # # divide and assign 213 | # for i in range(num_users): 214 | # rand_set = set(np.random.choice(idx_shard, 5, replace=False)) 215 | # idx_shard = list(set(idx_shard) - rand_set) 216 | # for rand in rand_set: 217 | # dict_users[i] = np.concatenate( 218 | # (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 219 | # return dict_users 220 | 221 | def cifar_noniid(dataset, args): 222 | """ 223 | Sample non-I.I.D client data from cifar dataset 224 | :param dataset: 225 | :param num_users: 226 | :return: 227 | Each device randomly sample 228 | """ 229 | num_users = args.num_users 230 | lenRandom = args.Nummm 231 | num_items = int(lenRandom/num_users) 232 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 233 | for ii in range(num_users): 234 | dict_users[ii] = set(np.random.choice(all_idxs, num_items, replace=False)) 235 | all_idxs = list(set(all_idxs) - dict_users[ii]) 236 | 237 | labels = np.array(dataset.targets) 238 | labels = labels[all_idxs] 239 | 240 | # sort labels 241 | idxs = np.arange(len(labels)) 242 | idxs_labels = np.vstack((idxs, labels)) 243 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 244 | idxs = idxs_labels[0,:] 245 | 246 | # divide and assign 247 | numImage = int(len(idxs)/num_users) 248 | for ii in range(num_users): 249 | temp = idxs[ii*numImage:(ii+1)*numImage] 250 | dict_users[ii] = np.concatenate((list(dict_users[ii]), temp), axis=0) 251 | 252 | return dict_users 253 | 254 | # from collections import deque 255 | # def cifar_noniid(dataset, args): 256 | # """ 257 | # Sample non-I.I.D client data from CIFAR10 dataset 258 | # :param dataset: 259 | # :param num_users: 260 | # :return: 261 | # """ 262 | # num_users = args.num_users 263 | # num_shards, num_imgs = 200, 250 264 | # idx_shard = list(np.arange(num_shards)) 265 | # idx_labl = list(np.arange(10).repeat(20)) 266 | # idxs_labels_1 = np.vstack((idx_shard, idx_labl)) 267 | # np.random.shuffle(idxs_labels_1.T) 268 | # deq1 = deque(idxs_labels_1[0, :]) 269 | # deq2 = deque(idxs_labels_1[1, :]) 270 | # # idx_shard = [] 271 | # # JJJ = int(num_shards/10) 272 | # # for i in range(10): 273 | # # e = np.arange(i*JJJ, (i+1)*JJJ, 1) 274 | # # idx_shard.append(e) 275 | # dict_users = {i: np.array([]) for i in range(num_users)} 276 | # idxs = np.arange(num_shards*num_imgs) 277 | # labels = [dataset[i][1] for i in range(len(dataset))] 278 | # # labels = np.array(dataset.train_labels) 279 | # 280 | # # sort labels 281 | # idxs_labels = np.vstack((idxs, labels)) 282 | # idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] 283 | # idxs = idxs_labels[0, :] 284 | # 285 | # # divide and assign 286 | # 287 | # for i in range(num_users): 288 | # rand_set = [] 289 | # sett = [] 290 | # while len(rand_set) < 5: 291 | # ttp = deq2.popleft() 292 | # if ttp in sett and i < num_users-1: 293 | # deq1.append(deq1.popleft()) 294 | # deq2.append(ttp) 295 | # else: 296 | # rand_set.append(deq1.popleft()) 297 | # sett.append(ttp) 298 | # for rand in rand_set: 299 | # dict_users[i] = np.concatenate( 300 | # (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 301 | # return dict_users 302 | 303 | if __name__ == '__main__': 304 | dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True, 305 | transform=transforms.Compose([ 306 | transforms.ToTensor(), 307 | transforms.Normalize((0.1307,), 308 | (0.3081,)) 309 | ])) 310 | num = 100 311 | d = mnist_noniid(dataset_train, num) 312 | -------------------------------------------------------------------------------- /MD_AirComp_FEEL/update.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import copy 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader, Dataset 10 | import compressors 11 | from sklearn.cluster import kmeans_plusplus 12 | import faiss 13 | 14 | class DatasetSplit(Dataset): 15 | """An abstract Dataset class wrapped around Pytorch Dataset class. 16 | """ 17 | 18 | def __init__(self, dataset, idxs): 19 | self.dataset = dataset 20 | self.idxs = [int(i) for i in idxs] 21 | 22 | def __len__(self): 23 | return len(self.idxs) 24 | 25 | def __getitem__(self, item): 26 | image, label = self.dataset[self.idxs[item]] 27 | return image.clone().detach(), torch.tensor(label) 28 | 29 | 30 | class LocalUpdate(object): 31 | def __init__(self, args, dataset, idxs, logger): 32 | self.args = args 33 | self.logger = logger 34 | self.trainloader, self.validloader, self.testloader = self.train_val_test( 35 | dataset, list(idxs)) 36 | self.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else "cpu") 37 | # Default criterion set to NLL loss function 38 | self.criterion = nn.CrossEntropyLoss().to(self.device) 39 | 40 | ###### define compressors ####### 41 | self.compressor = compressors.Compressor() 42 | if args.compressor == 'identical': 43 | self.compressor.makeIdenticalCompressor() 44 | elif args.compressor == 'topk256': 45 | self.compressor.makeTopKCompressor(1/256) 46 | elif args.compressor == 'topk128': 47 | self.compressor.makeTopKCompressor(1/128) 48 | elif args.compressor == 'topk64': 49 | self.compressor.makeTopKCompressor(1/64) 50 | elif args.compressor == 'sign': 51 | self.compressor.makeSignCompressor() 52 | elif args.compressor == 'Kmeans': 53 | self.compressor.makeSignCompressor() 54 | else: 55 | exit('unknown compressor: {}'.format(args.compressor)) 56 | 57 | 58 | 59 | def train_val_test(self, dataset, idxs): 60 | """ 61 | Returns train, validation and test dataloaders for a given dataset 62 | and user indexes. 63 | """ 64 | # split indexes for train, validation, and test (80, 10, 10) 65 | idxs_train = idxs[:int(0.8*len(idxs))] 66 | idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))] 67 | idxs_test = idxs[int(0.9*len(idxs)):] 68 | 69 | trainloader = DataLoader(DatasetSplit(dataset, idxs_train), 70 | batch_size=self.args.local_bs, shuffle=True) 71 | validloader = DataLoader(DatasetSplit(dataset, idxs_val), 72 | batch_size=int(len(idxs_val)/10), shuffle=False) 73 | testloader = DataLoader(DatasetSplit(dataset, idxs_test), 74 | batch_size=int(len(idxs_test)/10), shuffle=False) 75 | return trainloader, validloader, testloader 76 | 77 | def update_weights_local(self, model, global_round): 78 | # Set mode to train model 79 | model.train() 80 | epoch_loss = [] 81 | 82 | optimizer = torch.optim.SGD(model.parameters(), lr=self.args.local_lr, momentum=0) 83 | 84 | for iter in range(self.args.local_ep): 85 | batch_loss = [] 86 | total = 0 87 | for batch_idx, (images, labels) in enumerate(self.trainloader): 88 | images, labels = images.to(self.device), labels.to(self.device) 89 | 90 | model.zero_grad() 91 | logits = model(images) 92 | loss = self.criterion(logits, labels) 93 | loss.backward() 94 | optimizer.step() 95 | 96 | if self.args.verbose and (batch_idx % 10 == 0): 97 | print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 98 | global_round, iter, batch_idx * len(images), 99 | len(self.trainloader.dataset), 100 | 100. * batch_idx / len(self.trainloader), loss.item())) 101 | 102 | batch_loss.append(loss.item() * len(labels)) 103 | total += len(labels) 104 | epoch_loss.append(sum(batch_loss)/total) 105 | 106 | par_after = [] 107 | for p in model.parameters(): 108 | par_after.append(p.data.detach().clone()) 109 | 110 | 111 | return model.state_dict(), par_after, sum(epoch_loss) / len(epoch_loss) 112 | 113 | 114 | def VQcompress(self, signal): 115 | BS = self.args.Vb 116 | M = self.args.M 117 | s_nExtraZero = signal.shape[1] % BS 118 | if s_nExtraZero != 0: 119 | s_nExtraZero = BS - s_nExtraZero 120 | signal_e = torch.cat((signal, torch.zeros((signal.shape[0], s_nExtraZero))), 1) 121 | data2 = signal_e[0, :].numpy().reshape(-1, BS) 122 | centers, _ = kmeans_plusplus(data2, n_clusters=M, random_state=0) 123 | index = faiss.IndexFlatL2(BS) 124 | # index0.add(centers0.numpy()) 125 | index.add(centers) 126 | _, intIndex = index.search(signal_e[:].numpy().reshape(-1, BS), 1) 127 | # SigKmean = centers0[intIndex000].reshape(signal000.shape[0], -1) 128 | SigKmean = torch.tensor(centers[intIndex].reshape(signal.shape[0], -1)) 129 | intIndex = intIndex.reshape(signal.shape[0], -1) 130 | if s_nExtraZero == 0: 131 | SigKmean = SigKmean 132 | else: 133 | SigKmean = SigKmean[:, :-s_nExtraZero] 134 | 135 | Ssum = torch.sum(signal, dim=0) / signal.shape[0] 136 | EsKmean = torch.sum(SigKmean, dim=0) / signal.shape[0] 137 | MSE_Kmean = torch.norm(Ssum - EsKmean.view(-1)).pow(2) 138 | Qerror = 10 * torch.log10(torch.sum((torch.abs(Ssum - EsKmean.view(-1)) ** 2)) / torch.sum((torch.abs(Ssum) ** 2))) 139 | 140 | return Qerror, intIndex, centers, SigKmean, EsKmean 141 | 142 | def compressSignal(self, signal, D): 143 | # transit_bits = 0 144 | signal_compressed = [] 145 | for p in signal: 146 | signal_compressed.append(torch.zeros_like(p)) 147 | 148 | signal_flatten = torch.zeros(D).to(self.device) 149 | 150 | signal_offset = 0 151 | for t in range(len(signal)): 152 | offset = len(signal[t].flatten(0)) 153 | signal_flatten[(signal_offset):(signal_offset + offset)] = signal[t].flatten(0) 154 | signal_offset += offset 155 | 156 | 157 | signal_flatten = self.compressor.compressVector(signal_flatten) 158 | # transit_bits += compressors.Compressor.last_need_to_send_advance 159 | 160 | signal_offset = 0 161 | for t in range(len(signal)): 162 | offset = len(signal[t].flatten(0)) 163 | signal_compressed[t].flatten(0)[:] = signal_flatten[(signal_offset):(signal_offset + offset)] 164 | signal_offset += offset 165 | 166 | return signal_compressed 167 | 168 | def compressSignal_layerwise(self, signal, D): 169 | transit_bits = 0 170 | # signal_compressed = [] 171 | for p in signal: 172 | signal_compressed.append(torch.zeros_like(p)) 173 | 174 | signal_flatten = torch.zeros(D).to(self.device) 175 | 176 | signal_offset = 0 177 | for t in range(len(signal)): 178 | offset = len(signal[t].flatten(0)) 179 | signal_flatten[(signal_offset):(signal_offset + offset)] = self.compressor.compressVector(signal[t].flatten(0), self.iteration) 180 | # transit_bits += compressors.Compressor.last_need_to_send_advance 181 | signal_offset += offset 182 | 183 | signal_offset = 0 184 | for t in range(len(signal)): 185 | offset = len(signal[t].flatten(0)) 186 | signal_compressed[t].flatten(0)[:] = signal_flatten[(signal_offset):(signal_offset + offset)] 187 | signal_offset += offset 188 | 189 | return signal_compressed 190 | 191 | 192 | 193 | def update_weights(self, model, global_round): 194 | # Set mode to train model 195 | model.train() 196 | epoch_loss = [] 197 | 198 | optimizer = torch.optim.SGD(model.parameters(), lr=self.args.local_lr, momentum=0) 199 | 200 | for iter in range(self.args.local_ep): 201 | batch_loss = [] 202 | for batch_idx, (images, labels) in enumerate(self.trainloader): 203 | images, labels = images.to(self.device), labels.to(self.device) 204 | 205 | model.zero_grad() 206 | log_probs = model(images) 207 | loss = self.criterion(log_probs, labels) 208 | loss.backward() 209 | optimizer.step() 210 | 211 | if self.args.verbose and (batch_idx % 10 == 0): 212 | print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 213 | global_round, iter, batch_idx * len(images), 214 | len(self.trainloader.dataset), 215 | 100. * batch_idx / len(self.trainloader), loss.item())) 216 | # self.logger.add_scalar('loss', loss.item()) 217 | batch_loss.append(loss.item()/len(labels)) 218 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 219 | 220 | par_after = [] 221 | for p in model.parameters(): 222 | par_after.append(p.data.detach().clone()) 223 | 224 | return par_after, sum(epoch_loss) / len(epoch_loss) 225 | 226 | def inference(self, model): 227 | """ Returns the inference accuracy and loss. 228 | """ 229 | 230 | model.eval() 231 | loss, total, correct = 0.0, 0.0, 0.0 232 | 233 | for batch_idx, (images, labels) in enumerate(self.testloader): 234 | images, labels = images.to(self.device), labels.to(self.device) 235 | 236 | # Inference 237 | outputs = model(images) 238 | batch_loss = self.criterion(outputs, labels) 239 | loss += batch_loss.item() * len(labels) 240 | 241 | # Prediction 242 | _, pred_labels = torch.max(outputs, 1) 243 | pred_labels = pred_labels.view(-1) 244 | correct += torch.sum(torch.eq(pred_labels, labels)).item() 245 | total += len(labels) 246 | 247 | accuracy = correct/total 248 | loss = loss/total 249 | return accuracy, loss 250 | 251 | 252 | def update_model_inplace(model, par_before, delta, args, cur_iter, momentum_buffer_list, exp_avgs, exp_avg_sqs, max_exp_avg_sqs): 253 | grads = copy.deepcopy(delta) 254 | 255 | # learning rate decay 256 | iteration = cur_iter + 1 # add 1 is to make sure nonzero denominator in adam calculation 257 | # if iteration < int(args.epochs/2): 258 | # lr_decay = 1.0 259 | # elif iteration < int(3*args.epochs/4): 260 | # lr_decay = 0.1 261 | # else: 262 | # lr_decay = 0.01 263 | lr_decay=1.0 264 | 265 | for i, param in enumerate(model.parameters()): 266 | grad = grads[i] # recieve the aggregated (averaged) gradient 267 | 268 | # SGD calculation 269 | if args.optimizer == 'fedavg': 270 | # need to reset the trainable parameter 271 | # because we have updated the model via state_dict when dealing with batch normalization 272 | param.data.add_(param.data, alpha=-1).add_(par_before[i], alpha=1).add_(grad, alpha=args.lr * lr_decay) 273 | # param.data.add_(grad, alpha=args.lr * lr_decay) 274 | # SGD+momentum calculation 275 | elif args.optimizer == 'fedavgm': 276 | buf = momentum_buffer_list[i] 277 | buf.mul_(args.momentum).add_(grad, alpha=1) 278 | grad = buf 279 | param.data.add_(param.data, alpha=-1).add_(par_before[i], alpha=1).add_(grad, alpha=args.lr * lr_decay) 280 | # adam calculation 281 | elif args.optimizer == 'fedadam': 282 | exp_avg = exp_avgs[i] 283 | exp_avg_sq = exp_avg_sqs[i] 284 | 285 | bias_correction1 = 1 - args.beta1 ** iteration 286 | bias_correction2 = 1 - args.beta2 ** iteration 287 | 288 | exp_avg.mul_(args.beta1).add_(grad, alpha=1 - args.beta1) 289 | exp_avg_sq.mul_(args.beta2).addcmul_(grad, grad.conj(), value=1 - args.beta2) 290 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(args.eps) # without maximum 291 | 292 | step_size = args.lr * lr_decay / bias_correction1 293 | 294 | param.data.add_(param.data, alpha=-1).add_(par_before[i], alpha=1).addcdiv_(exp_avg, denom, value=step_size) 295 | elif args.optimizer == 'fedams': 296 | exp_avg = exp_avgs[i] 297 | exp_avg_sq = exp_avg_sqs[i] 298 | 299 | bias_correction1 = 1 - args.beta1 ** iteration 300 | bias_correction2 = 1 - args.beta2 ** iteration 301 | 302 | exp_avg.mul_(args.beta1).add_(grad, alpha=1 - args.beta1) 303 | exp_avg_sq.mul_(args.beta2).addcmul_(grad, grad.conj(), value=1 - args.beta2) 304 | torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) 305 | denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(args.eps) 306 | 307 | step_size = args.lr * lr_decay / bias_correction1 308 | 309 | param.data.add_(param.data, alpha=-1).add_(par_before[i], alpha=1).addcdiv_(exp_avg, denom, value=step_size) 310 | elif args.optimizer == 'fedamsd': 311 | lr_decay=1.0/math.sqrt(iteration) 312 | 313 | exp_avg = exp_avgs[i] 314 | exp_avg_sq = exp_avg_sqs[i] 315 | 316 | bias_correction1 = 1 - args.beta1 ** iteration 317 | bias_correction2 = 1 - args.beta2 ** iteration 318 | 319 | exp_avg.mul_(args.beta1).add_(grad, alpha=1 - args.beta1) 320 | exp_avg_sq.mul_(args.beta2).addcmul_(grad, grad.conj(), value=1 - args.beta2) 321 | torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) 322 | denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(args.eps) 323 | 324 | step_size = args.lr * lr_decay / bias_correction1 325 | 326 | param.data.add_(param.data, alpha=-1).add_(par_before[i], alpha=1).addcdiv_(exp_avg, denom, value=step_size) 327 | elif args.optimizer == 'fedadagrad': 328 | exp_avg_sq = exp_avg_sqs[i] 329 | exp_avg_sq.addcmul_(1, grad, grad) 330 | param.data.add_(param.data, alpha=-1).add_(par_before[i], alpha=1).addcdiv_(grad, exp_avg_sq.sqrt().add_(args.eps), value=args.lr * lr_decay) 331 | elif args.optimizer == 'fedyogi': 332 | exp_avg = exp_avgs[i] 333 | exp_avg_sq = exp_avg_sqs[i] 334 | 335 | bias_correction1 = 1 - args.beta1 ** iteration 336 | bias_correction2 = 1 - args.beta2 ** iteration 337 | 338 | exp_avg.mul_(args.beta1).add_(grad, alpha=1 - args.beta1) 339 | tmp_sq = grad ** 2 340 | tmp_diff = exp_avg_sq - tmp_sq 341 | exp_avg_sq.add_( - (1 - args.beta2), torch.sign(tmp_diff) * tmp_sq) 342 | 343 | denom = exp_avg_sq.sqrt().add_(args.eps) 344 | 345 | step_size = args.lr * lr_decay * math.sqrt(bias_correction2) / bias_correction1 346 | 347 | param.data.add_(param.data, alpha=-1).add_(par_before[i], alpha=1).addcdiv_(exp_avg, denom, value=step_size) 348 | 349 | else: 350 | exit('unknown optimizer: {}'.format(args.optimizer)) 351 | 352 | 353 | def test_inference(args, model, test_dataset): 354 | """ Returns the test accuracy and loss. 355 | """ 356 | 357 | model.eval() 358 | loss, total, correct = 0.0, 0.0, 0.0 359 | 360 | device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else "cpu") 361 | criterion = nn.CrossEntropyLoss().to(device) 362 | testloader = DataLoader(test_dataset, batch_size=128, 363 | shuffle=False) 364 | 365 | for batch_idx, (images, labels) in enumerate(testloader): 366 | images, labels = images.to(device), labels.to(device) 367 | 368 | # Inference 369 | outputs = model(images) 370 | batch_loss = criterion(outputs, labels) 371 | loss += batch_loss.item() * len(labels) 372 | 373 | # Prediction 374 | _, pred_labels = torch.max(outputs, 1) 375 | pred_labels = pred_labels.view(-1) 376 | correct += torch.sum(torch.eq(pred_labels, labels)).item() 377 | total += len(labels) 378 | 379 | accuracy = correct/total 380 | loss = loss/total 381 | return accuracy, loss 382 | 383 | 384 | -------------------------------------------------------------------------------- /MD_AirComp_FEEL/main_MIMO_channel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.8 4 | import os 5 | import copy 6 | import time 7 | import pickle 8 | import numpy as np 9 | from tqdm import tqdm 10 | import pandas as pd 11 | import math 12 | import torch 13 | from tensorboardX import SummaryWriter 14 | 15 | from options import args_parser 16 | from update import LocalUpdate, update_model_inplace, test_inference 17 | from utils import get_model, get_dataset, average_weights, exp_details, average_parameter_delta 18 | import utils 19 | from sklearn.cluster import kmeans_plusplus 20 | import faiss 21 | import scipy.stats as st 22 | import math 23 | def AMP_DA(y, X, H): 24 | H = H.cuda() 25 | y = y.cuda() 26 | X = X.cuda() 27 | 28 | N_RAs = H.shape[0] 29 | N_UEs = H.shape[1] 30 | N_dim = y.shape[1] 31 | N_M = y.shape[2] 32 | tol = 1e-5 33 | exx = 1e-10 34 | damp = 0.3 35 | alphabet = torch.arange(0.0, args.num_users * args.frac + 1, 1) 36 | M = len(alphabet) - 1 37 | 38 | lam = N_RAs / N_UEs 39 | c = torch.arange(0.01, 10, 10 / 1024) 40 | rho = (1 - 2 * N_UEs * ((1 + c ** 2) * st.norm.cdf(-c) - c * st.norm.pdf(c)) / N_RAs) / ( 41 | 1 + c ** 2 - 2 * ((1 + c ** 2) * st.norm.cdf(-c) - c * st.norm.pdf(c))) 42 | alpha = lam * torch.max(rho) * torch.ones((N_UEs, N_dim)) 43 | x_hat0 = (alpha * torch.sum(alphabet) / M * torch.ones((N_UEs, N_dim)))[:, :, None].repeat(1, 1, N_M) 44 | x_hat = (x_hat0 + 1j * x_hat0).cuda() 45 | var_hat = torch.ones((N_UEs, N_dim, N_M)).cuda() 46 | 47 | V = torch.ones((N_RAs, N_dim, N_M)).cuda() 48 | V_new = torch.ones((N_RAs, N_dim, N_M)).cuda() 49 | Z_new = y.clone() 50 | sigma2 = 100 51 | t = 1 52 | Z = y.clone() 53 | maxIte = 50 54 | MSE = torch.zeros(maxIte) 55 | MSE[0] = 100 56 | hvar = (torch.norm(y) ** 2 - N_RAs * sigma2) / (N_dim * lam * torch.max(rho) * torch.norm(H) ** 2) 57 | hmean = 0 58 | alpha_new = torch.ones((N_UEs, N_dim, N_M)) 59 | x_hat_new = torch.ones((N_UEs, N_dim, N_M)) + 1j * torch.ones((N_UEs, N_dim, N_M)) 60 | var_hat_new = torch.ones((N_UEs, N_dim, N_M)) 61 | 62 | hvarnew = torch.zeros(N_M) 63 | hmeannew = torch.zeros(N_M) + 1j * torch.zeros(N_M) 64 | sigma2new = torch.zeros(N_M) 65 | 66 | alphabet = alphabet.cuda() 67 | alpha = alpha.cuda() 68 | while t < maxIte: 69 | x_hat_pre = x_hat.clone() 70 | for i in range(N_M): 71 | V_new[:, :, i] = torch.abs(H) ** 2 @ var_hat[:, :, i] 72 | Z_new[:, :, i] = H @ x_hat[:, :, i] - ((y[:, :, i] - Z[:, :, i]) / (sigma2 + V[:, :, i])) * V_new[:, :, i] # + 1e-8 73 | 74 | Z_new[:, :, i] = damp * Z[:, :, i] + (1 - damp) * Z_new[:, :, i] 75 | V_new[:, :, i] = damp * V[:, :, i] + (1 - damp) * V_new[:, :, i] 76 | 77 | var1 = (torch.abs(H) ** 2).T @ (1 / (sigma2 + V_new[:, :, i])) 78 | var2 = H.conj().T @ ((y[:, :, i] - Z_new[:, :, i]) / (sigma2 + V_new[:, :, i])) 79 | 80 | Ri = var2 / (var1) + x_hat[:, :, i] 81 | Vi = 1 / (var1) 82 | 83 | sigma2new[i] = ((torch.abs(y[:, :, i] - Z_new[:, :, i]) ** 2) / ( 84 | torch.abs(1 + V_new[:, :, i] / sigma2) ** 2) + sigma2 * V_new[:, :, i] / ( 85 | V_new[:, :, i] + sigma2)).mean() 86 | 87 | if i == 0: 88 | r_s = Ri[None, :, :].repeat(M + 1, 1, 1) - alphabet[:, None, None].repeat(1, N_UEs, N_dim) 89 | pf8 = torch.exp(-(torch.abs(r_s) ** 2 / Vi)) / Vi / math.pi 90 | pf7 = torch.zeros((M + 1, N_UEs, N_dim)).cuda() 91 | pf7[0, :, :] = pf8[0, :, :] * (torch.ones((N_UEs, N_dim)).cuda() - alpha) 92 | pf7[1:, :, :] = pf8[1:, :, :] * (alpha / M) 93 | del pf8 94 | PF7 = torch.sum(pf7, axis=0) 95 | pf6 = pf7 / PF7 96 | del pf7, PF7 97 | AAA = alphabet[None, :, None].repeat(N_dim, 1, 1) 98 | BBB = torch.permute(pf6,(2,1,0)) 99 | x_hat_new[:, :, i] = (torch.einsum("ijk,ikn->ijn", BBB, AAA).squeeze(-1)).T 100 | del AAA 101 | alphabet2 = alphabet ** 2 102 | AAA2 = alphabet2[None, :, None].repeat(N_dim, 1, 1) 103 | var_hat_new[:, :, i] = (torch.einsum("ijk,ikn->ijn", BBB, AAA2).squeeze(-1)).T.cpu() - torch.abs( 104 | x_hat_new[:, :, i]) ** 2 105 | del AAA2 106 | alpha_new[:, :, i] = torch.clamp(torch.sum(pf6[1:, :, :], axis=0), exx, 1 - exx) 107 | del pf6 108 | else: 109 | A = (hvar * Vi) / (Vi + hvar) 110 | B = (hvar * Ri + Vi * hmean) / (Vi + hvar) 111 | lll = torch.log(Vi / (Vi + hvar)) / 2 + torch.abs(Ri) ** 2 / 2 / Vi - torch.abs(Ri - hmean) ** 2 / 2 / ( 112 | Vi + hvar) 113 | pai = torch.clamp(alpha / (alpha + (1 - alpha) * torch.exp(-lll)), exx, 1 - exx, out=None) 114 | x_hat_new[:, :, i] = pai * B 115 | var_hat_new[:, :, i] = (pai * (torch.abs(B) ** 2 + A)).cpu() - torch.abs(x_hat_new[:, :, i]) ** 2 116 | # mean update 117 | hmeannew[i] = (torch.sum(pai * B, axis=0) / torch.sum(pai, axis=0)).mean() 118 | # variance update 119 | hvarnew[i] = (torch.sum(pai * (torch.abs(hmean - B) ** 2 + Vi), axis=0) / torch.sum(pai, axis=0)).mean() 120 | # activity indicator update 121 | alpha_new[:, :, i] = torch.clamp(pai, exx, 1 - exx) 122 | if N_M > 1: 123 | hvar = hvarnew[1:].mean() 124 | hmean = hmeannew[1:].mean() 125 | sigma2 = sigma2new.mean() 126 | alpha = (torch.sum(alpha_new, axis=2) / N_M).cuda() 127 | # alpha = alpha_new 128 | III = x_hat_pre.cpu() - x_hat_new 129 | NMSE_iter = torch.sum(torch.abs(III) ** 2) / torch.sum(torch.abs(x_hat_new) ** 2) 130 | # del III 131 | MSE[t] = torch.sum(torch.abs(y - torch.permute( 132 | torch.einsum("ijk,ikn->ijn", torch.permute(x_hat, (2, 1, 0)), H.T[None, :, :].repeat(N_M, 1, 1)), 133 | (2, 1, 0))) ** 2) / N_RAs / N_dim / N_M 134 | 135 | x_hat = x_hat_new.cuda().clone() 136 | if t > 15 and MSE[t] >= MSE[t - 1]: 137 | x_hat = x_hat_pre.clone() 138 | break 139 | 140 | NMSE = 10 * math.log10(torch.sum(torch.abs(x_hat[:, :, 0] - X[:, :, 0]) ** 2) / torch.sum(torch.abs(X[:, :, 0]) ** 2)) 141 | 142 | var_hat = var_hat_new.cuda().clone() 143 | # alpha = alpha_new 144 | V = V_new.clone() 145 | Z = Z_new.clone() 146 | t = t + 1 147 | return x_hat, var_hat, alpha, t, NMSE 148 | 149 | 150 | def UMA_MIMO(MatInd, args): 151 | tau = 1 # factor of channel imperfection 152 | Na = args.antenna # Number of antennas at BS 153 | Ka = MatInd.shape[0] 154 | h_a = (np.random.randn(Ka, Na).astype(np.float32) + 1j * np.random.randn(Ka, Na).astype(np.float32)) / np.sqrt(2) 155 | e = (np.random.randn(Ka).astype(np.float32) + 1j * np.random.randn(Ka).astype(np.float32)) / np.sqrt(2) 156 | h_e = tau*h_a[:, 0] + np.sqrt(1-tau)*e 157 | h_d = (1 / h_e)*h_a.T 158 | if tau == 1: 159 | h_d[0, :] = np.ones([1, Ka]).astype(np.float32) 160 | Np = MatInd.shape[1] # Number of SMV problems 161 | 162 | Ph = abs(h_e) 163 | IdPh = np.where(Ph < 0.14) 164 | h_d[:, IdPh] = 0 165 | 166 | X_eq = np.zeros((args.M, Np, Na)).astype(np.float32) + 1j * np.zeros((args.M, Np, Na)).astype(np.float32) 167 | I1tmp = np.arange(Np) 168 | for i in range(Ka): 169 | I0tmp = MatInd[i, :] 170 | X_eq[I0tmp, I1tmp, :] = X_eq[I0tmp, I1tmp, :] + h_d[:, i] 171 | 172 | 173 | Y0 = np.einsum("ijk,ikn->ijn", X_eq.T, args.UM.T[np.newaxis, :, :].repeat(Na, axis=0)).T 174 | 175 | Ps = np.linalg.norm(Y0.reshape(1, -1), ord='fro') ** 2 / Np / args.L/ Na 176 | snr = 10 ** (args.SNR / 10) 177 | Pn = Ps / snr 178 | Y = Y0 + (math.sqrt(Pn / 2) * np.random.randn(args.L, Np, Na).astype(np.float32) + 1j * np.random.randn(args.L, Np, Na).astype(np.float32)) 179 | 180 | # if args.Algo == 'AMP': 181 | H = torch.tensor(args.UM) 182 | Y = torch.tensor(Y) 183 | X_eq = torch.tensor(X_eq) 184 | x_hat, var_hat, alpha, t, NMSE = AMP_DA(Y, X_eq, H) 185 | 186 | return x_hat, alpha, NMSE, X_eq[:, :, 0] 187 | 188 | if __name__ == '__main__': 189 | start_time = time.time() 190 | 191 | # parse args 192 | args = args_parser() 193 | args.seed = 42 194 | args.M = 2**6 # quantization levels 195 | args.Vb = 20 # diemnsion of each vector quantization 196 | args.iid = 0 # non i.i.d. data distribution 197 | args.L = 20 # length of each transmit codeword 198 | UMmat = args.UM[:args.L, :] # transmit codebook matrix 199 | args.UM = UMmat 200 | args.antenna = 2 # number of antennas at BS 201 | args.Nummm = 10000 # number of data samples for data split 202 | args.epochs = 1000 # number of global rounds 203 | args.model = 'resnet-s' # model name 204 | exp_details(args) 205 | # import os 206 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1" 207 | 208 | # define paths 209 | file_name = '/Results_MIMOchannel_seed{}_Na{}_L{}_{}_{}_llr[{}]_glr[{}]_Vb[{}]_le[{}]_bs[{}]_iid[{}]_Ql[{}]_frac[{}]_{}.pkl'.\ 210 | format(args.seed, args.antenna, args.L, args.model, args.optimizer, 211 | args.local_lr, args.lr, args.Vb, 212 | args.local_ep, args.local_bs, args.iid, args.M, args.frac, args.compressor) 213 | logger = SummaryWriter('./logs/'+file_name) 214 | 215 | device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else "cpu") 216 | torch.set_num_threads(1) # limit cpu use 217 | print ('-- pytorch version: ', torch.__version__) 218 | 219 | np.random.seed(args.seed) 220 | torch.manual_seed(args.seed) 221 | if device != 'cpu': 222 | torch.cuda.manual_seed(args.seed) 223 | 224 | if not os.path.exists(args.outfolder): 225 | os.mkdir(args.outfolder) 226 | 227 | # load dataset and user groups 228 | train_dataset, test_dataset, num_classes, user_groups = get_dataset(args) 229 | 230 | # Set the model to train and send it to device. 231 | global_model = get_model(args.model, args.dataset, train_dataset[0][0].shape, num_classes) 232 | global_model.to(device) 233 | global_model.train() 234 | 235 | momentum_buffer_list = [] 236 | exp_avgs = [] 237 | exp_avg_sqs = [] 238 | max_exp_avg_sqs = [] 239 | for i, p in enumerate(global_model.parameters()): 240 | momentum_buffer_list.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)) 241 | exp_avgs.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)) 242 | exp_avg_sqs.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)) 243 | max_exp_avg_sqs.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)+args.max_init) # 1e-2 244 | 245 | ### init error ------- 246 | e = [] 247 | for id in range(args.num_users): 248 | ei = [] 249 | for i, p in enumerate(global_model.parameters()): 250 | ei.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)) 251 | e.append(ei) 252 | D = sum(p.numel() for p in global_model.parameters()) 253 | print('total dimension:', D) 254 | print('compressor:', args.compressor) 255 | 256 | # Check if divisibility condition is satisfied 257 | s_nExtraZero = D % args.Vb 258 | if s_nExtraZero != 0: 259 | s_nExtraZero = args.Vb - s_nExtraZero 260 | 261 | # initialize error feedback 262 | Qerr = torch.ones((int(args.frac * args.num_users), args.epochs)) 263 | 264 | # Training 265 | train_loss_sampled, train_loss, train_accuracy = [], [], [] 266 | test_loss, test_accuracy = [], [] 267 | start_time = time.time() 268 | NMSEtot = np.zeros(args.epochs) 269 | KEST0 = [] 270 | KEST =[] 271 | KACT = [] 272 | for epoch in tqdm(range(args.epochs)): 273 | ep_time = time.time() 274 | 275 | local_weights, local_params, local_losses = [], [], [] 276 | print(f'\n | Global Training Round : {epoch+1} |\n') 277 | 278 | 279 | par_before = [] 280 | for p in global_model.parameters(): # get trainable parameters 281 | par_before.append(p.data.detach().clone()) 282 | # this is to store parameters before update 283 | w0 = global_model.state_dict() # get all parameters, including batch normalization related ones 284 | 285 | 286 | global_model.train() 287 | m = max(int(args.frac * args.num_users), 1) 288 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 289 | 290 | # user for loop 291 | for ooo in range(len(idxs_users)): 292 | idx = idxs_users[ooo] 293 | local_model = LocalUpdate(args=args, dataset=train_dataset, 294 | idxs=user_groups[idx], logger=logger) 295 | 296 | w, p, loss = local_model.update_weights_local( 297 | model=copy.deepcopy(global_model), global_round=epoch) 298 | 299 | 300 | ####### add error feedback ####### 301 | delta = utils.sub_params(p, par_before) 302 | tmp = utils.add_params(e[idx], delta) 303 | 304 | for i in range(len(tmp)): 305 | if i == 0: 306 | VecTmp = tmp[i].cpu().reshape(1, -1) 307 | else: 308 | VecTmp = torch.cat((VecTmp, tmp[i].cpu().reshape(1, -1)), 1) 309 | VecTmp = torch.cat((VecTmp, torch.zeros((1, s_nExtraZero))), 1) 310 | 311 | # Vector Quantization 312 | if ooo == 0: 313 | data2 = VecTmp.numpy().reshape(-1, args.Vb) 314 | centers, _ = kmeans_plusplus(data2, n_clusters=args.M, random_state=0) 315 | index = faiss.IndexFlatL2(args.Vb) 316 | index.add(centers) 317 | _, intIndex = index.search(VecTmp.numpy().reshape(-1, args.Vb), 1) 318 | SigVQ0 = torch.tensor(centers[intIndex].reshape(1, -1)) 319 | else: 320 | _, intIndex = index.search(VecTmp.numpy().reshape(-1, args.Vb), 1) 321 | SigVQ0 = torch.tensor(centers[intIndex].reshape(1, -1)) 322 | if ooo == 0: 323 | MatInd = intIndex.T 324 | else: 325 | MatInd = np.vstack([MatInd, intIndex.T]) 326 | del intIndex 327 | # Quantization error 328 | Qerr[ooo, epoch] = 10 * torch.log10(torch.sum((torch.abs(SigVQ0 - VecTmp) ** 2)) / torch.sum((torch.abs(VecTmp) ** 2))) 329 | if s_nExtraZero == 0: 330 | SigVQ = SigVQ0 331 | else: 332 | SigVQ = SigVQ0[:, :-s_nExtraZero] 333 | # Delta update 334 | delta_out = [] 335 | L = 0 336 | for i in range(len(tmp)): 337 | Ll = L + tmp[i].numel() 338 | delta_out.append(SigVQ[:, L:Ll].reshape(tmp[i].shape).cuda()) 339 | L = Ll 340 | del SigVQ 341 | # Error update 342 | e[idx] = utils.sub_params(tmp, delta_out) 343 | 344 | local_weights.append(copy.deepcopy(w)) 345 | local_losses.append(copy.deepcopy(loss)) 346 | 347 | # Transmission through MIMO channel 348 | x_hat, alpha, NMSE, X_eq = UMA_MIMO(MatInd, args) 349 | x_hat = x_hat.cpu().numpy() 350 | X_eq = X_eq.cpu().numpy() # actual transmitted signal 351 | 352 | Kact = int(np.real(np.sum(X_eq, 0)[0])) # actual number of active users 353 | KACT.append(Kact) 354 | 355 | NMSEtot[epoch] = NMSE 356 | temp = pd.Series(np.sum(abs(np.around(x_hat[:, :, 0])), axis=0)) 357 | Cont = temp.value_counts() 358 | Kest = int(Cont.keys()[0]) # estimated number of active users (proposed) 359 | KEST.append(Kest) 360 | Kest0 = int(np.mean(temp)) # estimated number of active users (benchmark) 361 | KEST0.append(Kest0) 362 | Est_delta = (np.abs(np.round(x_hat[:, :, 0].T)) @ centers).reshape(-1, 1).T / Kest 363 | 364 | # sparsity level for the AMP-DA algorithm 365 | Sptot = np.count_nonzero(X_eq, axis=0) 366 | if epoch == 0: 367 | SpLev = Sptot 368 | else: 369 | SpLev = np.vstack((SpLev, Sptot)) 370 | 371 | # Check if divisibility condition is satisfied 372 | if s_nExtraZero == 0: 373 | Est_delta = Est_delta 374 | else: 375 | Est_delta = Est_delta[:, :-s_nExtraZero] 376 | 377 | 378 | global_delta = [] 379 | L = 0 380 | for i in range(len(tmp)): 381 | Ll = L + tmp[i].numel() 382 | global_delta.append(torch.tensor(Est_delta[:, L:Ll]).reshape(tmp[i].shape).cuda()) 383 | L = Ll 384 | 385 | bn_weights = average_weights(local_weights) 386 | global_model.load_state_dict(bn_weights) 387 | 388 | 389 | update_model_inplace( 390 | global_model, par_before, global_delta, args, epoch, 391 | momentum_buffer_list, exp_avgs, exp_avg_sqs, max_exp_avg_sqs) 392 | 393 | # report and store loss and accuracy 394 | # this is local training loss on sampled users 395 | loss_avg = sum(local_losses) / len(local_losses) 396 | train_loss.append(loss_avg) 397 | 398 | global_model.eval() 399 | 400 | # Test inference after completion of training 401 | test_acc, test_ls = test_inference(args, global_model, test_dataset) 402 | test_accuracy.append(test_acc) 403 | test_loss.append(test_ls) 404 | 405 | # print global training loss after every rounds 406 | print('Epoch Run Time: {0:0.4f} of {1} global rounds'.format(time.time()-ep_time, epoch+1)) 407 | print(f'Training Loss : {train_loss[-1]}') 408 | print(f'Test Loss : {test_loss[-1]}') 409 | print(f'Test Accuracy : {test_accuracy[-1]} \n') 410 | print(f'NMSE UMA : {NMSE} \n') 411 | logger.add_scalar('train loss', train_loss[-1], epoch) 412 | logger.add_scalar('test loss', test_loss[-1], epoch) 413 | logger.add_scalar('test acc', test_accuracy[-1], epoch) 414 | 415 | if args.save: 416 | # Saving the objects train_loss and train_accuracy: 417 | 418 | 419 | with open(args.outfolder + file_name, 'wb') as f: 420 | pickle.dump([train_loss, test_loss, test_accuracy, Qerr, SpLev, NMSEtot, KEST, KEST0, KACT], f) 421 | 422 | print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time)) 423 | 424 | 425 | --------------------------------------------------------------------------------