├── utils ├── __init__.py ├── enums.py ├── clients.py ├── server.py ├── utils.py ├── meter.py ├── dataloader.py └── model.py ├── unlearn ├── __init__.py ├── flipping.py ├── federaser.py └── pga.py ├── .isort.cfg ├── .gitignore ├── .flake8 ├── requirements.txt ├── result_sample ├── config.txt └── usage.ipynb ├── .pre-commit-config.yaml ├── README.md ├── config.py ├── case0.py ├── case1.py ├── case2.py ├── case5.py ├── case3.py └── case4.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /unlearn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile = black -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .idea/ 3 | venv/ 4 | data/ -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, W503, E501, E266 3 | max-line-length = 88 4 | max-complexity = 13 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | adversarial-robustness-toolbox==1.13.0 2 | jupyter==1.0.0 3 | matplotlib==3.3.4 4 | numpy==1.22 5 | torch==1.13.1 6 | torchvision==0.14.1 7 | -------------------------------------------------------------------------------- /result_sample/config.txt: -------------------------------------------------------------------------------- 1 | num_clients: 5 2 | batch_size: 128 3 | num_rounds 4 | mnist: 50 5 | cifar10: 100 6 | num_unlearn_rounds 7 | mnist: 5 8 | cifar10: 10 9 | num_post_training_rounds 10 | mnist: 15 11 | cifar10: 30 12 | poisoned_percent: 0.9 13 | lr: 1e-2 -------------------------------------------------------------------------------- /utils/enums.py: -------------------------------------------------------------------------------- 1 | class EnumBase: 2 | @classmethod 3 | def get_list(cls): 4 | return [getattr(cls, attr) for attr in dir(cls) if attr.isupper()] 5 | 6 | 7 | class Cifar100: 8 | MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343) 9 | STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404) 10 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 22.10.0 4 | hooks: 5 | - id: black 6 | 7 | - repo: https://github.com/pycqa/isort 8 | rev: 5.10.1 9 | hooks: 10 | - id: isort 11 | 12 | - repo: https://github.com/pycqa/flake8 13 | rev: 5.0.4 14 | hooks: 15 | - id: flake8 -------------------------------------------------------------------------------- /unlearn/flipping.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from tqdm import tqdm 4 | 5 | from utils import clients, server 6 | 7 | 8 | def unlearn( 9 | args, 10 | param, 11 | loaders, 12 | chosen_clients, 13 | epochs=1, 14 | lr=0.01, 15 | ): 16 | list_params = [] 17 | 18 | for client in tqdm(chosen_clients): 19 | print(f"-----------client {client} starts training----------") 20 | 21 | if client == 0: 22 | print("-----------flip----------") 23 | 24 | tem_param, train_summ = clients.client_train( 25 | args, deepcopy(param), loaders[client], epochs=epochs, is_flip=True 26 | ) 27 | else: 28 | print("-----------not flip----------") 29 | 30 | tem_param, train_summ = clients.client_train( 31 | args, deepcopy(param), loaders[client], epochs=epochs, is_flip=False 32 | ) 33 | 34 | list_params.append(tem_param) 35 | 36 | # server aggregation 37 | global_param = server.FedAvg(list_params) 38 | 39 | return global_param, train_summ 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Structure of the repository 2 | - result_sample: This folder contains sample results and the usage file. 3 | - results: This folder will stores all information when running an experiment. 4 | - unlearn: This folder contains the implementation of the unlearning methods. 5 | - utils: This folder contains the utility files for Federated Learning. 6 | 7 | # How to reproduce the experiment results: 8 | - Step 1: Go to config.py file to config the experimen. Important factors include: dataset, num_rounds, num_unlearn_rounds, num_post_training_rounds, num_onboarding_rounds and poisoned_percent. Note that we currently fix the num_clients to 5. 9 | - Step 2: Create a folder name "models" in folder results. 10 | - Step 3: In config.py, set is_onboarding to False and run case0.py. Then, run case1.py, case2.py, case3.py, case4.py and case5.py. 11 | - Step 4: In config.py, set is_onboarding to True and run case1.py, case2.py, case3.py, case4.py and case5.py again. 12 | - Step 5: Create folders with_onboarding, without_onboarding and plot in result_sample folder. 13 | - Step 6: copy the generated pkl files in the results folder into the folder result_sample/with_onboarding. 14 | - Step 7: Adjust the configuration and run cells in the usage.ipynb in the result_sample folder. -------------------------------------------------------------------------------- /utils/clients.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from random import randint 3 | 4 | import torch 5 | 6 | from utils import meter 7 | from utils.model import get_model 8 | 9 | 10 | def client_train(args, param, loader, epochs=1, is_flip=False): 11 | model = get_model(args) 12 | model.load_state_dict(param) 13 | 14 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) 15 | model.train() 16 | 17 | summ = meter.Meter() 18 | 19 | for epoch in range(epochs): 20 | for data, target in loader: 21 | data = data.to(args.device) 22 | 23 | if is_flip: 24 | if args.dataset != "cifar100": 25 | new_label = randint(0, 9) 26 | else: 27 | new_label = randint(0, 99) 28 | for i in range(len(target)): 29 | target[i] = new_label 30 | 31 | target = target.to(args.device) 32 | 33 | output = model(data) 34 | loss = args.loss_fn(output, target) 35 | 36 | optimizer.zero_grad() 37 | loss.backward() 38 | optimizer.step() 39 | 40 | # Only keep track of the last epoch 41 | if epoch == epochs - 1: 42 | summ.update( 43 | output.argmax(dim=1).detach().cpu(), target.cpu(), loss.item() 44 | ) 45 | 46 | return deepcopy(model.cpu().state_dict()), summ.get() 47 | -------------------------------------------------------------------------------- /utils/server.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | 5 | from utils import meter 6 | from utils.model import get_model 7 | 8 | 9 | def FedAvg(list_params): 10 | agg_param = deepcopy(list_params[0]) 11 | for k in agg_param.keys(): 12 | agg_param[k] = torch.stack([param[k].float() for param in list_params], 0).mean( 13 | 0 14 | ) 15 | return agg_param 16 | 17 | 18 | def test(args, param, loader, base_model_path=None): 19 | """ 20 | Evaluate the scheme 21 | - args: configuration 22 | - param: model state dict 23 | - loader: test set 24 | """ 25 | 26 | model = get_model(args) 27 | model.load_state_dict(param) 28 | model.eval() 29 | 30 | if base_model_path != None: 31 | eval_metric = meter.EvaluationMetrics() 32 | state = torch.load(base_model_path) 33 | 34 | # base_model = nets.NET(pretrained=False).to(args.device) 35 | 36 | base_model = get_model(args) 37 | base_model.load_state_dict(state) 38 | base_model.eval() 39 | 40 | summ = meter.Meter() 41 | 42 | with torch.no_grad(): 43 | for data, target in loader: 44 | data = data.to(args.device) 45 | target = target.to(args.device) 46 | 47 | output = model(data) 48 | loss = args.loss_fn(output, target) 49 | summ.update(output.argmax(dim=1).detach().cpu(), target.cpu(), loss.item()) 50 | 51 | if base_model_path != None: 52 | base_res = base_model(data) 53 | 54 | # modify evaluation metric, adding comparing similarity between two fc2 55 | # base_model and model 56 | eval_metric.update(base_res.cpu(), output.cpu(), base_model, model) 57 | 58 | if base_model_path != None: 59 | return summ.get(), eval_metric.get() 60 | 61 | return summ.get() 62 | -------------------------------------------------------------------------------- /unlearn/federaser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def fed_eraser_one_step( 5 | old_client_models, 6 | new_client_models, 7 | global_model_before_forget, 8 | global_model_after_forget, 9 | ): 10 | old_param_update = dict() # oldCM - oldGM_t 11 | new_param_update = dict() # newCM - newGM_t 12 | 13 | new_global_model_state = global_model_after_forget # newGM_t 14 | return_model_state = ( 15 | dict() 16 | ) # newGM_t + ||oldCM - oldGM_t||*(newCM - newGM_t)/||newCM - newGM_t|| 17 | 18 | assert len(old_client_models) == len(new_client_models) 19 | for layer in global_model_before_forget.keys(): 20 | old_param_update[layer] = 0 * global_model_before_forget[layer] 21 | new_param_update[layer] = 0 * global_model_before_forget[layer] 22 | return_model_state[layer] = 0 * global_model_before_forget[layer] 23 | 24 | for i in range(len(new_client_models)): 25 | old_param_update[layer] += old_client_models[i][layer] 26 | new_param_update[layer] += new_client_models[i][layer] 27 | 28 | old_param_update[layer] /= len(new_client_models) # oldCM 29 | new_param_update[layer] /= len(new_client_models) # newCM 30 | 31 | old_param_update[layer] = ( 32 | old_param_update[layer] - global_model_before_forget[layer] 33 | ) # oldCM - oldGM_t 34 | new_param_update[layer] = ( 35 | new_param_update[layer] - global_model_after_forget[layer] 36 | ) # newCM - newGM_t 37 | 38 | step_length = torch.norm(old_param_update[layer]) # ||oldCM - oldGM_t|| 39 | step_direction = new_param_update[layer] / torch.norm( 40 | new_param_update[layer] 41 | ) # (newCM - newGM_t)/||newCM - newGM_t|| 42 | 43 | return_model_state[layer] = ( 44 | new_global_model_state[layer] + step_length * step_direction 45 | ) 46 | 47 | return return_model_state 48 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | 7 | np.random.seed(42) 8 | 9 | 10 | def get_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--dataset", type=str, default="cifar100") 13 | parser.add_argument("--num_clients", type=int, default=5) 14 | parser.add_argument("--batch_size", type=int, default=128) 15 | parser.add_argument("--num_rounds", type=int, default=20) 16 | parser.add_argument("--num_unlearn_rounds", type=int, default=2) 17 | parser.add_argument("--num_post_training_rounds", type=int, default=30) 18 | 19 | parser.add_argument("--is_saving_client", type=bool, default=False) 20 | 21 | # onboarding 22 | parser.add_argument("--is_onboarding", type=bool, default=True) 23 | parser.add_argument("--num_onboarding_rounds", type=int, default=30) 24 | 25 | # backdoor 26 | parser.add_argument("--poisoned_percent", type=float, default=0.9) 27 | 28 | parser.add_argument("--local_epochs", type=int, default=1) 29 | parser.add_argument("--lr", type=float, default=1e-2) 30 | 31 | parser.add_argument("--saved", action="store_true") 32 | parser.add_argument("--no_saved", dest="saved", action="store_false") 33 | 34 | parser.set_defaults(saved=True) 35 | 36 | args = parser.parse_args() 37 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 38 | args.loss_fn = torch.nn.CrossEntropyLoss() 39 | 40 | case = sys.argv[0].split(".")[0] 41 | 42 | args.out_file = ( 43 | f"results/{case}_" 44 | f"{args.dataset}_" 45 | f"C{args.num_clients}_" 46 | f"BS{args.batch_size}_" 47 | f"R{args.num_rounds}_" 48 | f"UR{args.num_unlearn_rounds}_" 49 | f"PR{args.num_post_training_rounds}_" 50 | f"E{args.local_epochs}_" 51 | f"LR{args.lr}" 52 | f".pkl" 53 | ) 54 | 55 | return args 56 | 57 | 58 | if __name__ == "__main__": 59 | args = get_args() 60 | 61 | print(args.out_file) 62 | -------------------------------------------------------------------------------- /case0.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import time 4 | from copy import deepcopy 5 | 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | 10 | import config 11 | from utils import clients, server 12 | from utils.dataloader import get_loaders 13 | from utils.model import get_model 14 | from utils.utils import get_results, save_param, update_results 15 | 16 | np.random.seed(42) 17 | torch.manual_seed(42) 18 | torch.cuda.manual_seed(42) 19 | torch.backends.cudnn.enabled = False 20 | torch.backends.cudnn.deterministic = True 21 | 22 | 23 | if __name__ == "__main__": 24 | args = config.get_args() 25 | train_loaders, test_loader, test_loader_poison = get_loaders(args) 26 | 27 | model = get_model(args) 28 | global_param = model.state_dict() 29 | 30 | res = get_results(args) 31 | 32 | # train and evaluate the FL model 33 | num_rounds = args.num_rounds 34 | 35 | start_time = time.time() 36 | for round in range(num_rounds): 37 | print( 38 | "Round {}/{}: lr {} {}".format( 39 | round + 1, args.num_rounds, args.lr, args.out_file 40 | ) 41 | ) 42 | 43 | train_loss, test_loss = 0, 0 44 | train_corr, test_acc = 0, 0 45 | train_total = 0 46 | list_params = [] 47 | 48 | chosen_clients = [i for i in range(args.num_clients)] 49 | 50 | for client in tqdm(chosen_clients): 51 | print(f"-----------client {client} starts training----------") 52 | tem_param, train_summ = clients.client_train( 53 | args, 54 | deepcopy(global_param), 55 | train_loaders[client], 56 | epochs=args.local_epochs, 57 | ) 58 | 59 | save_param( 60 | args, 61 | param=tem_param, 62 | case=0, 63 | client=client, 64 | round=round, 65 | is_global=False, 66 | ) 67 | 68 | train_loss += train_summ["loss"] 69 | train_corr += train_summ["correct"] 70 | train_total += train_summ["total"] 71 | 72 | list_params.append(tem_param) 73 | 74 | res["train"]["loss"]["avg"].append(train_loss / len(chosen_clients)) 75 | res["train"]["acc"]["avg"].append(train_corr / train_total) 76 | 77 | print( 78 | "Train loss: {:5f} acc: {:5f}".format( 79 | res["train"]["loss"]["avg"][-1], 80 | res["train"]["acc"]["avg"][-1], 81 | ) 82 | ) 83 | 84 | # server aggregation 85 | global_param = server.FedAvg(list_params) 86 | 87 | save_param(args, param=global_param, case=0, round=round) 88 | 89 | res = update_results(args, res, global_param, test_loader, test_loader_poison) 90 | 91 | total_time = time.time() - start_time 92 | res["time"] = total_time 93 | print(f"Time {total_time}") 94 | 95 | with open(args.out_file, "wb") as fp: 96 | pickle.dump(res, fp) 97 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from utils import server 9 | import pickle 10 | 11 | 12 | class Utils: 13 | @staticmethod 14 | def get_distance(model1, model2): 15 | with torch.no_grad(): 16 | model1_flattened = nn.utils.parameters_to_vector(model1.parameters()) 17 | model2_flattened = nn.utils.parameters_to_vector(model2.parameters()) 18 | distance = torch.square(torch.norm(model1_flattened - model2_flattened)) 19 | return distance 20 | 21 | @staticmethod 22 | def get_distances_from_current_model(current_model, party_models): 23 | num_updates = len(party_models) 24 | distances = np.zeros(num_updates) 25 | for i in range(num_updates): 26 | distances[i] = Utils.get_distance(current_model, party_models[i]) 27 | return distances 28 | 29 | def evaluate(testloader, model): 30 | model.eval() 31 | correct = 0 32 | total = 0 33 | with torch.no_grad(): 34 | for data in testloader: 35 | images, labels = data 36 | outputs = model(images) 37 | _, predicted = torch.max(outputs.data, 1) 38 | total += labels.size(0) 39 | correct += (predicted == labels).sum().item() 40 | 41 | return 100 * correct / total 42 | 43 | 44 | def get_results(args): 45 | res = {} 46 | for k1 in ("train", "val"): 47 | res[k1] = {} 48 | for k2 in ("loss", "acc"): 49 | res[k1][k2] = {} 50 | res[k1][k2]["avg"] = [] 51 | res[k1][k2]["clean"] = [] 52 | res[k1][k2]["backdoor"] = [] 53 | for k3 in range(args.num_clients): 54 | res[k1][k2][k3] = [] 55 | 56 | return res 57 | 58 | def load_results(filename): 59 | with open(filename, 'rb') as fp: 60 | data = pickle.load(fp) 61 | 62 | return data 63 | 64 | 65 | def save_client_param(args, param, case, client, round): 66 | folder_path = f"./results/models/case{case}/client{client}" 67 | os.makedirs(folder_path, exist_ok=True) 68 | torch.save( 69 | param, 70 | f'{folder_path}/{args.out_file.split("/")[-1].split(".pkl")[0]}_round{round}.pt', 71 | ) 72 | 73 | 74 | def save_global_param(args, param, case, round): 75 | folder_path = f"./results/models/case{case}" 76 | os.makedirs(folder_path, exist_ok=True) 77 | torch.save( 78 | param, 79 | f'{folder_path}/{args.out_file.split("/")[-1].split(".pkl")[0]}_round{round}.pt', 80 | ) 81 | 82 | 83 | def save_param(args, param, case, client=None, round=None, is_global=True): 84 | if args.saved: 85 | if is_global: 86 | save_global_param(args, param, case, round) 87 | else: 88 | # Temporarily comment the bellow line to not save the client model 89 | save_client_param(args, param, case, client, round) 90 | # pass 91 | 92 | def update_results(args, res, global_param, test_loader, test_loader_poison): 93 | clean_test_summ = server.test(args, global_param, test_loader) 94 | res["val"]["loss"]["clean"].append(clean_test_summ["loss"]) 95 | res["val"]["acc"]["clean"].append( 96 | clean_test_summ["correct"] / clean_test_summ["total"] 97 | ) 98 | 99 | backdoor_test_summ = server.test(args, global_param, test_loader_poison) 100 | res["val"]["loss"]["backdoor"].append(backdoor_test_summ["loss"]) 101 | res["val"]["acc"]["backdoor"].append( 102 | backdoor_test_summ["correct"] / backdoor_test_summ["total"] 103 | ) 104 | 105 | print(f'Global clean accuracy: {res["val"]["acc"]["clean"][-1]}') 106 | print(f'Global backdoor accuracy: {res["val"]["acc"]["backdoor"][-1]}') 107 | 108 | return res 109 | -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score 4 | 5 | import config 6 | 7 | args = config.get_args() 8 | 9 | 10 | class Meter: 11 | def __init__(self): 12 | self.correct = 0 13 | self.total = 0 14 | self.losses = [] 15 | self.precision = 0 16 | self.recall = 0 17 | self.f1 = 0 18 | 19 | def update(self, pt, gt, loss): 20 | """ 21 | pt: n x 1, integer 22 | gt: n x 1, integer 23 | loss: float scale 24 | """ 25 | 26 | self.total += len(gt) 27 | self.correct += (pt == gt).sum().item() 28 | self.losses.append(loss) 29 | 30 | def get(self): 31 | """ 32 | return: 33 | averaged loss 34 | total correctness 35 | precision 36 | recall 37 | f1 38 | total samples 39 | """ 40 | avg_loss = sum(self.losses) / len(self.losses) 41 | return { 42 | "loss": avg_loss, 43 | "correct": self.correct, 44 | "total": self.total, 45 | } 46 | 47 | 48 | class EvaluationMetrics: 49 | """ 50 | Evaluate based on accuracy, precision, recall and f1_score 51 | Evaluate based on fully connected 2 52 | """ 53 | 54 | def __init__(self): 55 | self.l2norm = torch.tensor([]) 56 | self.similarity = torch.tensor([]) 57 | self.sape = torch.tensor([]) 58 | self.fc2_similarity = torch.tensor([]) 59 | self.model_similarity = torch.tensor([]) 60 | 61 | def update(self, base, pt, base_model, pt_model): 62 | self.l2norm = torch.cat([self.l2norm, (base - pt).square().sum(dim=1)]) 63 | self.similarity = torch.cat( 64 | [self.similarity, torch.nn.CosineSimilarity(dim=1)(base, pt)] 65 | ) 66 | self.sape = torch.cat( 67 | [ 68 | self.sape, 69 | (base - pt).abs().sum(dim=1) 70 | / (base.abs().sum(dim=1) + pt.abs().sum(dim=1)), 71 | ] 72 | ) 73 | 74 | # # similarity between 2 models 75 | # local_updates = [base_model, pt_model] 76 | # 77 | # update_flats = [] 78 | # for update in local_updates: 79 | # update_f = [var.flatten() for key, var in update.items()] 80 | # update_f = torch.cat(update_f) 81 | # update_flats.append(update_f) 82 | # 83 | # for i in range(len(update_flats)): 84 | # list1 = [] 85 | # for j in range(len(update_flats)): 86 | # if i != j: 87 | # vector_a = np.array(update_flats[i].view(-1, 1).cpu()) 88 | # vector_b = np.array(update_flats[j].view(-1, 1).cpu()) 89 | # 90 | # a = vector_a[vector_a < 1.0] 91 | # b = vector_b[vector_a < 1.0] 92 | # 93 | # num = np.dot(a, b) 94 | # denom = np.linalg.norm(a) * np.linalg.norm(b) 95 | # cos = num / denom 96 | # 97 | # # list1.append(cos) 98 | # self.model_similarity = torch.cat( 99 | # [self.model_similarity, torch.tensor(cos)] 100 | # ) 101 | # 102 | # 103 | 104 | # base_flat_weight = torch.flatten(base_model.fc2.weight) 105 | # pt_flat_weight = torch.flatten(pt_model.fc2.weight) 106 | 107 | base_flat_weight = base_model.fc2.weight 108 | pt_flat_weight = pt_model.fc2.weight 109 | 110 | # print("quick debug 1:", base.shape) 111 | # print("quick debug 2:", base_model.fc2.weight.shape) 112 | 113 | self.fc2_similarity = torch.cat( 114 | [ 115 | self.fc2_similarity, 116 | torch.nn.CosineSimilarity(dim=1)( 117 | base_flat_weight.cpu(), pt_flat_weight.cpu() 118 | ), 119 | ] 120 | ) 121 | 122 | def get(self): 123 | return [ 124 | self.l2norm.mean().item(), 125 | self.similarity.mean().item(), 126 | self.sape.mean().item(), 127 | self.fc2_similarity.mean().item(), 128 | self.model_similarity, 129 | ] 130 | -------------------------------------------------------------------------------- /unlearn/pga.py: -------------------------------------------------------------------------------- 1 | from copy import copy, deepcopy 2 | 3 | import numpy as np 4 | import torch 5 | from torch.nn.utils import clip_grad_norm_, parameters_to_vector, vector_to_parameters 6 | 7 | from utils import meter 8 | from utils.model import get_model 9 | from utils.utils import Utils 10 | 11 | 12 | def compute_ref_vec(global_param, party0_param, num_parties): 13 | model_ref_vec = num_parties / (num_parties - 1) * parameters_to_vector( 14 | global_param 15 | ) - 1 / (num_parties - 1) * parameters_to_vector(party0_param) 16 | 17 | return model_ref_vec 18 | 19 | 20 | def get_ref_vec(args): 21 | global_param = torch.load( 22 | f"./results/models/case0/case0_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}_round{args.num_rounds-1}.pt" 23 | ) 24 | party0_param = torch.load( 25 | f"./results/models/case0/client0/case0_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}_round{args.num_rounds - 1}.pt" 26 | ) 27 | global_model = get_model(args) 28 | unlearn_client_model = get_model(args) 29 | 30 | global_model.load_state_dict(global_param) 31 | global_param = global_model.parameters() 32 | 33 | unlearn_client_model.load_state_dict(party0_param) 34 | party0_param = unlearn_client_model.parameters() 35 | 36 | num_parties = args.num_clients 37 | 38 | ref_param = compute_ref_vec(global_param, party0_param, num_parties) 39 | 40 | return ref_param 41 | 42 | 43 | def get_model_ref(args): 44 | model_ref_vec = get_ref_vec(args) 45 | model_ref = get_model(args) 46 | vector_to_parameters(model_ref_vec, model_ref.parameters()) 47 | 48 | return model_ref 49 | 50 | 51 | def get_threshold(args, model_ref): 52 | dist_ref_random_lst = [] 53 | for _ in range(10): 54 | random_model = get_model(args) 55 | dist_ref_random_lst.append(Utils.get_distance(model_ref, random_model).cpu()) 56 | 57 | threshold = np.mean(dist_ref_random_lst) / 3 58 | print(f"Radius for model_ref: {threshold}") 59 | return threshold 60 | 61 | 62 | def unlearn( 63 | args, 64 | param, 65 | param_ref, 66 | party0_param, 67 | distance_threshold, 68 | loader, 69 | threshold, 70 | clip_grad=1, 71 | epochs=1, 72 | lr=0.01, 73 | ): 74 | model = get_model(args) 75 | model.load_state_dict(param) 76 | 77 | model_ref = get_model(args) 78 | model_ref.load_state_dict(param_ref) 79 | 80 | party0_model = get_model(args) 81 | party0_model.load_state_dict(party0_param) 82 | 83 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) 84 | model.train() 85 | 86 | summ = meter.Meter() 87 | 88 | flag = False 89 | for epoch in range(epochs): 90 | if flag: 91 | break 92 | for data, target in loader: 93 | data = data.to(args.device) 94 | target = target.to(args.device) 95 | 96 | output = model(data) 97 | loss = args.loss_fn(output, target) 98 | 99 | optimizer.zero_grad() 100 | loss = -loss # negate the loss for gradient ascent 101 | loss.backward() 102 | if clip_grad > 0: 103 | clip_grad_norm_(model.parameters(), clip_grad) 104 | optimizer.step() 105 | 106 | with torch.no_grad(): 107 | distance = Utils.get_distance(model, model_ref) 108 | if distance > threshold: 109 | dist_vec = parameters_to_vector( 110 | model.parameters() 111 | ) - parameters_to_vector(model_ref.parameters()) 112 | dist_vec = dist_vec / torch.norm(dist_vec) * np.sqrt(threshold) 113 | proj_vec = parameters_to_vector(model_ref.parameters()) + dist_vec 114 | vector_to_parameters(proj_vec, model.parameters()) 115 | distance = Utils.get_distance(model, model_ref) 116 | 117 | distance_ref_party_0 = Utils.get_distance(model, party0_model) 118 | print( 119 | "Distance from the unlearned model to party 0:", 120 | distance_ref_party_0.item(), 121 | ) 122 | 123 | if distance_ref_party_0 > distance_threshold: 124 | flag = True 125 | summ.update( 126 | output.argmax(dim=1).detach().cpu(), target.cpu(), loss.item() 127 | ) 128 | break 129 | 130 | summ.update(output.argmax(dim=1).detach().cpu(), target.cpu(), loss.item()) 131 | 132 | return deepcopy(model.cpu().state_dict()), summ.get() 133 | -------------------------------------------------------------------------------- /case1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import time 4 | from copy import deepcopy 5 | 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | 10 | import config 11 | from utils import clients, server 12 | from utils.dataloader import get_loaders 13 | from utils.model import get_model 14 | from utils.utils import get_results, save_param, update_results, load_results 15 | 16 | np.random.seed(42) 17 | torch.manual_seed(42) 18 | torch.cuda.manual_seed(42) 19 | torch.backends.cudnn.enabled = False 20 | torch.backends.cudnn.deterministic = True 21 | 22 | """ 23 | Retrain 24 | """ 25 | 26 | if __name__ == "__main__": 27 | 28 | args = config.get_args() 29 | train_loaders, test_loader, test_loader_poison = get_loaders(args) 30 | 31 | model = get_model(args) 32 | global_param = model.state_dict() 33 | 34 | num_rounds = args.num_rounds 35 | num_unlearn_rounds = args.num_unlearn_rounds 36 | num_post_training_rounds = args.num_post_training_rounds 37 | num_onboarding_rounds = args.num_onboarding_rounds 38 | 39 | if not args.is_onboarding: 40 | start_time = time.time() 41 | 42 | # train and evaluate the FL model 43 | end_round = ( 44 | args.num_rounds + args.num_unlearn_rounds + args.num_post_training_rounds 45 | ) 46 | 47 | res = get_results(args) 48 | 49 | for round in range(end_round): 50 | if round == num_rounds: 51 | total_time = time.time() - start_time 52 | res["time"] = total_time 53 | print(f"Time {total_time}") 54 | 55 | # print(" -------------- saving time .... --------------") 56 | 57 | 58 | print( 59 | "Round {}/{}: lr {} {}".format( 60 | round + 1, end_round, args.lr, args.out_file 61 | ) 62 | ) 63 | 64 | train_loss, test_loss = 0, 0 65 | train_corr, test_acc = 0, 0 66 | train_total = 0 67 | list_params = [] 68 | 69 | chosen_clients = [i for i in range(1, args.num_clients)] 70 | 71 | for client in tqdm(chosen_clients): 72 | print(f"-----------client {client} starts training----------") 73 | tem_param, train_summ = clients.client_train( 74 | args, 75 | deepcopy(global_param), 76 | train_loaders[client], 77 | epochs=args.local_epochs, 78 | ) 79 | 80 | # save_param( 81 | # args, 82 | # param=tem_param, 83 | # case=1, 84 | # client=client, 85 | # round=round, 86 | # is_global=False, 87 | # ) 88 | 89 | train_loss += train_summ["loss"] 90 | train_corr += train_summ["correct"] 91 | train_total += train_summ["total"] 92 | 93 | list_params.append(tem_param) 94 | 95 | res["train"]["loss"]["avg"].append(train_loss / len(chosen_clients)) 96 | res["train"]["acc"]["avg"].append(train_corr / train_total) 97 | 98 | print( 99 | "Train loss: {:5f} acc: {:5f}".format( 100 | res["train"]["loss"]["avg"][-1], 101 | res["train"]["acc"]["avg"][-1], 102 | ) 103 | ) 104 | 105 | # server aggregation 106 | global_param = server.FedAvg(list_params) 107 | 108 | save_param(args, param=global_param, case=1, round=round) 109 | 110 | res = update_results(args, res, global_param, test_loader, test_loader_poison) 111 | 112 | with open(args.out_file, "wb") as fp: 113 | pickle.dump(res, fp) 114 | else: 115 | ######################## onboarding round ############################ 116 | start_round = num_rounds + num_unlearn_rounds + num_post_training_rounds 117 | end_round = start_round + num_onboarding_rounds 118 | 119 | global_param = torch.load( 120 | f"./results/models/case1/case1_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}_round{start_round-1}.pt" 121 | ) 122 | 123 | res = load_results( 124 | f"./results/case1_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}.pkl" 125 | ) 126 | 127 | for round in range(start_round, end_round): 128 | print( 129 | "Round {}/{}: lr {} {}".format(round + 1, end_round, args.lr, args.out_file) 130 | ) 131 | 132 | # print("debug check time:", res["time"]) 133 | 134 | train_loss, test_loss = 0, 0 135 | train_corr, test_acc = 0, 0 136 | train_total = 0 137 | list_params = [] 138 | 139 | chosen_clients = [i for i in range(args.num_clients)] 140 | 141 | for client in tqdm(chosen_clients): 142 | print(f"-----------client {client} starts training----------") 143 | tem_param, train_summ = clients.client_train( 144 | args, 145 | deepcopy(global_param), 146 | train_loaders[client], 147 | epochs=args.local_epochs, 148 | ) 149 | 150 | # save client params 151 | # save_param( 152 | # args, 153 | # param=tem_param, 154 | # case=1, 155 | # client=client, 156 | # round=round, 157 | # is_global=False, 158 | # ) 159 | 160 | train_loss += train_summ["loss"] 161 | train_corr += train_summ["correct"] 162 | train_total += train_summ["total"] 163 | 164 | list_params.append(tem_param) 165 | 166 | res["train"]["loss"]["avg"].append(train_loss / len(chosen_clients)) 167 | res["train"]["acc"]["avg"].append(train_corr / train_total) 168 | 169 | print( 170 | "Train loss: {:5f} acc: {:5f}".format( 171 | res["train"]["loss"]["avg"][-1], 172 | res["train"]["acc"]["avg"][-1], 173 | ) 174 | ) 175 | 176 | # server aggregation 177 | global_param = server.FedAvg(list_params) 178 | 179 | # save global param 180 | save_param(args, param=global_param, case=1, round=round) 181 | 182 | res = update_results(args, res, global_param, test_loader, test_loader_poison) 183 | 184 | # print("debug check time:", res["time"]) 185 | 186 | with open(args.out_file, "wb") as fp: 187 | pickle.dump(res, fp) 188 | 189 | # total_time = time.time() - start_time 190 | # res["time"] = total_time 191 | # print(f"Time {total_time}") 192 | 193 | # with open(args.out_file, "wb") as fp: 194 | # pickle.dump(res, fp) 195 | -------------------------------------------------------------------------------- /case2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import time 4 | from copy import deepcopy 5 | 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | 10 | import config 11 | from utils import clients, server 12 | from utils.dataloader import get_loaders 13 | from utils.model import get_model 14 | from utils.utils import get_results, save_param, update_results, load_results 15 | 16 | np.random.seed(42) 17 | torch.manual_seed(42) 18 | torch.cuda.manual_seed(42) 19 | torch.backends.cudnn.enabled = False 20 | torch.backends.cudnn.deterministic = True 21 | 22 | """ 23 | Continue training 24 | """ 25 | 26 | if __name__ == "__main__": 27 | 28 | args = config.get_args() 29 | train_loaders, test_loader, test_loader_poison = get_loaders(args) 30 | 31 | model = get_model(args) 32 | global_param = model.state_dict() 33 | 34 | num_rounds = args.num_rounds 35 | num_unlearn_rounds = args.num_unlearn_rounds 36 | num_post_training_rounds = args.num_post_training_rounds 37 | num_onboarding_rounds = args.num_onboarding_rounds 38 | 39 | if not args.is_onboarding: 40 | start_time = time.time() 41 | 42 | 43 | global_param = torch.load( 44 | f"./results/models/case0/case0_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}_round{args.num_rounds-1}.pt" 45 | ) 46 | 47 | res = get_results(args) 48 | 49 | # train and evaluate the FL model 50 | end_round = ( 51 | num_rounds + num_unlearn_rounds + num_post_training_rounds 52 | ) 53 | 54 | for round in range(num_rounds, end_round): 55 | 56 | if round == num_rounds + num_unlearn_rounds: 57 | total_time = time.time() - start_time 58 | res["time"] = total_time 59 | print(f"Time {total_time}") 60 | 61 | print( 62 | "Round {}/{}: lr {} {}".format(round + 1, end_round, args.lr, args.out_file) 63 | ) 64 | 65 | train_loss, test_loss = 0, 0 66 | train_corr, test_acc = 0, 0 67 | train_total = 0 68 | list_params = [] 69 | 70 | chosen_clients = [i for i in range(1, args.num_clients)] 71 | 72 | for client in tqdm(chosen_clients): 73 | print(f"-----------client {client} starts training----------") 74 | tem_param, train_summ = clients.client_train( 75 | args, 76 | deepcopy(global_param), 77 | train_loaders[client], 78 | epochs=args.local_epochs, 79 | ) 80 | 81 | # save client params 82 | # save_param( 83 | # args, 84 | # param=tem_param, 85 | # case=2, 86 | # client=client, 87 | # round=round, 88 | # is_global=False, 89 | # ) 90 | 91 | train_loss += train_summ["loss"] 92 | train_corr += train_summ["correct"] 93 | train_total += train_summ["total"] 94 | 95 | list_params.append(tem_param) 96 | 97 | res["train"]["loss"]["avg"].append(train_loss / len(chosen_clients)) 98 | res["train"]["acc"]["avg"].append(train_corr / train_total) 99 | 100 | print( 101 | "Train loss: {:5f} acc: {:5f}".format( 102 | res["train"]["loss"]["avg"][-1], 103 | res["train"]["acc"]["avg"][-1], 104 | ) 105 | ) 106 | 107 | # server aggregation 108 | global_param = server.FedAvg(list_params) 109 | 110 | # save global param 111 | save_param(args, param=global_param, case=2, round=round) 112 | 113 | res = update_results(args, res, global_param, test_loader, test_loader_poison) 114 | 115 | with open(args.out_file, "wb") as fp: 116 | pickle.dump(res, fp) 117 | else: 118 | ######################## onboarding round ############################ 119 | start_round = num_rounds + num_unlearn_rounds + num_post_training_rounds 120 | end_round = start_round + num_onboarding_rounds 121 | 122 | global_param = torch.load( 123 | f"./results/models/case2/case2_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}_round{start_round-1}.pt" 124 | ) 125 | 126 | res = load_results( 127 | f"./results/case2_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}.pkl" 128 | ) 129 | 130 | for round in range(start_round, end_round): 131 | print( 132 | "Round {}/{}: lr {} {}".format(round + 1, end_round, args.lr, args.out_file) 133 | ) 134 | 135 | train_loss, test_loss = 0, 0 136 | train_corr, test_acc = 0, 0 137 | train_total = 0 138 | list_params = [] 139 | 140 | chosen_clients = [i for i in range(args.num_clients)] 141 | 142 | for client in tqdm(chosen_clients): 143 | print(f"-----------client {client} starts training----------") 144 | tem_param, train_summ = clients.client_train( 145 | args, 146 | deepcopy(global_param), 147 | train_loaders[client], 148 | epochs=args.local_epochs, 149 | ) 150 | 151 | # save client params 152 | # save_param( 153 | # args, 154 | # param=tem_param, 155 | # case=2, 156 | # client=client, 157 | # round=round, 158 | # is_global=False, 159 | # ) 160 | 161 | train_loss += train_summ["loss"] 162 | train_corr += train_summ["correct"] 163 | train_total += train_summ["total"] 164 | 165 | list_params.append(tem_param) 166 | 167 | res["train"]["loss"]["avg"].append(train_loss / len(chosen_clients)) 168 | res["train"]["acc"]["avg"].append(train_corr / train_total) 169 | 170 | print( 171 | "Train loss: {:5f} acc: {:5f}".format( 172 | res["train"]["loss"]["avg"][-1], 173 | res["train"]["acc"]["avg"][-1], 174 | ) 175 | ) 176 | 177 | # server aggregation 178 | global_param = server.FedAvg(list_params) 179 | 180 | # save global param 181 | save_param(args, param=global_param, case=2, round=round) 182 | 183 | res = update_results(args, res, global_param, test_loader, test_loader_poison) 184 | 185 | with open(args.out_file, "wb") as fp: 186 | pickle.dump(res, fp) 187 | 188 | # total_time = time.time() - start_time 189 | # res["time"] = total_time 190 | # print(f"Time {total_time}") 191 | 192 | # with open(args.out_file, "wb") as fp: 193 | # pickle.dump(res, fp) 194 | -------------------------------------------------------------------------------- /case5.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import time 4 | from copy import deepcopy 5 | 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | 10 | import config 11 | from unlearn.flipping import unlearn 12 | from utils import clients, server 13 | from utils.dataloader import get_loaders 14 | from utils.model import get_model 15 | from utils.utils import get_results, save_param, update_results, load_results 16 | 17 | np.random.seed(42) 18 | torch.manual_seed(42) 19 | torch.cuda.manual_seed(42) 20 | torch.backends.cudnn.enabled = False 21 | torch.backends.cudnn.deterministic = True 22 | 23 | 24 | """ 25 | Flipping method 26 | """ 27 | 28 | if __name__ == "__main__": 29 | 30 | args = config.get_args() 31 | train_loaders, test_loader, test_loader_poison = get_loaders(args) 32 | 33 | model = get_model(args) 34 | global_param = model.state_dict() 35 | 36 | num_rounds = args.num_rounds 37 | num_unlearn_rounds = args.num_unlearn_rounds 38 | num_post_training_rounds = args.num_post_training_rounds 39 | num_onboarding_rounds = args.num_onboarding_rounds 40 | 41 | 42 | if not args.is_onboarding: 43 | start_time = time.time() 44 | 45 | res = get_results(args) 46 | 47 | global_param = torch.load( 48 | f"./results/models/case0/case0_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}_round{args.num_rounds-1}.pt" 49 | ) 50 | 51 | party_to_be_erased = 0 52 | 53 | # train and evaluate the FL model 54 | end_round = num_rounds + num_unlearn_rounds 55 | 56 | print("------------Unlearn------------") 57 | for round in range(num_rounds, end_round): 58 | print( 59 | "Round {}/{}: lr {} {}".format(round + 1, end_round, args.lr, args.out_file) 60 | ) 61 | 62 | train_loss, test_loss = 0, 0 63 | train_corr, test_acc = 0, 0 64 | train_total = 0 65 | list_params = [] 66 | 67 | tem_param, unlearn_summ = unlearn( 68 | args=args, 69 | param=global_param, 70 | loaders=train_loaders, 71 | chosen_clients=[i for i in range(args.num_clients)], 72 | epochs=1, 73 | ) 74 | 75 | global_param = tem_param 76 | 77 | # save global param 78 | folder_path = "./results/models/case5" 79 | os.makedirs(folder_path, exist_ok=True) 80 | torch.save( 81 | global_param, 82 | f'{folder_path}/{args.out_file.split("/")[-1].split(".pkl")[0]}_round{round}.pt', 83 | ) 84 | 85 | res = update_results(args, res, global_param, test_loader, test_loader_poison) 86 | 87 | print(f'Global clean accuracy: {res["val"]["acc"]["clean"][-1]}') 88 | print(f'Global backdoor accuracy: {res["val"]["acc"]["backdoor"][-1]}') 89 | 90 | total_time = time.time() - start_time 91 | res["time"] = total_time 92 | print(f"Time {total_time}") 93 | 94 | ######################## post train ############################ 95 | start_round = end_round 96 | end_round = start_round + num_post_training_rounds 97 | for round in range(start_round, end_round): 98 | print( 99 | "Round {}/{}: lr {} {}".format(round + 1, end_round, args.lr, args.out_file) 100 | ) 101 | 102 | train_loss, test_loss = 0, 0 103 | train_corr, test_acc = 0, 0 104 | train_total = 0 105 | list_params = [] 106 | 107 | chosen_clients = [i for i in range(1, args.num_clients)] 108 | 109 | for client in tqdm(chosen_clients): 110 | print(f"-----------client {client} starts training----------") 111 | tem_param, train_summ = clients.client_train( 112 | args, 113 | deepcopy(global_param), 114 | train_loaders[client], 115 | epochs=args.local_epochs, 116 | ) 117 | 118 | # save client params 119 | # save_param( 120 | # args, 121 | # param=tem_param, 122 | # case=5, 123 | # client=client, 124 | # round=round, 125 | # is_global=False, 126 | # ) 127 | 128 | train_loss += train_summ["loss"] 129 | train_corr += train_summ["correct"] 130 | train_total += train_summ["total"] 131 | 132 | list_params.append(tem_param) 133 | 134 | res["train"]["loss"]["avg"].append(train_loss / len(chosen_clients)) 135 | res["train"]["acc"]["avg"].append(train_corr / train_total) 136 | 137 | print( 138 | "Train loss: {:5f} acc: {:5f}".format( 139 | res["train"]["loss"]["avg"][-1], 140 | res["train"]["acc"]["avg"][-1], 141 | ) 142 | ) 143 | 144 | # server aggregation 145 | global_param = server.FedAvg(list_params) 146 | 147 | # save global param 148 | save_param(args, param=global_param, case=5, round=round) 149 | 150 | res = update_results(args, res, global_param, test_loader, test_loader_poison) 151 | 152 | with open(args.out_file, "wb") as fp: 153 | pickle.dump(res, fp) 154 | 155 | else: 156 | ######################## onboarding round ############################ 157 | start_round = num_rounds + num_unlearn_rounds + num_post_training_rounds 158 | end_round = start_round + num_onboarding_rounds 159 | 160 | global_param = torch.load( 161 | f"./results/models/case5/case5_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}_round{start_round-1}.pt" 162 | ) 163 | 164 | res = load_results( 165 | f"./results/case5_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}.pkl" 166 | ) 167 | 168 | for round in range(start_round, end_round): 169 | print( 170 | "Round {}/{}: lr {} {}".format(round + 1, end_round, args.lr, args.out_file) 171 | ) 172 | 173 | train_loss, test_loss = 0, 0 174 | train_corr, test_acc = 0, 0 175 | train_total = 0 176 | list_params = [] 177 | 178 | chosen_clients = [i for i in range(args.num_clients)] 179 | 180 | for client in tqdm(chosen_clients): 181 | print(f"-----------client {client} starts training----------") 182 | tem_param, train_summ = clients.client_train( 183 | args, 184 | deepcopy(global_param), 185 | train_loaders[client], 186 | epochs=args.local_epochs, 187 | ) 188 | 189 | # save client params 190 | # save_param( 191 | # args, 192 | # param=tem_param, 193 | # case=5, 194 | # client=client, 195 | # round=round, 196 | # is_global=False, 197 | # ) 198 | 199 | train_loss += train_summ["loss"] 200 | train_corr += train_summ["correct"] 201 | train_total += train_summ["total"] 202 | 203 | list_params.append(tem_param) 204 | 205 | res["train"]["loss"]["avg"].append(train_loss / len(chosen_clients)) 206 | res["train"]["acc"]["avg"].append(train_corr / train_total) 207 | 208 | print( 209 | "Train loss: {:5f} acc: {:5f}".format( 210 | res["train"]["loss"]["avg"][-1], 211 | res["train"]["acc"]["avg"][-1], 212 | ) 213 | ) 214 | 215 | # server aggregation 216 | global_param = server.FedAvg(list_params) 217 | 218 | # save global param 219 | save_param(args, param=global_param, case=5, round=round) 220 | 221 | res = update_results(args, res, global_param, test_loader, test_loader_poison) 222 | 223 | with open(args.out_file, "wb") as fp: 224 | pickle.dump(res, fp) 225 | 226 | # total_time = time.time() - start_time 227 | # res["time"] = total_time 228 | # print(f"Time {total_time}") 229 | 230 | # with open(args.out_file, "wb") as fp: 231 | # pickle.dump(res, fp) 232 | -------------------------------------------------------------------------------- /case3.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import pickle 4 | import time 5 | from copy import deepcopy 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn.utils import parameters_to_vector, vector_to_parameters 10 | from tqdm import tqdm 11 | 12 | import config 13 | from unlearn.pga import get_model_ref, get_ref_vec, get_threshold, unlearn 14 | from utils import clients, server 15 | from utils.dataloader import get_loaders 16 | from utils.model import get_model 17 | from utils.utils import get_results, save_param, update_results, load_results 18 | 19 | np.random.seed(42) 20 | torch.manual_seed(42) 21 | torch.cuda.manual_seed(42) 22 | torch.backends.cudnn.enabled = False 23 | torch.backends.cudnn.deterministic = True 24 | 25 | 26 | """ 27 | PGA 28 | https://arxiv.org/pdf/2207.05521.pdf 29 | """ 30 | 31 | if __name__ == "__main__": 32 | args = config.get_args() 33 | train_loaders, test_loader, test_loader_poison = get_loaders(args) 34 | 35 | num_rounds = args.num_rounds 36 | num_unlearn_rounds = args.num_unlearn_rounds 37 | num_post_training_rounds = args.num_post_training_rounds 38 | num_onboarding_rounds = args.num_onboarding_rounds 39 | 40 | 41 | if not args.is_onboarding: 42 | start_time = time.time() 43 | 44 | res = get_results(args) 45 | 46 | model_ref = get_model_ref(args) 47 | 48 | # train and evaluate the FL model 49 | model = copy.deepcopy(model_ref) 50 | 51 | param_ref = model_ref.state_dict() 52 | party0_param = torch.load( 53 | f"./results/models/case0/client0/case0_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}_round{args.num_rounds - 1}.pt" 54 | ) 55 | 56 | global_param = copy.deepcopy(param_ref) 57 | 58 | party_to_be_erased = 0 59 | threshold = get_threshold(args, model_ref) 60 | end_round = num_rounds + num_unlearn_rounds 61 | 62 | print(f"------------Unlearn------------") 63 | 64 | for round in range(num_rounds, end_round): 65 | print( 66 | "Round {}/{}: lr {} {}".format(round + 1, end_round, args.lr, args.out_file) 67 | ) 68 | 69 | train_loss, test_loss = 0, 0 70 | train_corr, test_acc = 0, 0 71 | train_total = 0 72 | list_params = [] 73 | 74 | tem_param, unlearn_summ = unlearn( 75 | args=args, 76 | param=global_param, 77 | param_ref=param_ref, 78 | party0_param=party0_param, 79 | distance_threshold=2.2, 80 | loader=train_loaders[party_to_be_erased], 81 | threshold=threshold, 82 | clip_grad=5, 83 | epochs=1, 84 | ) 85 | 86 | global_param = tem_param 87 | 88 | # save global param 89 | folder_path = f"./results/models/case3" 90 | os.makedirs(folder_path, exist_ok=True) 91 | torch.save( 92 | global_param, 93 | f'{folder_path}/{args.out_file.split("/")[-1].split(".pkl")[0]}_round{round}.pt', 94 | ) 95 | 96 | res = update_results(args, res, global_param, test_loader, test_loader_poison) 97 | 98 | print(f'Global clean accuracy: {res["val"]["acc"]["clean"][-1]}') 99 | print(f'Global backdoor accuracy: {res["val"]["acc"]["backdoor"][-1]}') 100 | 101 | total_time = time.time() - start_time 102 | res["time"] = total_time 103 | print(f"Time {total_time}") 104 | 105 | ######################## post train ############################ 106 | start_round = end_round 107 | end_round = start_round + args.num_post_training_rounds 108 | for round in range(start_round, end_round): 109 | print( 110 | "Round {}/{}: lr {} {}".format(round + 1, end_round, args.lr, args.out_file) 111 | ) 112 | 113 | train_loss, test_loss = 0, 0 114 | train_corr, test_acc = 0, 0 115 | train_total = 0 116 | list_params = [] 117 | 118 | chosen_clients = [i for i in range(1, args.num_clients)] 119 | 120 | for client in tqdm(chosen_clients): 121 | print(f"-----------client {client} starts training----------") 122 | tem_param, train_summ = clients.client_train( 123 | args, 124 | deepcopy(global_param), 125 | train_loaders[client], 126 | epochs=args.local_epochs, 127 | ) 128 | 129 | # save client params 130 | # save_param( 131 | # args, 132 | # param=tem_param, 133 | # case=3, 134 | # client=client, 135 | # round=round, 136 | # is_global=False, 137 | # ) 138 | 139 | train_loss += train_summ["loss"] 140 | train_corr += train_summ["correct"] 141 | train_total += train_summ["total"] 142 | 143 | list_params.append(tem_param) 144 | 145 | res["train"]["loss"]["avg"].append(train_loss / len(chosen_clients)) 146 | res["train"]["acc"]["avg"].append(train_corr / train_total) 147 | 148 | print( 149 | "Train loss: {:5f} acc: {:5f}".format( 150 | res["train"]["loss"]["avg"][-1], 151 | res["train"]["acc"]["avg"][-1], 152 | ) 153 | ) 154 | 155 | # server aggregation 156 | global_param = server.FedAvg(list_params) 157 | 158 | # save global param 159 | save_param(args, param=global_param, case=3, round=round) 160 | 161 | res = update_results(args, res, global_param, test_loader, test_loader_poison) 162 | 163 | with open(args.out_file, "wb") as fp: 164 | pickle.dump(res, fp) 165 | 166 | else: 167 | ######################## onboarding round ############################ 168 | model = get_model(args) 169 | global_param = model.state_dict() 170 | 171 | start_round = num_rounds + num_unlearn_rounds + num_post_training_rounds 172 | end_round = start_round + num_onboarding_rounds 173 | 174 | global_param = torch.load( 175 | f"./results/models/case3/case3_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}_round{start_round-1}.pt" 176 | ) 177 | 178 | res = load_results( 179 | f"./results/case3_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}.pkl" 180 | ) 181 | 182 | for round in range(start_round, end_round): 183 | print( 184 | "Round {}/{}: lr {} {}".format(round + 1, end_round, args.lr, args.out_file) 185 | ) 186 | 187 | train_loss, test_loss = 0, 0 188 | train_corr, test_acc = 0, 0 189 | train_total = 0 190 | list_params = [] 191 | 192 | chosen_clients = [i for i in range(args.num_clients)] 193 | 194 | for client in tqdm(chosen_clients): 195 | print(f"-----------client {client} starts training----------") 196 | tem_param, train_summ = clients.client_train( 197 | args, 198 | deepcopy(global_param), 199 | train_loaders[client], 200 | epochs=args.local_epochs, 201 | ) 202 | 203 | # save client params 204 | # save_param( 205 | # args, 206 | # param=tem_param, 207 | # case=3, 208 | # client=client, 209 | # round=round, 210 | # is_global=False, 211 | # ) 212 | 213 | train_loss += train_summ["loss"] 214 | train_corr += train_summ["correct"] 215 | train_total += train_summ["total"] 216 | 217 | list_params.append(tem_param) 218 | 219 | res["train"]["loss"]["avg"].append(train_loss / len(chosen_clients)) 220 | res["train"]["acc"]["avg"].append(train_corr / train_total) 221 | 222 | print( 223 | "Train loss: {:5f} acc: {:5f}".format( 224 | res["train"]["loss"]["avg"][-1], 225 | res["train"]["acc"]["avg"][-1], 226 | ) 227 | ) 228 | 229 | # server aggregation 230 | global_param = server.FedAvg(list_params) 231 | 232 | # save global param 233 | save_param(args, param=global_param, case=3, round=round) 234 | 235 | res = update_results(args, res, global_param, test_loader, test_loader_poison) 236 | 237 | with open(args.out_file, "wb") as fp: 238 | pickle.dump(res, fp) 239 | 240 | # total_time = time.time() - start_time 241 | # res["time"] = total_time 242 | # print(f"Time {total_time}") 243 | 244 | # with open(args.out_file, "wb") as fp: 245 | # pickle.dump(res, fp) 246 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | from art.attacks.poisoning import PoisoningAttackBackdoor 6 | from art.attacks.poisoning.perturbations import add_pattern_bd 7 | from art.utils import load_dataset, to_categorical 8 | from torch.utils.data import DataLoader, TensorDataset 9 | 10 | from utils.enums import Cifar100 11 | 12 | # seeds 13 | torch.manual_seed(0) 14 | np.random.seed(0) 15 | 16 | 17 | def insert_backdoor(args, x_train_party, y_train_party, example_target, backdoor, plotting = False): 18 | # Insert backdoor 19 | if plotting: 20 | percent_poison = args["poisoned_percent"] 21 | else: 22 | percent_poison = args.poisoned_percent 23 | 24 | all_indices = np.arange(len(x_train_party)) 25 | remove_indices = all_indices[np.all(y_train_party == example_target, axis=1)] 26 | 27 | target_indices = list(set(all_indices) - set(remove_indices)) 28 | num_poison = int(percent_poison * len(target_indices)) 29 | selected_indices = np.random.choice(target_indices, num_poison, replace=False) 30 | 31 | poisoned_data, poisoned_labels = backdoor.poison( 32 | x_train_party[selected_indices], y=example_target, broadcast=True 33 | ) 34 | 35 | poisoned_x_train = np.copy(x_train_party) 36 | poisoned_y_train = np.argmax(y_train_party, axis=1) 37 | for s, i in zip(selected_indices, range(len(selected_indices))): 38 | poisoned_x_train[s] = poisoned_data[i] 39 | poisoned_y_train[s] = int(np.argmax(poisoned_labels[i])) 40 | 41 | return poisoned_x_train, poisoned_y_train 42 | 43 | 44 | def create_dataset_from_poisoned_data( 45 | args, x_train_party, y_train_party, example_target, backdoor, plotting = False 46 | ): 47 | poisoned_x_train, poisoned_y_train = insert_backdoor( 48 | args, x_train_party, y_train_party, example_target, backdoor, plotting = plotting 49 | ) 50 | # poisoned_x_train_ch = np.expand_dims(poisoned_x_train, axis=1) 51 | poisoned_x_train_ch = np.transpose(poisoned_x_train, (0, 3, 1, 2)) 52 | 53 | poisoned_dataset_train = TensorDataset( 54 | torch.Tensor(poisoned_x_train_ch), torch.Tensor(poisoned_y_train).long() 55 | ) 56 | if plotting: 57 | poisoned_dataloader_train = DataLoader( 58 | poisoned_dataset_train, batch_size=args["batch_size"], shuffle=True 59 | ) 60 | else: 61 | poisoned_dataloader_train = DataLoader( 62 | poisoned_dataset_train, batch_size=args.batch_size, shuffle=True 63 | ) 64 | 65 | return poisoned_dataloader_train 66 | 67 | 68 | def create_dataset_for_normal_clients( 69 | args, x_train_parties, y_train_parties, num_samples_per_party, plotting = False 70 | ): 71 | # x_train_parties_ch = np.expand_dims(x_train_parties, axis=1) 72 | x_train_parties_ch = np.transpose(x_train_parties, (0, 3, 1, 2)) 73 | y_train_parties_c = np.argmax(y_train_parties, axis=1).astype(int) 74 | 75 | # Create PyTorch datasets for other parties 76 | x_train_parties = TensorDataset( 77 | torch.Tensor(x_train_parties_ch), torch.Tensor(y_train_parties_c).long() 78 | ) 79 | 80 | if plotting: 81 | clean_dataset_train = torch.utils.data.random_split( 82 | x_train_parties, [num_samples_per_party for _ in range(1, args["num_clients"])] 83 | ) 84 | else: 85 | clean_dataset_train = torch.utils.data.random_split( 86 | x_train_parties, [num_samples_per_party for _ in range(1, args.num_clients)] 87 | ) 88 | 89 | 90 | 91 | return clean_dataset_train 92 | 93 | 94 | def load_cifar100(): 95 | transform_train = transforms.Compose( 96 | [ 97 | transforms.RandomCrop(32, padding=4), 98 | transforms.RandomHorizontalFlip(), 99 | transforms.RandomRotation(15), 100 | transforms.ToTensor(), 101 | transforms.Normalize(Cifar100.MEAN, Cifar100.STD), 102 | ] 103 | ) 104 | 105 | transform_test = transforms.Compose( 106 | [transforms.ToTensor(), transforms.Normalize(Cifar100.MEAN, Cifar100.STD)] 107 | ) 108 | cifar100_train = torchvision.datasets.CIFAR100( 109 | root="./data", train=True, download=True, transform=transform_train 110 | ) 111 | cifar100_test = torchvision.datasets.CIFAR100( 112 | root="./data", train=False, download=True, transform=transform_test 113 | ) 114 | 115 | x_train = [] 116 | y_train = [] 117 | for i in range(len(cifar100_train)): 118 | data, label = cifar100_train[i] 119 | x_train.append(data.numpy()) 120 | y_train.append(label) 121 | 122 | x_train = np.array(x_train) 123 | y_train = np.array(y_train) 124 | 125 | x_test = [] 126 | y_test = [] 127 | for i in range(len(cifar100_test)): 128 | data, label = cifar100_test[i] 129 | x_test.append(data.numpy()) 130 | y_test.append(label) 131 | 132 | x_test = np.array(x_test) 133 | y_test = np.array(y_test) 134 | 135 | # Set channels last 136 | x_train = x_train.transpose((0, 2, 3, 1)) 137 | x_test = x_test.transpose((0, 2, 3, 1)) 138 | 139 | y_train = to_categorical(y_train, 100) 140 | y_test = to_categorical(y_test, 100) 141 | 142 | return (x_train, y_train), (x_test, y_test) 143 | 144 | 145 | def load_data(dataset): 146 | if dataset == "cifar100": 147 | (x_train, y_train), (x_test, y_test) = load_cifar100() 148 | else: # dataset in [mnist, cifar10] 149 | (x_train, y_train), (x_test, y_test), min_, max_ = load_dataset(dataset) 150 | 151 | # label must be one hot encoded 152 | n_train = np.shape(y_train)[0] 153 | shuffled_indices = np.arange(n_train) 154 | np.random.shuffle(shuffled_indices) 155 | x_train = x_train[shuffled_indices] 156 | y_train = y_train[shuffled_indices] 157 | 158 | return x_train, y_train, x_test, y_test 159 | 160 | 161 | def create_train_loaders(args, x_train, y_train, example_target, backdoor, plotting = False): 162 | if plotting: 163 | num_samples = y_train.shape[0] 164 | num_samples_erased_party = int(num_samples / args["num_clients"]) 165 | num_samples_per_party = int( 166 | (num_samples - num_samples_erased_party) / (args["num_clients"] - 1) 167 | ) 168 | 169 | num_samples = (args["num_clients"] - 1) * num_samples_per_party 170 | 171 | x_train_party = x_train[0:num_samples_erased_party] 172 | y_train_party = y_train[0:num_samples_erased_party] 173 | 174 | x_train_parties = x_train[ 175 | num_samples_erased_party : num_samples_erased_party + num_samples 176 | ] 177 | y_train_parties = y_train[ 178 | num_samples_erased_party : num_samples_erased_party + num_samples 179 | ] 180 | 181 | poisoned_dataloader_train = create_dataset_from_poisoned_data( 182 | args, x_train_party, y_train_party, example_target, backdoor, plotting = plotting 183 | ) 184 | clean_dataset_train = create_dataset_for_normal_clients( 185 | args, x_train_parties, y_train_parties, num_samples_per_party, plotting = plotting 186 | ) 187 | 188 | train_loaders = [poisoned_dataloader_train] 189 | for i in range(len(clean_dataset_train)): 190 | train_loaders.append( 191 | DataLoader(clean_dataset_train[i], batch_size=args["batch_size"], shuffle=True) 192 | ) 193 | 194 | else: 195 | num_samples = y_train.shape[0] 196 | num_samples_erased_party = int(num_samples / args.num_clients) 197 | num_samples_per_party = int( 198 | (num_samples - num_samples_erased_party) / (args.num_clients - 1) 199 | ) 200 | 201 | num_samples = (args.num_clients - 1) * num_samples_per_party 202 | 203 | x_train_party = x_train[0:num_samples_erased_party] 204 | y_train_party = y_train[0:num_samples_erased_party] 205 | 206 | x_train_parties = x_train[ 207 | num_samples_erased_party : num_samples_erased_party + num_samples 208 | ] 209 | y_train_parties = y_train[ 210 | num_samples_erased_party : num_samples_erased_party + num_samples 211 | ] 212 | 213 | poisoned_dataloader_train = create_dataset_from_poisoned_data( 214 | args, x_train_party, y_train_party, example_target, backdoor, plotting = plotting 215 | ) 216 | clean_dataset_train = create_dataset_for_normal_clients( 217 | args, x_train_parties, y_train_parties, num_samples_per_party, plotting = plotting 218 | ) 219 | 220 | train_loaders = [poisoned_dataloader_train] 221 | for i in range(len(clean_dataset_train)): 222 | train_loaders.append( 223 | DataLoader(clean_dataset_train[i], batch_size=args.batch_size, shuffle=True) 224 | ) 225 | 226 | 227 | return train_loaders 228 | 229 | 230 | def create_test_loaders(args, x_test, y_test, example_target, backdoor, plotting = False): 231 | all_indices = np.arange(len(x_test)) 232 | remove_indices = all_indices[np.all(y_test == example_target, axis=1)] 233 | 234 | target_indices = list(set(all_indices) - set(remove_indices)) 235 | poisoned_data, poisoned_labels = backdoor.poison( 236 | x_test[target_indices], y=example_target, broadcast=True 237 | ) 238 | 239 | poisoned_x_test = np.copy(x_test) 240 | poisoned_y_test = np.argmax(y_test, axis=1) 241 | 242 | for s, i in zip(target_indices, range(len(target_indices))): 243 | poisoned_x_test[s] = poisoned_data[i] 244 | poisoned_y_test[s] = int(np.argmax(poisoned_labels[i])) 245 | 246 | # poisoned_x_test_ch = np.expand_dims(poisoned_x_test, axis=1) 247 | poisoned_x_test_ch = np.transpose(poisoned_x_test, (0, 3, 1, 2)) 248 | 249 | poisoned_dataset_test = TensorDataset( 250 | torch.Tensor(poisoned_x_test_ch), torch.Tensor(poisoned_y_test).long() 251 | ) 252 | testloader_poison = DataLoader( 253 | poisoned_dataset_test, batch_size=1000, shuffle=False 254 | ) 255 | 256 | # x_test_pt = np.expand_dims(x_test, axis=1) 257 | x_test_pt = np.transpose(x_test, (0, 3, 1, 2)) 258 | 259 | y_test_pt = np.argmax(y_test, axis=1).astype(int) 260 | dataset_test = TensorDataset( 261 | torch.Tensor(x_test_pt), torch.Tensor(y_test_pt).long() 262 | ) 263 | testloader = DataLoader(dataset_test, batch_size=1000, shuffle=False) 264 | 265 | return testloader, testloader_poison 266 | 267 | 268 | def get_loaders(args, plotting = False): 269 | if plotting: 270 | x_train, y_train, x_test, y_test = load_data(args['dataset']) 271 | else: 272 | x_train, y_train, x_test, y_test = load_data(args.dataset) 273 | 274 | # Init backdoor pattern 275 | backdoor = PoisoningAttackBackdoor(add_pattern_bd) 276 | 277 | example_target = np.zeros(y_train.shape[1]) 278 | example_target[-1] = 1 279 | 280 | train_loaders = create_train_loaders( 281 | args, x_train, y_train, example_target, backdoor, plotting=plotting 282 | ) 283 | test_loader, test_loader_poison = create_test_loaders( 284 | args, x_test, y_test, example_target, backdoor, plotting=plotting 285 | ) 286 | 287 | return train_loaders, test_loader, test_loader_poison 288 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torchvision import models 5 | from typing import List 6 | 7 | # Note: This model is taken from McMahan et al. FL paper 8 | class FLNet(nn.Module): 9 | def __init__(self): 10 | super(FLNet, self).__init__() 11 | self.conv1 = nn.Conv2d(1, 32, 5, padding=2) 12 | self.conv2 = nn.Conv2d(32, 64, 5, padding=2) 13 | self.fc1 = nn.Linear(64 * 7 * 7, 512) 14 | self.fc2 = nn.Linear(512, 10) 15 | 16 | def forward(self, x): 17 | x = F.max_pool2d(F.relu(self.conv1(x)), 2) 18 | x = F.max_pool2d(F.relu(self.conv2(x)), 2) 19 | x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3]) 20 | x = F.relu(self.fc1(x)) 21 | x = self.fc2(x) 22 | return x 23 | 24 | 25 | class CNNCifar(nn.Module): 26 | def __init__(self): 27 | super(CNNCifar, self).__init__() 28 | self.conv1 = nn.Conv2d(3, 6, 5) 29 | self.pool = nn.MaxPool2d(2, 2) 30 | self.conv2 = nn.Conv2d(6, 16, 5) 31 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 32 | self.fc2 = nn.Linear(120, 84) 33 | self.fc3 = nn.Linear(84, 10) 34 | 35 | def forward(self, x): 36 | x = self.pool(F.relu(self.conv1(x))) 37 | x = self.pool(F.relu(self.conv2(x))) 38 | x = torch.flatten(x, 1) 39 | x = F.relu(self.fc1(x)) 40 | x = F.relu(self.fc2(x)) 41 | x = self.fc3(x) 42 | return x 43 | 44 | 45 | class BasicBlock(nn.Module): 46 | """Basic Block for resnet 18 and resnet 34""" 47 | 48 | # BasicBlock and BottleNeck block 49 | # have different output size 50 | # we use class attribute expansion 51 | # to distinct 52 | expansion = 1 53 | 54 | def __init__(self, in_channels, out_channels, stride=1): 55 | super().__init__() 56 | 57 | # residual function 58 | self.residual_function = nn.Sequential( 59 | nn.Conv2d( 60 | in_channels, 61 | out_channels, 62 | kernel_size=3, 63 | stride=stride, 64 | padding=1, 65 | bias=False, 66 | ), 67 | nn.BatchNorm2d(out_channels), 68 | nn.ReLU(inplace=True), 69 | nn.Conv2d( 70 | out_channels, 71 | out_channels * BasicBlock.expansion, 72 | kernel_size=3, 73 | padding=1, 74 | bias=False, 75 | ), 76 | nn.BatchNorm2d(out_channels * BasicBlock.expansion), 77 | ) 78 | 79 | # shortcut 80 | self.shortcut = nn.Sequential() 81 | 82 | # the shortcut output dimension is not the same with residual function 83 | # use 1*1 convolution to match the dimension 84 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 85 | self.shortcut = nn.Sequential( 86 | nn.Conv2d( 87 | in_channels, 88 | out_channels * BasicBlock.expansion, 89 | kernel_size=1, 90 | stride=stride, 91 | bias=False, 92 | ), 93 | nn.BatchNorm2d(out_channels * BasicBlock.expansion), 94 | ) 95 | 96 | def forward(self, x): 97 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 98 | 99 | 100 | class ResNet(nn.Module): 101 | def __init__(self, block, num_block, num_classes=100): 102 | super().__init__() 103 | 104 | self.in_channels = 64 105 | 106 | self.conv1 = nn.Sequential( 107 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), 108 | nn.BatchNorm2d(64), 109 | nn.ReLU(inplace=True), 110 | ) 111 | # we use a different inputsize than the original paper 112 | # so conv2_x's stride is 1 113 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 114 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 115 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 116 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 117 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 118 | self.fc = nn.Linear(512 * block.expansion, num_classes) 119 | 120 | def _make_layer(self, block, out_channels, num_blocks, stride): 121 | """make resnet layers(by layer i didnt mean this 'layer' was the 122 | same as a neuron netowork layer, ex. conv layer), one layer may 123 | contain more than one residual block 124 | Args: 125 | block: block type, basic block or bottle neck block 126 | out_channels: output depth channel number of this layer 127 | num_blocks: how many blocks per layer 128 | stride: the stride of the first block of this layer 129 | Return: 130 | return a resnet layer 131 | """ 132 | 133 | # we have num_block blocks per layer, the first block 134 | # could be 1 or 2, other blocks would always be 1 135 | strides = [stride] + [1] * (num_blocks - 1) 136 | layers = [] 137 | for stride in strides: 138 | layers.append(block(self.in_channels, out_channels, stride)) 139 | self.in_channels = out_channels * block.expansion 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | output = self.conv1(x) 145 | output = self.conv2_x(output) 146 | output = self.conv3_x(output) 147 | output = self.conv4_x(output) 148 | output = self.conv5_x(output) 149 | output = self.avg_pool(output) 150 | output = output.view(output.size(0), -1) 151 | output = self.fc(output) 152 | 153 | return output 154 | 155 | 156 | def resnet18(): 157 | """return a ResNet 18 object""" 158 | return ResNet(BasicBlock, [2, 2, 2, 2]) 159 | 160 | class Cifar100(nn.Module): 161 | def __init__(self, pretrained=True) -> None: 162 | super().__init__() 163 | 164 | self.prenet = models.resnet50(pretrained=pretrained) 165 | num_ftrs = self.prenet.fc.in_features 166 | self.prenet.fc = nn.Sequential( 167 | nn.Linear(num_ftrs, 100) 168 | ) 169 | 170 | def forward(self, x): 171 | x = self.prenet(x) 172 | return x 173 | 174 | class DecoupledModel(nn.Module): 175 | def __init__(self): 176 | super(DecoupledModel, self).__init__() 177 | self.need_all_features_flag = False 178 | self.all_features = [] 179 | self.base: nn.Module = None 180 | self.classifier: nn.Module = None 181 | self.dropout: List[nn.Module] = [] 182 | 183 | def need_all_features(self): 184 | self.need_all_features_flag = True 185 | target_modules = [ 186 | module 187 | for module in self.base.modules() 188 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) 189 | ] 190 | 191 | def get_feature_hook_fn(model, input, output): 192 | self.all_features.append(output) 193 | 194 | for module in target_modules: 195 | module.register_forward_hook(get_feature_hook_fn) 196 | 197 | def check_avaliability(self): 198 | if self.base is None or self.classifier is None: 199 | raise RuntimeError( 200 | "You need to re-write the base and classifier in your custom model class." 201 | ) 202 | self.dropout = [ 203 | module 204 | for module in list(self.base.modules()) + list(self.classifier.modules()) 205 | if isinstance(module, nn.Dropout) 206 | ] 207 | 208 | def forward(self, x: torch.Tensor): 209 | out = self.classifier(F.relu(self.base(x))) 210 | if self.need_all_features_flag: 211 | self.all_features = [] 212 | return out 213 | 214 | def get_final_features(self, x: torch.Tensor, detach=True): 215 | if len(self.dropout) > 0: 216 | for dropout in self.dropout: 217 | dropout.eval() 218 | 219 | func = (lambda x: x.clone().detach()) if detach else (lambda x: x) 220 | out = self.base(x) 221 | 222 | if len(self.dropout) > 0: 223 | for dropout in self.dropout: 224 | dropout.train() 225 | 226 | return func(out) 227 | 228 | def get_all_features(self, x: torch.Tensor, detach=True): 229 | feature_list = None 230 | if self.need_all_features_flag: 231 | if len(self.dropout) > 0: 232 | for dropout in self.dropout: 233 | dropout.eval() 234 | 235 | func = (lambda x: x.clone().detach()) if detach else (lambda x: x) 236 | _ = self.base(x) 237 | feature_list = [func(feature) for feature in self.all_features] 238 | self.all_features = [] 239 | 240 | if len(self.dropout) > 0: 241 | for dropout in self.dropout: 242 | dropout.train() 243 | 244 | return feature_list 245 | 246 | 247 | 248 | class ResNet18(DecoupledModel): 249 | def __init__(self, dataset): 250 | super(ResNet18, self).__init__() 251 | config = { 252 | "mnist": 10, 253 | "medmnistS": 11, 254 | "medmnistC": 11, 255 | "medmnistA": 11, 256 | "fmnist": 10, 257 | "svhn": 10, 258 | "emnist": 62, 259 | "femnist": 62, 260 | "cifar10": 10, 261 | "cinic10": 10, 262 | "cifar100": 100, 263 | "covid19": 4, 264 | "usps": 10, 265 | "celeba": 2, 266 | "tiny_imagenet": 200, 267 | } 268 | # NOTE: If you don't want parameters pretrained, set `pretrained` as False 269 | pretrained = True 270 | self.base = models.resnet18( 271 | weights=models.ResNet18_Weights.DEFAULT if pretrained else None 272 | ) 273 | self.classifier = nn.Linear(self.base.fc.in_features, config[dataset]) 274 | self.base.fc = nn.Identity() 275 | 276 | def forward(self, x): 277 | if x.shape[1] == 1: 278 | x = torch.expand_copy(x, (x.shape[0], 3, *x.shape[2:])) 279 | return super().forward(x) 280 | 281 | def get_all_features(self, x, detach=True): 282 | if x.shape[1] == 1: 283 | x = torch.expand_copy(x, (x.shape[0], 3, *x.shape[2:])) 284 | return super().get_all_features(x, detach) 285 | 286 | def get_final_features(self, x, detach=True): 287 | if x.shape[1] == 1: 288 | x = torch.expand_copy(x, (x.shape[0], 3, *x.shape[2:])) 289 | return super().get_final_features(x, detach) 290 | 291 | def get_model(args, plotting = False): 292 | if plotting: 293 | dataset = args["dataset"] 294 | device = args["device"] 295 | else: 296 | dataset = args.dataset 297 | device = args.device 298 | 299 | if dataset == "mnist": 300 | model = FLNet().to(device) 301 | elif dataset == "cifar10": 302 | model = resnet18().to(device) 303 | elif dataset == "cifar100": 304 | model = resnet18().to(device) 305 | else: 306 | raise Exception("dataset is not supported") 307 | 308 | return model 309 | -------------------------------------------------------------------------------- /case4.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import time 3 | from copy import deepcopy 4 | 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | 9 | import config 10 | from unlearn.federaser import fed_eraser_one_step 11 | from utils import clients, server 12 | from utils.dataloader import get_loaders 13 | from utils.model import get_model 14 | from utils.utils import get_results, save_param, update_results, load_results 15 | 16 | np.random.seed(42) 17 | torch.manual_seed(42) 18 | torch.cuda.manual_seed(42) 19 | torch.backends.cudnn.enabled = False 20 | torch.backends.cudnn.deterministic = True 21 | 22 | 23 | if __name__ == "__main__": 24 | args = config.get_args() 25 | train_loaders, test_loader, test_loader_poison = get_loaders(args) 26 | 27 | model = get_model(args) 28 | global_param = model.state_dict() 29 | 30 | num_rounds = args.num_rounds 31 | num_unlearn_rounds = args.num_unlearn_rounds 32 | num_post_training_rounds = args.num_post_training_rounds 33 | num_onboarding_rounds = args.num_onboarding_rounds 34 | 35 | 36 | if not args.is_onboarding: 37 | start_time = time.time() 38 | 39 | res = get_results(args) 40 | 41 | # load fl global params 42 | old_global_models = [] 43 | for round in range(args.num_rounds): 44 | global_param = torch.load( 45 | f"./results/models/case0/case0_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}_round{round}.pt" 46 | ) 47 | old_global_models.append(global_param) 48 | 49 | new_global_models = [] 50 | 51 | # train and evaluate the FL model 52 | chosen_clients = [i for i in range(1, args.num_clients)] 53 | 54 | rounds = [i for i in range(0, num_rounds, num_rounds // args.num_unlearn_rounds)] 55 | print(rounds) 56 | 57 | for i, round in enumerate(rounds): 58 | roundth = args.num_rounds + i 59 | print( 60 | "Round {}/{}: lr {} {}".format( 61 | roundth + 1, 62 | num_rounds + args.num_unlearn_rounds, 63 | args.lr, 64 | args.out_file, 65 | ) 66 | ) 67 | 68 | train_loss, test_loss = 0, 0 69 | train_corr, test_acc = 0, 0 70 | train_total = 0 71 | list_params = [] 72 | 73 | old_client_updates = [] 74 | new_client_updates = [] 75 | 76 | # 1st round unlearn only fedavg non-malicious clients updates 77 | if round == 0: 78 | for client in chosen_clients: 79 | old_client_update = torch.load( 80 | f"./results/models/case0/client{client}/case0_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}_round{round}.pt" 81 | ) 82 | old_client_updates.append(old_client_update) 83 | 84 | new_global_model = server.FedAvg(old_client_updates) 85 | new_global_models.append(new_global_model) 86 | 87 | save_param(args, param=global_param, case=4, round=roundth) 88 | res = update_results( 89 | args, res, global_param, test_loader, test_loader_poison 90 | ) 91 | continue 92 | 93 | old_global_model = old_global_models[round] 94 | new_prev_global_model = new_global_models[-1] 95 | 96 | for client in tqdm(chosen_clients): 97 | print(f"-----------client {client} starts training----------") 98 | old_client_update = torch.load( 99 | f"./results/models/case0/client{client}/case0_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}_round{round}.pt" 100 | ) 101 | old_client_updates.append(old_client_update) 102 | 103 | # local_cali_round = int(math.ceil(args.local_epochs * FedEraser.CALI_RATIO)) 104 | local_cali_round = 1 105 | new_client_update, train_summ = clients.client_train( 106 | args, 107 | deepcopy(new_prev_global_model), 108 | train_loaders[client], 109 | epochs=local_cali_round, 110 | ) 111 | 112 | new_client_updates.append(new_client_update) 113 | 114 | train_loss += train_summ["loss"] 115 | train_corr += train_summ["correct"] 116 | train_total += train_summ["total"] 117 | list_params.append(new_client_update) 118 | 119 | res["train"]["loss"]["avg"].append(train_loss / len(list_params)) 120 | res["train"]["acc"]["avg"].append(train_corr / train_total) 121 | print( 122 | "Train loss {:5f} acc {:5f}".format( 123 | res["train"]["loss"]["avg"][-1], res["train"]["acc"]["avg"][-1] 124 | ) 125 | ) 126 | 127 | new_global_model = fed_eraser_one_step( 128 | old_client_updates, 129 | new_client_updates, 130 | old_global_model, 131 | new_prev_global_model, 132 | ) 133 | new_global_models.append(new_global_model) 134 | 135 | save_param(args, param=new_global_model, case=4, round=roundth) 136 | res = update_results( 137 | args, res, new_global_model, test_loader, test_loader_poison 138 | ) 139 | 140 | total_time = time.time() - start_time 141 | res["time"] = total_time 142 | print(f"Time {total_time}") 143 | 144 | ######################## post train ############################ 145 | global_param = new_global_model 146 | end_round = args.num_rounds + len(rounds) 147 | start_round = end_round 148 | end_round = start_round + args.num_post_training_rounds 149 | for round in range(start_round, end_round): 150 | print( 151 | "Round {}/{}: lr {} {}".format(round + 1, end_round, args.lr, args.out_file) 152 | ) 153 | 154 | train_loss, test_loss = 0, 0 155 | train_corr, test_acc = 0, 0 156 | train_total = 0 157 | list_params = [] 158 | 159 | chosen_clients = [i for i in range(1, args.num_clients)] 160 | 161 | for client in tqdm(chosen_clients): 162 | print(f"-----------client {client} starts training----------") 163 | tem_param, train_summ = clients.client_train( 164 | args, 165 | deepcopy(global_param), 166 | train_loaders[client], 167 | epochs=args.local_epochs, 168 | ) 169 | 170 | # save client params 171 | # save_param( 172 | # args, 173 | # param=tem_param, 174 | # case=4, 175 | # client=client, 176 | # round=round, 177 | # is_global=False, 178 | # ) 179 | 180 | train_loss += train_summ["loss"] 181 | train_corr += train_summ["correct"] 182 | train_total += train_summ["total"] 183 | 184 | list_params.append(tem_param) 185 | 186 | res["train"]["loss"]["avg"].append(train_loss / len(chosen_clients)) 187 | res["train"]["acc"]["avg"].append(train_corr / train_total) 188 | 189 | print( 190 | "Train loss: {:5f} acc: {:5f}".format( 191 | res["train"]["loss"]["avg"][-1], 192 | res["train"]["acc"]["avg"][-1], 193 | ) 194 | ) 195 | 196 | # server aggregation 197 | global_param = server.FedAvg(list_params) 198 | 199 | # save global param 200 | save_param(args, param=global_param, case=4, round=round) 201 | 202 | res = update_results(args, res, global_param, test_loader, test_loader_poison) 203 | 204 | with open(args.out_file, "wb") as fp: 205 | pickle.dump(res, fp) 206 | 207 | else: 208 | ######################## onboarding round ############################ 209 | start_round = num_rounds + num_unlearn_rounds + num_post_training_rounds 210 | end_round = start_round + num_onboarding_rounds 211 | 212 | global_param = torch.load( 213 | f"./results/models/case4/case4_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}_round{start_round-1}.pt" 214 | ) 215 | 216 | res = load_results( 217 | f"./results/case4_{args.dataset}_C{args.num_clients}_BS{args.batch_size}_R{args.num_rounds}_UR{args.num_unlearn_rounds}_PR{args.num_post_training_rounds}_E{args.local_epochs}_LR{args.lr}.pkl" 218 | ) 219 | 220 | for round in range(start_round, end_round): 221 | print( 222 | "Round {}/{}: lr {} {}".format(round + 1, end_round, args.lr, args.out_file) 223 | ) 224 | 225 | train_loss, test_loss = 0, 0 226 | train_corr, test_acc = 0, 0 227 | train_total = 0 228 | list_params = [] 229 | 230 | chosen_clients = [i for i in range(args.num_clients)] 231 | 232 | for client in tqdm(chosen_clients): 233 | print(f"-----------client {client} starts training----------") 234 | tem_param, train_summ = clients.client_train( 235 | args, 236 | deepcopy(global_param), 237 | train_loaders[client], 238 | epochs=args.local_epochs, 239 | ) 240 | 241 | # save client params 242 | # save_param( 243 | # args, 244 | # param=tem_param, 245 | # case=4, 246 | # client=client, 247 | # round=round, 248 | # is_global=False, 249 | # ) 250 | 251 | train_loss += train_summ["loss"] 252 | train_corr += train_summ["correct"] 253 | train_total += train_summ["total"] 254 | 255 | list_params.append(tem_param) 256 | 257 | res["train"]["loss"]["avg"].append(train_loss / len(chosen_clients)) 258 | res["train"]["acc"]["avg"].append(train_corr / train_total) 259 | 260 | print( 261 | "Train loss: {:5f} acc: {:5f}".format( 262 | res["train"]["loss"]["avg"][-1], 263 | res["train"]["acc"]["avg"][-1], 264 | ) 265 | ) 266 | 267 | # server aggregation 268 | global_param = server.FedAvg(list_params) 269 | 270 | # save global param 271 | save_param(args, param=global_param, case=4, round=round) 272 | 273 | res = update_results(args, res, global_param, test_loader, test_loader_poison) 274 | 275 | with open(args.out_file, "wb") as fp: 276 | pickle.dump(res, fp) 277 | 278 | # total_time = time.time() - start_time 279 | # res["time"] = total_time 280 | # print(f"Time {total_time}") 281 | 282 | # with open(args.out_file, "wb") as fp: 283 | # pickle.dump(res, fp) 284 | -------------------------------------------------------------------------------- /result_sample/usage.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "\"\"\"\n", 10 | " Cases:\n", 11 | " Case 0: normal federated learning\n", 12 | " Case 1: baseline, retrain from scratch\n", 13 | " Case 2: method 1: continue train\n", 14 | " Case 3: method 2: PGA\n", 15 | " Case 4: method 3: federaser\n", 16 | " Case 5: method 4: flipping\n", 17 | "\"\"\"" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "\"\"\" \n", 27 | " List of settings:\n", 28 | " 1. MNIST: \n", 29 | " - R10, UR5, PR15, OR15\n", 30 | " - R10, UR1, PR15, OR15\n", 31 | " - R50, UR5, PR15, OR15\n", 32 | " 2. CIFAR10\n", 33 | " - R20, UR10, PR30, OR30\n", 34 | " - R20, UR2, PR30, OR30\n", 35 | " - R100, UR10, PR30, OR30\n", 36 | " List of experiments:\n", 37 | " 1. Accuracy\n", 38 | " - compare case 2 with case 1\n", 39 | " - compare case 3 with case 1\n", 40 | " - compare case 4 with case 1\n", 41 | " - compare case 5 with case 1\n", 42 | " 2. Accuracy on the last round before onboarding\n", 43 | " 3. Params similarity\n", 44 | " 4. Prediction Similarity\n", 45 | " 5. Unlearning time\n", 46 | "\"\"\"" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# import modules\n", 56 | "\n", 57 | "import pickle\n", 58 | "import matplotlib.pyplot as plt\n", 59 | "import os\n", 60 | "import numpy as np\n", 61 | "import sys\n", 62 | "sys.path.insert(0, '..')\n", 63 | "\n", 64 | "from utils.model import get_model\n", 65 | "import torch\n", 66 | "\n", 67 | "import pandas as pd\n" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "# configs for experiments\n", 77 | "\n", 78 | "configs = {\n", 79 | " \"mnist\": {\n", 80 | " \"num_round\": 50,\n", 81 | " \"num_unlearn_round\": 5,\n", 82 | " \"num_post_training_round\": 15\n", 83 | " },\n", 84 | " \"cifar10\": {\n", 85 | " \"num_round\": 100,\n", 86 | " \"num_unlearn_round\": 10,\n", 87 | " \"num_post_training_round\": 30\n", 88 | " },\n", 89 | " \"cifar100\": {\n", 90 | " \"num_round\": 100,\n", 91 | " \"num_unlearn_round\": 10,\n", 92 | " \"num_post_training_round\": 30\n", 93 | " },\n", 94 | " \"dataset\": \"cifar100\"\n", 95 | "}" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "# result structure\n", 105 | "res = {}\n", 106 | "\n", 107 | "for k1 in (\"train\", \"val\"):\n", 108 | " res[k1] = {}\n", 109 | " for k2 in (\"loss\", \"acc\"):\n", 110 | " res[k1][k2] = {}\n", 111 | " res[k1][k2][\"avg\"] = []\n", 112 | " res[k1][k2][\"clean\"] = []\n", 113 | " res[k1][k2][\"backdoor\"] = []\n", 114 | " for k3 in range(5):\n", 115 | " res[k1][k2][k3] = []\n", 116 | "\n", 117 | "# or, for better visualization, this is the architecture of res\n", 118 | "\n", 119 | "res = {\n", 120 | " \"train\": {\n", 121 | " \"loss\": {\n", 122 | " \"avg\": [],\n", 123 | " \"clean\": [],\n", 124 | " \"backdoor\": [],\n", 125 | " 0: [],\n", 126 | " 1: [],\n", 127 | " 2: [],\n", 128 | " 3: [],\n", 129 | " 4: []\n", 130 | " },\n", 131 | " \"acc\": {\n", 132 | " \"avg\": [],\n", 133 | " \"clean\": [],\n", 134 | " \"backdoor\": [],\n", 135 | " 0: [],\n", 136 | " 1: [],\n", 137 | " 2: [],\n", 138 | " 3: [],\n", 139 | " 4: []\n", 140 | " }\n", 141 | " },\n", 142 | " \"val\": {\n", 143 | " \"loss\": {\n", 144 | " \"avg\": [],\n", 145 | " \"clean\": [],\n", 146 | " \"backdoor\": [],\n", 147 | " 0: [],\n", 148 | " 1: [],\n", 149 | " 2: [],\n", 150 | " 3: [],\n", 151 | " 4: []\n", 152 | " },\n", 153 | " \"acc\": {\n", 154 | " \"avg\": [],\n", 155 | " \"clean\": [],\n", 156 | " \"backdoor\": [],\n", 157 | " 0: [],\n", 158 | " 1: [],\n", 159 | " 2: [],\n", 160 | " 3: [],\n", 161 | " 4: []\n", 162 | " }\n", 163 | " }\n", 164 | "}" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "paths = os.listdir(\"with_onboarding\")" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "def to_csv(X,Ys, filename, is_cuda = False):\n", 183 | " df = pd.DataFrame({\n", 184 | " X[\"label\"]: X[\"value\"],\n", 185 | " })\n", 186 | "\n", 187 | " if is_cuda:\n", 188 | " for label, Y in Ys.items():\n", 189 | " df[label] = [y.cpu().item() for y in Y]\n", 190 | " else:\n", 191 | " for label, Y in Ys.items():\n", 192 | " df[label] = Y\n", 193 | "\n", 194 | " df.to_csv(\"csvs/\" + filename, index = False)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "def load_gen(filename, type=\"acc\"):\n", 204 | " with open(filename, 'rb') as fp:\n", 205 | " data = pickle.load(fp)['val'][type]\n", 206 | " return data\n", 207 | "\n", 208 | "\n", 209 | "onboarding = True\n", 210 | "num_onboarding_rounds = 30\n", 211 | "\n", 212 | "\n", 213 | "if onboarding:\n", 214 | " folder = \"with_onboarding/\"\n", 215 | "else:\n", 216 | " folder = \"without_onboarding/\"\n", 217 | "\n", 218 | "\n", 219 | "name = {\n", 220 | " \"case0\": \"normal\",\n", 221 | " \"case1\": \"Retrain\",\n", 222 | " \"case2\": \"Continue to Train\",\n", 223 | " \"case3\": \"PGA\",\n", 224 | " \"case4\": \"FedEraser\",\n", 225 | " \"case5\": \"Flipping\"\n", 226 | "}\n", 227 | "\n", 228 | "\n", 229 | "def show_result(path, methods=[1, 2, 3, 4], is_marked=False):\n", 230 | "\n", 231 | " markers = [\"\", \"bo--\", \"gx--\", \"m^-\", \"c+-\", \"r>-\", \"y<-\", \"ks-\", \"yd-\"]\n", 232 | "\n", 233 | " num_rounds = 0\n", 234 | "\n", 235 | " for i in [3, 4, 5]:\n", 236 | " temp = 0\n", 237 | " if i == 3:\n", 238 | " temp = int(path.split(\"_\")[i][1:])\n", 239 | " else:\n", 240 | " temp = int(path.split(\"_\")[i][2:])\n", 241 | "\n", 242 | " num_rounds += temp\n", 243 | "\n", 244 | " num_rounds += num_onboarding_rounds \n", 245 | "\n", 246 | " fl_rounds = [i for i in range(1, num_rounds + 1)]\n", 247 | "\n", 248 | " filename_baseline = f\"case0_{path}\"\n", 249 | " baseline = load_gen(folder + filename_baseline)\n", 250 | "\n", 251 | " for i in methods:\n", 252 | " filename = f\"case{i}_{path}\"\n", 253 | " try:\n", 254 | " data = load_gen(folder + filename)\n", 255 | " except:\n", 256 | " continue\n", 257 | " case = f\"case{i}\"\n", 258 | "\n", 259 | " if i != 1:\n", 260 | " clean_data = baseline[\"clean\"] + data[\"clean\"]\n", 261 | " backdoor_data = baseline[\"backdoor\"] + data[\"backdoor\"]\n", 262 | " else:\n", 263 | " clean_data = data[\"clean\"]\n", 264 | " backdoor_data = data[\"backdoor\"]\n", 265 | "\n", 266 | " to_csv( \n", 267 | " {\n", 268 | " \"label\": \"Rounds\",\n", 269 | " \"value\": fl_rounds\n", 270 | " },\n", 271 | " {\n", 272 | " \"clean_data\": clean_data,\n", 273 | " \"backdoor_data\": backdoor_data\n", 274 | " },\n", 275 | " f\"exp1_accuracy/{configs['dataset']}_case{i}_R{configs[configs['dataset']]['num_round']}_UR{configs[configs['dataset']]['num_unlearn_round']}_PR{configs[configs['dataset']]['num_post_training_round']}.csv\"\n", 276 | " )\n", 277 | "\n", 278 | " if is_marked:\n", 279 | " plt.plot(fl_rounds, clean_data, markers[2*i-1], label=f\"{name[case]} clean\")\n", 280 | " plt.plot(fl_rounds, backdoor_data, markers[2*i], label=f\"{name[case]} backdoor\")\n", 281 | " else:\n", 282 | " plt.plot(fl_rounds, clean_data, label=f\"{name[case]} clean\")\n", 283 | " plt.plot(fl_rounds, backdoor_data, label=f\"{name[case]} backdoor\")\n", 284 | "\n", 285 | " plt.xlabel('Rounds')\n", 286 | " plt.ylabel('Accuracy')\n", 287 | " plt.locator_params(axis=\"x\", integer=True)\n", 288 | " plt.grid()\n", 289 | " plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))\n", 290 | "\n", 291 | " method_string = \"\"\n", 292 | " for i in methods:\n", 293 | " method_string += str(i)\n", 294 | "\n", 295 | " plt.savefig(f\"plot/{configs['dataset']}/{path[:-4]}_M{method_string}.png\", dpi=1200, bbox_inches='tight')\n", 296 | " \n", 297 | " plt.show()\n" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "def show_result_all(path, methods=[1, 2, 3, 4, 5], is_clean = True, is_marked=True):\n", 307 | "\n", 308 | " markers = [\"\", \"^\", \"s\", \"<\", \"o\", \"v\"]\n", 309 | " colors = [\"\", \"b\", \"orange\", \"g\", \"r\", \"k\"]\n", 310 | "\n", 311 | " num_rounds = 0\n", 312 | "\n", 313 | " for i in [3, 4, 5]:\n", 314 | " temp = 0\n", 315 | " if i == 3:\n", 316 | " temp = int(path.split(\"_\")[i][1:])\n", 317 | " else:\n", 318 | " temp = int(path.split(\"_\")[i][2:])\n", 319 | "\n", 320 | " num_rounds += temp\n", 321 | "\n", 322 | " num_rounds += num_onboarding_rounds \n", 323 | "\n", 324 | " fl_rounds = [i for i in range(1, num_rounds + 1)]\n", 325 | "\n", 326 | " filename_baseline = f\"case0_{path}\"\n", 327 | " baseline = load_gen(folder + filename_baseline)\n", 328 | "\n", 329 | " for i in methods:\n", 330 | " filename = f\"case{i}_{path}\"\n", 331 | " try:\n", 332 | " data = load_gen(folder + filename)\n", 333 | " except:\n", 334 | " continue\n", 335 | " case = f\"case{i}\"\n", 336 | "\n", 337 | " if is_clean:\n", 338 | " if i != 1:\n", 339 | " clean_data = baseline[\"clean\"] + data[\"clean\"]\n", 340 | " # backdoor_data = baseline[\"backdoor\"] + data[\"backdoor\"]\n", 341 | " else:\n", 342 | " clean_data = data[\"clean\"]\n", 343 | " # backdoor_data = data[\"backdoor\"]\n", 344 | "\n", 345 | " if is_marked:\n", 346 | " plt.plot(fl_rounds, clean_data, marker = markers[i], markevery= 10, color = colors[i], label=f\"{name[case]}\")\n", 347 | " # plt.plot(fl_rounds, backdoor_data, markers[2*i], label=f\"{name[case]} backdoor\")\n", 348 | " else:\n", 349 | " plt.plot(fl_rounds, clean_data, color = colors[i], label=f\"{name[case]}\")\n", 350 | " # plt.plot(fl_rounds, backdoor_data, label=f\"{name[case]} backdoor\")\n", 351 | " else:\n", 352 | " if i != 1:\n", 353 | " # clean_data = baseline[\"clean\"] + data[\"clean\"]\n", 354 | " backdoor_data = baseline[\"backdoor\"] + data[\"backdoor\"]\n", 355 | " else:\n", 356 | " # clean_data = data[\"clean\"]\n", 357 | " backdoor_data = data[\"backdoor\"]\n", 358 | "\n", 359 | " if is_marked:\n", 360 | " # plt.plot(fl_rounds, clean_data, markers[2*i-1], label=f\"{name[case]} clean\")\n", 361 | " plt.plot(fl_rounds, backdoor_data, marker = markers[i], markevery=10, color = colors[i], label=f\"{name[case]}\")\n", 362 | " else:\n", 363 | " # plt.plot(fl_rounds, clean_data, label=f\"{name[case]} clean\")\n", 364 | " plt.plot(fl_rounds, backdoor_data, color = colors[i], label=f\"{name[case]}\")\n", 365 | "\n", 366 | " plt.xlabel('Rounds')\n", 367 | " plt.ylabel('Accuracy')\n", 368 | " plt.locator_params(axis=\"x\", integer=True)\n", 369 | " plt.grid()\n", 370 | " plt.legend(loc='best')\n", 371 | "\n", 372 | " method_string = \"\"\n", 373 | " for i in methods:\n", 374 | " method_string += str(i)\n", 375 | "\n", 376 | " type = \"\"\n", 377 | " if is_clean:\n", 378 | " type = \"clean\"\n", 379 | " else:\n", 380 | " type = \"backdoor\"\n", 381 | "\n", 382 | " plt.savefig(f\"plot/{configs['dataset']}/{path[:-4]}_M{method_string}_{type}.png\", dpi=1200, bbox_inches='tight')\n", 383 | " plt.savefig(f\"plot/{configs['dataset']}/{path[:-4]}_M{method_string}_{type}.pdf\", dpi=1200, bbox_inches='tight')\n", 384 | " \n", 385 | " plt.show()\n" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": null, 391 | "metadata": {}, 392 | "outputs": [], 393 | "source": [ 394 | "# 1. Accuracy of all methods\n", 395 | "\n", 396 | "path = f\"{configs['dataset']}_C5_BS128_R{configs[configs['dataset']]['num_round']}_UR{configs[configs['dataset']]['num_unlearn_round']}_PR{configs[configs['dataset']]['num_post_training_round']}_E1_LR0.01.pkl\"\n", 397 | "\n", 398 | "show_result_all(path, is_clean=False)" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": null, 404 | "metadata": {}, 405 | "outputs": [], 406 | "source": [ 407 | "def show_last_round_result_before_onboarding(path, methods=[1, 2, 3, 4]):\n", 408 | " filename_baseline = f\"case0_{path}\"\n", 409 | " baseline = load_gen(folder + filename_baseline)\n", 410 | "\n", 411 | " clean_data = []\n", 412 | " backdoor_data = []\n", 413 | "\n", 414 | " clean_labels = []\n", 415 | " backdoor_labels = []\n", 416 | "\n", 417 | " method_names = [name[f\"case{i}\"] for i in methods]\n", 418 | " x_axis = np.arange(len(method_names))\n", 419 | "\n", 420 | " for i in methods:\n", 421 | " filename = f\"case{i}_{path}\"\n", 422 | " try:\n", 423 | " data = load_gen(folder + filename)\n", 424 | " except:\n", 425 | " continue\n", 426 | " case = f\"case{i}\"\n", 427 | "\n", 428 | "\n", 429 | " clean_data.append(data[\"clean\"][-configs[configs[\"dataset\"]][\"num_post_training_round\"]-1])\n", 430 | " backdoor_data.append(data[\"backdoor\"][-configs[configs[\"dataset\"]][\"num_post_training_round\"]-1])\n", 431 | "\n", 432 | " clean_label = f\"{name[case]} clean\"\n", 433 | " backdoor_label = f\"{name[case]} backdoor\"\n", 434 | " clean_labels.append(clean_label)\n", 435 | " backdoor_labels.append(backdoor_label)\n", 436 | "\n", 437 | " plt.bar(x_axis-0.2, clean_data, 0.4, label=\"clean\")\n", 438 | " plt.bar(x_axis+0.2, backdoor_data, 0.4, label=\"backdoor\")\n", 439 | "\n", 440 | " plt.xticks(x_axis, method_names)\n", 441 | " plt.xlabel('Methods')\n", 442 | " plt.ylabel('Accuracy')\n", 443 | " plt.title(\"Last Round Accuracy\")\n", 444 | " plt.grid()\n", 445 | " plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))\n", 446 | " \n", 447 | " plt.show()" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": null, 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [ 456 | "# 2. Last round accuracy\n", 457 | "\n", 458 | "\"\"\"\n", 459 | " This cell is to run the first experiment: accuracy on the last round before onboarding\n", 460 | "\"\"\"\n", 461 | "\n", 462 | "path = f\"{configs['dataset']}_C5_BS128_R{configs[configs['dataset']]['num_round']}_UR{configs[configs['dataset']]['num_unlearn_round']}_PR{configs[configs['dataset']]['num_post_training_round']}_E1_LR0.01.pkl\"\n", 463 | "\n", 464 | "show_last_round_result_before_onboarding(path, methods=[1, 2, 3, 4, 5])" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": null, 470 | "metadata": {}, 471 | "outputs": [], 472 | "source": [ 473 | "def show_numerical_result(path, methods=[1, 2, 3, 4], dataset = \"mnist\"):\n", 474 | " filename_baseline = f\"case0_{path}\"\n", 475 | " baseline = load_gen(folder + filename_baseline)\n", 476 | "\n", 477 | " clean_data = []\n", 478 | " backdoor_data = []\n", 479 | "\n", 480 | " clean_labels = []\n", 481 | " backdoor_labels = []\n", 482 | "\n", 483 | " # method_names = [name[f\"case{i}\"] for i in methods]\n", 484 | " # x_axis = np.arange(len(method_names))\n", 485 | "\n", 486 | " for i in methods:\n", 487 | " filename = f\"case{i}_{path}\"\n", 488 | " try:\n", 489 | " data = load_gen(folder + filename)\n", 490 | " except:\n", 491 | " continue\n", 492 | "\n", 493 | " case = f\"case{i}\"\n", 494 | "\n", 495 | " # clean_data.append(data[\"clean\"][-1])\n", 496 | " # backdoor_data.append(data[\"backdoor\"][-1])\n", 497 | "\n", 498 | " clean_label = f\"{name[case]} clean\"\n", 499 | " # clean_labels.append(clean_label)\n", 500 | " # backdoor_labels.append(backdoor_label)\n", 501 | "\n", 502 | " print(clean_label)\n", 503 | " \n", 504 | " res_str = \"\"\n", 505 | "\n", 506 | " if i == 1:\n", 507 | " # i=1: Continue train\n", 508 | " res_str += f\"{data['clean'][configs[dataset]['num_round'] - 1]} & {data['backdoor'][configs[dataset]['num_round'] - 1]} & \"\n", 509 | " res_str += f\"{data['clean'][configs[dataset]['num_round']]} & {data['backdoor'][configs[dataset]['num_round']]} & \"\n", 510 | " res_str += f\"{data['clean'][configs[dataset]['num_round'] + configs[dataset]['num_unlearn_round'] - 1]} & {data['backdoor'][configs[dataset]['num_round'] + configs[dataset]['num_unlearn_round'] - 1]} & \"\n", 511 | " res_str += f\"{data['clean'][configs[dataset]['num_round'] + configs[dataset]['num_unlearn_round']]} & {data['backdoor'][configs[dataset]['num_round'] + configs[dataset]['num_unlearn_round']]} & \"\n", 512 | " res_str += f\"{data['clean'][configs[dataset]['num_round'] + configs[dataset]['num_unlearn_round'] + configs[dataset]['num_post_training_round'] - 1]} & {data['backdoor'][configs[dataset]['num_round'] + configs[dataset]['num_unlearn_round'] + configs[dataset]['num_post_training_round'] - 1]} & \"\n", 513 | " res_str += f\"{data['clean'][configs[dataset]['num_round'] + configs[dataset]['num_unlearn_round'] + configs[dataset]['num_post_training_round']]} & {data['backdoor'][configs[dataset]['num_round'] + configs[dataset]['num_unlearn_round'] + configs[dataset]['num_post_training_round']]} & \"\n", 514 | " res_str += f\"{data['clean'][-1]} & {data['backdoor'][-1]}\"\n", 515 | " else:\n", 516 | " res_str += f\"{baseline['clean'][configs[dataset]['num_round'] - 1]} & {baseline['backdoor'][configs[dataset]['num_round'] - 1]} & \"\n", 517 | " res_str += f\"{data['clean'][0]} & {data['backdoor'][0]} & \"\n", 518 | " res_str += f\"{data['clean'][configs[dataset]['num_unlearn_round'] - 1]} & {data['backdoor'][configs[dataset]['num_unlearn_round'] - 1]} & \"\n", 519 | " res_str += f\"{data['clean'][configs[dataset]['num_unlearn_round']]} & {data['backdoor'][configs[dataset]['num_unlearn_round']]} & \"\n", 520 | " res_str += f\"{data['clean'][configs[dataset]['num_unlearn_round'] + configs[dataset]['num_post_training_round'] - 1]} & {data['backdoor'][configs[dataset]['num_unlearn_round'] + configs[dataset]['num_post_training_round'] - 1]} & \"\n", 521 | " res_str += f\"{data['clean'][configs[dataset]['num_unlearn_round'] + configs[dataset]['num_post_training_round']]} & {data['backdoor'][configs[dataset]['num_unlearn_round'] + configs[dataset]['num_post_training_round']]} & \"\n", 522 | " res_str += f\"{data['clean'][-1]} & {data['backdoor'][-1]}\"\n", 523 | "\n", 524 | "\n", 525 | " print(res_str)\n", 526 | " " 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": null, 532 | "metadata": {}, 533 | "outputs": [], 534 | "source": [ 535 | "path = f\"{configs['dataset']}_C5_BS128_R{configs[configs['dataset']]['num_round']}_UR{configs[configs['dataset']]['num_unlearn_round']}_PR{configs[configs['dataset']]['num_post_training_round']}_E1_LR0.01.pkl\"\n", 536 | "show_numerical_result(path, methods=[1, 2, 3, 4, 5], dataset = configs['dataset'])" 537 | ] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "execution_count": null, 542 | "metadata": {}, 543 | "outputs": [], 544 | "source": [ 545 | "def load_time(filename, type=\"acc\"):\n", 546 | " with open(filename, 'rb') as fp:\n", 547 | " data = pickle.load(fp)[\"time\"]\n", 548 | " return data\n", 549 | "\n", 550 | "\n", 551 | "\n", 552 | "def show_time(path, methods):\n", 553 | " num_rounds = 0\n", 554 | " for i in [3, 4, 5]:\n", 555 | " temp = 0\n", 556 | " if i == 3:\n", 557 | " temp = int(path.split(\"_\")[i][1:])\n", 558 | " else:\n", 559 | " temp = int(path.split(\"_\")[i][2:])\n", 560 | " num_rounds += temp\n", 561 | "\n", 562 | " num_rounds += num_onboarding_rounds\n", 563 | "\n", 564 | " fl_rounds = [i for i in range(1, num_rounds + 1)]\n", 565 | "\n", 566 | " method_names = [name[f\"case{i}\"] for i in methods]\n", 567 | " x_axis = np.arange(len(method_names))\n", 568 | "\n", 569 | " retrain_time = 0\n", 570 | " factors = []\n", 571 | "\n", 572 | " for i in methods:\n", 573 | " filename = f\"case{i}_{path}\"\n", 574 | " try:\n", 575 | " time = load_time(folder + filename)\n", 576 | " except:\n", 577 | " print(filename)\n", 578 | " continue\n", 579 | "\n", 580 | " if i == 1:\n", 581 | " retrain_time = time\n", 582 | "\n", 583 | " factor = time/retrain_time\n", 584 | "\n", 585 | " factors.append(factor)\n", 586 | "\n", 587 | " # print(method_names)\n", 588 | " # print(factors)\n", 589 | " plt.bar(method_names, factors)\n", 590 | " plt.ylabel('Unit')\n", 591 | " plt.grid()\n", 592 | " #plt.locator_params(axis=\"x\", integer=True)\n", 593 | " #plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))\n", 594 | " plt.show()" 595 | ] 596 | }, 597 | { 598 | "cell_type": "code", 599 | "execution_count": null, 600 | "metadata": {}, 601 | "outputs": [], 602 | "source": [ 603 | "# 5. Unlearning time\n", 604 | "\n", 605 | "\"\"\"\n", 606 | " This cell is to run the fifth experiment: measuring unlearning time\n", 607 | "\"\"\"\n", 608 | "\n", 609 | "path = f\"{configs['dataset']}_C5_BS128_R{configs[configs['dataset']]['num_round']}_UR{configs[configs['dataset']]['num_unlearn_round']}_PR{configs[configs['dataset']]['num_post_training_round']}_E1_LR0.01.pkl\"\n", 610 | "show_time(path, methods=[1, 2, 3, 4, 5])" 611 | ] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "execution_count": null, 616 | "metadata": {}, 617 | "outputs": [], 618 | "source": [ 619 | "def show_time_detail(path, methods):\n", 620 | " num_rounds = 0\n", 621 | " for i in [3, 4, 5]:\n", 622 | " temp = 0\n", 623 | " if i == 3:\n", 624 | " temp = int(path.split(\"_\")[i][1:])\n", 625 | " else:\n", 626 | " temp = int(path.split(\"_\")[i][2:])\n", 627 | " num_rounds += temp\n", 628 | "\n", 629 | " num_rounds += num_onboarding_rounds\n", 630 | "\n", 631 | " fl_rounds = [i for i in range(1, num_rounds + 1)]\n", 632 | "\n", 633 | " method_names = [name[f\"case{i}\"] for i in methods]\n", 634 | " x_axis = np.arange(len(method_names))\n", 635 | "\n", 636 | " retrain_time = 0\n", 637 | " factors = []\n", 638 | "\n", 639 | " for i in methods:\n", 640 | " filename = f\"case{i}_{path}\"\n", 641 | " try:\n", 642 | " time = load_time(folder + filename)\n", 643 | " except:\n", 644 | " print(filename)\n", 645 | " continue\n", 646 | "\n", 647 | " if i == 1:\n", 648 | " retrain_time = time\n", 649 | "\n", 650 | " factor = time/retrain_time\n", 651 | "\n", 652 | " factors.append(factor)\n", 653 | "\n", 654 | " print(f\"{method_names[i-1]}: {time:.2f}({(retrain_time / time):.2f}x)\")\n" 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "execution_count": null, 660 | "metadata": {}, 661 | "outputs": [], 662 | "source": [ 663 | "# 5. Unlearning time\n", 664 | "\n", 665 | "\"\"\"\n", 666 | " This cell is to run the fifth experiment: measuring unlearning time\n", 667 | "\"\"\"\n", 668 | "\n", 669 | "path = f\"{configs['dataset']}_C5_BS128_R{configs[configs['dataset']]['num_round']}_UR{configs[configs['dataset']]['num_unlearn_round']}_PR{configs[configs['dataset']]['num_post_training_round']}_E1_LR0.01.pkl\"\n", 670 | "show_time_detail(path, methods=[1, 2, 3, 4, 5])" 671 | ] 672 | }, 673 | { 674 | "cell_type": "code", 675 | "execution_count": null, 676 | "metadata": {}, 677 | "outputs": [], 678 | "source": [ 679 | "args = {\n", 680 | " \"dataset\": configs['dataset'],\n", 681 | " \"num_clients\": 5,\n", 682 | " \"batch_size\": 128,\n", 683 | " \"num_rounds\": configs[configs['dataset']]['num_round'],\n", 684 | " \"num_unlearn_rounds\": configs[configs['dataset']]['num_unlearn_round'],\n", 685 | " \"num_post_training_rounds\": configs[configs['dataset']]['num_post_training_round'],\n", 686 | " \"local_epochs\": 1,\n", 687 | " \"lr\": 0.01,\n", 688 | " \"device\": torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n", 689 | " \"poisoned_percent\": 0.9\n", 690 | "}" 691 | ] 692 | }, 693 | { 694 | "cell_type": "code", 695 | "execution_count": null, 696 | "metadata": {}, 697 | "outputs": [], 698 | "source": [ 699 | "def load_model(path):\n", 700 | " model = get_model(args, plotting=True)\n", 701 | " model.load_state_dict(torch.load(path))\n", 702 | "\n", 703 | " return model" 704 | ] 705 | }, 706 | { 707 | "cell_type": "code", 708 | "execution_count": null, 709 | "metadata": {}, 710 | "outputs": [], 711 | "source": [ 712 | "# load the baseline model after learning phase\n", 713 | "case = 2\n", 714 | "\n", 715 | "path = f\"../results/models/case1/case1_{args['dataset']}_C{args['num_clients']}_BS{args['batch_size']}_R{args['num_rounds']}_UR{args['num_unlearn_rounds']}_PR{args['num_post_training_rounds']}_E{args['local_epochs']}_LR{args['lr']}_round{args['num_rounds'] - 1}.pt\"\n", 716 | "baseline_model = load_model(path)\n", 717 | "\n", 718 | "# path2 = f\"../results/models/case{case}/case{case}_{args['dataset']}_C{args['num_clients']}_BS{args['batch_size']}_R{args['num_rounds']}_UR{args['num_unlearn_rounds']}_PR{args['num_post_training_rounds']}_E{args['local_epochs']}_LR{args['lr']}_round{args['num_rounds']}.pt\"\n", 719 | "# model2 = load_model(path2)" 720 | ] 721 | }, 722 | { 723 | "cell_type": "code", 724 | "execution_count": null, 725 | "metadata": {}, 726 | "outputs": [], 727 | "source": [ 728 | "def plot_diff(X, Y, title):\n", 729 | " Y = [y.cpu().numpy() for y in Y]\n", 730 | " \n", 731 | " plt.plot(X, Y)\n", 732 | "\n", 733 | " plt.xlabel('Rounds')\n", 734 | " plt.ylabel('Difference')\n", 735 | "\n", 736 | " plt.xticks(np.arange(min(X), max(X)+1, len(X) // 10))\n", 737 | " \n", 738 | " plt.title(title)\n", 739 | " \n", 740 | " plt.show()" 741 | ] 742 | }, 743 | { 744 | "cell_type": "code", 745 | "execution_count": null, 746 | "metadata": {}, 747 | "outputs": [], 748 | "source": [ 749 | "def compare_prediction(model1, model2, data_loader):\n", 750 | " model1.eval()\n", 751 | " model2.eval()\n", 752 | "\n", 753 | " output1s = torch.tensor([])\n", 754 | " output2s = torch.tensor([])\n", 755 | "\n", 756 | " with torch.no_grad():\n", 757 | " for data, target in data_loader:\n", 758 | " data = data.to(args[\"device\"])\n", 759 | " target = target.to(args[\"device\"])\n", 760 | "\n", 761 | " output1 = model1(data).argmax(dim=1).detach().cpu().float()\n", 762 | " output2 = model2(data).argmax(dim=1).detach().cpu().float()\n", 763 | "\n", 764 | " output1s = torch.cat((output1s, output1))\n", 765 | " output2s = torch.cat((output2s, output2))\n", 766 | " \n", 767 | " \n", 768 | " cos = torch.nn.CosineSimilarity(dim=0, eps=1e-9)\n", 769 | " return cos(output1s, output2s)\n", 770 | "\n" 771 | ] 772 | }, 773 | { 774 | "cell_type": "code", 775 | "execution_count": null, 776 | "metadata": {}, 777 | "outputs": [], 778 | "source": [ 779 | "from utils.dataloader import get_loaders\n", 780 | "train_loaders, test_loader, test_loader_poison = get_loaders(args, plotting=True)\n", 781 | "\n", 782 | "\n", 783 | "markers = [\"\", \"\", \"^\", \"s\", \"<\", \"o\", \"v\"]\n", 784 | "colors = [\"\", \"\", \"b\", \"orange\", \"g\", \"r\", \"k\"]\n", 785 | "\n", 786 | "for case in [2,3,4,5]:\n", 787 | " X = []\n", 788 | " Y = []\n", 789 | "\n", 790 | " for i in range(args['num_rounds'], args['num_rounds'] + args['num_unlearn_rounds'] + args['num_post_training_rounds']):\n", 791 | " path = f\"../results/models/case{case}/case{case}_{args['dataset']}_C{args['num_clients']}_BS{args['batch_size']}_R{args['num_rounds']}_UR{args['num_unlearn_rounds']}_PR{args['num_post_training_rounds']}_E{args['local_epochs']}_LR{args['lr']}_round{i}.pt\"\n", 792 | " unlearned_model = load_model(path)\n", 793 | "\n", 794 | " cos_sim = compare_prediction(unlearned_model, baseline_model, test_loader)\n", 795 | " # print(cos_sim)\n", 796 | " X.append(i)\n", 797 | " Y.append(cos_sim)\n", 798 | "\n", 799 | " \n", 800 | " Y = [y.cpu().numpy() for y in Y]\n", 801 | " \n", 802 | " case_name = f\"case{case}\"\n", 803 | "\n", 804 | " plt.plot(X, Y, marker = markers[case], markevery= 10, color = colors[case], label=f\"{name[case_name]}\")\n", 805 | " \n", 806 | "\n", 807 | "plt.xticks(np.arange(min(X), max(X)+1, len(X) // 10))\n", 808 | "\n", 809 | "plt.xlabel('Rounds')\n", 810 | "plt.ylabel('Cosine Similarity')\n", 811 | "\n", 812 | "plt.grid()\n", 813 | "plt.legend(loc='best')\n", 814 | "\n", 815 | "\n", 816 | "plt.savefig(f\"plot/{configs['dataset']}/Cosine_Similarity.png\", dpi=1200, bbox_inches='tight')\n", 817 | "plt.savefig(f\"plot/{configs['dataset']}/Cosine_Similarity.pdf\", dpi=1200, bbox_inches='tight')\n", 818 | "\n", 819 | "plt.show()" 820 | ] 821 | }, 822 | { 823 | "cell_type": "code", 824 | "execution_count": null, 825 | "metadata": {}, 826 | "outputs": [], 827 | "source": [] 828 | } 829 | ], 830 | "metadata": { 831 | "kernelspec": { 832 | "display_name": "Python 3 (ipykernel)", 833 | "language": "python", 834 | "name": "python3" 835 | }, 836 | "language_info": { 837 | "codemirror_mode": { 838 | "name": "ipython", 839 | "version": 3 840 | }, 841 | "file_extension": ".py", 842 | "mimetype": "text/x-python", 843 | "name": "python", 844 | "nbconvert_exporter": "python", 845 | "pygments_lexer": "ipython3", 846 | "version": "3.10.9" 847 | }, 848 | "orig_nbformat": 4, 849 | "vscode": { 850 | "interpreter": { 851 | "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" 852 | } 853 | } 854 | }, 855 | "nbformat": 4, 856 | "nbformat_minor": 2 857 | } 858 | --------------------------------------------------------------------------------