├── .gitignore ├── FLModel.py ├── LICENSE ├── MLModel.py ├── README.md ├── __init__.py ├── rdp_analysis.py ├── test_cnn.ipynb ├── test_scatter_linear.ipynb └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | /FedCF/ 3 | .ipynb_checkpoints/ 4 | __pycache__/ 5 | -------------------------------------------------------------------------------- /FLModel.py: -------------------------------------------------------------------------------- 1 | # Federated Learning Model in PyTorch 2 | import torch 3 | from torch import nn 4 | from torch.utils.data import Dataset, DataLoader, TensorDataset 5 | from utils import gaussian_noise 6 | from rdp_analysis import calibrating_sampled_gaussian 7 | 8 | from MLModel import * 9 | 10 | import numpy as np 11 | import copy 12 | 13 | 14 | class FLClient(nn.Module): 15 | """ Client of Federated Learning framework. 16 | 1. Receive global model from server 17 | 2. Perform local training (compute gradients) 18 | 3. Return local model (gradients) to server 19 | """ 20 | def __init__(self, model, output_size, data, lr, E, batch_size, q, clip, sigma, device=None): 21 | """ 22 | :param model: ML model's training process should be implemented 23 | :param data: (tuple) dataset, all data in client side is used as training data 24 | :param lr: learning rate 25 | :param E: epoch of local update 26 | """ 27 | super(FLClient, self).__init__() 28 | self.device = device 29 | self.BATCH_SIZE = batch_size 30 | self.torch_dataset = TensorDataset(torch.tensor(data[0]), 31 | torch.tensor(data[1])) 32 | self.data_size = len(self.torch_dataset) 33 | self.data_loader = DataLoader( 34 | dataset=self.torch_dataset, 35 | batch_size=self.BATCH_SIZE, 36 | shuffle=True 37 | ) 38 | self.sigma = sigma # DP noise level 39 | self.lr = lr 40 | self.E = E 41 | self.clip = clip 42 | self.q = q 43 | if model == 'scatter': 44 | self.model = ScatterLinear(81, (7, 7), input_norm="GroupNorm", num_groups=27).to(self.device) 45 | else: 46 | self.model = model(data[0].shape[1], output_size).to(self.device) 47 | 48 | def recv(self, model_param): 49 | """receive global model from aggregator (server)""" 50 | self.model.load_state_dict(copy.deepcopy(model_param)) 51 | 52 | def update(self): 53 | """local model update""" 54 | self.model.train() 55 | criterion = nn.CrossEntropyLoss(reduction='none') 56 | optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9) 57 | # optimizer = torch.optim.Adam(self.model.parameters()) 58 | 59 | for e in range(self.E): 60 | # randomly select q fraction samples from data 61 | # according to the privacy analysis of moments accountant 62 | # training "Lots" are sampled by poisson sampling 63 | idx = np.where(np.random.rand(len(self.torch_dataset[:][0])) < self.q)[0] 64 | 65 | sampled_dataset = TensorDataset(self.torch_dataset[idx][0], self.torch_dataset[idx][1]) 66 | sample_data_loader = DataLoader( 67 | dataset=sampled_dataset, 68 | batch_size=self.BATCH_SIZE, 69 | shuffle=True 70 | ) 71 | 72 | optimizer.zero_grad() 73 | 74 | clipped_grads = {name: torch.zeros_like(param) for name, param in self.model.named_parameters()} 75 | for batch_x, batch_y in sample_data_loader: 76 | batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device) 77 | pred_y = self.model(batch_x.float()) 78 | loss = criterion(pred_y, batch_y.long()) 79 | 80 | # bound l2 sensitivity (gradient clipping) 81 | # clip each of the gradient in the "Lot" 82 | for i in range(loss.size()[0]): 83 | loss[i].backward(retain_graph=True) 84 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip) 85 | for name, param in self.model.named_parameters(): 86 | clipped_grads[name] += param.grad 87 | self.model.zero_grad() 88 | 89 | # add Gaussian noise 90 | for name, param in self.model.named_parameters(): 91 | clipped_grads[name] += gaussian_noise(clipped_grads[name].shape, self.clip, self.sigma, device=self.device) 92 | 93 | # scale back 94 | for name, param in self.model.named_parameters(): 95 | clipped_grads[name] /= (self.data_size*self.q) 96 | 97 | for name, param in self.model.named_parameters(): 98 | param.grad = clipped_grads[name] 99 | 100 | # update local model 101 | optimizer.step() 102 | 103 | 104 | 105 | class FLServer(nn.Module): 106 | """ Server of Federated Learning 107 | 1. Receive model (or gradients) from clients 108 | 2. Aggregate local models (or gradients) 109 | 3. Compute global model, broadcast global model to clients 110 | """ 111 | def __init__(self, fl_param): 112 | super(FLServer, self).__init__() 113 | self.device = fl_param['device'] 114 | self.client_num = fl_param['client_num'] 115 | self.C = fl_param['C'] # (float) C in [0, 1] 116 | self.clip = fl_param['clip'] 117 | self.T = fl_param['tot_T'] # total number of global iterations (communication rounds) 118 | 119 | self.data = [] 120 | self.target = [] 121 | for sample in fl_param['data'][self.client_num:]: 122 | self.data += [torch.tensor(sample[0]).to(self.device)] # test set 123 | self.target += [torch.tensor(sample[1]).to(self.device)] # target label 124 | 125 | self.input_size = int(self.data[0].shape[1]) 126 | self.lr = fl_param['lr'] 127 | 128 | # compute noise using moments accountant 129 | # self.sigma = compute_noise(1, fl_param['q'], fl_param['eps'], fl_param['E']*fl_param['tot_T'], fl_param['delta'], 1e-5) 130 | 131 | # calibration with subsampeld Gaussian mechanism under composition 132 | self.sigma = calibrating_sampled_gaussian(fl_param['q'], fl_param['eps'], fl_param['delta'], iters=fl_param['E']*fl_param['tot_T'], err=1e-3) 133 | print("noise scale =", self.sigma) 134 | 135 | self.clients = [FLClient(fl_param['model'], 136 | fl_param['output_size'], 137 | fl_param['data'][i], 138 | fl_param['lr'], 139 | fl_param['E'], 140 | fl_param['batch_size'], 141 | fl_param['q'], 142 | fl_param['clip'], 143 | self.sigma, 144 | self.device) 145 | for i in range(self.client_num)] 146 | 147 | if fl_param['model'] == 'scatter': 148 | self.global_model = ScatterLinear(81, (7, 7), input_norm="GroupNorm", num_groups=27).to(self.device) 149 | else: 150 | self.global_model = fl_param['model'](self.input_size, fl_param['output_size']).to(self.device) 151 | 152 | self.weight = np.array([client.data_size * 1.0 for client in self.clients]) 153 | self.broadcast(self.global_model.state_dict()) 154 | 155 | def aggregated(self, idxs_users): 156 | """FedAvg""" 157 | model_par = [self.clients[idx].model.state_dict() for idx in idxs_users] 158 | new_par = copy.deepcopy(model_par[0]) 159 | for name in new_par: 160 | new_par[name] = torch.zeros(new_par[name].shape).to(self.device) 161 | for idx, par in enumerate(model_par): 162 | w = self.weight[idxs_users[idx]] / np.sum(self.weight[:]) 163 | for name in new_par: 164 | # new_par[name] += par[name] * (self.weight[idxs_users[idx]] / np.sum(self.weight[idxs_users])) 165 | new_par[name] += par[name] * (w / self.C) 166 | self.global_model.load_state_dict(copy.deepcopy(new_par)) 167 | return self.global_model.state_dict().copy() 168 | 169 | def broadcast(self, new_par): 170 | """Send aggregated model to all clients""" 171 | for client in self.clients: 172 | client.recv(new_par.copy()) 173 | 174 | def test_acc(self): 175 | self.global_model.eval() 176 | correct = 0 177 | tot_sample = 0 178 | for i in range(len(self.data)): 179 | t_pred_y = self.global_model(self.data[i]) 180 | _, predicted = torch.max(t_pred_y, 1) 181 | correct += (predicted == self.target[i]).sum().item() 182 | tot_sample += self.target[i].size(0) 183 | acc = correct / tot_sample 184 | return acc 185 | 186 | def global_update(self): 187 | # idxs_users = np.random.choice(range(len(self.clients)), int(self.C * len(self.clients)), replace=False) 188 | idxs_users = np.sort(np.random.choice(range(len(self.clients)), int(self.C * len(self.clients)), replace=False)) 189 | for idx in idxs_users: 190 | self.clients[idx].update() 191 | self.broadcast(self.aggregated(idxs_users)) 192 | acc = self.test_acc() 193 | torch.cuda.empty_cache() 194 | return acc 195 | 196 | def set_lr(self, lr): 197 | for c in self.clients: 198 | c.lr = lr 199 | 200 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 jyfan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MLModel.py: -------------------------------------------------------------------------------- 1 | # Machine learning models 2 | import torch 3 | from torch import nn 4 | from kymatio.torch import Scattering2D 5 | 6 | class MNIST_CNN(nn.Module): 7 | """ 8 | End-to-end CNN model for MNIST and Fashion-MNIST, with Tanh activations. 9 | References: 10 | - Papernot, Nicolas, et al. Tempered Sigmoid Activations for Deep Learning with Differential Privacy. In AAAI 2021. 11 | - Tramer, Florian, and Dan Boneh. Differentially Private Learning Needs Better Features (or Much More Data). In ICLR 2021. 12 | """ 13 | def __init__(self, input_dim, output_dim): 14 | super(MNIST_CNN, self).__init__() 15 | 16 | self.layer1 = nn.Sequential( 17 | nn.Conv2d(1, 16, kernel_size=8, stride=2, padding=2), 18 | nn.Tanh(), 19 | nn.MaxPool2d(kernel_size=2, stride=1)) 20 | 21 | self.layer2 = nn.Sequential( 22 | nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=0), 23 | nn.Tanh(), 24 | nn.MaxPool2d(kernel_size=2, stride=1)) 25 | 26 | self.fc = nn.Sequential(nn.Linear(4 * 4 * 32, 32), 27 | nn.Tanh(), 28 | nn.Linear(32, 10)) 29 | 30 | def forward(self, x): 31 | x = self.layer1(x) 32 | x = self.layer2(x) 33 | x = x.view(x.size(0), -1) 34 | x = self.fc(x) 35 | return x 36 | 37 | 38 | def get_scatter_transform(): 39 | shape = (28, 28, 1) 40 | scattering = Scattering2D(J=2, shape=shape[:2]) 41 | K = 81 * shape[2] 42 | (h, w) = shape[:2] 43 | return scattering, K, (h//4, w//4) 44 | 45 | 46 | class ScatterLinear(nn.Module): 47 | """ 48 | ScatterNet model used in the following paper 49 | - Tramer, Florian, and Dan Boneh. Differentially Private Learning Needs Better Features (or Much More Data). In ICLR 2021. 50 | See https://github.com/ftramer/Handcrafted-DP/blob/main/models.py 51 | """ 52 | def __init__(self, in_channels, hw_dims, input_norm=None, classes=10, clip_norm=None, **kwargs): 53 | super(ScatterLinear, self).__init__() 54 | self.K = in_channels 55 | self.h = hw_dims[0] 56 | self.w = hw_dims[1] 57 | self.fc = None 58 | self.norm = None 59 | self.clip = None 60 | self.build(input_norm, classes=classes, clip_norm=clip_norm, **kwargs) 61 | 62 | def build(self, input_norm=None, num_groups=None, bn_stats=None, clip_norm=None, classes=10): 63 | self.fc = nn.Linear(self.K * self.h * self.w, classes) 64 | 65 | if input_norm is None: 66 | self.norm = nn.Identity() 67 | elif input_norm == "GroupNorm": 68 | self.norm = nn.GroupNorm(num_groups, self.K, affine=False) 69 | else: 70 | self.norm = lambda x: standardize(x, bn_stats) 71 | 72 | if clip_norm is None: 73 | self.clip = nn.Identity() 74 | else: 75 | self.clip = ClipLayer(clip_norm) 76 | 77 | def forward(self, x): 78 | x = self.norm(x.view(-1, self.K, self.h, self.w)) 79 | x = self.clip(x) 80 | x = x.reshape(x.size(0), -1) 81 | x = self.fc(x) 82 | return x 83 | 84 | 85 | class LogisticRegression(nn.Module): 86 | """Logistic regression""" 87 | def __init__(self, num_feature, output_size): 88 | super(LogisticRegression, self).__init__() 89 | self.linear = nn.Linear(num_feature, output_size) 90 | 91 | def forward(self, x): 92 | return self.linear(x) 93 | 94 | 95 | class MLP(nn.Module): 96 | """Neural Networks""" 97 | def __init__(self, input_dim, output_dim): 98 | super(MLP, self).__init__() 99 | self.model = nn.Sequential( 100 | nn.Linear(input_dim, 1000), 101 | nn.Tanh(), 102 | 103 | nn.Linear(1000, output_dim)) 104 | 105 | def forward(self, x): 106 | return self.model(x) 107 | 108 | 109 | class three_layer_MLP(nn.Module): 110 | """Neural Networks""" 111 | def __init__(self, input_dim, output_dim): 112 | super(three_layer_MLP, self).__init__() 113 | self.model = nn.Sequential( 114 | nn.Linear(input_dim, 600), 115 | nn.Dropout(0.2), 116 | nn.ReLU(), 117 | 118 | nn.Linear(600, 300), 119 | nn.Dropout(0.2), 120 | nn.ReLU(), 121 | 122 | nn.Linear(300, 100), 123 | nn.Dropout(0.2), 124 | nn.ReLU(), 125 | 126 | nn.Linear(100, output_dim)) 127 | 128 | def forward(self, x): 129 | return self.model(x) 130 | 131 | 132 | class MnistCNN_(nn.Module): 133 | def __init__(self, input_dim, output_dim): 134 | super(MnistCNN_, self).__init__() 135 | self.layer1 = nn.Sequential( 136 | nn.Conv2d(1, 16, kernel_size=5, padding=2), 137 | nn.ReLU(), 138 | nn.MaxPool2d(2)) 139 | 140 | self.layer2 = nn.Sequential( 141 | nn.Conv2d(16, 32, kernel_size=5, padding=2), 142 | nn.ReLU(), 143 | nn.MaxPool2d(2)) 144 | 145 | self.fc = nn.Linear(7 * 7 * 32, 10) 146 | 147 | def forward(self, x): 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = x.view(x.size(0), -1) 151 | x = self.fc(x) 152 | return x 153 | 154 | 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Federated Learning 2 | 3 | This is an implementation of **Federated Learning (FL)** with **Differential Privacy (DP)**. The FL algorithm is FedAvg, based on the paper [Communication-Efficient Learning of Deep Networks from Decentralized Data](https://arxiv.org/abs/1602.05629). Each client trains local model by DP-SGD [2] to perturb model parameters. The noise multiplier is determined by [3-5] (see rdp_analysis.py). 4 | 5 | ## Requirements 6 | - torch, torchvision 7 | - numpy 8 | - scipy 9 | 10 | ## Files 11 | > FLModel.py: definition of the FL client and FL server class 12 | 13 | > MLModel.py: CNN model for MNIST datasets 14 | 15 | > rdp_analysis.py: RDP for subsampled Gaussian [3], convert RDP to DP by Ref. [4, 5] (tighter privacy analysis than [2]). 16 | 17 | > utils.py: sample MNIST in a non-i.i.d. manner 18 | 19 | ## Usag 20 | Run test_cnn.ipynb 21 | 22 | ### FL model parameters 23 | ```python 24 | # code segment in test_cnn.ipynb 25 | lr = 0.1 26 | fl_param = { 27 | 'output_size': 10, # number of units in output layer 28 | 'client_num': client_num, # number of clients 29 | 'model': MNIST_CNN, # model 30 | 'data': d, # dataset 31 | 'lr': lr, # learning rate 32 | 'E': 500, # number of local iterations 33 | 'eps': 4.0, # privacy budget 34 | 'delta': 1e-5, # approximate differential privacy: (epsilon, delta)-DP 35 | 'q': 0.01, # sampling rate 36 | 'clip': 0.2, # clipping norm 37 | 'tot_T': 10, # number of aggregation times (communication rounds) 38 | } 39 | ``` 40 | 41 | 42 | ## References 43 | [1] McMahan, Brendan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas. Communication-Efficient Learning of Deep Networks from Decentralized Data. In *AISTATS*, 2017. 44 | 45 | [2] Abadi, Martin, et al. Deep learning with differential privacy. In *CCS*. 2016. 46 | 47 | [3] Mironov, Ilya, Kunal Talwar, and Li Zhang. R\'enyi differential privacy of the sampled gaussian mechanism. arXiv preprint 2019. 48 | 49 | [4] Canonne, Clément L., Gautam Kamath, and Thomas Steinke. The discrete gaussian for differential privacy. In *NeurIPS*, 2020. 50 | 51 | [5] Asoodeh, S., Liao, J., Calmon, F.P., Kosut, O. and Sankar, L., A better bound gives a hundred rounds: Enhanced privacy guarantees via f-divergences. In *ISIT*, 2020. -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | -------------------------------------------------------------------------------- /rdp_analysis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from decimal import * 3 | from scipy.special import comb 4 | 5 | getcontext().prec = 128 6 | 7 | 8 | def rdp2dp(rdp, bad_event, alpha): 9 | """ 10 | convert RDP to DP, Ref: 11 | - Canonne, Clément L., Gautam Kamath, and Thomas Steinke. The discrete gaussian for differential privacy. In NeurIPS, 2020. (See Proposition 12) 12 | - Asoodeh, S., Liao, J., Calmon, F.P., Kosut, O. and Sankar, L., A better bound gives a hundred rounds: Enhanced privacy guarantees via f-divergences. In ISIT, 2020. (See Lemma 1) 13 | """ 14 | return rdp + 1.0/(alpha-1) * (np.log(1.0/bad_event) + (alpha-1)*np.log(1-1.0/alpha) - np.log(alpha)) 15 | 16 | 17 | def compute_rdp(alpha, q, sigma): 18 | """ 19 | RDP for subsampled Gaussian mechanism, Ref: 20 | - Mironov, Ilya, Kunal Talwar, and Li Zhang. R\'enyi differential privacy of the sampled gaussian mechanism. arXiv preprint 2019. 21 | """ 22 | sum_ = Decimal(0.0) 23 | for k in range(0, alpha+1): 24 | sum_ += Decimal(comb(alpha, k)) * Decimal(1-q)**Decimal(alpha-k) * Decimal(q**k) * Decimal(np.e)**(Decimal(k**2-k)/Decimal(2*sigma**2)) 25 | rdp = sum_.ln() / Decimal(alpha-1) 26 | return float(rdp) 27 | 28 | 29 | def search_dp(q, sigma, bad_event, iters=1): 30 | """ 31 | Given the sampling rate, variance of Gaussian noise, and privacy parameter delta, 32 | this function returns the corresponding DP budget. 33 | """ 34 | min_dp = 1e5 35 | for alpha in list(range(2, 101)): 36 | rdp = iters * compute_rdp(alpha, q, sigma) 37 | dp = rdp2dp(rdp, bad_event, alpha) 38 | min_dp = min(min_dp, dp) 39 | return min_dp 40 | 41 | 42 | def calibrating_sampled_gaussian(q, eps, bad_event, iters=1, err=1e-3): 43 | """ 44 | Calibrate noise to privacy budgets 45 | """ 46 | sigma_max = 100 47 | sigma_min = 0.1 48 | 49 | def binary_search(left, right): 50 | mid = (left + right) / 2 51 | 52 | lbd = search_dp(q, mid, bad_event, iters) 53 | ubd = search_dp(q, left, bad_event, iters) 54 | 55 | if ubd > eps and lbd > eps: # min noise & mid noise are too small 56 | left = mid 57 | elif ubd > eps and lbd < eps: # mid noise is too large 58 | right = mid 59 | else: 60 | print("an error occurs in func: binary search!") 61 | return -1 62 | return left, right 63 | 64 | # check 65 | if search_dp(q, sigma_max, bad_event, iters) > eps: 66 | print("noise > 100") 67 | return -1 68 | 69 | while sigma_max-sigma_min > err: 70 | sigma_min, sigma_max = binary_search(sigma_min, sigma_max) 71 | return sigma_max -------------------------------------------------------------------------------- /test_cnn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "tags": [] 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# Application of FL task\n", 12 | "from MLModel import *\n", 13 | "from FLModel import *\n", 14 | "from utils import *\n", 15 | "\n", 16 | "from torchvision import datasets, transforms\n", 17 | "import torch\n", 18 | "import numpy as np\n", 19 | "import os\n", 20 | "\n", 21 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n", 22 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 23 | "#device = torch.device(\"cpu\")" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "def load_cnn_mnist(num_users):\n", 33 | " train = datasets.MNIST(root=\"~/data/\", train=True, download=True, transform=transforms.ToTensor())\n", 34 | " train_data = train.data.float().unsqueeze(1)\n", 35 | " train_label = train.targets\n", 36 | "\n", 37 | " mean = train_data.mean()\n", 38 | " std = train_data.std()\n", 39 | " train_data = (train_data - mean) / std\n", 40 | "\n", 41 | " test = datasets.MNIST(root=\"~/data/\", train=False, download=True, transform=transforms.ToTensor())\n", 42 | " test_data = test.data.float().unsqueeze(1)\n", 43 | " test_label = test.targets\n", 44 | " test_data = (test_data - mean) / std\n", 45 | "\n", 46 | " # split MNIST (training set) into non-iid data sets\n", 47 | " non_iid = []\n", 48 | " user_dict = mnist_noniid(train_label, num_users)\n", 49 | " for i in range(num_users):\n", 50 | " idx = user_dict[i]\n", 51 | " d = train_data[idx]\n", 52 | " targets = train_label[idx].float()\n", 53 | " non_iid.append((d, targets))\n", 54 | " non_iid.append((test_data.float(), test_label.float()))\n", 55 | " return non_iid" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "\"\"\"\n", 65 | "1. load_data\n", 66 | "2. generate clients (step 3)\n", 67 | "3. generate aggregator\n", 68 | "4. training\n", 69 | "\"\"\"\n", 70 | "client_num = 4\n", 71 | "d = load_cnn_mnist(client_num)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 4, 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "sigma = 1.0771102905273438\n" 84 | ] 85 | } 86 | ], 87 | "source": [ 88 | "\"\"\"\n", 89 | "FL model parameters.\n", 90 | "\"\"\"\n", 91 | "import warnings\n", 92 | "warnings.filterwarnings(\"ignore\")\n", 93 | "\n", 94 | "lr = 0.15\n", 95 | "\n", 96 | "fl_param = {\n", 97 | " 'output_size': 10,\n", 98 | " 'client_num': client_num,\n", 99 | " 'model': MNIST_CNN,\n", 100 | " 'data': d,\n", 101 | " 'lr': lr,\n", 102 | " 'E': 500,\n", 103 | " 'C': 1,\n", 104 | " 'eps': 4.0,\n", 105 | " 'delta': 1e-5,\n", 106 | " 'q': 0.01,\n", 107 | " 'clip': 0.1,\n", 108 | " 'tot_T': 10,\n", 109 | " 'batch_size': 128,\n", 110 | " 'device': device\n", 111 | "}\n", 112 | "\n", 113 | "fl_entity = FLServer(fl_param).to(device)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 5, 119 | "metadata": { 120 | "scrolled": true 121 | }, 122 | "outputs": [ 123 | { 124 | "name": "stdout", 125 | "output_type": "stream", 126 | "text": [ 127 | "global epochs = 1, acc = 0.5342 Time taken: 302.74s\n", 128 | "global epochs = 2, acc = 0.7952 Time taken: 606.97s\n", 129 | "global epochs = 3, acc = 0.8969 Time taken: 910.56s\n", 130 | "global epochs = 4, acc = 0.9250 Time taken: 1211.34s\n", 131 | "global epochs = 5, acc = 0.9406 Time taken: 1516.22s\n", 132 | "global epochs = 6, acc = 0.9513 Time taken: 1818.05s\n", 133 | "global epochs = 7, acc = 0.9544 Time taken: 2125.24s\n", 134 | "global epochs = 8, acc = 0.9579 Time taken: 2427.08s\n", 135 | "global epochs = 9, acc = 0.9615 Time taken: 2741.97s\n", 136 | "global epochs = 10, acc = 0.9633 Time taken: 3044.31s\n" 137 | ] 138 | } 139 | ], 140 | "source": [ 141 | "import time\n", 142 | "\n", 143 | "acc = []\n", 144 | "start_time = time.time()\n", 145 | "for t in range(fl_param['tot_T']):\n", 146 | " acc += [fl_entity.global_update()]\n", 147 | " print(\"global epochs = {:d}, acc = {:.4f}\".format(t+1, acc[-1]), \" Time taken: %.2fs\" % (time.time() - start_time))" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "# SGD (mnt=0.9)" 157 | ] 158 | } 159 | ], 160 | "metadata": { 161 | "kernelspec": { 162 | "display_name": "Python 3 (ipykernel)", 163 | "language": "python", 164 | "name": "python3" 165 | }, 166 | "language_info": { 167 | "codemirror_mode": { 168 | "name": "ipython", 169 | "version": 3 170 | }, 171 | "file_extension": ".py", 172 | "mimetype": "text/x-python", 173 | "name": "python", 174 | "nbconvert_exporter": "python", 175 | "pygments_lexer": "ipython3", 176 | "version": "3.10.9" 177 | } 178 | }, 179 | "nbformat": 4, 180 | "nbformat_minor": 4 181 | } 182 | -------------------------------------------------------------------------------- /test_scatter_linear.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "tags": [] 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# Application of FL task\n", 12 | "from MLModel import *\n", 13 | "from FLModel import *\n", 14 | "from utils import *\n", 15 | "\n", 16 | "from torchvision import datasets, transforms\n", 17 | "import torch\n", 18 | "import numpy as np\n", 19 | "import os\n", 20 | "\n", 21 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", 22 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 23 | "#device = torch.device(\"cpu\")" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "scattering, K, (h, w) = get_scatter_transform()\n", 33 | "scattering.to(device)\n", 34 | "\n", 35 | "def get_scattered_feature(dataset):\n", 36 | " scatters = []\n", 37 | " targets = []\n", 38 | " \n", 39 | " loader = torch.utils.data.DataLoader(\n", 40 | " dataset, batch_size=256, shuffle=True, num_workers=1, pin_memory=True)\n", 41 | "\n", 42 | " \n", 43 | " for (data, target) in loader:\n", 44 | " data, target = data.to(device), target.to(device)\n", 45 | " if scattering is not None:\n", 46 | " data = scattering(data)\n", 47 | " scatters.append(data)\n", 48 | " targets.append(target)\n", 49 | "\n", 50 | " scatters = torch.cat(scatters, axis=0)\n", 51 | " targets = torch.cat(targets, axis=0)\n", 52 | "\n", 53 | " data = torch.utils.data.TensorDataset(scatters, targets)\n", 54 | " return data\n", 55 | "\n", 56 | "def load_mnist(num_users):\n", 57 | " train = datasets.MNIST(root=\"~/data/\", train=True, download=True, transform=transforms.ToTensor())\n", 58 | " test = datasets.MNIST(root=\"~/data/\", train=False, download=True, transform=transforms.ToTensor())\n", 59 | " \n", 60 | " # get scattered features\n", 61 | " train = get_scattered_feature(train)\n", 62 | " test = get_scattered_feature(test)\n", 63 | " \n", 64 | " train_data = train[:][0].squeeze().cpu().float()\n", 65 | " train_label = train[:][1].cpu()\n", 66 | " \n", 67 | " test_data = test[:][0].squeeze().cpu().float()\n", 68 | " test_label = test[:][1].cpu()\n", 69 | "\n", 70 | " # split MNIST (training set) into non-iid data sets\n", 71 | " non_iid = []\n", 72 | " user_dict = mnist_noniid(train_label, num_users)\n", 73 | " for i in range(num_users):\n", 74 | " idx = user_dict[i]\n", 75 | " d = train_data[idx]\n", 76 | " targets = train_label[idx].float()\n", 77 | " non_iid.append((d, targets))\n", 78 | " non_iid.append((test_data.float(), test_label.float()))\n", 79 | " return non_iid" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 3, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "\"\"\"\n", 89 | "1. load_data\n", 90 | "2. generate clients (step 3)\n", 91 | "3. generate aggregator\n", 92 | "4. training\n", 93 | "\"\"\"\n", 94 | "client_num = 4\n", 95 | "d = load_mnist(client_num)\n", 96 | "\n", 97 | "torch.cuda.empty_cache()" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 4, 103 | "metadata": { 104 | "tags": [] 105 | }, 106 | "outputs": [ 107 | { 108 | "data": { 109 | "text/plain": [ 110 | "torch.Size([81, 7, 7])" 111 | ] 112 | }, 113 | "execution_count": 4, 114 | "metadata": {}, 115 | "output_type": "execute_result" 116 | } 117 | ], 118 | "source": [ 119 | "d[1][0][0].shape" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 5, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "noise scale = 1.0771102905273438\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "\"\"\"\n", 137 | "FL model parameters.\n", 138 | "\"\"\"\n", 139 | "import warnings\n", 140 | "warnings.filterwarnings(\"ignore\")\n", 141 | "\n", 142 | "lr = 0.075\n", 143 | "\n", 144 | "fl_param = {\n", 145 | " 'output_size': 10,\n", 146 | " 'K': K,\n", 147 | " 'h': h,\n", 148 | " 'w': w,\n", 149 | " 'client_num': client_num,\n", 150 | " 'model': 'scatter',\n", 151 | " 'data': d,\n", 152 | " 'lr': lr,\n", 153 | " 'E': 500,\n", 154 | " 'C': 1,\n", 155 | " 'eps': 4.0,\n", 156 | " 'delta': 1e-5,\n", 157 | " 'q': 0.01,\n", 158 | " 'clip': 0.1,\n", 159 | " 'tot_T': 10,\n", 160 | " 'batch_size': 128,\n", 161 | " 'device': device\n", 162 | "}\n", 163 | "\n", 164 | "fl_entity = FLServer(fl_param).to(device)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 6, 170 | "metadata": { 171 | "scrolled": true 172 | }, 173 | "outputs": [ 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "global epochs = 1, acc = 0.8842 Time taken: 161.98s\n", 179 | "global epochs = 2, acc = 0.9348 Time taken: 322.99s\n", 180 | "global epochs = 3, acc = 0.9546 Time taken: 486.85s\n", 181 | "global epochs = 4, acc = 0.9600 Time taken: 648.92s\n", 182 | "global epochs = 5, acc = 0.9657 Time taken: 807.26s\n", 183 | "global epochs = 6, acc = 0.9666 Time taken: 959.25s\n", 184 | "global epochs = 7, acc = 0.9704 Time taken: 1109.23s\n", 185 | "global epochs = 8, acc = 0.9712 Time taken: 1257.56s\n", 186 | "global epochs = 9, acc = 0.9739 Time taken: 1400.09s\n", 187 | "global epochs = 10, acc = 0.9742 Time taken: 1538.17s\n" 188 | ] 189 | } 190 | ], 191 | "source": [ 192 | "import time\n", 193 | "\n", 194 | "acc = []\n", 195 | "start_time = time.time()\n", 196 | "for t in range(fl_param['tot_T']):\n", 197 | " acc += [fl_entity.global_update()]\n", 198 | " print(\"global epochs = {:d}, acc = {:.4f}\".format(t+1, acc[-1]), \" Time taken: %.2fs\" % (time.time() - start_time))" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "# SGD (mnt=0.9)" 208 | ] 209 | } 210 | ], 211 | "metadata": { 212 | "kernelspec": { 213 | "display_name": "Python 3 (ipykernel)", 214 | "language": "python", 215 | "name": "python3" 216 | }, 217 | "language_info": { 218 | "codemirror_mode": { 219 | "name": "ipython", 220 | "version": 3 221 | }, 222 | "file_extension": ".py", 223 | "mimetype": "text/x-python", 224 | "name": "python", 225 | "nbconvert_exporter": "python", 226 | "pygments_lexer": "ipython3", 227 | "version": "3.10.9" 228 | } 229 | }, 230 | "nbformat": 4, 231 | "nbformat_minor": 4 232 | } 233 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Useful tools 3 | """ 4 | import numpy as np 5 | import random 6 | import torch 7 | 8 | 9 | def mnist_noniid(labels, num_users): 10 | """ 11 | Sample non-I.I.D client data from MNIST dataset 12 | :param dataset: 13 | :param num_users: 14 | :return: 15 | """ 16 | # num_shards, num_imgs = 30, 2000 17 | num_shards = int(num_users*3) 18 | num_imgs = int(60000 / num_shards) 19 | idx_shard = [i for i in range(num_shards)] 20 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 21 | idxs = np.arange(num_shards*num_imgs) 22 | labels = labels.numpy() 23 | 24 | # sort labels 25 | idxs_labels = np.vstack((idxs, labels)) 26 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 27 | idxs = idxs_labels[0,:] 28 | 29 | # divide and assign 30 | for i in range(num_users): 31 | rand_set = set(np.random.choice(idx_shard, 3, replace=False)) 32 | idx_shard = list(set(idx_shard) - rand_set) 33 | for rand in rand_set: 34 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 35 | return dict_users 36 | 37 | 38 | def gaussian_noise(data_shape, s, sigma, device=None): 39 | """ 40 | Gaussian noise 41 | """ 42 | return torch.normal(0, sigma * s, data_shape).to(device) --------------------------------------------------------------------------------