├── .gitignore
├── LICENSE
├── README.md
├── data
├── README.md
├── __init__.py
├── cifar
│ └── .gitkeep
└── mnist
│ └── .gitkeep
├── imgs
├── 05cifar_lenet.pdf
├── 05cifar_vgg.pdf
├── 05fmnist_lenet.pdf
├── 09cifar_lenet.pdf
├── 09cifar_vgg.pdf
├── 09fmnist_lenet.pdf
├── 2cifar_lenet.pdf
├── 2cifar_vgg.pdf
├── 2fmnist_lenet.pdf
├── fed_acc.pdf
└── local_acc.pdf
├── main_fed.py
├── main_gate.py
├── main_local.py
├── main_nn.py
├── main_per_fb.py
├── models
├── Fed.py
├── Nets.py
├── Test.py
├── Update.py
└── __init__.py
├── requirements.txt
└── utils
├── __init__.py
├── options.py
├── sampling.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | imgs/*_f.pdf
2 | runs/*
3 | save/
4 | plot*
5 | test.py
6 | main__gate.py
7 | _config.yml
8 | main_gate_single.py
9 |
10 | # pycharm
11 | .idea/*
12 |
13 | # documents
14 | *.csv
15 | .xls
16 | .xlsx
17 | .pdf
18 | .json
19 |
20 | # macOS
21 | .DS_Store
22 |
23 | # Byte-compiled / optimized / DLL files
24 | __pycache__/
25 | *.py[cod]
26 | *$py.class
27 |
28 | # C extensions
29 | *.so
30 |
31 | # Distribution / packaging
32 | .Python
33 | env/
34 | build/
35 | develop-eggs/
36 | dist/
37 | downloads/
38 | eggs/
39 | .eggs/
40 | lib/
41 | lib64/
42 | parts/
43 | sdist/
44 | var/
45 | wheels/
46 | *.egg-info/
47 | .installed.cfg
48 | *.egg
49 |
50 | # PyInstaller
51 | # Usually these files are written by a python script from a template
52 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
53 | *.manifest
54 | *.spec
55 |
56 | # virtualenv
57 | .venv
58 | venv/
59 | ENV/
60 |
61 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Shaoxiong Ji
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PFL-MoE: Personalized Federated Learning Based on Mixture of Experts
2 |
3 | In our experiments, we use two image recognition datasets to conduct model training:
4 | Fashion-MNIST and CIFAR-10. With two network models trained, we have three combinations: Fashion-MNIST + LeNet-5, CIFAR-10 + LeNet-5, and CIFAR-10 + VGG-16.
5 |
6 | ## Requirements
7 | python>=3.6
8 | pytorch>=0.4
9 |
10 | ## Run
11 | dataset+model: fmnist+lenet, cifar+lenet, cifar+vgg
12 | $\alpha=[0.5, 0.9, 2.0]$ for each group of dataset+model
13 |
14 | Local:
15 | > python [main_local.py](main_local.py) --dataset fmnist --model lenet --epochs 100 --gpu 0 --num_users 100 --alpha 0.5
16 |
17 | FedAvg:
18 | > python [main_fed.py](main_fed.py) --dataset fmnist --model lenet --epochs 1000 --gpu 0 --lr 0.01 --num_users 100 --frac 0.1 --alpha 0.5
19 |
20 | PFL-FB + PFL-MF:
21 | > python [main_gate.py](main_gate.py) --dataset fmnist --model lenet --epochs 200 --num_users 100 --gpu 1 --alpha 0.5
22 |
23 | PFL-FB + PFL-MFE:
24 | > python [main_gate.py](main_gate.py) --dataset fmnist --model lenet --epochs 200 --num_users 100 --gpu 1 --alpha 0.5 --struct
25 |
26 | See the arguments in [options.py](utils/options.py).
27 | ## Results
28 | ###
29 | Each client has two types of tests, including local test and global test.
30 |
31 | Table 1. The average value of **local test** accuracy of all clients in three baselines and proposed algorithms.
32 |
33 |
34 |
35 | |
36 | non-IID |
37 | Local(%) |
38 | FedAvg(%) |
39 | PFL-FB(%) |
40 | PFL-MF(%) |
41 | PFL-MFE(%) |
42 |
43 |
44 | | Fashion-MNIST & LeNet5 |
45 | 0.5 |
46 | 84.87 |
47 | 90 |
48 | 92.84 |
49 | 92.85 |
50 | 92.89 |
51 |
52 |
53 | | 0.9 |
54 | 82.23 |
55 | 90.31 |
56 | 91.84 |
57 | 92.02 |
58 | 92.01 |
59 |
60 |
61 | | 2 |
62 | 78.63 |
63 | 90.5 |
64 | 90.47 |
65 | 90.97 |
66 | 90.93 |
67 |
68 |
69 | | CIFAR-10 & LeNet5 |
70 | 0.5 |
71 | 65.58 |
72 | 68.92 |
73 | 77.46 |
74 | 75.49 |
75 | 77.23 |
76 |
77 |
78 | | 0.9 |
79 | 61.49 |
80 | 70.7 |
81 | 74.7 |
82 | 74.1 |
83 | 74.74 |
84 |
85 |
86 | | 2 |
87 | 55.8 |
88 | 72.69 |
89 | 72.5 |
90 | 73.24 |
91 | 73.44 |
92 |
93 |
94 | | CIFAR-10 & VGG-16 |
95 | 0.5 |
96 | 52.77 |
97 | 88.16 |
98 | 91.92 |
99 | 90.63 |
100 | 91.71 |
101 |
102 |
103 | | 0.9 |
104 | 45.24 |
105 | 88.45 |
106 | 91.34 |
107 | 90.63 |
108 | 91.18 |
109 |
110 |
111 | | 2 |
112 | 34.2 |
113 | 89.17 |
114 | 90.4 |
115 | 90.15 |
116 | 90.4 |
117 |
118 |
119 |
120 | Table 2. The average value of **global test** accuracy of all clients.
121 |
122 |
123 |
124 | |
125 | non-IID |
126 | Local(%) |
127 | FedAvg(%) |
128 | PFL-FB(%) |
129 | PFL-MF(%) |
130 | PFL-MFE(%) |
131 |
132 |
133 | | Fashion-MNIST & LeNet5 |
134 | 0.5 |
135 | 57.77 |
136 | 90 |
137 | 83.35 |
138 | 85.45 |
139 | 85.3 |
140 |
141 |
142 | | 0.9 |
143 | 65.28 |
144 | 90.31 |
145 | 85.91 |
146 | 87.69 |
147 | 87.67 |
148 |
149 |
150 | | 2 |
151 | 71.06 |
152 | 90.5 |
153 | 87.77 |
154 | 89.37 |
155 | 89.18 |
156 |
157 |
158 | | CIFAR-10 & LeNet5 |
159 | 0.5 |
160 | 28.89 |
161 | 68.92 |
162 | 54.28 |
163 | 62.33 |
164 | 58.27 |
165 |
166 |
167 | | 0.9 |
168 | 32.1 |
169 | 70.7 |
170 | 59.93 |
171 | 65.78 |
172 | 64.13 |
173 |
174 |
175 | | 2 |
176 | 35.32 |
177 | 72.69 |
178 | 66.06 |
179 | 69.79 |
180 | 69.78 |
181 |
182 |
183 | | CIFAR-10 & VGG-16 |
184 | 0.5 |
185 | 21.53 |
186 | 88.16 |
187 | 82.39 |
188 | 85.81 |
189 | 84.05 |
190 |
191 |
192 | | 0.9 |
193 | 22.45 |
194 | 88.45 |
195 | 82.62 |
196 | 88.15 |
197 | 87.9 |
198 |
199 |
200 | | 2 |
201 | 21.27 |
202 | 89.17 |
203 | 88.77 |
204 | 89.3 |
205 | 89.3 |
206 |
207 |
208 |
209 | Fig 1. Fashion-MNIST + LeNet-5, $\alpha=0.9$. The global test accuracy and local test accuracy of all client of PFL-FB, PFL-MF, and PFL-MFE algorithms. All x-axis are FedAvg local test accuracy of each client (can be regarded as client index). Each point represents a test accuracy comparison between a PFL algorithm and FedAvg for a particular client.
210 |
211 | 
212 |
213 | Fig 2. CIFAR-10 + LeNet-5, $\alpha=0.9$.
214 |
215 | 
216 |
217 | Fig 3. CIFAR-10 + VGG-16, $\alpha=2.0$.
218 |
219 | 
220 |
221 | ## Acknowledgements
222 |
223 | Acknowledgments give to [shaoxiongji](https://github.com/shaoxiongji)
224 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Data
2 |
3 | MNIST & CIFAR-10 datasets will be downloaded automatically by the torchvision package.
4 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @python: 3.6
4 |
5 |
--------------------------------------------------------------------------------
/data/cifar/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/data/cifar/.gitkeep
--------------------------------------------------------------------------------
/data/mnist/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/data/mnist/.gitkeep
--------------------------------------------------------------------------------
/imgs/05cifar_lenet.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/05cifar_lenet.pdf
--------------------------------------------------------------------------------
/imgs/05cifar_vgg.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/05cifar_vgg.pdf
--------------------------------------------------------------------------------
/imgs/05fmnist_lenet.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/05fmnist_lenet.pdf
--------------------------------------------------------------------------------
/imgs/09cifar_lenet.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/09cifar_lenet.pdf
--------------------------------------------------------------------------------
/imgs/09cifar_vgg.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/09cifar_vgg.pdf
--------------------------------------------------------------------------------
/imgs/09fmnist_lenet.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/09fmnist_lenet.pdf
--------------------------------------------------------------------------------
/imgs/2cifar_lenet.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/2cifar_lenet.pdf
--------------------------------------------------------------------------------
/imgs/2cifar_vgg.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/2cifar_vgg.pdf
--------------------------------------------------------------------------------
/imgs/2fmnist_lenet.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/2fmnist_lenet.pdf
--------------------------------------------------------------------------------
/imgs/fed_acc.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/fed_acc.pdf
--------------------------------------------------------------------------------
/imgs/local_acc.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guobbin/PFL-MoE/256dbf5a768b6b08a60713fd342395d01f0d6954/imgs/local_acc.pdf
--------------------------------------------------------------------------------
/main_fed.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | import copy
6 | import numpy as np
7 | from torchvision import datasets, transforms
8 | import torch
9 |
10 | from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid
11 | from utils.options import args_parser
12 | from models.Update import LocalUpdate
13 | from models.Nets import vgg16, CNNCifar
14 | from models.Fed import FedAvg
15 | from models.Test import test_img
16 | from utils.util import setup_seed
17 | from datetime import datetime
18 | from torch.utils.tensorboard import SummaryWriter
19 |
20 |
21 | if __name__ == '__main__':
22 | # parse args
23 | args = args_parser()
24 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
25 | setup_seed(args.seed)
26 |
27 | # log
28 | current_time = datetime.now().strftime('%b.%d_%H.%M.%S')
29 | TAG = 'exp/fed/{}_{}_{}_C{}_iid{}_{}_user{}_{}'.format(args.dataset, args.model, args.epochs, args.frac, args.iid,
30 | args.alpha, args.num_users, current_time)
31 | # TAG = f'alpha_{alpha}/data_distribution'
32 | logdir = f'runs/{TAG}' if not args.debug else f'runs2/{TAG}'
33 | writer = SummaryWriter(logdir)
34 |
35 | # load dataset and split users
36 | if args.dataset == 'mnist':
37 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
38 | dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
39 | dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
40 | # sample users
41 | if args.iid:
42 | dict_users = mnist_iid(dataset_train, args.num_users)
43 | else:
44 | dict_users = mnist_noniid(dataset_train, args.num_users)
45 | elif args.dataset == 'cifar':
46 | transform_train = transforms.Compose([
47 | transforms.RandomCrop(32, padding=4),
48 | transforms.RandomHorizontalFlip(),
49 | transforms.ToTensor(),
50 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
51 | ])
52 | transform_test = transforms.Compose([
53 | transforms.ToTensor(),
54 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
55 | ])
56 | dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=transform_train)
57 | dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=transform_test)
58 | elif args.dataset == 'fmnist':
59 | dataset_train = datasets.FashionMNIST('../data/fmnist/', train=True, download=True,
60 | transform=transforms.Compose([
61 | transforms.Resize((32, 32)),
62 | transforms.RandomCrop(32, padding=4),
63 | transforms.RandomHorizontalFlip(),
64 | transforms.ToTensor(),
65 | transforms.Normalize((0.1307,), (0.3081,)),
66 | ]))
67 |
68 | # testing
69 | dataset_test = datasets.FashionMNIST('../data/fmnist/', train=False, download=True,
70 | transform=transforms.Compose([
71 | transforms.Resize((32, 32)),
72 | transforms.ToTensor(),
73 | transforms.Normalize((0.1307,), (0.3081,))
74 | ]))
75 | # test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
76 | else:
77 | exit('Error: unrecognized dataset')
78 |
79 | if args.iid:
80 | dict_users = cifar_iid(dataset_train, args.num_users)
81 | else:
82 | dict_users, _ = cifar_noniid(dataset_train, args.num_users, args.alpha)
83 | for k, v in dict_users.items():
84 | writer.add_histogram(f'user_{k}/data_distribution',
85 | np.array(dataset_train.targets)[v],
86 | bins=np.arange(11))
87 | writer.add_histogram(f'all_user/data_distribution',
88 | np.array(dataset_train.targets)[v],
89 | bins=np.arange(11), global_step=k)
90 |
91 | # build model
92 | if args.model == 'lenet' and (args.dataset == 'cifar' or args.dataset == 'fmnist'):
93 | net_glob = CNNCifar(args=args).to(args.device)
94 | elif args.model == 'vgg' and args.dataset == 'cifar':
95 | net_glob = vgg16().to(args.device)
96 | else:
97 | exit('Error: unrecognized model')
98 | print(net_glob)
99 | net_glob.train()
100 |
101 | # copy weights
102 | w_glob = net_glob.state_dict()
103 |
104 | # training
105 | loss_train = []
106 | cv_loss, cv_acc = [], []
107 | val_loss_pre, counter = 0, 0
108 | net_best = None
109 | best_loss = None
110 | val_acc_list, net_list = [], []
111 | test_best_acc = 0.0
112 |
113 | for iter in range(args.epochs):
114 | w_locals, loss_locals = [], []
115 | m = max(int(args.frac * args.num_users), 1)
116 | idxs_users = np.random.choice(range(args.num_users), m, replace=False)
117 | for idx in idxs_users:
118 | local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
119 | w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
120 | w_locals.append(w)
121 | loss_locals.append(loss)
122 | # update global weights
123 | w_glob = FedAvg(w_locals)
124 |
125 | # copy weight to net_glob
126 | net_glob.load_state_dict(w_glob)
127 |
128 | # print loss
129 | loss_avg = sum(loss_locals) / len(loss_locals)
130 | print('Round {:3d}, Train loss {:.3f}'.format(iter, loss_avg))
131 | loss_train.append(loss_avg)
132 | writer.add_scalar('train_loss', loss_avg, iter)
133 | test_acc, test_loss = test_img(net_glob, dataset_test, args)
134 | writer.add_scalar('test_loss', test_loss, iter)
135 | writer.add_scalar('test_acc', test_acc, iter)
136 |
137 | save_info = {
138 | "model": net_glob.state_dict(),
139 | "epoch": iter
140 | }
141 | # save model weights
142 | if (iter+1) % 500 == 0:
143 | save_path = f'./save2/{TAG}_{iter+1}es' if args.debug else f'./save/{TAG}_{iter+1}es'
144 | torch.save(save_info, save_path)
145 | if iter > 100 and test_acc > test_best_acc:
146 | test_best_acc = test_acc
147 | save_path = f'./save2/{TAG}_bst' if args.debug else f'./save/{TAG}_bst'
148 | torch.save(save_info, save_path)
149 |
150 | # plot loss curve
151 | # plt.figure()
152 | # plt.plot(range(len(loss_train)), loss_train)
153 | # plt.ylabel('train_loss')
154 | # plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))
155 |
156 | # testing
157 | net_glob.eval()
158 | acc_train, loss_train = test_img(net_glob, dataset_train, args)
159 | acc_test, loss_test = test_img(net_glob, dataset_test, args)
160 | print("Training accuracy: {:.2f}".format(acc_train))
161 | print("Testing accuracy: {:.2f}".format(acc_test))
162 | writer.close()
163 |
--------------------------------------------------------------------------------
/main_gate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.utils.data import DataLoader
8 | import torch.optim as optim
9 | from torchvision import datasets, transforms
10 | from utils.options import args_parser
11 | from models.Nets import CNNGate, gate_vgg16
12 | from utils.util import setup_seed, add_scalar
13 | from torch.utils.tensorboard import SummaryWriter
14 | from datetime import datetime
15 | from utils.sampling import cifar_noniid
16 | import numpy as np
17 | from models.Update import DatasetSplit
18 | from models.Test import user_test, user_per_test
19 |
20 |
21 | if __name__ == '__main__':
22 | # parse args
23 | args = args_parser()
24 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
25 | setup_seed(args.seed)
26 |
27 | # log
28 | current_time = datetime.now().strftime('%b.%d_%H.%M.%S')
29 | TAG = 'exp/{}gate2/{}_{}_{}_{}_user{}_{}'.format('struct/' if args.struct else '', args.dataset, args.model, args.epochs,
30 | args.alpha, args.num_users, current_time)
31 | TAG2 = 'exp/{}per_fb/{}_{}_{}_{}_user{}_{}'.format('struct/' if args.struct else '', args.dataset, args.model, args.epochs,
32 | args.alpha, args.num_users, current_time)
33 | logdir = f'runs/{TAG}'
34 | logdir2 = f'runs/{TAG2}'
35 | if args.debug:
36 | logdir = f'runs2/{TAG}'
37 | logdir2 = f'runs2/{TAG2}'
38 | writer = SummaryWriter(logdir)
39 | writer2 = SummaryWriter(logdir2)
40 |
41 | # load dataset and split users
42 | train_loader, test_loader, class_weight = 1, 1, 1
43 |
44 | save_dataset_path = f'./data/{args.dataset}_non_iid{args.alpha}_user{args.num_users}_fast_data'
45 | # global_weight = torch.load('./save/exp/fed/cifar_resnet_1000_C0.1_iidFalse_2.0_user100_Nov.28_01.41.16_bst')['model']
46 | global_weight = torch.load(
47 | f'./save/exp/fed/{args.dataset}_{args.model}_1000_C0.1_iidFalse_{args.alpha}_user{args.num_users}_bst')[
48 | 'model']
49 | if 'gate.weight' in global_weight:
50 | del (global_weight['gate.weight'])
51 | del (global_weight['gate.bias'])
52 | if args.rebuild:
53 | if args.dataset == "cifar":
54 | # training
55 | transform_train = transforms.Compose([
56 | transforms.RandomCrop(32, padding=4),
57 | transforms.RandomHorizontalFlip(),
58 | transforms.ToTensor(),
59 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
60 | ])
61 | dataset_train = datasets.CIFAR10('../data/cifar', train=True, transform=transform_train, download=True)
62 | # testing
63 | transform_test = transforms.Compose([
64 | transforms.ToTensor(),
65 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
66 | ])
67 | dataset_test = datasets.CIFAR10('../data/cifar', train=False, transform=transform_test, download=True)
68 | elif args.dataset == "fmnist":
69 | dataset_train = datasets.FashionMNIST('../data/fmnist/', train=True, download=True,
70 | transform=transforms.Compose([
71 | transforms.Resize((32, 32)),
72 | transforms.RandomCrop(32, padding=4),
73 | transforms.RandomHorizontalFlip(),
74 | transforms.ToTensor(),
75 | transforms.Normalize((0.1307,), (0.3081,)),
76 | ]))
77 |
78 | # testing
79 | dataset_test = datasets.FashionMNIST('../data/fmnist/', train=False, download=True,
80 | transform=transforms.Compose([
81 | transforms.Resize((32, 32)),
82 | transforms.ToTensor(),
83 | transforms.Normalize((0.1307,), (0.3081,))
84 | ]))
85 | else:
86 | exit('Error: unrecognized dataset')
87 | # non_iid
88 | dict_users, _ = cifar_noniid(dataset_train, args.num_users, args.alpha)
89 | save_dataset = {
90 | "dataset_test": dataset_test,
91 | "dataset_train": dataset_train,
92 | "dict_users": dict_users
93 | }
94 | torch.save(save_dataset, save_dataset_path)
95 | else:
96 | save_dataset = torch.load(save_dataset_path)
97 | dataset_test = save_dataset['dataset_test']
98 | dataset_train = save_dataset['dataset_train']
99 | dict_users = save_dataset['dict_users']
100 |
101 | test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
102 | for k, v in dict_users.items():
103 | writer.add_histogram(f'user_{k}/data_distribution',
104 | np.array(dataset_train.targets)[v],
105 | bins=np.arange(11))
106 | writer.add_histogram(f'all_user/data_distribution',
107 | np.array(dataset_train.targets)[v],
108 | bins = np.arange(11), global_step = k)
109 |
110 | # build model
111 | if args.model == 'lenet' and (args.dataset == 'cifar' or args.dataset == 'fmnist'):
112 | net_glob = CNNGate(args=args).to(args.device)
113 | elif args.model == 'vgg' and args.dataset == 'cifar':
114 | net_glob = gate_vgg16(args=args).to(args.device)
115 | else:
116 | exit('Error: unrecognized model')
117 | image, target = next(iter(test_loader))
118 | writer.add_graph(net_glob, image.to(args.device))
119 |
120 | gate_epochs = 200
121 |
122 | local_acc = np.zeros([args.num_users, args.epochs + gate_epochs + 1])
123 | total_acc = np.zeros([args.num_users, args.epochs + gate_epochs + 1])
124 | local_acc2 = np.zeros([args.num_users, args.epochs + gate_epochs + 1])
125 | total_acc2 = np.zeros([args.num_users, args.epochs + gate_epochs + 1])
126 |
127 | for user_num in range(len(dict_users)):
128 | # user data
129 | user_train = DatasetSplit(dataset_train, dict_users[user_num])
130 |
131 | np.random.shuffle(dict_users[user_num])
132 | cut_point = len(dict_users[user_num]) // 4
133 | train_loader = DataLoader(DatasetSplit(dataset_train, dict_users[user_num][cut_point:]),
134 | batch_size=64, shuffle=True)
135 | gate_loader = DataLoader(DatasetSplit(dataset_train, dict_users[user_num][:cut_point]),
136 | batch_size=64, shuffle=True)
137 |
138 | class_weight = np.zeros(10)
139 | for image, label in user_train:
140 | class_weight[label] += 1
141 | class_weight /= sum(class_weight)
142 |
143 | # init
144 |
145 | net_glob.load_state_dict(global_weight, False)
146 |
147 | if args.model == 'lenet':
148 | keys_ind = ['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias']
149 | net_glob.load_state_dict({'p' + k: global_weight[k] for k in keys_ind}, strict=False)
150 | elif args.model == 'vgg':
151 | keys_ind = ['classifier.1.weight', 'classifier.1.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias']
152 | net_glob.load_state_dict({'p' + k: global_weight[k] for k in keys_ind}, strict=False)
153 | else:
154 | exit("Error: unrecognized model")
155 | net_glob.gate.reset_parameters()
156 |
157 | # training
158 | if args.model == 'lenet':
159 | layer_set = {'p' + k[:k.rindex('.')] for k in keys_ind}
160 | optimizer = optim.SGD([{'params': getattr(net_glob, l).parameters()} for l in layer_set],
161 | lr=0.001, momentum=0.9, weight_decay=5e-4)
162 | elif args.model == 'vgg':
163 | layer_set = {k[len('pclassifier'):k.rindex('.')] for k in keys_ind}
164 | optimizer = optim.SGD([{'params': net_glob.pclassifier.parameters()}],
165 | lr=0.005, momentum=0.9, weight_decay=5e-4)
166 | else:
167 | exit('Error: unrecognized model')
168 | # optimizer_gate = optim.SGD([{'params': net_glob.gate.parameters()}], lr=0.001, momentum=0.9, weight_decay=5e-4)
169 | criterion = nn.CrossEntropyLoss()
170 |
171 | test_result = user_per_test(args, net_glob, test_loader, class_weight)
172 | add_scalar(writer2, user_num, test_result, 0)
173 | total_acc2[user_num][0] = test_result[1]
174 | local_acc2[user_num][0] = test_result[3]
175 |
176 | for epoch in range(1, args.epochs+1):
177 | net_glob.train()
178 | batch_loss = []
179 | gate_out = []
180 | for batch_idx, (data, target) in enumerate(train_loader):
181 | data, target = data.to(args.device), target.to(args.device)
182 | optimizer.zero_grad()
183 | output, g, z = net_glob(data)
184 | gate_out.append(g)
185 | loss = criterion(z, target)
186 | loss.backward()
187 | optimizer.step()
188 | batch_loss.append(loss.item())
189 | if epoch % 10 == 1:
190 | # writer.add_histogram(f"user_{user_num}/gate_out", torch.cat(gate_out[0:-1], -1), epoch)
191 | if args.model == 'lenet':
192 | for layer in layer_set:
193 | writer.add_histogram(f"user_{user_num}/{layer}/weight", getattr(net_glob, layer).weight, epoch)
194 | elif args.model == 'vgg':
195 | for layer in layer_set:
196 | writer.add_histogram(f"user_{user_num}/pclassifier.{layer}/weight", getattr(net_glob.pclassifier, layer).weight, epoch)
197 | loss_avg = sum(batch_loss) / len(batch_loss)
198 | print(f'User {user_num} train loss:', loss_avg)
199 | writer2.add_scalar(f'user_{user_num}/pfc_train_loss', loss_avg, epoch)
200 |
201 | test_result = user_per_test(args, net_glob, test_loader, class_weight)
202 | print(f'global test acc:', test_result[1])
203 | add_scalar(writer2, user_num, test_result, epoch)
204 | total_acc2[user_num][epoch] = test_result[1]
205 | local_acc2[user_num][epoch] = test_result[3]
206 |
207 | test_result = user_test(args, net_glob, test_loader, class_weight)
208 | add_scalar(writer, user_num, test_result, args.epochs)
209 | total_acc[user_num][args.epochs] = test_result[1]
210 | local_acc[user_num][args.epochs] = test_result[3]
211 |
212 | optimizer_gate = optim.Adam([{'params': net_glob.gate.parameters()}], weight_decay=5e-4)
213 |
214 | for gate_epoch in range(1, 1 + gate_epochs):
215 | net_glob.train()
216 | gate_epoch_loss = []
217 | gate_out = torch.tensor([], device=args.device)
218 | for batch_idx, (data, target) in enumerate(gate_loader):
219 | data, target = data.to(args.device), target.to(args.device)
220 | optimizer_gate.zero_grad()
221 | output, g, z = net_glob(data)
222 | gate_out = torch.cat((gate_out, g.view(-1)))
223 | loss = criterion(output, target)
224 | loss.backward()
225 | optimizer_gate.step()
226 | gate_epoch_loss.append(loss.item())
227 | if gate_epoch % 10 == 1:
228 | writer.add_histogram(f"user_{user_num}/gate_out", gate_out)
229 | writer.add_histogram(f"user_{user_num}/gate/weight", net_glob.gate.weight)
230 | writer.add_histogram(f"user_{user_num}/gate/bais", net_glob.gate.bias)
231 | loss_avg = sum(gate_epoch_loss) / len(gate_epoch_loss)
232 | print(f'User {user_num} gate loss', loss_avg)
233 | writer.add_scalar(f'user_{user_num}/gate_train_loss', loss_avg, args.epochs + gate_epoch)
234 |
235 | test_result = user_test(args, net_glob, test_loader, class_weight)
236 | add_scalar(writer, user_num, test_result, args.epochs + gate_epoch)
237 | total_acc[user_num][args.epochs + gate_epoch] = test_result[1]
238 | local_acc[user_num][args.epochs + gate_epoch] = test_result[3]
239 |
240 | save_info = {
241 | "total_acc": total_acc,
242 | "local_acc": local_acc
243 | }
244 | save_info2 = {
245 | "total_acc": total_acc2,
246 | "local_acc": local_acc2
247 | }
248 | save_path = f'{logdir}/local_train_epoch_acc'
249 | save_path2 = f'{logdir2}/local_train_epoch_acc'
250 | torch.save(save_info, save_path)
251 | torch.save(save_info2, save_path2)
252 |
253 | total_acc = total_acc.mean(axis=0)
254 | local_acc = local_acc.mean(axis=0)
255 | total_acc2 = total_acc2.mean(axis=0)
256 | local_acc2 = local_acc2.mean(axis=0)
257 | for epoch, _ in enumerate(total_acc):
258 | if epoch >= args.epochs:
259 | writer.add_scalar('test/global/test_acc', total_acc[epoch], epoch)
260 | writer.add_scalar('test/local/test_acc', local_acc[epoch], epoch)
261 | if epoch <= args.epochs:
262 | writer2.add_scalar('test/global/test_acc', total_acc2[epoch], epoch)
263 | writer2.add_scalar('test/local/test_acc', local_acc2[epoch], epoch)
264 | writer.close()
265 | writer2.close()
266 |
--------------------------------------------------------------------------------
/main_local.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 | import copy
5 | import numpy as np
6 | from torchvision import datasets, transforms
7 | import torch
8 | import torch.nn
9 | import torch.nn.functional as F
10 | from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid
11 | from utils.options import args_parser
12 | from models.Nets import CNNCifar, vgg16
13 | from utils.util import setup_seed
14 | from datetime import datetime
15 | from torch.utils.tensorboard import SummaryWriter
16 | from torch.utils.data import DataLoader
17 | import torch.optim as optim
18 | from models.Update import DatasetSplit
19 | from models.Test import local_test
20 | from utils.util import add_scalar
21 |
22 |
23 | def test(model, data_source):
24 | model.eval()
25 | total_loss = 0.0
26 | correct = 0.0
27 | correct_class = np.zeros(10)
28 | correct_class_acc = np.zeros(10)
29 | correct_class_size = np.zeros(10)
30 |
31 | dataset_size = len(data_source.dataset)
32 | data_iterator = data_source
33 | with torch.no_grad():
34 | for batch_id, (data, targets) in enumerate(data_iterator):
35 | data, targets = data.to(args.device), targets.to(args.device)
36 | output = model(data)
37 | total_loss += F.cross_entropy(output, targets,
38 | reduction='sum').item() # sum up batch loss
39 | pred = output.data.max(1)[1] # get the index of the max log-probability
40 | correct += pred.eq(targets.data.view_as(pred)).cpu().sum().item()
41 | for i in range(10):
42 | class_ind = targets.data.view_as(pred).eq(i*torch.ones_like(pred))
43 | correct_class_size[i] += class_ind.cpu().sum().item()
44 | correct_class[i] += (pred.eq(targets.data.view_as(pred))*class_ind).cpu().sum().item()
45 |
46 | acc = 100.0 * (float(correct) / float(dataset_size))
47 | for i in range(10):
48 | correct_class_acc[i] = (float(correct_class[i]) / float(correct_class_size[i]))
49 | total_l = total_loss / dataset_size
50 | # print(f'Average loss: {total_l}, Accuracy: {correct}/{dataset_size} ({acc}%)')
51 | return total_l, acc, correct_class_acc
52 |
53 |
54 | if __name__ == '__main__':
55 | # parse args
56 | args = args_parser()
57 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
58 | setup_seed(args.seed)
59 |
60 | # log
61 | current_time = datetime.now().strftime('%b.%d_%H.%M.%S')
62 | TAG = 'exp/local/{}_{}_{}_iid{}_{}_user{}_{}'.format(args.dataset, args.model, args.epochs, args.iid, args.alpha,
63 | args.num_users, current_time)
64 | logdir = f'runs/{TAG}' if not args.debug else f'runs2/{TAG}'
65 | writer = SummaryWriter(logdir)
66 |
67 | # load dataset and split users
68 | if args.dataset == 'mnist':
69 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
70 | dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
71 | dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
72 | # sample users
73 | if args.iid:
74 | dict_users = mnist_iid(dataset_train, args.num_users)
75 | else:
76 | dict_users = mnist_noniid(dataset_train, args.num_users)
77 |
78 | elif args.dataset == 'cifar':
79 | transform_train = transforms.Compose([
80 | transforms.RandomCrop(32, padding=4),
81 | transforms.RandomHorizontalFlip(),
82 | transforms.ToTensor(),
83 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
84 | ])
85 | transform_test = transforms.Compose([
86 | transforms.ToTensor(),
87 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
88 | ])
89 | dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=transform_train)
90 | dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=transform_test)
91 | elif args.dataset == 'fmnist':
92 | dataset_train = datasets.FashionMNIST('../data/fmnist/', train=True, download=True,
93 | transform=transforms.Compose([
94 | transforms.Resize((32, 32)),
95 | transforms.RandomCrop(32, padding=4),
96 | transforms.RandomHorizontalFlip(),
97 | transforms.ToTensor(),
98 | transforms.Normalize((0.1307,), (0.3081,)),
99 | ]))
100 |
101 | # testing
102 | dataset_test = datasets.FashionMNIST('../data/fmnist/', train=False, download=True,
103 | transform=transforms.Compose([
104 | transforms.Resize((32, 32)),
105 | transforms.ToTensor(),
106 | transforms.Normalize((0.1307,), (0.3081,))
107 | ]))
108 | # test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
109 | else:
110 | exit('Error: unrecognized dataset')
111 |
112 | if args.iid:
113 | dict_users = cifar_iid(dataset_train, args.num_users)
114 | else:
115 | dict_users, _ = cifar_noniid(dataset_train, args.num_users, args.alpha)
116 | for k, v in dict_users.items():
117 | writer.add_histogram(f'user_{k}/data_distribution',
118 | np.array(dataset_train.targets)[v],
119 | bins=np.arange(11))
120 | writer.add_histogram(f'all_user/data_distribution',
121 | np.array(dataset_train.targets)[v],
122 | bins=np.arange(11), global_step=k)
123 |
124 | test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
125 | img_size = dataset_train[0][0].shape
126 |
127 | # build model
128 | if args.model == 'lenet' and (args.dataset == 'cifar' or args.dataset == 'fmnist'):
129 | net_glob = CNNCifar(args=args).to(args.device)
130 | elif args.model == 'vgg' and args.dataset == 'cifar':
131 | net_glob = vgg16().to(args.device)
132 | else:
133 | exit('Error: unrecognized model')
134 | print(net_glob)
135 | net_glob.train()
136 |
137 | # copy weights
138 | w_init = copy.deepcopy(net_glob.state_dict())
139 |
140 | local_acc_final = []
141 | total_acc_final = []
142 | local_acc = np.zeros([args.num_users, args.epochs])
143 | total_acc = np.zeros([args.num_users, args.epochs])
144 |
145 | # training
146 | for idx in range(args.num_users):
147 | # print(w_init)
148 | net_glob.load_state_dict(w_init)
149 | optimizer = optim.Adam(net_glob.parameters())
150 | train_loader = DataLoader(DatasetSplit(dataset_train, dict_users[idx]), batch_size=64, shuffle=True)
151 | image_trainset_weight = np.zeros(10)
152 | for label in np.array(dataset_train.targets)[dict_users[idx]]:
153 | image_trainset_weight[label] += 1
154 | image_trainset_weight = image_trainset_weight / image_trainset_weight.sum()
155 | list_loss = []
156 | net_glob.train()
157 | for epoch in range(args.epochs):
158 | batch_loss = []
159 | for batch_idx, (data, target) in enumerate(train_loader):
160 | data, target = data.to(args.device), target.to(args.device)
161 | optimizer.zero_grad()
162 | output = net_glob(data)
163 | loss = F.cross_entropy(output, target)
164 | loss.backward()
165 | optimizer.step()
166 | # if batch_idx % 3 == 0:
167 | # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
168 | # epoch, batch_idx * len(data), len(train_loader.dataset),
169 | # 100. * batch_idx / len(train_loader), loss.item()))
170 | batch_loss.append(loss.item())
171 |
172 | loss_avg = sum(batch_loss) / len(batch_loss)
173 | print('\nLocal Train loss:', loss_avg)
174 | writer.add_scalar(f'user_{idx}/local_train_loss', loss_avg, epoch)
175 |
176 | test_result = local_test(args, net_glob, test_loader, image_trainset_weight)
177 | add_scalar(writer, idx, test_result, epoch)
178 | print('Global Test ACC:', test_result[1])
179 | print('Local Test ACC:', test_result[3])
180 |
181 | total_acc[idx][epoch] = test_result[1]
182 | local_acc[idx][epoch] = test_result[3]
183 |
184 | total_acc_final.append(test_result[1])
185 | local_acc_final.append(test_result[3])
186 | print(f'user {idx} done!')
187 |
188 | save_info = {
189 | "total_acc": total_acc,
190 | "local_acc": local_acc
191 | }
192 | save_path = f'{logdir}/local_train_epoch_acc'
193 | torch.save(save_info, save_path)
194 |
195 | total_acc = total_acc.mean(axis=0)
196 | local_acc = local_acc.mean(axis=0)
197 | for epoch in range(args.epochs):
198 | writer.add_scalar('test/global/test_acc', total_acc[epoch], epoch)
199 | writer.add_scalar('test/local/test_acc', local_acc[epoch], epoch)
200 | writer.close()
201 | #
202 | # # plot loss curve
203 | # plt.figure()
204 | # plt.title('local train acc', fontsize=20) # 标题,并设定字号大小
205 | # labels = ['local', 'total']
206 | # plt.boxplot([local_acc_final, total_acc_final], labels=labels, notch=True, showmeans=True)
207 | # plt.ylabel('test acc')
208 | # plt.savefig(f'{logdir}/local_train_acc.png')
209 |
--------------------------------------------------------------------------------
/main_nn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.utils.data import DataLoader
8 | import torch.optim as optim
9 | from torchvision import datasets, transforms
10 | from utils.options import args_parser
11 | from models.Nets import CNNCifar, vgg16
12 | from utils.util import setup_seed
13 | from torch.utils.tensorboard import SummaryWriter
14 | from datetime import datetime
15 | from models.Test import test
16 |
17 |
18 | if __name__ == '__main__':
19 |
20 | # parse args
21 | args = args_parser()
22 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
23 | setup_seed(args.seed)
24 |
25 | # log
26 | current_time = datetime.now().strftime('%b.%d_%H.%M.%S')
27 | TAG = 'nn_{}_{}_{}_{}'.format(args.dataset, args.model, args.epochs, current_time)
28 | logdir = f'runs/{TAG}'
29 | if args.debug:
30 | logdir = f'runs2/{TAG}'
31 | writer = SummaryWriter(logdir)
32 |
33 | # load dataset and split users
34 | if args.dataset == 'cifar':
35 | transform_train = transforms.Compose([
36 | transforms.RandomCrop(32, padding=4),
37 | transforms.RandomHorizontalFlip(),
38 | transforms.ToTensor(),
39 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
40 | ])
41 | dataset_train = datasets.CIFAR10('../data/cifar', train=True, transform=transform_train, download=True)
42 |
43 | # testing
44 | transform_test = transforms.Compose([
45 | transforms.ToTensor(),
46 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
47 | ])
48 | dataset_test = datasets.CIFAR10('../data/cifar', train=False, transform=transform_test, download=True)
49 | test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
50 | img_size = dataset_train[0][0].shape
51 | elif args.dataset == 'fmnist':
52 | dataset_train = datasets.FashionMNIST('../data/fmnist/', train=True, download=True,
53 | transform=transforms.Compose([
54 | transforms.Resize((32, 32)),
55 | transforms.RandomCrop(32, padding=4),
56 | transforms.RandomHorizontalFlip(),
57 | transforms.ToTensor(),
58 | transforms.Normalize((0.1307,), (0.3081,)),
59 | ]))
60 |
61 | # testing
62 | dataset_test = datasets.FashionMNIST('../data/fmnist/', train=False, download=True,
63 | transform=transforms.Compose([
64 | transforms.Resize((32, 32)),
65 | transforms.ToTensor(),
66 | transforms.Normalize((0.1307,), (0.3081,))
67 | ]))
68 | test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
69 | else:
70 | exit('Error: unrecognized dataset')
71 |
72 | # build model
73 | if args.model == 'lenet' and (args.dataset == 'cifar' or args.dataset == 'fmnist'):
74 | net_glob = CNNCifar(args=args).to(args.device)
75 | elif args.model == 'vgg' and args.dataset == 'cifar':
76 | net_glob = vgg16().to(args.device)
77 | else:
78 | exit('Error: unrecognized model')
79 | print(net_glob)
80 | img = dataset_train[0][0].unsqueeze(0).to(args.device)
81 | writer.add_graph(net_glob, img)
82 |
83 | # training
84 | creterion = nn.CrossEntropyLoss()
85 | train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)
86 | # optimizer = optim.Adam(net_glob.parameters())
87 | optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
88 | # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
89 | # # # scheduler.step()
90 |
91 | list_loss = []
92 | net_glob.train()
93 | for epoch in range(args.epochs):
94 | batch_loss = []
95 | for batch_idx, (data, target) in enumerate(train_loader):
96 | data, target = data.to(args.device), target.to(args.device)
97 | optimizer.zero_grad()
98 | output = net_glob(data)
99 | loss = creterion(output, target)
100 | loss.backward()
101 | optimizer.step()
102 | if batch_idx % 50 == 0:
103 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
104 | epoch, batch_idx * len(data), len(train_loader.dataset),
105 | 100. * batch_idx / len(train_loader), loss.item()))
106 | batch_loss.append(loss.item())
107 | # scheduler.step()
108 | loss_avg = sum(batch_loss)/len(batch_loss)
109 | print('\nTrain loss:', loss_avg)
110 | list_loss.append(loss_avg)
111 | writer.add_scalar('train_loss', loss_avg, epoch)
112 | test_acc, test_loss = test(args, net_glob, test_loader)
113 | writer.add_scalar('test_loss', test_loss, epoch)
114 | writer.add_scalar('test_acc', test_acc, epoch)
115 |
116 | # save model weights
117 | save_info = {
118 | "epochs": args.epochs,
119 | "optimizer": optimizer.state_dict(),
120 | "model": net_glob.state_dict()
121 | }
122 |
123 | save_path = f'save2/{TAG}' if args.debug else f'save2/{TAG}'
124 | torch.save(save_info, save_path)
125 | writer.close()
126 |
--------------------------------------------------------------------------------
/main_per_fb.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | from torch.utils.data import DataLoader
6 | import torch.optim as optim
7 | from torchvision import datasets, transforms
8 | from utils.options import args_parser
9 | from models.Nets import CNNGate
10 | from torch.utils.tensorboard import SummaryWriter
11 | from datetime import datetime
12 | from utils.sampling import cifar_noniid
13 | from models.Update import DatasetSplit
14 | from utils.util import *
15 | from models.Test import user_per_test
16 | import copy
17 |
18 |
19 | if __name__ == '__main__':
20 | # parse args
21 | args = args_parser()
22 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
23 | setup_seed(args.seed)
24 |
25 | # log
26 | current_time = datetime.now().strftime('%b.%d_%H.%M.%S')
27 | TAG = 'exp/non_iid/per_fb_{}_{}_{}_{}_user{}_{}'.format(args.dataset, args.model, args.epochs, args.alpha,
28 | args.num_users, current_time)
29 | logdir = f'runs/{TAG}'
30 | if args.debug:
31 | logdir = f'runs2/{TAG}'
32 | writer = SummaryWriter(logdir)
33 |
34 | # load dataset and split users
35 | train_loader, test_loader, class_weight, dict_users, dataset_train = 1, 1, 1, 1, 1
36 | if args.dataset == 'mnist':
37 | dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
38 | transform=transforms.Compose([
39 | transforms.ToTensor(),
40 | transforms.Normalize((0.1307,), (0.3081,))
41 | ]))
42 | # testing
43 | dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True,
44 | transform=transforms.Compose([
45 | transforms.ToTensor(),
46 | transforms.Normalize((0.1307,), (0.3081,))
47 | ]))
48 |
49 | elif args.dataset == 'cifar':
50 | save_dataset_path = f'./data/cifar_non_iid{args.alpha}_user{args.num_users}_fast_data'
51 | # global_weight = torch.load('./save/nn_cifar_cnn_100_Oct.13_19.45.20')['model']
52 | # global_weight = torch.load(f'./save/exp/fed/{args.dataset}_{args.model}_1000_C0.1_iidFalse_{args.alpha}_user{args.num_users}_1000es')['model']
53 | global_weight = torch.load(f'./save/exp/fed/cifar_lenet_1000_C0.1_iidFalse_0.2_user30*3_Nov.16_14.37.35_1000es')['model']
54 | if args.rebuild:
55 | # training
56 | transform_train = transforms.Compose([
57 | transforms.RandomCrop(32, padding=4),
58 | transforms.RandomHorizontalFlip(),
59 | transforms.ToTensor(),
60 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
61 | ])
62 | dataset_train = datasets.CIFAR10('../data/cifar', train=True, transform=transform_train, download=True)
63 | # testing
64 | transform_test = transforms.Compose([
65 | transforms.ToTensor(),
66 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
67 | ])
68 | dataset_test = datasets.CIFAR10('../data/cifar', train=False, transform=transform_test, download=True)
69 | # non_iid
70 | dict_users, _ = cifar_noniid(dataset_train, args.num_users, args.alpha)
71 |
72 | save_dataset = {
73 | "dataset_test": dataset_test,
74 | "dataset_train": dataset_train,
75 | "dict_users": dict_users
76 | }
77 | torch.save(save_dataset, save_dataset_path)
78 | else:
79 | save_dataset = torch.load(save_dataset_path)
80 | dataset_test = save_dataset['dataset_test']
81 | dataset_train = save_dataset['dataset_train']
82 | dict_users = save_dataset['dict_users']
83 |
84 | test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
85 | for k, v in dict_users.items():
86 | writer.add_histogram(f'user_{k}/data_distribution',
87 | np.array(dataset_train.targets)[v],
88 | bins=np.arange(11))
89 | writer.add_histogram(f'all_user/data_distribution',
90 | np.array(dataset_train.targets)[v],
91 | bins=np.arange(11), global_step=k)
92 | img_size = dataset_train[0][0].shape
93 | elif args.dataset == 'fmnist':
94 | pass
95 | else:
96 | exit('Error: unrecognized dataset')
97 |
98 | # build model
99 | net_glob = CNNGate(args=args).to(args.device)
100 | image, target = next(iter(test_loader))
101 | writer.add_graph(net_glob, image.to(args.device))
102 |
103 | local_acc = np.zeros([args.num_users, args.epochs + 1])
104 | total_acc = np.zeros([args.num_users, args.epochs + 1])
105 |
106 | for user_num in range(len(dict_users)):
107 | # user train data
108 | user_train = DatasetSplit(dataset_train, dict_users[user_num])
109 | train_loader = DataLoader(user_train, batch_size=64, shuffle=True)
110 |
111 | class_weight = np.zeros(10)
112 | for image, label in user_train:
113 | class_weight[label] += 1
114 | class_weight /= sum(class_weight)
115 |
116 | # init
117 |
118 | # global_weight = torch.load('./save/fed_cifar_cnn_1000_C0.1_iidFalse_0.9_Nov.05_09.31.38_500es')['model']
119 | net_glob.load_state_dict(global_weight, False)
120 | net_glob.pfc1.load_state_dict({'weight': global_weight['fc1.weight'], 'bias': global_weight['fc1.bias']})
121 | net_glob.pfc2.load_state_dict({'weight': global_weight['fc2.weight'], 'bias': global_weight['fc2.bias']})
122 | net_glob.pfc3.load_state_dict({'weight': global_weight['fc3.weight'], 'bias': global_weight['fc3.bias']})
123 |
124 | # training
125 | optimizer = optim.SGD([
126 | {'params': net_glob.pfc1.parameters()},
127 | {'params': net_glob.pfc2.parameters()},
128 | {'params': net_glob.pfc3.parameters()},
129 | ], lr=0.001, momentum=0.9, weight_decay=5e-4)
130 | criterion = nn.CrossEntropyLoss()
131 |
132 | test_result = user_per_test(args, net_glob, test_loader, class_weight)
133 | add_scalar(writer, user_num, test_result, 0)
134 | total_acc[user_num][0] = test_result[1]
135 | local_acc[user_num][0] = test_result[3]
136 |
137 | for epoch in range(1, args.epochs+1):
138 | net_glob.train()
139 | batch_loss = []
140 | gate_out = []
141 | for batch_idx, (data, target) in enumerate(train_loader):
142 | data, target = data.to(args.device), target.to(args.device)
143 | optimizer.zero_grad()
144 | output, g, z = net_glob(data)
145 | gate_out.append(g)
146 | loss = criterion(z, target)
147 | loss.backward()
148 | optimizer.step()
149 | batch_loss.append(loss.item())
150 | writer.add_histogram(f"user_{user_num}/pfc1/weight", net_glob.pfc1.weight, epoch)
151 | writer.add_histogram(f"user_{user_num}/pfc2/weight", net_glob.pfc2.weight, epoch)
152 | writer.add_histogram(f"user_{user_num}/pfc3/weight", net_glob.pfc3.weight, epoch)
153 | loss_avg = sum(batch_loss) / len(batch_loss)
154 | print(f'User {user_num} Train loss:', loss_avg)
155 | writer.add_scalar(f'user_{user_num}/pfc_train_loss', loss_avg, epoch)
156 |
157 | test_result = user_per_test(args, net_glob, test_loader, class_weight)
158 | add_scalar(writer, user_num, test_result, epoch)
159 | total_acc[user_num][epoch] = test_result[1]
160 | local_acc[user_num][epoch] = test_result[3]
161 |
162 | save_info = {
163 | "total_acc": total_acc,
164 | "local_acc": local_acc
165 | }
166 | save_path = f'{logdir}/local_train_epoch_acc'
167 | torch.save(save_info, save_path)
168 |
169 | total_acc = total_acc.mean(axis=0)
170 | local_acc = local_acc.mean(axis=0)
171 | for epoch, _ in enumerate(total_acc):
172 | writer.add_scalar('test/global/test_acc', total_acc[epoch], epoch)
173 | writer.add_scalar('test/local/test_acc', local_acc[epoch], epoch)
174 | writer.close()
175 |
--------------------------------------------------------------------------------
/models/Fed.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | import copy
6 | import torch
7 | from torch import nn
8 |
9 |
10 | def FedAvg(w):
11 | w_avg = copy.deepcopy(w[0])
12 | for k in w_avg.keys():
13 | for i in range(1, len(w)):
14 | w_avg[k] += w[i][k]
15 | w_avg[k] = torch.div(w_avg[k], len(w))
16 | return w_avg
17 |
--------------------------------------------------------------------------------
/models/Nets.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 | from thop import profile
9 |
10 |
11 | class MLP(nn.Module):
12 | def __init__(self, dim_in, dim_hidden, dim_out):
13 | super(MLP, self).__init__()
14 | self.layer_input = nn.Linear(dim_in, dim_hidden)
15 | self.relu = nn.ReLU()
16 | self.dropout = nn.Dropout()
17 | self.layer_hidden = nn.Linear(dim_hidden, dim_out)
18 |
19 | def forward(self, x):
20 | x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
21 | x = self.layer_input(x)
22 | x = self.dropout(x)
23 | x = self.relu(x)
24 | x = self.layer_hidden(x)
25 | return x
26 |
27 |
28 | class CNNMnist(nn.Module):
29 | def __init__(self, args):
30 | super(CNNMnist, self).__init__()
31 | self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
32 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
33 | self.conv2_drop = nn.Dropout2d()
34 | self.fc1 = nn.Linear(320, 50)
35 | self.fc2 = nn.Linear(50, args.num_classes)
36 |
37 | def forward(self, x):
38 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
39 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
40 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
41 | x = F.relu(self.fc1(x))
42 | x = F.dropout(x, training=self.training)
43 | x = self.fc2(x)
44 | return x
45 |
46 |
47 | class CNNCifar(nn.Module):
48 | def __init__(self, args):
49 | super(CNNCifar, self).__init__()
50 | self.conv1 = nn.Conv2d(1 if args.dataset == 'fmnist' else 3, 6, 5)
51 | self.pool = nn.MaxPool2d(2, 2)
52 | self.conv2 = nn.Conv2d(6, 16, 5)
53 | self.fc1 = nn.Linear(16 * 5 * 5, 120)
54 | self.fc2 = nn.Linear(120, 84)
55 | self.fc3 = nn.Linear(84, args.num_classes)
56 |
57 | def forward(self, x):
58 | x = self.pool(F.relu(self.conv1(x)))
59 | x = self.pool(F.relu(self.conv2(x)))
60 | x = x.view(-1, 16 * 5 * 5)
61 | x = F.relu(self.fc1(x))
62 | x = F.relu(self.fc2(x))
63 | x = self.fc3(x)
64 | return x
65 |
66 |
67 | class CNNGate(nn.Module):
68 | def __init__(self, args):
69 | super(CNNGate, self).__init__()
70 | self.args = args
71 | self.conv1 = nn.Conv2d(1 if args.dataset == 'fmnist' else 3, 6, 5)
72 | self.conv2 = nn.Conv2d(6, 16, 5)
73 | self.fc1 = nn.Linear(16 * 5 * 5, 120)
74 | self.fc2 = nn.Linear(120, 84)
75 | self.fc3 = nn.Linear(84, self.args.num_classes)
76 |
77 | for p in self.parameters():
78 | p.requires_grad = False
79 |
80 | self.gate = nn.Linear(32 * 32 * (3 if args.dataset == 'cifar' else 1), 1) if args.struct else nn.Linear(16 * 5 * 5, 1)
81 | self.pfc1 = nn.Linear(16 * 5 * 5, 120)
82 | self.pfc2 = nn.Linear(120, 84)
83 | self.pfc3 = nn.Linear(84, args.num_classes)
84 |
85 | def forward(self, x1):
86 | x = F.max_pool2d(F.relu(self.conv1(x1)), 2)
87 | x = F.max_pool2d(F.relu(self.conv2(x)), 2)
88 | x = torch.flatten(x, 1)
89 |
90 | z = F.relu(self.pfc1(x))
91 | z = F.relu(self.pfc2(z))
92 | z = self.pfc3(z)
93 |
94 | g = torch.sigmoid(self.gate(torch.flatten(x1, 1))) if self.args.struct else torch.sigmoid(self.gate(x))
95 | # g =
96 | y = F.relu(self.fc1(x))
97 | y = F.relu(self.fc2(y))
98 | y = self.fc3(y)
99 | return y * g + z * (1-g), g, z
100 | # return z
101 |
102 |
103 | '''
104 | Modified from https://github.com/pytorch/vision.git
105 | '''
106 | import math
107 |
108 | import torch.nn as nn
109 | import torch.nn.init as init
110 |
111 | __all__ = [
112 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
113 | 'vgg19_bn', 'vgg19',
114 | ]
115 |
116 |
117 | class VGG(nn.Module):
118 | '''
119 | VGG model
120 | '''
121 |
122 | def __init__(self, features, has_gate=False, struct=False):
123 | super(VGG, self).__init__()
124 | self.has_gate = has_gate
125 | self.struct = struct
126 | self.features = features
127 | self.classifier = nn.Sequential(
128 | nn.Dropout(),
129 | nn.Linear(512, 512),
130 | nn.ReLU(True),
131 | nn.Dropout(),
132 | nn.Linear(512, 512),
133 | nn.ReLU(True),
134 | nn.Linear(512, 10),
135 | )
136 | if has_gate:
137 | self.pclassifier = nn.Sequential(
138 | nn.Dropout(),
139 | nn.Linear(512, 512),
140 | nn.ReLU(True),
141 | nn.Dropout(),
142 | nn.Linear(512, 512),
143 | nn.ReLU(True),
144 | nn.Linear(512, 10),
145 | )
146 | if self.struct:
147 | self.gate = nn.Linear(3 * 32 * 32, 1)
148 | else:
149 | self.gate = nn.Linear(512, 1)
150 | # Initialize weights
151 | for m in self.modules():
152 | if isinstance(m, nn.Conv2d):
153 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
154 | m.weight.data.normal_(0, math.sqrt(2. / n))
155 | m.bias.data.zero_()
156 |
157 | def forward(self, input):
158 | x = self.features(input)
159 | x = x.view(x.size(0), -1)
160 | if self.has_gate:
161 | if self.struct:
162 | g = torch.sigmoid(self.gate(torch.flatten(input, 1)))
163 | else:
164 | g = torch.sigmoid(self.gate(x))
165 | y = self.classifier(x)
166 | z = self.pclassifier(x)
167 | return y * g + z * (1-g), g, z
168 | else:
169 | x = self.classifier(x)
170 | return x
171 |
172 |
173 | def make_layers(cfg, batch_norm=False):
174 | layers = []
175 | in_channels = 3
176 | for v in cfg:
177 | if v == 'M':
178 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
179 | else:
180 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
181 | if batch_norm:
182 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
183 | else:
184 | layers += [conv2d, nn.ReLU(inplace=True)]
185 | in_channels = v
186 | return nn.Sequential(*layers)
187 |
188 |
189 | cfg = {
190 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
191 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
192 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
193 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
194 | 512, 512, 512, 512, 'M'],
195 | }
196 |
197 |
198 | def vgg11():
199 | """VGG 11-layer model (configuration "A")"""
200 | return VGG(make_layers(cfg['A']))
201 |
202 |
203 | def vgg11_bn():
204 | """VGG 11-layer model (configuration "A") with batch normalization"""
205 | return VGG(make_layers(cfg['A'], batch_norm=True))
206 |
207 |
208 | def vgg13():
209 | """VGG 13-layer model (configuration "B")"""
210 | return VGG(make_layers(cfg['B']))
211 |
212 |
213 | def vgg13_bn():
214 | """VGG 13-layer model (configuration "B") with batch normalization"""
215 | return VGG(make_layers(cfg['B'], batch_norm=True))
216 |
217 |
218 | def vgg16():
219 | """VGG 16-layer model (configuration "D")"""
220 | return VGG(make_layers(cfg['D']))
221 |
222 |
223 | def gate_vgg16(args):
224 | return VGG(make_layers(cfg['D']), has_gate=True, struct=args.struct)
225 |
226 |
227 | def vgg16_bn():
228 | """VGG 16-layer model (configuration "D") with batch normalization"""
229 | return VGG(make_layers(cfg['D'], batch_norm=True))
230 |
231 |
232 | def vgg19():
233 | """VGG 19-layer model (configuration "E")"""
234 | return VGG(make_layers(cfg['E']))
235 |
236 |
237 | def vgg19_bn():
238 | """VGG 19-layer model (configuration 'E') with batch normalization"""
239 | return VGG(make_layers(cfg['E'], batch_norm=True))
240 |
241 |
242 | def cifar_test():
243 | # from utils.options import args_parser
244 | net = vgg16().to("cuda:0")
245 | img = torch.randn(100, 3, 32, 32).to("cuda:0")
246 | flops, params = profile(net, inputs=(img, ))
247 | print(flops, params)
248 | print(img.size())
249 |
250 | # test()
251 | # cifar_test()
252 |
253 |
254 |
--------------------------------------------------------------------------------
/models/Test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @python: 3.6
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 | from torch.utils.data import DataLoader
9 | import numpy as np
10 |
11 |
12 | def test(args, net_g, data_loader):
13 | # testing
14 | net_g.eval()
15 | test_loss = []
16 | correct = 0
17 | with torch.no_grad():
18 | for idx, (data, target) in enumerate(data_loader):
19 | data, target = data.to(args.device), target.to(args.device)
20 | log_probs = net_g(data)
21 | test_loss.append(nn.CrossEntropyLoss()(log_probs, target).item())
22 | y_pred = log_probs.data.max(1, keepdim=True)[1]
23 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum().item()
24 |
25 | loss_avg = sum(test_loss)/len(test_loss)
26 | test_acc = 100. * correct / len(data_loader.dataset)
27 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
28 | loss_avg, correct, len(data_loader.dataset), test_acc))
29 |
30 | return test_acc, loss_avg
31 |
32 |
33 | def test_img(net_g, datatest, args):
34 | net_g.eval()
35 | # testing
36 | test_loss = 0
37 | correct = 0
38 | data_loader = DataLoader(datatest, batch_size=args.test_bs)
39 | l = len(data_loader)
40 | with torch.no_grad():
41 | for idx, (data, target) in enumerate(data_loader):
42 | if args.gpu != -1:
43 | data, target = data.to(args.device), target.to(args.device)
44 | log_probs = net_g(data)
45 | # sum up batch loss
46 | test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
47 | # get the index of the max log-probability
48 | y_pred = log_probs.data.max(1, keepdim=True)[1]
49 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
50 |
51 | test_loss /= len(data_loader.dataset)
52 | accuracy = 100.00 * correct.item() / len(data_loader.dataset)
53 | # if args.verbose:
54 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
55 | test_loss, correct, len(data_loader.dataset), accuracy))
56 | return accuracy, test_loss
57 |
58 |
59 | def user_test(args, net_glob, data_loader, class_weight):
60 | #testing
61 | net_glob.eval()
62 | correct_class = np.zeros(10)
63 | class_loss = np.zeros(10)
64 | correct_class_acc = np.zeros(10)
65 | class_loss_avg = np.zeros(10)
66 | correct_class_size = np.zeros(10)
67 | correct = 0.0
68 | dataset_size = len(data_loader.dataset)
69 | total_loss = 0.0
70 | with torch.no_grad():
71 | for idx, (data, target) in enumerate(data_loader):
72 | data, target = data.to(args.device), target.to(args.device)
73 | output, g, z = net_glob(data)
74 | # g = (g > 0.5).float()
75 | # output = y * g + z * (1-g)
76 | pred = output.max(1)[1]
77 | correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
78 | loss = nn.CrossEntropyLoss(reduction='none')(output, target)
79 | total_loss += loss.sum().item()
80 | for i in range(10):
81 | class_ind = target.data.view_as(pred).eq(i * torch.ones_like(pred))
82 | correct_class_size[i] += class_ind.cpu().sum().item()
83 | correct_class[i] += (pred.eq(target.data.view_as(pred)) * class_ind).cpu().sum().item()
84 | class_loss[i] += (loss*class_ind.float()).cpu().sum().item()
85 |
86 | acc = 100.0 * (float(correct) / float(dataset_size))
87 | total_l = total_loss / dataset_size
88 | for i in range(10):
89 | correct_class_acc[i] = (float(correct_class[i]) / float(correct_class_size[i]))
90 | class_loss_avg[i] = (float(class_loss[i]) / float(correct_class_size[i]))
91 | user_acc = correct_class_acc * class_weight
92 | user_loss = class_loss_avg * class_weight
93 | return total_l, acc, user_loss.sum(), 100*user_acc.sum()
94 |
95 |
96 | def user_per_test(args, net_glob, data_loader, class_weight):
97 | #testing
98 | net_glob.eval()
99 | correct_class = np.zeros(10)
100 | class_loss = np.zeros(10)
101 | correct_class_acc = np.zeros(10)
102 | class_loss_avg = np.zeros(10)
103 | correct_class_size = np.zeros(10)
104 | correct = 0.0
105 | dataset_size = len(data_loader.dataset)
106 | total_loss = 0.0
107 | with torch.no_grad():
108 | for idx, (data, target) in enumerate(data_loader):
109 | data, target = data.to(args.device), target.to(args.device)
110 | output, g, z = net_glob(data)
111 | pred = z.max(1)[1]
112 | correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
113 | loss = nn.CrossEntropyLoss(reduction='none')(z, target)
114 | total_loss += loss.sum().item()
115 | for i in range(10):
116 | class_ind = target.data.view_as(pred).eq(i * torch.ones_like(pred))
117 | correct_class_size[i] += class_ind.cpu().sum().item()
118 | correct_class[i] += (pred.eq(target.data.view_as(pred)) * class_ind).cpu().sum().item()
119 | class_loss[i] += (loss*class_ind.float()).cpu().sum().item()
120 |
121 | acc = 100.0 * (float(correct) / float(dataset_size))
122 | total_l = total_loss / dataset_size
123 | for i in range(10):
124 | correct_class_acc[i] = (float(correct_class[i]) / float(correct_class_size[i]))
125 | class_loss_avg[i] = (float(class_loss[i]) / float(correct_class_size[i]))
126 | user_acc = correct_class_acc * class_weight
127 | user_loss = class_loss_avg * class_weight
128 | return total_l, acc, user_loss.sum(), 100*user_acc.sum()
129 |
130 |
131 | def local_test(args, net_glob, data_loader, class_weight):
132 | #testing
133 | net_glob.eval()
134 | correct_class = np.zeros(10)
135 | class_loss = np.zeros(10)
136 | correct_class_acc = np.zeros(10)
137 | class_loss_avg = np.zeros(10)
138 | correct_class_size = np.zeros(10)
139 | correct = 0.0
140 | dataset_size = len(data_loader.dataset)
141 | total_loss = 0.0
142 | with torch.no_grad():
143 | for idx, (data, target) in enumerate(data_loader):
144 | data, target = data.to(args.device), target.to(args.device)
145 | output = net_glob(data)
146 | pred = output.max(1)[1]
147 | correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
148 | loss = nn.CrossEntropyLoss(reduction='none')(output, target)
149 | total_loss += loss.sum().item()
150 | for i in range(10):
151 | class_ind = target.data.view_as(pred).eq(i * torch.ones_like(pred))
152 | correct_class_size[i] += class_ind.cpu().sum().item()
153 | correct_class[i] += (pred.eq(target.data.view_as(pred)) * class_ind).cpu().sum().item()
154 | class_loss[i] += (loss*class_ind.float()).cpu().sum().item()
155 |
156 | acc = 100.0 * (float(correct) / float(dataset_size))
157 | total_l = total_loss / dataset_size
158 | for i in range(10):
159 | correct_class_acc[i] = (float(correct_class[i]) / float(correct_class_size[i]))
160 | class_loss_avg[i] = (float(class_loss[i]) / float(correct_class_size[i]))
161 | user_acc = correct_class_acc * class_weight
162 | user_loss = class_loss_avg * class_weight
163 | return total_l, acc, user_loss.sum(), 100*user_acc.sum()
164 |
--------------------------------------------------------------------------------
/models/Update.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | import torch
6 | from torch import nn, autograd
7 | from torch.utils.data import DataLoader, Dataset
8 | import numpy as np
9 | import random
10 | from sklearn import metrics
11 |
12 |
13 | class DatasetSplit(Dataset):
14 | def __init__(self, dataset, idxs):
15 | self.dataset = dataset
16 | self.targets = dataset.targets
17 | self.idxs = list(idxs)
18 |
19 | def __len__(self):
20 | return len(self.idxs)
21 |
22 | def __getitem__(self, item):
23 | image, label = self.dataset[self.idxs[item]]
24 | return image, label
25 |
26 |
27 | class LocalUpdate(object):
28 | def __init__(self, args, dataset=None, idxs=None):
29 | self.args = args
30 | self.loss_func = nn.CrossEntropyLoss()
31 | self.selected_clients = []
32 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)
33 |
34 | def train(self, net):
35 | net.train()
36 | # train and update
37 |
38 | # optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.9, weight_decay=5e-4)
39 | optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.5)
40 |
41 | epoch_loss = []
42 | for iter in range(self.args.local_ep):
43 | batch_loss = []
44 | for batch_idx, (images, labels) in enumerate(self.ldr_train):
45 | images, labels = images.to(self.args.device), labels.to(self.args.device)
46 | net.zero_grad()
47 | log_probs = net(images)
48 | loss = self.loss_func(log_probs, labels)
49 | loss.backward()
50 | optimizer.step()
51 | if self.args.verbose and batch_idx % 10 == 0:
52 | print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
53 | iter, batch_idx * len(images), len(self.ldr_train.dataset),
54 | 100. * batch_idx / len(self.ldr_train), loss.item()))
55 | batch_loss.append(loss.item())
56 | epoch_loss.append(sum(batch_loss)/len(batch_loss))
57 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
58 |
59 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @python: 3.6
4 |
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.1.0
2 | torchvision==0.3.0
3 |
4 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @python: 3.6
4 |
5 |
--------------------------------------------------------------------------------
/utils/options.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | import argparse
6 |
7 |
8 | def args_parser():
9 | parser = argparse.ArgumentParser()
10 | # federated arguments
11 | parser.add_argument('--epochs', type=int, default=10, help="rounds of training")
12 | parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
13 | parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C")
14 | parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")
15 | parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
16 | parser.add_argument('--test_bs', type=int, default=128, help="test batch size")
17 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
18 | parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")
19 | parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample")
20 |
21 | # model arguments
22 | parser.add_argument('--model', type=str, default='mlp', help='model name')
23 | parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
24 | parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
25 | help='comma-separated kernel size to use for convolution')
26 | parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
27 | parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets")
28 | parser.add_argument('--max_pool', type=str, default='True',
29 | help="Whether use max pooling rather than strided convolutions")
30 |
31 | # other arguments
32 | parser.add_argument('--rebuild', action='store_true', help="rebuild train data")
33 | parser.add_argument('--struct', action='store_true', help="intermediate or raw data in gate model")
34 | parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
35 | parser.add_argument('--iid', action='store_true', help='whether i.i.d or not')
36 | parser.add_argument('--alpha', type=float, default=0.9, help='non-iid control')
37 | parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
38 | parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges")
39 | parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
40 | parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
41 | parser.add_argument('--verbose', action='store_true', help='verbose print')
42 | parser.add_argument('--debug', action='store_true', help='no runs event')
43 | parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
44 | args = parser.parse_args()
45 | return args
46 |
--------------------------------------------------------------------------------
/utils/sampling.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 |
6 | import numpy as np
7 | from torchvision import datasets, transforms
8 | from collections import defaultdict
9 | import random
10 |
11 |
12 | def mnist_iid(dataset, num_users):
13 | """
14 | Sample I.I.D. client data from MNIST dataset
15 | :param dataset:
16 | :param num_users:
17 | :return: dict of image index
18 | """
19 | num_items = int(len(dataset)/num_users)
20 | dict_users, all_idxs = {}, [i for i in range(len(dataset))]
21 | for i in range(num_users):
22 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
23 | all_idxs = list(set(all_idxs) - dict_users[i])
24 | return dict_users
25 |
26 |
27 | def mnist_noniid(dataset, num_users):
28 | """
29 | Sample non-I.I.D client data from MNIST dataset
30 | :param dataset:
31 | :param num_users:
32 | :return:
33 | """
34 | num_shards, num_imgs = 200, 300
35 | idx_shard = [i for i in range(num_shards)]
36 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
37 | idxs = np.arange(num_shards*num_imgs)
38 | labels = dataset.train_labels.numpy()
39 |
40 | # sort labels
41 | idxs_labels = np.vstack((idxs, labels))
42 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]
43 | idxs = idxs_labels[0,:]
44 |
45 | # divide and assign
46 | for i in range(num_users):
47 | rand_set = set(np.random.choice(idx_shard, 2, replace=False))
48 | idx_shard = list(set(idx_shard) - rand_set)
49 | for rand in rand_set:
50 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
51 | return dict_users
52 |
53 |
54 | def cifar_iid(dataset, num_users):
55 | """
56 | Sample I.I.D. client data from CIFAR10 dataset
57 | :param dataset:
58 | :param num_users:
59 | :return: dict of image index
60 | """
61 | num_items = int(len(dataset)/num_users)
62 | dict_users, all_idxs = {}, [i for i in range(len(dataset))]
63 | for i in range(num_users):
64 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
65 | all_idxs = list(set(all_idxs) - dict_users[i])
66 | return dict_users
67 |
68 |
69 | def cifar_noniid(dataset, no_participants, alpha=0.9):
70 | """
71 | Input: Number of participants and alpha (param for distribution)
72 | Output: A list of indices denoting data in CIFAR training set.
73 | Requires: cifar_classes, a preprocessed class-indice dictionary.
74 | Sample Method: take a uniformly sampled 10-dimension vector as parameters for
75 | dirichlet distribution to sample number of images in each class.
76 | """
77 | np.random.seed(666)
78 | random.seed(666)
79 | cifar_classes = {}
80 | for ind, x in enumerate(dataset):
81 | _, label = x
82 | if label in cifar_classes:
83 | cifar_classes[label].append(ind)
84 | else:
85 | cifar_classes[label] = [ind]
86 |
87 | per_participant_list = defaultdict(list)
88 | no_classes = len(cifar_classes.keys())
89 | class_size = len(cifar_classes[0])
90 | datasize = {}
91 | for n in range(no_classes):
92 | random.shuffle(cifar_classes[n])
93 | sampled_probabilities = class_size * np.random.dirichlet(
94 | np.array(no_participants * [alpha]))
95 | for user in range(no_participants):
96 | no_imgs = int(round(sampled_probabilities[user]))
97 | datasize[user, n] = no_imgs
98 | sampled_list = cifar_classes[n][:min(len(cifar_classes[n]), no_imgs)]
99 | per_participant_list[user].extend(sampled_list)
100 | cifar_classes[n] = cifar_classes[n][min(len(cifar_classes[n]), no_imgs):]
101 | train_img_size = np.zeros(no_participants)
102 | for i in range(no_participants):
103 | train_img_size[i] = sum([datasize[i,j] for j in range(10)])
104 | clas_weight = np.zeros((no_participants,10))
105 | for i in range(no_participants):
106 | for j in range(10):
107 | clas_weight[i,j] = float(datasize[i,j])/float((train_img_size[i]))
108 | return per_participant_list, clas_weight
109 |
110 |
111 | if __name__ == '__main__':
112 | dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
113 | transform=transforms.Compose([
114 | transforms.ToTensor(),
115 | transforms.Normalize((0.1307,), (0.3081,))
116 | ]))
117 | num = 100
118 | d = mnist_noniid(dataset_train, num)
119 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import torch.backends.cudnn
2 | import torch.cuda
3 | import numpy as np
4 | import random
5 | import torch.nn as nn
6 |
7 |
8 | def setup_seed(seed):
9 | torch.manual_seed(seed+1)
10 | torch.cuda.manual_seed_all(seed+123)
11 | np.random.seed(seed+1234)
12 | random.seed(seed+12345)
13 | torch.backends.cudnn.deterministic = True
14 |
15 |
16 | def add_scalar(writer, user_num, test_result, epoch):
17 | test_loss, test_acc, user_loss, user_acc = test_result
18 | writer.add_scalar(f'user_{user_num}/global/test_loss', test_loss, epoch)
19 | writer.add_scalar(f'user_{user_num}/global/test_acc', test_acc, epoch)
20 | writer.add_scalar(f'user_{user_num}/local/test_loss', user_loss, epoch)
21 | writer.add_scalar(f'user_{user_num}/local/test_acc', user_acc, epoch)
22 |
23 |
--------------------------------------------------------------------------------