├── assets └── FedPHE.png ├── requirements.txt ├── utils ├── subset.py ├── sampling.py ├── data_split.py ├── util.py ├── min_hash.py ├── functional.py ├── dataset.py └── partition.py ├── encryption ├── gmpy_math.py ├── complement.py ├── compress.py ├── aciq.py ├── encrypt.py ├── bfv.py ├── quantize.py ├── ckks.py └── paillier.py ├── models ├── resnet50.py └── model.py ├── fed.py ├── README.md ├── main.py ├── LICENSE ├── client.py └── server.py /assets/FedPHE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lunan0320/FedPHE/HEAD/assets/FedPHE.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python-3.9.18 2 | numpy-1.26.2 3 | pytorch-2.1.2 4 | pytorch-cuda-12.1 5 | tenseal-0.3.14 6 | gmpy2-2.1.5 7 | gap-stat-2.0.3 8 | pillow-6.1 9 | tenseal-0.3.14 10 | scikit-learn-1.3.0 11 | pandas-2.1.4 -------------------------------------------------------------------------------- /utils/subset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Subset 2 | class CustomSubset(Subset): 3 | '''A custom subset class with customizable data transformation''' 4 | def __init__(self, dataset, indices, subset_transform=None): 5 | super().__init__(dataset, indices) 6 | self.subset_transform = subset_transform 7 | 8 | def __getitem__(self, idx): 9 | 10 | x, y = self.dataset[self.indices[idx]] 11 | 12 | if self.subset_transform: 13 | x = self.subset_transform(x) 14 | 15 | return x, y 16 | 17 | def __len__(self): 18 | return len(self.indices) -------------------------------------------------------------------------------- /encryption/gmpy_math.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import gmpy2 4 | 5 | POWMOD_GMP_SIZE = pow(2, 64) 6 | 7 | def powmod(a, b, c): 8 | """ 9 | return int: (a ** b) % c 10 | """ 11 | 12 | if a == 1: 13 | return 1 14 | 15 | if max(a, b, c) < POWMOD_GMP_SIZE: 16 | return pow(a, b, c) 17 | 18 | else: 19 | return int(gmpy2.powmod(a, b, c)) 20 | 21 | 22 | def invert(a, b): 23 | """return int: x, where a * x == 1 mod b 24 | """ 25 | x = int(gmpy2.invert(a, b)) 26 | 27 | if x == 0: 28 | raise ZeroDivisionError('invert(a, b) no inverse exists') 29 | 30 | return x 31 | 32 | 33 | def getprimeover(n): 34 | """return a random n-bit prime number 35 | """ 36 | r = gmpy2.mpz(random.SystemRandom().getrandbits(n)) 37 | r = gmpy2.bit_set(r, n - 1) 38 | 39 | return int(gmpy2.next_prime(r)) 40 | 41 | 42 | def isqrt(n): 43 | """ return the integer square root of N """ 44 | 45 | return int(gmpy2.isqrt(n)) 46 | 47 | 48 | -------------------------------------------------------------------------------- /encryption/complement.py: -------------------------------------------------------------------------------- 1 | 2 | # true value trans to 2's complement 3 | def true2two(value, padding_bits, quan_bits): 4 | if value >= 0: 5 | binary = format(value,'0{}b'.format(quan_bits)) 6 | return int(binary.zfill(padding_bits + quan_bits),2) 7 | else: 8 | binary = format(abs(value),'0{}b'.format(quan_bits)) 9 | inversed = ''.join('1' if bit == '0' else '0' for bit in binary) 10 | complement = int(bin(int(inversed,2)+1)[2:].zfill(padding_bits + quan_bits),2) 11 | return complement 12 | 13 | # 2's complement trans to true value 14 | def two2true(value, padding_bits, quan_bits): 15 | #value = int(value ) 16 | mod = pow(2,quan_bits) 17 | #value = value % mod 18 | if value > mod: 19 | value %= mod 20 | value = bin(value)[2:].zfill(quan_bits) 21 | 22 | if value[0] == '0': 23 | return int(value,2) 24 | elif value[0] == '1': 25 | inversed = ''.join('1' if bit == '0' else '0' for bit in value) 26 | value = int(inversed,2) + 1 27 | if value >= pow(2, quan_bits) - 1: 28 | value -= pow(2, quan_bits) 29 | return (-1*value) 30 | else: 31 | raise ValueError("Overflow {}".format(value)) -------------------------------------------------------------------------------- /encryption/compress.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from multiprocessing import Pool 3 | import sys 4 | 5 | N_JOBS = 4 6 | 7 | 8 | def chunks_idx(l, n): 9 | d, r = divmod(len(l), n) 10 | for i in range(n): 11 | si = (d+1)*(i if i < r else r) + d*(0 if i < r else i - r) 12 | yield si, si+(d+1 if i < r else d) 13 | 14 | ####################################### 15 | def _compress(flatten_array, num_bits): 16 | res = 0 17 | l = len(flatten_array) 18 | for element in flatten_array: 19 | res <<= num_bits 20 | res += element 21 | 22 | return res, l 23 | 24 | def compress_multi(flatten_array, num_bits): 25 | l = len(flatten_array) 26 | 27 | pool_inputs = [] 28 | sizes = [] 29 | pool = Pool(N_JOBS) 30 | 31 | for begin, end in chunks_idx(range(l), N_JOBS): 32 | sizes.append(end - begin) 33 | 34 | pool_inputs.append([flatten_array[begin:end], num_bits]) 35 | 36 | pool_outputs = pool.starmap(_compress, pool_inputs) 37 | pool.close() 38 | pool.join() 39 | 40 | res = 0 41 | 42 | for idx, output in enumerate(pool_outputs): 43 | res += output[0] << (int(np.sum(sizes[idx + 1:])) * num_bits) 44 | 45 | num_bytes = (num_bits * l - 1) // 8 + 1 46 | res = res.to_bytes(num_bytes, 'big') 47 | return res, l -------------------------------------------------------------------------------- /encryption/aciq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class ACIQ(object): 5 | 6 | def __init__(self, num_bits): 7 | super(ACIQ, self).__init__() 8 | self.num_bits = num_bits 9 | 10 | def get_alpha_gaus(self, min, max, size): 11 | alpha_gaus = [None, None, 1.710635, 2.151593, 2.559136, 2.936201, 3.286914, 3.615114, 12 | 3.924035, 4.216331, 4.494167, 4.759313, 5.013188, 5.257151, 5.491852, 5.719160, 13 | 5.938345, 6.150141, 6.356593, 6.560495, 6.752936, 6.931921, 7.106395, 7.350340, 14 | 7.482915, 7.691728, 7.668494, 7.583591, 7.583591, 8.326501, 8.171210, 8.171210] 15 | gaussian_const = (0.5 * 0.35) * (1 + (np.pi * np.log(4)) ** 0.5) 16 | sigma = ((max - min) * gaussian_const) / ((2 * np.log(size)) ** 0.5) 17 | 18 | alpha_opt = alpha_gaus[31] if self.num_bits > 31 else alpha_gaus[self.num_bits] 19 | return alpha_opt * sigma 20 | 21 | def get_alpha_gaus_direct(self, sigma): 22 | alpha_gaus = [None, None, 1.710635, 2.151593, 2.559136, 2.936201, 3.286914, 3.615114, 23 | 3.924035, 4.216331, 4.494167, 4.759313, 5.013188, 5.257151, 5.491852, 5.719160, 24 | 5.938345, 6.150141, 6.356593, 6.560495, 6.752936, 6.931921, 7.106395, 7.350340, 25 | 7.482915, 7.691728, 7.668494, 7.583591, 7.583591, 8.326501, 8.171210, 8.171210] 26 | alpha_opt = alpha_gaus[31] if self.num_bits > 31 else alpha_gaus[self.num_bits] 27 | return alpha_opt * sigma -------------------------------------------------------------------------------- /encryption/encrypt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import Iterable 3 | 4 | 5 | class Encrypt(object): 6 | def __init__(self): 7 | self.public_key = None 8 | self.privacy_key = None 9 | 10 | def generate_key(self, n_length=0): 11 | pass 12 | 13 | def set_public_key(self, public_key): 14 | pass 15 | 16 | def get_public_key(self): 17 | pass 18 | 19 | def set_privacy_key(self, privacy_key): 20 | pass 21 | 22 | def get_privacy_key(self): 23 | pass 24 | 25 | def encrypt(self, value): 26 | pass 27 | 28 | def decrypt(self, value): 29 | pass 30 | 31 | def encrypt_list(self, values): 32 | result = [self.encrypt(msg) for msg in values] 33 | return result 34 | 35 | def decrypt_list(self, values): 36 | result = [self.decrypt(msg) for msg in values] 37 | return result 38 | 39 | def distribute_decrypt(self, X): 40 | decrypt_table = X.mapValues(lambda x: self.decrypt(x)) 41 | return decrypt_table 42 | 43 | def distribute_encrypt(self, X): 44 | encrypt_table = X.mapValues(lambda x: self.encrypt(x)) 45 | return encrypt_table 46 | 47 | def _recursive_func(self, obj, func): 48 | if isinstance(obj, np.ndarray): 49 | if len(obj.shape) == 1: 50 | return np.reshape([func(val) for val in obj], obj.shape) 51 | else: 52 | return np.reshape([self._recursive_func(o, func) for o in obj], obj.shape) 53 | elif isinstance(obj, Iterable): 54 | return type(obj)( 55 | self._recursive_func(o, func) if isinstance(o, Iterable) else func(o) for o in obj) 56 | else: 57 | return func(obj) 58 | 59 | def recursive_encrypt(self, X): 60 | return self._recursive_func(X, self.encrypt) 61 | 62 | def recursive_decrypt(self, X): 63 | return self._recursive_func(X, self.decrypt) 64 | -------------------------------------------------------------------------------- /utils/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | 6 | # cosine similarity 7 | def importance_get(g1,g2): 8 | flat_tensor1 = torch.from_numpy(np.array(g1)) 9 | flat_tensor2 = torch.from_numpy(np.array(g2)) 10 | cos_sim = torch.nn.functional.cosine_similarity(flat_tensor1, flat_tensor2, dim=-1) 11 | return cos_sim 12 | 13 | 14 | # sort according to importance 15 | def importance_sort(list_imp): 16 | dict_sort = sorted(enumerate(list_imp), key=lambda list_imp:list_imp[1]) # x[1]是因为在enumerate(a)中,a数值在第1位 17 | list_sort = sorted(list_imp) 18 | list_index = [x[0] for x in dict_sort] 19 | return list_sort,list_index 20 | 21 | 22 | # dynamic sampling 23 | def dynamic_B(list_imp,threshold,B): 24 | K = 0 25 | L = len(list_imp) 26 | for i in range(L): 27 | if list_imp[i] > threshold: 28 | K += 1 29 | B = max(B-1,L-K) 30 | return B,K 31 | 32 | 33 | # prob for client sync 34 | def get_p(list_imp,B,L): 35 | list_k = [] 36 | list_p = [0] * L 37 | for i in range(len(list_imp)): 38 | for k in range(L,0,-1): 39 | tmp = sum(list_imp[:k]) 40 | if B + k-1 - L <= list_imp[i]/tmp: 41 | list_k.append(k) 42 | list_p[i] = ((B+k-L)*list_imp[i]/tmp) 43 | break 44 | 45 | K = max(list_k) 46 | for i in range(L-1,K-1,-1): 47 | list_p[i] = 1 48 | return list_p 49 | 50 | 51 | def client_sampling(list_p,global_index): 52 | list_index = [] 53 | for i in range(len(list_p)): 54 | rand = random.random() 55 | if list_p[i] >= rand: 56 | list_index.append(global_index[i]) 57 | return sorted(list_index) 58 | 59 | 60 | def Adaptive_samping(global_importance,threshold,B): 61 | global_sort,global_index = importance_sort(global_importance) 62 | B,K = dynamic_B(global_importance,threshold,B) 63 | global_p = get_p(global_sort,B,len(global_sort)) 64 | global_index = client_sampling(global_p,global_index) 65 | B_real_list.append(len(global_index)) 66 | return global_index 67 | 68 | 69 | def Adaptive_samping_bar(global_importance,threshold,B,B_real_list): 70 | global_sort,global_index = importance_sort(global_importance) 71 | B,K = dynamic_B(global_importance,threshold,B) 72 | global_p = get_p(global_sort,B,len(global_sort)) 73 | global_index = client_sampling(global_p,global_index) 74 | B_real_list.append(len(global_index)) 75 | 76 | return global_index,B,B_real_list 77 | -------------------------------------------------------------------------------- /utils/data_split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def noniid_split_dirichlet(train_labels, alpha, n_clients): 5 | """ 6 | Splits a list of data indices with corresponding labels 7 | into subsets according to a dirichlet distribution with parameter 8 | alpha 9 | Args: 10 | train_labels: ndarray of train_labels 11 | alpha: the parameter of dirichlet distribution 12 | n_clients: number of clients 13 | Returns: 14 | client_idcs: a list containing sample idcs of clients 15 | """ 16 | 17 | n_classes = train_labels.max()+1 18 | label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes) 19 | 20 | class_idcs = [np.argwhere(train_labels==y).flatten() 21 | for y in range(n_classes)] 22 | 23 | 24 | client_idcs = [[] for _ in range(n_clients)] 25 | for c, fracs in zip(class_idcs, label_distribution): 26 | for i, idcs in enumerate(np.split(c, (np.cumsum(fracs)[:-1]*len(c)).astype(int))): 27 | client_idcs[i] += [idcs] 28 | 29 | client_idcs = [np.concatenate(idcs) for idcs in client_idcs] 30 | 31 | return client_idcs 32 | 33 | def noniid_split_pathological(train_labels, n_classes_per_client, n_clients): 34 | n_classes = train_labels.max()+1 35 | data_idcs = list(range(len(train_labels))) 36 | label2index = {k: [] for k in range(n_classes)} 37 | for idx in data_idcs: 38 | label = train_labels[idx] 39 | label2index[label].append(idx) 40 | 41 | sorted_idcs = [] 42 | for label in label2index: 43 | sorted_idcs += label2index[label] 44 | 45 | def iid_divide(l, g): 46 | num_elems = len(l) 47 | group_size = int(len(l) / g) 48 | num_big_groups = num_elems - g * group_size 49 | num_small_groups = g - num_big_groups 50 | glist = [] 51 | for i in range(num_small_groups): 52 | glist.append(l[group_size * i: group_size * (i + 1)]) 53 | bi = group_size * num_small_groups 54 | group_size += 1 55 | for i in range(num_big_groups): 56 | glist.append(l[bi + group_size * i:bi + group_size * (i + 1)]) 57 | return glist 58 | 59 | 60 | n_shards = n_clients * n_classes_per_client 61 | 62 | shards = iid_divide(sorted_idcs, n_shards) 63 | np.random.shuffle(shards) 64 | 65 | tasks_shards = iid_divide(shards, n_clients) 66 | 67 | clients_idcs = [[] for _ in range(n_clients)] 68 | for client_id in range(n_clients): 69 | for shard in tasks_shards[client_id]: 70 | 71 | clients_idcs[client_id] += shard 72 | 73 | return clients_idcs -------------------------------------------------------------------------------- /encryption/bfv.py: -------------------------------------------------------------------------------- 1 | import tenseal as ts 2 | import numpy as np 3 | from encryption.quantize import quantize, unquantize 4 | import random 5 | 6 | 7 | def bfv_enc(plain_list,bfv_ctx,args): 8 | isBatch=args.isBatch 9 | batch_size=args.enc_batch_size 10 | topk=args.topk 11 | is_spars = args.isSpars 12 | 13 | quan_bits = args.quan_bits 14 | plain_quan = quantize(plain_list,quan_bits,args.n_clients).tolist() 15 | 16 | batch_num = int(np.ceil(len(plain_list) / batch_size)) 17 | if isBatch: 18 | # padding 19 | if len(plain_list) % batch_num != 0: 20 | padding_num = batch_num * batch_size - len(plain_list) 21 | plain_list.extend([0]*padding_num) 22 | plain_quan.extend([0]*padding_num) 23 | if is_spars == 'topk': 24 | topk = int(np.ceil(batch_num * topk)) 25 | sign = np.sign(np.array(plain_list)) 26 | tmp_list = (np.array(plain_list) * sign).tolist() 27 | plain_batchs = [tmp_list[i * batch_size : (i+1) * batch_size ]for i in range(batch_num)] 28 | avg_list = [np.average(np.abs(batch)) for batch in plain_batchs] 29 | max_avg_list = np.sort(avg_list)[::-1][:topk] 30 | mask_list = [] 31 | for i in range(len(max_avg_list)): 32 | mask_list.append(avg_list.index(max_avg_list[i])) 33 | mask_list.sort() 34 | 35 | res_mask = [0 for i in range(batch_num) ] 36 | for i in range(batch_num): 37 | if i in mask_list: 38 | res_mask[i] = 1 39 | # batch for encryption 40 | plain_list = [plain_quan[mask_list[i] * batch_size : (mask_list[i] + 1) * batch_size] for i in range(len(mask_list))] 41 | 42 | cipher_list = [] 43 | for i in range(len(mask_list)): 44 | cipher = ts.bfv_vector(bfv_ctx,plain_list[i]) 45 | cipher_list.append(cipher.serialize()) 46 | return cipher_list, res_mask 47 | else: 48 | cipher_list = [] 49 | for i in range(batch_num): 50 | cipher = ts.bfv_vector(bfv_ctx, plain_quan[i * batch_size : (i + 1) * batch_size]) 51 | cipher_list.append(cipher.serialize()) 52 | return cipher_list 53 | else: 54 | cipher = [ts.bfv_vector(bfv_ctx, [i]).serialize() for i in plain_quan] 55 | return cipher 56 | 57 | 58 | def bfv_dec(cipher_list,bfv_ctx,sk,isBatch,quan_bits,n_clients,sum_masks = [],batch_size = 0): 59 | 60 | if isBatch: 61 | plain_list = [] 62 | for idx, cipher_serial in enumerate(cipher_list): 63 | if cipher_serial == 0: 64 | zero_pad = [0] * batch_size 65 | plain_list.extend(zero_pad) 66 | else: 67 | plain = ts.BFVVector.load(bfv_ctx, cipher_serial).decrypt(sk) 68 | plain = unquantize(plain,quan_bits,n_clients) 69 | if sum_masks != []: 70 | plain = np.array(plain)/sum_masks[idx] 71 | plain_list.extend(plain) 72 | 73 | return np.array(plain_list) 74 | else: 75 | plains = [ts.BFVVector.load(bfv_ctx, i).decrypt() for i in cipher_list] 76 | plains = [plains.squeeze()] 77 | res = [] 78 | for plain in plains: 79 | tmp = unquantize(plain,quan_bits,n_clients) 80 | res.append(tmp) 81 | return np.array(res) 82 | 83 | 84 | -------------------------------------------------------------------------------- /models/resnet50.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch.nn import init 4 | 5 | class LambdaLayer(nn.Module): 6 | def __init__(self, lambd): 7 | super(LambdaLayer, self).__init__() 8 | self.lambd = lambd 9 | 10 | def forward(self, x): 11 | return self.lambd(x) 12 | 13 | def _weights_init(m): 14 | classname = m.__class__.__name__ 15 | if isinstance(m, nn.Conv2d): 16 | init.kaiming_normal_(m.weight) 17 | elif isinstance(m, nn.BatchNorm2d): 18 | init.constant_(m.weight, 1) 19 | init.constant_(m.bias, 0) 20 | 21 | class BasicBlock(nn.Module): 22 | expansion = 1 23 | 24 | def __init__(self, in_planes, planes, stride=1, option="B"): 25 | super(BasicBlock, self).__init__() 26 | self.conv1 = nn.Conv2d( 27 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 30 | stride=1, padding=1, bias=False) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | 33 | self.shortcut = nn.Sequential() 34 | if stride != 1 or in_planes != planes * self.expansion: 35 | if option == "A": 36 | self.shortcut = LambdaLayer(lambda x: 37 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 38 | elif option == "B": 39 | self.shortcut = nn.Sequential( 40 | nn.Conv2d(in_planes, planes * self.expansion, 41 | kernel_size=1, stride=stride, bias=False), 42 | nn.BatchNorm2d(planes * self.expansion) 43 | ) 44 | 45 | def forward(self, x): 46 | out = F.relu(self.bn1(self.conv1(x))) 47 | out = self.bn2(self.conv2(out)) 48 | out += self.shortcut(x) 49 | out = F.relu(out) 50 | return out 51 | 52 | class ResNet(nn.Module): 53 | def __init__(self, block, num_blocks, in_channels=3, num_classes=100): 54 | super(ResNet, self).__init__() 55 | self.in_planes = 16 56 | 57 | self.conv1 = nn.Conv2d( 58 | in_channels, 16, kernel_size=3, stride=1, padding=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(16) 60 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 61 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 62 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 63 | self.linear = nn.Linear(64 * block.expansion, num_classes) 64 | 65 | self.apply(_weights_init) 66 | 67 | def _make_layer(self, block, planes, num_blocks, stride): 68 | strides = [stride] + [1]*(num_blocks-1) 69 | layers = [] 70 | for stride in strides: 71 | layers.append(block(self.in_planes, planes, stride)) 72 | self.in_planes = planes * block.expansion 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = F.relu(self.bn1(self.conv1(x))) 77 | out = self.layer1(out) 78 | out = self.layer2(out) 79 | out = self.layer3(out) 80 | out = F.avg_pool2d(out, 8) 81 | out = out.view(out.size(0), -1) 82 | out = self.linear(out) 83 | return out 84 | 85 | def resnet50(in_channels=3, num_classes=100): 86 | return ResNet(BasicBlock, [5, 5, 5], in_channels, num_classes) 87 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import numpy as np 3 | import os 4 | import torch 5 | import random 6 | from models.model import CNN_cifar,LeNet_mnist,CNN_fmnist,resnet20 7 | from sklearn.cluster import KMeans 8 | from models.resnet50 import resnet50 9 | 10 | def model_init(dataset,device): 11 | if dataset=='MNIST': 12 | model = LeNet_mnist().to(device) 13 | elif dataset == 'FashionMNIST': 14 | model = CNN_fmnist().to(device) 15 | elif dataset == 'CIFAR10': 16 | model = resnet20(in_channels=3,num_classes=10).to(device) 17 | #model = CNN_cifar().to(device) 18 | elif dataset == 'CIFAR100': 19 | model = resnet50(in_channels=3,num_classes=100).to(device) 20 | else: 21 | raise ValueError("Datset name is invalid, please input MNIST, FashionMNIST or CIFAR10") 22 | return model 23 | 24 | def logging(str,args): 25 | log_file = open(os.path.join(args.log_dir, args.dataset + '.log'), "a+") 26 | print("{} | {}".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), str)) 27 | print("{} | {}".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), str), file=log_file) 28 | 29 | 30 | def init_prop(train_dataset,test_dataset,n_clients): 31 | """ 32 | Initialize weights of aggregation according to the samples of clients. 33 | 34 | Args: 35 | train_dataset (`dict`): 36 | Training dataset. 37 | test_dataset (`dict`): 38 | Test dataset. 39 | n_clients (`int`): 40 | Number of clients to participate. 41 | Returns: 42 | train_props (`list`): 43 | Training weight for each client. 44 | test_props (`list`): 45 | Test weight for each client. 46 | """ 47 | client_n_samples_train = [] 48 | client_n_samples_test = [] 49 | for idx in range(n_clients): 50 | client_n_samples_train.append(len(train_dataset[idx])) 51 | client_n_samples_test.append(len(test_dataset[idx])) 52 | samples_sum_train = np.sum(client_n_samples_train) 53 | samples_sum_test = np.sum(client_n_samples_test) 54 | test_props = [] 55 | train_props = [] 56 | for idx in range(n_clients): 57 | train_props.append(client_n_samples_train[idx]/samples_sum_train) 58 | test_props.append(client_n_samples_test[idx]/samples_sum_test) 59 | return train_props,test_props 60 | 61 | def seed_everything(seed): 62 | random.seed(seed) 63 | torch.manual_seed(seed) 64 | np.random.seed(seed) 65 | os.environ["PYTHONHASHSEED"] = str(seed) 66 | 67 | torch.cuda.manual_seed_all(seed) 68 | torch.backends.cudnn.deterministic = True 69 | torch.backends.cudnn.benchmark = False 70 | 71 | def pseudo_random(seed, batch_num, topk, round_t): 72 | random.seed(seed + round_t) 73 | number_list = list(range(batch_num)) 74 | topk = int(np.ceil(batch_num * topk)) 75 | # 对数字列表进行洗牌 76 | random.shuffle(number_list) 77 | random_list = sorted(number_list[:topk]) 78 | 79 | return random_list 80 | 81 | def jaccard_similarity(x, y): 82 | intersection = np.intersect1d(x, y) 83 | union = np.union1d(x, y) 84 | return len(intersection) / len(union) 85 | 86 | def jaccard_distance_matrix(matrix): 87 | n = matrix.shape[0] 88 | dist_matrix = np.zeros((n, n)) 89 | for i in range(n): 90 | for j in range(n): 91 | if i != j: 92 | dist_matrix[i][j] = 1 - jaccard_similarity(matrix[i], matrix[j]) 93 | return dist_matrix 94 | 95 | def jaccard_kmeans_clustering(matrix, k): 96 | distance_matrix = jaccard_distance_matrix(matrix) 97 | clusters = np.zeros(len(matrix)) # Initializing clusters array 98 | while True: 99 | kmeans = KMeans(n_clusters=k, random_state=42) 100 | clusters = kmeans.fit_predict(distance_matrix) 101 | cluster_indices = [np.where(clusters == i)[0] for i in range(k)] 102 | empty_clusters = [i for i, indices in enumerate(cluster_indices) if len(indices) == 0] 103 | 104 | if len(empty_clusters) == 0: 105 | break # No empty clusters, exit loop 106 | 107 | return cluster_indices -------------------------------------------------------------------------------- /fed.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import torch.multiprocessing as mp 5 | import numpy as np 6 | 7 | import utils.min_hash as lsh 8 | from utils.util import logging 9 | from client import client_process 10 | from server import server_process 11 | from utils.dataset import load_dataset,load_exist 12 | from utils.util import init_prop 13 | from client import params_tolist 14 | from models.model import LeNet_mnist,CNN_fmnist,resnet20,CNN_cifar 15 | from models.resnet50 import resnet50 16 | 17 | def model_init(dataset,device): 18 | """ 19 | Model initialization. 20 | 21 | Args: 22 | dataset (`str`): 23 | Name of dataset. 24 | device (`str`): 25 | Training on GPU or MPS or CPU. 26 | 27 | Returns: 28 | model (`OrderDict`): 29 | Model for dataset. 30 | """ 31 | if dataset == 'MNIST': 32 | model = LeNet_mnist().to(device) 33 | elif dataset == 'FashionMNIST': 34 | model = CNN_fmnist().to(device) 35 | #model = resnet20(in_channels=1,num_classes=10).to(device) 36 | elif dataset == 'CIFAR10': 37 | model = resnet20(in_channels=3,num_classes=10).to(device) 38 | elif dataset == 'CIFAR100': 39 | model = resnet50(in_channels=3,num_classes=100).to(device) 40 | else: 41 | raise ValueError("Datset name is invalid, please input MNIST, FashionMNIST, CIFAR10 or CIFAR100") 42 | return model 43 | 44 | 45 | def run(args,kwargs_IPC,device): 46 | """ 47 | Run fucntion to launch server and clients processes. 48 | 49 | Args: 50 | args (`arg_parse`): 51 | Hyper-parameters. 52 | kwargs_IPC (`dict`): 53 | Parameters for IPC communication. 54 | device (`str`): 55 | Training on GPU or MPS or CPU. 56 | Returns: 57 | None. 58 | """ 59 | train_file = os.path.join(args.data_dir, args.dataset + '_train') 60 | if not os.path.exists(train_file): 61 | client_train_datasets, client_test_datasets, data_info,server_test_sets = load_dataset(args) 62 | print("Generate new files!") 63 | else: 64 | client_train_datasets, client_test_datasets, data_info,server_test_sets = load_exist(args) 65 | print("Load last files!") 66 | 67 | train_weights,test_weights = init_prop(client_train_datasets,client_test_datasets, args.n_clients) 68 | 69 | logging("training weights: {}".format(train_weights),args) 70 | logging("testing weights:{}".format(test_weights),args) 71 | for idx in range(args.n_clients): 72 | logging('client{}, train samples {},test samples {}'.format( 73 | idx,len(client_train_datasets[idx]),len(client_test_datasets[idx])),args) 74 | logging("data split finished!",args) 75 | 76 | kwargs = {'batch_size': args.batch_size, 77 | 'shuffle': True,'drop_last':True} 78 | if args.cuda and torch.cuda.is_available(): 79 | kwargs.update({'num_workers': 0, 80 | 'pin_memory': True, 81 | }) 82 | 83 | 84 | model = model_init(args.dataset,device) 85 | params_list,params_num,layer_shape = params_tolist(model) 86 | total_sum = sum(params_num.values()) 87 | 88 | # enc_tool for paillier algorithm 89 | if args.enc and args.algorithm == 'paillier': 90 | enc_tools = kwargs_IPC['enc_tools'] 91 | enc_tools.update({'total_params':total_sum}) 92 | kwargs_IPC.update({'enc_tools':enc_tools}) 93 | 94 | # number of batchs for processing 95 | batch_num = int(np.ceil(total_sum / args.enc_batch_size)) 96 | if args.enc and args.isBatch: 97 | logging("Batch num:{}".format(batch_num),args) 98 | 99 | if args.isSelection: 100 | random_R = lsh.gen_random_R(input_len = total_sum, sim_len=args.sim_len) 101 | kwargs_IPC.update({'random_R':random_R,}) 102 | 103 | 104 | processes = [] 105 | for rank in range(args.n_clients+1): 106 | # for server 107 | if rank == 0: 108 | p = mp.Process(target=server_process,args=(args,kwargs_IPC,total_sum,batch_num,train_weights,test_weights,server_test_sets,kwargs)) 109 | # for clients 110 | else: 111 | p = mp.Process(target=client_process, args=(rank-1, args, model, device, 112 | client_train_datasets[rank-1], client_test_datasets[rank-1], kwargs,kwargs_IPC,train_weights)) 113 | p.start() 114 | processes.append(p) 115 | for p in processes: 116 | p.join() 117 | 118 | 119 | logging('Final End',args) 120 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | 6 | 7 | 8 | 9 | class LeNet_mnist(nn.Module): 10 | def __init__(self, in_channels=1, num_classes=10): 11 | super(LeNet_mnist, self).__init__() 12 | self.conv1 = nn.Conv2d(1, 6, 5) 13 | self.pool1 = nn.MaxPool2d(2, 2) 14 | self.conv2 = nn.Conv2d(6, 16, 5) 15 | self.pool2 = nn.MaxPool2d(2, 2) 16 | self.conv3 = nn.Conv2d(16, 120, 4) 17 | self.fc1 = nn.Linear(120, 84) 18 | self.fc2 = nn.Linear(84, 10) 19 | 20 | def forward(self, x): 21 | x = F.relu(self.conv1(x)) # input(3, 32, 32) output(16, 28, 28) 22 | x = self.pool1(x) # output(16, 14, 14) 23 | x = F.relu(self.conv2(x)) # output(32, 10, 10) 24 | x = self.pool2(x) # output(32, 5, 5) 25 | x = F.relu(self.conv3(x)) 26 | x = x.view(-1, 120) # output(32*5*5) 27 | x = F.relu(self.fc1(x)) # output(120) 28 | x = self.fc2(x) # output(10) 29 | return x 30 | 31 | class CNN_cifar(nn.Module): 32 | def __init__(self, in_channels=3, num_classes=10): 33 | super(CNN_cifar, self).__init__() 34 | self.conv1 = nn.Conv2d(in_channels, 6, 5) 35 | self.pool = nn.MaxPool2d(2, 2) 36 | self.conv2 = nn.Conv2d(6, 16, 5) 37 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 38 | self.fc2 = nn.Linear(120, 84) 39 | self.fc3 = nn.Linear(84, num_classes) 40 | 41 | def forward(self, x): 42 | x = self.pool(F.relu(self.conv1(x))) 43 | x = self.pool(F.relu(self.conv2(x))) 44 | x = x.view(-1, 16 * 5 * 5) 45 | x = F.relu(self.fc1(x)) 46 | x = F.relu(self.fc2(x)) 47 | x = self.fc3(x) 48 | return x 49 | 50 | class CNN_fmnist(nn.Module): 51 | def __init__(self,in_channels=1, num_classes=10): 52 | super(CNN_fmnist, self).__init__() 53 | self.layer1 = nn.Sequential( 54 | nn.Conv2d(in_channels, 16, kernel_size=(5, 5), padding=2), 55 | nn.BatchNorm2d(16), 56 | nn.ReLU() 57 | ) 58 | self.pool1 = nn.MaxPool2d(2) 59 | self.layer2 = nn.Sequential( 60 | nn.Conv2d(16, 32, kernel_size=(3, 3)), 61 | nn.BatchNorm2d(32), 62 | nn.ReLU() 63 | ) 64 | self.layer3 = nn.Sequential( 65 | nn.Conv2d(32, 64, kernel_size=(3, 3)), 66 | nn.BatchNorm2d(64), 67 | nn.ReLU() 68 | ) 69 | self.pool2 = nn.MaxPool2d(2) 70 | self.fc = nn.Linear(5 * 5 * 64, num_classes) 71 | 72 | def forward(self, x): 73 | out = self.pool1(self.layer1(x)) 74 | out = self.pool2(self.layer3(self.layer2(out))) 75 | out = out.view(out.size(0), -1) 76 | out = self.fc(out) 77 | return out 78 | 79 | 80 | # resnet20 81 | def resnet20(in_channels, num_classes): 82 | return ResNet(BasicBlock, [3, 3, 3], in_channels, num_classes) 83 | # resnet32 84 | def resnet32(in_channels, num_classes): 85 | return ResNet(BasicBlock, [5, 5, 5], in_channels, num_classes) 86 | 87 | class LambdaLayer(nn.Module): 88 | def __init__(self, lambd): 89 | super(LambdaLayer, self).__init__() 90 | self.lambd = lambd 91 | 92 | def forward(self, x): 93 | return self.lambd(x) 94 | 95 | def _weights_init(m): 96 | classname = m.__class__.__name__ 97 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 98 | init.kaiming_normal_(m.weight) 99 | 100 | class BasicBlock(nn.Module): 101 | expansion = 1 102 | 103 | def __init__(self, in_planes, planes, stride=1, option="B"): 104 | super(BasicBlock, self).__init__() 105 | self.conv1 = nn.Conv2d( 106 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 107 | self.bn1 = nn.BatchNorm2d(planes) 108 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 109 | stride=1, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(planes) 111 | 112 | self.shortcut = nn.Sequential() 113 | if stride != 1 or in_planes != planes: 114 | if option == "A": 115 | """For CIFAR10 ResNet paper uses option A. 116 | """ 117 | self.shortcut = LambdaLayer(lambda x: 118 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 119 | elif option == "B": 120 | self.shortcut = nn.Sequential( 121 | nn.Conv2d(in_planes, self.expansion * planes, 122 | kernel_size=1, stride=stride, bias=False), 123 | nn.BatchNorm2d(self.expansion * planes) 124 | ) 125 | 126 | def forward(self, x): 127 | out = self.conv1(x) 128 | out = self.bn1(out) 129 | out = F.relu(out) 130 | 131 | out = self.conv2(out) 132 | out = self.bn2(out) 133 | 134 | out += self.shortcut(x) 135 | out = F.relu(out) 136 | 137 | return out 138 | 139 | class ResNet(nn.Module): 140 | def __init__(self, block, num_blocks, in_channels=3, num_classes=10): 141 | super(ResNet, self).__init__() 142 | self.in_planes = 16 143 | 144 | self.conv1 = nn.Conv2d( 145 | in_channels, 16, kernel_size=3, stride=1, padding=1, bias=False) 146 | self.bn1 = nn.BatchNorm2d(16) 147 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 148 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 149 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 150 | self.linear = nn.Linear(64, num_classes) 151 | 152 | self.apply(_weights_init) 153 | 154 | def _make_layer(self, block, planes, num_blocks, stride): 155 | strides = [stride] + [1]*(num_blocks-1) 156 | layers = [] 157 | for stride in strides: 158 | layers.append(block(self.in_planes, planes, stride)) 159 | self.in_planes = planes * block.expansion 160 | 161 | return nn.Sequential(*layers) 162 | 163 | def forward(self, x): 164 | out = self.conv1(x) # -> (batch, 16, 32, 32) 165 | out = self.bn1(out) 166 | out = F.relu(out) 167 | 168 | out = self.layer1(out) # -> (batch, 16, 32, 32) 169 | out = self.layer2(out) # -> (batch, 32, 16, 16) 170 | out = self.layer3(out) # -> (batch, 64, 8, 8) 171 | 172 | out = F.avg_pool2d(out, out.size()[3]) # -> (batch, 64, 1, 1) 173 | out = out.view(out.size(0), -1) # -> (batch, 64) 174 | out = self.linear(out) # -> (batch, num_classes) 175 | 176 | return out 177 | -------------------------------------------------------------------------------- /encryption/quantize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from encryption.aciq import ACIQ 3 | from encryption.complement import true2two, two2true 4 | 5 | 6 | 7 | precision = 5 8 | 9 | # get aciq alpha 10 | def get_alpha_r_max(plains,elements_bits, num_clients): 11 | list_min = np.min(plains) 12 | list_max = np.max(plains) 13 | list_size = len(plains) 14 | 15 | aciq = ACIQ(elements_bits) 16 | alpha = aciq.get_alpha_gaus(list_min,list_max,list_size) 17 | r_max = alpha * num_clients 18 | return alpha, r_max 19 | 20 | # quantize and padding return two_complement 21 | def quantize_padding(value, alpha, quan_bits, num_clients): 22 | # clipping 23 | value = np.clip(value, -alpha, alpha) 24 | 25 | # quantizing 26 | sign = np.sign(value) 27 | unsigned_value = value * sign 28 | unsigned_value = unsigned_value * (pow(2, quan_bits - 1) - 1.0) / (alpha * num_clients) 29 | value = unsigned_value * sign 30 | 31 | # stochastic round 32 | size = value.shape 33 | value = np.floor(value + np.random.random(size)).astype(int) 34 | value = value.astype(object) 35 | 36 | # add padding bits and show with complement 37 | padding_bits = int(np.ceil(np.log2(num_clients))) 38 | res_val = [] 39 | for elem in value: 40 | res_val.append(true2two(elem, padding_bits, quan_bits)) 41 | return res_val 42 | 43 | # unquantize and return true value 44 | def unquantize_padding(value, alpha, quan_bits, num_clients): 45 | value = np.array(value) 46 | # stochastic round 47 | size = value.shape 48 | value = np.floor(value + np.random.random(size)).astype(int) 49 | value = value.astype(object) 50 | padding_bits = int(np.ceil(np.log2(num_clients))) 51 | #alpha *= pow(2,padding_bits) 52 | 53 | # extract 2's complement to true value 54 | res_value = [] 55 | for elem in value: 56 | # if elem < 0: 57 | # raise ValueError("Overflow {}".format(elem)) 58 | # # print("Overflow {}".format(elem)) 59 | # # res_value.append(0) 60 | # # continue 61 | if elem == 0: 62 | res_value.append(elem) 63 | continue 64 | #elem &= ((1 << (quan_bits-1)) - 1) 65 | elem = two2true(elem, padding_bits, quan_bits) 66 | res_value.append(elem) 67 | value = res_value 68 | 69 | # unquantize 70 | sign = np.sign(value) 71 | unsigned_value = value * sign 72 | unsigned_value = unsigned_value * (alpha * num_clients) / (pow(2, quan_bits - 1) - 1.0) 73 | value = unsigned_value * sign 74 | return value 75 | 76 | def quan_no_compl(value, quan_bits, num_clients): 77 | alpha = 1.0 78 | # clipping 79 | value = np.clip(value, -alpha, alpha) 80 | 81 | # quantizing 82 | sign = np.sign(value) 83 | unsigned_value = value * sign 84 | unsigned_value = unsigned_value * (pow(2, quan_bits - 1) - 1.0) / (alpha * num_clients) 85 | value = unsigned_value * sign 86 | 87 | # stochastic round 88 | size = value.shape 89 | value = np.floor(value + np.random.random(size)).astype(int) 90 | value = value.astype(object) 91 | 92 | return value 93 | 94 | def unquan_no_compl(value, quan_bits, num_clients): 95 | alpha = 1.0 96 | value = np.array(value) 97 | # stochastic round 98 | size = value.shape 99 | value = np.floor(value + np.random.random(size)).astype(int) 100 | value = value.astype(object) 101 | 102 | # unquantize 103 | sign = np.sign(value) 104 | unsigned_value = value * sign 105 | unsigned_value = unsigned_value * (alpha * num_clients) / (pow(2, quan_bits - 1) - 1.0) 106 | value = unsigned_value * sign 107 | return value 108 | 109 | def quantize_postive(value,quan_bits,num_clients): 110 | alpha = 1.0 111 | # clipping 112 | value = np.clip(value, -alpha, alpha) 113 | 114 | value = np.add(value,1.0) 115 | print(np.max(value)) 116 | 117 | # quantizing 118 | value = value * (pow(2, quan_bits - 1) - 1.0) / (alpha * num_clients) 119 | 120 | # stochastic round 121 | size = value.shape 122 | value = np.floor(value + np.random.random(size)).astype(int) 123 | value = value.astype(object) 124 | 125 | return value 126 | 127 | def unquantize_postive(value,quan_bits,num_clients,idx_weights): 128 | value = np.array(value) 129 | # stochastic round 130 | size = value.shape 131 | value = np.floor(value + np.random.random(size)).astype(int) 132 | value = value.astype(object) 133 | alpha = 1.0 134 | 135 | # unquantize 136 | value = value * (alpha * num_clients) / (pow(2, quan_bits - 1) - 1.0) 137 | #idx_weights = [0.14,0.35,0.45,0.5] 138 | 139 | offset = np.sum(idx_weights) 140 | value = np.add(value, -1 * offset) 141 | return value 142 | 143 | # quantize function 144 | def quantize(plains,element_bits,num_clients): 145 | #alpha,r_max = get_alpha_r_max(plains,element_bits,num_clients) 146 | alpha = 5.0 147 | quantized = quantize_padding(plains,alpha,element_bits,num_clients) 148 | return np.array(quantized) 149 | 150 | # unquantize function 151 | def unquantize(plains, element_bits, num_clients, alpha = 5.0): 152 | unquantized = unquantize_padding(plains, alpha, element_bits, num_clients) 153 | unquantized = np.array(unquantized) 154 | return np.round(unquantized, precision) 155 | 156 | # batch elems function 157 | def batch_padding(array, max_bits, elem_bits,batch_size): 158 | elem_bits = elem_bits + 4 159 | array = array.tolist() 160 | #batch_size = max_bits // elem_bits 161 | if len(array) % batch_size != 0: 162 | pad_zero_nums = batch_size - len(array) % batch_size 163 | array += [0] * pad_zero_nums 164 | 165 | # how many batches is needed 166 | batch_nums = len(array) // batch_size 167 | 168 | # carry batches 169 | batch_list = [] 170 | mod = pow(2, elem_bits) 171 | 172 | # for each batch 173 | for b in range(batch_nums): 174 | tmp = 0 175 | # for each elem 176 | for i in range(batch_size): 177 | tmp *= mod 178 | tmp += array[i + b * batch_size] 179 | batch_list.append(tmp) 180 | 181 | return np.array(batch_list).astype(object) 182 | 183 | def unbatching_padding(array, elem_bits,batch_size ): 184 | elem_bits = elem_bits + 4 185 | res = [] 186 | mask = pow(2,elem_bits) - 1 187 | 188 | for item in array: 189 | item = int(item) 190 | if item == 0 : 191 | tmp = [0] * batch_size 192 | res.extend(tmp) 193 | continue 194 | tmp = [] 195 | for i in range(batch_size): 196 | num = item & mask 197 | item >>= elem_bits 198 | tmp.append(num) 199 | tmp.reverse() 200 | res += tmp 201 | res = np.array(res) 202 | 203 | return res -------------------------------------------------------------------------------- /utils/min_hash.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import warnings 4 | from sklearn.cluster import KMeans 5 | from sklearn.metrics import jaccard_score 6 | from gap_statistic import OptimalK 7 | 8 | warnings.filterwarnings('ignore', category=FutureWarning) 9 | warnings.filterwarnings('ignore', category=UserWarning) 10 | 11 | from sklearn.metrics.pairwise import cosine_similarity 12 | from utils.util import jaccard_kmeans_clustering 13 | 14 | def sigGen(matrix,randomSeq): 15 | """ 16 | * generate the signature vector 17 | :param matrix: a ndarray var 18 | :return a signature vector: a list var 19 | """ 20 | # initialize the sig vector as [-1, -1, ..., -1] 21 | result = [-1 for i in range(matrix.shape[1])] 22 | 23 | count = 0 24 | 25 | for row in randomSeq: 26 | for i in range(matrix.shape[1]): 27 | if matrix[row][i] != 0 and result[i] == -1: 28 | result[i] = row 29 | count += 1 30 | if count == matrix.shape[1]: 31 | break 32 | return result 33 | 34 | def sigMatrixGen(input_matrix,random_R, n): 35 | """ 36 | generate the sig matrix 37 | :param input_matrix: naarray var 38 | :param n: the row number of sig matrix which we set 39 | :return sig matrix: ndarray var 40 | """ 41 | 42 | result = [] 43 | 44 | for i in range(n): 45 | sig = sigGen(input_matrix,random_R[i]) 46 | result.append(sig) 47 | return np.array(result) 48 | 49 | def quan_params(input_matrix, threshold): 50 | for idx, row in enumerate(input_matrix): 51 | for i in range(len(row)): 52 | if abs(row[i]) < threshold: 53 | input_matrix[idx][i] = 0 54 | else: 55 | input_matrix[idx][i] = 1 56 | 57 | return input_matrix 58 | 59 | def real_sim(input_matrix): 60 | row_num = input_matrix.shape[0] 61 | total = 0 62 | sim = 0 63 | for row in range(row_num): 64 | if input_matrix[row][0] == 1 or input_matrix[row][1] == 1: 65 | total += 1 66 | if input_matrix[row][0] == input_matrix[row][1]: 67 | sim += 1 68 | return sim / total 69 | 70 | def dim_reduce_sim(input_matrix): 71 | row_num = input_matrix.shape[0] 72 | sim = 0 73 | for row in range(row_num): 74 | if input_matrix[row][0] == input_matrix[row][1]: 75 | sim += 1 76 | return sim / row_num 77 | 78 | 79 | def gen_random_R(input_len, sim_len): 80 | """ 81 | Random matrix is needed for sketch-based client selection. 82 | 83 | Args: 84 | input_len (`int`): 85 | Input dimension. 86 | sim_len (`int`): 87 | Output dimension. 88 | Returns: 89 | random_R (`list`): 90 | Random matrix. 91 | """ 92 | 93 | random_R = [] 94 | for i in range(sim_len): 95 | seq_list = np.arange(input_len) 96 | np.random.shuffle(seq_list) 97 | random_R.append(seq_list) 98 | return random_R 99 | 100 | ''' 101 | def gap_statistic(data,max_k): 102 | optimalK = OptimalK(n_jobs=4, parallel_backend='joblib') 103 | n_clusters = optimalK(data, cluster_array=np.arange(1, max_k+1)) 104 | return n_clusters 105 | ''' 106 | 107 | def sse_statistic(X,max_k): 108 | sse = [] 109 | for k in range(1, max_k+1): 110 | kmeans = KMeans(n_clusters=k).fit(X) 111 | sse.append(kmeans.inertia_) 112 | diff = np.diff(sse) 113 | diff_r = diff[1:] / diff[:-1] 114 | k_opt = np.argmax(diff_r) + 2 115 | return k_opt 116 | 117 | def gap_statistic(data,max_k): 118 | data = np.array(data).reshape(-1,1) 119 | optimalK = OptimalK(n_jobs=20, parallel_backend='joblib') 120 | n_clusters = optimalK(data, cluster_array=np.arange(1, max_k+1)) 121 | return n_clusters 122 | 123 | def clusters_selection_L2(hash_mat,max_k,train_weights=[],weights_clusters=[]): 124 | k_opt = gap_statistic(hash_mat.astype(float),max_k) 125 | clusters = jaccard_kmeans_clustering(hash_mat, k_opt) 126 | tmp = hash_mat.tolist() 127 | rep_num = [1] * len(train_weights) 128 | selected_clients = [] 129 | print("Num:{} ,Next round Selected clients:{}".format(k_opt, clusters)) 130 | 131 | for i, indices in enumerate(clusters): 132 | tmp_weights = [train_weights[i] for i in indices] 133 | tmp_list = list(indices) 134 | zipped = zip(tmp_weights, tmp_list) 135 | max_tuple = max(zipped, key=lambda x: x[0]) 136 | client_index = max_tuple[1] 137 | if train_weights != []: 138 | rep_num[client_index] = len(indices) 139 | tmp = 0 140 | for idx in indices: 141 | tmp += train_weights[idx] 142 | weights_clusters[client_index] = tmp 143 | selected_clients.append(client_index) 144 | return selected_clients,rep_num 145 | 146 | def clusters_selection(hash_mat,max_k,train_weights=[],weights_clusters=[]): 147 | k_opt = gap_statistic(hash_mat.astype(float),max_k) 148 | while True: 149 | kmeans = KMeans(n_clusters=k_opt).fit(hash_mat) 150 | cluster_counts = np.unique(kmeans.labels_) 151 | if len(cluster_counts) == k_opt: 152 | break 153 | labels = kmeans.labels_ 154 | unique_labels = np.unique(labels) 155 | selected_clients = [] 156 | 157 | rep_num = [1] * len(train_weights) 158 | for label in unique_labels: 159 | indices = np.where(labels == label)[0] 160 | tmp_weights = [train_weights[i] for i in indices] 161 | tmp_list = list(indices) 162 | zipped = zip(tmp_weights, tmp_list) 163 | max_tuple = max(zipped, key=lambda x: x[0]) 164 | client_index = max_tuple[1] 165 | 166 | if train_weights != []: 167 | rep_num[client_index] = len(indices) 168 | 169 | if train_weights != []: 170 | tmp = 0 171 | for idx in indices: 172 | tmp += train_weights[idx] 173 | weights_clusters[client_index] = tmp 174 | selected_clients.append(client_index) 175 | 176 | return selected_clients,rep_num 177 | 178 | def client_selection(sampled_clients,labels,train_weights=[],weights_clusters=[]): 179 | unique_labels = np.unique(labels) 180 | selected_clients = [] 181 | rep_num = [1] * len(train_weights) 182 | for label in unique_labels: 183 | indices = np.where(labels == label)[0] # 找到属于当前标签的数据点索引 184 | retA = [i for i in indices if i in sampled_clients] 185 | if retA == []: 186 | continue 187 | else: 188 | client_index = np.random.choice(np.array(retA)) 189 | 190 | if train_weights != []: 191 | rep_num[client_index] = len(indices) 192 | 193 | if train_weights != []: 194 | tmp = 0 195 | for idx in retA: 196 | tmp += train_weights[idx] 197 | weights_clusters[client_index] = tmp 198 | selected_clients.append(client_index) 199 | return selected_clients,rep_num 200 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Efficient and Straggler-Resistant Homomorphic Encryption for Heterogeneous Federated Learning [[paper](https://ieeexplore.ieee.org/abstract/document/10621440)] 2 | 3 | > Nan Yan, Yuqing Li, Jing Chen, Xiong Wang, Jianan Hong, Kun He, and Wei Wang. *IEEE INFOCOM 2024* 4 | 5 | ![](https://cdn.jsdelivr.net/gh/lunan0320/pics@main/images/202403/image-20240320201923223.png) 6 | 7 | # FedPHE: A Secure and Efficient Federated Learning via Packed Homomorphic Encryption[[paper](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=10989521)] 8 | 9 | > Yuqing Li, Nan Yan, Jing Chen, Xiong Wang, Jianan Hong, Kun He, Wei Wang, and Bo Li. *IEEE TDSC* 10 | 11 | ![](./assets/FedPHE.png) 12 | 13 | ## News 14 | 15 | - [2025.05.06] FedPHE has been accepted in IEEE TDSC. 16 | - [2024.03.21] Usage and Acknowledgments section. 17 | - [2024.03.20] FedPHE source code has been released. 18 | - [2023.12.01] Paper has been accepted in IEEE INFOCOM 2024. 19 | 20 | ## Abstract 21 | 22 | Cross-silo federated learning (FL) enables multiple institutions (clients) to collaboratively build a global model without sharing their private data. To prevent privacy leakage during aggregation, homomorphic encryption (HE) is widely used to encrypt model updates, yet incurs high computation and communication overheads. To reduce these overheads, packed HE (PHE) has been proposed to encrypt multiple plaintexts into a single ciphertext. However, the original design of PHE does not consider the heterogeneity among different clients, an intrinsic problem in cross-silo FL, often resulting in undermined training efficiency with slow convergence and stragglers. In this work, we propose FedPHE, an efficiently packed homomorphically encrypted FL framework with secure weighted aggregation and client selection to tackle the heterogeneity problem. Specifically, using CKKS with sparsification, FedPHE can achieve efficient encrypted weighted aggregation by accounting for contributions of local updates to the global model. To mitigate the straggler effect, we devise a sketching-based client selection scheme to cherry-pick representative clients with heterogeneous models and computing capabilities. We show, through rigorous security analysis and extensive experiments, that FedPHE can efficiently safeguard clients’ privacy, achieve a training speedup of 1.85 − 4.44×, cut the communication overhead by 1.24 − 22.62×, and reduce the straggler effect by up to 1.71 − 2.39×. 23 | 24 | ## Citation 25 | 26 | > If you find FedPHE useful or relevant to your research, please kindly cite our paper using the following bibtex. 27 | 28 | ``` 29 | @inproceedings{yan2024efficient, 30 | title={Efficient and Straggler-Resistant Homomorphic Encryption for Heterogeneous Federated Learning}, 31 | author={Yan, Nan and Li, Yuqing and Chen, Jing and Wang, Xiong and Hong, Jianan and He, Kun and Wang, Wei}, 32 | booktitle={IEEE INFOCOM 2024-IEEE Conference on Computer Communications}, 33 | pages={791--800}, 34 | year={2024}, 35 | organization={IEEE} 36 | } 37 | ``` 38 | 39 | ``` 40 | @article{li2025fedphe, 41 | title={FedPHE: A Secure and Efficient Federated Learning via Packed Homomorphic Encryption}, 42 | author={Li, Yuqing and Yan, Nan and Chen, Jing and Wang, Xiong and Hong, Jianan and He, Kun and Wang, Wei and Li, Bo}, 43 | journal={IEEE Transactions on Dependable and Secure Computing}, 44 | year={2025}, 45 | publisher={IEEE} 46 | } 47 | ``` 48 | 49 | ## Folder Structure 50 | 51 | ``` 52 | ├── workspace 53 | │ └── data 54 | │ └── log 55 | │ └── encryption 56 | | | └── paillier,bfv,ckks 57 | │ ├── models 58 | │ ├── utils 59 | │ │ └── dataset 60 | │ │ └── min_hash 61 | │ ├── client 62 | │ ├── server 63 | │ ├── fed 64 | │ ├── main 65 | ``` 66 | 67 | ## Usage 68 | 69 | ### Installation 70 | 71 | ``` 72 | # create an environment called "FedPHE" 73 | conda create -n FedPHE python 74 | conda activate FedPHE 75 | 76 | # git clone the repo first 77 | git clone git@github.com:lunan0320/FedPHE.git 78 | 79 | cd FedPHE 80 | mkdir data, log, data_dir 81 | 82 | # install the correct packages required 83 | pip install -r requirements.txt 84 | ``` 85 | 86 | ### Run 87 | 88 | ``` 89 | python main.py --dataset {dataset_name} --epochs {epoch_num} --lr {learing_rate} --n_clients {client_num} --topk {sparse_rate} --algorithm {HE_algo} --enc_batch_size {pack_size} --sim_len {hash_func_num} 90 | ``` 91 | 92 | ### Reproduce our results 93 | 94 | In this repository, you will find all the necessary components to reproduce the results from our research. The example instruction is outlined below: 95 | 96 | ``` 97 | python main.py --dataset MNIST --epochs 100 --lr 0.001 --n_clients 8 --topk 0.1 --algorithm ckks --enc_batch_size 4096 --sim_len 200 --enc True --isSelection True --isSpars topk 98 | ``` 99 | 100 | 1. The package size for `paillier`, `bfv` and `ckks` are different. We always set `80` for paillier and `4096` for bfv and ckks. 101 | 2. You can adaptively choose whether to `encrypt`, whether to `select`, whether to `sparse`, etc. 102 | 3. You can choose which encryption method to use. If you want to calculate a more accurate time cost, you can set `--cipher_count False` (there is also a certain time cost for counting cipher traffic) 103 | 4. Please make sure your `GPU` has enough memory, `inter-process communication` ciphertext may cause `memory overflow` 104 | 105 | > Note: In order to communicate and synchronize between processes, the code uses the `torch.multiprocessing` library. We have set up more `locks`, `events`, `values`, `pipes`, and `queues`. If you are a novice in this area, do not modify these switches, as it may cause running failures. 106 | 107 | ## Results 108 | 109 | - We observe that FedPHE reduces the network footprint for MNIST, FashionMNIST, and CIFAR-10 by up to 16.49×, 16.89×, and 3.31×, respectively, compared to PackedCKKS. Moreover, it outperforms PackedBFV for 4.28−22.62× across three datasets. It is worth noting that the ciphertext size is only 0.81×, 0.77×, and 4.11× compared to the BatchCrypt. This indicates the efficiency of FedPHE in reducing the ciphertext generated by CKKS to the level of BatchCrypt encryption with Paillier. This achievement is truly remarkable. Additionally, the ciphertext size, which is previously in “memory out” state as shown in Table I, has been reduced to only 2.07 − 9.88× larger than the plaintext baseline, making FedPHE applicable to FL in practice. In conclusion, FedPHE achieves communication overhead reduction ranging from 1.24× to 22.62× compared to these baselines. 110 | 111 | - As shown in Table III, BatchCrypt requires 4.02 − 7.01× more training time compared to plaintext. In contrast, FedPHE incurs only 1.58 − 2.17× training time of the plaintext baseline, greatly enhancing the efficiency of model training. Furthermore, leveraging sparsification and client selection, FedPHE achieves a training acceleration of 1.85−4.44×. With an apt sparsification threshold, FedPHE does not adversely affect the trained model quality. Instead, it achieves significant compression while maintaining high performance. 112 | 113 | ![](https://cdn.jsdelivr.net/gh/lunan0320/pics@main/images/202403/image-20240320222216046.png) 114 | 115 | ## Acknowledgments 116 | 117 | I'd like to express my gratitude to the following projects and contributors for their work and contributions, which have been invaluable to this project: 118 | 119 | - [Gap_Statistic](https://github.com/milesgranger/gap_statistic) - for their dynamically determined the suggested clusters in the data for unsupervised learning. 120 | - [BatchCrypt](https://github.com/marcoszh/BatchCrypt) - for paper ATC'20 paper "BatchCrypt: Efficient Homomorphic Encryption for Cross-Silo Federated Learning". 121 | - [FLASHE](https://github.com/SamuelGong/FLASHE) - for the innovative algorithms that inspired our "packed with sparsification" optimization techniques. 122 | - [Datasketch](https://github.com/ekzhu/datasketch) - for the MinHash, LSH data sketch that can process and search very large amount of data super fast, with little loss of accuracy. 123 | 124 | This project also stands on the shoulders of numerous open-source contributors who make their work freely available for public use and modification. A heartfelt thank you to the open-source community for making this project possible. 125 | -------------------------------------------------------------------------------- /encryption/ckks.py: -------------------------------------------------------------------------------- 1 | import tenseal as ts 2 | import numpy as np 3 | from encryption.encrypt import Encrypt 4 | from functools import reduce 5 | from utils.util import pseudo_random 6 | 7 | 8 | class CKKSCipher(Encrypt): 9 | def __init__(self, poly_modulus_degree=8192, 10 | coeff_mod_bit_sizes=None, 11 | global_scale=2**40): 12 | super(CKKSCipher, self).__init__() 13 | 14 | self.context = ts.context(ts.SCHEME_TYPE.CKKS, poly_modulus_degree=8192, coeff_mod_bit_sizes=[60, 40, 40, 60]) 15 | self.context.global_scale=2**40 16 | # ckks_sk=ckks_ctx.secret_key() 17 | # ckks_ctx.make_context_public() 18 | 19 | # self.poly_modulus_degree = poly_modulus_degree 20 | # if coeff_mod_bit_sizes: 21 | # self.coeff_mod_bit_sizes = coeff_mod_bit_sizes 22 | # else: 23 | # self.coeff_mod_bit_sizes = [] # should be this since we do no do any multiplication 24 | # self.global_scale = global_scale 25 | 26 | # self.context = ts.context( 27 | # scheme=ts.SCHEME_TYPE.CKKS, 28 | # poly_modulus_degree=self.poly_modulus_degree, 29 | # coeff_mod_bit_sizes=self.coeff_mod_bit_sizes, 30 | # encryption_type=ts.ENCRYPTION_TYPE.SYMMETRIC 31 | # ) 32 | # self.context.generate_galois_keys() 33 | # self.context.global_scale = self.global_scale 34 | 35 | def from_bytes(self, arr): 36 | if isinstance(arr, list): 37 | ret = [] 38 | for e in arr: 39 | ret.append(self.from_bytes(e)) 40 | return ret 41 | else: 42 | c = ts.CKKSVector.load(self.context, arr) 43 | return c 44 | 45 | def to_bytes(self, arr): 46 | if isinstance(arr, list): 47 | ret = [] 48 | for e in arr: 49 | ret.append(self.to_bytes(e)) 50 | return ret 51 | else: 52 | return arr.serialize() 53 | 54 | def encrypt(self, value): 55 | batch_size = [] 56 | # value should be a 1-d np.array 57 | cipher = ts.ckks_vector(self.context, value) 58 | cipher_serial = cipher.serialize() 59 | return cipher_serial 60 | 61 | def enc_batch(self,value): 62 | batch_size = 50 63 | batch_num = int(np.ceil(len(value) / batch_size)) 64 | cipher_list = [] 65 | for i in range(batch_num): 66 | cipher = ts.ckks_vector(self.context, value[i * batch_size : (i + 1) * batch_size]) 67 | cipher_list.append(cipher.serialize()) 68 | return cipher_list 69 | 70 | 71 | 72 | 73 | 74 | def sum(self, arr,idx_weights): 75 | loaded = [ts.CKKSVector.load(self.context, e)*idx_weights[arr.index(e)] for e in arr] 76 | res = reduce(lambda x, y: x + y, loaded) 77 | return res.serialize() 78 | 79 | def sum_batch(self,arr,idx_weights): 80 | batch_num = len(arr[0]) 81 | res_list = [] 82 | for batch in range(batch_num): 83 | res = 0 84 | for client_cipher in range(len(arr)): 85 | res += ts.CKKSVector.load(self.context, arr[client_cipher][batch]) * idx_weights[client_cipher] 86 | res_list.append(res.serialize()) 87 | return res_list 88 | 89 | def decrypt(self, value): 90 | return np.array(ts.CKKSVector.load(self.context, value).decrypt()) 91 | 92 | 93 | def encrypt_no_batch(self, value): 94 | return [ts.ckks_vector(self.context, [i]).serialize() for i in value] 95 | 96 | def encrypt_no_batch1(self, value): 97 | return [ts.ckks_vector(self.context, [i]) for i in value] 98 | 99 | def sum_no_batch(self, arr): 100 | l = len(arr[0]) 101 | result = [] 102 | for i in range(l): 103 | scalars = [ts.CKKSVector.load(self.context, e[i]) for e in arr] 104 | 105 | result.append(reduce(lambda x, y: x + y, scalars).serialize()) 106 | return result 107 | 108 | def sum_no_batch1(self, cipher_lists,idx_weights): 109 | l = len(cipher_lists[0]) 110 | result = [] 111 | for i in range(l): 112 | scalars = [e[i]*idx_weights[cipher_lists.index(e)] for e in cipher_lists] 113 | result.append(reduce(lambda x, y: x + y, scalars)) 114 | return result 115 | 116 | def dec_batch(self,serial_list): 117 | plain_list = [] 118 | for cipher_serial in serial_list: 119 | plain = ts.CKKSVector.load(self.context, cipher_serial).decrypt() 120 | plain_list.extend(plain) 121 | 122 | return np.array(plain_list) 123 | 124 | def decrypt_no_batch(self, value): 125 | return np.array([ts.CKKSVector.load(self.context, i).decrypt() for i in value]) 126 | 127 | def decrypt_no_batch1(self, value): 128 | return np.array([ts.CKKSVector(self.context, i).decrypt() for i in value]) 129 | 130 | def set_context(self, bytes): 131 | context = ts.Context.load(bytes) 132 | self.context = context 133 | 134 | def get_context(self, save_secret_key=False): 135 | return self.context.serialize(save_secret_key=save_secret_key, 136 | save_galois_keys=False, 137 | save_relin_keys=False) 138 | 139 | 140 | def ckks_enc(plain_list,ckks_ctx,isBatch,batch_size,topk,round,randk_seed, is_spars = 'topk'): 141 | batch_num = int(np.ceil(len(plain_list) / batch_size)) 142 | if isBatch: 143 | # padding 144 | if len(plain_list) % batch_num != 0: 145 | padding_num = batch_num * batch_size - len(plain_list) 146 | plain_list.extend([0]*padding_num) 147 | if is_spars == 'topk': 148 | topk = int(np.ceil(batch_num * topk)) 149 | plain_batchs = [plain_list[i * batch_size : (i+1) * batch_size ]for i in range(batch_num)] 150 | # L2 norm 151 | #avg_list = ((np.linalg.norm(plain_batchs,axis=1,keepdims=True)).flatten()).tolist() 152 | # avg 153 | avg_list = [np.average(np.abs(batch)) for batch in plain_batchs] 154 | max_avg_list = np.sort(avg_list)[::-1][:topk] 155 | mask_list = [] 156 | for i in range(len(max_avg_list)): 157 | mask_list.append(avg_list.index(max_avg_list[i])) 158 | mask_list.sort() 159 | res_mask = [0 for i in range(batch_num) ] 160 | for i in range(batch_num): 161 | if i in mask_list: 162 | res_mask[i] = 1 163 | # batch for encryption 164 | plain_list = [plain_list[mask_list[i] * batch_size : (mask_list[i] + 1) * batch_size] for i in range(len(mask_list))] 165 | cipher_list = [] 166 | for i in range(len(mask_list)): 167 | cipher = ts.ckks_vector(ckks_ctx,plain_list[i]) 168 | cipher_list.append(cipher.serialize()) 169 | return cipher_list, res_mask 170 | elif is_spars == 'randk': 171 | randk = topk 172 | plain_batchs = [plain_list[i * batch_size : (i+1) * batch_size ]for i in range(batch_num)] 173 | randk_list = pseudo_random(randk_seed,batch_num,randk,round) 174 | # batch for encryption 175 | plain_list = [plain_list[randk_list[i] * batch_size : (randk_list[i] + 1) * batch_size] for i in range(len(randk_list))] 176 | cipher_list = [] 177 | for i in range(len(randk_list)): 178 | cipher = ts.ckks_vector(ckks_ctx,plain_list[i]) 179 | cipher_list.append(cipher.serialize()) 180 | return cipher_list, randk_list 181 | else: 182 | cipher_list = [] 183 | for i in range(batch_num): 184 | cipher = ts.ckks_vector(ckks_ctx, plain_list[i * batch_size : (i + 1) * batch_size]) 185 | cipher_list.append(cipher.serialize()) 186 | return cipher_list 187 | else: 188 | cipher = [ts.ckks_vector(ckks_ctx, [i]).serialize() for i in plain_list] 189 | return cipher 190 | 191 | 192 | def ckks_dec(cipher_list,ckks_ctx,sk,isBatch,randk_list, sum_masks = [],batch_size = 0): 193 | if isBatch: 194 | # randk align 195 | if randk_list != []: 196 | tmp_list = [0] * len(cipher_list) 197 | for i,radnk_idx in enumerate(randk_list): 198 | tmp_list[radnk_idx] = cipher_list[i] 199 | cipher_list = tmp_list 200 | plain_list = [] 201 | for idx, cipher_serial in enumerate(cipher_list): 202 | if cipher_serial == 0: 203 | zero_pad = [0] * batch_size 204 | plain_list.extend(zero_pad) 205 | else: 206 | plain = ts.CKKSVector.load(ckks_ctx, cipher_serial).decrypt(sk) 207 | if sum_masks != []: 208 | plain = np.array(plain)/sum_masks[idx] 209 | plain_list.extend(plain) 210 | return np.array(plain_list) 211 | else: 212 | plains = [ts.CKKSVector.load(ckks_ctx, i).decrypt() for i in cipher_list] 213 | plains = [plains.squeeze()] 214 | 215 | return np.array(plains) 216 | 217 | 218 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | import os 5 | import torch 6 | import torch.multiprocessing as mp 7 | from utils.util import logging 8 | import random 9 | import numpy as np 10 | import tenseal as ts 11 | from fed import run 12 | from encryption.paillier import PaillierCipher 13 | 14 | 15 | def arg_parse(): 16 | parser = argparse.ArgumentParser() 17 | 18 | # dataset and parameters 19 | parser.add_argument('--dataset', type=str, default='MNIST', 20 | help='datasets: MNIST, FashionMNIST, CIFAR10, CIFAR100') 21 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 22 | help='input batch size for training (default: 64)') 23 | parser.add_argument('--epochs', type=int, default=50, metavar='N', 24 | help='number of epochs to train (default: 10)') 25 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 26 | help='learning rate (default: 0.001)') 27 | parser.add_argument('--weighted',type=bool,default=True) 28 | parser.add_argument('--n_clients', type=int, default= 8, metavar='N', 29 | help='how many training processes to use (default: 10)') 30 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 31 | help='momentum (default: 0.9)') 32 | 33 | 34 | # data split 35 | parser.add_argument('--n_shards', type=int, default=5, 36 | help='number of shards') 37 | parser.add_argument('--alpha', type=float, default=1, 38 | help='parameter of dirichlet') 39 | parser.add_argument('--sgm', type=float, default=0.3, 40 | help='parameter of unbalance') 41 | parser.add_argument('--split', type=str, default='noniid', 42 | help='split method: iid or non-iid') 43 | parser.add_argument('--noniid_method', type=str, default='dirichlet', 44 | help='noniid method: pathological or dirichlet') 45 | # modules 46 | parser.add_argument('--enc',type=bool,default=True, 47 | help='enc or not') 48 | parser.add_argument('--isSelection',type=bool,default=True, 49 | help='Client selection or not') 50 | parser.add_argument('--isSpars', type=str, default='topk', 51 | help='sparsification method: topk or randk or topk') 52 | # sparsification 53 | parser.add_argument('--topk',type=float,default=0.2, 54 | help='sparfication fraction') 55 | 56 | # encryption paramsy 57 | parser.add_argument('--isBatch',type=bool,default=True, 58 | help='Batch HE or not') 59 | parser.add_argument('--cipher_count',type=bool,default=True, 60 | help='ciphertext size') 61 | parser.add_argument('--algorithm',type=str,default='ckks', 62 | help='HE algorithm: paillier,bfv, ckks') 63 | parser.add_argument('--quan_bits',type=int,default=16, 64 | help='quantification bits') 65 | parser.add_argument('--enc_batch_size',type=int,default=4096, 66 | help='Batch Encryption size') 67 | 68 | # selection 69 | parser.add_argument('--sim_len',type=int,default=200, 70 | help='lsh matrix width') 71 | 72 | # device and logdir 73 | parser.add_argument('--cuda', action='store_true', default=True, 74 | help='enables CUDA training') 75 | parser.add_argument('--mps', action='store_true', default=True, 76 | help='enables macOS GPU training') 77 | parser.add_argument("--log_dir", type=str, 78 | default="log", help="directory of logs") 79 | parser.add_argument("--data_dir", type=str, 80 | default="data_dir/", help="directory of logs") 81 | parser.add_argument('--seed', type=int, default=42, help='random seed') 82 | parser.add_argument('--randk_seed', type=int, default=12, help='random k packages seed') 83 | 84 | return parser.parse_args() 85 | 86 | 87 | def seed_everything(seed,is_cuda): 88 | """ 89 | Seed function for randomization 90 | 91 | Args: 92 | seed (`int`): 93 | The seed in the parameters. 94 | is_cuda (`bool`): 95 | Whether to enable CUDA training. 96 | Returns: 97 | None 98 | """ 99 | random.seed(seed) 100 | torch.manual_seed(seed) 101 | np.random.seed(seed) 102 | os.environ["PYTHONHASHSEED"] = str(seed) 103 | # initialize the gpu device id 104 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 105 | if is_cuda: 106 | torch.cuda.manual_seed_all(is_cuda) 107 | torch.backends.cudnn.deterministic = True 108 | torch.backends.cudnn.benchmark = False 109 | 110 | 111 | def init_logger(log_dir,dataset): 112 | """ 113 | Remove the historical log file 114 | 115 | Args: 116 | log_dir (`arg_parse` ): 117 | The directory to save log files. 118 | Returns: 119 | None 120 | """ 121 | log_file = os.path.join(log_dir, dataset + '.log') 122 | if os.path.exists(log_file): 123 | os.remove(log_file) 124 | 125 | def ckks_init(data_dir): 126 | """ 127 | Initialize and write context in CKKS encryption mode. 128 | 129 | Args: 130 | data_dir (`str`): 131 | Directory for data to store. 132 | 133 | Returns: 134 | None. 135 | """ 136 | ckks_ctx = ts.context(ts.SCHEME_TYPE.CKKS, poly_modulus_degree=8192, coeff_mod_bit_sizes=[60, 40, 40, 60]) 137 | ckks_ctx.global_scale=2**40 138 | ckks_ctx.generate_galois_keys() 139 | params = ckks_ctx.serialize(save_secret_key=True) 140 | ckks_file = os.path.join(data_dir + 'context_params') 141 | with open(ckks_file, "wb") as f: 142 | f.write(params) 143 | 144 | def bfv_init(data_dir): 145 | """ 146 | Initialize and write context in BFV encryption mode. 147 | 148 | Args: 149 | data_dir (`str`): 150 | Directory for data to store. 151 | 152 | Returns: 153 | None. 154 | """ 155 | bfv_ctx = ts.context(ts.SCHEME_TYPE.BFV, poly_modulus_degree=8192, plain_modulus=1032193) 156 | params = bfv_ctx.serialize(save_secret_key=True) 157 | ckks_file = os.path.join(data_dir + 'bfv_ctx') 158 | with open(ckks_file, "wb") as f: 159 | f.write(params) 160 | 161 | def paillier_init(): 162 | """ 163 | Initialize context in paillier encryption mode. 164 | 165 | Args: 166 | None. 167 | 168 | Returns: 169 | enc_tools (`dict`): 170 | cls_paillier (`context`): 171 | PaillierCipher context. 172 | mod (`int`): 173 | mod size. 174 | num_bits_per_batch (`int`): 175 | bit length for one package. 176 | """ 177 | enc_tools = {} 178 | cls_paillier = PaillierCipher() 179 | cls_paillier.generate_key(n_length=2048) 180 | mod = pow(cls_paillier.get_n(),2) 181 | 182 | enc_tools['cls_paillier'] = cls_paillier 183 | enc_tools['mod'] = mod 184 | enc_tools['num_bits_per_batch'] = (cls_paillier.get_n() ** 2).bit_length() 185 | return enc_tools 186 | 187 | 188 | def IPC_init(n_clients): 189 | """ 190 | IPC communication between processes. 191 | 192 | Args: 193 | n_clients (`int` ): 194 | The num of clients to participate. 195 | Returns: 196 | `dict`: The locks, pipes, queues, flag, event in multiprocessing communication. 197 | lock_print, queue_lock (`Lock`): 198 | Process lock for print and logging. 199 | flag (`Value`): 200 | Record whether the iteration is terminated. 201 | e, e_server (`Event`): 202 | Synchronization with clients. 203 | pipes, send_pipes (`Pipe`): 204 | (Encrypted) gradients send to the server. 205 | queues, acc_queue, hash_queue, clients_queues (`Queue`): 206 | Server send (encrypted) aggregated gradients to clients, 207 | clients send local accuracy to the server, 208 | clients send hash value to the server, 209 | server send clients selected in the next epoch. 210 | """ 211 | lock_print = mp.Lock() 212 | queue_lock = mp.Lock() 213 | 214 | flag = mp.Value('b', False) 215 | 216 | e = mp.Event() 217 | e_server = mp.Event() 218 | 219 | pipes = [mp.Pipe() for _ in range(n_clients)] 220 | send_pipes = [mp.Pipe() for _ in range(n_clients)] 221 | 222 | queues = [mp.Queue(1) for _ in range(n_clients)] 223 | acc_queue = mp.Queue(n_clients) 224 | hash_queue = mp.Queue(n_clients) 225 | clients_queues = mp.Queue(n_clients) 226 | 227 | 228 | kwargs_IPC = {'lock':lock_print,'e':e,'client_pipes':pipes,'queues':queues,'flag':flag,'e_server':e_server, 229 | 'acc_queue':acc_queue,'hash_queue':hash_queue,'queue_lock':queue_lock,'clients_queues':clients_queues 230 | ,'send_pipes':send_pipes,} 231 | return kwargs_IPC 232 | 233 | def device_init(is_cuda,is_mps): 234 | """ 235 | Determine which device to train on. 236 | 237 | Args: 238 | is_cuda (`bool`): 239 | Whether to enable CUDA training. 240 | is_mps (`bool`): 241 | Whether to enable mps training. 242 | Returns: 243 | None 244 | """ 245 | use_cuda = is_cuda and torch.cuda.is_available() 246 | use_mps = is_mps and torch.backends.mps.is_available() 247 | if use_cuda: 248 | device = torch.device("cuda") 249 | elif use_mps: 250 | device = torch.device("mps") 251 | else: 252 | device = torch.device("cpu") 253 | return device 254 | 255 | def main(): 256 | """ 257 | Main function. 258 | 259 | """ 260 | mp.set_start_method('spawn', force=True) 261 | args = arg_parse() 262 | 263 | seed_everything(args.seed, args.cuda) 264 | 265 | init_logger(args.log_dir,args.dataset) 266 | 267 | device = device_init(args.cuda,args.mps) 268 | 269 | kwargs_IPC = IPC_init(args.n_clients) 270 | 271 | if args.enc : 272 | if args.algorithm == 'ckks': 273 | ckks_init(args.data_dir) 274 | elif args.algorithm == 'paillier': 275 | enc_tools = paillier_init() 276 | kwargs_IPC.update({'enc_tools':enc_tools,}) 277 | elif args.algorithm == 'bfv': 278 | bfv_init(args.data_dir) 279 | else: 280 | raise ValueError("invalid algorithm!") 281 | 282 | logging("Basic information: device {}, learning rate {}, num clients {}, epochs {},noniid_method {},\ 283 | isEnc {},isBatch {},sparsification {}, client selection {},topk {}, enc_batch_size {}".format( 284 | device, args.lr,args.n_clients,args.epochs,args.noniid_method,args.enc,args.isBatch,args.isSpars,args.isSelection,args.topk,args.enc_batch_size),args) 285 | 286 | run(args,kwargs_IPC,device) 287 | 288 | 289 | if __name__ == '__main__': 290 | main() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /encryption/paillier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import time 4 | from encryption import gmpy_math 5 | 6 | from multiprocessing import cpu_count, Pool 7 | from encryption.encrypt import Encrypt 8 | 9 | from encryption.quantize import quantize, unquantize, batch_padding, unbatching_padding 10 | 11 | 12 | 13 | N_JOBS = cpu_count() 14 | 15 | 16 | def _static_encrypt(value, n): 17 | pubkey = PaillierPublicKey(n) 18 | return pubkey.encrypt(value) 19 | 20 | 21 | def _static_decrypt(value, n, p, q): 22 | pubkey = PaillierPublicKey(n) 23 | prikey = PaillierPrivateKey(pubkey, p, q) 24 | return prikey.decrypt(value) 25 | 26 | 27 | class PaillierPublicKey(object): 28 | def __init__(self, n): 29 | self.g = n + 1 30 | self.n = n 31 | self.nsquare = n * n 32 | self.max_int = n // 3 - 1 33 | 34 | def get_n(self): 35 | return self.n 36 | 37 | def __repr__(self): 38 | hashcode = hex(hash(self))[2:] 39 | return "".format(hashcode[:10]) 40 | 41 | def __eq__(self, other): 42 | return self.n == other.n 43 | 44 | def __hash__(self): 45 | return hash(self.n) 46 | 47 | def apply_obfuscator(self, ciphertext): 48 | r = random.SystemRandom().randrange(1, self.n) 49 | obfuscator = gmpy_math.powmod(r, self.n, self.nsquare) 50 | 51 | return (ciphertext * obfuscator) % self.nsquare 52 | 53 | def encrypt(self, plaintext): 54 | if plaintext >= (self.n - self.max_int) and plaintext < self.n: 55 | # Very large plaintext, take a sneaky shortcut using inverses 56 | neg_plaintext = self.n - plaintext # = abs(plaintext - nsquare) 57 | neg_ciphertext = (self.n * neg_plaintext + 1) % self.nsquare 58 | ciphertext = gmpy_math.invert(neg_ciphertext, self.nsquare) 59 | else: 60 | ciphertext = (self.n * plaintext + 1) % self.nsquare 61 | 62 | ciphertext = self.apply_obfuscator(ciphertext) 63 | return ciphertext 64 | 65 | 66 | class PaillierPrivateKey(object): 67 | def __init__(self, public_key, p, q): 68 | if not p * q == public_key.n: 69 | raise ValueError("given public key does not match the given p and q") 70 | if p == q: 71 | raise ValueError("p and q have to be different") 72 | self.public_key = public_key 73 | if q < p: 74 | self.p = q 75 | self.q = p 76 | else: 77 | self.p = p 78 | self.q = q 79 | self.psquare = self.p * self.p 80 | self.qsquare = self.q * self.q 81 | self.q_inverse = gmpy_math.invert(self.q, self.p) 82 | self.hp = self.h_func(self.p, self.psquare) 83 | self.hq = self.h_func(self.q, self.qsquare) 84 | 85 | def __eq__(self, other): 86 | return self.p == other.p and self.q == other.q 87 | 88 | def __hash__(self): 89 | return hash((self.p, self.q)) 90 | 91 | def __repr__(self): 92 | hashcode = hex(hash(self))[2:] 93 | 94 | return "".format(hashcode[:10]) 95 | 96 | def h_func(self, x, xsquare): 97 | return gmpy_math.invert(self.l_func(gmpy_math.powmod(self.public_key.g, 98 | x - 1, xsquare), x), x) 99 | 100 | def l_func(self, x, p): 101 | return (x - 1) // p 102 | 103 | def crt(self, mp, mq): 104 | u = (mp - mq) * self.q_inverse % self.p 105 | x = (mq + (u * self.q)) % self.public_key.n 106 | 107 | return x 108 | 109 | def decrypt(self, ciphertext): 110 | mp = self.l_func(gmpy_math.powmod(ciphertext, 111 | self.p-1, self.psquare), 112 | self.p) * self.hp % self.p 113 | mq = self.l_func(gmpy_math.powmod(ciphertext, 114 | self.q-1, self.qsquare), 115 | self.q) * self.hq % self.q 116 | 117 | plaintext = self.crt(mp, mq) 118 | 119 | return plaintext 120 | 121 | def get_p_q(self): 122 | return self.p, self.q 123 | 124 | 125 | class PaillierKeypair(object): 126 | def __init__(self): 127 | pass 128 | 129 | @staticmethod 130 | def generate_keypair(n_length=1024): 131 | p = q = n = None 132 | n_len = 0 133 | 134 | while n_len != n_length: 135 | p = gmpy_math.getprimeover(n_length // 2) 136 | q = p 137 | while q == p: 138 | q = gmpy_math.getprimeover(n_length // 2) 139 | n = p * q 140 | n_len = n.bit_length() 141 | 142 | public_key = PaillierPublicKey(n) 143 | private_key = PaillierPrivateKey(public_key, p, q) 144 | 145 | return public_key, private_key 146 | 147 | 148 | class PaillierCipher(Encrypt): 149 | def __init__(self): 150 | super(PaillierCipher, self).__init__() 151 | self.uuid = None 152 | self.exchanged_keys = None 153 | self.n = None 154 | self.key_length = None 155 | 156 | def set_n(self, n): # for all (arbiter is necessary, while host and guest is optional since they dont add) 157 | self.n = n 158 | 159 | def get_n(self): 160 | return self.n 161 | 162 | def set_self_uuid(self, uuid): 163 | self.uuid = uuid 164 | 165 | def set_exchanged_keys(self, keys): 166 | self.exchanged_keys = keys 167 | 168 | def generate_key(self, n_length=2048): 169 | self.key_length = n_length 170 | self.public_key, self.privacy_key = \ 171 | PaillierKeypair.generate_keypair(n_length=n_length) 172 | self.set_n(self.public_key.n) 173 | 174 | def get_key_pair(self): 175 | return self.public_key, self.privacy_key 176 | 177 | def set_public_key(self, public_key): 178 | self.public_key = public_key 179 | # for host 180 | self.set_n(public_key.n) 181 | 182 | def get_public_key(self): 183 | return self.public_key 184 | 185 | def set_privacy_key(self, privacy_key): 186 | self.privacy_key = privacy_key 187 | 188 | def get_privacy_key(self): 189 | return self.privacy_key 190 | 191 | def _dynamic_encrypt(self, value): 192 | return self.public_key.encrypt(value) 193 | 194 | def _multiprocessing_encrypt(self, value): 195 | shape = value.shape 196 | value_flatten = value.flatten() 197 | n = self.public_key.get_n() 198 | 199 | pool_inputs = [] 200 | for i in range(len(value_flatten)): 201 | pool_inputs.append([value_flatten[i], n]) 202 | 203 | pool = Pool(N_JOBS) 204 | ret = pool.starmap(_static_encrypt, pool_inputs) 205 | pool.close() 206 | pool.join() 207 | 208 | ret = np.array(ret) 209 | return ret.reshape(shape) 210 | 211 | def encrypt(self, value): 212 | if self.public_key is not None: 213 | if not isinstance(value, np.ndarray): 214 | return self._dynamic_encrypt(value) 215 | else: 216 | return self._multiprocessing_encrypt(value) 217 | else: 218 | return None 219 | 220 | def _dynamic_decrypt(self, value): 221 | return self.privacy_key.decrypt(value) 222 | 223 | def _multiprocessing_decrypt(self, value): 224 | shape = value.shape 225 | value_flatten = value.flatten() 226 | n = self.public_key.get_n() 227 | p, q = self.privacy_key.get_p_q() 228 | 229 | pool_inputs = [] 230 | for i in range(len(value_flatten)): 231 | pool_inputs.append( 232 | [value_flatten[i], n, p, q] 233 | ) 234 | 235 | pool = Pool(N_JOBS) 236 | ret = pool.starmap(_static_decrypt, pool_inputs) 237 | pool.close() 238 | pool.join() 239 | 240 | ret = np.array(ret) 241 | return ret.reshape(shape) 242 | 243 | def decrypt(self, value): 244 | if self.privacy_key is not None: 245 | if not isinstance(value, np.ndarray): 246 | return self._dynamic_decrypt(value) 247 | else: 248 | return self._multiprocessing_decrypt(value) 249 | else: 250 | return None 251 | 252 | 253 | def paillier_enc(plain_list, cls_paillier,args): 254 | 255 | quan_bits = args.quan_bits 256 | isBatch = args.isBatch 257 | isSpars = args.isSpars 258 | num_clients = args.n_clients 259 | batch_size = args.enc_batch_size 260 | padding_bits = int(np.ceil(np.log2(num_clients + 1))) 261 | elem_bits = padding_bits + quan_bits 262 | 263 | quan_begin = time.time() 264 | 265 | quan_list = quantize(plain_list,quan_bits,num_clients) 266 | plain_shape = len(plain_list) 267 | quan_end = time.time() 268 | #print(f"*******************Client {id} Encrypt*******************") 269 | #print(f"Quantized time:{quan_end-quan_begin:.3f}s",end=" | ") 270 | 271 | if isBatch: 272 | batch_num = int(np.ceil(len(plain_list) / batch_size)) 273 | # if len(plain_list) % batch_num != 0: 274 | # padding_num = batch_num * batch_size - len(plain_list) 275 | # plain_list.extend([0]*padding_num) 276 | batch_begin = time.time() 277 | quan_list = batch_padding(quan_list,cls_paillier.key_length,elem_bits,batch_size=batch_size) 278 | # unquan_list = unbatching_padding(quan_list,elem_bits,batch_size) 279 | batch_end = time.time() 280 | #print(f"Batching time:{batch_end-batch_begin:.3f}s",end=" | ") 281 | if isSpars == 'topk': 282 | # 取最大的 topk 个 batch 283 | topk = int(np.ceil(batch_num * args.topk)) 284 | sign = np.sign(np.array(plain_list)) 285 | tmp_list = (np.array(plain_list) * sign).tolist() 286 | plain_batchs = [tmp_list[i * batch_size : (i+1) * batch_size ]for i in range(batch_num)] 287 | avg_list = [np.average(batch) for batch in plain_batchs] 288 | max_avg_list = np.sort(avg_list)[::-1][:topk] 289 | tmp_list = [] 290 | mask_list = [0] * batch_num 291 | for i in range(len(max_avg_list)): 292 | tmp_list.append(avg_list.index(max_avg_list[i])) 293 | 294 | tmp_list = np.sort(tmp_list) 295 | for i in tmp_list: 296 | mask_list[i] = 1 297 | 298 | quan_list = [quan_list[tmp_list[i]] for i in range(len(tmp_list))] 299 | 300 | # paillier encrypt 301 | enc_begin = time.time() 302 | cipher_list = [] 303 | for longint in quan_list: 304 | tmp_cipher = cls_paillier.encrypt(longint) 305 | cipher_list.append(tmp_cipher) 306 | enc_end = time.time() 307 | #print(f"Encrypt time:{enc_end-enc_begin:.3f}s") 308 | 309 | if isSpars == 'topk': 310 | return cipher_list,mask_list 311 | 312 | return cipher_list 313 | 314 | def paillier_dec(cipher_list, cls_paillier,plain_shape,args): 315 | quan_bits = args.quan_bits 316 | num_clients = args.n_clients 317 | isBatch = args.isBatch 318 | padding_bits = int(np.ceil(np.log2(num_clients + 1))) 319 | elem_bits = quan_bits + padding_bits 320 | # paillier decrypt 321 | dec_begin = time.time() 322 | decrypted_m = [] 323 | batch_size = args.enc_batch_size 324 | for cipher in cipher_list: 325 | if cipher == 0: 326 | dec_plain = 0 327 | decrypted_m.append(dec_plain) 328 | else: 329 | dec_plain = cls_paillier.decrypt(cipher) 330 | decrypted_m.append(dec_plain) 331 | quantized_sum = decrypted_m 332 | dec_end = time.time() 333 | #print(f"*******************Client {id} Decrypt*******************") 334 | #print(f"Decrypt time:{dec_end-dec_begin:.3f}s",end=" | ") 335 | 336 | if isBatch: 337 | # unbatching 338 | unbatch_begin = time.time() 339 | unbatch_plaintext = unbatching_padding(decrypted_m, elem_bits,args.enc_batch_size)[:(int(np.prod(plain_shape)))] 340 | quantized_sum = unbatch_plaintext.reshape(plain_shape) 341 | unbatch_end = time.time() 342 | #print(f"Unbatching time:{unbatch_end - unbatch_begin:.3f}s",end=" | ") 343 | 344 | # unquantized 345 | unquan_begin = time.time() 346 | unquan_plaintext = unquantize(quantized_sum,quan_bits,num_clients) 347 | unquan_end = time.time() 348 | #print("Quan agg:",unquan_plaintext) 349 | #print(f"Unquantized time:{unquan_end-unquan_begin:.3f}s") 350 | return unquan_plaintext -------------------------------------------------------------------------------- /utils/functional.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import warnings 4 | 5 | 6 | def split_indices(num_cumsum, rand_perm): 7 | client_indices_pairs = [(cid, idxs) for cid, idxs in 8 | enumerate(np.split(rand_perm, num_cumsum)[:-1])] 9 | client_dict = dict(client_indices_pairs) 10 | return client_dict 11 | 12 | 13 | def balance_split(num_clients, num_samples): 14 | """Assign same sample sample for each client. 15 | 16 | Args: 17 | num_clients (int): Number of clients for partition. 18 | num_samples (int): Total number of samples. 19 | 20 | Returns: 21 | numpy.ndarray: A numpy array consisting ``num_clients`` integer elements, each represents sample number of corresponding clients. 22 | 23 | """ 24 | num_samples_per_client = int(num_samples / num_clients) 25 | client_sample_nums = (np.ones(num_clients) * num_samples_per_client).astype( 26 | int) 27 | return client_sample_nums 28 | 29 | 30 | def lognormal_unbalance_split(num_clients, num_samples, unbalance_sgm): 31 | """Assign different sample number for each client using Log-Normal distribution. 32 | 33 | Sample numbers for clients are drawn from Log-Normal distribution. 34 | 35 | Args: 36 | num_clients (int): Number of clients for partition. 37 | num_samples (int): Total number of samples. 38 | unbalance_sgm (float): Log-normal variance. When equals to ``0``, the partition is equal to :func:`balance_partition`. 39 | 40 | Returns: 41 | numpy.ndarray: A numpy array consisting ``num_clients`` integer elements, each represents sample number of corresponding clients. 42 | 43 | """ 44 | num_samples_per_client = int(num_samples / num_clients) 45 | if unbalance_sgm != 0: 46 | client_sample_nums = np.random.lognormal(mean=np.log(num_samples_per_client), 47 | sigma=unbalance_sgm, 48 | size=num_clients) 49 | client_sample_nums = ( 50 | client_sample_nums / np.sum(client_sample_nums) * num_samples).astype(int) 51 | diff = np.sum(client_sample_nums) - num_samples # diff <= 0 52 | 53 | # Add/Subtract the excess number starting from first client 54 | if diff != 0: 55 | for cid in range(num_clients): 56 | if client_sample_nums[cid] > diff: 57 | client_sample_nums[cid] -= diff 58 | break 59 | else: 60 | client_sample_nums = (np.ones(num_clients) * num_samples_per_client).astype(int) 61 | 62 | return client_sample_nums 63 | 64 | 65 | def dirichlet_unbalance_split(num_clients, num_samples, alpha): 66 | """Assign different sample number for each client using Log-Normal distribution. 67 | 68 | Sample numbers for clients are drawn from Log-Normal distribution. 69 | 70 | Args: 71 | num_clients (int): Number of clients for partition. 72 | num_samples (int): Total number of samples. 73 | alpha (float): Dirichlet concentration parameter 74 | 75 | Returns: 76 | numpy.ndarray: A numpy array consisting ``num_clients`` integer elements, each represents sample number of corresponding clients. 77 | 78 | """ 79 | min_size = 0 80 | while min_size < 10: 81 | proportions = np.random.dirichlet(np.repeat(alpha, num_clients)) 82 | proportions = proportions / proportions.sum() 83 | min_size = np.min(proportions * num_samples) 84 | 85 | client_sample_nums = (proportions * num_samples).astype(int) 86 | return client_sample_nums 87 | 88 | 89 | def homo_partition(client_sample_nums, num_samples): 90 | """Partition data indices in IID way given sample numbers for each clients. 91 | 92 | Args: 93 | client_sample_nums (numpy.ndarray): Sample numbers for each clients. 94 | num_samples (int): Number of samples. 95 | 96 | Returns: 97 | dict: ``{ client_id: indices}``. 98 | 99 | """ 100 | rand_perm = np.random.permutation(num_samples) 101 | num_cumsum = np.cumsum(client_sample_nums).astype(int) 102 | client_dict = split_indices(num_cumsum, rand_perm) 103 | return client_dict 104 | 105 | 106 | def hetero_dir_partition(targets, num_clients, num_classes, dir_alpha, min_require_size=None): 107 | """ 108 | 109 | Non-iid partition based on Dirichlet distribution. The method is from "hetero-dir" partition of 110 | `Bayesian Nonparametric Federated Learning of Neural Networks `_ 111 | and `Federated Learning with Matched Averaging `_. 112 | 113 | This method simulates heterogeneous partition for which number of data points and class 114 | proportions are unbalanced. Samples will be partitioned into :math:`J` clients by sampling 115 | :math:`p_k \sim \text{Dir}_{J}(\alpha)` and allocating a :math:`p_{p,j}` proportion of the 116 | samples of class :math:`k` to local client :math:`j`. 117 | 118 | Sample number for each client is decided in this function. 119 | 120 | Args: 121 | targets (list or numpy.ndarray): Sample targets. Unshuffled preferred. 122 | num_clients (int): Number of clients for partition. 123 | num_classes (int): Number of classes in samples. 124 | dir_alpha (float): Parameter alpha for Dirichlet distribution. 125 | min_require_size (int, optional): Minimum required sample number for each client. If set to ``None``, then equals to ``num_classes``. 126 | 127 | Returns: 128 | dict: ``{ client_id: indices}``. 129 | """ 130 | if min_require_size is None: 131 | min_require_size = num_classes 132 | 133 | if not isinstance(targets, np.ndarray): 134 | targets = np.array(targets) 135 | num_samples = targets.shape[0] 136 | 137 | min_size = 0 138 | while min_size < min_require_size: 139 | idx_batch = [[] for _ in range(num_clients)] 140 | # for each class in the dataset 141 | for k in range(num_classes): 142 | idx_k = np.where(targets == k)[0] 143 | np.random.shuffle(idx_k) 144 | proportions = np.random.dirichlet( 145 | np.repeat(dir_alpha, num_clients)) 146 | # Balance 147 | proportions = np.array( 148 | [p * (len(idx_j) < num_samples / num_clients) for p, idx_j in 149 | zip(proportions, idx_batch)]) 150 | proportions = proportions / proportions.sum() 151 | proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] 152 | idx_batch = [idx_j + idx.tolist() for idx_j, idx in 153 | zip(idx_batch, np.split(idx_k, proportions))] 154 | min_size = min([len(idx_j) for idx_j in idx_batch]) 155 | 156 | client_dict = dict() 157 | for cid in range(num_clients): 158 | np.random.shuffle(idx_batch[cid]) 159 | client_dict[cid] = np.array(idx_batch[cid]) 160 | 161 | return client_dict 162 | 163 | 164 | def shards_partition(targets, num_clients, num_shards): 165 | """Non-iid partition used in FedAvg `paper `_. 166 | 167 | Args: 168 | targets (list or numpy.ndarray): Sample targets. Unshuffled preferred. 169 | num_clients (int): Number of clients for partition. 170 | num_shards (int): Number of shards in partition. 171 | 172 | Returns: 173 | dict: ``{ client_id: indices}``. 174 | 175 | """ 176 | if not isinstance(targets, np.ndarray): 177 | targets = np.array(targets) 178 | num_samples = targets.shape[0] 179 | 180 | size_shard = int(num_samples / num_shards) 181 | if num_samples % num_shards != 0: 182 | warnings.warn("warning: length of dataset isn't divided exactly by num_shards. " 183 | "Some samples will be dropped.") 184 | 185 | shards_per_client = int(num_shards / num_clients) 186 | if num_shards % num_clients != 0: 187 | warnings.warn("warning: num_shards isn't divided exactly by num_clients. " 188 | "Some shards will be dropped.") 189 | 190 | indices = np.arange(num_samples) 191 | # sort sample indices according to labels 192 | indices_targets = np.vstack((indices, targets)) 193 | indices_targets = indices_targets[:, indices_targets[1, :].argsort()] 194 | # corresponding labels after sorting are [0, .., 0, 1, ..., 1, ...] 195 | sorted_indices = indices_targets[0, :] 196 | 197 | # permute shards idx, and slice shards_per_client shards for each client 198 | rand_perm = np.random.permutation(num_shards) 199 | num_client_shards = np.ones(num_clients) * shards_per_client 200 | # sample index must be int 201 | num_cumsum = np.cumsum(num_client_shards).astype(int) 202 | # shard indices for each client 203 | client_shards_dict = split_indices(num_cumsum, rand_perm) 204 | 205 | # map shard idx to sample idx for each client 206 | client_dict = dict() 207 | for cid in range(num_clients): 208 | shards_set = client_shards_dict[cid] 209 | current_indices = [ 210 | sorted_indices[shard_id * size_shard: (shard_id + 1) * size_shard] 211 | for shard_id in shards_set] 212 | client_dict[cid] = np.concatenate(current_indices, axis=0) 213 | 214 | return client_dict 215 | 216 | 217 | def client_inner_dirichlet_partition(targets, num_clients, num_classes, dir_alpha, 218 | client_sample_nums, verbose=True): 219 | """Non-iid Dirichlet partition. 220 | 221 | The method is from The method is from paper `Federated Learning Based on Dynamic Regularization `_. 222 | This function can be used by given specific sample number for all clients ``client_sample_nums``. 223 | It's different from :func:`hetero_dir_partition`. 224 | 225 | Args: 226 | targets (list or numpy.ndarray): Sample targets. 227 | num_clients (int): Number of clients for partition. 228 | num_classes (int): Number of classes in samples. 229 | dir_alpha (float): Parameter alpha for Dirichlet distribution. 230 | client_sample_nums (numpy.ndarray): A numpy array consisting ``num_clients`` integer elements, each represents sample number of corresponding clients. 231 | verbose (bool, optional): Whether to print partition process. Default as ``True``. 232 | 233 | Returns: 234 | dict: ``{ client_id: indices}``. 235 | 236 | """ 237 | if not isinstance(targets, np.ndarray): 238 | targets = np.array(targets) 239 | 240 | rand_perm = np.random.permutation(targets.shape[0]) 241 | targets = targets[rand_perm] 242 | 243 | class_priors = np.random.dirichlet(alpha=[dir_alpha] * num_classes, 244 | size=num_clients) 245 | prior_cumsum = np.cumsum(class_priors, axis=1) 246 | idx_list = [np.where(targets == i)[0] for i in range(num_classes)] 247 | class_amount = [len(idx_list[i]) for i in range(num_classes)] 248 | 249 | client_indices = [np.zeros(client_sample_nums[cid]).astype(np.int64) for cid in 250 | range(num_clients)] 251 | 252 | while np.sum(client_sample_nums) != 0: 253 | curr_cid = np.random.randint(num_clients) 254 | # If current node is full resample a client 255 | if verbose: 256 | print('Remaining Data: %d' % np.sum(client_sample_nums)) 257 | if client_sample_nums[curr_cid] <= 0: 258 | continue 259 | client_sample_nums[curr_cid] -= 1 260 | curr_prior = prior_cumsum[curr_cid] 261 | while True: 262 | curr_class = np.argmax(np.random.uniform() <= curr_prior) 263 | # Redraw class label if no rest in current class samples 264 | if class_amount[curr_class] <= 0: 265 | continue 266 | class_amount[curr_class] -= 1 267 | client_indices[curr_cid][client_sample_nums[curr_cid]] = \ 268 | idx_list[curr_class][class_amount[curr_class]] 269 | 270 | break 271 | 272 | client_dict = {cid: client_indices[cid] for cid in range(num_clients)} 273 | return client_dict 274 | 275 | 276 | def label_skew_quantity_based_partition(targets, num_clients, num_classes, major_classes_num): 277 | """ 278 | 279 | Args: 280 | targets (np.ndarray): Labels od dataset. 281 | num_clients (int): Number of clients. 282 | num_classes (int): Number of unique classes. 283 | major_classes_num (int): Number of classes for each client, should be less then ``num_classes``. 284 | 285 | Returns: 286 | dict: ``{ client_id: indices}``. 287 | 288 | """ 289 | idx_batch = [np.ndarray(0, dtype=np.int64) for _ in range(num_clients)] 290 | # only for major_classes_num < num_classes. 291 | # if major_classes_num = num_classes, it equals to IID partition 292 | times = [0 for _ in range(num_classes)] 293 | contain = [] 294 | for cid in range(num_clients): 295 | current = [cid % num_classes] 296 | times[cid % num_classes] += 1 297 | j = 1 298 | while j < major_classes_num: 299 | ind = np.random.randint(num_classes) 300 | if ind not in current: 301 | j += 1 302 | current.append(ind) 303 | times[ind] += 1 304 | contain.append(current) 305 | 306 | for k in range(num_classes): 307 | idx_k = np.where(targets == k)[0] 308 | np.random.shuffle(idx_k) 309 | split = np.array_split(idx_k, times[k]) 310 | ids = 0 311 | for cid in range(num_clients): 312 | if k in contain[cid]: 313 | idx_batch[cid] = np.append(idx_batch[cid], split[ids]) 314 | ids += 1 315 | 316 | client_dict = {cid: idx_batch[cid] for cid in range(num_clients)} 317 | return client_dict 318 | 319 | 320 | def fcube_synthetic_partition(data): 321 | """Feature-distribution-skew:synthetic partition. 322 | 323 | Synthetic partition for FCUBE dataset. This partition is from `Federated Learning on Non-IID Data Silos: An Experimental Study `_. 324 | 325 | Args: 326 | data (np.ndarray): Data of dataset :class:`FCUBE`. 327 | 328 | Returns: 329 | dict: ``{ client_id: indices}``. 330 | """ 331 | num_clients = 4 332 | client_indices = [[] for _ in range(num_clients)] 333 | for idx, sample in enumerate(data): 334 | p1, p2, p3 = sample 335 | if (p1 > 0 and p2 > 0 and p3 > 0) or (p1 < 0 and p2 < 0 and p3 < 0): 336 | client_indices[0].append(idx) 337 | elif (p1 > 0 and p2 > 0 and p3 < 0) or (p1 < 0 and p2 < 0 and p3 > 0): 338 | client_indices[1].append(idx) 339 | elif (p1 > 0 and p2 < 0 and p3 > 0) or (p1 < 0 and p2 > 0 and p3 < 0): 340 | client_indices[2].append(idx) 341 | else: 342 | client_indices[3].append(idx) 343 | client_dict = {cid: np.array(client_indices[cid]).astype(int) for cid in range(num_clients)} 344 | return client_dict 345 | 346 | 347 | def samples_num_count(client_dict, num_clients): 348 | client_samples_nums = [[cid, client_dict[cid].shape[0]] for cid in 349 | range(num_clients)] 350 | client_sample_count = pd.DataFrame(data=client_samples_nums, 351 | columns=['client', 'num_samples']).set_index('client') 352 | return client_sample_count 353 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | from torchvision.transforms import ToTensor 3 | from torch.utils.data import ConcatDataset 4 | import numpy as np 5 | from utils.subset import CustomSubset 6 | from utils.partition import CIFAR10Partitioner,MNISTPartitioner,FMNISTPartitioner,CIFAR100Partitioner 7 | import pickle 8 | import os 9 | 10 | '''' 11 | load and split dataset for clients 12 | ''' 13 | def load_dataset(args): 14 | # dataset load 15 | if args.dataset == 'MNIST': 16 | transform = transforms.Compose([ToTensor()]) 17 | train_set = datasets.MNIST(root="./data", download=True, transform=transform, train=True) 18 | test_set = datasets.MNIST(root="./data", download=True, transform=transform, train=False) 19 | elif args.dataset == 'FashionMNIST': 20 | transform = transforms.Compose([ToTensor()]) 21 | train_set = datasets.FashionMNIST(root="./data", download=True, transform=transform, train=True) 22 | test_set = datasets.FashionMNIST(root="./data", download=True, transform=transform, train=False) 23 | elif args.dataset == 'CIFAR10': 24 | transform = transforms.Compose([ToTensor()]) 25 | train_set = datasets.CIFAR10(root="./data", download=True, transform=transform, train=True) 26 | test_set = datasets.CIFAR10(root="./data", download=True, transform=transform, train=False) 27 | elif args.dataset == "CIFAR100": 28 | transform = transforms.Compose([ToTensor()]) 29 | train_set = datasets.CIFAR100(root="./data", download=True, transform=transform, train=True) 30 | test_set = datasets.CIFAR100(root="./data", download=True, transform=transform, train=False) 31 | else: 32 | raise ValueError("Please input the correct dataset name, it must be one of:" 33 | "MNIST, FashionMNST, CIFAR10, CIFAR100") 34 | 35 | # data info 36 | dataset_info = {} 37 | dataset_info["classes"] = train_set.classes 38 | dataset_info["num_classes"] = len(train_set.classes) 39 | dataset_info["input_size"] = train_set.data[0].shape[0] 40 | 41 | if len(train_set.data[0].shape) == 2: 42 | dataset_info["num_channels"] = 1 43 | else: 44 | dataset_info["num_channels"] = train_set.data[0].shape[-1] 45 | 46 | client_train_idx, client_test_idx = [], [] 47 | 48 | # labels for train/test dataset 49 | train_labels = np.array(train_set.targets) 50 | test_labels = np.array(test_set.targets) 51 | 52 | # data split method 53 | if args.split == 'noniid': 54 | if args.noniid_method == 'pathological': 55 | if args.dataset == "MNIST": 56 | train_part = MNISTPartitioner(train_labels, 57 | args.n_clients, 58 | partition="noniid-#label", 59 | major_classes_num=args.n_shards, 60 | seed=args.seed) 61 | test_part = MNISTPartitioner(test_labels, 62 | args.n_clients, 63 | partition="iid", 64 | seed=args.seed) 65 | elif args.dataset == "FashionMNIST": 66 | train_part = FMNISTPartitioner(train_labels, 67 | args.n_clients, 68 | partition="noniid-#label", 69 | major_classes_num=args.n_shards, 70 | seed=args.seed) 71 | test_part = FMNISTPartitioner(test_labels, 72 | args.n_clients, 73 | partition="iid", 74 | seed=args.seed) 75 | elif args.dataset == 'CIFAR10': 76 | train_part = CIFAR10Partitioner(train_labels, 77 | args.n_clients, 78 | balance=None, 79 | partition="shards", 80 | num_shards=args.n_shards, 81 | seed=args.seed) 82 | test_part = CIFAR10Partitioner(test_labels, 83 | args.n_clients, 84 | balance=True, 85 | partition="iid", 86 | seed=args.seed) 87 | elif args.dataset == 'CIFAR100': 88 | train_part = CIFAR100Partitioner(train_labels, 89 | args.n_clients, 90 | balance=None, 91 | partition="shards", 92 | num_shards=args.n_shards, 93 | seed=args.seed) 94 | test_part = CIFAR100Partitioner(test_labels, 95 | args.n_clients, 96 | balance=True, 97 | partition="iid", 98 | seed=args.seed) 99 | else: 100 | raise ValueError("Please input the correct dataset name, it must be one of:" 101 | "MNIST, FashionMNIST, CIFAR10, CIFAR100") 102 | 103 | elif args.noniid_method == 'dirichlet': 104 | if args.dataset == "MNIST": 105 | train_part = MNISTPartitioner(train_labels, 106 | args.n_clients, 107 | # partition="noniid-labeldir", 108 | partition="unbalance", 109 | dir_alpha=args.alpha, 110 | seed=args.seed) 111 | test_part = MNISTPartitioner(test_labels, 112 | args.n_clients, 113 | partition="iid", 114 | seed=args.seed) 115 | elif args.dataset == "FashionMNIST": 116 | train_part = FMNISTPartitioner(train_labels, 117 | args.n_clients, 118 | #partition="noniid-labeldir", 119 | partition="unbalance", 120 | dir_alpha=args.alpha, 121 | seed=args.seed) 122 | test_part = FMNISTPartitioner(test_labels, 123 | args.n_clients, 124 | partition="iid", 125 | seed=args.seed) 126 | elif args.dataset == 'CIFAR10': 127 | train_part = CIFAR10Partitioner(train_labels, 128 | args.n_clients, 129 | balance=False, 130 | partition="dirichlet", 131 | unbalance_sgm=args.sgm, 132 | dir_alpha=args.alpha, 133 | seed=args.seed) 134 | test_part = CIFAR10Partitioner(test_labels, 135 | args.n_clients, 136 | balance=True, 137 | partition="iid", 138 | unbalance_sgm=args.sgm, 139 | seed=args.seed) 140 | elif args.dataset == 'CIFAR100': 141 | train_part = CIFAR100Partitioner(train_labels, 142 | args.n_clients, 143 | balance=False, 144 | unbalance_sgm=args.sgm, 145 | partition="dirichlet", 146 | dir_alpha=args.alpha, 147 | seed=args.seed) 148 | test_part = CIFAR100Partitioner(test_labels, 149 | args.n_clients, 150 | balance=True, 151 | partition="iid", 152 | seed=args.seed) 153 | else: 154 | raise ValueError("Please input the correct dataset name, it must be one of:" 155 | "MNIST, FashionMNIST, CIFAR10") 156 | else: 157 | raise ValueError("Please input the correct noniid method, it must be one of:" 158 | "pathological, dirichlet") 159 | elif args.split == 'iid': 160 | if args.dataset == "MNIST": 161 | train_part = MNISTPartitioner(train_labels, 162 | args.n_clients, 163 | partition="iid", 164 | seed=args.seed) 165 | test_part = MNISTPartitioner(test_labels, 166 | args.n_clients, 167 | partition="iid", 168 | seed=args.seed) 169 | elif args.dataset == "FashionMNIST": 170 | train_part = FMNISTPartitioner(train_labels, 171 | args.n_clients, 172 | partition="iid", 173 | seed=args.seed) 174 | test_part = FMNISTPartitioner(test_labels, 175 | args.n_clients, 176 | partition="iid", 177 | seed=args.seed) 178 | elif args.dataset == 'CIFAR10': 179 | train_part = CIFAR10Partitioner(train_labels, 180 | args.n_clients, 181 | balance=True, 182 | partition="iid", 183 | seed=args.seed) 184 | test_part = CIFAR10Partitioner(test_labels, 185 | args.n_clients, 186 | balance=True, 187 | partition="iid", 188 | seed=args.seed) 189 | elif args.dataset == 'CIFAR100': 190 | train_part = CIFAR100Partitioner(train_labels, 191 | args.n_clients, 192 | balance=True, 193 | partition="iid", 194 | seed=args.seed) 195 | test_part = CIFAR100Partitioner(test_labels, 196 | args.n_clients, 197 | balance=True, 198 | partition="iid", 199 | seed=args.seed) 200 | else: 201 | raise ValueError("Please input the correct dataset name, it must be one of:" 202 | "MNIST, FashionMNIST, CIFAR10, CIFAR100") 203 | else: 204 | raise ValueError("Please input the correct split method, it must be one of:" 205 | "iid, noniid") 206 | 207 | # index to value 208 | for value in train_part.client_dict.values(): 209 | client_train_idx.append(value) 210 | for value in test_part.client_dict.values(): 211 | client_test_idx.append(value) 212 | 213 | # subset of the original train/test dataset for each client 214 | client_train_sets = [CustomSubset(train_set,idx) for idx in client_train_idx] 215 | client_test_sets = [CustomSubset(test_set, idx) for idx in client_test_idx] 216 | 217 | # save the load_data results 218 | train_file = os.path.join(args.data_dir, args.dataset + '_train') 219 | test_file = os.path.join(args.data_dir, args.dataset + '_test') 220 | with open(train_file, "wb") as f: 221 | train_bytes = pickle.dumps(client_train_idx) 222 | f.write(train_bytes) 223 | with open(test_file, "wb") as f: 224 | test_bytes = pickle.dumps(client_test_idx) 225 | f.write(test_bytes) 226 | 227 | # server dataset for fedavg 228 | server_test_idxs = [] 229 | for i in range(len(client_test_idx)): 230 | server_test_idxs += client_test_sets[i] 231 | server_test_sets = server_test_idxs 232 | return client_train_sets, client_test_sets, dataset_info, server_test_sets 233 | 234 | 235 | def load_exist(args): 236 | ''' 237 | load existing dataset for clients 238 | ''' 239 | if args.dataset == 'MNIST': 240 | transform = transforms.Compose([ToTensor()]) 241 | train_set = datasets.MNIST(root="./data", download=True, transform=transform, train=True) 242 | test_set = datasets.MNIST(root="./data", download=True, transform=transform, train=False) 243 | elif args.dataset == 'FashionMNIST': 244 | transform = transforms.Compose([ToTensor()]) 245 | train_set = datasets.FashionMNIST(root="./data", download=True, transform=transform, train=True) 246 | test_set = datasets.FashionMNIST(root="./data", download=True, transform=transform, train=False) 247 | elif args.dataset == 'CIFAR10': 248 | transform = transforms.Compose([ToTensor()]) 249 | train_set = datasets.CIFAR10(root="./data", download=True, transform=transform, train=True) 250 | test_set = datasets.CIFAR10(root="./data", download=True, transform=transform, train=False) 251 | elif args.dataset == "CIFAR100": 252 | transform = transforms.Compose([ToTensor()]) 253 | train_set = datasets.CIFAR100(root="./data", download=True, transform=transform, train=True) 254 | test_set = datasets.CIFAR100(root="./data", download=True, transform=transform, train=False) 255 | else: 256 | raise ValueError("Please input the correct dataset name, it must be one of:" 257 | "MNIST, FashionMNST, CIFAR10") 258 | 259 | # data info 260 | dataset_info = {} 261 | dataset_info["classes"] = train_set.classes 262 | dataset_info["num_classes"] = len(train_set.classes) 263 | dataset_info["input_size"] = train_set.data[0].shape[0] 264 | 265 | if len(train_set.data[0].shape) == 2: 266 | dataset_info["num_channels"] = 1 267 | else: 268 | dataset_info["num_channels"] = train_set.data[0].shape[-1] 269 | 270 | # read dataset 271 | train_file = os.path.join(args.data_dir, args.dataset + '_train') 272 | test_file = os.path.join(args.data_dir, args.dataset + '_test') 273 | with open(train_file, "rb") as f: 274 | train_bytes = f.read() 275 | client_train_idx = pickle.loads(train_bytes) 276 | with open(test_file, "rb") as f: 277 | test_bytes = f.read() 278 | client_test_idx = pickle.loads(test_bytes) 279 | 280 | 281 | # subset of the original train/test dataset for each client 282 | client_train_sets = [CustomSubset(train_set,idx) for idx in client_train_idx] 283 | client_test_sets = [CustomSubset(test_set, idx) for idx in client_test_idx] 284 | 285 | # server dataset for fedavg 286 | server_test_idxs = [] 287 | for i in range(len(client_test_idx)): 288 | server_test_idxs += client_test_sets[i] 289 | server_test_sets = server_test_idxs 290 | 291 | return client_train_sets, client_test_sets, dataset_info,server_test_sets -------------------------------------------------------------------------------- /client.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import torch 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from collections import OrderedDict 9 | import time 10 | from utils.util import logging 11 | import random 12 | import tenseal as ts 13 | from threading import Thread 14 | from encryption.ckks import ckks_enc,ckks_dec 15 | from encryption.bfv import bfv_enc,bfv_dec 16 | from encryption.paillier import paillier_enc,paillier_dec 17 | 18 | import utils.min_hash as min_lsh 19 | 20 | import pickle 21 | from multiprocessing import shared_memory 22 | 23 | 24 | def params_tolist(model): 25 | """ 26 | Model parameters converted to list 27 | 28 | Args: 29 | model: 30 | Model to be converted. 31 | Returns: 32 | params_list 33 | The converted parameter list. 34 | params_num 35 | The amount of parameters for each layer. 36 | layer_shape 37 | Shape of each layer. 38 | """ 39 | model.to('cpu') 40 | local_state = model.state_dict() 41 | params_list = [] 42 | layer_shape = {} 43 | params_num = {} 44 | layer_params = [] 45 | for key in model.state_dict().keys(): 46 | layer_shape[key] = local_state[key].shape 47 | params_num[key] = int(np.prod(local_state[key].shape)) 48 | layer_params = local_state[key].reshape(params_num[key]).tolist() 49 | params_list.append(layer_params) 50 | params_list = [b for a in params_list for b in a] 51 | return params_list,params_num,layer_shape 52 | 53 | 54 | def params_tomodel(model,global_list,params_num,layer_shape,args,params_list): 55 | """ 56 | Parameter list to model 57 | 58 | Args: 59 | model: 60 | The model obtained after parameter conversion. 61 | global_list 62 | Global model parameter list. 63 | params_num 64 | The amount of parameters for each layer. 65 | layer_shape 66 | Shape of each layer. 67 | args 68 | Hyper-parameters. 69 | params_list 70 | Local parameter list 71 | Returns: 72 | None 73 | """ 74 | update_state = OrderedDict() 75 | model.to('cpu') 76 | idx_cnt = 0 77 | if args.isSpars == 'topk' or args.isSpars == 'randk': 78 | for idx, key in enumerate(model.state_dict().keys()): 79 | layer_size = int(params_num[key]) 80 | tmp = global_list[idx_cnt : idx_cnt + layer_size] 81 | 82 | # The part with a value of 0 is replaced by local parameters. 83 | for idx_tmp in range(len(tmp)): 84 | if tmp[idx_tmp] == 0 and ( idx_tmp == len(tmp)- 1 or tmp[idx_tmp+1]==0 ): 85 | tmp[idx_tmp] = params_list[idx_cnt + idx_tmp] 86 | # global_list[idx_cnt+idx_tmp] = tmp[idx_tmp] 87 | update_state[ 88 | key] = torch.from_numpy(np.array(tmp).reshape(layer_shape[key])) 89 | idx_cnt += layer_size 90 | else: 91 | for idx, key in enumerate(model.state_dict().keys()): 92 | layer_size = int(params_num[key]) 93 | tmp = global_list[idx_cnt:idx_cnt + layer_size] 94 | update_state[ 95 | key] = torch.from_numpy(np.array(tmp).reshape(layer_shape[key])) 96 | idx_cnt += layer_size 97 | 98 | model.load_state_dict(update_state) 99 | 100 | 101 | def minHash(rank,random_R,global_list,params_list, args, quan_thres = 0.05): 102 | ''' 103 | quan_thres: Tthreshold value used for quantization 104 | sim_len: Number of hash functions 105 | ''' 106 | sim_len = args.sim_len 107 | 108 | mat = np.concatenate((np.array(global_list).reshape(-1,1),np.array(params_list).reshape(-1,1)),axis=1) 109 | 110 | quan_matrix = min_lsh.quan_params(mat,quan_thres) 111 | 112 | sim_mat = min_lsh.sigMatrixGen(quan_matrix,random_R, sim_len) 113 | 114 | # client_sim2 = min_lsh.dim_reduce_sim(sim_mat) 115 | 116 | minHash = (sim_mat[:,1]).tolist() 117 | return minHash 118 | 119 | # Simulate straggler 120 | def straggler(rank): 121 | timewait = np.random.randint(10,15) 122 | if rank == 1: 123 | time.sleep(timewait) 124 | if rank == 2: 125 | time.sleep(timewait) 126 | 127 | def client_process(rank, args, model, device,dataset, test_dataset, kwargs,kwargs_IPC,train_weights): 128 | torch.manual_seed(args.seed + rank) 129 | queue = kwargs_IPC['queues'][rank] 130 | e = kwargs_IPC['e'] 131 | lock = kwargs_IPC['lock'] 132 | pipe = kwargs_IPC['client_pipes'][rank][0] 133 | flag = kwargs_IPC['flag'] 134 | e_server = kwargs_IPC['e_server'] 135 | acc_queue = kwargs_IPC['acc_queue'] 136 | self_weight = train_weights[rank] 137 | acc_pipe = kwargs_IPC['send_pipes'][rank][0] 138 | 139 | if args.enc and args.algorithm == 'paillier': 140 | enc_tools = kwargs_IPC['enc_tools'] 141 | else: 142 | enc_tools = {} 143 | 144 | if args.isSelection: 145 | random_R = kwargs_IPC['random_R'] 146 | 147 | hash_queue = kwargs_IPC['hash_queue'] 148 | train_loader = torch.utils.data.DataLoader(dataset, **kwargs) 149 | test_loader = torch.utils.data.DataLoader(test_dataset, **kwargs) 150 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 151 | 152 | if rank == 0: 153 | spars_list = [] 154 | sum_masks = [] 155 | epoch = 0 156 | 157 | self_flag = True 158 | acc_list = [] 159 | 160 | while not flag.value: 161 | 162 | #epoch_begin = time.time() 163 | 164 | #train_begin = time.time() 165 | train_epoch(epoch, args, model, device, train_loader, optimizer,rank) 166 | #train_end = time.time() 167 | #logging("id:{},train time:{}".format(rank,train_end-train_begin),args) 168 | 169 | params_list,params_num,layer_shape = params_tolist(model) 170 | total_sum = sum(params_num.values()) 171 | if args.enc and args.algorithm == 'paillier': 172 | enc_tools.update({'total_params':total_sum}) 173 | 174 | # if selected 175 | if self_flag: 176 | # if epoch > 0 : 177 | # straggler(rank) 178 | if args.enc: 179 | if args.algorithm == 'paillier': 180 | params_list = (np.array(params_list) * self_weight).tolist() 181 | if args.algorithm == 'bfv': 182 | params_list = (np.array(params_list) * self_weight).tolist() 183 | if args.isSpars == 'topk' : 184 | #enc_begin = time.time() 185 | cipher, mask = enc_params(params_list,enc_tools, args) 186 | pipe.send([rank,mask,cipher]) 187 | #enc_end = time.time() 188 | #logging("id:{},enc time:{}".format(rank,enc_end-enc_begin),args) 189 | # lock.acquire() 190 | # logging("client {}, send mask {}.".format(rank,mask),args) 191 | # lock.release() 192 | elif args.isSpars == 'randk': 193 | cipher, randk_list = enc_params(params_list,enc_tools,args,epoch = epoch) 194 | if rank == 0: 195 | logging("epoch:{},rand_K:{}".format(epoch,randk_list),args) 196 | pipe.send([rank,cipher]) 197 | 198 | elif args.isSpars == 'full': 199 | #enc_begin = time.time() 200 | cipher = enc_params(params_list,enc_tools,args,epoch = epoch) 201 | pipe.send([rank,cipher]) 202 | #enc_end = time.time() 203 | #logging("id:{},enc time:{}".format(rank,enc_end-enc_begin),args) 204 | else: 205 | pipe.send([rank,params_list]) 206 | # lock.acquire() 207 | # logging("client {}, send params {}.".format(rank,params_list[0]),args) 208 | # lock.release() 209 | 210 | if flag.value: 211 | break 212 | 213 | # Waiting for server aggregation 214 | e.wait() 215 | 216 | global_list = queue.get() 217 | involved_frac = global_list[0] 218 | global_weights = global_list[1] 219 | 220 | if args.enc: 221 | if args.isSpars == 'topk': 222 | #dec_begin = time.time() 223 | sum_masks = involved_frac 224 | global_weights = (dec_params(global_weights,sum_masks,enc_tools, args)).tolist() 225 | #dec_end = time.time() 226 | #logging("id:{},dec time:{}".format(rank,dec_end-dec_begin),args) 227 | elif args.isSpars == 'randk': 228 | global_weights = (dec_params(global_weights,sum_masks,enc_tools, args, randk_list)).tolist() 229 | else: 230 | #dec_begin = time.time() 231 | global_weights = (dec_params(global_weights,sum_masks, enc_tools,args) / involved_frac).tolist() 232 | #dec_end = time.time() 233 | #logging("id:{},dec time:{}".format(rank,dec_end-dec_begin),args) 234 | global_weights = global_weights[:total_sum] 235 | else: 236 | global_weights = (np.array(global_weights) / involved_frac).tolist() 237 | 238 | # lock.acquire() 239 | # print('client{},receive{}'.format(rank,global_weights[0])) 240 | # lock.release() 241 | params_list,params_num,layer_shape = params_tolist(model) 242 | 243 | params_tomodel(model,global_weights,params_num,layer_shape,args,params_list) 244 | 245 | if args.enc: 246 | client_acc,client_loss = test_epoch(model, device, test_loader) 247 | acc_pipe.send([rank,client_acc,client_loss]) 248 | print('client{},acc:{},loss:{}'.format(rank,client_acc,client_loss)) 249 | 250 | if args.isSelection: 251 | client_hash = minHash(rank, random_R,global_weights,params_list,args) 252 | hash_queue.put([rank,client_hash]) 253 | 254 | if flag.value: 255 | break 256 | 257 | # Wait for server to make client selection 258 | e_server.wait() 259 | 260 | selected_file = os.path.join(args.data_dir, args.dataset + 'selected') 261 | with open(selected_file, "rb") as f: 262 | clients_bytes = f.read() 263 | clients_share = list(pickle.loads(clients_bytes))[0] 264 | clients_weights = list(pickle.loads(clients_bytes))[1] 265 | 266 | if rank not in clients_share: 267 | self_flag = False 268 | else: 269 | self_flag = True 270 | idx = clients_share.index(rank) 271 | self_weight = clients_weights[idx] 272 | #epoch_end = time.time() 273 | 274 | epoch += 1 275 | 276 | lock.acquire() 277 | logging("client {} finished!".format(rank),args) 278 | lock.release() 279 | return 280 | 281 | 282 | def test(args, model, device, dataset, kwargs): 283 | torch.manual_seed(args.seed) 284 | 285 | test_loader = torch.utils.data.DataLoader(dataset, **kwargs) 286 | 287 | return test_epoch(model, device, test_loader) 288 | 289 | 290 | def train_epoch(epoch, args, model, device, data_loader, optimizer,rank): 291 | model.to(device) 292 | model.train() 293 | loss_fn = torch.nn.CrossEntropyLoss() 294 | for batch_idx, (data, target) in enumerate(data_loader): 295 | 296 | output = model(data.to(device)) 297 | target = target.to(device) 298 | # loss = F.nll_loss(output, target.to(device)) 299 | # loss = torch.nn.CrossEntropyLoss()(output, target.to(device)) 300 | loss = loss_fn(output, target) 301 | optimizer.zero_grad() 302 | loss.backward() 303 | optimizer.step() 304 | # if batch_idx == len(data_loader) - 1: 305 | # logging('client {}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 306 | # rank, epoch, batch_idx * len(data), len(data_loader.dataset), 307 | # 100. * batch_idx / len(data_loader), loss.item())) 308 | 309 | 310 | def test_epoch(model, device, data_loader): 311 | model.to(device) 312 | model.eval() 313 | correct = 0 314 | test_loss = 0 315 | with torch.no_grad(): 316 | for data, target in data_loader: 317 | output = model(data.to(device)) 318 | #test_loss += F.nll_loss(output, target.to(device), reduction='sum').item() # sum up batch loss 319 | #test_loss += torch.nn.CrossEntropyLoss()(output, target.to(device),reduction='sum') 320 | test_loss +=F.cross_entropy(output, target.to(device), reduction='sum') 321 | pred = output.max(1)[1] # get the index of the max log-probability 322 | correct += pred.eq(target.to(device)).sum().item() 323 | test_loss /= len(data_loader.dataset) 324 | #print("loss",test_loss) 325 | # print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 326 | # test_loss, correct, len(data_loader.dataset), 327 | # 100. * correct / len(data_loader.dataset))) 328 | client_acc = correct / len(data_loader.dataset) 329 | 330 | # set the return variable format 331 | test_loss = round(float(test_loss),3) 332 | client_acc = round(client_acc,4)*100 333 | return client_acc, test_loss 334 | 335 | 336 | # Encrypt all parameters of a client 337 | def enc_params(params_list,enc_tools,args,epoch = 0): 338 | if args.algorithm == 'ckks': 339 | ckks_file = os.path.join(args.data_dir + 'context_params') 340 | with open(ckks_file, "rb") as f: 341 | params = f.read() 342 | ckks_ctx = ts.context_from(params) 343 | return ckks_enc(params_list,ckks_ctx,isBatch=args.isBatch,batch_size=args.enc_batch_size, 344 | topk=args.topk,round = epoch,randk_seed=args.randk_seed, is_spars = args.isSpars) 345 | elif args.algorithm =='paillier': 346 | cls_paillier = enc_tools['cls_paillier'] 347 | return paillier_enc(params_list,cls_paillier,args) 348 | elif args.algorithm == 'bfv': 349 | bfv_file = os.path.join(args.data_dir + 'bfv_ctx') 350 | with open(bfv_file, "rb") as f: 351 | params = f.read() 352 | bfv_ctx = ts.context_from(params) 353 | return bfv_enc(params_list,bfv_ctx,args) 354 | else: 355 | raise ValueError("please select valid algorithm") 356 | 357 | # Decrypt all parameters of a client 358 | def dec_params(cipher_list,sum_masks, enc_tools,args, randk_list = []): 359 | if args.algorithm == 'ckks': 360 | ckks_file = os.path.join(args.data_dir + 'context_params') 361 | with open(ckks_file, "rb") as f: 362 | params = f.read() 363 | ckks_ctx = ts.context_from(params) 364 | sk = ckks_ctx.secret_key() 365 | return ckks_dec(cipher_list,ckks_ctx,sk,args.isBatch,randk_list,sum_masks,args.enc_batch_size) 366 | elif args.algorithm =='paillier': 367 | cls_paillier = enc_tools['cls_paillier'] 368 | total_params = enc_tools['total_params'] 369 | return paillier_dec(cipher_list,cls_paillier,total_params,args) 370 | elif args.algorithm == 'bfv': 371 | bfv_file = os.path.join(args.data_dir + 'bfv_ctx') 372 | with open(bfv_file, "rb") as f: 373 | params = f.read() 374 | bfv_ctx = ts.context_from(params) 375 | sk = bfv_ctx.secret_key() 376 | return bfv_dec(cipher_list,bfv_ctx,sk,args.isBatch,args.quan_bits,args.n_clients,sum_masks,args.enc_batch_size) 377 | else: 378 | raise ValueError("please select valid algorithm") 379 | 380 | -------------------------------------------------------------------------------- /utils/partition.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | 5 | from . import functional as F 6 | 7 | 8 | class DataPartitioner(ABC): 9 | """Base class for data partition in federated learning. 10 | """ 11 | 12 | @abstractmethod 13 | def _perform_partition(self): 14 | raise NotImplementedError 15 | 16 | @abstractmethod 17 | def __getitem__(self, index): 18 | raise NotImplementedError 19 | 20 | @abstractmethod 21 | def __len__(self): 22 | raise NotImplementedError 23 | 24 | 25 | class CIFAR10Partitioner(DataPartitioner): 26 | """CIFAR10 data partitioner. 27 | 28 | Partition CIFAR10 given specific client number. Currently 6 supported partition schemes can be 29 | achieved by passing different combination of parameters in initialization: 30 | 31 | - ``balance=None`` 32 | 33 | - ``partition="dirichlet"``: non-iid partition used in 34 | `Bayesian Nonparametric Federated Learning of Neural Networks `_ 35 | and `Federated Learning with Matched Averaging `_. Refer 36 | to :func:`fedlab.utils.dataset.functional.hetero_dir_partition` for more information. 37 | 38 | - ``partition="shards"``: non-iid method used in FedAvg `paper `_. 39 | Refer to :func:`fedlab.utils.dataset.functional.shards_partition` for more information. 40 | 41 | 42 | - ``balance=True``: "Balance" refers to FL scenario that sample numbers for different clients 43 | are the same. Refer to :func:`fedlab.utils.dataset.functional.balance_partition` for more 44 | information. 45 | 46 | - ``partition="iid"``: Random select samples from complete dataset given sample number for 47 | each client. 48 | 49 | - ``partition="dirichlet"``: Refer to :func:`fedlab.utils.dataset.functional.client_inner_dirichlet_partition` 50 | for more information. 51 | 52 | - ``balance=False``: "Unbalance" refers to FL scenario that sample numbers for different clients 53 | are different. For unbalance method, sample number for each client is drown from Log-Normal 54 | distribution with variance ``unbalanced_sgm``. When ``unbalanced_sgm=0``, partition is 55 | balanced. Refer to :func:`fedlab.utils.dataset.functional.lognormal_unbalance_partition` 56 | for more information. The method is from paper `Federated Learning Based on Dynamic Regularization `_. 57 | 58 | - ``partition="iid"``: Random select samples from complete dataset given sample number for 59 | each client. 60 | 61 | - ``partition="dirichlet"``: Refer to :func:`fedlab.utils.dataset.functional.client_inner_dirichlet_partition` 62 | for more information. 63 | 64 | Args: 65 | targets (list or numpy.ndarray): Targets of dataset for partition. Each element is in range of [0, 1, ..., 9]. 66 | num_clients (int): Number of clients for data partition. 67 | balance (bool, optional): Balanced partition over all clients or not. Default as ``True``. 68 | partition (str, optional): Partition type, only ``"iid"``, ``shards``, ``"dirichlet"`` are supported. Default as ``"iid"``. 69 | unbalance_sgm (float, optional): Log-normal distribution variance for unbalanced data partition over clients. Default as ``0`` for balanced partition. 70 | num_shards (int, optional): Number of shards in non-iid ``"shards"`` partition. Only works if ``partition="shards"``. Default as ``None``. 71 | dir_alpha (float, optional): Dirichlet distribution parameter for non-iid partition. Only works if ``partition="dirichlet"``. Default as ``None``. 72 | verbose (bool, optional): Whether to print partition process. Default as ``True``. 73 | seed (int, optional): Random seed. Default as ``None``. 74 | """ 75 | 76 | num_classes = 10 77 | 78 | def __init__(self, targets, num_clients, 79 | balance=True, partition="iid", 80 | unbalance_sgm=0, 81 | num_shards=None, 82 | dir_alpha=None, 83 | verbose=True, 84 | seed=None): 85 | 86 | self.targets = np.array(targets) # with shape (num_samples,) 87 | self.num_samples = self.targets.shape[0] 88 | self.num_clients = num_clients 89 | self.client_dict = dict() 90 | self.partition = partition 91 | self.balance = balance 92 | self.dir_alpha = dir_alpha 93 | self.num_shards = num_shards 94 | self.unbalance_sgm = unbalance_sgm 95 | self.verbose = verbose 96 | # self.rng = np.random.default_rng(seed) # rng currently not supports randint 97 | np.random.seed(seed) 98 | 99 | # partition scheme check 100 | if balance is None: 101 | assert partition in ["dirichlet", "shards"], f"When balance=None, 'partition' only " \ 102 | f"accepts 'dirichlet' and 'shards'." 103 | elif isinstance(balance, bool): 104 | assert partition in ["iid", "dirichlet"], f"When balance is bool, 'partition' only " \ 105 | f"accepts 'dirichlet' and 'iid'." 106 | else: 107 | raise ValueError(f"'balance' can only be NoneType or bool, not {type(balance)}.") 108 | 109 | # perform partition according to setting 110 | self.client_dict = self._perform_partition() 111 | # get sample number count for each client 112 | self.client_sample_count = F.samples_num_count(self.client_dict, self.num_clients) 113 | 114 | def _perform_partition(self): 115 | if self.balance is None: 116 | if self.partition == "dirichlet": 117 | client_dict = F.hetero_dir_partition(self.targets, 118 | self.num_clients, 119 | self.num_classes, 120 | self.dir_alpha, 121 | min_require_size=10) 122 | 123 | else: # partition is 'shards' 124 | client_dict = F.shards_partition(self.targets, self.num_clients, self.num_shards) 125 | 126 | else: # if balance is True or False 127 | # perform sample number balance/unbalance partition over all clients 128 | if self.balance is True: 129 | client_sample_nums = F.balance_split(self.num_clients, self.num_samples) 130 | else: 131 | client_sample_nums = F.lognormal_unbalance_split(self.num_clients, 132 | self.num_samples, 133 | self.unbalance_sgm) 134 | 135 | # perform iid/dirichlet partition for each client 136 | if self.partition == "iid": 137 | client_dict = F.homo_partition(client_sample_nums, self.num_samples) 138 | else: # for dirichlet 139 | client_dict = F.client_inner_dirichlet_partition(self.targets, self.num_clients, 140 | self.num_classes, self.dir_alpha, 141 | client_sample_nums, self.verbose) 142 | 143 | return client_dict 144 | 145 | def __getitem__(self, index): 146 | """Obtain sample indices for client ``index``. 147 | 148 | Args: 149 | index (int): Client ID. 150 | 151 | Returns: 152 | list: List of sample indices for client ID ``index``. 153 | 154 | """ 155 | return self.client_dict[index] 156 | 157 | def __len__(self): 158 | """Usually equals to number of clients.""" 159 | return len(self.client_dict) 160 | 161 | 162 | class CIFAR100Partitioner(CIFAR10Partitioner): 163 | """CIFAR100 data partitioner. 164 | 165 | This is a subclass of the :class:`CIFAR10Partitioner`. 166 | """ 167 | num_classes = 100 168 | 169 | 170 | class BasicPartitioner(DataPartitioner): 171 | """ 172 | - label-distribution-skew:quantity-based 173 | - label-distribution-skew:distributed-based (Dirichlet) 174 | - quantity-skew (Dirichlet) 175 | - IID 176 | 177 | Args: 178 | targets: 179 | num_clients: 180 | partition: 181 | dir_alpha: 182 | major_classes_num: 183 | verbose: 184 | seed: 185 | """ 186 | num_classes = 2 187 | 188 | def __init__(self, targets, num_clients, 189 | partition='iid', 190 | dir_alpha=None, 191 | major_classes_num=1, 192 | verbose=True, 193 | seed=None): 194 | self.targets = np.array(targets) # with shape (num_samples,) 195 | self.num_samples = self.targets.shape[0] 196 | self.num_clients = num_clients 197 | self.client_dict = dict() 198 | self.partition = partition 199 | self.dir_alpha = dir_alpha 200 | self.verbose = verbose 201 | # self.rng = np.random.default_rng(seed) # rng currently not supports randint 202 | np.random.seed(seed) 203 | 204 | if partition == "noniid-#label": 205 | # label-distribution-skew:quantity-based 206 | assert isinstance(major_classes_num, int), f"'major_classes_num' should be integer, " \ 207 | f"not {type(major_classes_num)}." 208 | assert major_classes_num > 0, f"'major_classes_num' should be positive." 209 | assert major_classes_num < self.num_classes, f"'major_classes_num' for each client " \ 210 | f"should be less than number of total " \ 211 | f"classes {self.num_classes}." 212 | self.major_classes_num = major_classes_num 213 | elif partition in ["noniid-labeldir", "unbalance"]: 214 | # label-distribution-skew:distributed-based (Dirichlet) and quantity-skew (Dirichlet) 215 | assert dir_alpha > 0, f"Parameter 'dir_alpha' for Dirichlet distribution should be " \ 216 | f"positive." 217 | elif partition == "iid": 218 | # IID 219 | pass 220 | else: 221 | raise ValueError( 222 | f"tabular data partition only supports 'noniid-#label', 'noniid-labeldir', " 223 | f"'unbalance', 'iid'. {partition} is not supported.") 224 | 225 | self.client_dict = self._perform_partition() 226 | # get sample number count for each client 227 | self.client_sample_count = F.samples_num_count(self.client_dict, self.num_clients) 228 | 229 | def _perform_partition(self): 230 | if self.partition == "noniid-#label": 231 | # label-distribution-skew:quantity-based 232 | client_dict = F.label_skew_quantity_based_partition(self.targets, self.num_clients, 233 | self.num_classes, 234 | self.major_classes_num) 235 | 236 | elif self.partition == "noniid-labeldir": 237 | # label-distribution-skew:distributed-based (Dirichlet) 238 | client_dict = F.hetero_dir_partition(self.targets, self.num_clients, self.num_classes, 239 | self.dir_alpha, 240 | min_require_size=10) 241 | 242 | elif self.partition == "unbalance": 243 | # quantity-skew (Dirichlet) 244 | client_sample_nums = F.dirichlet_unbalance_split(self.num_clients, self.num_samples, 245 | self.dir_alpha) 246 | client_dict = F.homo_partition(client_sample_nums, self.num_samples) 247 | 248 | else: 249 | # IID 250 | client_sample_nums = F.balance_split(self.num_clients, self.num_samples) 251 | client_dict = F.homo_partition(client_sample_nums, self.num_samples) 252 | 253 | return client_dict 254 | 255 | def __getitem__(self, index): 256 | return self.client_dict[index] 257 | 258 | def __len__(self): 259 | return len(self.client_dict) 260 | 261 | 262 | class VisionPartitioner(BasicPartitioner): 263 | num_classes = 10 264 | 265 | def __init__(self, targets, num_clients, 266 | partition='iid', 267 | dir_alpha=None, 268 | major_classes_num=None, 269 | verbose=True, 270 | seed=None): 271 | super(VisionPartitioner, self).__init__(targets=targets, num_clients=num_clients, 272 | partition=partition, 273 | dir_alpha=dir_alpha, 274 | major_classes_num=major_classes_num, 275 | verbose=verbose, 276 | seed=seed) 277 | 278 | 279 | class MNISTPartitioner(VisionPartitioner): 280 | num_features = 784 281 | 282 | 283 | class FMNISTPartitioner(VisionPartitioner): 284 | num_features = 784 285 | 286 | 287 | class SVHNPartitioner(VisionPartitioner): 288 | num_features = 1024 289 | 290 | 291 | # class FEMNISTPartitioner(DataPartitioner): 292 | # def __init__(self): 293 | # """ 294 | # - feature-distribution-skew:real-world 295 | # - IID 296 | # """ 297 | # # num_classes = 298 | # pass 299 | # 300 | # def _perform_partition(self): 301 | # pass 302 | # 303 | # def __getitem__(self, index): 304 | # return self.client_dict[index] 305 | # 306 | # def __len__(self): 307 | # return len(self.client_dict) 308 | 309 | 310 | class FCUBEPartitioner(DataPartitioner): 311 | """FCUBE data partitioner. 312 | 313 | FCUBE is a synthetic dataset for research in non-IID scenario with feature imbalance. This 314 | dataset and its partition methods are proposed in `Federated Learning on Non-IID Data Silos: An 315 | Experimental Study `_. 316 | 317 | Supported partition methods for FCUBE: 318 | 319 | - feature-distribution-skew:synthetic 320 | 321 | - IID 322 | 323 | For more details, please refer to Section (IV-B-b) of original paper. 324 | 325 | Args: 326 | data (numpy.ndarray): Data of dataset :class:`FCUBE`. 327 | """ 328 | num_classes = 2 329 | num_clients = 4 # only accept partition for 4 clients 330 | 331 | def __init__(self, data, partition): 332 | if partition not in ['synthetic', 'iid']: 333 | raise ValueError( 334 | f"FCUBE only supports 'synthetic' and 'iid' partition, not {partition}.") 335 | self.partition = partition 336 | self.data = data 337 | if isinstance(data, np.ndarray): 338 | self.num_samples = data.shape[0] 339 | else: 340 | self.num_samples = len(data) 341 | 342 | self.client_dict = self._perform_partition() 343 | 344 | def _perform_partition(self): 345 | if self.partition == 'synthetic': 346 | # feature-distribution-skew:synthetic 347 | client_dict = F.fcube_synthetic_partition(self.data) 348 | else: 349 | # IID partition 350 | client_sample_nums = F.balance_split(self.num_clients, self.num_samples) 351 | client_dict = F.homo_partition(client_sample_nums, self.num_samples) 352 | 353 | return client_dict 354 | 355 | def __getitem__(self, index): 356 | return self.client_dict[index] 357 | 358 | def __len__(self): 359 | return self.num_clients 360 | 361 | 362 | class AdultPartitioner(BasicPartitioner): 363 | num_features = 123 364 | num_classes = 2 365 | 366 | 367 | class RCV1Partitioner(BasicPartitioner): 368 | num_features = 47236 369 | num_classes = 2 370 | 371 | 372 | class CovtypePartitioner(BasicPartitioner): 373 | num_features = 54 374 | num_classes = 2 375 | 376 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import torch 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from collections import OrderedDict 9 | import time 10 | import random 11 | from threading import Thread 12 | from utils.util import logging 13 | import tenseal as ts 14 | import pickle 15 | import utils.min_hash as lsh 16 | import time 17 | from functools import reduce 18 | from multiprocessing import Pool,cpu_count 19 | from utils import sampling 20 | from sklearn.cluster import KMeans 21 | from encryption.paillier import paillier_dec,paillier_enc 22 | from encryption.bfv import bfv_dec 23 | from client import params_tolist,params_tomodel 24 | from utils.util import model_init 25 | from client import test_epoch 26 | ####################################### 27 | def chunks_idx(l, n): 28 | d, r = divmod(len(l), n) 29 | for i in range(n): 30 | si = (d+1)*(i if i < r else r) + d*(0 if i < r else i - r) 31 | yield si, si+(d+1 if i < r else d) 32 | 33 | 34 | def _compress(flatten_array, num_bits): 35 | res = 0 36 | l = len(flatten_array) 37 | for element in flatten_array: 38 | res <<= num_bits 39 | res += element 40 | 41 | return res, l 42 | 43 | def compress_multi(flatten_array, num_bits): 44 | l = len(flatten_array) 45 | MAGIC_N_JOBS = 10 46 | pool_inputs = [] 47 | sizes = [] 48 | pool = Pool(MAGIC_N_JOBS) 49 | 50 | for begin, end in chunks_idx(range(l), MAGIC_N_JOBS): 51 | sizes.append(end - begin) 52 | 53 | pool_inputs.append([flatten_array[begin:end], num_bits]) 54 | 55 | pool_outputs = pool.starmap(_compress, pool_inputs) 56 | pool.close() 57 | pool.join() 58 | 59 | res = 0 60 | 61 | for idx, output in enumerate(pool_outputs): 62 | res += output[0] << (int(np.sum(sizes[idx + 1:])) * num_bits) 63 | 64 | num_bytes = (num_bits * l - 1) // 8 + 1 65 | res = res.to_bytes(num_bytes, 'big') 66 | return res, l 67 | 68 | def device_init(args): 69 | use_cuda = args.cuda and torch.cuda.is_available() 70 | use_mps = args.mps and torch.backends.mps.is_available() 71 | if use_cuda: 72 | device = torch.device("cuda") 73 | elif use_mps: 74 | device = torch.device("mps") 75 | else: 76 | device = torch.device("cpu") 77 | return device 78 | 79 | def recv_msg(idx,pipe,lock,recv_list,rec,participation_list): 80 | recv_list[idx] = 1 81 | msg = pipe.recv() 82 | participation_list[idx]+=1 83 | # lock.acquire() 84 | # print("Server receive: client{}".format(idx)) 85 | # lock.release() 86 | rec[str(idx)] = msg 87 | recv_list[idx] = 0 88 | return 89 | 90 | def recv_acc(idx,pipe,recv_list,rec): 91 | recv_list[idx] = 1 92 | msg = pipe.recv() 93 | rec[str(idx)] = msg 94 | recv_list[idx] = 0 95 | return 96 | 97 | def test_epoch1(model, device, data_loaders): 98 | model.to(device) 99 | model.eval() 100 | correct = 0 101 | total_data = 0 102 | test_loss = 0 103 | with torch.no_grad(): 104 | for data_loader in data_loaders: 105 | for data, target in data_loader: 106 | output = model(data.to(device)) 107 | #test_loss += F.nll_loss(output, target.to(device), reduction='sum').item() # sum up batch loss 108 | test_loss += torch.nn.CrossEntropyLoss()(output, target.to(device)) 109 | pred = output.max(1)[1] # get the index of the max log-probability 110 | correct += pred.eq(target.to(device)).sum().item() 111 | total_data += len(data_loader.dataset) 112 | test_loss /= total_data 113 | #print("loss",test_loss) 114 | # print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 115 | # test_loss, correct, len(data_loader.dataset), 116 | # 100. * correct / len(data_loader.dataset))) 117 | client_acc = correct / total_data 118 | # print(test_loss.item()) 119 | # print(client_acc) 120 | test_loss = float(test_loss) 121 | return client_acc, test_loss 122 | 123 | def cipher_size(cipher): 124 | ciphertext_size = 0 125 | for batch_cipher in cipher: 126 | compressed_ciphertext_bytes = pickle.dumps(batch_cipher) 127 | ciphertext_size += len(compressed_ciphertext_bytes) 128 | return ciphertext_size 129 | 130 | def client_selection( mat, num_clients,train_weights,weights_clusters): 131 | client_idxs,rep_num = lsh.clusters_selection_L2(np.array(mat), num_clients, train_weights,weights_clusters) 132 | return client_idxs,rep_num 133 | 134 | 135 | def aggregatie_weights(rec,recv_list,weights_client,total_sum,batch_num,id_list,args,enc_tools = {},rep_num = []): 136 | weights = 0 137 | if args.enc: 138 | ciphertext_size = 0 139 | global_cipher = [0] * batch_num 140 | if args.algorithm == 'paillier': 141 | global_cipher1 = [0] * batch_num 142 | global_cipher2 = [0] * batch_num 143 | if args.isSpars == 'topk': 144 | sum_mask = [0] * batch_num 145 | else: 146 | agg_res = np.zeros(total_sum) 147 | add_count = 0 148 | for idx,value in enumerate(rec.values()): 149 | c_id = value[0] 150 | if args.enc: 151 | if recv_list[c_id] != 0: 152 | continue 153 | if args.algorithm == 'ckks': 154 | ckks_file = os.path.join(args.data_dir + 'context_params') 155 | with open(ckks_file, "rb") as f: 156 | params = f.read() 157 | ckks_ctx = ts.context_from(params) 158 | frac = ts.ckks_vector(ckks_ctx,[weights_client[c_id]]) 159 | if args.isSpars == 'topk': 160 | mask = value[1] 161 | cipher = value[2] 162 | #print("id:",c_id,"mask:",mask) 163 | if args.cipher_count: 164 | ciphertext_size += cipher_size(cipher) 165 | for batch in range(batch_num): 166 | res = 0 167 | 168 | if mask[batch]: 169 | cnt = 0 170 | for i in range(batch): 171 | if mask[i]: 172 | cnt += 1 173 | res = ts.CKKSVector.load(ckks_ctx,cipher[cnt]) * frac 174 | sum_mask[batch] += weights_client[c_id] 175 | 176 | if global_cipher[batch]: 177 | res += ts.CKKSVector.load(ckks_ctx, global_cipher[batch]) 178 | global_cipher[batch] = res.serialize() 179 | elif args.isSpars == 'randk' or args.isSpars == 'full': 180 | cipher = value[1] 181 | if args.cipher_count: 182 | ciphertext_size += cipher_size(cipher) 183 | for batch in range(len(cipher)): 184 | add_cipher_batch = ts.CKKSVector.load(ckks_ctx, cipher[batch]) * frac 185 | if global_cipher[batch]: 186 | global_cipher_batch = ts.CKKSVector.load(ckks_ctx, global_cipher[batch]) 187 | add_cipher_batch += global_cipher_batch 188 | global_cipher[batch] = add_cipher_batch.serialize() 189 | weights += weights_client[c_id] 190 | elif args.algorithm == 'paillier': 191 | mod = enc_tools['mod'] 192 | num_bits_per_batch = enc_tools['num_bits_per_batch'] 193 | if args.isSpars == 'topk': 194 | mask = value[1] 195 | cipher = value[2] 196 | if args.cipher_count: 197 | compressed_ciphertext= compress_multi(np.array(cipher).flatten().astype(object), num_bits_per_batch) 198 | ciphertext_size += cipher_size(compressed_ciphertext) 199 | for batch in range(batch_num): 200 | res = 0 201 | if mask[batch]: 202 | cnt = 0 203 | for i in range(batch): 204 | if mask[i]: 205 | cnt += 1 206 | res = cipher[cnt] 207 | sum_mask[batch] += weights_client[c_id] 208 | for i in range(rep_num[idx]-1): 209 | res = (res * cipher[cnt])%mod 210 | if global_cipher[batch]: 211 | global_cipher_batch = global_cipher[batch] 212 | global_cipher[batch] = (global_cipher_batch * res) % mod 213 | else: 214 | global_cipher[batch] = res 215 | else: 216 | cipher = value[1] 217 | add_count += 1 218 | # if args.algorithm == 'paillier' and add_count == 3: 219 | # global_cipher1 = global_cipher 220 | # global_cipher = global_cipher2 221 | if args.cipher_count: 222 | compressed_ciphertext= compress_multi(np.array(cipher).flatten().astype(object), num_bits_per_batch) 223 | ciphertext_size += cipher_size(compressed_ciphertext) 224 | 225 | # test code 226 | ''' 227 | cls_paillier = enc_tools['cls_paillier'] 228 | total_params = enc_tools['total_params'] 229 | global_weights = paillier_dec(cipher,cls_paillier,total_params,args) 230 | print("server: id :",c_id,"global_weights[0]:",global_weights[0]) 231 | ''' 232 | for batch in range(len(cipher)): 233 | add_cipher_batch = cipher[batch] 234 | for i in range(rep_num[idx]-1): 235 | add_cipher_batch = (add_cipher_batch * cipher[batch])%mod 236 | if global_cipher[batch]: 237 | add_cipher_batch = (add_cipher_batch * global_cipher[batch])%mod 238 | global_cipher[batch] = add_cipher_batch 239 | weights += weights_client[c_id] 240 | elif args.algorithm == 'bfv': 241 | bfv_file = os.path.join(args.data_dir + 'bfv_ctx') 242 | with open(bfv_file, "rb") as f: 243 | params = f.read() 244 | bfv_ctx = ts.context_from(params) 245 | sk = bfv_ctx.secret_key() 246 | if args.isSpars == 'topk': 247 | mask = value[1] 248 | cipher = value[2] 249 | # tmp_list = bfv_dec(cipher,bfv_ctx,sk,args.isBatch,args.quan_bits,args.n_clients,batch_size = args.enc_batch_size) 250 | # print("id:",c_id,tmp_list[0],mask) 251 | if args.cipher_count: 252 | ciphertext_size += cipher_size(cipher) 253 | for batch in range(batch_num): 254 | res = 0 255 | 256 | if mask[batch]: 257 | cnt = 0 258 | for i in range(batch): 259 | if mask[i]: 260 | cnt += 1 261 | res = ts.BFVVector.load(bfv_ctx,cipher[cnt]) 262 | sum_mask[batch] += weights_client[c_id] 263 | 264 | if global_cipher[batch]: 265 | res += ts.BFVVector.load(bfv_ctx, global_cipher[batch]) 266 | global_cipher[batch] = res.serialize() 267 | 268 | else: 269 | cipher = value[1] 270 | #sk = bfv_ctx.secret_key() 271 | # tmp_plain = bfv_dec(cipher,bfv_ctx,sk,args.isBatch,args.quan_bits,args.n_clients,batch_size = args.enc_batch_size) 272 | # print("server dec id:",c_id,tmp_plain[0]) 273 | if args.cipher_count: 274 | ciphertext_size += cipher_size(cipher) 275 | for batch in range(len(cipher)): 276 | add_cipher_batch = ts.BFVVector.load(bfv_ctx, cipher[batch]) 277 | if global_cipher[batch]: 278 | global_cipher_batch = ts.BFVVector.load(bfv_ctx, global_cipher[batch]) 279 | add_cipher_batch += global_cipher_batch 280 | global_cipher[batch] = add_cipher_batch.serialize() 281 | weights += weights_client[c_id] 282 | else: 283 | raise ValueError("invalid enc algorithm",args.algorithm) 284 | else: 285 | value = value[1] 286 | if recv_list[c_id] == 0: 287 | add_params = np.array(value)*weights_client[c_id] 288 | weights += weights_client[c_id] 289 | agg_res += add_params 290 | if args.enc: 291 | if args.isSpars == 'topk': 292 | if args.cipher_count: 293 | logging('server receive: ciphertext size:{} bytes'.format(ciphertext_size),args) 294 | return sum_mask, ciphertext_size, global_cipher 295 | else: 296 | return sum_mask,global_cipher 297 | else: 298 | if args.cipher_count: 299 | logging('server receive: ciphertext size:{} bytes'.format(ciphertext_size),args) 300 | return weights, ciphertext_size, global_cipher 301 | else: 302 | return weights,global_cipher 303 | else: 304 | agg_res = agg_res.tolist() 305 | return weights,agg_res 306 | 307 | def server_process(args,kwargs_IPC,total_sum,batch_num,train_weights,test_weights,server_test_sets,kwargs): 308 | n_clients = args.n_clients 309 | rec = {} 310 | acc_rec = {} 311 | n_epochs = args.epochs 312 | queues = kwargs_IPC['queues'] 313 | acc_queue = kwargs_IPC['acc_queue'] 314 | e = kwargs_IPC['e'] 315 | lock = kwargs_IPC['lock'] 316 | recv_list = [0 for i in range(n_clients)] 317 | recv_acc_list = [0 for i in range(n_clients)] 318 | pipe = kwargs_IPC['client_pipes'] 319 | pipes = kwargs_IPC['send_pipes'] 320 | send_pipes= [pipes[idx][1] for idx in range(n_clients)] 321 | server_pipes = [pipe[idx][1] for idx in range(n_clients)] 322 | flag = kwargs_IPC['flag'] 323 | e_server = kwargs_IPC['e_server'] 324 | if args.enc and args.algorithm == 'paillier': 325 | enc_tools = kwargs_IPC['enc_tools'] 326 | else: 327 | enc_tools = {} 328 | rep_num = [1] * args.n_clients 329 | select_flag=False 330 | hash_queue = kwargs_IPC['hash_queue'] 331 | 332 | participation_list = [0 for _ in range(n_clients)] 333 | accuracy_list = [] 334 | loss_list = [] 335 | total_ciphertext_size = 0 336 | cipher_size_list = [] 337 | id_list = [range(n_clients)] 338 | weights_client = [weight for weight in train_weights] 339 | time_list = [] 340 | tmp_len_clusters = [] 341 | # If it is plain text training, the server has a global model 342 | if args.enc == False: 343 | device = device_init(args) 344 | model = model_init(args.dataset,device) 345 | params_list,params_num,layer_shape = params_tolist(model) 346 | server_test_sets = torch.utils.data.DataLoader(server_test_sets, **kwargs) 347 | begin = time.time() 348 | 349 | for epoch in range(n_epochs): 350 | 351 | if epoch > 0 and epoch % 10 == 0: 352 | select_flag = True 353 | e.clear() 354 | 355 | threads = [] 356 | for idx in range(n_clients): 357 | # If the previous listening thread ends or there is no listening thread 358 | if recv_list[idx] == 0: 359 | client_pipe = server_pipes[idx] 360 | thread = Thread(target=recv_msg,args = (idx,client_pipe,lock,recv_list,rec,participation_list)) 361 | threads.append(thread) 362 | thread.start() 363 | for thread in threads: 364 | thread.join(timeout=3) 365 | 366 | if args.isSelection and epoch > 0: 367 | wait_bound = len(client_selected) 368 | else: 369 | wait_bound = n_clients 370 | wait_time = n_clients 371 | for i in range(args.n_clients): 372 | wait_time -= recv_list[i] 373 | while wait_time != wait_bound: 374 | time.sleep(1) 375 | wait_time = n_clients 376 | for i in range(args.n_clients): 377 | wait_time -= recv_list[i] 378 | 379 | 380 | # average weight 381 | if not args.weighted: 382 | weights_client = [1/n_clients for _ in range(n_clients)] 383 | train_weights = [1/n_clients for _ in range(n_clients)] 384 | 385 | # Encryption weight aggregation 386 | if args.enc: 387 | if args.cipher_count: 388 | weights, *agg_res= aggregatie_weights(rec,recv_list,weights_client, 389 | total_sum,batch_num,id_list,args,enc_tools,rep_num) 390 | total_ciphertext_size += agg_res[0] 391 | cipher_size_list.append(agg_res[0]) 392 | agg_res = agg_res[1] 393 | else: 394 | weights, agg_res= aggregatie_weights(rec,recv_list,weights_client, 395 | total_sum,batch_num,id_list,args,enc_tools,rep_num) 396 | else: 397 | 398 | weights, agg_res= aggregatie_weights(rec,recv_list,weights_client, 399 | total_sum,batch_num,id_list,args) 400 | global_weights = (np.array(agg_res) / weights).tolist() 401 | params_list,params_num,layer_shape = params_tolist(model) 402 | params_tomodel(model,global_weights,params_num,layer_shape,args,params_list) 403 | lock.acquire() 404 | logging('server agg: epoch {}.'.format(epoch),args) 405 | lock.release() 406 | 407 | # The aggregation is completed 408 | if epoch > 0 and args.isSelection: 409 | e_server.clear() 410 | 411 | # send to client 412 | for queue in queues: 413 | if queue.empty() == False: 414 | a = queue.get() 415 | queue.put([weights,agg_res]) 416 | 417 | # The aggregated content has been sent and can be read by the client 418 | e.set() 419 | 420 | if args.enc == True: 421 | acc_rec = {} 422 | threads = [] 423 | for idx in range(n_clients): 424 | if recv_acc_list[idx] == 0: 425 | client_pipe = send_pipes[idx] 426 | thread = Thread(target=recv_acc,args = (idx,client_pipe,recv_acc_list,acc_rec)) 427 | threads.append(thread) 428 | thread.start() 429 | for thread in threads: 430 | thread.join(timeout = 3) 431 | 432 | # wait for client accuracy 433 | time.sleep(1) 434 | acc_epoch_list = [] 435 | acc_weights = 0 436 | epoch_acc = 0 437 | epoch_loss = 0 438 | loss_epoch_list = [] 439 | for idx,value in enumerate(acc_rec.values()): 440 | id_acc = value 441 | c_id = id_acc[0] 442 | acc = id_acc[1] 443 | loss = id_acc[2] 444 | 445 | # lock.acquire() 446 | # logging('client:{}, accuracy:{}%.'.format(c_id,acc),args) 447 | # lock.release() 448 | 449 | acc_weights += test_weights[c_id] 450 | loss_epoch_list.append(loss*test_weights[c_id]) 451 | acc_epoch_list.append(acc*test_weights[c_id]) 452 | 453 | # current epoch accuraacy 454 | epoch_acc = round(np.sum(np.array(acc_epoch_list)) / acc_weights,2) 455 | epoch_loss = round(np.sum(np.array(loss_epoch_list)) / acc_weights,2) 456 | 457 | # save each epoch accuracy 458 | accuracy_list.append(epoch_acc) 459 | loss_list.append(epoch_loss) 460 | lock.acquire() 461 | logging("***********Server epoch {}, Clients accuracy:{}, loss:{}%***********\n".format( 462 | epoch,epoch_acc,epoch_loss),args) 463 | lock.release() 464 | else: 465 | server_acc,server_loss = test_epoch(model, device, server_test_sets) 466 | accuracy_list.append(server_acc) 467 | loss_list.append(server_loss) 468 | lock.acquire() 469 | logging("***********Server epoch {}, Clients accuracy:{}%***********\n".format(epoch,server_acc),args) 470 | lock.release() 471 | end = time.time() 472 | time_cost = round(end-begin,2) 473 | print("time:{}s".format(time_cost)) 474 | time_list.append(time_cost) 475 | 476 | if args.isSelection: 477 | time.sleep(1) 478 | # wait for client sketch 479 | weights_clusters = [weight for weight in train_weights] 480 | weights_client = [weight for weight in train_weights] 481 | hash_list = [] 482 | id_list = [] 483 | while not hash_queue.empty(): 484 | id_hash = hash_queue.get() 485 | id_list.append(id_hash[0]) 486 | hash_list.append(id_hash[1]) 487 | 488 | hash_list = sorted(hash_list, key=lambda x: id_list[hash_list.index(x)]) 489 | id_list = sorted(id_list) 490 | client_selected,rep_num = client_selection(np.array(hash_list),len(id_list),weights_client,weights_clusters) 491 | if args.isSelection: 492 | logging("Num:{} ,Next round Selected clients:{}".format(len(client_selected), client_selected),args) 493 | tmp_len_clusters.append(len(client_selected)) 494 | weights_client = weights_clusters 495 | new_weights = [] 496 | for i in client_selected: 497 | new_weights.append(weights_client[i]) 498 | selected_file = os.path.join(args.data_dir, args.dataset + 'selected') 499 | with open(selected_file, "wb") as f: 500 | clients_bytes = pickle.dumps([client_selected,new_weights]) 501 | f.write(clients_bytes) 502 | # Set the flag bit to indicate that the client selection is completed 503 | e_server.set() 504 | 505 | logging('server end!',args) 506 | 507 | flag.value = True 508 | e.clear() 509 | e_server.clear() 510 | if args.enc and args.cipher_count: 511 | logging("Total ciphertext size: {} bytes, size list: {}.".format(total_ciphertext_size,cipher_size_list),args) 512 | logging("Accuracy list: {}%.".format(accuracy_list), args) 513 | logging("Loss list:{}".format(loss_list),args) 514 | logging("time list:{}s".format(time_list),args) 515 | logging("Participate list: {}.".format(participation_list), args) 516 | logging("tmp_len_clusters:{}".format(tmp_len_clusters),args) 517 | 518 | return 519 | 520 | --------------------------------------------------------------------------------