├── images └── unlearning.png ├── requirements.txt ├── README.md ├── LICENSE └── utils ├── model.py ├── utils.py ├── local_train.py └── fusion.py /images/unlearning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/federated-unlearning/HEAD/images/unlearning.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | adversarial-robustness-toolbox==1.13.0 2 | jupyter==1.0.0 3 | matplotlib==3.3.4 4 | numpy==1.22 5 | torch==1.13.1 6 | torchvision==0.14.1 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Federated Unlearning 2 | This repo contains the implementation of the work described in [Federated Unlearning: How to Efficiently Erase a Client in FL?](https://arxiv.org/pdf/2207.05521.pdf) 3 | 4 | ## Acknowledgement 5 | 6 | This work was supported by European Union’s Horizon 2020 research and innovation programme under grant number 951911 – AI4Media. 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2022 International Business Machines 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | 5 | ## Note: This model is taken from McMahan et al. FL paper 6 | class FLNet(nn.Module): 7 | def __init__(self): 8 | super(FLNet, self).__init__() 9 | self.conv1 = nn.Conv2d(1, 32, 5, padding=2) 10 | self.conv2 = nn.Conv2d(32, 64, 5, padding=2) 11 | self.fc1 = nn.Linear(64*7*7, 512) 12 | self.fc2 = nn.Linear(512, 10) 13 | 14 | def forward(self, x): 15 | x = F.max_pool2d(F.relu(self.conv1(x)), 2) 16 | x = F.max_pool2d(F.relu(self.conv2(x)), 2) 17 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 18 | x = F.relu(self.fc1(x)) 19 | x = self.fc2(x) 20 | return x 21 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | class Utils(): 7 | 8 | @staticmethod 9 | def get_distance(model1, model2): 10 | with torch.no_grad(): 11 | model1_flattened = nn.utils.parameters_to_vector(model1.parameters()) 12 | model2_flattened = nn.utils.parameters_to_vector(model2.parameters()) 13 | distance = torch.square(torch.norm(model1_flattened - model2_flattened)) 14 | return distance 15 | 16 | @staticmethod 17 | def get_distances_from_current_model(current_model, party_models): 18 | num_updates = len(party_models) 19 | distances = np.zeros(num_updates) 20 | for i in range(num_updates): 21 | distances[i] = Utils.get_distance(current_model, party_models[i]) 22 | return distances 23 | 24 | def evaluate(testloader, model): 25 | model.eval() 26 | correct = 0 27 | total = 0 28 | with torch.no_grad(): 29 | for data in testloader: 30 | images, labels = data 31 | outputs = model(images) 32 | _, predicted = torch.max(outputs.data, 1) 33 | total += labels.size(0) 34 | correct += (predicted == labels).sum().item() 35 | 36 | return 100 * correct / total 37 | 38 | -------------------------------------------------------------------------------- /utils/local_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class LocalTraining(): 5 | 6 | """ 7 | Base class for Local Training 8 | """ 9 | 10 | def __init__(self, 11 | num_updates_in_epoch=None, 12 | num_local_epochs=1): 13 | 14 | self.name = "local-training" 15 | self.num_updates = num_updates_in_epoch 16 | self.num_local_epochs = num_local_epochs 17 | 18 | 19 | def train(self, model, trainloader, criterion=None, opt=None, lr = 1e-2): 20 | """ 21 | Method for local training 22 | """ 23 | if criterion is None: 24 | criterion = nn.CrossEntropyLoss() 25 | if opt is None: 26 | opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) 27 | 28 | if self.num_updates is not None: 29 | self.num_local_epochs = 1 30 | 31 | model.train() 32 | running_loss = 0.0 33 | for epoch in range(self.num_local_epochs): 34 | for batch_id, (data, target) in enumerate(trainloader): 35 | x_batch, y_batch = data, target 36 | 37 | opt.zero_grad() 38 | 39 | outputs = model(x_batch) 40 | loss = criterion(outputs, y_batch) 41 | 42 | loss.backward() 43 | opt.step() 44 | 45 | running_loss += loss.item() 46 | 47 | if self.num_updates is not None and batch_id >= self.num_updates: 48 | break 49 | 50 | return model, running_loss/(batch_id+1) 51 | -------------------------------------------------------------------------------- /utils/fusion.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | from torch import nn 4 | import copy 5 | 6 | """" 7 | Fusion: base class for fusion algorithms 8 | FusionAvg: compute average across all parties 9 | FusionRetrain: compute average across all parties except the target one 10 | """ 11 | 12 | class Fusion(abc.ABC): 13 | 14 | """ 15 | Base class for Fusion 16 | """ 17 | 18 | def __init__(self, num_parties): 19 | self.name = "fusion" 20 | self.num_parties = num_parties 21 | 22 | def average_selected_models(self, selected_parties, party_models): 23 | with torch.no_grad(): 24 | sum_vec = nn.utils.parameters_to_vector(party_models[selected_parties[0]].parameters()) 25 | if len(selected_parties) > 1: 26 | for i in range(1,len(selected_parties)): 27 | sum_vec += nn.utils.parameters_to_vector(party_models[selected_parties[i]].parameters()) 28 | sum_vec /= len(selected_parties) 29 | 30 | model = copy.deepcopy(party_models[0]) 31 | nn.utils.vector_to_parameters(sum_vec, model.parameters()) 32 | return model.state_dict() 33 | 34 | 35 | @abc.abstractmethod 36 | def fusion_algo(self, party_models, current_model=None): 37 | raise NotImplementedError 38 | 39 | 40 | class FusionAvg(Fusion): 41 | 42 | def __init__(self, num_parties): 43 | super().__init__(num_parties) 44 | self.name = "Fusion-Average" 45 | 46 | def fusion_algo(self, party_models, current_model=None): 47 | selected_parties = [i for i in range(self.num_parties)] 48 | aggregated_model_state_dict = super().average_selected_models(selected_parties, party_models) 49 | return aggregated_model_state_dict 50 | 51 | 52 | class FusionRetrain(Fusion): 53 | 54 | def __init__(self, num_parties): 55 | super().__init__(num_parties) 56 | self.name = "Fusion-Retrain" 57 | 58 | # Currently, we assume that the party to be erased is party_id = 0 59 | def fusion_algo(self, party_models, current_model=None): 60 | selected_parties = [i for i in range(1,self.num_parties)] 61 | aggregated_model_state_dict = super().average_selected_models(selected_parties, party_models) 62 | return aggregated_model_state_dict 63 | --------------------------------------------------------------------------------