├── data ├── __init__.py ├── mnist_flat │ ├── __init__.py │ └── mnist_flat_generator.py ├── relational_table_preprocessor.py └── noniid_partition.py ├── imgs └── split_nn.PNG ├── models.py ├── readme.md ├── split_nn.py └── data_entities.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/mnist_flat/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/split_nn.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlpotter/SplitLearning/HEAD/imgs/split_nn.PNG -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | class model1(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.conv1 = nn.Conv2d(1,32, 3) 9 | self.pool = nn.MaxPool2d(2, 2) 10 | 11 | def forward(self, x): 12 | # x = x.view(-1,1,28,28) 13 | x = self.pool(F.relu(self.conv1(x))) 14 | return x 15 | 16 | class model2(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | self.fc1 = nn.Linear(32*13*13,1000) 20 | self.fc2 = nn.Linear(1000,100) 21 | 22 | def forward(self, x): 23 | x = torch.flatten(x, 1) # flatten all dimensions except batch 24 | x = F.relu(self.fc1(x)) 25 | x = F.relu(self.fc2(x)) 26 | 27 | return x 28 | 29 | class model3(nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | self.fc3 = nn.Linear(100, 10) 33 | 34 | def forward(self, x): 35 | x = self.fc3(x) 36 | return x -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Split-Learning on Heterogenous Distributed MNIST 2 | 3 | ## Bob (coordinator) 4 | Bob consists of two main functions 5 | 1. *Train Request* 6 | 1. Request for Alice_x to update model weights to last trained Alice_x' weights. 7 | 2. Perform flow of forward and backward pass in figure below for N*Batches iterations. 8 | 3. Round robin fashion request for next Alice_x'' to begin training. 9 | 2. *Evaluation Request* 10 | 1. Request for each Alice_x to perform evaluation on the test set. 11 | 2. Aggregate results of each Alice_x and log the overall performance. 12 | 13 | Details from https://dspace.mit.edu/bitstream/handle/1721.1/121966/1810.06060.pdf?sequence=2&isAllowed=y . 14 | 15 | ![Alt text](imgs/split_nn.PNG?raw=true "Decentralized Split Learning Architecure") 16 | 17 | 18 | Example run: 19 | ```python split_nn.py --epochs=2 --iterations=2 --world_size=5``` 20 | 21 | ``` 22 | Split Learning Initialization 23 | 24 | optional arguments: 25 | -h, --help show this help message and exit 26 | --world_size WORLD_SIZE 27 | The world size which is equal to 1 server + (world size - 1) clients 28 | --epochs EPOCHS The number of epochs to run on the client training each iteration 29 | --iterations ITERATIONS 30 | The number of iterations to communication between clients and server 31 | --batch_size BATCH_SIZE 32 | The batch size during the epoch training 33 | --partition_alpha PARTITION_ALPHA 34 | Number to describe the uniformity during sampling (heterogenous data generation for LDA) 35 | --datapath DATAPATH folder path to all the local datasets 36 | --lr LR Learning rate of local client (SGD) 37 | ``` -------------------------------------------------------------------------------- /split_nn.py: -------------------------------------------------------------------------------- 1 | from data_entities import alice,bob 2 | import torch.multiprocessing as mp 3 | import torch.distributed.rpc as rpc 4 | import os 5 | import argparse 6 | from data.mnist_flat.mnist_flat_generator import load_mnist_image 7 | 8 | def init_env(): 9 | print("Initialize Meetup Spot") 10 | os.environ['MASTER_ADDR'] = "localhost" 11 | os.environ["MASTER_PORT"] = "5689" 12 | 13 | def example(rank,world_size,args): 14 | init_env() 15 | if rank == 0: 16 | rpc.init_rpc("bob", rank=rank, world_size=world_size) 17 | 18 | BOB = bob(args) 19 | 20 | for iter in range(args.iterations): 21 | for client_id in range(1,world_size): 22 | print(f"Training client {client_id}") 23 | BOB.train_request(client_id) 24 | BOB.eval_request() 25 | 26 | rpc.shutdown() 27 | else: 28 | rpc.init_rpc(f"alice{rank}", rank=rank, world_size=world_size) 29 | rpc.shutdown() 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser(description='Split Learning Initialization') 33 | parser.add_argument('--world_size',type=int,default=3,help='The world size which is equal to 1 server + (world size - 1) clients') 34 | parser.add_argument('--epochs',type=int,default=1,help='The number of epochs to run on the client training each iteration') 35 | parser.add_argument('--iterations',type=int,default=5,help='The number of iterations to communication between clients and server') 36 | parser.add_argument('--batch_size',type=int,default=16,help='The batch size during the epoch training') 37 | parser.add_argument('--partition_alpha',type=float,default=0.5,help='Number to describe the uniformity during sampling (heterogenous data generation for LDA)') 38 | parser.add_argument('--datapath',type=str,default="data/mnist_flat",help='folder path to all the local datasets') 39 | parser.add_argument('--lr',type=float,default=0.001,help='Learning rate of local client (SGD)') 40 | 41 | args = parser.parse_args() 42 | 43 | args.client_num_in_total = args.world_size - 1 44 | 45 | 46 | 47 | load_mnist_image(args) 48 | 49 | world_size = args.world_size 50 | mp.spawn(example, 51 | args=(world_size,args), 52 | nprocs=world_size, 53 | join=True) -------------------------------------------------------------------------------- /data/mnist_flat/mnist_flat_generator.py: -------------------------------------------------------------------------------- 1 | from data.relational_table_preprocessor import image_preprocess_dl 2 | import pandas as pd 3 | import numpy as np 4 | import torch 5 | from collections import Counter 6 | import os 7 | from sklearn.datasets import fetch_openml 8 | 9 | def load_mnist_flat(args): 10 | dataset = fetch_openml("mnist_784") 11 | 12 | data = dataset.data 13 | label_list = dataset.target.astype(int).tolist() 14 | 15 | args.class_num = len(np.unique(label_list)) 16 | 17 | [_, _, _, _,_, train_data_local_dict, test_data_local_dict, args.class_num] = relational_table_preprocess_dl(args, 18 | data, 19 | label_list, 20 | test_partition=0.2) 21 | for key in train_data_local_dict.keys(): 22 | torch.save(train_data_local_dict[key], os.path.join(args.datapath,f"data_worker{key+1}_train.pt")) 23 | print(dict(sorted(dict(Counter(train_data_local_dict[key].dataset[:][1].numpy().tolist())).items()))) 24 | torch.save(test_data_local_dict[key], os.path.join(args.datapath,f"data_worker{key+1}_test.pt")) 25 | 26 | def load_mnist_image(args): 27 | dataset = fetch_openml("mnist_784") 28 | 29 | data = dataset.data.reshape(-1, 1, 28, 28) 30 | label_list = dataset.target.astype(int).tolist() 31 | 32 | args.class_num = len(np.unique(label_list)) 33 | 34 | [_, _, _, _,_, train_data_local_dict, test_data_local_dict, args.class_num] = image_preprocess_dl(args, 35 | data, 36 | label_list, 37 | test_partition=0.2) 38 | for key in train_data_local_dict.keys(): 39 | torch.save(train_data_local_dict[key], os.path.join(args.datapath,f"data_worker{key+1}_train.pt")) 40 | print(dict(sorted(dict(Counter(train_data_local_dict[key].dataset[:][1].numpy().tolist())).items()))) 41 | torch.save(test_data_local_dict[key], os.path.join(args.datapath,f"data_worker{key+1}_test.pt")) 42 | 43 | if __name__ == '__main__': 44 | print("Nothing") -------------------------------------------------------------------------------- /data/relational_table_preprocessor.py: -------------------------------------------------------------------------------- 1 | from data.noniid_partition import non_iid_partition_with_dirichlet_distribution 2 | import pandas as pd 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader,TensorDataset,ConcatDataset 6 | import random 7 | 8 | def relational_table_preprocess_dl(args,data,label_list,test_partition=0.2): 9 | 10 | label_list = torch.LongTensor(label_list) 11 | data = torch.FloatTensor(data) 12 | 13 | net_dataidx_map = non_iid_partition_with_dirichlet_distribution(label_list=label_list, 14 | client_num=args.client_num_in_total, 15 | classes=args.class_num, 16 | alpha=args.partition_alpha, 17 | task='task') 18 | # wish to set up proper configuration of data correctly as seen used in FedML 19 | 20 | train_data_local_dict = {} 21 | test_data_local_dict = {} 22 | train_data_local_num_dict = {} 23 | test_partition = test_partition 24 | train_data_global = [] 25 | test_data_global = [] 26 | train_data_num = 0 27 | test_data_num = 0 28 | 29 | for key,client_data in net_dataidx_map.items(): 30 | N_client = len(client_data) 31 | N_train = int(N_client*(1-test_partition)) 32 | 33 | train_data_local_num_dict[key] = N_train 34 | random.shuffle(client_data) 35 | 36 | 37 | train,train_label = data[client_data[:N_train],:],label_list[client_data[:N_train]] 38 | train_dataset = TensorDataset(train,train_label) 39 | train_data_local_dict[key] = DataLoader(train_dataset,batch_size=args.batch_size,shuffle=True) 40 | 41 | test, test_label = data[client_data[N_train:], :], label_list[client_data[N_train:]] 42 | test_dataset = TensorDataset(test, test_label) 43 | test_data_local_dict[key] = DataLoader(test_dataset, batch_size=args.batch_size) 44 | 45 | train_data_global.append(train_dataset) 46 | test_data_global.append(test_dataset) 47 | 48 | train_data_num += N_train 49 | test_data_num += (N_client-N_train) 50 | 51 | train_data_global = DataLoader(ConcatDataset(train_data_global),batch_size=args.batch_size) 52 | test_data_global = DataLoader(ConcatDataset(test_data_global),batch_size=args.batch_size) 53 | 54 | return [train_data_num, test_data_num, train_data_global, test_data_global, 55 | train_data_local_num_dict, train_data_local_dict, test_data_local_dict, args.class_num] 56 | 57 | def image_preprocess_dl(args,data,label_list,test_partition=0.2): 58 | 59 | label_list = torch.LongTensor(label_list) 60 | data = torch.FloatTensor(data) 61 | 62 | net_dataidx_map = non_iid_partition_with_dirichlet_distribution(label_list=label_list, 63 | client_num=args.client_num_in_total, 64 | classes=args.class_num, 65 | alpha=args.partition_alpha, 66 | task='task') 67 | # wish to set up proper configuration of data correctly as seen used in FedML 68 | 69 | train_data_local_dict = {} 70 | test_data_local_dict = {} 71 | train_data_local_num_dict = {} 72 | test_partition = test_partition 73 | train_data_global = [] 74 | test_data_global = [] 75 | train_data_num = 0 76 | test_data_num = 0 77 | 78 | for key,client_data in net_dataidx_map.items(): 79 | N_client = len(client_data) 80 | N_train = int(N_client*(1-test_partition)) 81 | 82 | train_data_local_num_dict[key] = N_train 83 | random.shuffle(client_data) 84 | 85 | 86 | train,train_label = data[client_data[:N_train]],label_list[client_data[:N_train]] 87 | train_dataset = TensorDataset(train,train_label) 88 | train_data_local_dict[key] = DataLoader(train_dataset,batch_size=args.batch_size,shuffle=True) 89 | 90 | test, test_label = data[client_data[N_train:]], label_list[client_data[N_train:]] 91 | test_dataset = TensorDataset(test, test_label) 92 | test_data_local_dict[key] = DataLoader(test_dataset, batch_size=args.batch_size) 93 | 94 | train_data_global.append(train_dataset) 95 | test_data_global.append(test_dataset) 96 | 97 | train_data_num += N_train 98 | test_data_num += (N_client-N_train) 99 | 100 | train_data_global = DataLoader(ConcatDataset(train_data_global),batch_size=args.batch_size) 101 | test_data_global = DataLoader(ConcatDataset(test_data_global),batch_size=args.batch_size) 102 | 103 | return [train_data_num, test_data_num, train_data_global, test_data_global, 104 | train_data_local_num_dict, train_data_local_dict, test_data_local_dict, args.class_num] -------------------------------------------------------------------------------- /data/noniid_partition.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | 5 | 6 | def non_iid_partition_with_dirichlet_distribution(label_list, 7 | client_num, 8 | classes, 9 | alpha, 10 | task='classification'): 11 | """ 12 | Obtain sample index list for each client from the Dirichlet distribution. 13 | 14 | This LDA method is first proposed by : 15 | Measuring the Effects of Non-Identical Data Distribution for 16 | Federated Visual Classification (https://arxiv.org/pdf/1909.06335.pdf). 17 | 18 | This can generate nonIIDness with unbalance sample number in each label. 19 | The Dirichlet distribution is a density over a K dimensional vector p whose K components are positive and sum to 1. 20 | Dirichlet can support the probabilities of a K-way categorical event. 21 | In FL, we can view K clients' sample number obeys the Dirichlet distribution. 22 | For more details of the Dirichlet distribution, please check https://en.wikipedia.org/wiki/Dirichlet_distribution 23 | 24 | Parameters 25 | ---------- 26 | label_list : the label list from classification/segmentation dataset 27 | client_num : number of clients 28 | classes: the number of classification (e.g., 10 for CIFAR-10) OR a list of segmentation categories 29 | alpha: a concentration parameter controlling the identicalness among clients. 30 | task: CV specific task eg. classification, segmentation 31 | Returns 32 | ------- 33 | samples : ndarray, 34 | The drawn samples, of shape ``(size, k)``. 35 | """ 36 | net_dataidx_map = {} 37 | K = classes 38 | 39 | # For multiclass labels, the list is ragged and not a numpy array 40 | N = len(label_list) if task == 'segmentation' else label_list.shape[0] 41 | 42 | # guarantee the minimum number of sample in each client 43 | min_size = 0 44 | while min_size < 10: 45 | idx_batch = [[] for _ in range(client_num)] 46 | 47 | if task == 'segmentation': 48 | # Unlike classification tasks, here, one instance may have multiple categories/classes 49 | for c, cat in enumerate(classes): 50 | if c > 0: 51 | idx_k = np.asarray([np.any(label_list[i] == cat) and not np.any( 52 | np.in1d(label_list[i], classes[:c])) for i in 53 | range(len(label_list))]) 54 | else: 55 | idx_k = np.asarray( 56 | [np.any(label_list[i] == cat) for i in range(len(label_list))]) 57 | 58 | # Get the indices of images that have category = c 59 | idx_k = np.where(idx_k)[0] 60 | idx_batch, min_size = partition_class_samples_with_dirichlet_distribution(N, alpha, client_num, 61 | idx_batch, idx_k) 62 | else: 63 | # for each classification in the dataset 64 | for k in range(K): 65 | # get a list of batch indexes which are belong to label k 66 | idx_k = np.where(label_list == k)[0] 67 | idx_batch, min_size = partition_class_samples_with_dirichlet_distribution(N, alpha, client_num, 68 | idx_batch, idx_k) 69 | for i in range(client_num): 70 | np.random.shuffle(idx_batch[i]) 71 | net_dataidx_map[i] = idx_batch[i] 72 | 73 | return net_dataidx_map 74 | 75 | 76 | def partition_class_samples_with_dirichlet_distribution(N, alpha, client_num, idx_batch, idx_k): 77 | np.random.shuffle(idx_k) 78 | # using dirichlet distribution to determine the unbalanced proportion for each client (client_num in total) 79 | # e.g., when client_num = 4, proportions = [0.29543505 0.38414498 0.31998781 0.00043216], sum(proportions) = 1 80 | proportions = np.random.dirichlet(np.repeat(alpha, client_num)) 81 | 82 | # get the index in idx_k according to the dirichlet distribution 83 | proportions = np.array([p * (len(idx_j) < N / client_num) for p, idx_j in zip(proportions, idx_batch)]) 84 | proportions = proportions / proportions.sum() 85 | proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] 86 | 87 | # generate the batch list for each client 88 | idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))] 89 | min_size = min([len(idx_j) for idx_j in idx_batch]) 90 | 91 | return idx_batch, min_size 92 | 93 | 94 | def record_data_stats(y_train, net_dataidx_map, task='classification'): 95 | net_cls_counts = {} 96 | 97 | for net_i, dataidx in net_dataidx_map.items(): 98 | unq, unq_cnt = np.unique(np.concatenate(y_train[dataidx]), return_counts=True) if task == 'segmentation' \ 99 | else np.unique(y_train[dataidx], return_counts=True) 100 | tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))} 101 | net_cls_counts[net_i] = tmp 102 | logging.debug('Data statistics: %s' % str(net_cls_counts)) 103 | return net_cls_counts -------------------------------------------------------------------------------- /data_entities.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.distributed.rpc 4 | import torch 5 | import torch.nn as nn 6 | from models import * 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | from torch.distributed.rpc import RRef 10 | import torch.distributed.rpc as rpc 11 | import torch.distributed.autograd as dist_autograd 12 | from torch.distributed.optim import DistributedOptimizer 13 | import logging 14 | import os 15 | from collections import Counter 16 | from copy import deepcopy 17 | 18 | class alice(object): 19 | def __init__(self,server,bob_model_rrefs,rank,args): 20 | self.client_id = rank 21 | self.epochs = args.epochs 22 | self.start_logger() 23 | 24 | self.bob = server 25 | 26 | self.model1 = model1() 27 | self.model2 = model3() 28 | 29 | self.criterion = nn.CrossEntropyLoss() 30 | 31 | self.dist_optimizer= DistributedOptimizer( 32 | torch.optim.SGD, 33 | list(map(lambda x: RRef(x),self.model2.parameters())) + bob_model_rrefs + list(map(lambda x: RRef(x),self.model1.parameters())), 34 | lr=args.lr, 35 | momentum = 0.9 36 | ) 37 | 38 | self.load_data(args) 39 | 40 | 41 | def train(self,last_alice_rref,last_alice_id): 42 | self.logger.info("Training") 43 | 44 | if last_alice_rref is None: 45 | self.logger.info(f"Alice{self.client_id} is first client to train") 46 | 47 | else: 48 | self.logger.info(f"Alice{self.client_id} receiving weights from Alice{last_alice_id}") 49 | model1_weights,model2_weights = last_alice_rref.rpc_sync().give_weights() 50 | self.model1.load_state_dict(model1_weights) 51 | self.model2.load_state_dict(model2_weights) 52 | 53 | 54 | for epoch in range(self.epochs): 55 | for i,data in enumerate(self.train_dataloader): 56 | inputs,labels = data 57 | 58 | with dist_autograd.context() as context_id: 59 | 60 | activation_alice1 = self.model1(inputs) 61 | activation_bob = self.bob.rpc_sync().train(activation_alice1) #model(activation_alice1) 62 | activation_alice2 = self.model2(activation_bob) 63 | 64 | loss = self.criterion(activation_alice2,labels) 65 | 66 | # run the backward pass 67 | dist_autograd.backward(context_id, [loss]) 68 | 69 | self.dist_optimizer.step(context_id) 70 | 71 | 72 | def give_weights(self): 73 | return [deepcopy(self.model1.state_dict()), deepcopy(self.model2.state_dict())] 74 | 75 | def eval(self): 76 | correct = 0 77 | total = 0 78 | # since we're not training, we don't need to calculate the gradients for our outputs 79 | with torch.no_grad(): 80 | for data in self.test_dataloader: 81 | images, labels = data 82 | # calculate outputs by running images through the network 83 | activation_alice1 = self.model1(images) 84 | activation_bob = self.bob.rpc_sync().train(activation_alice1) # model(activation_alice1) 85 | outputs = self.model2(activation_bob) 86 | # the class with the highest energy is what we choose as prediction 87 | _, predicted = torch.max(outputs.data, 1) 88 | total += labels.size(0) 89 | correct += (predicted == labels).sum().item() 90 | 91 | self.logger.info(f"Alice{self.client_id} Evaluating Data: {round(correct / total, 3)}") 92 | return correct, total 93 | 94 | def load_data(self,args): 95 | self.train_dataloader = torch.load(os.path.join(args.datapath ,f"data_worker{self.client_id}_train.pt")) 96 | self.test_dataloader = torch.load(os.path.join(args.datapath ,f"data_worker{self.client_id}_test.pt")) 97 | 98 | self.n_train = len(self.train_dataloader.dataset) 99 | self.logger.info("Local Data Statistics:") 100 | self.logger.info("Dataset Size: {:.2f}".format(self.n_train)) 101 | self.logger.info(dict(Counter(self.test_dataloader.dataset[:][1].numpy().tolist()))) 102 | 103 | def start_logger(self): 104 | self.logger = logging.getLogger(f"alice{self.client_id}") 105 | self.logger.setLevel(logging.INFO) 106 | 107 | format = logging.Formatter("%(asctime)s: %(message)s") 108 | 109 | fh = logging.FileHandler(filename=f"logs/alice{self.client_id}.log",mode='w') 110 | fh.setFormatter(format) 111 | fh.setLevel(logging.INFO) 112 | 113 | self.logger.addHandler(fh) 114 | 115 | self.logger.info("Alice is going insane!") 116 | 117 | 118 | 119 | class bob(object): 120 | def __init__(self,args): 121 | 122 | self.server = RRef(self) 123 | self.model = model2() 124 | model_rrefs = list(map(lambda x: RRef(x),self.model.parameters())) 125 | 126 | self.alices = {rank+1: rpc.remote(f"alice{rank+1}", alice, (self.server,model_rrefs,rank+1,args)) for rank in range(args.client_num_in_total)} 127 | self.last_alice_id = None 128 | self.client_num_in_total = args.client_num_in_total 129 | self.start_logger() 130 | 131 | def train_request(self,client_id): 132 | # call the train request from alice 133 | self.logger.info(f"Train Request for Alice{client_id}") 134 | if self.last_alice_id is None: 135 | self.alices[client_id].rpc_sync(timeout=0).train(None,None) 136 | else: 137 | self.alices[client_id].rpc_sync(timeout=0).train(self.alices[self.last_alice_id],self.last_alice_id) 138 | self.last_alice_id = client_id 139 | 140 | def eval_request(self): 141 | self.logger.info("Initializing Evaluation of all Alices") 142 | total = [] 143 | num_corr = [] 144 | check_eval = [self.alices[client_id].rpc_async(timeout=0).eval() for client_id in 145 | range(1, self.client_num_in_total + 1)] 146 | for check in check_eval: 147 | corr, tot = check.wait() 148 | total.append(tot) 149 | num_corr.append(corr) 150 | 151 | self.logger.info("Accuracy over all data: {:.3f}".format(sum(num_corr) / sum(total))) 152 | 153 | def train(self,x): 154 | return self.model(x) 155 | 156 | def start_logger(self): 157 | self.logger = logging.getLogger("bob") 158 | self.logger.setLevel(logging.INFO) 159 | 160 | format = logging.Formatter("%(asctime)s: %(message)s") 161 | 162 | fh = logging.FileHandler(filename="logs/bob.log", mode='w') 163 | fh.setFormatter(format) 164 | fh.setLevel(logging.INFO) 165 | 166 | self.logger.addHandler(fh) 167 | self.logger.info("Bob Started Getting Tipsy") 168 | --------------------------------------------------------------------------------