├── README.md ├── requirements.txt └── src ├── FedHQ_main.py ├── __pycache__ ├── cifar_model.cpython-37.pyc ├── models.cpython-37.pyc ├── models_without_quant.cpython-37.pyc ├── options.cpython-37.pyc ├── quantizer.cpython-37.pyc ├── quantizer2.cpython-37.pyc ├── sampling.cpython-37.pyc ├── train.cpython-37.pyc ├── update.cpython-37.pyc └── utils.cpython-37.pyc ├── cifar_model.py ├── models.py ├── options.py ├── quantizer.py ├── sampling.py ├── train.py ├── update.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic Aggregation for Heterogeneous Quantization in Federated Learning (PyTorch) 2 | 3 | Implementation of dynamic aggregation for heterogeneous quantization in federated learning. 4 | 5 | ## References 6 | The experiements refer to the papers as following. The paper link and GitHub link are given. 7 | ### Papers: 8 | * [Communication-Efficient Learning of Deep Networks from Decentralized Data](https://arxiv.org/abs/1602.05629) : [GitHub](https://github.com/AshwinRJ/Federated-Learning-PyTorch) 9 | * [SWALP: Stochastic Weight Averaging in Low-Precision Training](https://arxiv.org/abs/1904.11943v2) : [GitHub](https://github.com/stevenygd/SWALP) 10 | 11 | ## Requirements 12 | Requirments.txt gives the detail requirements. 13 | * Python3 14 | * Pytorch 15 | * Torchvision 16 | 17 | ## Data 18 | * The experiments of FedHQ are run on MNIST and Cifar. 19 | * You can choose download the data through the code. 20 | 21 | ## Options 22 | #### FedHQ Parameters 23 | * ```--epochs:``` Number of communication rounds (T in the paper). Default is 150. 24 | * ```--num_users:```Number of clients (n in the paper). Default is 100. 25 | * ```--frac:``` Fraction of users to be used for federated updates (C in the paper). Default is 0.1. 26 | * ```--local_ep:``` Number of local training epochs in each user (K in the paper). Default is 1. 27 | * ```--local_bs:``` Batch size of local updates in each user (B in the paper). Default is 600. 28 | * ```--lr:``` Learning rate (η in the paper). Default is 0.1. 29 | * ```--optimizer: ``` The optimizer used. Default is sgd. 30 | * ```--momentum:``` Momentum of optimizer (M in the paper). Default is 0.5. 31 | * ```--weight_decay:``` Weight decay of optimizer (λ in the paper). Default: 0.0005. 32 | * ```--average_scheme:``` Decide the average scheme. Default is FedHQ. 33 | * ```--dataset:``` Name of dataset. Default is mnist. 34 | * ```--gpu:``` To use CPU or GPU. Default set 1 to use GPU. 35 | * ```--iid:``` Distribution of data amongst clients. Default set 1 for IID. 36 | * ```--bit_4_ratio:``` The ratio for 4-bit quantization clients. 37 | * ```--bit_8_ratio:``` The ratio for 8-bit quantization clients. 38 | 39 | In our experiment, the sum of 'bit_4_ratio' and 'bit_8_ratio' is 1. 40 | 41 | ## FedHQ Experiments 42 | The detail results of our experiment refer to the Section 6 of the paper. All the commands are given when running directory is FedHQ folder. 43 | #### Results on MNIST: 44 | * To run the FedHQ experiment with MNIST under IID condition using GPU: 45 | ``` 46 | python src/FedHQ_main.py --dataset=mnist --frac=1 --local_bs=600 --average_scheme=FedHQ --bit_4_ratio=0 --bit_8_ratio=1 47 | ``` 48 | * To run the FedHQ experiment with MNIST under non-IID condition using GPU: 49 | ``` 50 | python src/FedHQ_main.py --dataset=mnist --iid=0 --frac=1 --local_ep=1 --local_bs=600 --average_scheme=FedHQ --bit_4_ratio=0 --bit_8_ratio=1 51 | ``` 52 | Parameters setting as follows(only list the parameters differing from default): 53 | * ```frac: ``` 1
54 | Learning-rate decay is 0.9 per ten rounds. The ratios of 4-bit quantization clients are [0,0.2,0.4,0.6,0.8,1]. 55 | 56 | ```Table 1:``` Number of communication round to reach different target accuracy on MNIST dataset, IID partition. 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 |
Quantizationbits:
ratio
SchemesAccuracy
60%70%80%90%92%94%95%
4-bit:0
8-bit:1
FegAvg13152533424665
FedHQ+13142235394763
4-bit:0.2
8-bit:0.8
FegAvg12181932425482
Proportional15212232385073
FedHQ+13152533354761
4-bit:0.4
8-bit:0.6
FegAvg172224424562104
Proportional11172237415383
FedHQ+12172534385061
4-bit:0.6
8-bit:0.4
FegAvg1331*****
Proportional112735****
FedHQ+18192040424666
4-bit:0.8
8-bit:0.2
FegAvg21******
Proportional162451****
FedHQ+13182335475369
4-bit:1
8-bit:0
FegAvg1432*****
FedHQ+1620325279**
239 | 240 | ```Table 2:``` Number of communication round to reach different target accuracy on MNIST dataset, non-IID partition. 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 |
Quantizationbits:
ratio
SchemesAccuracy
60%70%80%90%92%94%95%
4-bit:0
8-bit:1
FegAvg12182639495575
FedHQ+11192232365574
4-bit:0.2
8-bit:0.8
FegAvg13171941436096
Proportional14182234445792
FedHQ+13202841436096
4-bit:0.4
8-bit:0.6
FegAvg192325435176131
Proportional202223424668119
FedHQ+13172542435587
4-bit:0.6
8-bit:0.4
FegAvg2137*****
Proportional1942*****
FedHQ+192631515987133
4-bit:0.8
8-bit:0.2
FegAvg2242*****
Proportional164150****
FedHQ+212331****
4-bit:1
8-bit:0
FegAvg1742*****
FedHQ+1939*****
423 | 424 | #### Results on CIFAR10: 425 | 426 | * To run the FedHQ experiment with MNIST under IID condition using GPU: 427 | ``` 428 | python src/FedHQ_main.py --dataset=cifar --epochs=300 --frac=0.1 --local_ep=5 --local_bs=128 --average_scheme=FedHQ --bit_4_ratio=0 --bit_8_ratio=1 429 | ``` 430 | * To run the FedHQ experiment with MNIST under non-IID condition using GPU: 431 | ``` 432 | python src/FedHQ_main.py --dataset=cifar --epochs=150 --iid=0 --frac=0.1 --local_ep=5 --local_bs=64 --momentum=0.2 --average_scheme=FedHQ --bit_4_ratio=0 --bit_8_ratio=1 433 | ``` 434 | Parameters setting as follows(only list the parameters differing from default): 435 | * ```--epochs:``` 300 for IID. 150 for non-IID. 436 | * ```frac: ``` 0.1 437 | * ```local_ep: ``` 5 438 | * ```local_bs: ``` 128
439 | Learning-rate decay is 0.9 per ten rounds. The ratios of 4-bit quantization clients are [0,0.2,0.4,0.6,0.8,1]. 440 | 441 | ```Table 3:``` Number of communication round to reach different target accuracy on CIFAR dataset, IID partition. 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 549 | 550 | 551 | 552 | 553 | 554 | 555 | 556 | 557 | 558 | 559 | 560 | 561 | 562 | 563 | 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 574 | 575 | 576 | 577 | 578 | 579 | 580 | 581 | 582 | 583 | 584 | 585 | 586 | 587 | 588 | 589 | 590 | 591 | 592 | 593 | 594 | 595 | 596 | 597 | 598 | 599 | 600 | 601 | 602 | 603 | 604 | 605 | 606 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | 614 | 615 | 616 | 617 | 618 | 619 | 620 | 621 | 622 | 623 | 624 | 625 | 626 | 627 | 628 | 629 | 630 | 631 | 632 | 633 | 634 |
Quantizationbits:
ratio
SchemesAccuracy
60%70%80%82%84%86%
4-bit:0
8-bit:1
FegAvg132245536994
FedHQ+122242536894
4-bit:0.2
8-bit:0.8
FegAvg58126****
Proportional3158179285**
FedHQ+1426566997133
4-bit:0.3
8-bit:0.7
FegAvg144276****
Proportional73109****
FedHQ+1323516692126
4-bit:0.4
8-bit:0.6
FegAvg******
Proportional119227****
FedHQ+1725577698199
4-bit:0.6
8-bit:0.4
FegAvg******
Proportional******
FedHQ+1833100184**
4-bit:0.8
8-bit:0.2
FegAvg******
Proportional******
FedHQ+84*****
4-bit:1
8-bit:0
FegAvg******
FedHQ+******
635 | 636 | ```Table 4:``` Number of communication round to reach different target accuracy on CIFAR dataset, non-IID partition. 637 | 638 | 639 | 640 | 641 | 642 | 643 | 644 | 645 | 646 | 647 | 648 | 649 | 650 | 651 | 652 | 653 | 654 | 655 | 656 | 657 | 658 | 659 | 660 | 661 | 662 | 663 | 664 | 665 | 666 | 667 | 668 | 669 | 670 | 671 | 672 | 673 | 674 | 675 | 676 | 677 | 678 | 679 | 680 | 681 | 682 | 683 | 684 | 685 | 686 | 687 | 688 | 689 | 690 | 691 | 692 | 693 | 694 | 695 | 696 | 697 | 698 | 699 | 700 | 701 | 702 | 703 | 704 | 705 | 706 | 707 | 708 | 709 | 710 | 711 | 712 | 713 | 714 | 715 | 716 | 717 | 718 | 719 | 720 | 721 | 722 | 723 | 724 | 725 | 726 | 727 | 728 | 729 | 730 | 731 | 732 | 733 | 734 | 735 | 736 | 737 | 738 | 739 | 740 | 741 | 742 | 743 | 744 | 745 | 746 | 747 | 748 | 749 | 750 | 751 | 752 | 753 | 754 | 755 | 756 | 757 | 758 | 759 | 760 | 761 | 762 | 763 | 764 | 765 | 766 | 767 | 768 | 769 | 770 | 771 | 772 | 773 | 774 | 775 | 776 | 777 | 778 | 779 | 780 | 781 | 782 | 783 | 784 | 785 | 786 | 787 | 788 | 789 | 790 | 791 | 792 | 793 | 794 | 795 | 796 | 797 | 798 | 799 | 800 | 801 | 802 | 803 | 804 | 805 | 806 | 807 | 808 | 809 | 810 | 811 | 812 | 813 | 814 | 815 | 816 | 817 | 818 | 819 | 820 | 821 | 822 | 823 | 824 | 825 | 826 | 827 | 828 | 829 |
Quantizationbits:
ratio
SchemesAccuracy
30%35%40%45%50%55%
4-bit:0
8-bit:1
FegAvg91515304873
FedHQ+91415274871
4-bit:0.2
8-bit:0.8
FegAvg314893***
Proportional18273893**
FedHQ+1857608393*
4-bit:0.3
8-bit:0.7
FegAvg48110****
Proportional14273893**
FedHQ+1114183888110
4-bit:0.4
8-bit:0.6
FegAvg******
Proportional93*****
FedHQ+2249606093*
4-bit:0.6
8-bit:0.4
FegAvg******
Proportional******
FedHQ+334062***
4-bit:0.8
8-bit:0.2
FegAvg******
Proportional******
FedHQ+117*****
4-bit:1
8-bit:0
FegAvg******
FedHQ+******
830 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python=3.7.3 2 | pytorch=1.2.0 3 | torchvision=0.4.0 4 | numpy=1.16.2 5 | tensorboardx=1.9 6 | matplotlib=3.0.3 7 | tqdm=4.31.1 8 | prettytable=0.7.2 9 | -------------------------------------------------------------------------------- /src/FedHQ_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from options import args_parser 3 | from train import train 4 | args = args_parser() 5 | 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 7 | 8 | if __name__ == '__main__': 9 | train() -------------------------------------------------------------------------------- /src/__pycache__/cifar_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/cifar_model.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/models_without_quant.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/models_without_quant.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/options.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/quantizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/quantizer.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/quantizer2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/quantizer2.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/sampling.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/sampling.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/train.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/train.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/update.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/update.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/private-monstera/data-aggregation-federated-learning/f5f58fa0e8ebc7314314496222009cde38c12f1e/src/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /src/cifar_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | import math 4 | 5 | def make_layers_Cifar10(cfg, quant, batch_norm=False, conv=nn.Conv2d): 6 | layers = list() 7 | in_channels = 3 8 | n = 1 9 | for v in cfg: 10 | if v == 'M': 11 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 12 | else: 13 | use_quant = v[-1] != 'N' 14 | filters = int(v) if use_quant else int(v[:-1]) 15 | conv2d = conv(in_channels, filters, kernel_size=3, padding=1, bias=False) 16 | if batch_norm: 17 | layers += [conv2d, nn.BatchNorm2d(filters), nn.ReLU(True)] 18 | else: 19 | layers += [conv2d, nn.ReLU()] 20 | if quant!=None: layers += [quant()] 21 | n += 1 22 | in_channels = filters 23 | return nn.Sequential(*layers) 24 | 25 | class CNNCifar(nn.Module): 26 | def __init__(self, args,quant): 27 | self.args=args 28 | super(CNNCifar, self).__init__() 29 | self.linear = nn.Linear 30 | cfg = { 31 | 9: ['64', '64', 'M', '128', '128', 'M', '256', '256', 'M'], 32 | 11: ['64', 'M', '128', 'M', '256', '256', 'M', '512', '512', 'M', '512', '512', 'M'], 33 | 13: ['64', '64', 'M', '128', '128', 'M', '256', '256', 'M', '512', '512', 'M', '512', '512', 'M'], 34 | 16: ['64', '64', 'M', '128', '128', 'M', '256', '256', '256', 'M', '512', '512', '512', 'M', '512', '512', '512', 'M'], 35 | } 36 | self.conv = nn.Conv2d 37 | self.features = make_layers_Cifar10(cfg[16], quant, True, self.conv) 38 | self.classifier=None 39 | if quant!=None: 40 | self.classifier = nn.Sequential( 41 | nn.Dropout(), 42 | self.linear(512 * 1 * 1, 4096), 43 | nn.ReLU(True), 44 | quant(), 45 | self.linear(4096, 4096), 46 | nn.ReLU(True), 47 | quant(), 48 | self.linear(4096, args.num_classes), 49 | nn.ReLU(True), 50 | quant(), 51 | nn.LogSoftmax(dim=1) 52 | ) 53 | else: 54 | self.classifier = nn.Sequential( 55 | nn.Dropout(), 56 | self.linear(512 * 1 * 1, 4096), 57 | nn.ReLU(True), 58 | self.linear(4096, 4096), 59 | nn.ReLU(True), 60 | self.linear(4096, args.num_classes), 61 | nn.ReLU(True), 62 | nn.LogSoftmax(dim=1) 63 | ) 64 | for m in self.modules(): 65 | if isinstance(m, nn.Conv2d): 66 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 67 | m.weight.data.normal_(0, math.sqrt(2. / n)) 68 | def forward(self, x): 69 | x = self.features(x) 70 | x = x.view(-1, 512 * 1 * 1) 71 | x = self.classifier(x) 72 | return x 73 | 74 | 75 | class ResnetCifar18(nn.Module): 76 | def __init__(self, quant, quantx, in_channel, out_channel, strides): 77 | super(ResnetCifar18,self).__init__() 78 | self.block=None 79 | self.residual=nn.Sequential() 80 | self.quantx=quantx 81 | if quant==None: 82 | self.block=nn.Sequential( 83 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=strides, padding=1, bias=False), 84 | nn.BatchNorm2d(out_channel), 85 | nn.ReLU(True), 86 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False), 87 | nn.BatchNorm2d(out_channel) 88 | ) 89 | if strides!=1 or in_channel!=out_channel: 90 | self.residual=nn.Sequential( 91 | nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides,bias=False), 92 | nn.BatchNorm2d(out_channel) 93 | ) 94 | else: 95 | self.block = nn.Sequential( 96 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=strides, padding=1, bias=False), 97 | nn.BatchNorm2d(out_channel), 98 | nn.ReLU(True), 99 | quant(), 100 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False), 101 | nn.BatchNorm2d(out_channel), 102 | quant() 103 | ) 104 | self.residual = nn.Sequential() 105 | if strides != 1 or in_channel != out_channel: 106 | self.residual = nn.Sequential( 107 | nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides, bias=False), 108 | nn.BatchNorm2d(out_channel), 109 | quant() 110 | ) 111 | def forward(self,x): 112 | out=self.block(x) 113 | out+=self.residual(x) 114 | out=F.relu(out) 115 | if self.quantx!=None: 116 | out=self.quantx(out) 117 | return out 118 | 119 | class ResNet(nn.Module): 120 | def __init__(self, args, quant, quantx): 121 | super(ResNet,self).__init__() 122 | self.in_channel=64 123 | self.quantx=quantx 124 | self.conv1=None 125 | if quant==None: 126 | self.conv1=nn.Sequential( 127 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False), 128 | nn.BatchNorm2d(64), 129 | nn.ReLU(True)# , 130 | # nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 131 | ) 132 | else: 133 | self.conv1 = nn.Sequential( 134 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False), 135 | nn.BatchNorm2d(64), 136 | nn.ReLU(True), 137 | quant(), 138 | # nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 139 | ) 140 | self.layer1 = self.make_layer(quant, quantx, 64, 2, stride=1) 141 | self.layer2 = self.make_layer(quant, quantx, 128, 2, stride=2) 142 | self.layer3 = self.make_layer(quant, quantx, 256, 2, stride=2) 143 | self.layer4 = self.make_layer(quant, quantx, 512, 2, stride=2) 144 | self.fc=nn.Linear(512, args.num_classes) 145 | 146 | def make_layer(self, quant, quantx, channel, num_blocks, stride): 147 | strides=[stride] + [1]*(num_blocks-1) 148 | layers=[] 149 | for stride in strides: 150 | layers.append(ResnetCifar18(quant, quantx, self.in_channel, channel, stride)) 151 | self.in_channel=channel 152 | return nn.Sequential(*layers) 153 | def forward(self, x): 154 | out = self.conv1(x) 155 | out = self.layer1(out) 156 | out = self.layer2(out) 157 | out = self.layer3(out) 158 | out = self.layer4(out) 159 | out = F.avg_pool2d(out, 4) 160 | if self.quantx != None: 161 | out = self.quantx(out) 162 | out = out.view(out.size(0), -1) 163 | out = self.fc(out) 164 | if self.quantx != None: 165 | out = self.quantx(out) 166 | out = F.log_softmax(out,dim=1) 167 | return out 168 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import math 3 | 4 | def make_layers_Mnist(cfg, quant, batch_norm=False, conv=nn.Conv2d): 5 | layers = list() 6 | in_channels = 1 7 | n = 1 8 | for v in cfg: 9 | if v == 'M': 10 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 11 | else: 12 | use_quant = v[-1] != 'N' 13 | filters = int(v) if use_quant else int(v[:-1]) 14 | conv2d = conv(in_channels, filters, kernel_size=3, padding=1) 15 | if batch_norm: 16 | layers += [conv2d, nn.BatchNorm2d(filters), nn.ReLU()] 17 | else: 18 | layers += [conv2d, nn.ReLU()] 19 | if quant!=None: layers += [quant()] 20 | n += 1 21 | in_channels = filters 22 | return nn.Sequential(*layers) 23 | 24 | class CNNMnist(nn.Module): 25 | def __init__(self, args,quant): 26 | self.args=args 27 | super(CNNMnist, self).__init__() 28 | self.linear = nn.Linear 29 | cfg = { 30 | 16: ['16', 'M', '32', 'M'] 31 | } 32 | self.conv = nn.Conv2d 33 | self.features = make_layers_Mnist(cfg[16], quant, True, self.conv) 34 | self.classifier=None 35 | if quant!=None: 36 | self.classifier = nn.Sequential( 37 | nn.Dropout(), 38 | self.linear(7*7*32, 512), 39 | nn.ReLU(True), 40 | quant(), 41 | self.linear(512, 10), 42 | nn.ReLU(True), 43 | quant(), 44 | nn.LogSoftmax(dim=1), 45 | ) 46 | else: 47 | self.classifier = nn.Sequential( 48 | nn.Dropout(), 49 | self.linear(7 * 7 * 32, 512), 50 | nn.ReLU(True), 51 | self.linear(512, 10), 52 | nn.ReLU(True), 53 | nn.LogSoftmax(dim=1) 54 | ) 55 | for m in self.modules(): 56 | if isinstance(m, nn.Conv2d): 57 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 58 | m.weight.data.normal_(0, math.sqrt(2. / n)) 59 | m.bias.data.zero_() 60 | def forward(self, x): 61 | x = self.features(x) 62 | x = x.view(x.size(0), -1) 63 | x = self.classifier(x) 64 | return x -------------------------------------------------------------------------------- /src/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def args_parser(): 5 | parser = argparse.ArgumentParser() 6 | 7 | # federated arguments (Notation for the arguments fllowed from paper) 8 | parser.add_argument('--epochs', type=int, default=150, 9 | help="number of rounds of training") 10 | parser.add_argument('--num_users', type=int, default=100, 11 | help="number of clie: n") 12 | parser.add_argument('--frac', type=float, default=1, 13 | help='the fraction of clients: C') 14 | parser.add_argument('--local_ep', type=int, default=1, 15 | help="the number of local epochs: K") 16 | parser.add_argument('--local_bs', type=int, default=600, 17 | help="local batch size: B") 18 | parser.add_argument('--lr', type=float, default=0.1, 19 | help='learning rate') 20 | parser.add_argument('--momentum', type=float, default=0.5, 21 | help='SGD momentum (default: 0.5)') 22 | parser.add_argument('--weight_decay', type=float, default=0.0005, 23 | help='weight decay of optimizer (default: 0.0005)') 24 | parser.add_argument('--average_scheme', type=str, default='FedHQ', help='choose average scheme', 25 | choices=['FedAvg','Proportional','FedHQ']) 26 | parser.add_argument('--quant_bits', type=int, default=8, help='record the current quantization bit') 27 | parser.add_argument('--bit_4_ratio', type=float, default=0.6, help='the ratio for 4-bit clients') 28 | parser.add_argument('--bit_8_ratio', type=float, default=0.4, help='the ratio for 8-bit clients') 29 | 30 | # other arguments 31 | parser.add_argument('--dataset', type=str, default='mnist', help="name \ 32 | of dataset") 33 | parser.add_argument('--num_classes', type=int, default=10, help="number \ 34 | of classes") 35 | parser.add_argument('--gpu', default=1, help="To use CPU or GPU. Default set to use GPU.") 36 | parser.add_argument('--optimizer', type=str, default='sgd', help="type \ 37 | of optimizer") 38 | parser.add_argument('--iid', type=int, default=1, 39 | help='Default set to IID. Set to 0 for non-IID.') 40 | parser.add_argument('--unequal', type=int, default=0, 41 | help='whether to use unequal data splits for \ 42 | non-i.i.d setting (use 0 for equal splits)') 43 | parser.add_argument('--stopping_rounds', type=int, default=10, 44 | help='rounds of early stopping') 45 | parser.add_argument('--verbose', type=int, default=1, help='verbose') 46 | parser.add_argument('--seed', type=int, default=1, help='random seed') 47 | 48 | parser.add_argument('--dir', type=str, default=None, 49 | help='training directory (default: None)') 50 | parser.add_argument('--data_path', type=str, default="./data", required=False, metavar='PATH', 51 | help='path to datasets location (default: "./data")') 52 | parser.add_argument('--num_workers', type=int, default=0, metavar='N', 53 | help='number of workers (default: 0)') 54 | parser.add_argument('--log_name', type=str, default='', metavar='S', 55 | help="Name for the log dir") 56 | parser.add_argument('--quant_type', type=str, default='stochastic', metavar='S', 57 | help='rounding method, stochastic or nearest ', choices=['stochastic', 'nearest']) 58 | 59 | args = parser.parse_args() 60 | return args 61 | -------------------------------------------------------------------------------- /src/quantizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | from options import args_parser 5 | args = args_parser() 6 | def add_r_(data): 7 | r = torch.rand_like(data) 8 | data.add_(r) 9 | 10 | def _round(data, sigma, t_min, t_max, mode, clip=True): 11 | """ 12 | Quantzie a Tensor. 13 | """ 14 | temp = data / sigma 15 | if mode=="nearest": 16 | temp = temp.round() 17 | elif mode=="stochastic": 18 | add_r_(temp) 19 | temp.floor_() 20 | else: raise ValueError("Invalid quantization mode: {}".format(mode)) 21 | temp *= sigma 22 | if clip: temp.clamp_(t_min, t_max) 23 | return temp 24 | 25 | def block_quantize(data, bits, mode, ebit): 26 | max_exponent = torch.floor(torch.log2(torch.abs(torch.where(data==torch.zeros_like(data), torch.ones_like(data), data)))) 27 | # Suppose we allocate W bits to represent each number in the block and F bits to represent the shared exponent. 28 | max_exponent.clamp_(-2 ** (ebit - 1), 2 ** (ebit - 1) - 1) 29 | i = data * 2**(-max_exponent+(bits-2)) 30 | if mode == "stochastic": 31 | add_r_(i) 32 | i.floor_() 33 | elif mode == "nearest": 34 | i.round_() 35 | i.clamp_(-2**(bits-1), 2**(bits-1)-1) 36 | temp = i * 2**(max_exponent-(bits-2)) 37 | return temp 38 | def q_quantize(data, bits, mode, ebit): 39 | max_exponent = torch.floor(torch.log2(torch.abs(torch.where(data==torch.zeros_like(data), torch.ones_like(data), data)))) 40 | max_exponent.clamp_(-2 ** (ebit - 1), 2 ** (ebit - 1) - 1) 41 | i = data * 2**(-max_exponent+(bits-2)) 42 | cur_exp = 2 ** (max_exponent - (bits - 2)) 43 | p4left = 1 - i % 1 44 | p4right = i % 1 45 | q_n_left = torch.floor(i).clamp_(-2 ** (bits - 1), 2 ** (bits - 1) - 1) * cur_exp 46 | q_n_right = torch.ceil(i).clamp_(-2 ** (bits - 1), 2 ** (bits - 1) - 1) * cur_exp 47 | e_q = torch.pow(q_n_left - data, 2) * p4left + torch.pow(q_n_right - data, 2) * p4right 48 | powdata=torch.pow(data, 2) 49 | q = e_q / powdata 50 | end_q=torch.where(data==torch.zeros_like(data),torch.zeros_like(data),q) 51 | if mode == "stochastic": 52 | add_r_(i) 53 | i.floor_() 54 | elif mode == "nearest": 55 | i.round_() 56 | i.clamp_(-2**(bits-1), 2**(bits-1)-1) 57 | temp = i * 2**(max_exponent-(bits-2)) 58 | max_q=torch.max(end_q) 59 | ind=torch.where(torch.abs(end_q-max_q)>=1e-9, -1000*torch.ones_like(end_q), data) 60 | data_q=torch.max(ind) 61 | return temp, max_q, torch.max(data_q) 62 | 63 | class BlockRounding(torch.autograd.Function): 64 | @staticmethod 65 | def forward(self, x, bits, ebits, mode): 66 | self.ebits = ebits 67 | self.bits=bits 68 | self.mode = mode 69 | if bits == -1: return x 70 | return block_quantize(x, bits, self.mode, ebits) 71 | 72 | @staticmethod 73 | def backward(self, grad_output): 74 | if self.needs_input_grad[0]: 75 | if self.bits != -1: 76 | grad_input = block_quantize(grad_output, self.bits, self.mode, self.ebits) 77 | else: 78 | grad_input = grad_output 79 | return grad_input, None, None, None, None 80 | 81 | quantize_block = BlockRounding.apply 82 | 83 | class BlockQuantizer(nn.Module): 84 | def __init__(self, bits, ebits, mode): 85 | super(BlockQuantizer, self).__init__() 86 | self.bits = bits 87 | self.ebits = ebits 88 | self.mode = mode 89 | 90 | def forward(self, x): 91 | return quantize_block(x, self.bits,self.ebits, self.mode) 92 | -------------------------------------------------------------------------------- /src/sampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def mnist_iid(dataset, num_users): 4 | """ 5 | Sample I.I.D. client data from MNIST dataset 6 | :param: dataset 7 | :param: num_users 8 | :return: dict of image index 9 | """ 10 | num_items = int(len(dataset)/num_users) 11 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 12 | for i in range(num_users): 13 | dict_users[i] = set(np.random.choice(all_idxs, num_items, 14 | replace=False)) 15 | all_idxs = list(set(all_idxs) - dict_users[i]) 16 | return dict_users 17 | 18 | 19 | def mnist_noniid(dataset, num_users): 20 | """ 21 | Sample non-I.I.D client data from MNIST dataset 22 | :param: dataset 23 | :param: num_users 24 | :return: dict of image index 25 | """ 26 | # 60,000 training imgs --> 200 imgs/shard X 300 shards 27 | num_shards, num_imgs = 200, 300 28 | idx_shard = [i for i in range(num_shards)] 29 | dict_users = {i: np.array([]) for i in range(num_users)} 30 | idxs = np.arange(num_shards*num_imgs) 31 | labels = dataset.train_labels.numpy() 32 | 33 | # sort labels 34 | idxs_labels = np.vstack((idxs, labels)) 35 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] 36 | idxs = idxs_labels[0, :] 37 | 38 | # divide and assign 2 shards/client 39 | for i in range(num_users): 40 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 41 | idx_shard = list(set(idx_shard) - rand_set) 42 | for rand in rand_set: 43 | dict_users[i] = np.concatenate( 44 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 45 | return dict_users 46 | 47 | 48 | def mnist_noniid_unequal(dataset, num_users): 49 | """ 50 | Sample non-I.I.D client data from MNIST dataset s.t clients 51 | have unequal amount of data 52 | :param: dataset 53 | :param: num_users 54 | :returns: a dict of clients with each clients assigned certain 55 | number of training imgs 56 | """ 57 | # 60,000 training imgs --> 50 imgs/shard X 1200 shards 58 | num_shards, num_imgs = 1200, 50 59 | idx_shard = [i for i in range(num_shards)] 60 | dict_users = {i: np.array([]) for i in range(num_users)} 61 | idxs = np.arange(num_shards*num_imgs) 62 | labels = dataset.train_labels.numpy() 63 | 64 | # sort labels 65 | idxs_labels = np.vstack((idxs, labels)) 66 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] 67 | idxs = idxs_labels[0, :] 68 | 69 | # Minimum and maximum shards assigned per client: 70 | min_shard = 1 71 | max_shard = 30 72 | 73 | # Divide the shards into random chunks for every client 74 | # s.t the sum of these chunks = num_shards 75 | random_shard_size = np.random.randint(min_shard, max_shard+1, 76 | size=num_users) 77 | random_shard_size = np.around(random_shard_size / 78 | sum(random_shard_size) * num_shards) 79 | random_shard_size = random_shard_size.astype(int) 80 | 81 | # Assign the shards randomly to each client 82 | if sum(random_shard_size) > num_shards: 83 | 84 | for i in range(num_users): 85 | # First assign each client 1 shard to ensure every client has 86 | # atleast one shard of data 87 | rand_set = set(np.random.choice(idx_shard, 1, replace=False)) 88 | idx_shard = list(set(idx_shard) - rand_set) 89 | for rand in rand_set: 90 | dict_users[i] = np.concatenate( 91 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), 92 | axis=0) 93 | 94 | random_shard_size = random_shard_size-1 95 | 96 | # Next, randomly assign the remaining shards 97 | for i in range(num_users): 98 | if len(idx_shard) == 0: 99 | continue 100 | shard_size = random_shard_size[i] 101 | if shard_size > len(idx_shard): 102 | shard_size = len(idx_shard) 103 | rand_set = set(np.random.choice(idx_shard, shard_size, 104 | replace=False)) 105 | idx_shard = list(set(idx_shard) - rand_set) 106 | for rand in rand_set: 107 | dict_users[i] = np.concatenate( 108 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), 109 | axis=0) 110 | else: 111 | 112 | for i in range(num_users): 113 | shard_size = random_shard_size[i] 114 | rand_set = set(np.random.choice(idx_shard, shard_size, 115 | replace=False)) 116 | idx_shard = list(set(idx_shard) - rand_set) 117 | for rand in rand_set: 118 | dict_users[i] = np.concatenate( 119 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), 120 | axis=0) 121 | 122 | if len(idx_shard) > 0: 123 | # Add the leftover shards to the client with minimum images: 124 | shard_size = len(idx_shard) 125 | # Add the remaining shard to the client with lowest data 126 | k = min(dict_users, key=lambda x: len(dict_users.get(x))) 127 | rand_set = set(np.random.choice(idx_shard, shard_size, 128 | replace=False)) 129 | idx_shard = list(set(idx_shard) - rand_set) 130 | for rand in rand_set: 131 | dict_users[k] = np.concatenate( 132 | (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]), 133 | axis=0) 134 | 135 | return dict_users 136 | 137 | 138 | def cifar_iid(dataset, num_users): 139 | """ 140 | Sample I.I.D. client data from CIFAR10 dataset 141 | :param: dataset 142 | :param: num_users 143 | :return: dict of image index 144 | """ 145 | num_items = int(len(dataset)/num_users) 146 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 147 | for i in range(num_users): 148 | dict_users[i] = set(np.random.choice(all_idxs, num_items,replace=False)) 149 | all_idxs = list(set(all_idxs) - dict_users[i]) 150 | return dict_users 151 | 152 | 153 | def cifar_noniid(dataset, num_users): 154 | """ 155 | Sample non-I.I.D client data from CIFAR10 dataset 156 | :param: dataset 157 | :param: num_users 158 | :return: dict of image index 159 | """ 160 | num_shards, num_imgs = 200, 250 161 | idx_shard = [i for i in range(num_shards)] 162 | dict_users = {i: np.array([]) for i in range(num_users)} 163 | idxs = np.arange(num_shards*num_imgs) 164 | labels = np.array(dataset.targets) 165 | 166 | # sort labels 167 | idxs_labels = np.vstack((idxs, labels)) 168 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] 169 | idxs = idxs_labels[0, :] 170 | # divide and assign 171 | for i in range(num_users): 172 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 173 | idx_shard = list(set(idx_shard) - rand_set) 174 | for rand in rand_set: 175 | dict_users[i] = np.concatenate( 176 | (dict_users[i], idxs[int(rand)*num_imgs:(int(rand)+1)*num_imgs]), axis=0) 177 | return dict_users -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import time 4 | import numpy as np 5 | from tqdm import tqdm 6 | import quantizer as qn 7 | 8 | import torch 9 | from tensorboardX import SummaryWriter 10 | 11 | from update import LocalUpdate, Evaluation 12 | from models import CNNMnist 13 | from cifar_model import CNNCifar, ResNet 14 | from utils import get_dataset, average_weights, exp_details, get_quantization_bit 15 | from options import args_parser 16 | import prettytable as pt 17 | def make_dir(filename): 18 | if not os.path.exists(filename): 19 | os.makedirs(filename) 20 | def train(): 21 | start_time = time.time() 22 | 23 | # define paths 24 | path_project = os.path.abspath('..') 25 | logger = SummaryWriter('../logs') 26 | args = args_parser() 27 | exp_details(args) 28 | device = 'cuda' 29 | 30 | # load dataset and user groups 31 | train_dataset, test_dataset, user_groups = get_dataset(args) 32 | 33 | # Build model for different dataset 34 | if args.dataset == 'mnist': 35 | quant = lambda: qn.BlockQuantizer(args.quant_bits, args.quant_bits, args.quant_type) 36 | global_model = CNNMnist(args=args, quant=quant) 37 | elif args.dataset == 'cifar': 38 | quant = lambda: qn.BlockQuantizer(args.quant_bits, args.quant_bits, args.quant_type) 39 | quantx = lambda x: qn.BlockQuantizer(x, args.quant_bits, args.quant_bits, args.quant_type) 40 | if args.iid==1: 41 | global_model = ResNet(args=args, quant=quant, quantx=quantx) 42 | else: 43 | global_model = CNNCifar(args=args, quant=quant) 44 | # hold global weight get from server 45 | global_model.to(device) 46 | global_model.train() 47 | global_weights = global_model.state_dict() 48 | # Record training loss and accuracy 49 | train_loss, train_accuracy = [], [] 50 | print_every = 1 51 | acc_level = np.array(list(range(41)))*0.01+0.6 52 | acc_true=[] 53 | acc_table_line=[] 54 | acc_flag=np.zeros_like(acc_level) 55 | quant_bit_for_user, avg_for_user = get_quantization_bit(args) 56 | last_max_acc=0 57 | # Set the filename for saving results 58 | result_base_filename = 'result/' + args.dataset + '/iid/' + args.average_scheme 59 | if args.iid==0: 60 | result_base_filename = 'result/' + args.dataset + '/noniid/' + args.average_scheme 61 | save_acc_filename=result_base_filename + '/c'+str(int(args.frac*10))+'result_04-'+str(int(args.bit_4_ratio*10))+'_8-'+str(int(args.bit_8_ratio*10))+'.txt' 62 | save_pkl_filename = result_base_filename +'/c'+str(int(args.frac*10))+'result_04-' + str(int(args.bit_4_ratio * 10)) + '_8-' + str(int(args.bit_8_ratio * 10)) + '.pkl' 63 | 64 | make_dir(result_base_filename) 65 | 66 | for epoch in tqdm(range(args.epochs)): 67 | local_weights, local_losses = [], [] 68 | print(f'\n | Global Training Round : {epoch+1} |\n') 69 | m = max(int(args.frac * args.num_users), 1) 70 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 71 | cur_q=[] 72 | for idx in idxs_users: 73 | quantx=None 74 | quant = None 75 | idx = int(idx) 76 | # Set the quantization condition and local model for each client 77 | if quant_bit_for_user[idx]!=0: 78 | args.quant_bits = quant_bit_for_user[idx] 79 | quant = lambda: qn.BlockQuantizer(args.quant_bits, args.quant_bits, args.quant_type) 80 | quantx = lambda x: qn.quantize_block(x, args.quant_bits, args.quant_bits, args.quant_type) 81 | if args.dataset == 'mnist': 82 | user_model = CNNMnist(args=args, quant=quant) 83 | elif args.dataset == 'cifar' and args.iid==1: 84 | user_model = ResNet(args=args, quant=quant, quantx=quantx) 85 | elif args.dataset == 'cifar' and args.iid==0: 86 | user_model = CNNCifar(args=args, quant=quant) 87 | user_model.to(device) 88 | user_model.train() 89 | # Update the weights of local to center weights 90 | weight_name = [] 91 | for i in global_weights: 92 | weight_name.append(i) 93 | cnt = 0 94 | user_weights = user_model.state_dict() 95 | for i in user_weights: 96 | if quantx!=None: 97 | user_weights[i] = quantx(global_weights[weight_name[cnt]].to(float)) 98 | else: 99 | user_weights[i] = global_weights[weight_name[cnt]] 100 | cnt += 1 101 | user_model.load_state_dict(user_weights) 102 | # Train the local model 103 | local_model = LocalUpdate(args=args, dataset=train_dataset, 104 | idxs=user_groups[idx], logger=logger,quant=quantx, quantbit=args.quant_bits, mode=args.quant_type) 105 | w, loss, q = local_model.update_weights(model=copy.deepcopy(user_model)) 106 | cur_q.append(q) 107 | local_weights.append(copy.deepcopy(w)) 108 | local_losses.append(copy.deepcopy(loss)) 109 | torch.cuda.empty_cache() 110 | 111 | # update global weights 112 | global_weights = average_weights(args, local_weights,avg_for_user[idxs_users],q_for_user=np.array(cur_q)) 113 | loss_avg = sum(local_losses) / len(local_losses) 114 | train_loss.append(loss_avg) 115 | 116 | # Calculate avg training accuracy over all clients at every epoch 117 | list_acc, list_loss = [], [] 118 | weight_name = [] 119 | for i in global_weights: 120 | weight_name.append(i) 121 | for c in range(args.num_users): 122 | quant = None 123 | quantx = None 124 | # Set the quantization condition and local model for each client 125 | if quant_bit_for_user[c] != 0: 126 | args.quant_bits = quant_bit_for_user[c] 127 | quant = lambda: qn.BlockQuantizer(args.quant_bits, args.quant_bits, args.quant_type) 128 | quantx = lambda x: qn.quantize_block(x, args.quant_bits, args.quant_bits, args.quant_type) 129 | if args.dataset == 'mnist': 130 | user_model = CNNMnist(args=args, quant=quant) 131 | elif args.dataset == 'cifar' and args.iid == 1: 132 | user_model = ResNet(args=args, quant=quant, quantx=quantx) 133 | elif args.dataset == 'cifar' and args.iid == 0: 134 | user_model = CNNCifar(args=args, quant=quant) 135 | user_model.to(device) 136 | user_model.eval() 137 | cnt = 0 138 | user_weights = user_model.state_dict() 139 | for i in user_weights: 140 | user_weights[i] = global_weights[weight_name[cnt]] 141 | cnt += 1 142 | user_model.load_state_dict(user_weights) 143 | local_model = LocalUpdate(args=args, dataset=train_dataset, 144 | idxs=user_groups[c], logger=logger, quant=quantx, quantbit=quant_bit_for_user[c],mode=args.quant_type) 145 | acc, loss = local_model.inference(model=user_model) 146 | list_acc.append(acc) 147 | list_loss.append(loss) 148 | torch.cuda.empty_cache() 149 | train_accuracy.append(sum(list_acc)/len(list_acc)) 150 | 151 | if args.dataset == 'mnist': 152 | global_model = CNNMnist(args=args, quant=None) 153 | elif args.dataset == 'cifar' and args.iid == 1: 154 | global_model = ResNet(args=args, quant=quant, quantx=quantx) 155 | elif args.dataset == 'cifar' and args.iid == 0: 156 | global_model = CNNCifar(args=args, quant=quant) 157 | cnt = 0 158 | current_weight = global_model.state_dict() 159 | for i in current_weight: 160 | current_weight[i] = global_weights[weight_name[cnt]] 161 | cnt += 1 162 | global_model.load_state_dict(current_weight) 163 | global_model.to(device) 164 | global_model.eval() 165 | test_acc, test_loss = Evaluation(args, global_model, test_dataset) 166 | # Save the center model with the maximum accuracy 167 | if test_acc>last_max_acc: 168 | last_max_acc=test_acc 169 | torch.save(global_model.state_dict(), save_pkl_filename) 170 | with open(save_acc_filename,'a') as file: 171 | write_content=str(epoch+1)+' '+str(np.mean(np.array(train_loss)))+' '+str(train_accuracy[-1])+' '+str(test_acc)+"\n" 172 | file.write(write_content) 173 | for acc_index in range(len(acc_level)): 174 | if acc_flag[acc_index]==0 and test_acc>=acc_level[acc_index] and test_acc not in acc_true and test_acc>=acc_level[acc_index+1]: 175 | acc_flag[acc_index]=1 176 | if acc_flag[acc_index]==0 and test_acc>=acc_level[acc_index] and test_acc not in acc_true: 177 | acc_true.append(test_acc) 178 | acc_flag[acc_index]=1 179 | acc_table_line.append(epoch+1) 180 | if (epoch+1) % print_every == 0: 181 | print(f' \nAvg Training Stats after {epoch+1} global rounds:') 182 | print(f'Training Loss : {np.mean(np.array(train_loss))}') 183 | print('Train Accuracy: {:.2f}%'.format(100*train_accuracy[-1])) 184 | print("Test Accuracy: {:.2f}%\n".format(100 * test_acc)) 185 | print(f"Test Loss: {np.mean(np.array(test_loss))}") 186 | print('4: ',args.bit_4_ratio,', 8: ',args.bit_8_ratio,', ',args.average_scheme) 187 | if len(acc_table_line) == 0: 188 | print("cannot get the communication round for the target accuracy") 189 | else: 190 | table = pt.PrettyTable() 191 | table.field_names = acc_true # acc_level 192 | table.add_row(acc_table_line) 193 | print(table) 194 | if (epoch+1)%10==0 and args.lr>=1e-4: 195 | args.lr=args.lr*0.9 196 | print('learning rate : ',args.lr) 197 | # Test inference after completion of training 198 | test_acc, test_loss = Evaluation(args, global_model, test_dataset) 199 | print(f' \n Results after {args.epochs} global rounds of training:') 200 | print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1])) 201 | print("|---- Test Accuracy: {:.2f}%".format(100*test_acc)) 202 | print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time)) 203 | table = pt.PrettyTable() 204 | table.field_names = acc_true 205 | table.add_row(acc_table_line) 206 | test_acc=0 207 | if len(acc_table_line) == 0: 208 | print("cannot get the communication round for the target accuracy") 209 | else: 210 | table = pt.PrettyTable() 211 | table.field_names = acc_true 212 | table.add_row(acc_table_line) 213 | print(table) 214 | return train_accuracy[-1],test_acc 215 | -------------------------------------------------------------------------------- /src/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils.data import DataLoader, Dataset 4 | import quantizer as qn 5 | from options import args_parser 6 | args = args_parser() 7 | class DatasetSplit(Dataset): 8 | """An abstract Dataset class wrapped around Pytorch Dataset class. 9 | """ 10 | def __init__(self, dataset, idxs): 11 | self.dataset = dataset 12 | self.idxs = [int(i) for i in idxs] 13 | 14 | def __len__(self): 15 | return len(self.idxs) 16 | 17 | def __getitem__(self, item): 18 | image, label = self.dataset[self.idxs[item]] 19 | return torch.tensor(image), torch.tensor(label) 20 | class LocalUpdate(object): 21 | def __init__(self, args, dataset, idxs, logger, quant, quantbit, mode): 22 | self.args = args 23 | self.logger = logger 24 | self.trainloader = self.train_val_test( 25 | dataset, list(idxs)) 26 | self.device = 'cuda' 27 | # Default criterion set to NLL loss function 28 | self.criterion = nn.NLLLoss().to(self.device) 29 | self.quantbit=quantbit 30 | self.quant=quant 31 | self.mode=mode 32 | 33 | def train_val_test(self, dataset, idxs): 34 | """ 35 | Returns train, validation and test dataloaders for a given dataset 36 | and user indexes. 37 | """ 38 | trainloader = DataLoader(DatasetSplit(dataset, idxs), 39 | batch_size=self.args.local_bs, shuffle=False) 40 | return trainloader 41 | 42 | def update_weights(self, model): 43 | # Set mode to train model 44 | model.train() 45 | epoch_loss = [] 46 | # Set optimizer for the local updates 47 | if self.args.optimizer == 'sgd': 48 | optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr, 49 | momentum=self.args.momentum, weight_decay=args.weight_decay) 50 | elif self.args.optimizer == 'adam': 51 | optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr, weight_decay=args.weight_decay) 52 | for iter in range(self.args.local_ep): 53 | batch_loss = [] 54 | for batch_idx, (images, labels) in enumerate(self.trainloader): 55 | images, labels = images.to(self.device), labels.to(self.device) 56 | 57 | optimizer.zero_grad() 58 | log_probs = model(images) 59 | loss = self.criterion(log_probs, labels) 60 | loss.backward() 61 | optimizer.step() 62 | # Weight quantization 63 | if self.quant != None: 64 | ansq=[] 65 | for name, p in model.named_parameters(): 66 | quant_p, q, data_q = qn.q_quantize(p.data,self.quantbit,self.mode,self.quantbit) 67 | ansq.append(q.item()) 68 | p.data=quant_p.data 69 | self.logger.add_scalar('loss', loss.item()) 70 | batch_loss.append(loss.item()) 71 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 72 | return model.state_dict(), sum(epoch_loss) / len(epoch_loss), max(ansq) 73 | 74 | def inference(self, model): 75 | model.eval() 76 | loss, total, correct = 0.0, 0.0, 0.0 77 | 78 | for batch_idx, (images, labels) in enumerate(self.trainloader): 79 | images, labels = images.to(self.device), labels.to(self.device) 80 | 81 | # Inference 82 | outputs = model(images) 83 | batch_loss = self.criterion(outputs, labels) 84 | loss += batch_loss.item() 85 | 86 | # Prediction 87 | _, pred_labels = torch.max(outputs, 1) 88 | pred_labels = pred_labels.view(-1) 89 | correct += torch.sum(torch.eq(pred_labels, labels)).item() 90 | total += len(labels) 91 | 92 | accuracy = correct/total 93 | return accuracy, loss/total 94 | 95 | 96 | def Evaluation(args, model, test_dataset): 97 | """ Returns the test accuracy and loss. 98 | """ 99 | 100 | model.eval() 101 | loss, total, correct = 0.0, 0.0, 0.0 102 | 103 | device = 'cuda' if args.gpu else 'cpu' 104 | criterion = nn.NLLLoss().to(device) 105 | testloader = DataLoader(test_dataset, batch_size=128, 106 | shuffle=False) 107 | cnt_loss=0 108 | for batch_idx, (images, labels) in enumerate(testloader): 109 | images, labels = images.to(device), labels.to(device) 110 | 111 | # Inference 112 | outputs = model(images) 113 | batch_loss = criterion(outputs, labels) 114 | loss += batch_loss.item() 115 | 116 | # Prediction 117 | _, pred_labels = torch.max(outputs, 1) 118 | pred_labels = pred_labels.view(-1) 119 | correct += torch.sum(torch.eq(pred_labels, labels)).item() 120 | total += len(labels) 121 | cnt_loss+=1 122 | 123 | accuracy = correct/total 124 | return accuracy, loss/cnt_loss 125 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from torchvision import datasets, transforms 3 | from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal 4 | from sampling import cifar_iid, cifar_noniid 5 | from options import args_parser 6 | import numpy as np 7 | args = args_parser() 8 | 9 | def get_dataset(args): 10 | """ Returns train and test datasets and a user group which is a dict where 11 | the keys are the user index and the values are the corresponding data for 12 | each of those users. 13 | """ 14 | if args.dataset == 'cifar': 15 | data_dir = '../data/cifar/' 16 | transform_train = transforms.Compose([ 17 | transforms.RandomCrop(32, padding=4), 18 | transforms.RandomHorizontalFlip(), 19 | transforms.ToTensor(), 20 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 21 | ]) 22 | 23 | transform_test = transforms.Compose([ 24 | transforms.ToTensor(), 25 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 26 | ]) 27 | train_dataset = datasets.CIFAR10(data_dir, train=True, download=True, 28 | transform=transform_train) 29 | test_dataset = datasets.CIFAR10(data_dir, train=False, download=True, 30 | transform=transform_test) 31 | 32 | # sample training data amongst users 33 | if args.iid: 34 | # Sample IID user data from Mnist 35 | user_groups = cifar_iid(train_dataset, args.num_users) 36 | else: 37 | # Sample Non-IID user data from Mnist 38 | if args.unequal: 39 | # Chose uneuqal splits for every user 40 | #raise NotImplementedError() 41 | user_groups = mnist_noniid_unequal(train_dataset, args.num_users) 42 | else: 43 | # Chose euqal splits for every user 44 | user_groups = cifar_noniid(train_dataset, args.num_users) 45 | 46 | elif args.dataset == 'mnist': 47 | data_dir = '../data/mnist/' 48 | 49 | apply_transform = transforms.Compose([ 50 | transforms.ToTensor(), 51 | transforms.Normalize((0.1307,), (0.3081,))]) 52 | 53 | train_dataset = datasets.MNIST(data_dir, train=True, download=True, 54 | transform=apply_transform) 55 | 56 | test_dataset = datasets.MNIST(data_dir, train=False, download=True, 57 | transform=apply_transform) 58 | 59 | # sample training data amongst users 60 | if args.iid: 61 | # Sample IID user data from Mnist 62 | user_groups = mnist_iid(train_dataset, args.num_users) 63 | else: 64 | # Sample Non-IID user data from Mnist 65 | if args.unequal: 66 | # Chose uneuqal splits for every user 67 | user_groups = mnist_noniid_unequal(train_dataset, args.num_users) 68 | else: 69 | # Chose euqal splits for every user 70 | user_groups = mnist_noniid(train_dataset, args.num_users) 71 | 72 | return train_dataset, test_dataset, user_groups 73 | 74 | def get_quantization_bit(args): 75 | quant_bit_for_user = np.zeros([args.num_users]) 76 | avg_for_user = np.zeros([args.num_users]) 77 | 78 | for user in range(args.num_users): 79 | if user < args.num_users * args.bit_4_ratio: 80 | quant_bit_for_user[user] = 4 81 | avg_for_user[user] = 4 82 | elif user < args.num_users * (args.bit_4_ratio + args.bit_8_ratio): 83 | quant_bit_for_user[user] = 8 84 | avg_for_user[user] = 8 85 | else: 86 | avg_for_user[user] = 64 87 | return quant_bit_for_user, avg_for_user 88 | 89 | def average_weights(args,w,avg_for_user,q_for_user): 90 | """ 91 | Returns the average of the weights. 92 | """ 93 | w_avg = copy.deepcopy(w[0]) 94 | user_num=len(avg_for_user) 95 | # Calculate the p for each user, the sum of p is 1 96 | p_for_user=np.ones(user_num) 97 | if args.average_scheme == 'FedAvg': 98 | p_for_user /= sum(p_for_user) 99 | if args.average_scheme == 'Proportional': 100 | p_for_user = np.array(avg_for_user) / np.sum(avg_for_user) 101 | if args.average_scheme == 'FedHQ': 102 | p_for_user = 1 / (1 + q_for_user) 103 | p_for_user /= np.sum(p_for_user) 104 | for key in w_avg: 105 | w_avg[key] *= p_for_user[0] 106 | for i in range(1, len(w)): 107 | weight_name = [] 108 | for j in w[i]: 109 | weight_name.append(j) 110 | cnt=0 111 | for key in w_avg.keys(): 112 | w_avg[key] += w[i][weight_name[cnt]] * p_for_user[i] 113 | cnt += 1 114 | return w_avg 115 | 116 | 117 | def exp_details(args): 118 | print('\nExperimental details:') 119 | print(f' Dataset : {args.dataset}') 120 | model='CNN' 121 | if args.dataset=='cifar' and args.iid==1: 122 | model='ResNet18' 123 | if args.dataset == 'cifar' and args.iid == 0: 124 | model = 'VGG11' 125 | 126 | print(f' Model : {model}') 127 | print(f' Optimizer : {args.optimizer}') 128 | print(f' Learning : {args.lr}') 129 | print(f' Global Rounds : {args.epochs}\n') 130 | 131 | print(' Federated parameters:') 132 | if args.iid: 133 | print(' IID') 134 | else: 135 | print(' Non-IID') 136 | print(f' Fraction of users : {args.frac}') 137 | print(f' Local Batch size : {args.local_bs}') 138 | print(f' Local Epochs : {args.local_ep}\n') 139 | return 140 | --------------------------------------------------------------------------------