├── .gitignore ├── README.md ├── dataset.py ├── loss_graph.py ├── model.py ├── results ├── fbase.png ├── ffed.png ├── log_baseline_fashion_mnist_10 ├── log_baseline_mnist_10 ├── log_federated_fashion_mnist_10_2_1 ├── log_federated_fashion_mnist_10_3_1 ├── log_federated_fashion_mnist_10_3_2 ├── log_federated_mnist_10_2_2 ├── log_federated_mnist_10_3_2 ├── mbase.png └── mfed.png ├── train_baseline.py └── train_federated.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | #Data 10 | data/ 11 | models/ 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Federated-Learning 2 | Implemention of a CNN model in a federated learning setting. The dataset is distributed across a given number of clients and then the local model is trained for each client. The parameters from each client's model is then used to update the global model. 3 | 4 | The experiment is performed on MNIST and FashionMNIST dataset. A simple CNN based model is used to train the neural network. The training dataset is split into 80% traning and 20% validation data, and the validation loss is used to save the best model. The results of model trained in a federated setting is compared with a simple (baseline) model trained centrally using the complete data. 5 | 6 | ## Requirements 7 | * Python3 8 | * PyTorch 9 | * TorchVision 10 | 11 | ## Directory Structure 12 | * dataset.py - script used to load the dataset 13 | * model.py - script used to initalize the CNN model 14 | * train_baseline.py - script used to train the baseline model 15 | * train_federated.py - script used to train the Federated learning model 16 | 17 | ## Training Options 18 | Training Paramters are mentioned below and can be set inside train_baseline.py and train_federated.py 19 | * NUM_EPOCHS : number of epochs to train the model 20 | * BATCH_SIZE : Batch Size for the dataset 21 | * NUM_CLIENTS : Number of Clients to simulate a federated setting (Only in train_federated.py) 22 | * LOCAL_ITERS : Number of iterations performed by each client to update the local model (Only in train_federated.py) 23 | 24 | ## Execution Details: 25 | Basline model can be trained using, 26 | ``` 27 | python3 train_baseline.py 28 | ``` 29 | Federated Learning model can be trained using. 30 | ``` 31 | python3 train_federated.py 32 | ``` 33 | ## Results: 34 | The model is trained for 10 epochs for both baseline and federated model (3 clients and 2 local iterations each) and the Test accuracy is reported as, 35 | 36 | | Dataset | Federated | Baseline | 37 | | ------------- |:-------------:| -----:| 38 | | MNIST | 99.1% | 98.9% | 39 | | FashionMNIST | 91.3% | 90.7% | 40 | 41 | The loss plots for all the models are displayed below, 42 | 43 | ### MNIST 44 | 45 | Federated Learning Model | Baseline Model 46 | :-------------------------:|:-------------------------: 47 | ![](https://github.com/ashar207/Federated-Learning/blob/master/results/mfed.png) | ![](https://github.com/ashar207/Federated-Learning/blob/master/results/mbase.png) 48 | 49 | ### Fashion MNIST 50 | Federated Learning Model | Baseline Model 51 | :-------------------------:|:-------------------------: 52 | ![](https://github.com/ashar207/Federated-Learning/blob/master/results/ffed.png) | ![](https://github.com/ashar207/Federated-Learning/blob/master/results/fbase.png) 53 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Script to fetch the requied dataset 5 | @author : Anant 6 | """ 7 | import torch 8 | import matplotlib.pyplot as plt 9 | from torchvision import datasets, transforms 10 | 11 | DATA_DIR = "./data/" 12 | 13 | 14 | def mnist_loader(val_split=0.2, batch_size=5): 15 | """ 16 | Loads the MNIST Data into 3 sets: train, validation and test 17 | :param val_split: a float values to decide the train and validation set split 18 | :param batch_size: an int value defining the batch size of the dataset 19 | :return train_dataloader: a Pytorch datalaoder iterator for training set 20 | :return val_dataloader: a Pytorch datalaoder iterator for validation set 21 | :return test_dataloader: a Pytorch datalaoder iterator for test set 22 | """ 23 | transform = transforms.Compose([transforms.ToTensor()]) 24 | 25 | # load the dataset 26 | train_dataset = datasets.MNIST(root=DATA_DIR+"mnsit", train=True, download=True, transform=transform) 27 | test_dataset = datasets.MNIST(root=DATA_DIR+"mnist", train=False, download=True, transform=transform) 28 | 29 | #Shuffle and split train and validations set 30 | val_size = int(val_split * len(train_dataset)) 31 | train_size = int((1-val_split) * len(train_dataset)) 32 | train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size]) 33 | 34 | #Define dataloader 35 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False) 36 | val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 37 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) 38 | 39 | print("-"*30+"MNIST DATASET"+"-"*30) 40 | print("Train Set size: ", len(train_dataset)) 41 | print("Validation Set size: ", len(val_dataset)) 42 | print("Test Set size: ", len(test_dataset)) 43 | 44 | return train_dataloader, val_dataloader, test_dataloader 45 | 46 | 47 | 48 | def fashion_mnist_loader(val_split=0.2, batch_size=5): 49 | """ 50 | Loads the MNIST Data into 3 sets: train, validation and test 51 | :param val_split: a float values to decide the train and validation set split 52 | :param batch_size: an int value defining the batch size of the dataset 53 | :return train_dataloader: a Pytorch datalaoder iterator for training set 54 | :return val_dataloader: a Pytorch datalaoder iterator for validation set 55 | :return test_dataloader: a Pytorch datalaoder iterator for test set 56 | """ 57 | transform = transforms.Compose([transforms.ToTensor()]) 58 | 59 | # load the dataset 60 | train_dataset = datasets.FashionMNIST(root=DATA_DIR+"fashion_mnist", train=True, download=True, transform=transform) 61 | test_dataset = datasets.FashionMNIST(root=DATA_DIR+"fashion_mnist", train=False, download=True, transform=transform) 62 | 63 | #Shuffle and split train and validations set 64 | val_size = int(val_split * len(train_dataset)) 65 | train_size = int((1-val_split) * len(train_dataset)) 66 | train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size]) 67 | 68 | #Define dataloader 69 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False) 70 | val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 71 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) 72 | 73 | print("-"*30+"FASHION MNIST DATASET"+"-"*30) 74 | print("Train Set size: ", len(train_dataset)) 75 | print("Validation Set size: ", len(val_dataset)) 76 | print("Test Set size: ", len(test_dataset)) 77 | 78 | return train_dataloader, val_dataloader, test_dataloader 79 | 80 | 81 | def load_dataset(val_split=0.2, batch_size=5,dataset="mnist"): 82 | """ 83 | Loads the dataset as per the given dataset 84 | :param val_split: a float values to decide the train and validation set split 85 | :param batch_size: an int value defining the batch size of the dataset 86 | :param dataset: input dataset used to further call the required function to load dataset 87 | :return datasets: return the dataloader iterator for the required dataset 88 | """ 89 | if dataset=="mnist": 90 | return mnist_loader(val_split, batch_size) 91 | if dataset=="fashion_mnist": 92 | return fashion_mnist_loader(val_split, batch_size) 93 | 94 | 95 | def visualize_dataset(datasets=["train", "val","test"]): 96 | """ 97 | Displays 5 images from each set, train, validation and test 98 | :param datasets: a list contatning train, validation and test set dataloaders 99 | """ 100 | fig, big_axes = plt.subplots( figsize=(20, 15) , nrows=3, ncols=1) 101 | for i in range(3): 102 | big_axes[i]._frameon = False 103 | big_axes[i].set_axis_off() 104 | data_iter = iter(datasets[i]) 105 | if i==0: big_axes[0].set_title("Train Set", fontsize=16) 106 | if i==1: big_axes[1].set_title("Validation Set", fontsize=16) 107 | if i==2: big_axes[2].set_title("Test Set", fontsize=16) 108 | #Plot 5 images the selected dataset 109 | for j in range(5): 110 | fig.add_subplot(3,5,(i*5)+j+1) 111 | plt.imshow(transforms.ToPILImage()(next(data_iter)[0][0]), cmap=plt.get_cmap('gray')) 112 | plt.axis('off') 113 | plt.show() 114 | 115 | 116 | if __name__ == "__main__": 117 | train, validation, test = mnist_loader() 118 | visualize_dataset([train, validation, test]) 119 | -------------------------------------------------------------------------------- /loss_graph.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | # with open("results/log_federated_fashion_mnist_10_3_2") as f: 3 | # fed = f.readlines() 4 | # with open("results/log_baseline_fashion_mnist_10") as f: 5 | # base = f.readlines() 6 | 7 | with open("results/log_federated_mnist_10_3_2") as f: 8 | fed = f.readlines() 9 | with open("results/log_baseline_mnist_10") as f: 10 | base = f.readlines() 11 | 12 | 13 | fed_val = [float(f.split(" ")[7].strip(",")) for f in fed if "Train" in f] 14 | fed_train = [float(f.split(" ")[4].strip(",")) for f in fed if "Train" in f] 15 | base_val = [float(b.split(" ")[7].strip(",")) for b in base if "Train" in b] 16 | base_train = [float(b.split(" ")[4].strip(",")) for b in base if "Train" in b] 17 | 18 | epochs = range(1,11) 19 | plt.figure() 20 | plt.plot(epochs, fed_train, 'g', label='Training loss') 21 | plt.plot(epochs, fed_val, 'b', label='validation loss') 22 | plt.title('Federated loss') 23 | plt.xlabel('Epochs') 24 | plt.ylabel('Loss') 25 | plt.legend() 26 | plt.xticks() 27 | # plt.ylim([0,3]) 28 | plt.ylim([0,1]) 29 | 30 | 31 | plt.figure() 32 | plt.plot(epochs, base_train, 'g', label='Training loss') 33 | plt.plot(epochs, base_val, 'b', label='validation loss') 34 | plt.title('Baseline loss') 35 | plt.xlabel('Epochs') 36 | plt.ylabel('Loss') 37 | plt.legend() 38 | plt.xticks() 39 | # plt.ylim([0,3]) 40 | plt.ylim([0,1]) 41 | plt.show() 42 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Script used to define the Model 5 | @author : Anant 6 | """ 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | #Model 11 | #Image (1x28x28) 12 | #Conv Layer 1 (32x28x28) + Relu Activation 13 | #MaxPool (32x14x14) 14 | #Conv Layer 2 (64x14x14) + Relu Activation 15 | #MaxPool (64x7x7) 16 | #FC layer 1 (3136) + Relu Activation 17 | #FC Layer 2 (128) 18 | #Output Layer (10) 19 | 20 | #Final image size before FC layer 21 | FLATTEN_SIZE = 64*7*7 22 | 23 | 24 | class CNN(nn.Module): 25 | def __init__(self): 26 | """ 27 | Initializes the CNN Model Class and the required layers 28 | """ 29 | super(CNN, self).__init__() 30 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1, stride=1) 31 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=1) 32 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 33 | self.fc1 = nn.Linear(FLATTEN_SIZE, 128) 34 | self.fc2 = nn.Linear(128, 10) 35 | self.dropout = nn.Dropout(0.25) 36 | 37 | def forward(self, x): 38 | """ 39 | Form the Feed Forward Network by combininig all the layers 40 | :param x: the input image for the network 41 | """ 42 | x = F.relu(self.conv1(x)) 43 | x = self.pool(x) 44 | x = F.relu(self.conv2(x)) 45 | x = self.pool(x) 46 | x = x.view(-1, FLATTEN_SIZE) 47 | x = F.relu(self.fc1(x)) 48 | x = self.dropout(x) 49 | x = self.fc2(x) 50 | pred = F.log_softmax(x, dim=1) 51 | return pred -------------------------------------------------------------------------------- /results/fbase.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/an4nt/Federated-Learning/7fb8dbc45301441f6a66417aa4a9f8186b975c71/results/fbase.png -------------------------------------------------------------------------------- /results/ffed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/an4nt/Federated-Learning/7fb8dbc45301441f6a66417aa4a9f8186b975c71/results/ffed.png -------------------------------------------------------------------------------- /results/log_baseline_fashion_mnist_10: -------------------------------------------------------------------------------- 1 | INFO:root:Epoch: 1/10, Train Loss: 0.44149528, Val Loss: 0.30370722, Val Accuracy: 0.88875000 2 | INFO:root:Saving Model State 3 | INFO:root:Epoch: 2/10, Train Loss: 0.29181354, Val Loss: 0.26328902, Val Accuracy: 0.90141667 4 | INFO:root:Saving Model State 5 | INFO:root:Epoch: 3/10, Train Loss: 0.24913410, Val Loss: 0.25483956, Val Accuracy: 0.90416667 6 | INFO:root:Saving Model State 7 | INFO:root:Epoch: 4/10, Train Loss: 0.22293655, Val Loss: 0.26677265, Val Accuracy: 0.90933333 8 | INFO:root:Epoch: 5/10, Train Loss: 0.20077494, Val Loss: 0.25017438, Val Accuracy: 0.91216667 9 | INFO:root:Saving Model State 10 | INFO:root:Epoch: 6/10, Train Loss: 0.18448148, Val Loss: 0.29715600, Val Accuracy: 0.91008333 11 | INFO:root:Epoch: 7/10, Train Loss: 0.16638099, Val Loss: 0.29874967, Val Accuracy: 0.91100000 12 | INFO:root:Epoch: 8/10, Train Loss: 0.15970705, Val Loss: 0.33128412, Val Accuracy: 0.91100000 13 | INFO:root:Epoch: 9/10, Train Loss: 0.15334789, Val Loss: 0.36495166, Val Accuracy: 0.91033333 14 | INFO:root:Epoch: 10/10, Train Loss: 0.14396968, Val Loss: 0.38900903, Val Accuracy: 0.90633333 15 | INFO:root:Test accuracy 0.90710000 16 | -------------------------------------------------------------------------------- /results/log_baseline_mnist_10: -------------------------------------------------------------------------------- 1 | INFO:root:Epoch: 1/10, Train Loss: 0.16311646, Val Loss: 0.06172110, Val Accuracy: 0.98058333 2 | INFO:root:Saving Model State 3 | INFO:root:Epoch: 2/10, Train Loss: 0.06833706, Val Loss: 0.04173575, Val Accuracy: 0.98883333 4 | INFO:root:Saving Model State 5 | INFO:root:Epoch: 3/10, Train Loss: 0.04864092, Val Loss: 0.04958712, Val Accuracy: 0.98666667 6 | INFO:root:Epoch: 4/10, Train Loss: 0.03774780, Val Loss: 0.04391603, Val Accuracy: 0.98950000 7 | INFO:root:Epoch: 5/10, Train Loss: 0.03307015, Val Loss: 0.04161129, Val Accuracy: 0.99075000 8 | INFO:root:Saving Model State 9 | INFO:root:Epoch: 6/10, Train Loss: 0.02514024, Val Loss: 0.04247383, Val Accuracy: 0.99041667 10 | INFO:root:Epoch: 7/10, Train Loss: 0.02396091, Val Loss: 0.05071571, Val Accuracy: 0.98883333 11 | INFO:root:Epoch: 8/10, Train Loss: 0.02082895, Val Loss: 0.06280836, Val Accuracy: 0.98958333 12 | INFO:root:Epoch: 9/10, Train Loss: 0.02017726, Val Loss: 0.05081339, Val Accuracy: 0.99041667 13 | INFO:root:Epoch: 10/10, Train Loss: 0.01993583, Val Loss: 0.06215932, Val Accuracy: 0.98883333 14 | INFO:root:Test accuracy 0.98990000 15 | -------------------------------------------------------------------------------- /results/log_federated_fashion_mnist_10_2_1: -------------------------------------------------------------------------------- 1 | INFO:root:Epoch: 1/10, Train Loss: 2.69544075, Val Loss: 0.50212094, Val Accuracy: 0.84608333 2 | INFO:root:Saving Model State 3 | INFO:root:Epoch: 2/10, Train Loss: 1.84587284, Val Loss: 0.27983051, Val Accuracy: 0.89841667 4 | INFO:root:Saving Model State 5 | INFO:root:Epoch: 3/10, Train Loss: 1.50086015, Val Loss: 0.24836887, Val Accuracy: 0.90658333 6 | INFO:root:Saving Model State 7 | INFO:root:Epoch: 4/10, Train Loss: 1.31289461, Val Loss: 0.24149429, Val Accuracy: 0.91083333 8 | INFO:root:Saving Model State 9 | INFO:root:Epoch: 5/10, Train Loss: 1.20601426, Val Loss: 0.23800733, Val Accuracy: 0.91216667 10 | INFO:root:Saving Model State 11 | INFO:root:Epoch: 6/10, Train Loss: 1.10887536, Val Loss: 0.22917683, Val Accuracy: 0.91575000 12 | INFO:root:Saving Model State 13 | INFO:root:Epoch: 7/10, Train Loss: 1.04446593, Val Loss: 0.23318506, Val Accuracy: 0.91525000 14 | INFO:root:Epoch: 8/10, Train Loss: 0.97117264, Val Loss: 0.24355989, Val Accuracy: 0.91408333 15 | INFO:root:Epoch: 9/10, Train Loss: 0.92376716, Val Loss: 0.24610977, Val Accuracy: 0.91691667 16 | INFO:root:Epoch: 10/10, Train Loss: 0.87287157, Val Loss: 0.24514329, Val Accuracy: 0.91650000 17 | INFO:root:Test accuracy 0.91320000 18 | -------------------------------------------------------------------------------- /results/log_federated_fashion_mnist_10_3_1: -------------------------------------------------------------------------------- 1 | INFO:root:Epoch: 1/10, Train Loss: 2.90596093, Val Loss: 0.62189217, Val Accuracy: 0.83841667 2 | INFO:root:Saving Model State 3 | INFO:root:Epoch: 2/10, Train Loss: 2.01085224, Val Loss: 0.31098610, Val Accuracy: 0.88766667 4 | INFO:root:Saving Model State 5 | INFO:root:Epoch: 3/10, Train Loss: 1.63638050, Val Loss: 0.26960418, Val Accuracy: 0.90008333 6 | INFO:root:Saving Model State 7 | INFO:root:Epoch: 4/10, Train Loss: 1.44018366, Val Loss: 0.25015737, Val Accuracy: 0.90558333 8 | INFO:root:Saving Model State 9 | INFO:root:Epoch: 5/10, Train Loss: 1.31022458, Val Loss: 0.24375120, Val Accuracy: 0.90766667 10 | INFO:root:Saving Model State 11 | INFO:root:Epoch: 6/10, Train Loss: 1.20550256, Val Loss: 0.24097768, Val Accuracy: 0.90875000 12 | INFO:root:Saving Model State 13 | INFO:root:Epoch: 7/10, Train Loss: 1.13557870, Val Loss: 0.23666576, Val Accuracy: 0.91200000 14 | INFO:root:Saving Model State 15 | INFO:root:Epoch: 8/10, Train Loss: 1.07047528, Val Loss: 0.23276908, Val Accuracy: 0.91425000 16 | INFO:root:Saving Model State 17 | INFO:root:Epoch: 9/10, Train Loss: 1.01295750, Val Loss: 0.23586403, Val Accuracy: 0.91466667 18 | INFO:root:Epoch: 10/10, Train Loss: 0.98473157, Val Loss: 0.23720542, Val Accuracy: 0.91408333 19 | INFO:root:Test accuracy 0.91050000 20 | -------------------------------------------------------------------------------- /results/log_federated_fashion_mnist_10_3_2: -------------------------------------------------------------------------------- 1 | INFO:root:Epoch: 1/10, Train Loss: 2.43409528, Val Loss: 0.60292162, Val Accuracy: 0.86991667 2 | INFO:root:Saving Model State 3 | INFO:root:Epoch: 2/10, Train Loss: 1.60204656, Val Loss: 0.26183640, Val Accuracy: 0.90400000 4 | INFO:root:Saving Model State 5 | INFO:root:Epoch: 3/10, Train Loss: 1.25619657, Val Loss: 0.23764249, Val Accuracy: 0.91391667 6 | INFO:root:Saving Model State 7 | INFO:root:Epoch: 4/10, Train Loss: 1.07377891, Val Loss: 0.23737312, Val Accuracy: 0.91525000 8 | INFO:root:Saving Model State 9 | INFO:root:Epoch: 5/10, Train Loss: 0.95287254, Val Loss: 0.23576796, Val Accuracy: 0.91791667 10 | INFO:root:Saving Model State 11 | INFO:root:Epoch: 6/10, Train Loss: 0.86513598, Val Loss: 0.23956436, Val Accuracy: 0.91766667 12 | INFO:root:Epoch: 7/10, Train Loss: 0.78979900, Val Loss: 0.24854811, Val Accuracy: 0.91933333 13 | INFO:root:Epoch: 8/10, Train Loss: 0.73562320, Val Loss: 0.25198404, Val Accuracy: 0.91800000 14 | INFO:root:Epoch: 9/10, Train Loss: 0.68781250, Val Loss: 0.26422832, Val Accuracy: 0.92025000 15 | INFO:root:Epoch: 10/10, Train Loss: 0.67031168, Val Loss: 0.27261011, Val Accuracy: 0.91875000 16 | INFO:root:Test accuracy 0.91320000 17 | -------------------------------------------------------------------------------- /results/log_federated_mnist_10_2_2: -------------------------------------------------------------------------------- 1 | INFO:root:Epoch: 1/10, Train Loss: 0.77747047, Val Loss: 0.10568265, Val Accuracy: 0.97708333 2 | INFO:root:Saving Model State 3 | INFO:root:Epoch: 2/10, Train Loss: 0.29993543, Val Loss: 0.04266676, Val Accuracy: 0.98833333 4 | INFO:root:Saving Model State 5 | INFO:root:Epoch: 3/10, Train Loss: 0.19241730, Val Loss: 0.04203904, Val Accuracy: 0.98841667 6 | INFO:root:Saving Model State 7 | INFO:root:Epoch: 4/10, Train Loss: 0.15258150, Val Loss: 0.04101090, Val Accuracy: 0.98883333 8 | INFO:root:Saving Model State 9 | INFO:root:Epoch: 5/10, Train Loss: 0.10989802, Val Loss: 0.04313709, Val Accuracy: 0.98933333 10 | INFO:root:Epoch: 6/10, Train Loss: 0.10128812, Val Loss: 0.05113721, Val Accuracy: 0.98825000 11 | INFO:root:Epoch: 7/10, Train Loss: 0.09750933, Val Loss: 0.04435758, Val Accuracy: 0.98975000 12 | INFO:root:Epoch: 8/10, Train Loss: 0.08628875, Val Loss: 0.04487178, Val Accuracy: 0.99041667 13 | INFO:root:Epoch: 9/10, Train Loss: 0.07953401, Val Loss: 0.04398598, Val Accuracy: 0.99125000 14 | INFO:root:Epoch: 10/10, Train Loss: 0.08658296, Val Loss: 0.04864905, Val Accuracy: 0.98975000 15 | INFO:root:Test accuracy 0.99110000 16 | -------------------------------------------------------------------------------- /results/log_federated_mnist_10_3_2: -------------------------------------------------------------------------------- 1 | INFO:root:Epoch: 1/10, Train Loss: 0.85636063, Val Loss: 0.14035963, Val Accuracy: 0.98066667 2 | INFO:root:Saving Model State 3 | INFO:root:Epoch: 2/10, Train Loss: 0.35715060, Val Loss: 0.04148094, Val Accuracy: 0.98816667 4 | INFO:root:Saving Model State 5 | INFO:root:Epoch: 3/10, Train Loss: 0.22336485, Val Loss: 0.03867630, Val Accuracy: 0.98933333 6 | INFO:root:Saving Model State 7 | INFO:root:Epoch: 4/10, Train Loss: 0.17369942, Val Loss: 0.03683914, Val Accuracy: 0.99041667 8 | INFO:root:Saving Model State 9 | INFO:root:Epoch: 5/10, Train Loss: 0.14026766, Val Loss: 0.04028376, Val Accuracy: 0.99050000 10 | INFO:root:Epoch: 6/10, Train Loss: 0.11820753, Val Loss: 0.03946097, Val Accuracy: 0.99058333 11 | INFO:root:Epoch: 7/10, Train Loss: 0.10293075, Val Loss: 0.04052444, Val Accuracy: 0.99033333 12 | INFO:root:Epoch: 8/10, Train Loss: 0.09973919, Val Loss: 0.04512742, Val Accuracy: 0.99041667 13 | INFO:root:Epoch: 9/10, Train Loss: 0.08614432, Val Loss: 0.05164980, Val Accuracy: 0.99100000 14 | INFO:root:Epoch: 10/10, Train Loss: 0.07863043, Val Loss: 0.04251062, Val Accuracy: 0.99125000 15 | INFO:root:Test accuracy 0.99130000 16 | -------------------------------------------------------------------------------- /results/mbase.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/an4nt/Federated-Learning/7fb8dbc45301441f6a66417aa4a9f8186b975c71/results/mbase.png -------------------------------------------------------------------------------- /results/mfed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/an4nt/Federated-Learning/7fb8dbc45301441f6a66417aa4a9f8186b975c71/results/mfed.png -------------------------------------------------------------------------------- /train_baseline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Script used to train the Baseline Model 5 | @author : Anant 6 | """ 7 | import os 8 | import torch 9 | import numpy as np 10 | import logging.config 11 | from tqdm import tqdm 12 | from model import CNN 13 | import matplotlib.pyplot as plt 14 | from dataset import load_dataset, visualize_dataset 15 | 16 | NUM_EPOCHS = 10 17 | VIS_DATA = False 18 | BATCH_SIZE = 5 19 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 20 | print('Device:', DEVICE) 21 | # DATASET = "fashion_mnist" 22 | DATASET = "mnist" 23 | 24 | def train(model, device, dataloader, criterion, optimizer): 25 | """ 26 | Trains a baseline model for the given dataset 27 | :param model: a CNN model required for training 28 | :param device: the device used to train the model - GPU/CPU 29 | :param dataloader: training data iterator used to train the model 30 | :param criterion: criterion used to calculate the traninig loss 31 | :param optimzer: Optimzer used to update the model parameters using backpropagation 32 | :return train_loss: training loss for the current epoch 33 | """ 34 | train_loss = 0.0 35 | model.train() 36 | for batch_idx, (data, target) in tqdm(enumerate(dataloader), total=len(dataloader.dataset)/BATCH_SIZE): 37 | data, target = data.to(device), target.to(device) 38 | #set gradients to zero 39 | optimizer.zero_grad() 40 | #Get output prediction from the model 41 | output = model(data) 42 | #Computer loss 43 | loss = criterion(output, target) 44 | train_loss += loss.item()*data.size(0) 45 | #Collect new set of gradients 46 | loss.backward() 47 | #Upadate the model 48 | optimizer.step() 49 | 50 | return train_loss / len(dataloader.dataset) 51 | 52 | 53 | def test(model, dataloader, criterion): 54 | """ 55 | Tests the baseline model for the given dataset 56 | :param model: Trained CNN model for testing 57 | :param dataloader: data iterator used to test the model 58 | :param criterion: criterion used to calculate the test loss 59 | :return test_loss: test loss for the given dataset 60 | :return preds: predictions for the given dataset 61 | :return accuracy: accuracy for the prediction values from the model 62 | """ 63 | test_loss = 0.0 64 | correct = 0 65 | model.eval() 66 | for batch_idx, (data, target) in tqdm(enumerate(dataloader), total=len(dataloader.dataset)/BATCH_SIZE): 67 | data, target = data.to(DEVICE), target.to(DEVICE) 68 | output = model(data) 69 | loss = criterion(output, target) 70 | test_loss += loss.item()*data.size(0) 71 | preds = output.argmax(dim=1, keepdim=True) 72 | correct += preds.eq(target.view_as(preds)).sum().item() 73 | accuracy = correct / len(dataloader.dataset) 74 | 75 | return test_loss/len(dataloader.dataset), preds, accuracy 76 | 77 | 78 | if __name__=="__main__": 79 | if not os.path.isdir('models'): 80 | os.mkdir('models') 81 | if not os.path.isdir('results'): 82 | os.mkdir('results') 83 | 84 | #Initialize a logger to log epoch results 85 | logname = ('results/log_baseline_' + DATASET + "_" + str(NUM_EPOCHS)) 86 | logging.basicConfig(filename=logname,level=logging.DEBUG) 87 | logger = logging.getLogger() 88 | 89 | #get data 90 | train_data, validation_data, test_data = load_dataset(val_split=0.2, batch_size=BATCH_SIZE, dataset=DATASET) 91 | if VIS_DATA: visualize_dataset([train, validation, test]) 92 | 93 | #get model and define criterion for loss and optimizer for model update 94 | model = CNN().to(DEVICE) 95 | criterion = torch.nn.CrossEntropyLoss() 96 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 97 | 98 | all_train_loss = list() 99 | all_val_loss = list() 100 | val_loss_min = np.Inf 101 | 102 | #Train the model for given number of epochs 103 | for epoch in range(1, NUM_EPOCHS+1): 104 | print("\nEpoch :", str(epoch)) 105 | #train using training data 106 | train_loss = train(model, DEVICE, train_data, criterion, optimizer) 107 | #test on validation data 108 | val_loss, _, accuracy = test(model, validation_data, criterion) 109 | all_train_loss.append(train_loss) 110 | all_val_loss.append(val_loss) 111 | logger.info('Epoch: {}/{}, Train Loss: {:.8f}, Val Loss: {:.8f}, Val Accuracy: {:.8f}'.format(epoch , NUM_EPOCHS, train_loss, val_loss, accuracy)) 112 | #if validation loss decreases, save the model 113 | if val_loss < val_loss_min: 114 | val_loss_min = val_loss 115 | logger.info("Saving Model State") 116 | torch.save(model.state_dict(), "models/mnist_baseline.sav") 117 | 118 | #load the best model from training 119 | model.load_state_dict(torch.load("models/mnist_baseline.sav")) 120 | #test the model using test data 121 | test_loss, predictions, accuracy = test(model, test_data, criterion) 122 | logger.info('Test accuracy {:.8f}'.format(accuracy)) -------------------------------------------------------------------------------- /train_federated.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Script used to train the Federated Model 5 | @author : Anant 6 | """ 7 | import os 8 | import torch 9 | import numpy as np 10 | import logging.config 11 | from tqdm import tqdm 12 | from model import CNN 13 | import matplotlib.pyplot as plt 14 | from dataset import load_dataset, visualize_dataset 15 | import copy 16 | 17 | NUM_EPOCHS = 10 18 | LOCAL_ITERS = 2 19 | VIS_DATA = False 20 | BATCH_SIZE = 5 21 | NUM_CLIENTS = 2 22 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 23 | print('Device:', DEVICE) 24 | 25 | # DATASET = "fashion_mnist" 26 | DATASET = "mnist" 27 | 28 | def FedAvg(params): 29 | """ 30 | Average the paramters from each client to update the global model 31 | :param params: list of paramters from each client's model 32 | :return global_params: average of paramters from each client 33 | """ 34 | global_params = copy.deepcopy(params[0]) 35 | for key in global_params.keys(): 36 | for param in params[1:]: 37 | global_params[key] += param[key] 38 | global_params[key] = torch.div(global_params[key], len(params)) 39 | return global_params 40 | 41 | 42 | def train(local_model, device, dataset, iters): 43 | """ 44 | Trains a local model for a given client 45 | :param local_model: a copy of global CNN model required for training 46 | :param device: the device used to train the model - GPU/CPU 47 | :param dataset: training dataset used to train the model 48 | :return local_params: parameters from the trained model from the client 49 | :return train_loss: training loss for the current epoch 50 | """ 51 | #optimzer for training the local models 52 | local_model.to(device) 53 | criterion = torch.nn.CrossEntropyLoss().to(device) 54 | optimizer = torch.optim.Adam(local_model.parameters(), lr=0.001) 55 | train_loss = 0.0 56 | local_model.train() 57 | #Iterate for the given number of Client Iterations 58 | for i in range(iters): 59 | batch_loss = 0.0 60 | for batch_idx, (data, target) in tqdm(enumerate(dataset), total=len(dataset)): 61 | data, target = data.to(device), target.to(device) 62 | #set gradients to zero 63 | optimizer.zero_grad() 64 | #Get output prediction from the Client model 65 | output = local_model(data) 66 | #Computer loss 67 | loss = criterion(output, target) 68 | batch_loss += loss.item()*data.size(0) 69 | #Collect new set of gradients 70 | loss.backward() 71 | #Update local model 72 | optimizer.step() 73 | #add loss for each iteration 74 | train_loss+=batch_loss/len(dataset) 75 | return local_model.state_dict(), train_loss/iters 76 | 77 | 78 | def test(model, dataloader): 79 | """ 80 | Tests the Federated global model for the given dataset 81 | :param model: Trained CNN model for testing 82 | :param dataloader: data iterator used to test the model 83 | :return test_loss: test loss for the given dataset 84 | :return preds: predictions for the given dataset 85 | :return accuracy: accuracy for the prediction values from the model 86 | """ 87 | criterion = torch.nn.CrossEntropyLoss() 88 | test_loss = 0.0 89 | correct = 0 90 | model.eval() 91 | for batch_idx, (data, target) in tqdm(enumerate(dataloader), total=len(dataloader.dataset)/BATCH_SIZE): 92 | data, target = data, target 93 | output = model(data) 94 | loss = criterion(output, target) 95 | test_loss += loss.item()*data.size(0) 96 | preds = output.argmax(dim=1, keepdim=True) 97 | correct += preds.eq(target.view_as(preds)).sum().item() 98 | accuracy = correct / len(dataloader.dataset) 99 | 100 | return test_loss/len(dataloader.dataset), preds, accuracy 101 | 102 | 103 | if __name__=="__main__": 104 | if not os.path.isdir('models'): 105 | os.mkdir('models') 106 | if not os.path.isdir('results'): 107 | os.mkdir('results') 108 | 109 | #Initialize a logger to log epoch results 110 | logname = ('results/log_federated_' + DATASET + "_" + str(NUM_EPOCHS) +"_"+ str(NUM_CLIENTS) + "_" + str(LOCAL_ITERS)) 111 | logging.basicConfig(filename=logname,level=logging.DEBUG) 112 | logger = logging.getLogger() 113 | 114 | 115 | #get data 116 | train_data, validation_data, test_data = load_dataset(val_split=0.2, batch_size=BATCH_SIZE, dataset=DATASET) 117 | if VIS_DATA: visualize_dataset([train, validation, test]) 118 | 119 | #distribute the trainning data across clients 120 | train_distributed_dataset = [[] for _ in range(NUM_CLIENTS)] 121 | for batch_idx, (data,target) in enumerate(train_data): 122 | train_distributed_dataset[batch_idx % NUM_CLIENTS].append((data, target)) 123 | 124 | #get model and define criterion for loss 125 | global_model = CNN() 126 | global_params = global_model.state_dict() 127 | 128 | global_model.train() 129 | all_train_loss = list() 130 | all_val_loss = list() 131 | val_loss_min = np.Inf 132 | 133 | #Train the model for given number of epochs 134 | for epoch in range(1, NUM_EPOCHS+1): 135 | print("\nEpoch :", str(epoch)) 136 | local_params, local_losses = [], [] 137 | #Send a copy of global model to each client 138 | for idx in range(NUM_CLIENTS): 139 | #Perform training on client side and get the parameters 140 | param, loss = train(copy.deepcopy(global_model), DEVICE, train_distributed_dataset[idx],LOCAL_ITERS) 141 | local_params.append(copy.deepcopy(param)) 142 | local_losses.append(copy.deepcopy(loss)) 143 | 144 | #Federated Average for the paramters from each client 145 | global_params = FedAvg(local_params) 146 | #Update the global model 147 | global_model.load_state_dict(global_params) 148 | all_train_loss.append(sum(local_losses)/len(local_losses)) 149 | 150 | #Test the global model 151 | val_loss, _, accuracy = test(global_model, validation_data) 152 | all_val_loss.append(val_loss) 153 | 154 | logger.info('Epoch: {}/{}, Train Loss: {:.8f}, Val Loss: {:.8f}, Val Accuracy: {:.8f}'\ 155 | .format(epoch , NUM_EPOCHS, all_train_loss[-1], val_loss, accuracy)) 156 | 157 | #if validation loss decreases, save the model 158 | if val_loss < val_loss_min: 159 | val_loss_min = val_loss 160 | logger.info("Saving Model State") 161 | torch.save(global_model.state_dict(), "models/" + DATASET + "_" + str(NUM_CLIENTS) + "_federated.sav") 162 | 163 | #load the best model from training 164 | global_model.load_state_dict(torch.load("models/"+ DATASET + "_" + str(NUM_CLIENTS) + "_federated.sav")) 165 | #test the model using test data 166 | test_loss, predictions, accuracy = test(global_model, test_data) 167 | logger.info('Test accuracy {:.8f}'.format(accuracy)) --------------------------------------------------------------------------------