├── README.md ├── celeba-gender-partitions.csv ├── celeba-gender-test.csv ├── celeba-gender-train.csv ├── celeba-gender-valid.csv ├── data └── README.md ├── environment.yml ├── loading_data.py ├── main_fed.py ├── main_fed_celeba.py ├── main_fed_scaling.py ├── main_nn.py ├── main_quan_celeba.py ├── model_para.py ├── models ├── Fed.py ├── Nets.py ├── README.md ├── ResNet.py ├── Scaling_Net.py ├── Scaling_update.py ├── Update.py ├── __init__.py ├── quan_resnet.py ├── quan_update.py └── test.py └── utils ├── README.md ├── __init__.py ├── options.py └── sampling.py /README.md: -------------------------------------------------------------------------------- 1 | # Aggregation Service for Federated Learning: An Efficient, Secure, and More Resilient Realization 2 | --- 3 | ## 代码说明 4 | ## 1.数据集 5 | ### 1.1数据集下载 6 | 本次版本的代码添加了一个数据集[CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html),CelebA是CelebFaces Attribute的缩写,意即名人人脸属性数据集,其包含10,177个名人身份的202,599张人脸图片,每张图片都做好了特征标记,包含人脸bbox标注框、5个人脸特征点坐标以及40个属性标记,CelebA由香港中文大学开放提供,广泛用于人脸相关的计算机视觉训练任务,可用于人脸属性标识训练、人脸检测训练以及landmark标记等。同时,CelebA也是联邦学习(FL)标准数据集[LEAF](https://leaf.cmu.edu/)中的一部分。由于数据集过大,无法通过PyTorch封装的代码直接获得,获取方式如下: 7 | > [1.My Google Drive](https://drive.google.com/drive/folders/1XegONA2EQzPPO5h-0sUp-hbzGNvyXHTe?usp=sharing) (推荐) 8 | > 9 | > [2.Google Drive](https://drive.google.com/open?id=0B7EVK8r0v71pWEZsZE9oNnFzTm8) 10 | > 11 | > [3.Baidu Drive](https://pan.baidu.com/s/1CRxxhoQ97A5qbsKO7iaAJg) 12 | 13 | 注意:本文中是利用这些数据做性别分类任务,因此,需要运行*loading_data.py*来获取标签文件,如下: 14 | ```python= 15 | python loading_data.py 16 | ``` 17 | 18 | --- 19 | ### 1.2数据集展示 20 | ![](https://codimd.xixiaoyao.cn/uploads/upload_7a764e819d2d6efd7b447224bdbc7ca4.png) 21 | For more details of the dataset, please refer to the paper ["Deep Learning Face Attributes in the Wild".](https://liuziwei7.github.io/projects/FaceAttributes.html) 22 | 23 | --- 24 | 25 | ### 1.3数据集引用 26 | ``` 27 | @inproceedings{liu2015faceattributes, 28 | title = {Deep Learning Face Attributes in the Wild}, 29 | author = {Liu, Ziwei and Luo, Ping and Wang, Xiaogang and Tang, Xiaoou}, 30 | booktitle = {Proceedings of International Conference on Computer Vision (ICCV)}, 31 | month = {December}, 32 | year = {2015} 33 | } 34 | ``` 35 | 36 | 37 | ## 2.运行说明 38 | - 环境配置 39 | ```python= 40 | conda env create -f environment.yml 41 | ``` 42 | 注意:请注意服务器上nvidia的版本,请确认一下是否可以安装我这个服务器所带的环境,这个很重要。 43 | ### 2.1 Baseline 44 | - Baseline: 不加Scaling与量化。 45 | ```python= 46 | python main_fed.py --dataset [dataset_name] --iid --num_channels 1 --model cnn --epochs 200 --gpu [GPU_number] 47 | ``` 48 | 注意:main_fed.py中不包括新加的数据集CelebA,**若要运行CIFAR-10数据集则num_channels=3且模型为AlexNet**。MNIST和CIFAR-10数据集可以不使用GPU。 49 | 50 | - CelebA数据集 51 | ```python= 52 | python main_fed_celeba.py --dataset celeba --iid --num_channels 3 --model cnn --epochs 200 --gpu 1 53 | ``` 54 | 注意:**此程序运行必须要使用GPU且运行时间大于十小时**。如果使用服务器挂载运行,建议使用nobhup来运行。 55 | 56 | --- 57 | 58 | ### 2.2 Scaling方案 59 | - Scaling方案 60 | ```python= 61 | python main_fed_scaling.py --dataset [dataset_name] --iid --num_channels 1 --model cnn --epochs 200 --scaling_factor 10 --gpu [GPU_number] 62 | ``` 63 | 64 | --- 65 | 66 | ### 2.3 Quan方案 67 | - Quan-16bit方案 68 | ```python= 69 | python main_quan_celeba.py --dataset [dataset_name] --iid --num_channels 1 --model cnn --epochs 200 --bit_width 16 --alpha [0.5 or 0.1] --gpu [GPU_number] 70 | ``` 71 | - Quan-8bi方案 72 | ```python= 73 | python main_quan_celeba.py --dataset [dataset_name] --iid --num_channels 1 --model cnn --epochs 200 --bit_width 8 --alpha [0.5 or 0.1] --gpu [GPU_number] 74 | ``` 75 | 76 | --- 77 | ## 3.模型结构、参数量以及大小 78 | ```python= 79 | python model_para.py 80 | ``` 81 | 结果如下: 82 | ``` 83 | Total params: 14,028,106 84 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- 85 | Total memory: 23.34MB 86 | Total MAdd: 1.54GMAdd 87 | Total Flops: 770.57MFlops 88 | Total MemR+W: 100.51MB 89 | 90 | resnet18 have 14028106 paramerters in total 91 | ``` 92 | --- 93 | 94 | # Reference 95 | ``` 96 | Shaoxiong Ji. (2018, March 30). A PyTorch Implementation of Federated Learning. Zenodo. http://doi.org/10.5281/zenodo.4321561 97 | ``` 98 | 99 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # dataset 2 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: fl 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - blas=1.0=openblas 7 | - ca-certificates=2021.5.25=h06a4308_1 8 | - certifi=2021.5.30=py36h06a4308_0 9 | - cycler=0.10.0=py36_0 10 | - dbus=1.13.18=hb2f20db_0 11 | - decorator=5.0.9=pyhd3eb1b0_0 12 | - expat=2.4.1=h2531618_2 13 | - fontconfig=2.13.1=h6c09931_0 14 | - freetype=2.10.4=h5ab3b9f_0 15 | - glib=2.68.2=h36276a3_0 16 | - gmp=6.2.1=h2531618_2 17 | - gmpy2=2.0.8=py36h10f8cd9_2 18 | - gst-plugins-base=1.14.0=h8213a91_2 19 | - gstreamer=1.14.0=h28cd5cc_2 20 | - icu=58.2=he6710b0_3 21 | - ipykernel=5.3.4=py36h5ca1d4c_0 22 | - ipython=6.1.0=py36_0 23 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 24 | - jedi=0.18.0=py36h06a4308_1 25 | - joblib=1.0.1=pyhd3eb1b0_0 26 | - jpeg=9b=h024ee3a_2 27 | - jupyter_client=6.1.12=pyhd3eb1b0_0 28 | - jupyter_core=4.7.1=py36h06a4308_0 29 | - kiwisolver=1.3.1=py36h2531618_0 30 | - lcms2=2.12=h3be6417_0 31 | - ld_impl_linux-64=2.33.1=h53a641e_7 32 | - libffi=3.3=he6710b0_2 33 | - libgcc-ng=9.1.0=hdf63c60_0 34 | - libgfortran-ng=7.3.0=hdf63c60_0 35 | - libopenblas=0.3.13=h4367d64_0 36 | - libpng=1.6.37=hbc83047_0 37 | - libsodium=1.0.18=h7b6447c_0 38 | - libstdcxx-ng=9.1.0=hdf63c60_0 39 | - libtiff=4.2.0=h85742a9_0 40 | - libuuid=1.0.3=h1bed415_2 41 | - libwebp-base=1.2.0=h27cfd23_0 42 | - libxcb=1.14=h7b6447c_0 43 | - libxml2=2.9.10=hb55368b_3 44 | - lz4-c=1.9.3=h2531618_0 45 | - matplotlib=3.3.4=py36h06a4308_0 46 | - matplotlib-base=3.3.4=py36h62a2d02_0 47 | - mpc=1.1.0=h10f8cd9_1 48 | - mpfr=4.0.2=hb69a4c5_1 49 | - ncurses=6.2=he6710b0_1 50 | - numpy=1.17.0=py36h99e49ec_0 51 | - numpy-base=1.17.0=py36h2f8d375_0 52 | - olefile=0.46=py36_0 53 | - openssl=1.1.1k=h27cfd23_0 54 | - pandas=0.20.3=py36_0 55 | - parso=0.8.2=pyhd3eb1b0_0 56 | - pcre=8.44=he6710b0_0 57 | - pexpect=4.8.0=pyhd3eb1b0_3 58 | - pickleshare=0.7.5=pyhd3eb1b0_1003 59 | - pillow=8.2.0=py36he98fc37_0 60 | - pip=21.1.1=py36h06a4308_0 61 | - prompt_toolkit=1.0.15=py36_0 62 | - ptyprocess=0.7.0=pyhd3eb1b0_2 63 | - pygments=2.9.0=pyhd3eb1b0_0 64 | - pyparsing=2.4.7=pyhd3eb1b0_0 65 | - pyqt=5.9.2=py36h05f1152_2 66 | - python=3.6.13=hdb3f193_0 67 | - python-dateutil=2.8.1=pyhd3eb1b0_0 68 | - pytz=2021.1=pyhd3eb1b0_0 69 | - pyzmq=20.0.0=py36h2531618_1 70 | - qt=5.9.7=h5867ecd_1 71 | - readline=8.1=h27cfd23_0 72 | - scikit-learn=0.24.2=py36ha9443f7_0 73 | - scipy=1.5.2=py36habc2bb6_0 74 | - setuptools=52.0.0=py36h06a4308_0 75 | - simplegeneric=0.8.1=py36_2 76 | - sip=4.19.8=py36hf484d3e_0 77 | - six=1.15.0=py36h06a4308_0 78 | - sqlite=3.35.4=hdfb4753_0 79 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 80 | - tk=8.6.10=hbc83047_0 81 | - tornado=6.1=py36h27cfd23_0 82 | - traitlets=4.3.3=py36_0 83 | - wcwidth=0.2.5=py_0 84 | - wheel=0.36.2=pyhd3eb1b0_0 85 | - xz=5.2.5=h7b6447c_0 86 | - zeromq=4.3.4=h2531618_0 87 | - zlib=1.2.11=h7b6447c_3 88 | - zstd=1.4.9=haebb681_0 89 | - pip: 90 | - future==0.18.2 91 | - torch==1.6.0+cu101 92 | - torchstat==0.0.7 93 | - torchvision==0.7.0+cu101 94 | prefix: /home/shuyu/anaconda3/envs/fl 95 | -------------------------------------------------------------------------------- /loading_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | from torch.utils.data import Dataset 5 | 6 | from PIL import Image 7 | 8 | 9 | 10 | def data_processing(): 11 | df1 = pd.read_csv('list_attr_celeba.txt', sep="\s+", skiprows=1, usecols=['Male']) 12 | df1.loc[df1['Male'] == -1, 'Male'] = 0 13 | 14 | df2 = pd.read_csv('list_eval_partition.txt', sep="\s+", skiprows=0, header=None) 15 | df2.columns = ['Filename', 'Partition'] 16 | df2 = df2.set_index('Filename') 17 | 18 | df3 = df1.merge(df2, left_index=True, right_index=True) 19 | df3.head() 20 | 21 | 22 | df3.to_csv('celeba-gender-partitions.csv') 23 | df4 = pd.read_csv('celeba-gender-partitions.csv', index_col=0) 24 | df4.head() 25 | 26 | 27 | df4.loc[df4['Partition'] == 0].to_csv('celeba-gender-train.csv') 28 | df4.loc[df4['Partition'] == 1].to_csv('celeba-gender-valid.csv') 29 | df4.loc[df4['Partition'] == 2].to_csv('celeba-gender-test.csv') 30 | 31 | 32 | 33 | class CelebaDataset(Dataset): 34 | """Custom Dataset for loading CelebA face images""" 35 | 36 | def __init__(self, csv_path, img_dir, transform=None): 37 | df = pd.read_csv(csv_path, index_col=0) 38 | self.img_dir = img_dir 39 | self.csv_path = csv_path 40 | self.img_names = df.index.values 41 | self.y = df['Male'].values 42 | self.transform = transform 43 | 44 | def __getitem__(self, index): 45 | img = Image.open(os.path.join(self.img_dir, 46 | self.img_names[index])) 47 | 48 | if self.transform is not None: 49 | img = self.transform(img) 50 | 51 | label = self.y[index] 52 | return img, label 53 | 54 | def __len__(self): 55 | return self.y.shape[0] 56 | 57 | if __name__ == '__main__': 58 | data_processing() -------------------------------------------------------------------------------- /main_fed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import matplotlib 6 | 7 | matplotlib.use('Agg') 8 | import matplotlib.pyplot as plt 9 | import copy 10 | import numpy as np 11 | from torchvision import datasets, transforms 12 | import torch 13 | 14 | from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, fmnist_iid 15 | from utils.options import args_parser 16 | from models.Update import LocalUpdate 17 | from models.Nets import MLP, CNNMnist, CNNCifar, AlexNetCifar, AlexNetMnist, CNNCeleba 18 | from models.Fed import FedAvg 19 | from models.test import test_img 20 | from service.client import * 21 | from service.utils.DH import * 22 | 23 | if __name__ == '__main__': 24 | # parse args 25 | #print('差值量化') 26 | 27 | args = args_parser() 28 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 29 | 30 | # load dataset and split users 31 | if args.dataset == 'mnist': 32 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 33 | dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist) 34 | dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist) 35 | # sample users 36 | if args.iid: 37 | dict_users = mnist_iid(dataset_train, args.num_users) 38 | else: 39 | dict_users = mnist_noniid(dataset_train, args.num_users) 40 | elif args.dataset == 'fmnist': 41 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 42 | dataset_train = datasets.FashionMNIST('../data/fmnist/', train=True, download=True, transform=trans_mnist) 43 | dataset_test = datasets.FashionMNIST('../data/fmnist/', train=False, download=True, transform=trans_mnist) 44 | # sample users 45 | if args.iid: 46 | dict_users = fmnist_iid(dataset_train, args.num_users) 47 | else: 48 | exit('Error: only consider IID setting in CIFAR10') 49 | elif args.dataset == 'cifar': 50 | trans_cifar = transforms.Compose( 51 | [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 52 | dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar) 53 | dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar) 54 | if args.iid: 55 | dict_users = cifar_iid(dataset_train, args.num_users) 56 | else: 57 | exit('Error: only consider IID setting in CIFAR10') 58 | else: 59 | exit('Error: unrecognized dataset') 60 | img_size = dataset_train[0][0].shape 61 | 62 | # initialise aggregate service 63 | p = random_prime(2048) 64 | g = get_generator(p) 65 | clients = [] 66 | for i in range(args.num_users): 67 | client = Client(i, p, g, args.bit_width) 68 | clients.append(client) 69 | # set bit width 70 | clients[0].set_bit_width() 71 | 72 | # fetch the public key list and compute K 73 | for i in range(args.num_users): 74 | clients[i].generate_k() 75 | 76 | # build model 77 | if args.model == 'cnn' and args.dataset == 'cifar': 78 | net_glob = CNNCifar(args=args).to(args.device) 79 | elif args.model == 'alexnet' and args.dataset == 'cifar': 80 | net_glob = AlexNetCifar(args=args).to(args.device) 81 | elif args.model == 'cnn' and args.dataset == 'mnist': 82 | net_glob = CNNMnist(args=args).to(args.device) 83 | elif args.model == 'alexnet' and args.dataset == 'mnist': 84 | net_glob = AlexNetMnist(args=args).to(args.device) 85 | elif args.model == 'cnn' and args.dataset == 'FashionMnist': 86 | net_glob = CNNMnist(args=args).to(args.device) 87 | elif args.model == 'alexnet' and args.dataset == 'FashionMnist': 88 | net_glob = AlexNetMnist(args=args).to(args.device) 89 | elif args.model == 'cnn' and args.dataset == 'celeba': 90 | net_glob = CNNCeleba(args=args).to(args.device) 91 | elif args.model == 'mlp': 92 | len_in = 1 93 | for x in img_size: 94 | len_in *= x 95 | net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device) 96 | else: 97 | exit('Error: unrecognized model') 98 | print(net_glob) 99 | net_glob.train() 100 | 101 | # copy weights 102 | w_glob = net_glob.state_dict() 103 | 104 | # training 105 | loss_train = [] 106 | cv_loss, cv_acc = [], [] 107 | val_loss_pre, counter = 0, 0 108 | net_best = None 109 | best_loss = None 110 | val_acc_list, net_list = [], [] 111 | res = [] 112 | ogs = [] 113 | for iter in range(args.epochs): 114 | net_glob.update() 115 | w_locals, loss_locals, w_locals_masked = [], [], [] 116 | m = max(int(args.frac * args.num_users), 1) 117 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 118 | for idx in idxs_users: 119 | local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx]) 120 | w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device)) 121 | w_locals.append(copy.deepcopy(w)) 122 | # mask the local w 123 | # w_masked = clients[idx].masking(w, idxs_users, iter) 124 | # w_locals_masked.append(copy.deepcopy(w_masked)) 125 | loss_locals.append(copy.deepcopy(loss)) 126 | # update global weights 127 | #w_glob_unmasked = clients[0].done(w_locals_masked) 128 | # reshape the result from the server 129 | # for k in w_glob.keys(): 130 | # w_glob[k] = torch.tensor(w_glob_unmasked[k]).reshape(w_glob[k].shape) 131 | w_glob = FedAvg(w_locals) # 对与clients的模型求平均 132 | 133 | # copy weight to net_glob 134 | net_glob.load_state_dict(w_glob) # 把求平均之后的模型参数加载到globe model 135 | net_glob.unquant(args) # 对globe unquant 136 | # net_glob.to(args.device) 137 | # print() 138 | net_glob.add() 139 | # print() 140 | # print loss 141 | loss_avg = sum(loss_locals) / len(loss_locals) 142 | print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg)) 143 | loss_train.append(loss_avg) 144 | 145 | # plot loss curve 146 | plt.figure() 147 | plt.plot(range(len(loss_train)), loss_train) 148 | plt.ylabel('train_loss') 149 | plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid)) 150 | 151 | # testing 152 | net_glob.eval() 153 | acc_train, loss_train = test_img(net_glob, dataset_train, args) 154 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 155 | print("Training accuracy: {:.2f}".format(acc_train)) 156 | print("Testing accuracy: {:.2f}".format(acc_test)) 157 | -------------------------------------------------------------------------------- /main_fed_celeba.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import matplotlib 6 | 7 | matplotlib.use('Agg') 8 | import matplotlib.pyplot as plt 9 | import copy 10 | import numpy as np 11 | from torchvision import datasets, transforms 12 | from torch.utils.data import DataLoader 13 | import torch 14 | from torchstat import stat 15 | from models import Nets 16 | 17 | from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, fmnist_iid, celeba_iid 18 | from utils.options import args_parser 19 | from models.Update import LocalUpdate 20 | from models.Nets import MLP, CNNMnist, CNNCifar, CNNFMnist, CNNCeleba 21 | from models.Fed import FedAvg 22 | from models.test import test_img 23 | import loading_data as dataset 24 | from models.Nets import Bottleneck 25 | #from service.client import * 26 | #from service.utils.DH import * 27 | 28 | 29 | def test(): 30 | net_glob.eval() 31 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 32 | print("Testing accuracy: {:.2f}".format(acc_test)) 33 | 34 | 35 | if __name__ == '__main__': 36 | # parse args 37 | #print('差值量化') 38 | 39 | args = args_parser() 40 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 41 | 42 | # load dataset and split users 43 | if args.dataset == 'mnist': 44 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 45 | dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist) 46 | dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist) 47 | # sample users 48 | if args.iid: 49 | dict_users = mnist_iid(dataset_train, args.num_users) 50 | else: 51 | dict_users = mnist_noniid(dataset_train, args.num_users) 52 | elif args.dataset == 'fmnist': 53 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 54 | dataset_train = datasets.FashionMNIST('../data/fmnist/', train=True, download=True, transform=trans_mnist) 55 | dataset_test = datasets.FashionMNIST('../data/fmnist/', train=False, download=True, transform=trans_mnist) 56 | # sample users 57 | if args.iid: 58 | dict_users = fmnist_iid(dataset_train, args.num_users) 59 | else: 60 | exit('Error: only consider IID setting in CIFAR10') 61 | elif args.dataset == 'celeba': 62 | custom_transform = transforms.Compose([transforms.CenterCrop((178, 178)), 63 | transforms.Resize((128, 128)), 64 | # transforms.Grayscale(), 65 | # transforms.Lambda(lambda x: x/255.), 66 | transforms.ToTensor()]) 67 | 68 | dataset_train = dataset.CelebaDataset(csv_path='celeba-gender-train.csv', 69 | img_dir='data/CelebA/img_align_celeba/', 70 | transform=custom_transform) 71 | 72 | valid_celeba= dataset.CelebaDataset(csv_path='celeba-gender-valid.csv', 73 | img_dir='data/CelebA/img_align_celeba/', 74 | transform=custom_transform) 75 | 76 | dataset_test = dataset.CelebaDataset(csv_path='celeba-gender-test.csv', 77 | img_dir='data/CelebA/img_align_celeba/', 78 | transform=custom_transform) 79 | 80 | if args.iid: 81 | dict_users = celeba_iid(dataset_train, args.num_users) 82 | else: 83 | exit('Error: only consider IID setting in CIFAR10') 84 | elif args.dataset == 'cifar': 85 | trans_cifar = transforms.Compose( 86 | [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 87 | dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar) 88 | dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar) 89 | if args.iid: 90 | dict_users = cifar_iid(dataset_train, args.num_users) 91 | else: 92 | exit('Error: only consider IID setting in CIFAR10') 93 | else: 94 | exit('Error: unrecognized dataset') 95 | img_size = dataset_train[0][0].shape 96 | 97 | # initialise aggregate service 98 | #p = random_prime(2048) 99 | #g = get_generator(p) 100 | #clients = [] 101 | #for i in range(args.num_users): 102 | # client = Client(i, p, g, args.bit_width) 103 | # clients.append(client) 104 | # set bit width 105 | #clients[0].set_bit_width() 106 | 107 | # fetch the public key list and compute K 108 | #for i in range(args.num_users): 109 | # clients[i].generate_k() 110 | 111 | # build model 112 | if args.model == 'cnn' and args.dataset == 'cifar': 113 | net_glob = CNNCifar(args=args).to(args.device) 114 | elif args.model == 'cnn' and args.dataset == 'mnist': 115 | net_glob = CNNMnist(args=args).to(args.device) 116 | elif args.model == 'cnn' and args.dataset == 'Fmnist': 117 | net_glob = CNNFMnist(args=args).to(args.device) 118 | elif args.model == 'cnn' and args.dataset == 'celeba': 119 | net_glob = CNNCeleba(args=args).to(args.device) 120 | elif args.model == 'mlp': 121 | len_in = 1 122 | for x in img_size: 123 | len_in *= x 124 | net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device) 125 | else: 126 | exit('Error: unrecognized model') 127 | 128 | #print(net_glob) 129 | net_glob.train() 130 | 131 | # copy weights 132 | w_glob = net_glob.state_dict() 133 | 134 | # training 135 | loss_train = [] 136 | cv_loss, cv_acc = [], [] 137 | val_loss_pre, counter = 0, 0 138 | net_best = None 139 | best_loss = None 140 | val_acc_list, net_list = [], [] 141 | 142 | if args.all_clients: 143 | print("Aggregation over all clients") 144 | w_locals = [w_glob for i in range(args.num_users)] 145 | for iter in range(args.epochs): 146 | loss_locals = [] 147 | if not args.all_clients: 148 | w_locals = [] 149 | m = max(int(args.frac * args.num_users), 1) 150 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 151 | for idx in idxs_users: 152 | local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx]) 153 | w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device)) 154 | if args.all_clients: 155 | w_locals[idx] = copy.deepcopy(w) 156 | else: 157 | w_locals.append(copy.deepcopy(w)) 158 | loss_locals.append(copy.deepcopy(loss)) 159 | # update global weights 160 | w_glob = FedAvg(w_locals) 161 | 162 | # copy weight to net_glob 163 | net_glob.load_state_dict(w_glob) 164 | 165 | # print loss 166 | loss_avg = sum(loss_locals) / len(loss_locals) 167 | #print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg)) 168 | loss_train.append(loss_avg) 169 | test() 170 | 171 | # plot loss curve 172 | plt.figure() 173 | plt.plot(range(len(loss_train)), loss_train) 174 | plt.ylabel('train_loss') 175 | plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid)) 176 | 177 | # testing 178 | net_glob.eval() 179 | acc_train, loss_train = test_img(net_glob, dataset_train, args) 180 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 181 | print("Training accuracy: {:.2f}".format(acc_train)) 182 | print("Testing accuracy: {:.2f}".format(acc_test)) 183 | 184 | -------------------------------------------------------------------------------- /main_fed_scaling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | import copy 9 | import numpy as np 10 | from torchvision import datasets, transforms 11 | import torch 12 | 13 | from utils.sampling import mnist_iid, mnist_noniid, cifar_iid 14 | from utils.options import args_parser 15 | from models.Scaling_update import LocalUpdate 16 | from models.Scaling_Net import MLP, CNNMnist, CNNCifar, AlexNetMnist, AlexNetCifar, CNNCeleba 17 | from models.Fed import FedAvg 18 | from models.test import test_img 19 | import loading_data as dataset 20 | from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, fmnist_iid, celeba_iid 21 | 22 | 23 | def test(): 24 | net_glob.eval() 25 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 26 | print("Testing accuracy: {:.2f}".format(acc_test)) 27 | 28 | 29 | 30 | 31 | if __name__ == '__main__': 32 | # parse args 33 | args = args_parser() 34 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 35 | 36 | # load dataset and split users 37 | if args.dataset == 'mnist': 38 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 39 | dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist) 40 | dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist) 41 | # sample users 42 | if args.iid: 43 | dict_users = mnist_iid(dataset_train, args.num_users) 44 | else: 45 | dict_users = mnist_noniid(dataset_train, args.num_users) 46 | elif args.dataset == 'FashionMnist': 47 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 48 | dataset_train = datasets.FashionMNIST('../data/FashionMnist/', train=True, download=True, transform=trans_mnist) 49 | dataset_test = datasets.FashionMNIST('../data/FashionMnist/', train=False, download=True, transform=trans_mnist) 50 | # sample users 51 | if args.iid: 52 | dict_users = mnist_iid(dataset_train, args.num_users) 53 | else: 54 | dict_users = mnist_noniid(dataset_train, args.num_users) 55 | elif args.dataset == 'cifar': 56 | trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 57 | dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar) 58 | dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar) 59 | if args.iid: 60 | dict_users = cifar_iid(dataset_train, args.num_users) 61 | else: 62 | exit('Error: only consider IID setting in CIFAR10') 63 | elif args.dataset == 'celeba': 64 | custom_transform = transforms.Compose([transforms.CenterCrop((178, 178)), 65 | transforms.Resize((128, 128)), 66 | # transforms.Grayscale(), 67 | # transforms.Lambda(lambda x: x/255.), 68 | transforms.ToTensor()]) 69 | 70 | dataset_train = dataset.CelebaDataset(csv_path='celeba-gender-train.csv', 71 | img_dir='data/CelebA/img_align_celeba/', 72 | transform=custom_transform) 73 | 74 | valid_celeba= dataset.CelebaDataset(csv_path='celeba-gender-valid.csv', 75 | img_dir='data/CelebA/img_align_celeba/', 76 | transform=custom_transform) 77 | 78 | dataset_test = dataset.CelebaDataset(csv_path='celeba-gender-test.csv', 79 | img_dir='data/CelebA/img_align_celeba/', 80 | transform=custom_transform) 81 | 82 | if args.iid: 83 | dict_users = celeba_iid(dataset_train, args.num_users) 84 | else: 85 | exit('Error: only consider IID setting in CIFAR10') 86 | else: 87 | exit('Error: unrecognized dataset') 88 | img_size = dataset_train[0][0].shape 89 | 90 | # build model 91 | if args.model == 'cnn' and args.dataset == 'cifar': 92 | net_glob = CNNCifar(args=args).to(args.device) 93 | elif args.model == 'alexnet' and args.dataset == 'cifar': 94 | net_glob = AlexNetCifar(args=args).to(args.device) 95 | elif args.model == 'cnn' and args.dataset == 'mnist': 96 | net_glob = CNNMnist(args=args).to(args.device) 97 | elif args.model == 'alexnet' and args.dataset == 'mnist': 98 | net_glob = AlexNetMnist(args=args).to(args.device) 99 | elif args.model == 'cnn' and args.dataset == 'FashionMnist': 100 | net_glob = CNNMnist(args=args).to(args.device) 101 | elif args.model == 'alexnet' and args.dataset == 'FashionMnist': 102 | net_glob = AlexNetMnist(args=args).to(args.device) 103 | elif args.model == 'cnn' and args.dataset == 'celeba': 104 | net_glob = CNNCeleba(args=args).to(args.device) 105 | elif args.model == 'mlp': 106 | len_in = 1 107 | for x in img_size: 108 | len_in *= x 109 | net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device) 110 | else: 111 | exit('Error: unrecognized model') 112 | print(net_glob) 113 | net_glob.train() 114 | 115 | # copy weights 116 | w_glob = net_glob.state_dict() 117 | 118 | # training 119 | loss_train = [] 120 | cv_loss, cv_acc = [], [] 121 | val_loss_pre, counter = 0, 0 122 | net_best = None 123 | best_loss = None 124 | val_acc_list, net_list = [], [] 125 | 126 | for iter in range(args.epochs): 127 | w_locals, loss_locals = [], [] 128 | m = max(int(args.frac * args.num_users), 1) 129 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 130 | for idx in idxs_users: 131 | local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx]) 132 | w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device)) 133 | w_locals.append(copy.deepcopy(w)) 134 | loss_locals.append(copy.deepcopy(loss)) 135 | # update global weights 136 | w_glob = FedAvg(w_locals) 137 | 138 | # copy weight to net_glob 139 | net_glob.load_state_dict(w_glob) 140 | net_glob.half(args=args) 141 | # print loss 142 | loss_avg = sum(loss_locals) / len(loss_locals) 143 | #print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg)) 144 | # loss_train.append(loss_avg) 145 | loss_train.append(loss_avg) 146 | test() 147 | 148 | #plot loss curve 149 | plt.figure() 150 | plt.plot(range(len(loss_train)), loss_train) 151 | plt.ylabel('train_loss') 152 | plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid)) 153 | 154 | # testing 155 | net_glob.eval() 156 | acc_train, loss_train = test_img(net_glob, dataset_train, args) 157 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 158 | print("Training accuracy: {:.2f}".format(acc_train)) 159 | print("Testing accuracy: {:.2f}".format(acc_test)) 160 | print() 161 | 162 | -------------------------------------------------------------------------------- /main_nn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | import torch.optim as optim 13 | from torchvision import datasets, transforms 14 | 15 | from utils.options import args_parser 16 | from models.Nets import MLP, CNNMnist, CNNCifar 17 | 18 | 19 | def test(net_g, data_loader): 20 | # testing 21 | net_g.eval() 22 | test_loss = 0 23 | correct = 0 24 | l = len(data_loader) 25 | for idx, (data, target) in enumerate(data_loader): 26 | data, target = data.to(args.device), target.to(args.device) 27 | log_probs = net_g(data) 28 | test_loss += F.cross_entropy(log_probs, target).item() 29 | y_pred = log_probs.data.max(1, keepdim=True)[1] 30 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 31 | 32 | test_loss /= len(data_loader.dataset) 33 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format( 34 | test_loss, correct, len(data_loader.dataset), 35 | 100. * correct / len(data_loader.dataset))) 36 | 37 | return correct, test_loss 38 | 39 | 40 | if __name__ == '__main__': 41 | # parse args 42 | args = args_parser() 43 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 44 | 45 | torch.manual_seed(args.seed) 46 | 47 | # load dataset and split users 48 | if args.dataset == 'mnist': 49 | dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True, 50 | transform=transforms.Compose([ 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.1307,), (0.3081,)) 53 | ])) 54 | img_size = dataset_train[0][0].shape 55 | elif args.dataset == 'cifar': 56 | transform = transforms.Compose( 57 | [transforms.ToTensor(), 58 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 59 | dataset_train = datasets.CIFAR10('./data/cifar', train=True, transform=transform, target_transform=None, download=True) 60 | img_size = dataset_train[0][0].shape 61 | else: 62 | exit('Error: unrecognized dataset') 63 | 64 | # build model 65 | if args.model == 'cnn' and args.dataset == 'cifar': 66 | net_glob = CNNCifar(args=args).to(args.device) 67 | elif args.model == 'cnn' and args.dataset == 'mnist': 68 | net_glob = CNNMnist(args=args).to(args.device) 69 | elif args.model == 'mlp': 70 | len_in = 1 71 | for x in img_size: 72 | len_in *= x 73 | net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).to(args.device) 74 | else: 75 | exit('Error: unrecognized model') 76 | print(net_glob) 77 | 78 | # training 79 | optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum) 80 | train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True) 81 | 82 | list_loss = [] 83 | net_glob.train() 84 | for epoch in range(args.epochs): 85 | batch_loss = [] 86 | for batch_idx, (data, target) in enumerate(train_loader): 87 | data, target = data.to(args.device), target.to(args.device) 88 | optimizer.zero_grad() 89 | output = net_glob(data) 90 | loss = F.cross_entropy(output, target) 91 | loss.backward() 92 | optimizer.step() 93 | if batch_idx % 50 == 0: 94 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 95 | epoch, batch_idx * len(data), len(train_loader.dataset), 96 | 100. * batch_idx / len(train_loader), loss.item())) 97 | batch_loss.append(loss.item()) 98 | loss_avg = sum(batch_loss)/len(batch_loss) 99 | print('\nTrain loss:', loss_avg) 100 | list_loss.append(loss_avg) 101 | 102 | # plot loss 103 | plt.figure() 104 | plt.plot(range(len(list_loss)), list_loss) 105 | plt.xlabel('epochs') 106 | plt.ylabel('train loss') 107 | plt.savefig('./log/nn_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs)) 108 | 109 | # testing 110 | if args.dataset == 'mnist': 111 | dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True, 112 | transform=transforms.Compose([ 113 | transforms.ToTensor(), 114 | transforms.Normalize((0.1307,), (0.3081,)) 115 | ])) 116 | test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False) 117 | elif args.dataset == 'cifar': 118 | transform = transforms.Compose( 119 | [transforms.ToTensor(), 120 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 121 | dataset_test = datasets.CIFAR10('./data/cifar', train=False, transform=transform, target_transform=None, download=True) 122 | test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False) 123 | else: 124 | exit('Error: unrecognized dataset') 125 | 126 | print('test on', len(dataset_test), 'samples') 127 | test_acc, test_loss = test(net_glob, test_loader) 128 | -------------------------------------------------------------------------------- /main_quan_celeba.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import matplotlib 6 | 7 | matplotlib.use('Agg') 8 | import matplotlib.pyplot as plt 9 | import copy 10 | import numpy as np 11 | from torchvision import datasets, transforms 12 | from torch.utils.data import DataLoader 13 | import torch 14 | from torchstat import stat 15 | 16 | from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, fmnist_iid, celeba_iid 17 | from utils.options import args_parser 18 | #from models.Update import LocalUpdate 19 | from models.Nets import MLP, CNNMnist, CNNCifar, CNNFMnist, CNNCeleba 20 | from models.quan_resnet import QCNNCeleba 21 | from models.quan_update import LocalUpdate 22 | from models.Fed import FedAvg 23 | from models.test import test_img 24 | import loading_data as dataset 25 | from models.Nets import Bottleneck 26 | #from service.client import * 27 | #from service.utils.DH import * 28 | 29 | 30 | def test(): 31 | net_glob.eval() 32 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 33 | print("Testing accuracy: {:.2f}".format(acc_test)) 34 | 35 | 36 | if __name__ == '__main__': 37 | # parse args 38 | #print('差值量化') 39 | 40 | args = args_parser() 41 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 42 | 43 | # load dataset and split users 44 | if args.dataset == 'mnist': 45 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 46 | dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist) 47 | dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist) 48 | # sample users 49 | if args.iid: 50 | dict_users = mnist_iid(dataset_train, args.num_users) 51 | else: 52 | dict_users = mnist_noniid(dataset_train, args.num_users) 53 | elif args.dataset == 'fmnist': 54 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 55 | dataset_train = datasets.FashionMNIST('../data/fmnist/', train=True, download=True, transform=trans_mnist) 56 | dataset_test = datasets.FashionMNIST('../data/fmnist/', train=False, download=True, transform=trans_mnist) 57 | # sample users 58 | if args.iid: 59 | dict_users = fmnist_iid(dataset_train, args.num_users) 60 | else: 61 | exit('Error: only consider IID setting in CIFAR10') 62 | elif args.dataset == 'celeba': 63 | custom_transform = transforms.Compose([transforms.CenterCrop((178, 178)), 64 | transforms.Resize((128, 128)), 65 | # transforms.Grayscale(), 66 | # transforms.Lambda(lambda x: x/255.), 67 | transforms.ToTensor()]) 68 | 69 | dataset_train = dataset.CelebaDataset(csv_path='celeba-gender-train.csv', 70 | img_dir='data/CelebA/img_align_celeba/', 71 | transform=custom_transform) 72 | 73 | valid_celeba= dataset.CelebaDataset(csv_path='celeba-gender-valid.csv', 74 | img_dir='data/CelebA/img_align_celeba/', 75 | transform=custom_transform) 76 | 77 | dataset_test = dataset.CelebaDataset(csv_path='celeba-gender-test.csv', 78 | img_dir='data/CelebA/img_align_celeba/', 79 | transform=custom_transform) 80 | 81 | if args.iid: 82 | dict_users = celeba_iid(dataset_train, args.num_users) 83 | else: 84 | exit('Error: only consider IID setting in CIFAR10') 85 | elif args.dataset == 'cifar': 86 | trans_cifar = transforms.Compose( 87 | [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 88 | dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar) 89 | dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar) 90 | if args.iid: 91 | dict_users = cifar_iid(dataset_train, args.num_users) 92 | else: 93 | exit('Error: only consider IID setting in CIFAR10') 94 | else: 95 | exit('Error: unrecognized dataset') 96 | img_size = dataset_train[0][0].shape 97 | 98 | # initialise aggregate service 99 | #p = random_prime(2048) 100 | #g = get_generator(p) 101 | #clients = [] 102 | #for i in range(args.num_users): 103 | # client = Client(i, p, g, args.bit_width) 104 | # clients.append(client) 105 | # set bit width 106 | #clients[0].set_bit_width() 107 | 108 | # fetch the public key list and compute K 109 | #for i in range(args.num_users): 110 | # clients[i].generate_k() 111 | 112 | # build model 113 | if args.model == 'cnn' and args.dataset == 'cifar': 114 | net_glob = CNNCifar(args=args).to(args.device) 115 | elif args.model == 'cnn' and args.dataset == 'mnist': 116 | net_glob = CNNMnist(args=args).to(args.device) 117 | elif args.model == 'cnn' and args.dataset == 'Fmnist': 118 | net_glob = CNNFMnist(args=args).to(args.device) 119 | elif args.model == 'cnn' and args.dataset == 'celeba': 120 | net_glob = CNNCeleba(args=args).to(args.device) 121 | elif args.model == 'qcnn' and args.dataset == 'celeba': 122 | net_glob = QCNNCeleba(args=args).to(args.device) 123 | elif args.model == 'mlp': 124 | len_in = 1 125 | for x in img_size: 126 | len_in *= x 127 | net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device) 128 | else: 129 | exit('Error: unrecognized model') 130 | 131 | #print(net_glob) 132 | net_glob.train() 133 | 134 | # copy weights 135 | w_glob = net_glob.state_dict() 136 | 137 | # training 138 | loss_train = [] 139 | cv_loss, cv_acc = [], [] 140 | val_loss_pre, counter = 0, 0 141 | net_best = None 142 | best_loss = None 143 | val_acc_list, net_list = [], [] 144 | 145 | if args.all_clients: 146 | print("Aggregation over all clients") 147 | w_locals = [w_glob for i in range(args.num_users)] 148 | for iter in range(args.epochs): 149 | net_glob.update() 150 | loss_locals = [] 151 | if not args.all_clients: 152 | w_locals = [] 153 | m = max(int(args.frac * args.num_users), 1) 154 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 155 | for idx in idxs_users: 156 | local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx]) 157 | w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device)) 158 | if args.all_clients: 159 | w_locals[idx] = copy.deepcopy(w) 160 | else: 161 | w_locals.append(copy.deepcopy(w)) 162 | loss_locals.append(copy.deepcopy(loss)) 163 | # update global weights 164 | w_glob = FedAvg(w_locals) 165 | 166 | # copy weight to net_glob 167 | net_glob.load_state_dict(w_glob) 168 | 169 | net_glob.unquant(args) 170 | 171 | net_glob.add() 172 | 173 | # print loss 174 | loss_avg = sum(loss_locals) / len(loss_locals) 175 | #print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg)) 176 | loss_train.append(loss_avg) 177 | test() 178 | 179 | # plot loss curve 180 | plt.figure() 181 | plt.plot(range(len(loss_train)), loss_train) 182 | plt.ylabel('train_loss') 183 | plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid)) 184 | 185 | # testing 186 | net_glob.eval() 187 | acc_train, loss_train = test_img(net_glob, dataset_train, args) 188 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 189 | print("Training accuracy: {:.2f}".format(acc_train)) 190 | print("Testing accuracy: {:.2f}".format(acc_test)) 191 | 192 | -------------------------------------------------------------------------------- /model_para.py: -------------------------------------------------------------------------------- 1 | from models import Nets 2 | from torchstat import stat 3 | from utils.options import args_parser 4 | import torch 5 | import torchvision.models as models 6 | args = args_parser() 7 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 8 | model = Nets.CNNCeleba(args=args).to(args.device) 9 | 10 | stat(model, (3, 128, 128)) 11 | 12 | print("resnet18 have {} paramerters in total".format(sum(x.numel() for x in model.parameters()))) -------------------------------------------------------------------------------- /models/Fed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import copy 6 | import torch 7 | from torch import nn 8 | 9 | 10 | def FedAvg(w): 11 | w_avg = copy.deepcopy(w[0]) 12 | for k in w_avg.keys(): 13 | for i in range(1, len(w)): 14 | w_avg[k] += w[i][k] 15 | w_avg[k] = torch.true_divide(w_avg[k], len(w)) 16 | return w_avg 17 | -------------------------------------------------------------------------------- /models/Nets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | import copy 9 | 10 | 11 | class MLP(nn.Module): 12 | def __init__(self, dim_in, dim_hidden, dim_out): 13 | super(MLP, self).__init__() 14 | self.layer_input = nn.Linear(dim_in, dim_hidden) 15 | self.relu = nn.ReLU() 16 | self.dropout = nn.Dropout() 17 | self.layer_hidden = nn.Linear(dim_hidden, dim_out) 18 | 19 | def forward(self, x): 20 | x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1]) 21 | x = self.layer_input(x) 22 | x = self.dropout(x) 23 | x = self.relu(x) 24 | x = self.layer_hidden(x) 25 | return x 26 | 27 | class CNNMnist(nn.Module): 28 | def __init__(self, args): 29 | super(CNNMnist, self).__init__() 30 | self.bit_width = args.bit_width 31 | self.alpha = args.alpha 32 | self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5) 33 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 34 | self.conv2_drop = nn.Dropout2d() 35 | self.fc1 = nn.Linear(320, 50) 36 | self.fc2 = nn.Linear(50, args.num_classes) 37 | self.lastc1 = copy.deepcopy(self.conv1.weight.detach()) 38 | self.lastc2 = copy.deepcopy(self.conv2.weight.detach()) 39 | self.lastf1 = copy.deepcopy(self.fc1.weight.detach()) 40 | self.lastf2 = copy.deepcopy(self.fc2.weight.detach()) 41 | 42 | def forward(self, x): 43 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 44 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 45 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 46 | x = F.relu(self.fc1(x)) 47 | x = F.dropout(x, training=self.training) 48 | x = self.fc2(x) 49 | return x 50 | 51 | 52 | class AlexNetMnist(nn.Module): 53 | def __init__(self, args): 54 | super(AlexNetMnist, self).__init__() 55 | self.features = nn.Sequential( 56 | nn.Conv2d(args.num_channels, 64, kernel_size=3, stride=2, padding=1), 57 | nn.ReLU(inplace=True), 58 | nn.MaxPool2d(kernel_size=2), 59 | nn.Conv2d(64, 192, kernel_size=3, padding=1), 60 | nn.ReLU(inplace=True), 61 | nn.MaxPool2d(kernel_size=2), 62 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 63 | nn.ReLU(inplace=True), 64 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 65 | nn.ReLU(inplace=True), 66 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 67 | nn.ReLU(inplace=True), 68 | nn.MaxPool2d(kernel_size=2), 69 | ) 70 | self.classifier = nn.Sequential( 71 | nn.Dropout(), 72 | nn.Linear(256 * 1 * 1, 4096), 73 | nn.ReLU(inplace=True), 74 | nn.Dropout(), 75 | nn.Linear(4096, 4096), 76 | nn.ReLU(inplace=True), 77 | nn.Linear(4096, args.num_classes), 78 | ) 79 | 80 | def forward(self, x): 81 | x = self.features(x) 82 | x = x.view(x.size(0), 256 * 1 * 1) 83 | x = self.classifier(x) 84 | return x 85 | 86 | class CNNCifar(nn.Module): 87 | def __init__(self, args): 88 | super(CNNCifar, self).__init__() 89 | self.conv1 = nn.Conv2d(3, 6, 5) 90 | self.pool = nn.MaxPool2d(2, 2) 91 | self.conv2 = nn.Conv2d(6, 16, 5) 92 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 93 | self.fc2 = nn.Linear(120, 84) 94 | self.fc3 = nn.Linear(84, args.num_classes) 95 | 96 | def forward(self, x): 97 | x = self.pool(F.relu(self.conv1(x))) 98 | x = self.pool(F.relu(self.conv2(x))) 99 | x = x.view(-1, 16 * 5 * 5) 100 | x = F.relu(self.fc1(x)) 101 | x = F.relu(self.fc2(x)) 102 | x = self.fc3(x) 103 | return x 104 | 105 | class CNNFMnist(nn.Module): 106 | def __init__(self, args): 107 | super(CNNMnist, self).__init__() 108 | self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5) 109 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 110 | self.conv2_drop = nn.Dropout2d() 111 | self.fc1 = nn.Linear(320, 50) 112 | self.fc2 = nn.Linear(50, args.num_classes) 113 | 114 | def forward(self, x): 115 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 116 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 117 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 118 | x = F.relu(self.fc1(x)) 119 | x = F.dropout(x, training=self.training) 120 | x = self.fc2(x) 121 | return x 122 | 123 | 124 | def conv3x3(in_planes, out_planes, stride=1): 125 | """3x3 convolution with padding""" 126 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 127 | padding=1, bias=False) 128 | 129 | 130 | class Bottleneck(nn.Module): 131 | expansion = 4 132 | 133 | def __init__(self, inplanes, planes, stride=1, downsample=None): 134 | super(Bottleneck, self).__init__() 135 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 136 | self.bn1 = nn.BatchNorm2d(planes) 137 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 138 | padding=1, bias=False) 139 | self.bn2 = nn.BatchNorm2d(planes) 140 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 141 | self.bn3 = nn.BatchNorm2d(planes * 4) 142 | self.relu = nn.ReLU(inplace=True) 143 | self.downsample = downsample 144 | self.stride = stride 145 | 146 | 147 | def forward(self, x): 148 | residual = x 149 | 150 | out = self.conv1(x) 151 | out = self.bn1(out) 152 | out = self.relu(out) 153 | 154 | out = self.conv2(out) 155 | out = self.bn2(out) 156 | out = self.relu(out) 157 | 158 | out = self.conv3(out) 159 | out = self.bn3(out) 160 | 161 | if self.downsample is not None: 162 | residual = self.downsample(x) 163 | 164 | out += residual 165 | out = self.relu(out) 166 | 167 | return out 168 | 169 | class AlexNetCifar(nn.Module): 170 | def __init__(self, args): 171 | super(AlexNetCifar, self).__init__() 172 | self.features = nn.Sequential( 173 | nn.Conv2d(args.num_channels, 64, kernel_size=3, stride=2, padding=1), 174 | nn.ReLU(inplace=True), 175 | nn.MaxPool2d(kernel_size=2, stride=2), 176 | nn.Conv2d(64, 192, kernel_size=3, padding=1), 177 | nn.ReLU(inplace=True), 178 | nn.MaxPool2d(kernel_size=2, stride=2), 179 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 180 | nn.ReLU(inplace=True), 181 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 182 | nn.ReLU(inplace=True), 183 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 184 | nn.ReLU(inplace=True), 185 | nn.MaxPool2d(kernel_size=2, stride=2), 186 | ) 187 | self.classifier = nn.Sequential( 188 | nn.Dropout(), 189 | nn.Linear(256 * 2 * 2, 4096), 190 | nn.ReLU(inplace=True), 191 | nn.Dropout(), 192 | nn.Linear(4096, 4096), 193 | nn.ReLU(inplace=True), 194 | nn.Linear(4096, args.num_classes), 195 | ) 196 | 197 | def forward(self, input): 198 | output = self.features(input) 199 | output = output.view(-1, 256*2*2) 200 | output = self.classifier(output) 201 | return output 202 | 203 | 204 | 205 | class CNNCeleba(nn.Module): 206 | def __init__(self, args): 207 | args.block = Bottleneck 208 | args.grayscale = False 209 | args.layers = [2, 2, 2, 2] 210 | self.inplanes = 64 211 | if args.grayscale: 212 | in_dim = 1 213 | else: 214 | in_dim = 3 215 | super(CNNCeleba, self).__init__() 216 | self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3, 217 | bias=False) 218 | self.bn1 = nn.BatchNorm2d(64) 219 | self.relu = nn.ReLU(inplace=True) 220 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 221 | self.layer1 = self._make_layer(args.block, 64, args.layers[0]) 222 | self.layer2 = self._make_layer(args.block, 128, args.layers[1], stride=2) 223 | self.layer3 = self._make_layer(args.block, 256, args.layers[2], stride=2) 224 | self.layer4 = self._make_layer(args.block, 512, args.layers[3], stride=2) 225 | self.avgpool = nn.AvgPool2d(7, stride=1, padding=2) 226 | self.fc = nn.Linear(2048 * args.block.expansion, args.num_classes) 227 | 228 | for m in self.modules(): 229 | if isinstance(m, nn.Conv2d): 230 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 231 | m.weight.data.normal_(0, (2. / n)**.5) 232 | elif isinstance(m, nn.BatchNorm2d): 233 | m.weight.data.fill_(1) 234 | m.bias.data.zero_() 235 | 236 | def _make_layer(self, block, planes, blocks, stride=1): 237 | downsample = None 238 | if stride != 1 or self.inplanes != planes * block.expansion: 239 | downsample = nn.Sequential( 240 | nn.Conv2d(self.inplanes, planes * block.expansion, 241 | kernel_size=1, stride=stride, bias=False), 242 | nn.BatchNorm2d(planes * block.expansion), 243 | ) 244 | 245 | layers = [] 246 | layers.append(block(self.inplanes, planes, stride, downsample)) 247 | self.inplanes = planes * block.expansion 248 | for i in range(1, blocks): 249 | layers.append(block(self.inplanes, planes)) 250 | 251 | return nn.Sequential(*layers) 252 | 253 | def forward(self, x): 254 | x = self.conv1(x) 255 | x = self.bn1(x) 256 | x = self.relu(x) 257 | x = self.maxpool(x) 258 | 259 | x = self.layer1(x) 260 | x = self.layer2(x) 261 | x = self.layer3(x) 262 | x = self.layer4(x) 263 | 264 | x = self.avgpool(x) 265 | x = x.view(x.size(0), -1) 266 | logits = self.fc(x) 267 | probas = F.softmax(logits, dim=1) 268 | return probas 269 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | # models code 2 | -------------------------------------------------------------------------------- /models/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import copy 5 | 6 | 7 | class ResBlock(nn.Module): 8 | def __init__(self, inchannel, outchannel, stride=1): 9 | super(ResBlock, self).__init__() 10 | # 这里定义了残差块内连续的2个卷积层 11 | self.left = nn.Sequential( 12 | nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False), 13 | nn.BatchNorm2d(outchannel), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False), 16 | nn.BatchNorm2d(outchannel) 17 | ) 18 | self.shortcut = nn.Sequential() 19 | if stride != 1 or inchannel != outchannel: 20 | # shortcut,这里为了跟2个卷积层的结果结构一致,要做处理 21 | self.shortcut = nn.Sequential( 22 | nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False), 23 | nn.BatchNorm2d(outchannel) 24 | ) 25 | 26 | def forward(self, x): 27 | out = self.left(x) 28 | # 将2个卷积层的输出跟处理过的x相加,实现ResNet的基本结构 29 | out = out + self.shortcut(x) 30 | out = F.relu(out) 31 | 32 | return out 33 | 34 | 35 | class ResNet(nn.Module): 36 | def __init__(self, args): 37 | super(ResNet, self).__init__() 38 | args.ResBlock = ResBlock 39 | self.inchannel = 64 40 | self.conv1 = nn.Sequential( 41 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False), 42 | nn.BatchNorm2d(64), 43 | nn.ReLU() 44 | ) 45 | self.layer1 = self.make_layer(ResBlock, 64, 2, stride=1) 46 | self.layer2 = self.make_layer(ResBlock, 128, 2, stride=2) 47 | self.layer3 = self.make_layer(ResBlock, 256, 2, stride=2) 48 | self.layer4 = self.make_layer(ResBlock, 512, 2, stride=2) 49 | self.fc = nn.Linear(512, args.num_classes) 50 | 51 | # 这个函数主要是用来,重复同一个残差块 52 | def make_layer(self, block, channels, num_blocks, stride): 53 | strides = [stride] + [1] * (num_blocks - 1) 54 | layers = [] 55 | for stride in strides: 56 | layers.append(block(self.inchannel, channels, stride)) 57 | self.inchannel = channels 58 | return nn.Sequential(*layers) 59 | 60 | def forward(self, x): 61 | # 在这里,整个ResNet18的结构就很清晰了 62 | out = self.conv1(x) 63 | out = self.layer1(out) 64 | out = self.layer2(out) 65 | out = self.layer3(out) 66 | out = self.layer4(out) 67 | out = F.avg_pool2d(out, 4) 68 | out = out.view(out.size(0), -1) 69 | out = self.fc(out) 70 | return out 71 | -------------------------------------------------------------------------------- /models/Scaling_Net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from utils.options import args_parser 9 | 10 | 11 | class MLP(nn.Module): 12 | def __init__(self, dim_in, dim_hidden, dim_out): 13 | super(MLP, self).__init__() 14 | self.layer_input = nn.Linear(dim_in, dim_hidden) 15 | self.relu = nn.ReLU() 16 | self.dropout = nn.Dropout() 17 | self.layer_hidden = nn.Linear(dim_hidden, dim_out) 18 | 19 | def forward(self, x): 20 | x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1]) 21 | x = self.layer_input(x) 22 | x = self.dropout(x) 23 | x = self.relu(x) 24 | x = self.layer_hidden(x) 25 | return x 26 | 27 | 28 | class CNNMnist(nn.Module): 29 | def __init__(self, args): 30 | super(CNNMnist, self).__init__() 31 | self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5) 32 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 33 | self.conv2_drop = nn.Dropout2d() 34 | self.fc1 = nn.Linear(320, 50) 35 | self.fc2 = nn.Linear(50, args.num_classes) 36 | 37 | def forward(self, x): 38 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 39 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 40 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 41 | x = F.relu(self.fc1(x)) 42 | x = F.dropout(x, training=self.training) 43 | x = self.fc2(x) 44 | return x 45 | 46 | def double(self, args): 47 | self.params = list(self.named_parameters()) 48 | for i in range(len(self.params)): 49 | self.params[i][1].data = torch.round((self.params[i][1].data * args.scaling_factor).to(args.device)) 50 | 51 | def half(self, args): 52 | self.params = list(self.named_parameters()) 53 | for i in range(len(self.params)): 54 | self.params[i][1].data = (self.params[i][1].data / args.scaling_factor).type(torch.float).to(args.device) 55 | 56 | class CNNCifar(nn.Module): 57 | def __init__(self, args): 58 | super(CNNCifar, self).__init__() 59 | self.conv1 = nn.Conv2d(3, 6, 5) 60 | self.pool = nn.MaxPool2d(2, 2) 61 | self.conv2 = nn.Conv2d(6, 16, 5) 62 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 63 | self.fc2 = nn.Linear(120, 84) 64 | self.fc3 = nn.Linear(84, args.num_classes) 65 | 66 | def forward(self, x): 67 | x = self.pool(F.relu(self.conv1(x))) 68 | x = self.pool(F.relu(self.conv2(x))) 69 | x = x.view(-1, 16 * 5 * 5) 70 | x = F.relu(self.fc1(x)) 71 | x = F.relu(self.fc2(x)) 72 | x = self.fc3(x) 73 | return x 74 | 75 | def double(self, args): 76 | self.params = list(self.named_parameters()) 77 | for i in range(len(self.params)): 78 | self.params[i][1].data = torch.round((self.params[i][1].data * args.scaling_factor).to(args.device)) 79 | 80 | def half(self, args): 81 | self.params = list(self.named_parameters()) 82 | for i in range(len(self.params)): 83 | self.params[i][1].data = (self.params[i][1].data / args.scaling_factor).type(torch.float).to(args.device) 84 | 85 | 86 | class AlexNetMnist(nn.Module): 87 | def __init__(self, args): 88 | super(AlexNetMnist, self).__init__() 89 | self.features = nn.Sequential( 90 | nn.Conv2d(args.num_channels, 64, kernel_size=3, stride=2, padding=1), 91 | nn.ReLU(inplace=True), 92 | nn.MaxPool2d(kernel_size=2), 93 | nn.Conv2d(64, 192, kernel_size=3, padding=1), 94 | nn.ReLU(inplace=True), 95 | nn.MaxPool2d(kernel_size=2), 96 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 97 | nn.ReLU(inplace=True), 98 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 99 | nn.ReLU(inplace=True), 100 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 101 | nn.ReLU(inplace=True), 102 | nn.MaxPool2d(kernel_size=2), 103 | ) 104 | self.classifier = nn.Sequential( 105 | nn.Dropout(), 106 | nn.Linear(256 * 1 * 1, 4096), 107 | nn.ReLU(inplace=True), 108 | nn.Dropout(), 109 | nn.Linear(4096, 4096), 110 | nn.ReLU(inplace=True), 111 | nn.Linear(4096, args.num_classes), 112 | ) 113 | 114 | def forward(self, x): 115 | x = self.features(x) 116 | x = x.view(x.size(0), 256 * 1 * 1) 117 | x = self.classifier(x) 118 | return x 119 | 120 | 121 | class AlexNetCifar(nn.Module): 122 | def __init__(self, args): 123 | super(AlexNetCifar, self).__init__() 124 | self.features = nn.Sequential( 125 | nn.Conv2d(args.num_channels, 64, kernel_size=3, stride=2, padding=1), 126 | nn.ReLU(inplace=True), 127 | nn.MaxPool2d(kernel_size=2, stride=2), 128 | nn.Conv2d(64, 192, kernel_size=3, padding=1), 129 | nn.ReLU(inplace=True), 130 | nn.MaxPool2d(kernel_size=2, stride=2), 131 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 132 | nn.ReLU(inplace=True), 133 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 134 | nn.ReLU(inplace=True), 135 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 136 | nn.ReLU(inplace=True), 137 | nn.MaxPool2d(kernel_size=2, stride=2), 138 | ) 139 | self.classifier = nn.Sequential( 140 | nn.Dropout(), 141 | nn.Linear(256 * 2 * 2, 4096), 142 | nn.ReLU(inplace=True), 143 | nn.Dropout(), 144 | nn.Linear(4096, 4096), 145 | nn.ReLU(inplace=True), 146 | nn.Linear(4096, args.num_classes), 147 | ) 148 | 149 | def forward(self, input): 150 | output = self.features(input) 151 | output = output.view(-1, 256*2*2) 152 | output = self.classifier(output) 153 | return output 154 | 155 | def double(self, args): 156 | self.params = list(self.named_parameters()) 157 | for i in range(len(self.params)): 158 | self.params[i][1].data = torch.round((self.params[i][1].data * args.scaling_factor).to(args.device)) 159 | 160 | def half(self, args): 161 | self.params = list(self.named_parameters()) 162 | for i in range(len(self.params)): 163 | self.params[i][1].data = (self.params[i][1].data / args.scaling_factor).type(torch.float).to(args.device) 164 | 165 | 166 | class ResBlock(nn.Module): 167 | def __init__(self, inchannel, outchannel, stride=1): 168 | super(ResBlock, self).__init__() 169 | # 这里定义了残差块内连续的2个卷积层 170 | self.left = nn.Sequential( 171 | nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False), 172 | nn.BatchNorm2d(outchannel), 173 | nn.ReLU(inplace=True), 174 | nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False), 175 | nn.BatchNorm2d(outchannel) 176 | ) 177 | self.shortcut = nn.Sequential() 178 | if stride != 1 or inchannel != outchannel: 179 | self.shortcut = nn.Sequential( 180 | nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False), 181 | nn.BatchNorm2d(outchannel) 182 | ) 183 | 184 | def forward(self, x): 185 | out = self.left(x) 186 | out = out + self.shortcut(x) 187 | out = F.relu(out) 188 | 189 | return out 190 | 191 | 192 | class CNNCeleba(nn.Module): 193 | def __init__(self, args): 194 | super(CNNCeleba, self).__init__() 195 | args.ResBlock = ResBlock 196 | self.inchannel = 64 197 | self.conv1 = nn.Sequential( 198 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False), 199 | nn.BatchNorm2d(64), 200 | nn.ReLU() 201 | ) 202 | self.layer1 = self.make_layer(ResBlock, 64, 2, stride=1) 203 | self.layer2 = self.make_layer(ResBlock, 128, 2, stride=2) 204 | self.layer3 = self.make_layer(ResBlock, 256, 2, stride=2) 205 | self.layer4 = self.make_layer(ResBlock, 512, 2, stride=2) 206 | self.fc = nn.Linear(512, args.num_classes) 207 | 208 | # 这个函数主要是用来,重复同一个残差块 209 | def make_layer(self, block, channels, num_blocks, stride): 210 | strides = [stride] + [1] * (num_blocks - 1) 211 | layers = [] 212 | for stride in strides: 213 | layers.append(block(self.inchannel, channels, stride)) 214 | self.inchannel = channels 215 | return nn.Sequential(*layers) 216 | 217 | def forward(self, x): 218 | # 在这里,整个ResNet18的结构就很清晰了 219 | out = self.conv1(x) 220 | out = self.layer1(out) 221 | out = self.layer2(out) 222 | out = self.layer3(out) 223 | out = self.layer4(out) 224 | out = F.avg_pool2d(out, 4) 225 | out = out.view(out.size(0), -1) 226 | out = self.fc(out) 227 | return out 228 | 229 | def double(self, args): 230 | self.params = list(self.named_parameters()) 231 | for i in range(len(self.params)): 232 | self.params[i][1].data = torch.round((self.params[i][1].data * args.scaling_factor).to(args.device)) 233 | 234 | def half(self, args): 235 | self.params = list(self.named_parameters()) 236 | for i in range(len(self.params)): 237 | self.params[i][1].data = (self.params[i][1].data / args.scaling_factor).type(torch.float).to(args.device) 238 | 239 | 240 | 241 | if __name__ == '__main__': 242 | args = args_parser() 243 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 244 | model = CNNMnist(args=args) 245 | print(model) 246 | #params = list(model.named_parameters()) 247 | #print(len(params)) 248 | #print((params[0][1].data * 400).to(args.device)) 249 | model = model.double() 250 | params = list(model.named_parameters()) 251 | print(params[0][1].data.to(args.device)) 252 | #for param in model.named_parameters(): 253 | #for i in range(len(params)): 254 | #print(params[i][1].data) 255 | 256 | #for param in model.named_parameters(): 257 | #print(param[1]) 258 | 259 | -------------------------------------------------------------------------------- /models/Scaling_update.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | from torch import nn, autograd 7 | from torch.utils.data import DataLoader, Dataset 8 | import numpy as np 9 | import random 10 | from sklearn import metrics 11 | 12 | 13 | class DatasetSplit(Dataset): 14 | def __init__(self, dataset, idxs): 15 | self.dataset = dataset 16 | self.idxs = list(idxs) 17 | 18 | def __len__(self): 19 | return len(self.idxs) 20 | 21 | def __getitem__(self, item): 22 | image, label = self.dataset[self.idxs[item]] 23 | return image, label 24 | 25 | 26 | class LocalUpdate(object): 27 | def __init__(self, args, dataset=None, idxs=None): 28 | self.args = args 29 | self.loss_func = nn.CrossEntropyLoss() 30 | self.selected_clients = [] 31 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True) 32 | 33 | def train(self, net): 34 | net.train() 35 | # train and update 36 | optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.5) 37 | 38 | epoch_loss = [] 39 | for iter in range(self.args.local_ep): 40 | batch_loss = [] 41 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 42 | images, labels = images.to(self.args.device), labels.to(self.args.device) 43 | net.zero_grad() 44 | log_probs = net(images) 45 | loss = self.loss_func(log_probs, labels) 46 | loss.backward() 47 | optimizer.step() 48 | if self.args.verbose and batch_idx % 10 == 0: 49 | print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 50 | iter, batch_idx * len(images), len(self.ldr_train.dataset), 51 | 100. * batch_idx / len(self.ldr_train), loss.item())) 52 | batch_loss.append(loss.item()) 53 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 54 | net.double(self.args) 55 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) 56 | 57 | -------------------------------------------------------------------------------- /models/Update.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | from torch import nn, autograd 7 | from torch.utils.data import DataLoader, Dataset 8 | import numpy as np 9 | import random 10 | from sklearn import metrics 11 | 12 | 13 | class DatasetSplit(Dataset): 14 | def __init__(self, dataset, idxs): 15 | self.dataset = dataset 16 | self.idxs = list(idxs) 17 | 18 | def __len__(self): 19 | return len(self.idxs) 20 | 21 | def __getitem__(self, item): 22 | image, label = self.dataset[self.idxs[item]] 23 | return image, label 24 | 25 | 26 | class LocalUpdate(object): 27 | def __init__(self, args, dataset=None, idxs=None): 28 | self.args = args 29 | self.loss_func = nn.CrossEntropyLoss() 30 | self.selected_clients = [] 31 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True, num_workers=4) 32 | 33 | def train(self, net): 34 | net.train() 35 | net.update() 36 | # train and update 37 | optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.5) 38 | 39 | epoch_loss = [] 40 | for iter in range(self.args.local_ep): 41 | batch_loss = [] 42 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 43 | images, labels = images.to(self.args.device), labels.to(self.args.device) 44 | net.zero_grad() 45 | log_probs = net(images) 46 | loss = self.loss_func(log_probs, labels) 47 | loss.backward() 48 | optimizer.step() 49 | if self.args.verbose and batch_idx % 10 == 0: 50 | print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 51 | iter, batch_idx * len(images), len(self.ldr_train.dataset), 52 | 100. * batch_idx / len(self.ldr_train), loss.item())) 53 | batch_loss.append(loss.item()) 54 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 55 | net.diff() 56 | # net.to(self.args.device) 57 | net.quant(self.args) 58 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) 59 | 60 | 61 | 62 | 63 | class LocalUpdate(object): 64 | def __init__(self, args, dataset=None, idxs=None): 65 | self.args = args 66 | self.loss_func = nn.CrossEntropyLoss() 67 | self.selected_clients = [] 68 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True, num_workers=4) 69 | 70 | def train(self, net): 71 | net.train() 72 | # train and update 73 | optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 74 | 75 | epoch_loss = [] 76 | for iter in range(self.args.local_ep): 77 | batch_loss = [] 78 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 79 | images, labels = images.to(self.args.device), labels.to(self.args.device) 80 | net.zero_grad() 81 | log_probs = net(images) 82 | loss = self.loss_func(log_probs, labels) 83 | loss.backward() 84 | optimizer.step() 85 | if self.args.verbose and batch_idx % 10 == 0: 86 | print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 87 | iter, batch_idx * len(images), len(self.ldr_train.dataset), 88 | 100. * batch_idx / len(self.ldr_train), loss.item())) 89 | batch_loss.append(loss.item()) 90 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 91 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) 92 | 93 | 94 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /models/quan_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import copy 5 | 6 | 7 | def quantize_matrix(matrix, bit_width=8, alpha=0.1): 8 | og_sign = torch.sign(matrix) 9 | uns_matrix = matrix * og_sign 10 | uns_result = torch.round((uns_matrix * (pow(2, bit_width - 1) - 1.0) / alpha)) 11 | result = (og_sign * uns_result) 12 | return result 13 | 14 | 15 | def unquantize_matrix(matrix, bit_width=8, alpha=0.1): 16 | matrix = matrix.int() 17 | og_sign = torch.sign(matrix) 18 | uns_matrix = matrix * og_sign 19 | uns_result = uns_matrix * alpha / (pow(2, bit_width - 1) - 1.0) 20 | result = og_sign * uns_result 21 | return result.float() 22 | 23 | 24 | class Bottleneck(nn.Module): 25 | expansion = 4 26 | 27 | def __init__(self, inplanes, planes, args, stride=1, downsample=None): 28 | super(Bottleneck, self).__init__() 29 | self.bit_width = args.bit_width 30 | self.alpha = args.alpha 31 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 32 | self.bn1 = nn.BatchNorm2d(planes) 33 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 34 | padding=1, bias=False) 35 | self.bn2 = nn.BatchNorm2d(planes) 36 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 37 | self.bn3 = nn.BatchNorm2d(planes * 4) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.downsample = downsample 40 | self.stride = stride 41 | self.lastc1 = copy.deepcopy(self.conv1.weight.detach()) 42 | self.lastc2 = copy.deepcopy(self.conv2.weight.detach()) 43 | self.lastc3 = copy.deepcopy(self.conv3.weight.detach()) 44 | 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv3(out) 58 | out = self.bn3(out) 59 | 60 | if self.downsample is not None: 61 | residual = self.downsample(x) 62 | 63 | out += residual 64 | out = self.relu(out) 65 | 66 | return out 67 | 68 | def update(self): 69 | self.lastc1 = copy.deepcopy(self.conv1.weight.detach()) 70 | self.lastc2 = copy.deepcopy(self.conv2.weight.detach()) 71 | self.lastc3 = copy.deepcopy(self.conv3.weight.detach()) 72 | 73 | 74 | def diff(self): 75 | temp1 = self.conv1.weight.detach() - self.lastc1 76 | temp2 = self.conv2.weight.detach() - self.lastc2 77 | temp3 = self.conv3.weight.detach() - self.lastc3 78 | 79 | 80 | self.conv1.weight = copy.deepcopy(nn.Parameter(temp1)) 81 | self.conv2.weight = copy.deepcopy(nn.Parameter(temp2)) 82 | self.conv3.weight = copy.deepcopy(nn.Parameter(temp3)) 83 | 84 | 85 | def add(self): 86 | temp1 = self.conv1.weight + self.lastc1 87 | temp2 = self.conv2.weight + self.lastc2 88 | temp3 = self.conv3.weight + self.lastc3 89 | 90 | self.conv1.weight = copy.deepcopy(nn.Parameter(temp1)) 91 | self.conv2.weight = copy.deepcopy(nn.Parameter(temp2)) 92 | self.conv3.weight = copy.deepcopy(nn.Parameter(temp3)) 93 | 94 | 95 | def quant(self, args): 96 | c1 = quantize_matrix(self.conv1.weight, self.bit_width, self.alpha).type(torch.FloatTensor) 97 | #self.conv1.weight = torch.nn.Parameter(c1).to(args.device) 98 | self.conv1.weight = torch.nn.Parameter(torch.FloatTensor(c1).to(args.device)) 99 | 100 | c2 = quantize_matrix(self.conv2.weight, self.bit_width, self.alpha).type(torch.FloatTensor) 101 | #self.conv2.weight = torch.nn.Parameter(c2).to(args.device) 102 | self.conv2.weight = torch.nn.Parameter(torch.FloatTensor(c2).to(args.device)) 103 | 104 | c3 = quantize_matrix(self.conv3.weight, self.bit_width, self.alpha).type(torch.FloatTensor) 105 | #self.conv3.weight = torch.nn.Parameter(c3).to(args.device) 106 | self.conv3.weight = torch.nn.Parameter(torch.FloatTensor(c3).to(args.device)) 107 | 108 | 109 | def unquant(self, args): 110 | c1 = unquantize_matrix(self.conv1.weight.detach(), self.bit_width, self.alpha).type(torch.FloatTensor) 111 | #self.conv1.weight = torch.nn.Parameter(c1).to(args.device) 112 | self.conv1.weight = torch.nn.Parameter(torch.FloatTensor(c1).to(args.device)) 113 | 114 | c2 = unquantize_matrix(self.conv2.weight.detach(), self.bit_width, self.alpha).type(torch.FloatTensor) 115 | #self.conv2.weight = torch.nn.Parameter(c2).to(args.device) 116 | self.conv2.weight = torch.nn.Parameter(torch.FloatTensor(c2).to(args.device)) 117 | 118 | c3 = quantize_matrix(self.conv3.weight, self.bit_width, self.alpha).type(torch.FloatTensor) 119 | #self.conv3.weight = torch.nn.Parameter(c3).to(args.device) 120 | self.conv3.weight = torch.nn.Parameter(torch.FloatTensor(c3).to(args.device)) 121 | 122 | 123 | 124 | 125 | class QCNNCeleba(nn.Module): 126 | def __init__(self, args): 127 | args.block = Bottleneck 128 | args.grayscale = False 129 | args.layers = [2, 2, 2, 2] 130 | self.bit_width = args.bit_width 131 | self.alpha = args.alpha 132 | self.inplanes = 64 133 | if args.grayscale: 134 | in_dim = 1 135 | else: 136 | in_dim = 3 137 | super(QCNNCeleba, self).__init__() 138 | self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3, bias=False) 139 | self.bn1 = nn.BatchNorm2d(64) 140 | self.relu = nn.ReLU(inplace=True) 141 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 142 | self.layer1 = self._make_layer(args.block, 64, args.layers[0], args) 143 | self.layer2 = self._make_layer(args.block, 128, args.layers[1], args, stride=2) 144 | self.layer3 = self._make_layer(args.block, 256, args.layers[2], args, stride=2) 145 | self.layer4 = self._make_layer(args.block, 512, args.layers[3], args, stride=2) 146 | self.avgpool = nn.AvgPool2d(7, stride=1, padding=2) 147 | self.fc = nn.Linear(2048 * args.block.expansion, args.num_classes) 148 | self.lastc1 = copy.deepcopy(self.conv1.weight.detach()) 149 | self.lastf1 = copy.deepcopy(self.fc.weight.detach()) 150 | 151 | 152 | 153 | for m in self.modules(): 154 | if isinstance(m, nn.Conv2d): 155 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 156 | m.weight.data.normal_(0, (2. / n)**.5) 157 | elif isinstance(m, nn.BatchNorm2d): 158 | m.weight.data.fill_(1) 159 | m.bias.data.zero_() 160 | 161 | def _make_layer(self, block, planes, blocks, args, stride=1): 162 | downsample = None 163 | if stride != 1 or self.inplanes != planes * block.expansion: 164 | downsample = nn.Sequential( 165 | nn.Conv2d(self.inplanes, planes * block.expansion, 166 | kernel_size=1, stride=stride, bias=False), 167 | nn.BatchNorm2d(planes * block.expansion), 168 | ) 169 | 170 | layers = [] 171 | layers.append(block(self.inplanes, planes, args, stride, downsample)) 172 | self.inplanes = planes * block.expansion 173 | for i in range(1, blocks): 174 | layers.append(block(self.inplanes, planes, args)) 175 | 176 | return nn.Sequential(*layers) 177 | 178 | def forward(self, x): 179 | x = self.conv1(x) 180 | x = self.bn1(x) 181 | x = self.relu(x) 182 | x = self.maxpool(x) 183 | 184 | x = self.layer1(x) 185 | x = self.layer2(x) 186 | x = self.layer3(x) 187 | x = self.layer4(x) 188 | 189 | x = self.avgpool(x) 190 | x = x.view(x.size(0), -1) 191 | logits = self.fc(x) 192 | probas = F.softmax(logits, dim=1) 193 | return probas 194 | 195 | def update(self): 196 | self.lastc1 = copy.deepcopy(self.conv1.weight.detach()) 197 | self.lastf1 = copy.deepcopy(self.fc.weight.detach()) 198 | for layer in self.layer1: 199 | layer.update() 200 | for layer in self.layer2: 201 | layer.update() 202 | for layer in self.layer3: 203 | layer.update() 204 | for layer in self.layer4: 205 | layer.update() 206 | 207 | def diff(self): 208 | temp1 = self.conv1.weight.detach() - self.lastc1 209 | temp2 = self.fc.weight.detach() - self.lastf1 210 | 211 | self.conv1.weight = copy.deepcopy(nn.Parameter(temp1)) 212 | self.fc.weight = copy.deepcopy(nn.Parameter(temp2)) 213 | 214 | for layer in self.layer1: 215 | layer.diff() 216 | for layer in self.layer2: 217 | layer.diff() 218 | for layer in self.layer3: 219 | layer.diff() 220 | for layer in self.layer4: 221 | layer.diff() 222 | 223 | 224 | def add(self): 225 | temp1 = self.conv1.weight + self.lastc1 226 | temp2 = self.fc.weight + self.lastf1 227 | 228 | self.conv1.weight = copy.deepcopy(nn.Parameter(temp1)) 229 | self.fc.weight = copy.deepcopy(nn.Parameter(temp2)) 230 | 231 | for layer in self.layer1: 232 | layer.add() 233 | for layer in self.layer2: 234 | layer.add() 235 | for layer in self.layer3: 236 | layer.add() 237 | for layer in self.layer4: 238 | layer.add() 239 | 240 | 241 | def quant(self, args): 242 | c1 = quantize_matrix(self.conv1.weight, self.bit_width, self.alpha).type(torch.FloatTensor) 243 | #self.conv1.weight = torch.nn.Parameter(c1).to(args.device) 244 | self.conv1.weight = torch.nn.Parameter(torch.FloatTensor(c1).to(args.device)) 245 | 246 | f1 = quantize_matrix(self.fc.weight, self.bit_width, self.alpha).type(torch.FloatTensor) 247 | #self.fc.weight = torch.nn.Parameter(f1).to(args.device) 248 | self.fc.weight = torch.nn.Parameter(torch.FloatTensor(f1).to(args.device)) 249 | 250 | for layer in self.layer1: 251 | layer.quant(args) 252 | for layer in self.layer2: 253 | layer.quant(args) 254 | for layer in self.layer3: 255 | layer.quant(args) 256 | for layer in self.layer4: 257 | layer.quant(args) 258 | 259 | def unquant(self, args): 260 | c1 = unquantize_matrix(self.conv1.weight, self.bit_width, self.alpha).type(torch.FloatTensor) 261 | #self.conv1.weight = torch.nn.Parameter(c1).to(args.device) 262 | self.conv1.weight = torch.nn.Parameter(torch.FloatTensor(c1).to(args.device)) 263 | 264 | f1 = unquantize_matrix(self.fc.weight, self.bit_width, self.alpha).type(torch.FloatTensor) 265 | #self.fc.weight = torch.nn.Parameter(f1).to(args.device) 266 | self.fc.weight = torch.nn.Parameter(torch.FloatTensor(f1).to(args.device)) 267 | 268 | for layer in self.layer1: 269 | layer.unquant(args) 270 | for layer in self.layer2: 271 | layer.unquant(args) 272 | for layer in self.layer3: 273 | layer.unquant(args) 274 | for layer in self.layer4: 275 | layer.unquant(args) 276 | 277 | -------------------------------------------------------------------------------- /models/quan_update.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | from torch import nn, autograd 7 | from torch.utils.data import DataLoader, Dataset 8 | import numpy as np 9 | import random 10 | from sklearn import metrics 11 | 12 | 13 | class DatasetSplit(Dataset): 14 | def __init__(self, dataset, idxs): 15 | self.dataset = dataset 16 | self.idxs = list(idxs) 17 | 18 | def __len__(self): 19 | return len(self.idxs) 20 | 21 | def __getitem__(self, item): 22 | image, label = self.dataset[self.idxs[item]] 23 | return image, label 24 | 25 | 26 | class LocalUpdate(object): 27 | def __init__(self, args, dataset=None, idxs=None): 28 | self.args = args 29 | self.loss_func = nn.CrossEntropyLoss() 30 | self.selected_clients = [] 31 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True) 32 | 33 | def train(self, net): 34 | net.train() 35 | net.update() 36 | # train and update 37 | optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.5) 38 | 39 | epoch_loss = [] 40 | for iter in range(self.args.local_ep): 41 | batch_loss = [] 42 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 43 | images, labels = images.to(self.args.device), labels.to(self.args.device) 44 | net.zero_grad() 45 | log_probs = net(images) 46 | loss = self.loss_func(log_probs, labels) 47 | loss.backward() 48 | optimizer.step() 49 | if self.args.verbose and batch_idx % 10 == 0: 50 | print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 51 | iter, batch_idx * len(images), len(self.ldr_train.dataset), 52 | 100. * batch_idx / len(self.ldr_train), loss.item())) 53 | batch_loss.append(loss.item()) 54 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 55 | net.diff() 56 | # net.to(self.args.device) 57 | net.quant(self.args) 58 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) -------------------------------------------------------------------------------- /models/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader 9 | 10 | 11 | def test_img(net_g, datatest, args): 12 | net_g.eval() 13 | # testing 14 | test_loss = 0 15 | correct = 0 16 | data_loader = DataLoader(datatest, batch_size=args.bs) 17 | l = len(data_loader) 18 | for idx, (data, target) in enumerate(data_loader): 19 | if args.gpu != -1: 20 | data, target = data.to(args.device), target.to(args.device) 21 | log_probs = net_g(data) 22 | # sum up batch loss 23 | test_loss += F.cross_entropy(log_probs, target, reduction='sum').item() 24 | # get the index of the max log-probability 25 | y_pred = log_probs.data.max(1, keepdim=True)[1] 26 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 27 | 28 | test_loss /= len(data_loader.dataset) 29 | accuracy = 100.00 * correct / len(data_loader.dataset) 30 | if args.verbose: 31 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format( 32 | test_loss, correct, len(data_loader.dataset), accuracy)) 33 | return accuracy, test_loss 34 | 35 | -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | # utils code 2 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import argparse 6 | 7 | def args_parser(): 8 | parser = argparse.ArgumentParser() 9 | # federated arguments 10 | parser.add_argument('--epochs', type=int, default=10, help="rounds of training") 11 | parser.add_argument('--num_users', type=int, default=100, help="number of users: K") 12 | parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C") 13 | parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E") 14 | parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B") 15 | parser.add_argument('--bs', type=int, default=128, help="test batch size") 16 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate") 17 | parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)") 18 | parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample") 19 | parser.add_argument('--bit_width', type=int, default=8, help="bit_width of quantize and unquantize") 20 | parser.add_argument('--alpha', type=float, default=0.1, help="alpha of quantize and unquantize") 21 | 22 | # model arguments 23 | parser.add_argument('--model', type=str, default='cnn', help='model name') 24 | parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel') 25 | parser.add_argument('--kernel_sizes', type=str, default='3,4,5', 26 | help='comma-separated kernel size to use for convolution') 27 | parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None") 28 | parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets") 29 | parser.add_argument('--max_pool', type=str, default='True', 30 | help="Whether use max pooling rather than strided convolutions") 31 | parser.add_argument('--scaling_factor', type=int, default=10, help="scaling factor") 32 | 33 | # other arguments 34 | parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset") 35 | parser.add_argument('--iid', action='store_true', help='whether i.i.d or not') 36 | parser.add_argument('--num_classes', type=int, default=2, help="number of classes") 37 | parser.add_argument('--num_channels', type=int, default=1, help="number of channels of images") 38 | parser.add_argument('--gpu', type=int, default=-1, help="GPU ID, -1 for CPU") 39 | parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping') 40 | parser.add_argument('--verbose', action='store_true', help='verbose print') 41 | parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') 42 | parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients') 43 | args = parser.parse_args() 44 | return args 45 | -------------------------------------------------------------------------------- /utils/sampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | 6 | import numpy as np 7 | from torchvision import datasets, transforms 8 | 9 | def mnist_iid(dataset, num_users): 10 | """ 11 | Sample I.I.D. client data from MNIST dataset 12 | :param dataset: 13 | :param num_users: 14 | :return: dict of image index 15 | """ 16 | num_items = int(len(dataset)/num_users) 17 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 18 | for i in range(num_users): 19 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) 20 | all_idxs = list(set(all_idxs) - dict_users[i]) 21 | return dict_users 22 | 23 | 24 | def mnist_noniid(dataset, num_users): 25 | """ 26 | Sample non-I.I.D client data from MNIST dataset 27 | :param dataset: 28 | :param num_users: 29 | :return: 30 | """ 31 | num_shards, num_imgs = 200, 300 32 | idx_shard = [i for i in range(num_shards)] 33 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 34 | idxs = np.arange(num_shards*num_imgs) 35 | labels = dataset.train_labels.numpy() 36 | 37 | # sort labels 38 | idxs_labels = np.vstack((idxs, labels)) 39 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 40 | idxs = idxs_labels[0,:] 41 | 42 | # divide and assign 43 | for i in range(num_users): 44 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 45 | idx_shard = list(set(idx_shard) - rand_set) 46 | for rand in rand_set: 47 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 48 | return dict_users 49 | 50 | 51 | def cifar_iid(dataset, num_users): 52 | """ 53 | Sample I.I.D. client data from CIFAR10 dataset 54 | :param dataset: 55 | :param num_users: 56 | :return: dict of image index 57 | """ 58 | num_items = int(len(dataset)/num_users) 59 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 60 | for i in range(num_users): 61 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) 62 | all_idxs = list(set(all_idxs) - dict_users[i]) 63 | return dict_users 64 | 65 | def fmnist_iid(dataset, num_users): 66 | """ 67 | Sample I.I.D. client data from CIFAR10 dataset 68 | :param dataset: 69 | :param num_users: 70 | :return: dict of image index 71 | """ 72 | num_items = int(len(dataset)/num_users) 73 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 74 | for i in range(num_users): 75 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) 76 | all_idxs = list(set(all_idxs) - dict_users[i]) 77 | return dict_users 78 | 79 | 80 | def celeba_iid(dataset, num_users): 81 | """ 82 | Sample I.I.D. client data from CIFAR10 dataset 83 | :param dataset: 84 | :param num_users: 85 | :return: dict of image index 86 | """ 87 | num_items = int(len(dataset)/num_users) 88 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 89 | for i in range(num_users): 90 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) 91 | all_idxs = list(set(all_idxs) - dict_users[i]) 92 | return dict_users 93 | 94 | if __name__ == '__main__': 95 | dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, 96 | transform=transforms.Compose([ 97 | transforms.ToTensor(), 98 | transforms.Normalize((0.1307,), (0.3081,)) 99 | ])) 100 | num = 100 101 | d = mnist_noniid(dataset_train, num) 102 | --------------------------------------------------------------------------------