├── README.md
├── requirements.txt
└── src
├── FedHQ_main.py
├── __pycache__
├── cifar_model.cpython-37.pyc
├── models.cpython-37.pyc
├── models_without_quant.cpython-37.pyc
├── options.cpython-37.pyc
├── quantizer.cpython-37.pyc
├── quantizer2.cpython-37.pyc
├── sampling.cpython-37.pyc
├── train.cpython-37.pyc
├── update.cpython-37.pyc
└── utils.cpython-37.pyc
├── cifar_model.py
├── models.py
├── options.py
├── quantizer.py
├── sampling.py
├── train.py
├── update.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Dynamic Aggregation for Heterogeneous Quantization in Federated Learning (PyTorch)
2 |
3 | Implementation of dynamic aggregation for heterogeneous quantization in federated learning.
4 |
5 | ## References
6 | The experiements refer to the papers as following. The paper link and GitHub link are given.
7 | ### Papers:
8 | * [Communication-Efficient Learning of Deep Networks from Decentralized Data](https://arxiv.org/abs/1602.05629) : [GitHub](https://github.com/AshwinRJ/Federated-Learning-PyTorch)
9 | * [SWALP: Stochastic Weight Averaging in Low-Precision Training](https://arxiv.org/abs/1904.11943v2) : [GitHub](https://github.com/stevenygd/SWALP)
10 |
11 | ## Requirements
12 | Requirments.txt gives the detail requirements.
13 | * Python3
14 | * Pytorch
15 | * Torchvision
16 |
17 | ## Data
18 | * The experiments of FedHQ are run on MNIST and Cifar.
19 | * You can choose download the data through the code.
20 |
21 | ## Options
22 | #### FedHQ Parameters
23 | * ```--epochs:``` Number of communication rounds (T in the paper). Default is 150.
24 | * ```--num_users:```Number of clients (n in the paper). Default is 100.
25 | * ```--frac:``` Fraction of users to be used for federated updates (C in the paper). Default is 0.1.
26 | * ```--local_ep:``` Number of local training epochs in each user (K in the paper). Default is 1.
27 | * ```--local_bs:``` Batch size of local updates in each user (B in the paper). Default is 600.
28 | * ```--lr:``` Learning rate (η in the paper). Default is 0.1.
29 | * ```--optimizer: ``` The optimizer used. Default is sgd.
30 | * ```--momentum:``` Momentum of optimizer (M in the paper). Default is 0.5.
31 | * ```--weight_decay:``` Weight decay of optimizer (λ in the paper). Default: 0.0005.
32 | * ```--average_scheme:``` Decide the average scheme. Default is FedHQ.
33 | * ```--dataset:``` Name of dataset. Default is mnist.
34 | * ```--gpu:``` To use CPU or GPU. Default set 1 to use GPU.
35 | * ```--iid:``` Distribution of data amongst clients. Default set 1 for IID.
36 | * ```--bit_4_ratio:``` The ratio for 4-bit quantization clients.
37 | * ```--bit_8_ratio:``` The ratio for 8-bit quantization clients.
38 |
39 | In our experiment, the sum of 'bit_4_ratio' and 'bit_8_ratio' is 1.
40 |
41 | ## FedHQ Experiments
42 | The detail results of our experiment refer to the Section 6 of the paper. All the commands are given when running directory is FedHQ folder.
43 | #### Results on MNIST:
44 | * To run the FedHQ experiment with MNIST under IID condition using GPU:
45 | ```
46 | python src/FedHQ_main.py --dataset=mnist --frac=1 --local_bs=600 --average_scheme=FedHQ --bit_4_ratio=0 --bit_8_ratio=1
47 | ```
48 | * To run the FedHQ experiment with MNIST under non-IID condition using GPU:
49 | ```
50 | python src/FedHQ_main.py --dataset=mnist --iid=0 --frac=1 --local_ep=1 --local_bs=600 --average_scheme=FedHQ --bit_4_ratio=0 --bit_8_ratio=1
51 | ```
52 | Parameters setting as follows(only list the parameters differing from default):
53 | * ```frac: ``` 1
54 | Learning-rate decay is 0.9 per ten rounds. The ratios of 4-bit quantization clients are [0,0.2,0.4,0.6,0.8,1].
55 |
56 | ```Table 1:``` Number of communication round to reach different target accuracy on MNIST dataset, IID partition.
57 |
58 |
59 | Quantizationbits: ratio |
60 | Schemes |
61 | Accuracy |
62 |
63 |
64 | 60% |
65 | 70% |
66 | 80% |
67 | 90% |
68 | 92% |
69 | 94% |
70 | 95% |
71 |
72 |
73 | 4-bit:0 8-bit:1 |
74 | FegAvg |
75 | 13 |
76 | 15 |
77 | 25 |
78 | 33 |
79 | 42 |
80 | 46 |
81 | 65 |
82 |
83 |
84 | FedHQ+ |
85 | 13 |
86 | 14 |
87 | 22 |
88 | 35 |
89 | 39 |
90 | 47 |
91 | 63 |
92 |
93 |
94 | 4-bit:0.2 8-bit:0.8 |
95 | FegAvg |
96 | 12 |
97 | 18 |
98 | 19 |
99 | 32 |
100 | 42 |
101 | 54 |
102 | 82 |
103 |
104 |
105 | Proportional |
106 | 15 |
107 | 21 |
108 | 22 |
109 | 32 |
110 | 38 |
111 | 50 |
112 | 73 |
113 |
114 |
115 | FedHQ+ |
116 | 13 |
117 | 15 |
118 | 25 |
119 | 33 |
120 | 35 |
121 | 47 |
122 | 61 |
123 |
124 |
125 | 4-bit:0.4 8-bit:0.6 |
126 | FegAvg |
127 | 17 |
128 | 22 |
129 | 24 |
130 | 42 |
131 | 45 |
132 | 62 |
133 | 104 |
134 |
135 |
136 | Proportional |
137 | 11 |
138 | 17 |
139 | 22 |
140 | 37 |
141 | 41 |
142 | 53 |
143 | 83 |
144 |
145 |
146 | FedHQ+ |
147 | 12 |
148 | 17 |
149 | 25 |
150 | 34 |
151 | 38 |
152 | 50 |
153 | 61 |
154 |
155 |
156 | 4-bit:0.6 8-bit:0.4 |
157 | FegAvg |
158 | 13 |
159 | 31 |
160 | * |
161 | * |
162 | * |
163 | * |
164 | * |
165 |
166 |
167 | Proportional |
168 | 11 |
169 | 27 |
170 | 35 |
171 | * |
172 | * |
173 | * |
174 | * |
175 |
176 |
177 | FedHQ+ |
178 | 18 |
179 | 19 |
180 | 20 |
181 | 40 |
182 | 42 |
183 | 46 |
184 | 66 |
185 |
186 |
187 | 4-bit:0.8 8-bit:0.2 |
188 | FegAvg |
189 | 21 |
190 | * |
191 | * |
192 | * |
193 | * |
194 | * |
195 | * |
196 |
197 |
198 | Proportional |
199 | 16 |
200 | 24 |
201 | 51 |
202 | * |
203 | * |
204 | * |
205 | * |
206 |
207 |
208 | FedHQ+ |
209 | 13 |
210 | 18 |
211 | 23 |
212 | 35 |
213 | 47 |
214 | 53 |
215 | 69 |
216 |
217 |
218 | 4-bit:1 8-bit:0 |
219 | FegAvg |
220 | 14 |
221 | 32 |
222 | * |
223 | * |
224 | * |
225 | * |
226 | * |
227 |
228 |
229 | FedHQ+ |
230 | 16 |
231 | 20 |
232 | 32 |
233 | 52 |
234 | 79 |
235 | * |
236 | * |
237 |
238 |
239 |
240 | ```Table 2:``` Number of communication round to reach different target accuracy on MNIST dataset, non-IID partition.
241 |
242 |
243 | Quantizationbits: ratio |
244 | Schemes |
245 | Accuracy |
246 |
247 |
248 | 60% |
249 | 70% |
250 | 80% |
251 | 90% |
252 | 92% |
253 | 94% |
254 | 95% |
255 |
256 |
257 | 4-bit:0 8-bit:1 |
258 | FegAvg |
259 | 12 |
260 | 18 |
261 | 26 |
262 | 39 |
263 | 49 |
264 | 55 |
265 | 75 |
266 |
267 |
268 | FedHQ+ |
269 | 11 |
270 | 19 |
271 | 22 |
272 | 32 |
273 | 36 |
274 | 55 |
275 | 74 |
276 |
277 |
278 | 4-bit:0.2 8-bit:0.8 |
279 | FegAvg |
280 | 13 |
281 | 17 |
282 | 19 |
283 | 41 |
284 | 43 |
285 | 60 |
286 | 96 |
287 |
288 |
289 | Proportional |
290 | 14 |
291 | 18 |
292 | 22 |
293 | 34 |
294 | 44 |
295 | 57 |
296 | 92 |
297 |
298 |
299 | FedHQ+ |
300 | 13 |
301 | 20 |
302 | 28 |
303 | 41 |
304 | 43 |
305 | 60 |
306 | 96 |
307 |
308 |
309 | 4-bit:0.4 8-bit:0.6 |
310 | FegAvg |
311 | 19 |
312 | 23 |
313 | 25 |
314 | 43 |
315 | 51 |
316 | 76 |
317 | 131 |
318 |
319 |
320 | Proportional |
321 | 20 |
322 | 22 |
323 | 23 |
324 | 42 |
325 | 46 |
326 | 68 |
327 | 119 |
328 |
329 |
330 | FedHQ+ |
331 | 13 |
332 | 17 |
333 | 25 |
334 | 42 |
335 | 43 |
336 | 55 |
337 | 87 |
338 |
339 |
340 | 4-bit:0.6 8-bit:0.4 |
341 | FegAvg |
342 | 21 |
343 | 37 |
344 | * |
345 | * |
346 | * |
347 | * |
348 | * |
349 |
350 |
351 | Proportional |
352 | 19 |
353 | 42 |
354 | * |
355 | * |
356 | * |
357 | * |
358 | * |
359 |
360 |
361 | FedHQ+ |
362 | 19 |
363 | 26 |
364 | 31 |
365 | 51 |
366 | 59 |
367 | 87 |
368 | 133 |
369 |
370 |
371 | 4-bit:0.8 8-bit:0.2 |
372 | FegAvg |
373 | 22 |
374 | 42 |
375 | * |
376 | * |
377 | * |
378 | * |
379 | * |
380 |
381 |
382 | Proportional |
383 | 16 |
384 | 41 |
385 | 50 |
386 | * |
387 | * |
388 | * |
389 | * |
390 |
391 |
392 | FedHQ+ |
393 | 21 |
394 | 23 |
395 | 31 |
396 | * |
397 | * |
398 | * |
399 | * |
400 |
401 |
402 | 4-bit:1 8-bit:0 |
403 | FegAvg |
404 | 17 |
405 | 42 |
406 | * |
407 | * |
408 | * |
409 | * |
410 | * |
411 |
412 |
413 | FedHQ+ |
414 | 19 |
415 | 39 |
416 | * |
417 | * |
418 | * |
419 | * |
420 | * |
421 |
422 |
423 |
424 | #### Results on CIFAR10:
425 |
426 | * To run the FedHQ experiment with MNIST under IID condition using GPU:
427 | ```
428 | python src/FedHQ_main.py --dataset=cifar --epochs=300 --frac=0.1 --local_ep=5 --local_bs=128 --average_scheme=FedHQ --bit_4_ratio=0 --bit_8_ratio=1
429 | ```
430 | * To run the FedHQ experiment with MNIST under non-IID condition using GPU:
431 | ```
432 | python src/FedHQ_main.py --dataset=cifar --epochs=150 --iid=0 --frac=0.1 --local_ep=5 --local_bs=64 --momentum=0.2 --average_scheme=FedHQ --bit_4_ratio=0 --bit_8_ratio=1
433 | ```
434 | Parameters setting as follows(only list the parameters differing from default):
435 | * ```--epochs:``` 300 for IID. 150 for non-IID.
436 | * ```frac: ``` 0.1
437 | * ```local_ep: ``` 5
438 | * ```local_bs: ``` 128
439 | Learning-rate decay is 0.9 per ten rounds. The ratios of 4-bit quantization clients are [0,0.2,0.4,0.6,0.8,1].
440 |
441 | ```Table 3:``` Number of communication round to reach different target accuracy on CIFAR dataset, IID partition.
442 |
443 |
444 | Quantizationbits: ratio |
445 | Schemes |
446 | Accuracy |
447 |
448 |
449 | 60% |
450 | 70% |
451 | 80% |
452 | 82% |
453 | 84% |
454 | 86% |
455 |
456 |
457 | 4-bit:0 8-bit:1 |
458 | FegAvg |
459 | 13 |
460 | 22 |
461 | 45 |
462 | 53 |
463 | 69 |
464 | 94 |
465 |
466 |
467 | FedHQ+ |
468 | 12 |
469 | 22 |
470 | 42 |
471 | 53 |
472 | 68 |
473 | 94 |
474 |
475 |
476 | 4-bit:0.2 8-bit:0.8 |
477 | FegAvg |
478 | 58 |
479 | 126 |
480 | * |
481 | * |
482 | * |
483 | * |
484 |
485 |
486 | Proportional |
487 | 31 |
488 | 58 |
489 | 179 |
490 | 285 |
491 | * |
492 | * |
493 |
494 |
495 | FedHQ+ |
496 | 14 |
497 | 26 |
498 | 56 |
499 | 69 |
500 | 97 |
501 | 133 |
502 |
503 |
504 | 4-bit:0.3 8-bit:0.7 |
505 | FegAvg |
506 | 144 |
507 | 276 |
508 | * |
509 | * |
510 | * |
511 | * |
512 |
513 |
514 | Proportional |
515 | 73 |
516 | 109 |
517 | * |
518 | * |
519 | * |
520 | * |
521 |
522 |
523 | FedHQ+ |
524 | 13 |
525 | 23 |
526 | 51 |
527 | 66 |
528 | 92 |
529 | 126 |
530 |
531 |
532 | 4-bit:0.4 8-bit:0.6 |
533 | FegAvg |
534 | * |
535 | * |
536 | * |
537 | * |
538 | * |
539 | * |
540 |
541 |
542 | Proportional |
543 | 119 |
544 | 227 |
545 | * |
546 | * |
547 | * |
548 | * |
549 |
550 |
551 | FedHQ+ |
552 | 17 |
553 | 25 |
554 | 57 |
555 | 76 |
556 | 98 |
557 | 199 |
558 |
559 |
560 | 4-bit:0.6 8-bit:0.4 |
561 | FegAvg |
562 | * |
563 | * |
564 | * |
565 | * |
566 | * |
567 | * |
568 |
569 |
570 | Proportional |
571 | * |
572 | * |
573 | * |
574 | * |
575 | * |
576 | * |
577 |
578 |
579 | FedHQ+ |
580 | 18 |
581 | 33 |
582 | 100 |
583 | 184 |
584 | * |
585 | * |
586 |
587 |
588 | 4-bit:0.8 8-bit:0.2 |
589 | FegAvg |
590 | * |
591 | * |
592 | * |
593 | * |
594 | * |
595 | * |
596 |
597 |
598 | Proportional |
599 | * |
600 | * |
601 | * |
602 | * |
603 | * |
604 | * |
605 |
606 |
607 | FedHQ+ |
608 | 84 |
609 | * |
610 | * |
611 | * |
612 | * |
613 | * |
614 |
615 |
616 | 4-bit:1 8-bit:0 |
617 | FegAvg |
618 | * |
619 | * |
620 | * |
621 | * |
622 | * |
623 | * |
624 |
625 |
626 | FedHQ+ |
627 | * |
628 | * |
629 | * |
630 | * |
631 | * |
632 | * |
633 |
634 |
635 |
636 | ```Table 4:``` Number of communication round to reach different target accuracy on CIFAR dataset, non-IID partition.
637 |
638 |
639 | Quantizationbits: ratio |
640 | Schemes |
641 | Accuracy |
642 |
643 |
644 | 30% |
645 | 35% |
646 | 40% |
647 | 45% |
648 | 50% |
649 | 55% |
650 |
651 |
652 | 4-bit:0 8-bit:1 |
653 | FegAvg |
654 | 9 |
655 | 15 |
656 | 15 |
657 | 30 |
658 | 48 |
659 | 73 |
660 |
661 |
662 | FedHQ+ |
663 | 9 |
664 | 14 |
665 | 15 |
666 | 27 |
667 | 48 |
668 | 71 |
669 |
670 |
671 | 4-bit:0.2 8-bit:0.8 |
672 | FegAvg |
673 | 31 |
674 | 48 |
675 | 93 |
676 | * |
677 | * |
678 | * |
679 |
680 |
681 | Proportional |
682 | 18 |
683 | 27 |
684 | 38 |
685 | 93 |
686 | * |
687 | * |
688 |
689 |
690 | FedHQ+ |
691 | 18 |
692 | 57 |
693 | 60 |
694 | 83 |
695 | 93 |
696 | * |
697 |
698 |
699 | 4-bit:0.3 8-bit:0.7 |
700 | FegAvg |
701 | 48 |
702 | 110 |
703 | * |
704 | * |
705 | * |
706 | * |
707 |
708 |
709 | Proportional |
710 | 14 |
711 | 27 |
712 | 38 |
713 | 93 |
714 | * |
715 | * |
716 |
717 |
718 | FedHQ+ |
719 | 11 |
720 | 14 |
721 | 18 |
722 | 38 |
723 | 88 |
724 | 110 |
725 |
726 |
727 | 4-bit:0.4 8-bit:0.6 |
728 | FegAvg |
729 | * |
730 | * |
731 | * |
732 | * |
733 | * |
734 | * |
735 |
736 |
737 | Proportional |
738 | 93 |
739 | * |
740 | * |
741 | * |
742 | * |
743 | * |
744 |
745 |
746 | FedHQ+ |
747 | 22 |
748 | 49 |
749 | 60 |
750 | 60 |
751 | 93 |
752 | * |
753 |
754 |
755 | 4-bit:0.6 8-bit:0.4 |
756 | FegAvg |
757 | * |
758 | * |
759 | * |
760 | * |
761 | * |
762 | * |
763 |
764 |
765 | Proportional |
766 | * |
767 | * |
768 | * |
769 | * |
770 | * |
771 | * |
772 |
773 |
774 | FedHQ+ |
775 | 33 |
776 | 40 |
777 | 62 |
778 | * |
779 | * |
780 | * |
781 |
782 |
783 | 4-bit:0.8 8-bit:0.2 |
784 | FegAvg |
785 | * |
786 | * |
787 | * |
788 | * |
789 | * |
790 | * |
791 |
792 |
793 | Proportional |
794 | * |
795 | * |
796 | * |
797 | * |
798 | * |
799 | * |
800 |
801 |
802 | FedHQ+ |
803 | 117 |
804 | * |
805 | * |
806 | * |
807 | * |
808 | * |
809 |
810 |
811 | 4-bit:1 8-bit:0 |
812 | FegAvg |
813 | * |
814 | * |
815 | * |
816 | * |
817 | * |
818 | * |
819 |
820 |
821 | FedHQ+ |
822 | * |
823 | * |
824 | * |
825 | * |
826 | * |
827 | * |
828 |
829 |
830 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | python=3.7.3
2 | pytorch=1.2.0
3 | torchvision=0.4.0
4 | numpy=1.16.2
5 | tensorboardx=1.9
6 | matplotlib=3.0.3
7 | tqdm=4.31.1
8 | prettytable=0.7.2
9 |
--------------------------------------------------------------------------------
/src/FedHQ_main.py:
--------------------------------------------------------------------------------
1 | import os
2 | from options import args_parser
3 | from train import train
4 | args = args_parser()
5 |
6 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
7 |
8 | if __name__ == '__main__':
9 | train()
--------------------------------------------------------------------------------
/src/__pycache__/cifar_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/cifar_model.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/models.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/models_without_quant.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/models_without_quant.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/options.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/options.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/quantizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/quantizer.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/quantizer2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/quantizer2.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/sampling.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/sampling.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/train.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/train.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/update.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/update.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/src/cifar_model.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | import math
4 |
5 | def make_layers_Cifar10(cfg, quant, batch_norm=False, conv=nn.Conv2d):
6 | layers = list()
7 | in_channels = 3
8 | n = 1
9 | for v in cfg:
10 | if v == 'M':
11 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
12 | else:
13 | use_quant = v[-1] != 'N'
14 | filters = int(v) if use_quant else int(v[:-1])
15 | conv2d = conv(in_channels, filters, kernel_size=3, padding=1, bias=False)
16 | if batch_norm:
17 | layers += [conv2d, nn.BatchNorm2d(filters), nn.ReLU(True)]
18 | else:
19 | layers += [conv2d, nn.ReLU()]
20 | if quant!=None: layers += [quant()]
21 | n += 1
22 | in_channels = filters
23 | return nn.Sequential(*layers)
24 |
25 | class CNNCifar(nn.Module):
26 | def __init__(self, args,quant):
27 | self.args=args
28 | super(CNNCifar, self).__init__()
29 | self.linear = nn.Linear
30 | cfg = {
31 | 9: ['64', '64', 'M', '128', '128', 'M', '256', '256', 'M'],
32 | 11: ['64', 'M', '128', 'M', '256', '256', 'M', '512', '512', 'M', '512', '512', 'M'],
33 | 13: ['64', '64', 'M', '128', '128', 'M', '256', '256', 'M', '512', '512', 'M', '512', '512', 'M'],
34 | 16: ['64', '64', 'M', '128', '128', 'M', '256', '256', '256', 'M', '512', '512', '512', 'M', '512', '512', '512', 'M'],
35 | }
36 | self.conv = nn.Conv2d
37 | self.features = make_layers_Cifar10(cfg[16], quant, True, self.conv)
38 | self.classifier=None
39 | if quant!=None:
40 | self.classifier = nn.Sequential(
41 | nn.Dropout(),
42 | self.linear(512 * 1 * 1, 4096),
43 | nn.ReLU(True),
44 | quant(),
45 | self.linear(4096, 4096),
46 | nn.ReLU(True),
47 | quant(),
48 | self.linear(4096, args.num_classes),
49 | nn.ReLU(True),
50 | quant(),
51 | nn.LogSoftmax(dim=1)
52 | )
53 | else:
54 | self.classifier = nn.Sequential(
55 | nn.Dropout(),
56 | self.linear(512 * 1 * 1, 4096),
57 | nn.ReLU(True),
58 | self.linear(4096, 4096),
59 | nn.ReLU(True),
60 | self.linear(4096, args.num_classes),
61 | nn.ReLU(True),
62 | nn.LogSoftmax(dim=1)
63 | )
64 | for m in self.modules():
65 | if isinstance(m, nn.Conv2d):
66 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
67 | m.weight.data.normal_(0, math.sqrt(2. / n))
68 | def forward(self, x):
69 | x = self.features(x)
70 | x = x.view(-1, 512 * 1 * 1)
71 | x = self.classifier(x)
72 | return x
73 |
74 |
75 | class ResnetCifar18(nn.Module):
76 | def __init__(self, quant, quantx, in_channel, out_channel, strides):
77 | super(ResnetCifar18,self).__init__()
78 | self.block=None
79 | self.residual=nn.Sequential()
80 | self.quantx=quantx
81 | if quant==None:
82 | self.block=nn.Sequential(
83 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=strides, padding=1, bias=False),
84 | nn.BatchNorm2d(out_channel),
85 | nn.ReLU(True),
86 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False),
87 | nn.BatchNorm2d(out_channel)
88 | )
89 | if strides!=1 or in_channel!=out_channel:
90 | self.residual=nn.Sequential(
91 | nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides,bias=False),
92 | nn.BatchNorm2d(out_channel)
93 | )
94 | else:
95 | self.block = nn.Sequential(
96 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=strides, padding=1, bias=False),
97 | nn.BatchNorm2d(out_channel),
98 | nn.ReLU(True),
99 | quant(),
100 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False),
101 | nn.BatchNorm2d(out_channel),
102 | quant()
103 | )
104 | self.residual = nn.Sequential()
105 | if strides != 1 or in_channel != out_channel:
106 | self.residual = nn.Sequential(
107 | nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides, bias=False),
108 | nn.BatchNorm2d(out_channel),
109 | quant()
110 | )
111 | def forward(self,x):
112 | out=self.block(x)
113 | out+=self.residual(x)
114 | out=F.relu(out)
115 | if self.quantx!=None:
116 | out=self.quantx(out)
117 | return out
118 |
119 | class ResNet(nn.Module):
120 | def __init__(self, args, quant, quantx):
121 | super(ResNet,self).__init__()
122 | self.in_channel=64
123 | self.quantx=quantx
124 | self.conv1=None
125 | if quant==None:
126 | self.conv1=nn.Sequential(
127 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
128 | nn.BatchNorm2d(64),
129 | nn.ReLU(True)# ,
130 | # nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
131 | )
132 | else:
133 | self.conv1 = nn.Sequential(
134 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
135 | nn.BatchNorm2d(64),
136 | nn.ReLU(True),
137 | quant(),
138 | # nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
139 | )
140 | self.layer1 = self.make_layer(quant, quantx, 64, 2, stride=1)
141 | self.layer2 = self.make_layer(quant, quantx, 128, 2, stride=2)
142 | self.layer3 = self.make_layer(quant, quantx, 256, 2, stride=2)
143 | self.layer4 = self.make_layer(quant, quantx, 512, 2, stride=2)
144 | self.fc=nn.Linear(512, args.num_classes)
145 |
146 | def make_layer(self, quant, quantx, channel, num_blocks, stride):
147 | strides=[stride] + [1]*(num_blocks-1)
148 | layers=[]
149 | for stride in strides:
150 | layers.append(ResnetCifar18(quant, quantx, self.in_channel, channel, stride))
151 | self.in_channel=channel
152 | return nn.Sequential(*layers)
153 | def forward(self, x):
154 | out = self.conv1(x)
155 | out = self.layer1(out)
156 | out = self.layer2(out)
157 | out = self.layer3(out)
158 | out = self.layer4(out)
159 | out = F.avg_pool2d(out, 4)
160 | if self.quantx != None:
161 | out = self.quantx(out)
162 | out = out.view(out.size(0), -1)
163 | out = self.fc(out)
164 | if self.quantx != None:
165 | out = self.quantx(out)
166 | out = F.log_softmax(out,dim=1)
167 | return out
168 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import math
3 |
4 | def make_layers_Mnist(cfg, quant, batch_norm=False, conv=nn.Conv2d):
5 | layers = list()
6 | in_channels = 1
7 | n = 1
8 | for v in cfg:
9 | if v == 'M':
10 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
11 | else:
12 | use_quant = v[-1] != 'N'
13 | filters = int(v) if use_quant else int(v[:-1])
14 | conv2d = conv(in_channels, filters, kernel_size=3, padding=1)
15 | if batch_norm:
16 | layers += [conv2d, nn.BatchNorm2d(filters), nn.ReLU()]
17 | else:
18 | layers += [conv2d, nn.ReLU()]
19 | if quant!=None: layers += [quant()]
20 | n += 1
21 | in_channels = filters
22 | return nn.Sequential(*layers)
23 |
24 | class CNNMnist(nn.Module):
25 | def __init__(self, args,quant):
26 | self.args=args
27 | super(CNNMnist, self).__init__()
28 | self.linear = nn.Linear
29 | cfg = {
30 | 16: ['16', 'M', '32', 'M']
31 | }
32 | self.conv = nn.Conv2d
33 | self.features = make_layers_Mnist(cfg[16], quant, True, self.conv)
34 | self.classifier=None
35 | if quant!=None:
36 | self.classifier = nn.Sequential(
37 | nn.Dropout(),
38 | self.linear(7*7*32, 512),
39 | nn.ReLU(True),
40 | quant(),
41 | self.linear(512, 10),
42 | nn.ReLU(True),
43 | quant(),
44 | nn.LogSoftmax(dim=1),
45 | )
46 | else:
47 | self.classifier = nn.Sequential(
48 | nn.Dropout(),
49 | self.linear(7 * 7 * 32, 512),
50 | nn.ReLU(True),
51 | self.linear(512, 10),
52 | nn.ReLU(True),
53 | nn.LogSoftmax(dim=1)
54 | )
55 | for m in self.modules():
56 | if isinstance(m, nn.Conv2d):
57 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
58 | m.weight.data.normal_(0, math.sqrt(2. / n))
59 | m.bias.data.zero_()
60 | def forward(self, x):
61 | x = self.features(x)
62 | x = x.view(x.size(0), -1)
63 | x = self.classifier(x)
64 | return x
--------------------------------------------------------------------------------
/src/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def args_parser():
5 | parser = argparse.ArgumentParser()
6 |
7 | # federated arguments (Notation for the arguments fllowed from paper)
8 | parser.add_argument('--epochs', type=int, default=150,
9 | help="number of rounds of training")
10 | parser.add_argument('--num_users', type=int, default=100,
11 | help="number of clie: n")
12 | parser.add_argument('--frac', type=float, default=1,
13 | help='the fraction of clients: C')
14 | parser.add_argument('--local_ep', type=int, default=1,
15 | help="the number of local epochs: K")
16 | parser.add_argument('--local_bs', type=int, default=600,
17 | help="local batch size: B")
18 | parser.add_argument('--lr', type=float, default=0.1,
19 | help='learning rate')
20 | parser.add_argument('--momentum', type=float, default=0.5,
21 | help='SGD momentum (default: 0.5)')
22 | parser.add_argument('--weight_decay', type=float, default=0.0005,
23 | help='weight decay of optimizer (default: 0.0005)')
24 | parser.add_argument('--average_scheme', type=str, default='FedHQ', help='choose average scheme',
25 | choices=['FedAvg','Proportional','FedHQ'])
26 | parser.add_argument('--quant_bits', type=int, default=8, help='record the current quantization bit')
27 | parser.add_argument('--bit_4_ratio', type=float, default=0.6, help='the ratio for 4-bit clients')
28 | parser.add_argument('--bit_8_ratio', type=float, default=0.4, help='the ratio for 8-bit clients')
29 |
30 | # other arguments
31 | parser.add_argument('--dataset', type=str, default='mnist', help="name \
32 | of dataset")
33 | parser.add_argument('--num_classes', type=int, default=10, help="number \
34 | of classes")
35 | parser.add_argument('--gpu', default=1, help="To use CPU or GPU. Default set to use GPU.")
36 | parser.add_argument('--optimizer', type=str, default='sgd', help="type \
37 | of optimizer")
38 | parser.add_argument('--iid', type=int, default=1,
39 | help='Default set to IID. Set to 0 for non-IID.')
40 | parser.add_argument('--unequal', type=int, default=0,
41 | help='whether to use unequal data splits for \
42 | non-i.i.d setting (use 0 for equal splits)')
43 | parser.add_argument('--stopping_rounds', type=int, default=10,
44 | help='rounds of early stopping')
45 | parser.add_argument('--verbose', type=int, default=1, help='verbose')
46 | parser.add_argument('--seed', type=int, default=1, help='random seed')
47 |
48 | parser.add_argument('--dir', type=str, default=None,
49 | help='training directory (default: None)')
50 | parser.add_argument('--data_path', type=str, default="./data", required=False, metavar='PATH',
51 | help='path to datasets location (default: "./data")')
52 | parser.add_argument('--num_workers', type=int, default=0, metavar='N',
53 | help='number of workers (default: 0)')
54 | parser.add_argument('--log_name', type=str, default='', metavar='S',
55 | help="Name for the log dir")
56 | parser.add_argument('--quant_type', type=str, default='stochastic', metavar='S',
57 | help='rounding method, stochastic or nearest ', choices=['stochastic', 'nearest'])
58 |
59 | args = parser.parse_args()
60 | return args
61 |
--------------------------------------------------------------------------------
/src/quantizer.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import torch.nn as nn
4 | from options import args_parser
5 | args = args_parser()
6 | def add_r_(data):
7 | r = torch.rand_like(data)
8 | data.add_(r)
9 |
10 | def _round(data, sigma, t_min, t_max, mode, clip=True):
11 | """
12 | Quantzie a Tensor.
13 | """
14 | temp = data / sigma
15 | if mode=="nearest":
16 | temp = temp.round()
17 | elif mode=="stochastic":
18 | add_r_(temp)
19 | temp.floor_()
20 | else: raise ValueError("Invalid quantization mode: {}".format(mode))
21 | temp *= sigma
22 | if clip: temp.clamp_(t_min, t_max)
23 | return temp
24 |
25 | def block_quantize(data, bits, mode, ebit):
26 | max_exponent = torch.floor(torch.log2(torch.abs(torch.where(data==torch.zeros_like(data), torch.ones_like(data), data))))
27 | # Suppose we allocate W bits to represent each number in the block and F bits to represent the shared exponent.
28 | max_exponent.clamp_(-2 ** (ebit - 1), 2 ** (ebit - 1) - 1)
29 | i = data * 2**(-max_exponent+(bits-2))
30 | if mode == "stochastic":
31 | add_r_(i)
32 | i.floor_()
33 | elif mode == "nearest":
34 | i.round_()
35 | i.clamp_(-2**(bits-1), 2**(bits-1)-1)
36 | temp = i * 2**(max_exponent-(bits-2))
37 | return temp
38 | def q_quantize(data, bits, mode, ebit):
39 | max_exponent = torch.floor(torch.log2(torch.abs(torch.where(data==torch.zeros_like(data), torch.ones_like(data), data))))
40 | max_exponent.clamp_(-2 ** (ebit - 1), 2 ** (ebit - 1) - 1)
41 | i = data * 2**(-max_exponent+(bits-2))
42 | cur_exp = 2 ** (max_exponent - (bits - 2))
43 | p4left = 1 - i % 1
44 | p4right = i % 1
45 | q_n_left = torch.floor(i).clamp_(-2 ** (bits - 1), 2 ** (bits - 1) - 1) * cur_exp
46 | q_n_right = torch.ceil(i).clamp_(-2 ** (bits - 1), 2 ** (bits - 1) - 1) * cur_exp
47 | e_q = torch.pow(q_n_left - data, 2) * p4left + torch.pow(q_n_right - data, 2) * p4right
48 | powdata=torch.pow(data, 2)
49 | q = e_q / powdata
50 | end_q=torch.where(data==torch.zeros_like(data),torch.zeros_like(data),q)
51 | if mode == "stochastic":
52 | add_r_(i)
53 | i.floor_()
54 | elif mode == "nearest":
55 | i.round_()
56 | i.clamp_(-2**(bits-1), 2**(bits-1)-1)
57 | temp = i * 2**(max_exponent-(bits-2))
58 | max_q=torch.max(end_q)
59 | ind=torch.where(torch.abs(end_q-max_q)>=1e-9, -1000*torch.ones_like(end_q), data)
60 | data_q=torch.max(ind)
61 | return temp, max_q, torch.max(data_q)
62 |
63 | class BlockRounding(torch.autograd.Function):
64 | @staticmethod
65 | def forward(self, x, bits, ebits, mode):
66 | self.ebits = ebits
67 | self.bits=bits
68 | self.mode = mode
69 | if bits == -1: return x
70 | return block_quantize(x, bits, self.mode, ebits)
71 |
72 | @staticmethod
73 | def backward(self, grad_output):
74 | if self.needs_input_grad[0]:
75 | if self.bits != -1:
76 | grad_input = block_quantize(grad_output, self.bits, self.mode, self.ebits)
77 | else:
78 | grad_input = grad_output
79 | return grad_input, None, None, None, None
80 |
81 | quantize_block = BlockRounding.apply
82 |
83 | class BlockQuantizer(nn.Module):
84 | def __init__(self, bits, ebits, mode):
85 | super(BlockQuantizer, self).__init__()
86 | self.bits = bits
87 | self.ebits = ebits
88 | self.mode = mode
89 |
90 | def forward(self, x):
91 | return quantize_block(x, self.bits,self.ebits, self.mode)
92 |
--------------------------------------------------------------------------------
/src/sampling.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def mnist_iid(dataset, num_users):
4 | """
5 | Sample I.I.D. client data from MNIST dataset
6 | :param: dataset
7 | :param: num_users
8 | :return: dict of image index
9 | """
10 | num_items = int(len(dataset)/num_users)
11 | dict_users, all_idxs = {}, [i for i in range(len(dataset))]
12 | for i in range(num_users):
13 | dict_users[i] = set(np.random.choice(all_idxs, num_items,
14 | replace=False))
15 | all_idxs = list(set(all_idxs) - dict_users[i])
16 | return dict_users
17 |
18 |
19 | def mnist_noniid(dataset, num_users):
20 | """
21 | Sample non-I.I.D client data from MNIST dataset
22 | :param: dataset
23 | :param: num_users
24 | :return: dict of image index
25 | """
26 | # 60,000 training imgs --> 200 imgs/shard X 300 shards
27 | num_shards, num_imgs = 200, 300
28 | idx_shard = [i for i in range(num_shards)]
29 | dict_users = {i: np.array([]) for i in range(num_users)}
30 | idxs = np.arange(num_shards*num_imgs)
31 | labels = dataset.train_labels.numpy()
32 |
33 | # sort labels
34 | idxs_labels = np.vstack((idxs, labels))
35 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
36 | idxs = idxs_labels[0, :]
37 |
38 | # divide and assign 2 shards/client
39 | for i in range(num_users):
40 | rand_set = set(np.random.choice(idx_shard, 2, replace=False))
41 | idx_shard = list(set(idx_shard) - rand_set)
42 | for rand in rand_set:
43 | dict_users[i] = np.concatenate(
44 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
45 | return dict_users
46 |
47 |
48 | def mnist_noniid_unequal(dataset, num_users):
49 | """
50 | Sample non-I.I.D client data from MNIST dataset s.t clients
51 | have unequal amount of data
52 | :param: dataset
53 | :param: num_users
54 | :returns: a dict of clients with each clients assigned certain
55 | number of training imgs
56 | """
57 | # 60,000 training imgs --> 50 imgs/shard X 1200 shards
58 | num_shards, num_imgs = 1200, 50
59 | idx_shard = [i for i in range(num_shards)]
60 | dict_users = {i: np.array([]) for i in range(num_users)}
61 | idxs = np.arange(num_shards*num_imgs)
62 | labels = dataset.train_labels.numpy()
63 |
64 | # sort labels
65 | idxs_labels = np.vstack((idxs, labels))
66 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
67 | idxs = idxs_labels[0, :]
68 |
69 | # Minimum and maximum shards assigned per client:
70 | min_shard = 1
71 | max_shard = 30
72 |
73 | # Divide the shards into random chunks for every client
74 | # s.t the sum of these chunks = num_shards
75 | random_shard_size = np.random.randint(min_shard, max_shard+1,
76 | size=num_users)
77 | random_shard_size = np.around(random_shard_size /
78 | sum(random_shard_size) * num_shards)
79 | random_shard_size = random_shard_size.astype(int)
80 |
81 | # Assign the shards randomly to each client
82 | if sum(random_shard_size) > num_shards:
83 |
84 | for i in range(num_users):
85 | # First assign each client 1 shard to ensure every client has
86 | # atleast one shard of data
87 | rand_set = set(np.random.choice(idx_shard, 1, replace=False))
88 | idx_shard = list(set(idx_shard) - rand_set)
89 | for rand in rand_set:
90 | dict_users[i] = np.concatenate(
91 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
92 | axis=0)
93 |
94 | random_shard_size = random_shard_size-1
95 |
96 | # Next, randomly assign the remaining shards
97 | for i in range(num_users):
98 | if len(idx_shard) == 0:
99 | continue
100 | shard_size = random_shard_size[i]
101 | if shard_size > len(idx_shard):
102 | shard_size = len(idx_shard)
103 | rand_set = set(np.random.choice(idx_shard, shard_size,
104 | replace=False))
105 | idx_shard = list(set(idx_shard) - rand_set)
106 | for rand in rand_set:
107 | dict_users[i] = np.concatenate(
108 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
109 | axis=0)
110 | else:
111 |
112 | for i in range(num_users):
113 | shard_size = random_shard_size[i]
114 | rand_set = set(np.random.choice(idx_shard, shard_size,
115 | replace=False))
116 | idx_shard = list(set(idx_shard) - rand_set)
117 | for rand in rand_set:
118 | dict_users[i] = np.concatenate(
119 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
120 | axis=0)
121 |
122 | if len(idx_shard) > 0:
123 | # Add the leftover shards to the client with minimum images:
124 | shard_size = len(idx_shard)
125 | # Add the remaining shard to the client with lowest data
126 | k = min(dict_users, key=lambda x: len(dict_users.get(x)))
127 | rand_set = set(np.random.choice(idx_shard, shard_size,
128 | replace=False))
129 | idx_shard = list(set(idx_shard) - rand_set)
130 | for rand in rand_set:
131 | dict_users[k] = np.concatenate(
132 | (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]),
133 | axis=0)
134 |
135 | return dict_users
136 |
137 |
138 | def cifar_iid(dataset, num_users):
139 | """
140 | Sample I.I.D. client data from CIFAR10 dataset
141 | :param: dataset
142 | :param: num_users
143 | :return: dict of image index
144 | """
145 | num_items = int(len(dataset)/num_users)
146 | dict_users, all_idxs = {}, [i for i in range(len(dataset))]
147 | for i in range(num_users):
148 | dict_users[i] = set(np.random.choice(all_idxs, num_items,replace=False))
149 | all_idxs = list(set(all_idxs) - dict_users[i])
150 | return dict_users
151 |
152 |
153 | def cifar_noniid(dataset, num_users):
154 | """
155 | Sample non-I.I.D client data from CIFAR10 dataset
156 | :param: dataset
157 | :param: num_users
158 | :return: dict of image index
159 | """
160 | num_shards, num_imgs = 200, 250
161 | idx_shard = [i for i in range(num_shards)]
162 | dict_users = {i: np.array([]) for i in range(num_users)}
163 | idxs = np.arange(num_shards*num_imgs)
164 | labels = np.array(dataset.targets)
165 |
166 | # sort labels
167 | idxs_labels = np.vstack((idxs, labels))
168 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
169 | idxs = idxs_labels[0, :]
170 | # divide and assign
171 | for i in range(num_users):
172 | rand_set = set(np.random.choice(idx_shard, 2, replace=False))
173 | idx_shard = list(set(idx_shard) - rand_set)
174 | for rand in rand_set:
175 | dict_users[i] = np.concatenate(
176 | (dict_users[i], idxs[int(rand)*num_imgs:(int(rand)+1)*num_imgs]), axis=0)
177 | return dict_users
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import time
4 | import numpy as np
5 | from tqdm import tqdm
6 | import quantizer as qn
7 |
8 | import torch
9 | from tensorboardX import SummaryWriter
10 |
11 | from update import LocalUpdate, Evaluation
12 | from models import CNNMnist
13 | from cifar_model import CNNCifar, ResNet
14 | from utils import get_dataset, average_weights, exp_details, get_quantization_bit
15 | from options import args_parser
16 | import prettytable as pt
17 | def make_dir(filename):
18 | if not os.path.exists(filename):
19 | os.makedirs(filename)
20 | def train():
21 | start_time = time.time()
22 |
23 | # define paths
24 | path_project = os.path.abspath('..')
25 | logger = SummaryWriter('../logs')
26 | args = args_parser()
27 | exp_details(args)
28 | device = 'cuda'
29 |
30 | # load dataset and user groups
31 | train_dataset, test_dataset, user_groups = get_dataset(args)
32 |
33 | # Build model for different dataset
34 | if args.dataset == 'mnist':
35 | quant = lambda: qn.BlockQuantizer(args.quant_bits, args.quant_bits, args.quant_type)
36 | global_model = CNNMnist(args=args, quant=quant)
37 | elif args.dataset == 'cifar':
38 | quant = lambda: qn.BlockQuantizer(args.quant_bits, args.quant_bits, args.quant_type)
39 | quantx = lambda x: qn.BlockQuantizer(x, args.quant_bits, args.quant_bits, args.quant_type)
40 | if args.iid==1:
41 | global_model = ResNet(args=args, quant=quant, quantx=quantx)
42 | else:
43 | global_model = CNNCifar(args=args, quant=quant)
44 | # hold global weight get from server
45 | global_model.to(device)
46 | global_model.train()
47 | global_weights = global_model.state_dict()
48 | # Record training loss and accuracy
49 | train_loss, train_accuracy = [], []
50 | print_every = 1
51 | acc_level = np.array(list(range(41)))*0.01+0.6
52 | acc_true=[]
53 | acc_table_line=[]
54 | acc_flag=np.zeros_like(acc_level)
55 | quant_bit_for_user, avg_for_user = get_quantization_bit(args)
56 | last_max_acc=0
57 | # Set the filename for saving results
58 | result_base_filename = 'result/' + args.dataset + '/iid/' + args.average_scheme
59 | if args.iid==0:
60 | result_base_filename = 'result/' + args.dataset + '/noniid/' + args.average_scheme
61 | save_acc_filename=result_base_filename + '/c'+str(int(args.frac*10))+'result_04-'+str(int(args.bit_4_ratio*10))+'_8-'+str(int(args.bit_8_ratio*10))+'.txt'
62 | save_pkl_filename = result_base_filename +'/c'+str(int(args.frac*10))+'result_04-' + str(int(args.bit_4_ratio * 10)) + '_8-' + str(int(args.bit_8_ratio * 10)) + '.pkl'
63 |
64 | make_dir(result_base_filename)
65 |
66 | for epoch in tqdm(range(args.epochs)):
67 | local_weights, local_losses = [], []
68 | print(f'\n | Global Training Round : {epoch+1} |\n')
69 | m = max(int(args.frac * args.num_users), 1)
70 | idxs_users = np.random.choice(range(args.num_users), m, replace=False)
71 | cur_q=[]
72 | for idx in idxs_users:
73 | quantx=None
74 | quant = None
75 | idx = int(idx)
76 | # Set the quantization condition and local model for each client
77 | if quant_bit_for_user[idx]!=0:
78 | args.quant_bits = quant_bit_for_user[idx]
79 | quant = lambda: qn.BlockQuantizer(args.quant_bits, args.quant_bits, args.quant_type)
80 | quantx = lambda x: qn.quantize_block(x, args.quant_bits, args.quant_bits, args.quant_type)
81 | if args.dataset == 'mnist':
82 | user_model = CNNMnist(args=args, quant=quant)
83 | elif args.dataset == 'cifar' and args.iid==1:
84 | user_model = ResNet(args=args, quant=quant, quantx=quantx)
85 | elif args.dataset == 'cifar' and args.iid==0:
86 | user_model = CNNCifar(args=args, quant=quant)
87 | user_model.to(device)
88 | user_model.train()
89 | # Update the weights of local to center weights
90 | weight_name = []
91 | for i in global_weights:
92 | weight_name.append(i)
93 | cnt = 0
94 | user_weights = user_model.state_dict()
95 | for i in user_weights:
96 | if quantx!=None:
97 | user_weights[i] = quantx(global_weights[weight_name[cnt]].to(float))
98 | else:
99 | user_weights[i] = global_weights[weight_name[cnt]]
100 | cnt += 1
101 | user_model.load_state_dict(user_weights)
102 | # Train the local model
103 | local_model = LocalUpdate(args=args, dataset=train_dataset,
104 | idxs=user_groups[idx], logger=logger,quant=quantx, quantbit=args.quant_bits, mode=args.quant_type)
105 | w, loss, q = local_model.update_weights(model=copy.deepcopy(user_model))
106 | cur_q.append(q)
107 | local_weights.append(copy.deepcopy(w))
108 | local_losses.append(copy.deepcopy(loss))
109 | torch.cuda.empty_cache()
110 |
111 | # update global weights
112 | global_weights = average_weights(args, local_weights,avg_for_user[idxs_users],q_for_user=np.array(cur_q))
113 | loss_avg = sum(local_losses) / len(local_losses)
114 | train_loss.append(loss_avg)
115 |
116 | # Calculate avg training accuracy over all clients at every epoch
117 | list_acc, list_loss = [], []
118 | weight_name = []
119 | for i in global_weights:
120 | weight_name.append(i)
121 | for c in range(args.num_users):
122 | quant = None
123 | quantx = None
124 | # Set the quantization condition and local model for each client
125 | if quant_bit_for_user[c] != 0:
126 | args.quant_bits = quant_bit_for_user[c]
127 | quant = lambda: qn.BlockQuantizer(args.quant_bits, args.quant_bits, args.quant_type)
128 | quantx = lambda x: qn.quantize_block(x, args.quant_bits, args.quant_bits, args.quant_type)
129 | if args.dataset == 'mnist':
130 | user_model = CNNMnist(args=args, quant=quant)
131 | elif args.dataset == 'cifar' and args.iid == 1:
132 | user_model = ResNet(args=args, quant=quant, quantx=quantx)
133 | elif args.dataset == 'cifar' and args.iid == 0:
134 | user_model = CNNCifar(args=args, quant=quant)
135 | user_model.to(device)
136 | user_model.eval()
137 | cnt = 0
138 | user_weights = user_model.state_dict()
139 | for i in user_weights:
140 | user_weights[i] = global_weights[weight_name[cnt]]
141 | cnt += 1
142 | user_model.load_state_dict(user_weights)
143 | local_model = LocalUpdate(args=args, dataset=train_dataset,
144 | idxs=user_groups[c], logger=logger, quant=quantx, quantbit=quant_bit_for_user[c],mode=args.quant_type)
145 | acc, loss = local_model.inference(model=user_model)
146 | list_acc.append(acc)
147 | list_loss.append(loss)
148 | torch.cuda.empty_cache()
149 | train_accuracy.append(sum(list_acc)/len(list_acc))
150 |
151 | if args.dataset == 'mnist':
152 | global_model = CNNMnist(args=args, quant=None)
153 | elif args.dataset == 'cifar' and args.iid == 1:
154 | global_model = ResNet(args=args, quant=quant, quantx=quantx)
155 | elif args.dataset == 'cifar' and args.iid == 0:
156 | global_model = CNNCifar(args=args, quant=quant)
157 | cnt = 0
158 | current_weight = global_model.state_dict()
159 | for i in current_weight:
160 | current_weight[i] = global_weights[weight_name[cnt]]
161 | cnt += 1
162 | global_model.load_state_dict(current_weight)
163 | global_model.to(device)
164 | global_model.eval()
165 | test_acc, test_loss = Evaluation(args, global_model, test_dataset)
166 | # Save the center model with the maximum accuracy
167 | if test_acc>last_max_acc:
168 | last_max_acc=test_acc
169 | torch.save(global_model.state_dict(), save_pkl_filename)
170 | with open(save_acc_filename,'a') as file:
171 | write_content=str(epoch+1)+' '+str(np.mean(np.array(train_loss)))+' '+str(train_accuracy[-1])+' '+str(test_acc)+"\n"
172 | file.write(write_content)
173 | for acc_index in range(len(acc_level)):
174 | if acc_flag[acc_index]==0 and test_acc>=acc_level[acc_index] and test_acc not in acc_true and test_acc>=acc_level[acc_index+1]:
175 | acc_flag[acc_index]=1
176 | if acc_flag[acc_index]==0 and test_acc>=acc_level[acc_index] and test_acc not in acc_true:
177 | acc_true.append(test_acc)
178 | acc_flag[acc_index]=1
179 | acc_table_line.append(epoch+1)
180 | if (epoch+1) % print_every == 0:
181 | print(f' \nAvg Training Stats after {epoch+1} global rounds:')
182 | print(f'Training Loss : {np.mean(np.array(train_loss))}')
183 | print('Train Accuracy: {:.2f}%'.format(100*train_accuracy[-1]))
184 | print("Test Accuracy: {:.2f}%\n".format(100 * test_acc))
185 | print(f"Test Loss: {np.mean(np.array(test_loss))}")
186 | print('4: ',args.bit_4_ratio,', 8: ',args.bit_8_ratio,', ',args.average_scheme)
187 | if len(acc_table_line) == 0:
188 | print("cannot get the communication round for the target accuracy")
189 | else:
190 | table = pt.PrettyTable()
191 | table.field_names = acc_true # acc_level
192 | table.add_row(acc_table_line)
193 | print(table)
194 | if (epoch+1)%10==0 and args.lr>=1e-4:
195 | args.lr=args.lr*0.9
196 | print('learning rate : ',args.lr)
197 | # Test inference after completion of training
198 | test_acc, test_loss = Evaluation(args, global_model, test_dataset)
199 | print(f' \n Results after {args.epochs} global rounds of training:')
200 | print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
201 | print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))
202 | print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))
203 | table = pt.PrettyTable()
204 | table.field_names = acc_true
205 | table.add_row(acc_table_line)
206 | test_acc=0
207 | if len(acc_table_line) == 0:
208 | print("cannot get the communication round for the target accuracy")
209 | else:
210 | table = pt.PrettyTable()
211 | table.field_names = acc_true
212 | table.add_row(acc_table_line)
213 | print(table)
214 | return train_accuracy[-1],test_acc
215 |
--------------------------------------------------------------------------------
/src/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.utils.data import DataLoader, Dataset
4 | import quantizer as qn
5 | from options import args_parser
6 | args = args_parser()
7 | class DatasetSplit(Dataset):
8 | """An abstract Dataset class wrapped around Pytorch Dataset class.
9 | """
10 | def __init__(self, dataset, idxs):
11 | self.dataset = dataset
12 | self.idxs = [int(i) for i in idxs]
13 |
14 | def __len__(self):
15 | return len(self.idxs)
16 |
17 | def __getitem__(self, item):
18 | image, label = self.dataset[self.idxs[item]]
19 | return torch.tensor(image), torch.tensor(label)
20 | class LocalUpdate(object):
21 | def __init__(self, args, dataset, idxs, logger, quant, quantbit, mode):
22 | self.args = args
23 | self.logger = logger
24 | self.trainloader = self.train_val_test(
25 | dataset, list(idxs))
26 | self.device = 'cuda'
27 | # Default criterion set to NLL loss function
28 | self.criterion = nn.NLLLoss().to(self.device)
29 | self.quantbit=quantbit
30 | self.quant=quant
31 | self.mode=mode
32 |
33 | def train_val_test(self, dataset, idxs):
34 | """
35 | Returns train, validation and test dataloaders for a given dataset
36 | and user indexes.
37 | """
38 | trainloader = DataLoader(DatasetSplit(dataset, idxs),
39 | batch_size=self.args.local_bs, shuffle=False)
40 | return trainloader
41 |
42 | def update_weights(self, model):
43 | # Set mode to train model
44 | model.train()
45 | epoch_loss = []
46 | # Set optimizer for the local updates
47 | if self.args.optimizer == 'sgd':
48 | optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
49 | momentum=self.args.momentum, weight_decay=args.weight_decay)
50 | elif self.args.optimizer == 'adam':
51 | optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr, weight_decay=args.weight_decay)
52 | for iter in range(self.args.local_ep):
53 | batch_loss = []
54 | for batch_idx, (images, labels) in enumerate(self.trainloader):
55 | images, labels = images.to(self.device), labels.to(self.device)
56 |
57 | optimizer.zero_grad()
58 | log_probs = model(images)
59 | loss = self.criterion(log_probs, labels)
60 | loss.backward()
61 | optimizer.step()
62 | # Weight quantization
63 | if self.quant != None:
64 | ansq=[]
65 | for name, p in model.named_parameters():
66 | quant_p, q, data_q = qn.q_quantize(p.data,self.quantbit,self.mode,self.quantbit)
67 | ansq.append(q.item())
68 | p.data=quant_p.data
69 | self.logger.add_scalar('loss', loss.item())
70 | batch_loss.append(loss.item())
71 | epoch_loss.append(sum(batch_loss)/len(batch_loss))
72 | return model.state_dict(), sum(epoch_loss) / len(epoch_loss), max(ansq)
73 |
74 | def inference(self, model):
75 | model.eval()
76 | loss, total, correct = 0.0, 0.0, 0.0
77 |
78 | for batch_idx, (images, labels) in enumerate(self.trainloader):
79 | images, labels = images.to(self.device), labels.to(self.device)
80 |
81 | # Inference
82 | outputs = model(images)
83 | batch_loss = self.criterion(outputs, labels)
84 | loss += batch_loss.item()
85 |
86 | # Prediction
87 | _, pred_labels = torch.max(outputs, 1)
88 | pred_labels = pred_labels.view(-1)
89 | correct += torch.sum(torch.eq(pred_labels, labels)).item()
90 | total += len(labels)
91 |
92 | accuracy = correct/total
93 | return accuracy, loss/total
94 |
95 |
96 | def Evaluation(args, model, test_dataset):
97 | """ Returns the test accuracy and loss.
98 | """
99 |
100 | model.eval()
101 | loss, total, correct = 0.0, 0.0, 0.0
102 |
103 | device = 'cuda' if args.gpu else 'cpu'
104 | criterion = nn.NLLLoss().to(device)
105 | testloader = DataLoader(test_dataset, batch_size=128,
106 | shuffle=False)
107 | cnt_loss=0
108 | for batch_idx, (images, labels) in enumerate(testloader):
109 | images, labels = images.to(device), labels.to(device)
110 |
111 | # Inference
112 | outputs = model(images)
113 | batch_loss = criterion(outputs, labels)
114 | loss += batch_loss.item()
115 |
116 | # Prediction
117 | _, pred_labels = torch.max(outputs, 1)
118 | pred_labels = pred_labels.view(-1)
119 | correct += torch.sum(torch.eq(pred_labels, labels)).item()
120 | total += len(labels)
121 | cnt_loss+=1
122 |
123 | accuracy = correct/total
124 | return accuracy, loss/cnt_loss
125 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from torchvision import datasets, transforms
3 | from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
4 | from sampling import cifar_iid, cifar_noniid
5 | from options import args_parser
6 | import numpy as np
7 | args = args_parser()
8 |
9 | def get_dataset(args):
10 | """ Returns train and test datasets and a user group which is a dict where
11 | the keys are the user index and the values are the corresponding data for
12 | each of those users.
13 | """
14 | if args.dataset == 'cifar':
15 | data_dir = '../data/cifar/'
16 | transform_train = transforms.Compose([
17 | transforms.RandomCrop(32, padding=4),
18 | transforms.RandomHorizontalFlip(),
19 | transforms.ToTensor(),
20 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
21 | ])
22 |
23 | transform_test = transforms.Compose([
24 | transforms.ToTensor(),
25 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
26 | ])
27 | train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
28 | transform=transform_train)
29 | test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
30 | transform=transform_test)
31 |
32 | # sample training data amongst users
33 | if args.iid:
34 | # Sample IID user data from Mnist
35 | user_groups = cifar_iid(train_dataset, args.num_users)
36 | else:
37 | # Sample Non-IID user data from Mnist
38 | if args.unequal:
39 | # Chose uneuqal splits for every user
40 | #raise NotImplementedError()
41 | user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
42 | else:
43 | # Chose euqal splits for every user
44 | user_groups = cifar_noniid(train_dataset, args.num_users)
45 |
46 | elif args.dataset == 'mnist':
47 | data_dir = '../data/mnist/'
48 |
49 | apply_transform = transforms.Compose([
50 | transforms.ToTensor(),
51 | transforms.Normalize((0.1307,), (0.3081,))])
52 |
53 | train_dataset = datasets.MNIST(data_dir, train=True, download=True,
54 | transform=apply_transform)
55 |
56 | test_dataset = datasets.MNIST(data_dir, train=False, download=True,
57 | transform=apply_transform)
58 |
59 | # sample training data amongst users
60 | if args.iid:
61 | # Sample IID user data from Mnist
62 | user_groups = mnist_iid(train_dataset, args.num_users)
63 | else:
64 | # Sample Non-IID user data from Mnist
65 | if args.unequal:
66 | # Chose uneuqal splits for every user
67 | user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
68 | else:
69 | # Chose euqal splits for every user
70 | user_groups = mnist_noniid(train_dataset, args.num_users)
71 |
72 | return train_dataset, test_dataset, user_groups
73 |
74 | def get_quantization_bit(args):
75 | quant_bit_for_user = np.zeros([args.num_users])
76 | avg_for_user = np.zeros([args.num_users])
77 |
78 | for user in range(args.num_users):
79 | if user < args.num_users * args.bit_4_ratio:
80 | quant_bit_for_user[user] = 4
81 | avg_for_user[user] = 4
82 | elif user < args.num_users * (args.bit_4_ratio + args.bit_8_ratio):
83 | quant_bit_for_user[user] = 8
84 | avg_for_user[user] = 8
85 | else:
86 | avg_for_user[user] = 64
87 | return quant_bit_for_user, avg_for_user
88 |
89 | def average_weights(args,w,avg_for_user,q_for_user):
90 | """
91 | Returns the average of the weights.
92 | """
93 | w_avg = copy.deepcopy(w[0])
94 | user_num=len(avg_for_user)
95 | # Calculate the p for each user, the sum of p is 1
96 | p_for_user=np.ones(user_num)
97 | if args.average_scheme == 'FedAvg':
98 | p_for_user /= sum(p_for_user)
99 | if args.average_scheme == 'Proportional':
100 | p_for_user = np.array(avg_for_user) / np.sum(avg_for_user)
101 | if args.average_scheme == 'FedHQ':
102 | p_for_user = 1 / (1 + q_for_user)
103 | p_for_user /= np.sum(p_for_user)
104 | for key in w_avg:
105 | w_avg[key] *= p_for_user[0]
106 | for i in range(1, len(w)):
107 | weight_name = []
108 | for j in w[i]:
109 | weight_name.append(j)
110 | cnt=0
111 | for key in w_avg.keys():
112 | w_avg[key] += w[i][weight_name[cnt]] * p_for_user[i]
113 | cnt += 1
114 | return w_avg
115 |
116 |
117 | def exp_details(args):
118 | print('\nExperimental details:')
119 | print(f' Dataset : {args.dataset}')
120 | model='CNN'
121 | if args.dataset=='cifar' and args.iid==1:
122 | model='ResNet18'
123 | if args.dataset == 'cifar' and args.iid == 0:
124 | model = 'VGG11'
125 |
126 | print(f' Model : {model}')
127 | print(f' Optimizer : {args.optimizer}')
128 | print(f' Learning : {args.lr}')
129 | print(f' Global Rounds : {args.epochs}\n')
130 |
131 | print(' Federated parameters:')
132 | if args.iid:
133 | print(' IID')
134 | else:
135 | print(' Non-IID')
136 | print(f' Fraction of users : {args.frac}')
137 | print(f' Local Batch size : {args.local_bs}')
138 | print(f' Local Epochs : {args.local_ep}\n')
139 | return
140 |
--------------------------------------------------------------------------------