├── .gitignore ├── LICENSE ├── README.md ├── data ├── README.md ├── __init__.py ├── cifar │ └── .gitkeep └── mnist │ └── .gitkeep ├── imgs ├── 05cifar_lenet.pdf ├── 05cifar_vgg.pdf ├── 05fmnist_lenet.pdf ├── 09cifar_lenet.pdf ├── 09cifar_vgg.pdf ├── 09fmnist_lenet.pdf ├── 2cifar_lenet.pdf ├── 2cifar_vgg.pdf ├── 2fmnist_lenet.pdf ├── fed_acc.pdf └── local_acc.pdf ├── main_fed.py ├── main_gate.py ├── main_local.py ├── main_nn.py ├── main_per_fb.py ├── models ├── Fed.py ├── Nets.py ├── Test.py ├── Update.py └── __init__.py ├── requirements.txt └── utils ├── __init__.py ├── options.py ├── sampling.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | imgs/*_f.pdf 2 | runs/* 3 | save/ 4 | plot* 5 | test.py 6 | main__gate.py 7 | _config.yml 8 | main_gate_single.py 9 | 10 | # pycharm 11 | .idea/* 12 | 13 | # documents 14 | *.csv 15 | .xls 16 | .xlsx 17 | .pdf 18 | .json 19 | 20 | # macOS 21 | .DS_Store 22 | 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | *$py.class 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .Python 33 | env/ 34 | build/ 35 | develop-eggs/ 36 | dist/ 37 | downloads/ 38 | eggs/ 39 | .eggs/ 40 | lib/ 41 | lib64/ 42 | parts/ 43 | sdist/ 44 | var/ 45 | wheels/ 46 | *.egg-info/ 47 | .installed.cfg 48 | *.egg 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # virtualenv 57 | .venv 58 | venv/ 59 | ENV/ 60 | 61 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Shaoxiong Ji 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PFL-MoE: Personalized Federated Learning Based on Mixture of Experts 2 | 3 | In our experiments, we use two image recognition datasets to conduct model training: 4 | Fashion-MNIST and CIFAR-10. With two network models trained, we have three combinations: Fashion-MNIST + LeNet-5, CIFAR-10 + LeNet-5, and CIFAR-10 + VGG-16. 5 | 6 | ## Requirements 7 | python>=3.6 8 | pytorch>=0.4 9 | 10 | ## Run 11 | dataset+model: fmnist+lenet, cifar+lenet, cifar+vgg
12 | $\alpha=[0.5, 0.9, 2.0]$ for each group of dataset+model 13 | 14 | Local: 15 | > python [main_local.py](main_local.py) --dataset fmnist --model lenet --epochs 100 --gpu 0 --num_users 100 --alpha 0.5 16 | 17 | FedAvg: 18 | > python [main_fed.py](main_fed.py) --dataset fmnist --model lenet --epochs 1000 --gpu 0 --lr 0.01 --num_users 100 --frac 0.1 --alpha 0.5 19 | 20 | PFL-FB + PFL-MF: 21 | > python [main_gate.py](main_gate.py) --dataset fmnist --model lenet --epochs 200 --num_users 100 --gpu 1 --alpha 0.5 22 | 23 | PFL-FB + PFL-MFE: 24 | > python [main_gate.py](main_gate.py) --dataset fmnist --model lenet --epochs 200 --num_users 100 --gpu 1 --alpha 0.5 --struct 25 | 26 | See the arguments in [options.py](utils/options.py). 27 | ## Results 28 | ### 29 | Each client has two types of tests, including local test and global test. 30 | 31 | Table 1. The average value of **local test** accuracy of all clients in three baselines and proposed algorithms. 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 |
non-IIDLocal(%)FedAvg(%)PFL-FB(%)PFL-MF(%)PFL-MFE(%)
Fashion-MNIST & LeNet50.584.879092.8492.8592.89
0.982.2390.3191.8492.0292.01
278.6390.590.4790.9790.93
CIFAR-10 & LeNet50.565.5868.9277.4675.4977.23
0.961.4970.774.774.174.74
255.872.6972.573.2473.44
CIFAR-10 & VGG-160.552.7788.1691.9290.6391.71
0.945.2488.4591.3490.6391.18
234.289.1790.490.1590.4
119 | 120 | Table 2. The average value of **global test** accuracy of all clients. 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 |
non-IID Local(%)FedAvg(%)PFL-FB(%)PFL-MF(%)PFL-MFE(%)
Fashion-MNIST & LeNet50.557.779083.3585.4585.3
0.965.2890.3185.9187.6987.67
271.0690.587.7789.3789.18
CIFAR-10 & LeNet50.528.8968.9254.2862.3358.27
0.932.170.759.9365.7864.13
235.3272.6966.0669.7969.78
CIFAR-10 & VGG-160.521.5388.1682.3985.8184.05
0.922.4588.4582.6288.1587.9
221.2789.1788.7789.389.3
208 | 209 | Fig 1. Fashion-MNIST + LeNet-5, $\alpha=0.9$. The global test accuracy and local test accuracy of all client of PFL-FB, PFL-MF, and PFL-MFE algorithms. All x-axis are FedAvg local test accuracy of each client (can be regarded as client index). Each point represents a test accuracy comparison between a PFL algorithm and FedAvg for a particular client. 210 | 211 | ![fmnist_lenet_0.9](https://github.com/guobbin/PFL-MoE/blob/master/imgs/09fmnist_lenet.pdf) 212 | 213 | Fig 2. CIFAR-10 + LeNet-5, $\alpha=0.9$. 214 | 215 | ![cifar_lenet_0.9](https://github.com/guobbin/PFL-MoE/blob/master/imgs/09cifar_lenet.pdf) 216 | 217 | Fig 3. CIFAR-10 + VGG-16, $\alpha=2.0$. 218 | 219 | ![cifar_vgg_2.0](https://github.com/guobbin/PFL-MoE/blob/master/imgs/2cifar_vgg.pdf) 220 | 221 | ## Acknowledgements 222 | 223 | Acknowledgments give to [shaoxiongji](https://github.com/shaoxiongji) 224 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | MNIST & CIFAR-10 datasets will be downloaded automatically by the torchvision package. 4 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /data/cifar/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/data/cifar/.gitkeep -------------------------------------------------------------------------------- /data/mnist/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/data/mnist/.gitkeep -------------------------------------------------------------------------------- /imgs/05cifar_lenet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/05cifar_lenet.pdf -------------------------------------------------------------------------------- /imgs/05cifar_vgg.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/05cifar_vgg.pdf -------------------------------------------------------------------------------- /imgs/05fmnist_lenet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/05fmnist_lenet.pdf -------------------------------------------------------------------------------- /imgs/09cifar_lenet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/09cifar_lenet.pdf -------------------------------------------------------------------------------- /imgs/09cifar_vgg.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/09cifar_vgg.pdf -------------------------------------------------------------------------------- /imgs/09fmnist_lenet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/09fmnist_lenet.pdf -------------------------------------------------------------------------------- /imgs/2cifar_lenet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/2cifar_lenet.pdf -------------------------------------------------------------------------------- /imgs/2cifar_vgg.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/2cifar_vgg.pdf -------------------------------------------------------------------------------- /imgs/2fmnist_lenet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/2fmnist_lenet.pdf -------------------------------------------------------------------------------- /imgs/fed_acc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/fed_acc.pdf -------------------------------------------------------------------------------- /imgs/local_acc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/local_acc.pdf -------------------------------------------------------------------------------- /main_fed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import copy 6 | import numpy as np 7 | from torchvision import datasets, transforms 8 | import torch 9 | 10 | from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid 11 | from utils.options import args_parser 12 | from models.Update import LocalUpdate 13 | from models.Nets import vgg16, CNNCifar 14 | from models.Fed import FedAvg 15 | from models.Test import test_img 16 | from utils.util import setup_seed 17 | from datetime import datetime 18 | from torch.utils.tensorboard import SummaryWriter 19 | 20 | 21 | if __name__ == '__main__': 22 | # parse args 23 | args = args_parser() 24 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 25 | setup_seed(args.seed) 26 | 27 | # log 28 | current_time = datetime.now().strftime('%b.%d_%H.%M.%S') 29 | TAG = 'exp/fed/{}_{}_{}_C{}_iid{}_{}_user{}_{}'.format(args.dataset, args.model, args.epochs, args.frac, args.iid, 30 | args.alpha, args.num_users, current_time) 31 | # TAG = f'alpha_{alpha}/data_distribution' 32 | logdir = f'runs/{TAG}' if not args.debug else f'runs2/{TAG}' 33 | writer = SummaryWriter(logdir) 34 | 35 | # load dataset and split users 36 | if args.dataset == 'mnist': 37 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 38 | dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist) 39 | dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist) 40 | # sample users 41 | if args.iid: 42 | dict_users = mnist_iid(dataset_train, args.num_users) 43 | else: 44 | dict_users = mnist_noniid(dataset_train, args.num_users) 45 | elif args.dataset == 'cifar': 46 | transform_train = transforms.Compose([ 47 | transforms.RandomCrop(32, padding=4), 48 | transforms.RandomHorizontalFlip(), 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 51 | ]) 52 | transform_test = transforms.Compose([ 53 | transforms.ToTensor(), 54 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 55 | ]) 56 | dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=transform_train) 57 | dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=transform_test) 58 | elif args.dataset == 'fmnist': 59 | dataset_train = datasets.FashionMNIST('../data/fmnist/', train=True, download=True, 60 | transform=transforms.Compose([ 61 | transforms.Resize((32, 32)), 62 | transforms.RandomCrop(32, padding=4), 63 | transforms.RandomHorizontalFlip(), 64 | transforms.ToTensor(), 65 | transforms.Normalize((0.1307,), (0.3081,)), 66 | ])) 67 | 68 | # testing 69 | dataset_test = datasets.FashionMNIST('../data/fmnist/', train=False, download=True, 70 | transform=transforms.Compose([ 71 | transforms.Resize((32, 32)), 72 | transforms.ToTensor(), 73 | transforms.Normalize((0.1307,), (0.3081,)) 74 | ])) 75 | # test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False) 76 | else: 77 | exit('Error: unrecognized dataset') 78 | 79 | if args.iid: 80 | dict_users = cifar_iid(dataset_train, args.num_users) 81 | else: 82 | dict_users, _ = cifar_noniid(dataset_train, args.num_users, args.alpha) 83 | for k, v in dict_users.items(): 84 | writer.add_histogram(f'user_{k}/data_distribution', 85 | np.array(dataset_train.targets)[v], 86 | bins=np.arange(11)) 87 | writer.add_histogram(f'all_user/data_distribution', 88 | np.array(dataset_train.targets)[v], 89 | bins=np.arange(11), global_step=k) 90 | 91 | # build model 92 | if args.model == 'lenet' and (args.dataset == 'cifar' or args.dataset == 'fmnist'): 93 | net_glob = CNNCifar(args=args).to(args.device) 94 | elif args.model == 'vgg' and args.dataset == 'cifar': 95 | net_glob = vgg16().to(args.device) 96 | else: 97 | exit('Error: unrecognized model') 98 | print(net_glob) 99 | net_glob.train() 100 | 101 | # copy weights 102 | w_glob = net_glob.state_dict() 103 | 104 | # training 105 | loss_train = [] 106 | cv_loss, cv_acc = [], [] 107 | val_loss_pre, counter = 0, 0 108 | net_best = None 109 | best_loss = None 110 | val_acc_list, net_list = [], [] 111 | test_best_acc = 0.0 112 | 113 | for iter in range(args.epochs): 114 | w_locals, loss_locals = [], [] 115 | m = max(int(args.frac * args.num_users), 1) 116 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 117 | for idx in idxs_users: 118 | local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx]) 119 | w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device)) 120 | w_locals.append(w) 121 | loss_locals.append(loss) 122 | # update global weights 123 | w_glob = FedAvg(w_locals) 124 | 125 | # copy weight to net_glob 126 | net_glob.load_state_dict(w_glob) 127 | 128 | # print loss 129 | loss_avg = sum(loss_locals) / len(loss_locals) 130 | print('Round {:3d}, Train loss {:.3f}'.format(iter, loss_avg)) 131 | loss_train.append(loss_avg) 132 | writer.add_scalar('train_loss', loss_avg, iter) 133 | test_acc, test_loss = test_img(net_glob, dataset_test, args) 134 | writer.add_scalar('test_loss', test_loss, iter) 135 | writer.add_scalar('test_acc', test_acc, iter) 136 | 137 | save_info = { 138 | "model": net_glob.state_dict(), 139 | "epoch": iter 140 | } 141 | # save model weights 142 | if (iter+1) % 500 == 0: 143 | save_path = f'./save2/{TAG}_{iter+1}es' if args.debug else f'./save/{TAG}_{iter+1}es' 144 | torch.save(save_info, save_path) 145 | if iter > 100 and test_acc > test_best_acc: 146 | test_best_acc = test_acc 147 | save_path = f'./save2/{TAG}_bst' if args.debug else f'./save/{TAG}_bst' 148 | torch.save(save_info, save_path) 149 | 150 | # plot loss curve 151 | # plt.figure() 152 | # plt.plot(range(len(loss_train)), loss_train) 153 | # plt.ylabel('train_loss') 154 | # plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid)) 155 | 156 | # testing 157 | net_glob.eval() 158 | acc_train, loss_train = test_img(net_glob, dataset_train, args) 159 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 160 | print("Training accuracy: {:.2f}".format(acc_train)) 161 | print("Testing accuracy: {:.2f}".format(acc_test)) 162 | writer.close() 163 | -------------------------------------------------------------------------------- /main_gate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | import torch.optim as optim 9 | from torchvision import datasets, transforms 10 | from utils.options import args_parser 11 | from models.Nets import CNNGate, gate_vgg16 12 | from utils.util import setup_seed, add_scalar 13 | from torch.utils.tensorboard import SummaryWriter 14 | from datetime import datetime 15 | from utils.sampling import cifar_noniid 16 | import numpy as np 17 | from models.Update import DatasetSplit 18 | from models.Test import user_test, user_per_test 19 | 20 | 21 | if __name__ == '__main__': 22 | # parse args 23 | args = args_parser() 24 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 25 | setup_seed(args.seed) 26 | 27 | # log 28 | current_time = datetime.now().strftime('%b.%d_%H.%M.%S') 29 | TAG = 'exp/{}gate2/{}_{}_{}_{}_user{}_{}'.format('struct/' if args.struct else '', args.dataset, args.model, args.epochs, 30 | args.alpha, args.num_users, current_time) 31 | TAG2 = 'exp/{}per_fb/{}_{}_{}_{}_user{}_{}'.format('struct/' if args.struct else '', args.dataset, args.model, args.epochs, 32 | args.alpha, args.num_users, current_time) 33 | logdir = f'runs/{TAG}' 34 | logdir2 = f'runs/{TAG2}' 35 | if args.debug: 36 | logdir = f'runs2/{TAG}' 37 | logdir2 = f'runs2/{TAG2}' 38 | writer = SummaryWriter(logdir) 39 | writer2 = SummaryWriter(logdir2) 40 | 41 | # load dataset and split users 42 | train_loader, test_loader, class_weight = 1, 1, 1 43 | 44 | save_dataset_path = f'./data/{args.dataset}_non_iid{args.alpha}_user{args.num_users}_fast_data' 45 | # global_weight = torch.load('./save/exp/fed/cifar_resnet_1000_C0.1_iidFalse_2.0_user100_Nov.28_01.41.16_bst')['model'] 46 | global_weight = torch.load( 47 | f'./save/exp/fed/{args.dataset}_{args.model}_1000_C0.1_iidFalse_{args.alpha}_user{args.num_users}_bst')[ 48 | 'model'] 49 | if 'gate.weight' in global_weight: 50 | del (global_weight['gate.weight']) 51 | del (global_weight['gate.bias']) 52 | if args.rebuild: 53 | if args.dataset == "cifar": 54 | # training 55 | transform_train = transforms.Compose([ 56 | transforms.RandomCrop(32, padding=4), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.ToTensor(), 59 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 60 | ]) 61 | dataset_train = datasets.CIFAR10('../data/cifar', train=True, transform=transform_train, download=True) 62 | # testing 63 | transform_test = transforms.Compose([ 64 | transforms.ToTensor(), 65 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 66 | ]) 67 | dataset_test = datasets.CIFAR10('../data/cifar', train=False, transform=transform_test, download=True) 68 | elif args.dataset == "fmnist": 69 | dataset_train = datasets.FashionMNIST('../data/fmnist/', train=True, download=True, 70 | transform=transforms.Compose([ 71 | transforms.Resize((32, 32)), 72 | transforms.RandomCrop(32, padding=4), 73 | transforms.RandomHorizontalFlip(), 74 | transforms.ToTensor(), 75 | transforms.Normalize((0.1307,), (0.3081,)), 76 | ])) 77 | 78 | # testing 79 | dataset_test = datasets.FashionMNIST('../data/fmnist/', train=False, download=True, 80 | transform=transforms.Compose([ 81 | transforms.Resize((32, 32)), 82 | transforms.ToTensor(), 83 | transforms.Normalize((0.1307,), (0.3081,)) 84 | ])) 85 | else: 86 | exit('Error: unrecognized dataset') 87 | # non_iid 88 | dict_users, _ = cifar_noniid(dataset_train, args.num_users, args.alpha) 89 | save_dataset = { 90 | "dataset_test": dataset_test, 91 | "dataset_train": dataset_train, 92 | "dict_users": dict_users 93 | } 94 | torch.save(save_dataset, save_dataset_path) 95 | else: 96 | save_dataset = torch.load(save_dataset_path) 97 | dataset_test = save_dataset['dataset_test'] 98 | dataset_train = save_dataset['dataset_train'] 99 | dict_users = save_dataset['dict_users'] 100 | 101 | test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False) 102 | for k, v in dict_users.items(): 103 | writer.add_histogram(f'user_{k}/data_distribution', 104 | np.array(dataset_train.targets)[v], 105 | bins=np.arange(11)) 106 | writer.add_histogram(f'all_user/data_distribution', 107 | np.array(dataset_train.targets)[v], 108 | bins = np.arange(11), global_step = k) 109 | 110 | # build model 111 | if args.model == 'lenet' and (args.dataset == 'cifar' or args.dataset == 'fmnist'): 112 | net_glob = CNNGate(args=args).to(args.device) 113 | elif args.model == 'vgg' and args.dataset == 'cifar': 114 | net_glob = gate_vgg16(args=args).to(args.device) 115 | else: 116 | exit('Error: unrecognized model') 117 | image, target = next(iter(test_loader)) 118 | writer.add_graph(net_glob, image.to(args.device)) 119 | 120 | gate_epochs = 200 121 | 122 | local_acc = np.zeros([args.num_users, args.epochs + gate_epochs + 1]) 123 | total_acc = np.zeros([args.num_users, args.epochs + gate_epochs + 1]) 124 | local_acc2 = np.zeros([args.num_users, args.epochs + gate_epochs + 1]) 125 | total_acc2 = np.zeros([args.num_users, args.epochs + gate_epochs + 1]) 126 | 127 | for user_num in range(len(dict_users)): 128 | # user data 129 | user_train = DatasetSplit(dataset_train, dict_users[user_num]) 130 | 131 | np.random.shuffle(dict_users[user_num]) 132 | cut_point = len(dict_users[user_num]) // 4 133 | train_loader = DataLoader(DatasetSplit(dataset_train, dict_users[user_num][cut_point:]), 134 | batch_size=64, shuffle=True) 135 | gate_loader = DataLoader(DatasetSplit(dataset_train, dict_users[user_num][:cut_point]), 136 | batch_size=64, shuffle=True) 137 | 138 | class_weight = np.zeros(10) 139 | for image, label in user_train: 140 | class_weight[label] += 1 141 | class_weight /= sum(class_weight) 142 | 143 | # init 144 | 145 | net_glob.load_state_dict(global_weight, False) 146 | 147 | if args.model == 'lenet': 148 | keys_ind = ['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'] 149 | net_glob.load_state_dict({'p' + k: global_weight[k] for k in keys_ind}, strict=False) 150 | elif args.model == 'vgg': 151 | keys_ind = ['classifier.1.weight', 'classifier.1.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias'] 152 | net_glob.load_state_dict({'p' + k: global_weight[k] for k in keys_ind}, strict=False) 153 | else: 154 | exit("Error: unrecognized model") 155 | net_glob.gate.reset_parameters() 156 | 157 | # training 158 | if args.model == 'lenet': 159 | layer_set = {'p' + k[:k.rindex('.')] for k in keys_ind} 160 | optimizer = optim.SGD([{'params': getattr(net_glob, l).parameters()} for l in layer_set], 161 | lr=0.001, momentum=0.9, weight_decay=5e-4) 162 | elif args.model == 'vgg': 163 | layer_set = {k[len('pclassifier'):k.rindex('.')] for k in keys_ind} 164 | optimizer = optim.SGD([{'params': net_glob.pclassifier.parameters()}], 165 | lr=0.005, momentum=0.9, weight_decay=5e-4) 166 | else: 167 | exit('Error: unrecognized model') 168 | # optimizer_gate = optim.SGD([{'params': net_glob.gate.parameters()}], lr=0.001, momentum=0.9, weight_decay=5e-4) 169 | criterion = nn.CrossEntropyLoss() 170 | 171 | test_result = user_per_test(args, net_glob, test_loader, class_weight) 172 | add_scalar(writer2, user_num, test_result, 0) 173 | total_acc2[user_num][0] = test_result[1] 174 | local_acc2[user_num][0] = test_result[3] 175 | 176 | for epoch in range(1, args.epochs+1): 177 | net_glob.train() 178 | batch_loss = [] 179 | gate_out = [] 180 | for batch_idx, (data, target) in enumerate(train_loader): 181 | data, target = data.to(args.device), target.to(args.device) 182 | optimizer.zero_grad() 183 | output, g, z = net_glob(data) 184 | gate_out.append(g) 185 | loss = criterion(z, target) 186 | loss.backward() 187 | optimizer.step() 188 | batch_loss.append(loss.item()) 189 | if epoch % 10 == 1: 190 | # writer.add_histogram(f"user_{user_num}/gate_out", torch.cat(gate_out[0:-1], -1), epoch) 191 | if args.model == 'lenet': 192 | for layer in layer_set: 193 | writer.add_histogram(f"user_{user_num}/{layer}/weight", getattr(net_glob, layer).weight, epoch) 194 | elif args.model == 'vgg': 195 | for layer in layer_set: 196 | writer.add_histogram(f"user_{user_num}/pclassifier.{layer}/weight", getattr(net_glob.pclassifier, layer).weight, epoch) 197 | loss_avg = sum(batch_loss) / len(batch_loss) 198 | print(f'User {user_num} train loss:', loss_avg) 199 | writer2.add_scalar(f'user_{user_num}/pfc_train_loss', loss_avg, epoch) 200 | 201 | test_result = user_per_test(args, net_glob, test_loader, class_weight) 202 | print(f'global test acc:', test_result[1]) 203 | add_scalar(writer2, user_num, test_result, epoch) 204 | total_acc2[user_num][epoch] = test_result[1] 205 | local_acc2[user_num][epoch] = test_result[3] 206 | 207 | test_result = user_test(args, net_glob, test_loader, class_weight) 208 | add_scalar(writer, user_num, test_result, args.epochs) 209 | total_acc[user_num][args.epochs] = test_result[1] 210 | local_acc[user_num][args.epochs] = test_result[3] 211 | 212 | optimizer_gate = optim.Adam([{'params': net_glob.gate.parameters()}], weight_decay=5e-4) 213 | 214 | for gate_epoch in range(1, 1 + gate_epochs): 215 | net_glob.train() 216 | gate_epoch_loss = [] 217 | gate_out = torch.tensor([], device=args.device) 218 | for batch_idx, (data, target) in enumerate(gate_loader): 219 | data, target = data.to(args.device), target.to(args.device) 220 | optimizer_gate.zero_grad() 221 | output, g, z = net_glob(data) 222 | gate_out = torch.cat((gate_out, g.view(-1))) 223 | loss = criterion(output, target) 224 | loss.backward() 225 | optimizer_gate.step() 226 | gate_epoch_loss.append(loss.item()) 227 | if gate_epoch % 10 == 1: 228 | writer.add_histogram(f"user_{user_num}/gate_out", gate_out) 229 | writer.add_histogram(f"user_{user_num}/gate/weight", net_glob.gate.weight) 230 | writer.add_histogram(f"user_{user_num}/gate/bais", net_glob.gate.bias) 231 | loss_avg = sum(gate_epoch_loss) / len(gate_epoch_loss) 232 | print(f'User {user_num} gate loss', loss_avg) 233 | writer.add_scalar(f'user_{user_num}/gate_train_loss', loss_avg, args.epochs + gate_epoch) 234 | 235 | test_result = user_test(args, net_glob, test_loader, class_weight) 236 | add_scalar(writer, user_num, test_result, args.epochs + gate_epoch) 237 | total_acc[user_num][args.epochs + gate_epoch] = test_result[1] 238 | local_acc[user_num][args.epochs + gate_epoch] = test_result[3] 239 | 240 | save_info = { 241 | "total_acc": total_acc, 242 | "local_acc": local_acc 243 | } 244 | save_info2 = { 245 | "total_acc": total_acc2, 246 | "local_acc": local_acc2 247 | } 248 | save_path = f'{logdir}/local_train_epoch_acc' 249 | save_path2 = f'{logdir2}/local_train_epoch_acc' 250 | torch.save(save_info, save_path) 251 | torch.save(save_info2, save_path2) 252 | 253 | total_acc = total_acc.mean(axis=0) 254 | local_acc = local_acc.mean(axis=0) 255 | total_acc2 = total_acc2.mean(axis=0) 256 | local_acc2 = local_acc2.mean(axis=0) 257 | for epoch, _ in enumerate(total_acc): 258 | if epoch >= args.epochs: 259 | writer.add_scalar('test/global/test_acc', total_acc[epoch], epoch) 260 | writer.add_scalar('test/local/test_acc', local_acc[epoch], epoch) 261 | if epoch <= args.epochs: 262 | writer2.add_scalar('test/global/test_acc', total_acc2[epoch], epoch) 263 | writer2.add_scalar('test/local/test_acc', local_acc2[epoch], epoch) 264 | writer.close() 265 | writer2.close() 266 | -------------------------------------------------------------------------------- /main_local.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | import copy 5 | import numpy as np 6 | from torchvision import datasets, transforms 7 | import torch 8 | import torch.nn 9 | import torch.nn.functional as F 10 | from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid 11 | from utils.options import args_parser 12 | from models.Nets import CNNCifar, vgg16 13 | from utils.util import setup_seed 14 | from datetime import datetime 15 | from torch.utils.tensorboard import SummaryWriter 16 | from torch.utils.data import DataLoader 17 | import torch.optim as optim 18 | from models.Update import DatasetSplit 19 | from models.Test import local_test 20 | from utils.util import add_scalar 21 | 22 | 23 | def test(model, data_source): 24 | model.eval() 25 | total_loss = 0.0 26 | correct = 0.0 27 | correct_class = np.zeros(10) 28 | correct_class_acc = np.zeros(10) 29 | correct_class_size = np.zeros(10) 30 | 31 | dataset_size = len(data_source.dataset) 32 | data_iterator = data_source 33 | with torch.no_grad(): 34 | for batch_id, (data, targets) in enumerate(data_iterator): 35 | data, targets = data.to(args.device), targets.to(args.device) 36 | output = model(data) 37 | total_loss += F.cross_entropy(output, targets, 38 | reduction='sum').item() # sum up batch loss 39 | pred = output.data.max(1)[1] # get the index of the max log-probability 40 | correct += pred.eq(targets.data.view_as(pred)).cpu().sum().item() 41 | for i in range(10): 42 | class_ind = targets.data.view_as(pred).eq(i*torch.ones_like(pred)) 43 | correct_class_size[i] += class_ind.cpu().sum().item() 44 | correct_class[i] += (pred.eq(targets.data.view_as(pred))*class_ind).cpu().sum().item() 45 | 46 | acc = 100.0 * (float(correct) / float(dataset_size)) 47 | for i in range(10): 48 | correct_class_acc[i] = (float(correct_class[i]) / float(correct_class_size[i])) 49 | total_l = total_loss / dataset_size 50 | # print(f'Average loss: {total_l}, Accuracy: {correct}/{dataset_size} ({acc}%)') 51 | return total_l, acc, correct_class_acc 52 | 53 | 54 | if __name__ == '__main__': 55 | # parse args 56 | args = args_parser() 57 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 58 | setup_seed(args.seed) 59 | 60 | # log 61 | current_time = datetime.now().strftime('%b.%d_%H.%M.%S') 62 | TAG = 'exp/local/{}_{}_{}_iid{}_{}_user{}_{}'.format(args.dataset, args.model, args.epochs, args.iid, args.alpha, 63 | args.num_users, current_time) 64 | logdir = f'runs/{TAG}' if not args.debug else f'runs2/{TAG}' 65 | writer = SummaryWriter(logdir) 66 | 67 | # load dataset and split users 68 | if args.dataset == 'mnist': 69 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 70 | dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist) 71 | dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist) 72 | # sample users 73 | if args.iid: 74 | dict_users = mnist_iid(dataset_train, args.num_users) 75 | else: 76 | dict_users = mnist_noniid(dataset_train, args.num_users) 77 | 78 | elif args.dataset == 'cifar': 79 | transform_train = transforms.Compose([ 80 | transforms.RandomCrop(32, padding=4), 81 | transforms.RandomHorizontalFlip(), 82 | transforms.ToTensor(), 83 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 84 | ]) 85 | transform_test = transforms.Compose([ 86 | transforms.ToTensor(), 87 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 88 | ]) 89 | dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=transform_train) 90 | dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=transform_test) 91 | elif args.dataset == 'fmnist': 92 | dataset_train = datasets.FashionMNIST('../data/fmnist/', train=True, download=True, 93 | transform=transforms.Compose([ 94 | transforms.Resize((32, 32)), 95 | transforms.RandomCrop(32, padding=4), 96 | transforms.RandomHorizontalFlip(), 97 | transforms.ToTensor(), 98 | transforms.Normalize((0.1307,), (0.3081,)), 99 | ])) 100 | 101 | # testing 102 | dataset_test = datasets.FashionMNIST('../data/fmnist/', train=False, download=True, 103 | transform=transforms.Compose([ 104 | transforms.Resize((32, 32)), 105 | transforms.ToTensor(), 106 | transforms.Normalize((0.1307,), (0.3081,)) 107 | ])) 108 | # test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False) 109 | else: 110 | exit('Error: unrecognized dataset') 111 | 112 | if args.iid: 113 | dict_users = cifar_iid(dataset_train, args.num_users) 114 | else: 115 | dict_users, _ = cifar_noniid(dataset_train, args.num_users, args.alpha) 116 | for k, v in dict_users.items(): 117 | writer.add_histogram(f'user_{k}/data_distribution', 118 | np.array(dataset_train.targets)[v], 119 | bins=np.arange(11)) 120 | writer.add_histogram(f'all_user/data_distribution', 121 | np.array(dataset_train.targets)[v], 122 | bins=np.arange(11), global_step=k) 123 | 124 | test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False) 125 | img_size = dataset_train[0][0].shape 126 | 127 | # build model 128 | if args.model == 'lenet' and (args.dataset == 'cifar' or args.dataset == 'fmnist'): 129 | net_glob = CNNCifar(args=args).to(args.device) 130 | elif args.model == 'vgg' and args.dataset == 'cifar': 131 | net_glob = vgg16().to(args.device) 132 | else: 133 | exit('Error: unrecognized model') 134 | print(net_glob) 135 | net_glob.train() 136 | 137 | # copy weights 138 | w_init = copy.deepcopy(net_glob.state_dict()) 139 | 140 | local_acc_final = [] 141 | total_acc_final = [] 142 | local_acc = np.zeros([args.num_users, args.epochs]) 143 | total_acc = np.zeros([args.num_users, args.epochs]) 144 | 145 | # training 146 | for idx in range(args.num_users): 147 | # print(w_init) 148 | net_glob.load_state_dict(w_init) 149 | optimizer = optim.Adam(net_glob.parameters()) 150 | train_loader = DataLoader(DatasetSplit(dataset_train, dict_users[idx]), batch_size=64, shuffle=True) 151 | image_trainset_weight = np.zeros(10) 152 | for label in np.array(dataset_train.targets)[dict_users[idx]]: 153 | image_trainset_weight[label] += 1 154 | image_trainset_weight = image_trainset_weight / image_trainset_weight.sum() 155 | list_loss = [] 156 | net_glob.train() 157 | for epoch in range(args.epochs): 158 | batch_loss = [] 159 | for batch_idx, (data, target) in enumerate(train_loader): 160 | data, target = data.to(args.device), target.to(args.device) 161 | optimizer.zero_grad() 162 | output = net_glob(data) 163 | loss = F.cross_entropy(output, target) 164 | loss.backward() 165 | optimizer.step() 166 | # if batch_idx % 3 == 0: 167 | # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 168 | # epoch, batch_idx * len(data), len(train_loader.dataset), 169 | # 100. * batch_idx / len(train_loader), loss.item())) 170 | batch_loss.append(loss.item()) 171 | 172 | loss_avg = sum(batch_loss) / len(batch_loss) 173 | print('\nLocal Train loss:', loss_avg) 174 | writer.add_scalar(f'user_{idx}/local_train_loss', loss_avg, epoch) 175 | 176 | test_result = local_test(args, net_glob, test_loader, image_trainset_weight) 177 | add_scalar(writer, idx, test_result, epoch) 178 | print('Global Test ACC:', test_result[1]) 179 | print('Local Test ACC:', test_result[3]) 180 | 181 | total_acc[idx][epoch] = test_result[1] 182 | local_acc[idx][epoch] = test_result[3] 183 | 184 | total_acc_final.append(test_result[1]) 185 | local_acc_final.append(test_result[3]) 186 | print(f'user {idx} done!') 187 | 188 | save_info = { 189 | "total_acc": total_acc, 190 | "local_acc": local_acc 191 | } 192 | save_path = f'{logdir}/local_train_epoch_acc' 193 | torch.save(save_info, save_path) 194 | 195 | total_acc = total_acc.mean(axis=0) 196 | local_acc = local_acc.mean(axis=0) 197 | for epoch in range(args.epochs): 198 | writer.add_scalar('test/global/test_acc', total_acc[epoch], epoch) 199 | writer.add_scalar('test/local/test_acc', local_acc[epoch], epoch) 200 | writer.close() 201 | # 202 | # # plot loss curve 203 | # plt.figure() 204 | # plt.title('local train acc', fontsize=20) # 标题,并设定字号大小 205 | # labels = ['local', 'total'] 206 | # plt.boxplot([local_acc_final, total_acc_final], labels=labels, notch=True, showmeans=True) 207 | # plt.ylabel('test acc') 208 | # plt.savefig(f'{logdir}/local_train_acc.png') 209 | -------------------------------------------------------------------------------- /main_nn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | import torch.optim as optim 9 | from torchvision import datasets, transforms 10 | from utils.options import args_parser 11 | from models.Nets import CNNCifar, vgg16 12 | from utils.util import setup_seed 13 | from torch.utils.tensorboard import SummaryWriter 14 | from datetime import datetime 15 | from models.Test import test 16 | 17 | 18 | if __name__ == '__main__': 19 | 20 | # parse args 21 | args = args_parser() 22 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 23 | setup_seed(args.seed) 24 | 25 | # log 26 | current_time = datetime.now().strftime('%b.%d_%H.%M.%S') 27 | TAG = 'nn_{}_{}_{}_{}'.format(args.dataset, args.model, args.epochs, current_time) 28 | logdir = f'runs/{TAG}' 29 | if args.debug: 30 | logdir = f'runs2/{TAG}' 31 | writer = SummaryWriter(logdir) 32 | 33 | # load dataset and split users 34 | if args.dataset == 'cifar': 35 | transform_train = transforms.Compose([ 36 | transforms.RandomCrop(32, padding=4), 37 | transforms.RandomHorizontalFlip(), 38 | transforms.ToTensor(), 39 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 40 | ]) 41 | dataset_train = datasets.CIFAR10('../data/cifar', train=True, transform=transform_train, download=True) 42 | 43 | # testing 44 | transform_test = transforms.Compose([ 45 | transforms.ToTensor(), 46 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 47 | ]) 48 | dataset_test = datasets.CIFAR10('../data/cifar', train=False, transform=transform_test, download=True) 49 | test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False) 50 | img_size = dataset_train[0][0].shape 51 | elif args.dataset == 'fmnist': 52 | dataset_train = datasets.FashionMNIST('../data/fmnist/', train=True, download=True, 53 | transform=transforms.Compose([ 54 | transforms.Resize((32, 32)), 55 | transforms.RandomCrop(32, padding=4), 56 | transforms.RandomHorizontalFlip(), 57 | transforms.ToTensor(), 58 | transforms.Normalize((0.1307,), (0.3081,)), 59 | ])) 60 | 61 | # testing 62 | dataset_test = datasets.FashionMNIST('../data/fmnist/', train=False, download=True, 63 | transform=transforms.Compose([ 64 | transforms.Resize((32, 32)), 65 | transforms.ToTensor(), 66 | transforms.Normalize((0.1307,), (0.3081,)) 67 | ])) 68 | test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False) 69 | else: 70 | exit('Error: unrecognized dataset') 71 | 72 | # build model 73 | if args.model == 'lenet' and (args.dataset == 'cifar' or args.dataset == 'fmnist'): 74 | net_glob = CNNCifar(args=args).to(args.device) 75 | elif args.model == 'vgg' and args.dataset == 'cifar': 76 | net_glob = vgg16().to(args.device) 77 | else: 78 | exit('Error: unrecognized model') 79 | print(net_glob) 80 | img = dataset_train[0][0].unsqueeze(0).to(args.device) 81 | writer.add_graph(net_glob, img) 82 | 83 | # training 84 | creterion = nn.CrossEntropyLoss() 85 | train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True) 86 | # optimizer = optim.Adam(net_glob.parameters()) 87 | optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 88 | # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1) 89 | # # # scheduler.step() 90 | 91 | list_loss = [] 92 | net_glob.train() 93 | for epoch in range(args.epochs): 94 | batch_loss = [] 95 | for batch_idx, (data, target) in enumerate(train_loader): 96 | data, target = data.to(args.device), target.to(args.device) 97 | optimizer.zero_grad() 98 | output = net_glob(data) 99 | loss = creterion(output, target) 100 | loss.backward() 101 | optimizer.step() 102 | if batch_idx % 50 == 0: 103 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 104 | epoch, batch_idx * len(data), len(train_loader.dataset), 105 | 100. * batch_idx / len(train_loader), loss.item())) 106 | batch_loss.append(loss.item()) 107 | # scheduler.step() 108 | loss_avg = sum(batch_loss)/len(batch_loss) 109 | print('\nTrain loss:', loss_avg) 110 | list_loss.append(loss_avg) 111 | writer.add_scalar('train_loss', loss_avg, epoch) 112 | test_acc, test_loss = test(args, net_glob, test_loader) 113 | writer.add_scalar('test_loss', test_loss, epoch) 114 | writer.add_scalar('test_acc', test_acc, epoch) 115 | 116 | # save model weights 117 | save_info = { 118 | "epochs": args.epochs, 119 | "optimizer": optimizer.state_dict(), 120 | "model": net_glob.state_dict() 121 | } 122 | 123 | save_path = f'save2/{TAG}' if args.debug else f'save2/{TAG}' 124 | torch.save(save_info, save_path) 125 | writer.close() 126 | -------------------------------------------------------------------------------- /main_per_fb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | from torch.utils.data import DataLoader 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from utils.options import args_parser 9 | from models.Nets import CNNGate 10 | from torch.utils.tensorboard import SummaryWriter 11 | from datetime import datetime 12 | from utils.sampling import cifar_noniid 13 | from models.Update import DatasetSplit 14 | from utils.util import * 15 | from models.Test import user_per_test 16 | import copy 17 | 18 | 19 | if __name__ == '__main__': 20 | # parse args 21 | args = args_parser() 22 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 23 | setup_seed(args.seed) 24 | 25 | # log 26 | current_time = datetime.now().strftime('%b.%d_%H.%M.%S') 27 | TAG = 'exp/non_iid/per_fb_{}_{}_{}_{}_user{}_{}'.format(args.dataset, args.model, args.epochs, args.alpha, 28 | args.num_users, current_time) 29 | logdir = f'runs/{TAG}' 30 | if args.debug: 31 | logdir = f'runs2/{TAG}' 32 | writer = SummaryWriter(logdir) 33 | 34 | # load dataset and split users 35 | train_loader, test_loader, class_weight, dict_users, dataset_train = 1, 1, 1, 1, 1 36 | if args.dataset == 'mnist': 37 | dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, 38 | transform=transforms.Compose([ 39 | transforms.ToTensor(), 40 | transforms.Normalize((0.1307,), (0.3081,)) 41 | ])) 42 | # testing 43 | dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, 44 | transform=transforms.Compose([ 45 | transforms.ToTensor(), 46 | transforms.Normalize((0.1307,), (0.3081,)) 47 | ])) 48 | 49 | elif args.dataset == 'cifar': 50 | save_dataset_path = f'./data/cifar_non_iid{args.alpha}_user{args.num_users}_fast_data' 51 | # global_weight = torch.load('./save/nn_cifar_cnn_100_Oct.13_19.45.20')['model'] 52 | # global_weight = torch.load(f'./save/exp/fed/{args.dataset}_{args.model}_1000_C0.1_iidFalse_{args.alpha}_user{args.num_users}_1000es')['model'] 53 | global_weight = torch.load(f'./save/exp/fed/cifar_lenet_1000_C0.1_iidFalse_0.2_user30*3_Nov.16_14.37.35_1000es')['model'] 54 | if args.rebuild: 55 | # training 56 | transform_train = transforms.Compose([ 57 | transforms.RandomCrop(32, padding=4), 58 | transforms.RandomHorizontalFlip(), 59 | transforms.ToTensor(), 60 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 61 | ]) 62 | dataset_train = datasets.CIFAR10('../data/cifar', train=True, transform=transform_train, download=True) 63 | # testing 64 | transform_test = transforms.Compose([ 65 | transforms.ToTensor(), 66 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 67 | ]) 68 | dataset_test = datasets.CIFAR10('../data/cifar', train=False, transform=transform_test, download=True) 69 | # non_iid 70 | dict_users, _ = cifar_noniid(dataset_train, args.num_users, args.alpha) 71 | 72 | save_dataset = { 73 | "dataset_test": dataset_test, 74 | "dataset_train": dataset_train, 75 | "dict_users": dict_users 76 | } 77 | torch.save(save_dataset, save_dataset_path) 78 | else: 79 | save_dataset = torch.load(save_dataset_path) 80 | dataset_test = save_dataset['dataset_test'] 81 | dataset_train = save_dataset['dataset_train'] 82 | dict_users = save_dataset['dict_users'] 83 | 84 | test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False) 85 | for k, v in dict_users.items(): 86 | writer.add_histogram(f'user_{k}/data_distribution', 87 | np.array(dataset_train.targets)[v], 88 | bins=np.arange(11)) 89 | writer.add_histogram(f'all_user/data_distribution', 90 | np.array(dataset_train.targets)[v], 91 | bins=np.arange(11), global_step=k) 92 | img_size = dataset_train[0][0].shape 93 | elif args.dataset == 'fmnist': 94 | pass 95 | else: 96 | exit('Error: unrecognized dataset') 97 | 98 | # build model 99 | net_glob = CNNGate(args=args).to(args.device) 100 | image, target = next(iter(test_loader)) 101 | writer.add_graph(net_glob, image.to(args.device)) 102 | 103 | local_acc = np.zeros([args.num_users, args.epochs + 1]) 104 | total_acc = np.zeros([args.num_users, args.epochs + 1]) 105 | 106 | for user_num in range(len(dict_users)): 107 | # user train data 108 | user_train = DatasetSplit(dataset_train, dict_users[user_num]) 109 | train_loader = DataLoader(user_train, batch_size=64, shuffle=True) 110 | 111 | class_weight = np.zeros(10) 112 | for image, label in user_train: 113 | class_weight[label] += 1 114 | class_weight /= sum(class_weight) 115 | 116 | # init 117 | 118 | # global_weight = torch.load('./save/fed_cifar_cnn_1000_C0.1_iidFalse_0.9_Nov.05_09.31.38_500es')['model'] 119 | net_glob.load_state_dict(global_weight, False) 120 | net_glob.pfc1.load_state_dict({'weight': global_weight['fc1.weight'], 'bias': global_weight['fc1.bias']}) 121 | net_glob.pfc2.load_state_dict({'weight': global_weight['fc2.weight'], 'bias': global_weight['fc2.bias']}) 122 | net_glob.pfc3.load_state_dict({'weight': global_weight['fc3.weight'], 'bias': global_weight['fc3.bias']}) 123 | 124 | # training 125 | optimizer = optim.SGD([ 126 | {'params': net_glob.pfc1.parameters()}, 127 | {'params': net_glob.pfc2.parameters()}, 128 | {'params': net_glob.pfc3.parameters()}, 129 | ], lr=0.001, momentum=0.9, weight_decay=5e-4) 130 | criterion = nn.CrossEntropyLoss() 131 | 132 | test_result = user_per_test(args, net_glob, test_loader, class_weight) 133 | add_scalar(writer, user_num, test_result, 0) 134 | total_acc[user_num][0] = test_result[1] 135 | local_acc[user_num][0] = test_result[3] 136 | 137 | for epoch in range(1, args.epochs+1): 138 | net_glob.train() 139 | batch_loss = [] 140 | gate_out = [] 141 | for batch_idx, (data, target) in enumerate(train_loader): 142 | data, target = data.to(args.device), target.to(args.device) 143 | optimizer.zero_grad() 144 | output, g, z = net_glob(data) 145 | gate_out.append(g) 146 | loss = criterion(z, target) 147 | loss.backward() 148 | optimizer.step() 149 | batch_loss.append(loss.item()) 150 | writer.add_histogram(f"user_{user_num}/pfc1/weight", net_glob.pfc1.weight, epoch) 151 | writer.add_histogram(f"user_{user_num}/pfc2/weight", net_glob.pfc2.weight, epoch) 152 | writer.add_histogram(f"user_{user_num}/pfc3/weight", net_glob.pfc3.weight, epoch) 153 | loss_avg = sum(batch_loss) / len(batch_loss) 154 | print(f'User {user_num} Train loss:', loss_avg) 155 | writer.add_scalar(f'user_{user_num}/pfc_train_loss', loss_avg, epoch) 156 | 157 | test_result = user_per_test(args, net_glob, test_loader, class_weight) 158 | add_scalar(writer, user_num, test_result, epoch) 159 | total_acc[user_num][epoch] = test_result[1] 160 | local_acc[user_num][epoch] = test_result[3] 161 | 162 | save_info = { 163 | "total_acc": total_acc, 164 | "local_acc": local_acc 165 | } 166 | save_path = f'{logdir}/local_train_epoch_acc' 167 | torch.save(save_info, save_path) 168 | 169 | total_acc = total_acc.mean(axis=0) 170 | local_acc = local_acc.mean(axis=0) 171 | for epoch, _ in enumerate(total_acc): 172 | writer.add_scalar('test/global/test_acc', total_acc[epoch], epoch) 173 | writer.add_scalar('test/local/test_acc', local_acc[epoch], epoch) 174 | writer.close() 175 | -------------------------------------------------------------------------------- /models/Fed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import copy 6 | import torch 7 | from torch import nn 8 | 9 | 10 | def FedAvg(w): 11 | w_avg = copy.deepcopy(w[0]) 12 | for k in w_avg.keys(): 13 | for i in range(1, len(w)): 14 | w_avg[k] += w[i][k] 15 | w_avg[k] = torch.div(w_avg[k], len(w)) 16 | return w_avg 17 | -------------------------------------------------------------------------------- /models/Nets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from thop import profile 9 | 10 | 11 | class MLP(nn.Module): 12 | def __init__(self, dim_in, dim_hidden, dim_out): 13 | super(MLP, self).__init__() 14 | self.layer_input = nn.Linear(dim_in, dim_hidden) 15 | self.relu = nn.ReLU() 16 | self.dropout = nn.Dropout() 17 | self.layer_hidden = nn.Linear(dim_hidden, dim_out) 18 | 19 | def forward(self, x): 20 | x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1]) 21 | x = self.layer_input(x) 22 | x = self.dropout(x) 23 | x = self.relu(x) 24 | x = self.layer_hidden(x) 25 | return x 26 | 27 | 28 | class CNNMnist(nn.Module): 29 | def __init__(self, args): 30 | super(CNNMnist, self).__init__() 31 | self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5) 32 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 33 | self.conv2_drop = nn.Dropout2d() 34 | self.fc1 = nn.Linear(320, 50) 35 | self.fc2 = nn.Linear(50, args.num_classes) 36 | 37 | def forward(self, x): 38 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 39 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 40 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 41 | x = F.relu(self.fc1(x)) 42 | x = F.dropout(x, training=self.training) 43 | x = self.fc2(x) 44 | return x 45 | 46 | 47 | class CNNCifar(nn.Module): 48 | def __init__(self, args): 49 | super(CNNCifar, self).__init__() 50 | self.conv1 = nn.Conv2d(1 if args.dataset == 'fmnist' else 3, 6, 5) 51 | self.pool = nn.MaxPool2d(2, 2) 52 | self.conv2 = nn.Conv2d(6, 16, 5) 53 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 54 | self.fc2 = nn.Linear(120, 84) 55 | self.fc3 = nn.Linear(84, args.num_classes) 56 | 57 | def forward(self, x): 58 | x = self.pool(F.relu(self.conv1(x))) 59 | x = self.pool(F.relu(self.conv2(x))) 60 | x = x.view(-1, 16 * 5 * 5) 61 | x = F.relu(self.fc1(x)) 62 | x = F.relu(self.fc2(x)) 63 | x = self.fc3(x) 64 | return x 65 | 66 | 67 | class CNNGate(nn.Module): 68 | def __init__(self, args): 69 | super(CNNGate, self).__init__() 70 | self.args = args 71 | self.conv1 = nn.Conv2d(1 if args.dataset == 'fmnist' else 3, 6, 5) 72 | self.conv2 = nn.Conv2d(6, 16, 5) 73 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 74 | self.fc2 = nn.Linear(120, 84) 75 | self.fc3 = nn.Linear(84, self.args.num_classes) 76 | 77 | for p in self.parameters(): 78 | p.requires_grad = False 79 | 80 | self.gate = nn.Linear(32 * 32 * (3 if args.dataset == 'cifar' else 1), 1) if args.struct else nn.Linear(16 * 5 * 5, 1) 81 | self.pfc1 = nn.Linear(16 * 5 * 5, 120) 82 | self.pfc2 = nn.Linear(120, 84) 83 | self.pfc3 = nn.Linear(84, args.num_classes) 84 | 85 | def forward(self, x1): 86 | x = F.max_pool2d(F.relu(self.conv1(x1)), 2) 87 | x = F.max_pool2d(F.relu(self.conv2(x)), 2) 88 | x = torch.flatten(x, 1) 89 | 90 | z = F.relu(self.pfc1(x)) 91 | z = F.relu(self.pfc2(z)) 92 | z = self.pfc3(z) 93 | 94 | g = torch.sigmoid(self.gate(torch.flatten(x1, 1))) if self.args.struct else torch.sigmoid(self.gate(x)) 95 | # g = 96 | y = F.relu(self.fc1(x)) 97 | y = F.relu(self.fc2(y)) 98 | y = self.fc3(y) 99 | return y * g + z * (1-g), g, z 100 | # return z 101 | 102 | 103 | ''' 104 | Modified from https://github.com/pytorch/vision.git 105 | ''' 106 | import math 107 | 108 | import torch.nn as nn 109 | import torch.nn.init as init 110 | 111 | __all__ = [ 112 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 113 | 'vgg19_bn', 'vgg19', 114 | ] 115 | 116 | 117 | class VGG(nn.Module): 118 | ''' 119 | VGG model 120 | ''' 121 | 122 | def __init__(self, features, has_gate=False, struct=False): 123 | super(VGG, self).__init__() 124 | self.has_gate = has_gate 125 | self.struct = struct 126 | self.features = features 127 | self.classifier = nn.Sequential( 128 | nn.Dropout(), 129 | nn.Linear(512, 512), 130 | nn.ReLU(True), 131 | nn.Dropout(), 132 | nn.Linear(512, 512), 133 | nn.ReLU(True), 134 | nn.Linear(512, 10), 135 | ) 136 | if has_gate: 137 | self.pclassifier = nn.Sequential( 138 | nn.Dropout(), 139 | nn.Linear(512, 512), 140 | nn.ReLU(True), 141 | nn.Dropout(), 142 | nn.Linear(512, 512), 143 | nn.ReLU(True), 144 | nn.Linear(512, 10), 145 | ) 146 | if self.struct: 147 | self.gate = nn.Linear(3 * 32 * 32, 1) 148 | else: 149 | self.gate = nn.Linear(512, 1) 150 | # Initialize weights 151 | for m in self.modules(): 152 | if isinstance(m, nn.Conv2d): 153 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 154 | m.weight.data.normal_(0, math.sqrt(2. / n)) 155 | m.bias.data.zero_() 156 | 157 | def forward(self, input): 158 | x = self.features(input) 159 | x = x.view(x.size(0), -1) 160 | if self.has_gate: 161 | if self.struct: 162 | g = torch.sigmoid(self.gate(torch.flatten(input, 1))) 163 | else: 164 | g = torch.sigmoid(self.gate(x)) 165 | y = self.classifier(x) 166 | z = self.pclassifier(x) 167 | return y * g + z * (1-g), g, z 168 | else: 169 | x = self.classifier(x) 170 | return x 171 | 172 | 173 | def make_layers(cfg, batch_norm=False): 174 | layers = [] 175 | in_channels = 3 176 | for v in cfg: 177 | if v == 'M': 178 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 179 | else: 180 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 181 | if batch_norm: 182 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 183 | else: 184 | layers += [conv2d, nn.ReLU(inplace=True)] 185 | in_channels = v 186 | return nn.Sequential(*layers) 187 | 188 | 189 | cfg = { 190 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 191 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 192 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 193 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 194 | 512, 512, 512, 512, 'M'], 195 | } 196 | 197 | 198 | def vgg11(): 199 | """VGG 11-layer model (configuration "A")""" 200 | return VGG(make_layers(cfg['A'])) 201 | 202 | 203 | def vgg11_bn(): 204 | """VGG 11-layer model (configuration "A") with batch normalization""" 205 | return VGG(make_layers(cfg['A'], batch_norm=True)) 206 | 207 | 208 | def vgg13(): 209 | """VGG 13-layer model (configuration "B")""" 210 | return VGG(make_layers(cfg['B'])) 211 | 212 | 213 | def vgg13_bn(): 214 | """VGG 13-layer model (configuration "B") with batch normalization""" 215 | return VGG(make_layers(cfg['B'], batch_norm=True)) 216 | 217 | 218 | def vgg16(): 219 | """VGG 16-layer model (configuration "D")""" 220 | return VGG(make_layers(cfg['D'])) 221 | 222 | 223 | def gate_vgg16(args): 224 | return VGG(make_layers(cfg['D']), has_gate=True, struct=args.struct) 225 | 226 | 227 | def vgg16_bn(): 228 | """VGG 16-layer model (configuration "D") with batch normalization""" 229 | return VGG(make_layers(cfg['D'], batch_norm=True)) 230 | 231 | 232 | def vgg19(): 233 | """VGG 19-layer model (configuration "E")""" 234 | return VGG(make_layers(cfg['E'])) 235 | 236 | 237 | def vgg19_bn(): 238 | """VGG 19-layer model (configuration 'E') with batch normalization""" 239 | return VGG(make_layers(cfg['E'], batch_norm=True)) 240 | 241 | 242 | def cifar_test(): 243 | # from utils.options import args_parser 244 | net = vgg16().to("cuda:0") 245 | img = torch.randn(100, 3, 32, 32).to("cuda:0") 246 | flops, params = profile(net, inputs=(img, )) 247 | print(flops, params) 248 | print(img.size()) 249 | 250 | # test() 251 | # cifar_test() 252 | 253 | 254 | -------------------------------------------------------------------------------- /models/Test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader 9 | import numpy as np 10 | 11 | 12 | def test(args, net_g, data_loader): 13 | # testing 14 | net_g.eval() 15 | test_loss = [] 16 | correct = 0 17 | with torch.no_grad(): 18 | for idx, (data, target) in enumerate(data_loader): 19 | data, target = data.to(args.device), target.to(args.device) 20 | log_probs = net_g(data) 21 | test_loss.append(nn.CrossEntropyLoss()(log_probs, target).item()) 22 | y_pred = log_probs.data.max(1, keepdim=True)[1] 23 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum().item() 24 | 25 | loss_avg = sum(test_loss)/len(test_loss) 26 | test_acc = 100. * correct / len(data_loader.dataset) 27 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format( 28 | loss_avg, correct, len(data_loader.dataset), test_acc)) 29 | 30 | return test_acc, loss_avg 31 | 32 | 33 | def test_img(net_g, datatest, args): 34 | net_g.eval() 35 | # testing 36 | test_loss = 0 37 | correct = 0 38 | data_loader = DataLoader(datatest, batch_size=args.test_bs) 39 | l = len(data_loader) 40 | with torch.no_grad(): 41 | for idx, (data, target) in enumerate(data_loader): 42 | if args.gpu != -1: 43 | data, target = data.to(args.device), target.to(args.device) 44 | log_probs = net_g(data) 45 | # sum up batch loss 46 | test_loss += F.cross_entropy(log_probs, target, reduction='sum').item() 47 | # get the index of the max log-probability 48 | y_pred = log_probs.data.max(1, keepdim=True)[1] 49 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 50 | 51 | test_loss /= len(data_loader.dataset) 52 | accuracy = 100.00 * correct.item() / len(data_loader.dataset) 53 | # if args.verbose: 54 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format( 55 | test_loss, correct, len(data_loader.dataset), accuracy)) 56 | return accuracy, test_loss 57 | 58 | 59 | def user_test(args, net_glob, data_loader, class_weight): 60 | #testing 61 | net_glob.eval() 62 | correct_class = np.zeros(10) 63 | class_loss = np.zeros(10) 64 | correct_class_acc = np.zeros(10) 65 | class_loss_avg = np.zeros(10) 66 | correct_class_size = np.zeros(10) 67 | correct = 0.0 68 | dataset_size = len(data_loader.dataset) 69 | total_loss = 0.0 70 | with torch.no_grad(): 71 | for idx, (data, target) in enumerate(data_loader): 72 | data, target = data.to(args.device), target.to(args.device) 73 | output, g, z = net_glob(data) 74 | # g = (g > 0.5).float() 75 | # output = y * g + z * (1-g) 76 | pred = output.max(1)[1] 77 | correct += pred.eq(target.data.view_as(pred)).cpu().sum().item() 78 | loss = nn.CrossEntropyLoss(reduction='none')(output, target) 79 | total_loss += loss.sum().item() 80 | for i in range(10): 81 | class_ind = target.data.view_as(pred).eq(i * torch.ones_like(pred)) 82 | correct_class_size[i] += class_ind.cpu().sum().item() 83 | correct_class[i] += (pred.eq(target.data.view_as(pred)) * class_ind).cpu().sum().item() 84 | class_loss[i] += (loss*class_ind.float()).cpu().sum().item() 85 | 86 | acc = 100.0 * (float(correct) / float(dataset_size)) 87 | total_l = total_loss / dataset_size 88 | for i in range(10): 89 | correct_class_acc[i] = (float(correct_class[i]) / float(correct_class_size[i])) 90 | class_loss_avg[i] = (float(class_loss[i]) / float(correct_class_size[i])) 91 | user_acc = correct_class_acc * class_weight 92 | user_loss = class_loss_avg * class_weight 93 | return total_l, acc, user_loss.sum(), 100*user_acc.sum() 94 | 95 | 96 | def user_per_test(args, net_glob, data_loader, class_weight): 97 | #testing 98 | net_glob.eval() 99 | correct_class = np.zeros(10) 100 | class_loss = np.zeros(10) 101 | correct_class_acc = np.zeros(10) 102 | class_loss_avg = np.zeros(10) 103 | correct_class_size = np.zeros(10) 104 | correct = 0.0 105 | dataset_size = len(data_loader.dataset) 106 | total_loss = 0.0 107 | with torch.no_grad(): 108 | for idx, (data, target) in enumerate(data_loader): 109 | data, target = data.to(args.device), target.to(args.device) 110 | output, g, z = net_glob(data) 111 | pred = z.max(1)[1] 112 | correct += pred.eq(target.data.view_as(pred)).cpu().sum().item() 113 | loss = nn.CrossEntropyLoss(reduction='none')(z, target) 114 | total_loss += loss.sum().item() 115 | for i in range(10): 116 | class_ind = target.data.view_as(pred).eq(i * torch.ones_like(pred)) 117 | correct_class_size[i] += class_ind.cpu().sum().item() 118 | correct_class[i] += (pred.eq(target.data.view_as(pred)) * class_ind).cpu().sum().item() 119 | class_loss[i] += (loss*class_ind.float()).cpu().sum().item() 120 | 121 | acc = 100.0 * (float(correct) / float(dataset_size)) 122 | total_l = total_loss / dataset_size 123 | for i in range(10): 124 | correct_class_acc[i] = (float(correct_class[i]) / float(correct_class_size[i])) 125 | class_loss_avg[i] = (float(class_loss[i]) / float(correct_class_size[i])) 126 | user_acc = correct_class_acc * class_weight 127 | user_loss = class_loss_avg * class_weight 128 | return total_l, acc, user_loss.sum(), 100*user_acc.sum() 129 | 130 | 131 | def local_test(args, net_glob, data_loader, class_weight): 132 | #testing 133 | net_glob.eval() 134 | correct_class = np.zeros(10) 135 | class_loss = np.zeros(10) 136 | correct_class_acc = np.zeros(10) 137 | class_loss_avg = np.zeros(10) 138 | correct_class_size = np.zeros(10) 139 | correct = 0.0 140 | dataset_size = len(data_loader.dataset) 141 | total_loss = 0.0 142 | with torch.no_grad(): 143 | for idx, (data, target) in enumerate(data_loader): 144 | data, target = data.to(args.device), target.to(args.device) 145 | output = net_glob(data) 146 | pred = output.max(1)[1] 147 | correct += pred.eq(target.data.view_as(pred)).cpu().sum().item() 148 | loss = nn.CrossEntropyLoss(reduction='none')(output, target) 149 | total_loss += loss.sum().item() 150 | for i in range(10): 151 | class_ind = target.data.view_as(pred).eq(i * torch.ones_like(pred)) 152 | correct_class_size[i] += class_ind.cpu().sum().item() 153 | correct_class[i] += (pred.eq(target.data.view_as(pred)) * class_ind).cpu().sum().item() 154 | class_loss[i] += (loss*class_ind.float()).cpu().sum().item() 155 | 156 | acc = 100.0 * (float(correct) / float(dataset_size)) 157 | total_l = total_loss / dataset_size 158 | for i in range(10): 159 | correct_class_acc[i] = (float(correct_class[i]) / float(correct_class_size[i])) 160 | class_loss_avg[i] = (float(class_loss[i]) / float(correct_class_size[i])) 161 | user_acc = correct_class_acc * class_weight 162 | user_loss = class_loss_avg * class_weight 163 | return total_l, acc, user_loss.sum(), 100*user_acc.sum() 164 | -------------------------------------------------------------------------------- /models/Update.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | from torch import nn, autograd 7 | from torch.utils.data import DataLoader, Dataset 8 | import numpy as np 9 | import random 10 | from sklearn import metrics 11 | 12 | 13 | class DatasetSplit(Dataset): 14 | def __init__(self, dataset, idxs): 15 | self.dataset = dataset 16 | self.targets = dataset.targets 17 | self.idxs = list(idxs) 18 | 19 | def __len__(self): 20 | return len(self.idxs) 21 | 22 | def __getitem__(self, item): 23 | image, label = self.dataset[self.idxs[item]] 24 | return image, label 25 | 26 | 27 | class LocalUpdate(object): 28 | def __init__(self, args, dataset=None, idxs=None): 29 | self.args = args 30 | self.loss_func = nn.CrossEntropyLoss() 31 | self.selected_clients = [] 32 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True) 33 | 34 | def train(self, net): 35 | net.train() 36 | # train and update 37 | 38 | # optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.9, weight_decay=5e-4) 39 | optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.5) 40 | 41 | epoch_loss = [] 42 | for iter in range(self.args.local_ep): 43 | batch_loss = [] 44 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 45 | images, labels = images.to(self.args.device), labels.to(self.args.device) 46 | net.zero_grad() 47 | log_probs = net(images) 48 | loss = self.loss_func(log_probs, labels) 49 | loss.backward() 50 | optimizer.step() 51 | if self.args.verbose and batch_idx % 10 == 0: 52 | print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 53 | iter, batch_idx * len(images), len(self.ldr_train.dataset), 54 | 100. * batch_idx / len(self.ldr_train), loss.item())) 55 | batch_loss.append(loss.item()) 56 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 57 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) 58 | 59 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.1.0 2 | torchvision==0.3.0 3 | 4 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import argparse 6 | 7 | 8 | def args_parser(): 9 | parser = argparse.ArgumentParser() 10 | # federated arguments 11 | parser.add_argument('--epochs', type=int, default=10, help="rounds of training") 12 | parser.add_argument('--num_users', type=int, default=100, help="number of users: K") 13 | parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C") 14 | parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E") 15 | parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B") 16 | parser.add_argument('--test_bs', type=int, default=128, help="test batch size") 17 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate") 18 | parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)") 19 | parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample") 20 | 21 | # model arguments 22 | parser.add_argument('--model', type=str, default='mlp', help='model name') 23 | parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel') 24 | parser.add_argument('--kernel_sizes', type=str, default='3,4,5', 25 | help='comma-separated kernel size to use for convolution') 26 | parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None") 27 | parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets") 28 | parser.add_argument('--max_pool', type=str, default='True', 29 | help="Whether use max pooling rather than strided convolutions") 30 | 31 | # other arguments 32 | parser.add_argument('--rebuild', action='store_true', help="rebuild train data") 33 | parser.add_argument('--struct', action='store_true', help="intermediate or raw data in gate model") 34 | parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset") 35 | parser.add_argument('--iid', action='store_true', help='whether i.i.d or not') 36 | parser.add_argument('--alpha', type=float, default=0.9, help='non-iid control') 37 | parser.add_argument('--num_classes', type=int, default=10, help="number of classes") 38 | parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges") 39 | parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU") 40 | parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping') 41 | parser.add_argument('--verbose', action='store_true', help='verbose print') 42 | parser.add_argument('--debug', action='store_true', help='no runs event') 43 | parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') 44 | args = parser.parse_args() 45 | return args 46 | -------------------------------------------------------------------------------- /utils/sampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | 6 | import numpy as np 7 | from torchvision import datasets, transforms 8 | from collections import defaultdict 9 | import random 10 | 11 | 12 | def mnist_iid(dataset, num_users): 13 | """ 14 | Sample I.I.D. client data from MNIST dataset 15 | :param dataset: 16 | :param num_users: 17 | :return: dict of image index 18 | """ 19 | num_items = int(len(dataset)/num_users) 20 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 21 | for i in range(num_users): 22 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) 23 | all_idxs = list(set(all_idxs) - dict_users[i]) 24 | return dict_users 25 | 26 | 27 | def mnist_noniid(dataset, num_users): 28 | """ 29 | Sample non-I.I.D client data from MNIST dataset 30 | :param dataset: 31 | :param num_users: 32 | :return: 33 | """ 34 | num_shards, num_imgs = 200, 300 35 | idx_shard = [i for i in range(num_shards)] 36 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 37 | idxs = np.arange(num_shards*num_imgs) 38 | labels = dataset.train_labels.numpy() 39 | 40 | # sort labels 41 | idxs_labels = np.vstack((idxs, labels)) 42 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 43 | idxs = idxs_labels[0,:] 44 | 45 | # divide and assign 46 | for i in range(num_users): 47 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 48 | idx_shard = list(set(idx_shard) - rand_set) 49 | for rand in rand_set: 50 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 51 | return dict_users 52 | 53 | 54 | def cifar_iid(dataset, num_users): 55 | """ 56 | Sample I.I.D. client data from CIFAR10 dataset 57 | :param dataset: 58 | :param num_users: 59 | :return: dict of image index 60 | """ 61 | num_items = int(len(dataset)/num_users) 62 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 63 | for i in range(num_users): 64 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) 65 | all_idxs = list(set(all_idxs) - dict_users[i]) 66 | return dict_users 67 | 68 | 69 | def cifar_noniid(dataset, no_participants, alpha=0.9): 70 | """ 71 | Input: Number of participants and alpha (param for distribution) 72 | Output: A list of indices denoting data in CIFAR training set. 73 | Requires: cifar_classes, a preprocessed class-indice dictionary. 74 | Sample Method: take a uniformly sampled 10-dimension vector as parameters for 75 | dirichlet distribution to sample number of images in each class. 76 | """ 77 | np.random.seed(666) 78 | random.seed(666) 79 | cifar_classes = {} 80 | for ind, x in enumerate(dataset): 81 | _, label = x 82 | if label in cifar_classes: 83 | cifar_classes[label].append(ind) 84 | else: 85 | cifar_classes[label] = [ind] 86 | 87 | per_participant_list = defaultdict(list) 88 | no_classes = len(cifar_classes.keys()) 89 | class_size = len(cifar_classes[0]) 90 | datasize = {} 91 | for n in range(no_classes): 92 | random.shuffle(cifar_classes[n]) 93 | sampled_probabilities = class_size * np.random.dirichlet( 94 | np.array(no_participants * [alpha])) 95 | for user in range(no_participants): 96 | no_imgs = int(round(sampled_probabilities[user])) 97 | datasize[user, n] = no_imgs 98 | sampled_list = cifar_classes[n][:min(len(cifar_classes[n]), no_imgs)] 99 | per_participant_list[user].extend(sampled_list) 100 | cifar_classes[n] = cifar_classes[n][min(len(cifar_classes[n]), no_imgs):] 101 | train_img_size = np.zeros(no_participants) 102 | for i in range(no_participants): 103 | train_img_size[i] = sum([datasize[i,j] for j in range(10)]) 104 | clas_weight = np.zeros((no_participants,10)) 105 | for i in range(no_participants): 106 | for j in range(10): 107 | clas_weight[i,j] = float(datasize[i,j])/float((train_img_size[i])) 108 | return per_participant_list, clas_weight 109 | 110 | 111 | if __name__ == '__main__': 112 | dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, 113 | transform=transforms.Compose([ 114 | transforms.ToTensor(), 115 | transforms.Normalize((0.1307,), (0.3081,)) 116 | ])) 117 | num = 100 118 | d = mnist_noniid(dataset_train, num) 119 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import torch.backends.cudnn 2 | import torch.cuda 3 | import numpy as np 4 | import random 5 | import torch.nn as nn 6 | 7 | 8 | def setup_seed(seed): 9 | torch.manual_seed(seed+1) 10 | torch.cuda.manual_seed_all(seed+123) 11 | np.random.seed(seed+1234) 12 | random.seed(seed+12345) 13 | torch.backends.cudnn.deterministic = True 14 | 15 | 16 | def add_scalar(writer, user_num, test_result, epoch): 17 | test_loss, test_acc, user_loss, user_acc = test_result 18 | writer.add_scalar(f'user_{user_num}/global/test_loss', test_loss, epoch) 19 | writer.add_scalar(f'user_{user_num}/global/test_acc', test_acc, epoch) 20 | writer.add_scalar(f'user_{user_num}/local/test_loss', user_loss, epoch) 21 | writer.add_scalar(f'user_{user_num}/local/test_acc', user_acc, epoch) 22 | 23 | --------------------------------------------------------------------------------