├── utils ├── __init__.py ├── __pycache__ │ ├── merge.cpython-39.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-39.pyc │ ├── datasets.cpython-37.pyc │ ├── datasets.cpython-39.pyc │ ├── arg_parser.cpython-37.pyc │ ├── arg_parser.cpython-39.pyc │ ├── connections.cpython-37.pyc │ ├── connections.cpython-39.pyc │ ├── merge_grads.cpython-37.pyc │ ├── split_dataset.cpython-37.pyc │ ├── split_dataset.cpython-39.pyc │ ├── client_simulation.cpython-37.pyc │ ├── client_simulation.cpython-39.pyc │ ├── dataset_settings.cpython-39.pyc │ └── sharing_strategy.cpython-37.pyc ├── sharing_strategy.py ├── client_simulation.py ├── connections.py ├── merge.py ├── arg_parser.py ├── preprocess_eye_dataset_1.py ├── dataset_settings.py ├── get_eye_dataset.py ├── split_dataset.py ├── preprocess_eye_dataset_2.py └── datasets.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-39.pyc │ ├── resnet18.cpython-39.pyc │ ├── MNIST_CNN.cpython-37.pyc │ ├── MNIST_CNN.cpython-39.pyc │ ├── resnet18_split.cpython-39.pyc │ └── resnet18_setting3.cpython-39.pyc └── resnet18.py ├── requirements.txt ├── citation.cff ├── profiler_ours.py ├── .gitignore ├── ConnectedClient.py ├── profiler.py ├── client.py ├── PFSL.py ├── README.md ├── PFSL_DR.py ├── PFSL_Setting124.py ├── FL_Setting3.py ├── PFSL_Setting3.py ├── system_simulation_e2.py └── FL_DR.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/__pycache__/merge.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/merge.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet18.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/models/__pycache__/resnet18.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/datasets.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/datasets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/datasets.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/MNIST_CNN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/models/__pycache__/MNIST_CNN.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/MNIST_CNN.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/models/__pycache__/MNIST_CNN.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/arg_parser.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/arg_parser.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/arg_parser.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/arg_parser.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/connections.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/connections.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/connections.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/connections.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/merge_grads.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/merge_grads.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/split_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/split_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/split_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/split_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet18_split.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/models/__pycache__/resnet18_split.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/client_simulation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/client_simulation.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/client_simulation.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/client_simulation.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset_settings.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/dataset_settings.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sharing_strategy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/sharing_strategy.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet18_setting3.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/models/__pycache__/resnet18_setting3.cpython-39.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.4.2 2 | numpy==1.21.0 3 | opencv_python==4.5.1.48 4 | pandas==1.4.2 5 | Pillow==9.4.0 6 | requests==2.28.1 7 | scikit_learn==1.2.1 8 | scipy==1.8.1 9 | thop==0.1.1.post2209072238 10 | torch==1.13.0 11 | torchvision==0.14.0 12 | tqdm==4.59.0 13 | openpyxl==3.1.1 14 | -------------------------------------------------------------------------------- /utils/sharing_strategy.py: -------------------------------------------------------------------------------- 1 | from heapq import nlargest 2 | from heapq import nsmallest 3 | 4 | def min_loss(clients_loss, n_smallest): 5 | smallest_loss_clients = nsmallest(N, clients_loss, key=clients_loss.get) 6 | return smallest_loss_clients 7 | 8 | def best_test_acc(clients_acc, n_largest): 9 | largest_acc_clients = nlargest(N, clients_acc, key=clients_acc.get) 10 | return largest_acc_clients 11 | -------------------------------------------------------------------------------- /citation.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - given-names: "Manas Wadhwa" 5 | email-id: manasw@iitbhilai.ac.in 6 | - given-names: "Gagan Gupta" 7 | email-id: gagan@iitbhilai.ac.in 8 | - given-names: "Ashutosh Sahu" 9 | email-id: ashutoshsahu@iitbhilai.ac.in 10 | - given-names: "Rahul Saini" 11 | email-id: rahuls@iitbhilai.ac.in 12 | - given-names: "Vidhi Mittal" 13 | email-id: vidhimittal@iitbhilai.ac.in 14 | 15 | title: "PFSL" 16 | version: 1.0.0 17 | doi: 10.5281/zenodo.7739655 18 | date-released: 2023-02-11 19 | url: "https://github.com/mnswdhw/PFSL" 20 | -------------------------------------------------------------------------------- /utils/client_simulation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import os 4 | import sys 5 | sys.path.append('..') 6 | from client import Client 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.utils.data.sampler import RandomSampler 9 | 10 | 11 | def generate_random_client_ids(num_clients, id_len=4) -> list: 12 | client_ids = [] 13 | for _ in range(num_clients): 14 | client_ids.append(''.join(random.sample("abcdefghijklmnopqrstuvwxyz1234567890", id_len))) 15 | return client_ids 16 | 17 | 18 | def generate_random_clients(num_clients) -> dict: 19 | client_ids = generate_random_client_ids(num_clients) 20 | clients = {} 21 | for id in client_ids: 22 | clients[id] = Client(id) 23 | return clients 24 | -------------------------------------------------------------------------------- /utils/connections.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import socket 3 | import pickle 4 | 5 | def is_socket_closed(sock: socket.socket) -> bool: 6 | logger = logging.getLogger(__name__) 7 | try: 8 | # this will try to read bytes without blocking without removing them from buffer (peek only) 9 | data = sock.recv(16, socket.MSG_DONTWAIT | socket.MSG_PEEK) 10 | if len(data) == 0: 11 | return True 12 | except BlockingIOError: 13 | return False # socket is open and reading from it would block 14 | except ConnectionResetError: 15 | return True # socket was closed for some other reason 16 | except Exception: 17 | logger.exception("unexpected exception when checking if a socket is closed") 18 | return False 19 | return False 20 | 21 | 22 | def send_object(socket, data): 23 | socket.send(data) 24 | 25 | 26 | def get_object(socket): 27 | data = socket.recv() 28 | return data 29 | 30 | -------------------------------------------------------------------------------- /utils/merge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | 4 | def merge_grads(normalized_data_sizes, params): 5 | # params = [params_client1, 6 | # params_client2, 7 | # params_client3 8 | # ... 9 | # ] 10 | num_clients = len(params) 11 | for j,col in enumerate(zip(*params)): 12 | avg = 0 13 | for i,param in enumerate(col): 14 | avg += normalized_data_sizes[i] * param.grad 15 | # avg += param.grad 16 | 17 | # avg /= num_clients # (since we are already doing weighted adding of gradients) 18 | for param in col: 19 | param.grad = copy.deepcopy(avg) 20 | # print("is para grad equal to average?", param.grad) 21 | 22 | return 23 | 24 | 25 | 26 | def merge_weights(w): 27 | #after step op, merge weights 28 | 29 | w_avg = copy.deepcopy(w[0]) 30 | for k in w_avg.keys(): 31 | for i in range(1, len(w)): 32 | w_avg[k] += w[i][k] 33 | w_avg[k] = torch.div(w_avg[k], len(w)) 34 | 35 | return w_avg -------------------------------------------------------------------------------- /profiler_ours.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import resnet50 2 | from thop import profile 3 | import os 4 | import random 5 | import string 6 | import socket 7 | import requests 8 | import sys 9 | import threading 10 | import time 11 | import torch 12 | from math import ceil 13 | from torchvision import transforms 14 | from utils.split_dataset import split_dataset, split_dataset_cifar10tl_exp 15 | from utils.client_simulation import generate_random_clients 16 | from utils.connections import send_object 17 | from utils.arg_parser import parse_arguments 18 | import matplotlib.pyplot as plt 19 | import time 20 | import server 21 | import multiprocessing 22 | # from opacus import PrivacyEngine 23 | # from opacus.accountants import RDPAccountant 24 | # from opacus import GradSampleModule 25 | # from opacus.optimizers import DPOptimizer 26 | # from opacus.validators import ModuleValidator 27 | import torch.optim as optim 28 | import copy 29 | from datetime import datetime 30 | from scipy.interpolate import make_interp_spline 31 | import numpy as np 32 | from ConnectedClient import ConnectedClient 33 | import importlib 34 | from utils.merge import merge_grads, merge_weights 35 | import pandas as pd 36 | import time 37 | 38 | 39 | 40 | model = importlib.import_module(f'models.resnet18') 41 | model_cf = model.front(3, pretrained=True) 42 | model_cb = model.back(pretrained = True) 43 | model_center = model.center(pretrained=True) 44 | 45 | input = torch.randn(64,3,224,224) 46 | input_back = torch.randn(64, 512, 7, 7) 47 | macs_client_CF, params_FL = profile(model_cf, inputs=(input, )) 48 | macs_client_CB, params_SL = profile(model_cb, inputs=(input_back, )) 49 | 50 | print(f"GFLOPS CF {((2 * macs_client_CF) / (10**9))} PARAMS: {params_FL}") 51 | print(f"GFLOPS CB {((2 * macs_client_CB) / (10**9))} PARAMS: {params_SL}") 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | downloads/ 14 | eggs/ 15 | .eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | pip-wheel-metadata/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 94 | __pypackages__/ 95 | 96 | # Celery stuff 97 | celerybeat-schedule 98 | celerybeat.pid 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | results/ 130 | data/ -------------------------------------------------------------------------------- /utils/arg_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_arguments(): 5 | # Training settings 6 | parser = argparse.ArgumentParser( 7 | description="Split Learning Research Simulation entrypoint", 8 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 9 | ) 10 | 11 | parser.add_argument( 12 | "-c", 13 | "--number_of_clients", 14 | type=int, 15 | default=10, 16 | metavar="C", 17 | help="Number of Clients", 18 | ) 19 | 20 | parser.add_argument( 21 | "-b", 22 | "--batch_size", 23 | type=int, 24 | default=128, 25 | metavar="B", 26 | help="Batch size", 27 | ) 28 | parser.add_argument( 29 | "--test_batch_size", 30 | type=int, 31 | default=128, 32 | metavar="TB", 33 | help="Input batch size for testing", 34 | ) 35 | parser.add_argument( 36 | "-n", 37 | "--epochs", 38 | type=int, 39 | default=50, 40 | metavar="N", 41 | help="Total number of epochs to train", 42 | ) 43 | parser.add_argument( 44 | "--lr", 45 | type=float, 46 | default=0.001, 47 | metavar="LR", 48 | help="Learning rate", 49 | ) 50 | 51 | parser.add_argument( 52 | "--rate", 53 | type=float, 54 | default=0.5, 55 | help="dropoff rate", 56 | ) 57 | 58 | parser.add_argument( 59 | "--dataset", 60 | type=str, 61 | default="cifar10", 62 | help="States dataset to be used", 63 | ) 64 | parser.add_argument( 65 | "--seed", 66 | type=int, 67 | default=1234, 68 | help="Random seed", 69 | ) 70 | parser.add_argument( 71 | "--model", 72 | type=str, 73 | default="resnet18", 74 | help="Model you would like to train", 75 | ) 76 | parser.add_argument( 77 | "--epoch_batch", 78 | type=str, 79 | default="5", 80 | help="Number of epochs after which next batch of clients should join", 81 | ) 82 | parser.add_argument( 83 | "--opt_iden", 84 | type=str, 85 | default="", 86 | help="optional identifier of experiment", 87 | ) 88 | 89 | parser.add_argument( 90 | "--pretrained", 91 | action="store_true", 92 | default=False, 93 | help="Use transfer learning using a pretrained model", 94 | ) 95 | parser.add_argument( 96 | "--datapoints", 97 | type=int, 98 | default=500, 99 | help="Number of samples of training data allotted to each client", 100 | ) 101 | parser.add_argument( 102 | "--setting", 103 | type=str, 104 | default='setting1', 105 | help='Setting you would like to run for, i.e, setting1 , setting2 or setting4' 106 | 107 | ) 108 | parser.add_argument( 109 | "--checkpoint", 110 | type=int, 111 | default=50, 112 | help="Epoch at which personalisation phase will start", 113 | ) 114 | 115 | args = parser.parse_args() 116 | return args 117 | -------------------------------------------------------------------------------- /utils/preprocess_eye_dataset_1.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import cv2 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset, DataLoader 6 | import pandas as pd 7 | import os 8 | import pickle 9 | import torchvision.transforms as transforms 10 | 11 | 12 | def crop_image_from_gray(img,tol=7): 13 | if img.ndim ==2: 14 | mask = img>tol 15 | return img[np.ix_(mask.any(1),mask.any(0))] 16 | elif img.ndim==3: 17 | gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 18 | mask = gray_img>tol 19 | 20 | check_shape = img[:,:,0][np.ix_(mask.any(1),mask.any(0))].shape[0] 21 | if (check_shape == 0): # image is too dark so that we crop out everything, 22 | return img # return original image 23 | else: 24 | img1=img[:,:,0][np.ix_(mask.any(1),mask.any(0))] 25 | img2=img[:,:,1][np.ix_(mask.any(1),mask.any(0))] 26 | img3=img[:,:,2][np.ix_(mask.any(1),mask.any(0))] 27 | # print(img1.shape,img2.shape,img3.shape) 28 | img = np.stack([img1,img2,img3],axis=-1) 29 | # print(img.shape) 30 | return img 31 | 32 | 33 | def circle_crop(img, sigmaX): 34 | """ 35 | Create circular crop around image centre 36 | """ 37 | 38 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 39 | img = crop_image_from_gray(img) 40 | height, width, depth = img.shape 41 | 42 | x = int(width/2) 43 | y = int(height/2) 44 | r = np.amin((x,y)) 45 | 46 | circle_img = np.zeros((height, width), np.uint8) 47 | cv2.circle(circle_img, (x,y), int(r), 1, thickness=-1) 48 | img = cv2.bitwise_and(img, img, mask=circle_img) 49 | img = crop_image_from_gray(img) 50 | img=cv2.resize(img, (224, 224)) 51 | img=cv2.addWeighted(img,4, cv2.GaussianBlur( img , (0,0) , sigmaX) ,-4 ,128) 52 | 53 | return img 54 | 55 | 56 | def load_data(): 57 | train = pd.read_csv('data/eye_dataset1/train.csv') 58 | test = pd.read_csv('data/eye_dataset1/test.csv') 59 | 60 | train_dir = os.path.join('data/eye_dataset1/train_images') 61 | test_dir = os.path.join('data/eye_dataset1/test_images') 62 | 63 | train['file_path'] = train['id_code'].map(lambda x: os.path.join(train_dir,'{}.png'.format(x))) 64 | test['file_path'] = test['id_code'].map(lambda x: os.path.join(test_dir,'{}.png'.format(x))) 65 | 66 | train['file_name'] = train["id_code"].apply(lambda x: x + ".png") 67 | test['file_name'] = test["id_code"].apply(lambda x: x + ".png") 68 | 69 | train['diagnosis'] = train['diagnosis'].astype(str) 70 | 71 | return train,test 72 | 73 | 74 | from tqdm import tqdm 75 | 76 | df_train,df_test = load_data() 77 | print(df_train['diagnosis'].value_counts()) 78 | X_train=[] 79 | Y_train=[] 80 | count_0, count_1, count_2=0,0,0 81 | print(type(df_train.diagnosis.iloc[0])) 82 | 83 | for i in tqdm(range(0,len(df_train))): 84 | 85 | img = cv2.imread(df_train.file_path.iloc[i]) 86 | 87 | img = circle_crop(img,sigmaX=10) 88 | X_train.append(img) 89 | temp=[] 90 | 91 | if(int(df_train.diagnosis.iloc[i])==0): 92 | count_0+=1 93 | ans=0 94 | elif(int(df_train.diagnosis.iloc[i])==1 or int(df_train.diagnosis.iloc[i])==2): 95 | count_1+=1 96 | ans=1 97 | elif(int(df_train.diagnosis.iloc[i])==3 or int(df_train.diagnosis.iloc[i])==4): 98 | count_2+=1 99 | ans=2 100 | Y_train.append(ans) 101 | 102 | with open('data/x_train_eye_1_1', 'wb') as pickle_file: 103 | pickle.dump(X_train, pickle_file) 104 | with open('data/y_train_eye_1_1', 'wb') as pickle_file: 105 | pickle.dump(Y_train, pickle_file) 106 | 107 | -------------------------------------------------------------------------------- /ConnectedClient.py: -------------------------------------------------------------------------------- 1 | from threading import Thread 2 | from utils.connections import is_socket_closed 3 | from utils.connections import send_object 4 | from utils.connections import get_object 5 | import pickle 6 | import queue 7 | import struct 8 | import torch 9 | 10 | 11 | def handle(client, addr, file): 12 | buffsize = 1024 13 | # file = '/home/ashutosh/score_report.pdf' 14 | # print('File size:', os.path.getsize(file)) 15 | fsize = struct.pack('!I', len(file)) 16 | print('Len of file size struct:', len(fsize)) 17 | client.send(fsize) 18 | # with open(file, 'rb') as fd: 19 | while True: 20 | chunk = fd.read(buffsize) 21 | if not chunk: 22 | break 23 | client.send(chunk) 24 | fd.seek(0) 25 | hash = hashlib.sha512() 26 | while True: 27 | chunk = fd.read(buffsize) 28 | if not chunk: 29 | break 30 | hash.update(chunk) 31 | client.send(hash.digest()) 32 | 33 | 34 | class ConnectedClient(object): 35 | # def __init__(self, id, conn, address, loop_time=1/60, *args, **kwargs): 36 | def __init__(self, id, conn, *args, **kwargs): 37 | super(ConnectedClient, self).__init__(*args, **kwargs) 38 | self.id = id 39 | self.conn = conn 40 | self.front_model = None 41 | self.back_model = None 42 | self.center_model = None 43 | self.train_fun = None 44 | self.test_fun = None 45 | self.keepRunning = True 46 | self.a1 = None 47 | self.a2 = None 48 | self.center_optimizer = None 49 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 50 | 51 | 52 | # def onThread(self, function, *args, **kwargs): 53 | # self.q.put((function, args, kwargs)) 54 | 55 | 56 | # def run(self, loop_time=1/60, *args, **kwargs): 57 | # super(ConnectedClient, self).run(*args, **kwargs) 58 | # while True: 59 | # try: 60 | # function, args, kwargs = self.q.get(timeout=self.timeout) 61 | # function(*args, **kwargs) 62 | # except queue.Empty: 63 | # self.idle() 64 | 65 | def forward_center(self): 66 | self.activations2 = self.center_model(self.remote_activations1) 67 | self.remote_activations2 = self.activations2.detach().requires_grad_(True) 68 | 69 | 70 | def backward_center(self): 71 | self.activations2.backward(self.remote_activations2.grad) 72 | 73 | 74 | def idle(self): 75 | pass 76 | 77 | 78 | def connect(self): 79 | pass 80 | 81 | 82 | def disconnect(self): 83 | if not is_socket_closed(self.conn): 84 | self.conn.close() 85 | return True 86 | else: 87 | return False 88 | 89 | 90 | # def _send_model(self): 91 | def send_model(self): 92 | model = {'front': self.front_model, 'back': self.back_model} 93 | send_object(self.conn, model) 94 | # handle(self.conn, self.address, model) 95 | 96 | 97 | # def send_optimizers(self): 98 | # # This is just a sample code and NOT optimizers. Need to write code for initializing optimizers 99 | # optimizers = {'front': self.front_model.parameters(), 'back': self.back_model.parameters()} 100 | # send_object(self.conn, optimizers) 101 | 102 | 103 | def send_activations(self, activations): 104 | send_object(self.conn, activations) 105 | 106 | 107 | def get_remote_activations1(self): 108 | self.remote_activations1 = get_object(self.conn) 109 | 110 | 111 | def send_remote_activations2(self): 112 | send_object(self.conn, self.remote_activations2) 113 | 114 | 115 | def get_remote_activations2_grads(self): 116 | self.remote_activations2.grad = get_object(self.conn) 117 | 118 | 119 | def send_remote_activations1_grads(self): 120 | send_object(self.conn, self.remote_activations1.grad) 121 | 122 | # def send_model(self): 123 | # self.onThread(self._send_model) 124 | -------------------------------------------------------------------------------- /models/resnet18.py: -------------------------------------------------------------------------------- 1 | from torchvision import models 2 | import torch.nn as nn 3 | 4 | # change these manually for now 5 | num_front_layers = 4 6 | num_back_layers = 3 7 | # remember unfreeze will be counted from the end of each model 8 | num_unfrozen_front_layers = 0 9 | num_unfrozen_center_layers = 1 #exp with 1 or 2(max) 10 | num_unfrozen_back_layers = 4 11 | 12 | def get_resnet18(pretrained: bool): 13 | model = models.resnet18(pretrained=pretrained) # this will use cached model if available instead of downloadinig again 14 | return model 15 | 16 | 17 | class front(nn.Module): 18 | def __init__(self, input_channels=3, pretrained=False): 19 | super(front, self).__init__() 20 | model = get_resnet18(pretrained) 21 | model_children = list(model.children()) 22 | self.input_channels = input_channels 23 | if self.input_channels == 1: 24 | self.conv_channel_change = nn.Conv2d(1,3,3,1,2) #to keep the image size same as input image size to this conv layer 25 | self.front_model = nn.Sequential(*model_children[:num_front_layers]) 26 | 27 | if pretrained: 28 | layer_iterator = iter(self.front_model) 29 | for i in range(num_front_layers-num_unfrozen_front_layers): 30 | layer = layer_iterator.__next__() 31 | for param in layer.parameters(): 32 | param.requires_grad = False 33 | 34 | def forward(self, x): 35 | 36 | if self.input_channels == 1: 37 | x = self.conv_channel_change(x) 38 | x = self.front_model(x) 39 | return x 40 | 41 | 42 | class center(nn.Module): 43 | def __init__(self, pretrained=False): 44 | super(center, self).__init__() 45 | model = get_resnet18(pretrained) 46 | model_children = list(model.children()) 47 | global center_model_length 48 | center_model_length = len(model_children) - num_front_layers - num_back_layers 49 | 50 | 51 | self.center_model = nn.Sequential(*model_children[num_front_layers:center_model_length+num_front_layers]) 52 | 53 | 54 | if pretrained: 55 | 56 | layer_iterator = iter(self.center_model) 57 | 58 | for i in range(center_model_length-num_unfrozen_center_layers): 59 | layer = layer_iterator.__next__() 60 | for param in layer.parameters(): 61 | param.requires_grad = False 62 | 63 | def freeze(self, epoch, pretrained=False): 64 | 65 | 66 | num_unfrozen_center_layers=0 67 | if pretrained: 68 | 69 | layer_iterator = iter(self.center_model) 70 | for i in range(center_model_length-num_unfrozen_center_layers): 71 | layer = layer_iterator.__next__() 72 | for param in layer.parameters(): 73 | param.requires_grad = False 74 | 75 | def forward(self, x): 76 | x = self.center_model(x) 77 | return x 78 | 79 | 80 | 81 | class back(nn.Module): 82 | def __init__(self, pretrained=False, output_dim=10): 83 | super(back, self).__init__() 84 | model = get_resnet18(pretrained) 85 | model_children = list(model.children()) 86 | model_length = len(model_children) 87 | 88 | fc_layer = nn.Linear(512, output_dim) 89 | model_children = model_children[:-1] + [nn.Flatten()] + [fc_layer] 90 | self.back_model = nn.Sequential(*model_children[model_length-num_back_layers:]) 91 | 92 | if pretrained: 93 | layer_iterator = iter(self.back_model) 94 | for i in range(num_back_layers-num_unfrozen_back_layers): 95 | layer = layer_iterator.__next__() 96 | for param in layer.parameters(): 97 | param.requires_grad = False 98 | 99 | 100 | def forward(self, x): 101 | x = self.back_model(x) 102 | return x 103 | 104 | 105 | if __name__ == '__main__': 106 | model = front(pretrained=True) 107 | print(f'{model.front_model}\n\n') 108 | model = center(pretrained=True) 109 | print(f'{model.center_model}\n\n') 110 | model = back(pretrained=True) 111 | print(f'{model.back_model}') 112 | -------------------------------------------------------------------------------- /utils/dataset_settings.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | 5 | def setting2(train_full_dataset,test_full_dataset, num_users): 6 | dict_users, dict_users_test = {}, {} 7 | for i in range(num_users): 8 | dict_users[i]=[] 9 | dict_users_test[i]=[] 10 | 11 | df=pd.DataFrame(list(train_full_dataset), columns=['images', 'labels']) 12 | df_test=pd.DataFrame(list(test_full_dataset), columns=['images', 'labels']) 13 | num_of_classes=len(df['labels'].unique()) 14 | 15 | dict_classwise={} 16 | dict_classwise_test={} 17 | 18 | 19 | for i in range(num_of_classes): 20 | dict_classwise[i] = df[df['labels']==i].index.values.astype(int) 21 | 22 | for i in range(num_of_classes): 23 | dict_classwise_test[i] = df_test[df_test['labels']==i].index.values.astype(int) 24 | 25 | for i in range(num_users): 26 | 27 | for j in range(num_of_classes): 28 | if(i==j or (i+1)%10==j): 29 | temp=list(np.random.choice(dict_classwise[j], 225, replace = False)) 30 | dict_users[i].extend(temp) 31 | dict_classwise[j] = list(set(dict_classwise[j]) -set( temp)) 32 | 33 | elif((i+2)%10==j or (i+3)%10==j): 34 | temp=list(np.random.choice(dict_classwise[j], 7, replace = False)) 35 | dict_users[i].extend(temp) 36 | dict_classwise[j] = list(set(dict_classwise[j]) -set( temp)) 37 | 38 | else: 39 | temp=list(np.random.choice(dict_classwise[j],6, replace = False)) 40 | dict_users[i].extend(temp) 41 | dict_classwise[j] = list(set(dict_classwise[j]) -set( temp)) 42 | 43 | for i in range(num_users): 44 | 45 | for j in range(num_of_classes): 46 | if(i==j or (i+1)%10==j): 47 | temp=list(np.random.choice(dict_classwise_test[j],450 , replace = False)) 48 | dict_users_test[i].extend(temp) 49 | dict_classwise_test[j] = list(set(dict_classwise_test[j]) -set( temp)) 50 | elif((i+2)%10==j or (i+3)%10==j or (i+4)%10==j or (i+5)%10==j): 51 | temp=list(np.random.choice(dict_classwise_test[j], 13, replace = False)) 52 | dict_users_test[i].extend(temp) 53 | dict_classwise_test[j] = list(set(dict_classwise_test[j]) -set( temp)) 54 | else: 55 | temp=list(np.random.choice(dict_classwise_test[j], 12, replace = False)) 56 | dict_users_test[i].extend(temp) 57 | dict_classwise_test[j] = list(set(dict_classwise_test[j]) -set( temp)) 58 | 59 | 60 | return dict_users , dict_users_test 61 | 62 | 63 | def setting1(dataset, num_users, datapoints): 64 | 65 | dict_users = {} 66 | 67 | for i in range(num_users): 68 | dict_users[i]=[] 69 | df=pd.DataFrame(list(dataset), columns=['images', 'labels']) 70 | num_of_classes=len(df['labels'].unique()) 71 | 72 | per_class_client=int(datapoints/num_of_classes) 73 | per_class_total=per_class_client*num_users 74 | 75 | dict_classwise={} 76 | 77 | for i in range(num_of_classes): 78 | dict_classwise[i] = df[df['labels']==i].index.values.astype(int)[:per_class_total] 79 | 80 | for i in range(num_users): 81 | 82 | for j in range(num_of_classes): 83 | temp=list(np.random.choice(dict_classwise[j], per_class_client, replace = False)) 84 | dict_users[i].extend(temp) 85 | dict_classwise[j] = list(set(dict_classwise[j]) -set( temp)) 86 | 87 | return dict_users 88 | 89 | def get_test_dict(dataset, num_users): 90 | 91 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 92 | val = int(len(dataset)//num_users) 93 | #giving the same test set to all the clients 94 | test_set = set(np.random.choice(all_idxs, 2000, replace=False)) 95 | 96 | for i in range(num_users): 97 | dict_users[i] = test_set 98 | return dict_users 99 | 100 | 101 | def get_dicts(train_full_dataset, test_full_dataset, num_users, setting, datapoints): 102 | 103 | if setting == 'setting2': 104 | dict_users, dict_users_test=setting2(train_full_dataset, test_full_dataset, num_users) 105 | 106 | elif setting == 'setting1': 107 | dict_users = setting1(train_full_dataset, num_users, datapoints) 108 | dict_users_test=get_test_dict(test_full_dataset, num_users) 109 | 110 | return dict_users, dict_users_test -------------------------------------------------------------------------------- /utils/get_eye_dataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from torch.utils.data import DataLoader, Dataset 3 | import numpy as np 4 | import torchvision.transforms as transforms 5 | import random 6 | 7 | with open('data/x_train_eye_1_1', 'rb') as f: 8 | x_train1 = pickle.load(f) 9 | with open('data/y_train_eye_1_1', 'rb') as f: 10 | y_train1 = pickle.load(f) 11 | 12 | with open('data/x_train_eye_2', 'rb') as f: 13 | x_train2 = pickle.load(f) 14 | with open('data/y_train_eye_2', 'rb') as f: 15 | y_train2 = pickle.load(f) 16 | 17 | # a_list = [1,2,3,4,1,2,1,2,3,4] 18 | def find_indices(list_to_check, item_to_find): 19 | array = np.array(list_to_check) 20 | indices = np.where(array == item_to_find)[0] 21 | return list(indices) 22 | 23 | 24 | # test_1= 25 | # rndm_idxs=list(np.random.choice( list(range(0,3500)),1000, replace = False)) 26 | x_test1=[] 27 | y_test1=[] 28 | x_test2=[] 29 | y_test2=[] 30 | 31 | temp = x_train1[0: 1000] 32 | x_test1.extend(temp) 33 | 34 | temp = y_train1[0: 1000] 35 | y_test1.extend(temp) 36 | 37 | x_train1=x_train1[1000:] 38 | y_train1=y_train1[1000:] 39 | 40 | 41 | temp = x_train2[0: 1000] 42 | x_test2.extend(temp) 43 | temp = y_train2[0: 1000] 44 | y_test2.extend(temp) 45 | 46 | x_train2=x_train2[1000:] 47 | y_train2= y_train2[1000:] 48 | 49 | list_0=find_indices(y_train1,0) 50 | list_1=find_indices(y_train1,1) 51 | list_2=find_indices(y_train1,2) 52 | 53 | 54 | list_0=find_indices(y_train2,0) 55 | list_1=find_indices(y_train2,1) 56 | list_2=find_indices(y_train2,2) 57 | 58 | 59 | list_0=find_indices(y_test1,0) 60 | list_1=find_indices(y_test1,1) 61 | list_2=find_indices(y_test1,2) 62 | 63 | 64 | list_0=find_indices(y_test2,0) 65 | list_1=find_indices(y_test2,1) 66 | list_2=find_indices(y_test2,2) 67 | 68 | 69 | class CreateDataset(Dataset): 70 | def __init__(self, x,y,idxs, transform=None): 71 | super().__init__() 72 | self.x=x 73 | self.y=y 74 | self.transform = transform 75 | self.idxs=list(idxs) 76 | 77 | def __len__(self): 78 | return len(self.idxs) 79 | 80 | def __getitem__(self, index): 81 | image,label = self.x[self.idxs[index]], self.y[self.idxs[index]] 82 | 83 | if self.transform is not None: 84 | image = self.transform(image) 85 | 86 | return image, label 87 | 88 | def get_eye_data(idxs, num): 89 | if(num==1): 90 | y_train=y_train1 91 | x_train=x_train1 92 | elif(num==2): 93 | y_train=y_train2 94 | x_train=x_train2 95 | elif(num==3): 96 | y_train=y_test1 97 | x_train=x_test1 98 | elif(num==4): 99 | y_train=y_test2 100 | x_train=x_test2 101 | transform_train= transforms.Compose([ 102 | transforms.ToPILImage(), 103 | transforms.Resize((224, 224)), 104 | transforms.RandomHorizontalFlip(p=0.4), 105 | #transforms.ColorJitter(brightness=2, contrast=2), 106 | transforms.ToTensor(), 107 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 108 | ]) 109 | train_data = CreateDataset(x=x_train, y=y_train, idxs=idxs, transform=transform_train) 110 | return train_data 111 | 112 | 113 | 114 | def get_idxs(): 115 | 116 | l1=[250,180,70] 117 | l2=[370,109,21] 118 | # l3=[295,157,48] 119 | dict_users, dict_users_test={}, {} 120 | client_id=0 121 | 122 | # dict_users_test=get_test_idxs() 123 | test_list=list(range(0,1000)) 124 | for num in range(0,2): 125 | data_id=num+1 126 | if(num==0): 127 | y_train=y_train1 128 | x_train=x_train1 129 | l=l1 130 | elif(num==1): 131 | y_train=y_train2 132 | x_train=x_train2 133 | l=l2 134 | 135 | list_0=find_indices(y_train,0) 136 | list_1=find_indices(y_train,1) 137 | list_2=find_indices(y_train,2) 138 | 139 | for i in range(0,5): 140 | train_idxs=[] 141 | 142 | 143 | temp=list(np.random.choice(list_0,l[0] , replace = False)) 144 | train_idxs.extend(temp) 145 | list_0 = list(set(list_0) -set( temp)) 146 | 147 | 148 | temp=list(np.random.choice(list_1,l[1] , replace = False)) 149 | train_idxs.extend(temp) 150 | list_1 = list(set(list_1) -set( temp)) 151 | 152 | 153 | temp=list(np.random.choice(list_2,l[2] , replace = False)) 154 | train_idxs.extend(temp) 155 | list_2 = list(set(list_2) -set( temp)) 156 | 157 | 158 | 159 | 160 | dict_users[client_id]=train_idxs 161 | dict_users_test[client_id]=test_list 162 | 163 | client_id+=1 164 | 165 | 166 | 167 | return dict_users, dict_users_test 168 | 169 | 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /utils/split_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import sys 4 | from utils import datasets 5 | import pickle 6 | from torch.utils.data import Dataset, random_split 7 | import pandas as pd 8 | import numpy as np 9 | 10 | class DatasetFromSubset(Dataset): 11 | def __init__(self, subset, transform=None): 12 | self.subset = subset 13 | self.transform = transform 14 | 15 | def __getitem__(self, index): 16 | x, y = self.subset[index] 17 | if self.transform: 18 | x = self.transform(x) 19 | return x, y 20 | 21 | def __len__(self): 22 | return len(self.subset) 23 | 24 | 25 | def split_dataset(dataset: str, client_ids: list, datapoints = None, pretrained = False, output_dir='data' ): 26 | print('Splitting dataset (may take some time)...', end='') 27 | 28 | num_clients = len(client_ids) 29 | 30 | train_dataset, test_dataset, input_channels = datasets.load_full_dataset(dataset, output_dir,num_clients, datapoints, pretrained) 31 | per_client_trainset_size = len(train_dataset)//len(client_ids) 32 | train_split = [per_client_trainset_size]*len(client_ids) 33 | # train_split.append(len(train_dataset)-per_client_trainset_size*(len(client_ids)-1)) 34 | 35 | per_client_testset_size = len(test_dataset)//len(client_ids) 36 | test_split = [per_client_testset_size]*len(client_ids) 37 | # test_split.append(len(test_dataset)-test_batch_size*(num_clients-1)) 38 | 39 | train_datasets = list(torch.utils.data.random_split(train_dataset, train_split)) 40 | test_datasets = list(torch.utils.data.random_split(test_dataset, test_split)) 41 | # print(type(train_datasets[0])) 42 | for i in range(len(client_ids)): 43 | out_dir = f'{output_dir}/{dataset}/{client_ids[i]}' 44 | os.makedirs(out_dir + '/train', exist_ok=True) 45 | os.makedirs(out_dir + '/test', exist_ok=True) 46 | torch.save(train_datasets[i], out_dir + f'/train/{client_ids[i]}.pt') 47 | torch.save(test_datasets[i], out_dir + f'/test/{client_ids[i]}.pt') 48 | print('Done') 49 | 50 | return len(train_dataset), input_channels 51 | 52 | 53 | def split_dataset_cifar10tl_exp(client_ids: list, datapoints, output_dir='data'): 54 | print('Splitting dataset (may take some time)...', end='') 55 | 56 | num_clients = len(client_ids) 57 | train_dataset, test_dataset, input_channels = datasets.load_full_dataset("cifar10_tl", output_dir, num_clients, datapoints) 58 | per_client_trainset_size = datapoints 59 | train_split = [per_client_trainset_size]*num_clients 60 | train_datasets = list(torch.utils.data.random_split(train_dataset, train_split)) 61 | 62 | for i in range(len(client_ids)): 63 | out_dir = f'{output_dir}/cifar10_tl/{client_ids[i]}' 64 | os.makedirs(out_dir + '/train', exist_ok=True) 65 | os.makedirs(out_dir + '/test', exist_ok=True) 66 | torch.save(train_datasets[i], out_dir + f'/train/{client_ids[i]}.pt') 67 | torch.save(test_dataset, out_dir + f'/test/{client_ids[i]}.pt') 68 | print('Done') 69 | 70 | return len(train_dataset), input_channels 71 | 72 | 73 | def split_dataset_cifar_setting2(client_ids: list, train_dataset, test_dataset, u_datapoints = 2000, c_datapoints = 150): 74 | 75 | print("Unique datapoints", u_datapoints) 76 | print("Common datapoints", c_datapoints) 77 | 78 | print('Splitting dataset (may take some time)...', end='') 79 | 80 | num_users = len(client_ids) 81 | 82 | #test split 83 | dict_users_test, all_idxs_test = {}, [i for i in range(len(test_dataset))] 84 | test_ids = set(np.random.choice(all_idxs_test, 2000, replace=False)) 85 | for i in range(num_users): 86 | dict_users_test[i] = test_ids 87 | 88 | 89 | #u_datapoints are the number of datapoints with the first client 90 | #c_datapoints are the number of datapoints with each of the remaining clients 91 | 92 | #train split 93 | dict_users_train = {} 94 | 95 | for i in range(num_users): 96 | dict_users_train[i]=[] 97 | 98 | df=pd.DataFrame(list(train_dataset), columns=['images', 'labels']) 99 | num_of_classes=len(df['labels'].unique()) 100 | 101 | per_class_client=int(c_datapoints/num_of_classes) 102 | per_class_total=per_class_client*num_users 103 | 104 | per_class_uclient = int(u_datapoints/num_of_classes) 105 | per_class_total += per_class_uclient #2000/10 = 200 106 | 107 | dict_classwise={} 108 | 109 | for i in range(num_of_classes): 110 | dict_classwise[i] = df[df['labels']==i].index.values.astype(int)[:per_class_total] 111 | 112 | for j in range(num_of_classes): 113 | temp=list(np.random.choice(dict_classwise[j], per_class_uclient, replace = False)) 114 | dict_users_train[0].extend(temp) 115 | dict_classwise[j] = list(set(dict_classwise[j]) -set( temp)) 116 | 117 | 118 | for i in range(1, num_users): 119 | 120 | for j in range(num_of_classes): 121 | temp=list(np.random.choice(dict_classwise[j], per_class_client, replace = False)) 122 | dict_users_train[i].extend(temp) 123 | dict_classwise[j] = list(set(dict_classwise[j]) -set( temp)) 124 | 125 | 126 | return dict_users_train, dict_users_test 127 | 128 | 129 | -------------------------------------------------------------------------------- /utils/preprocess_eye_dataset_2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | import shutil 5 | import pathlib 6 | import random 7 | import datetime 8 | import cv2 9 | import os 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import matplotlib.image as mpimg 13 | import pickle 14 | from sklearn.metrics import classification_report, precision_recall_fscore_support, accuracy_score, confusion_matrix 15 | from tqdm import tqdm 16 | 17 | dir_path='data/eye_dataset2/eyepacs_preprocess/eyepacs_preprocess' 18 | df_temp = pd.read_csv("data/eye_dataset2/trainLabels.csv") 19 | print(len(df_temp)) 20 | df_temp['level'].value_counts() 21 | 22 | 23 | def load_data(): 24 | train = pd.read_csv("data/eye_dataset2/trainLabels.csv") 25 | # test = pd.read_csv('data/test.csv') 26 | 27 | train_dir = os.path.join('data/eye_dataset2/eyepacs_preprocess/eyepacs_preprocess') 28 | # test_dir = os.path.join('data/test_images') 29 | 30 | train['file_path'] = train['image'].map(lambda x: os.path.join(train_dir,'{}.jpeg'.format(x))) 31 | # test['file_path'] = test['id_code'].map(lambda x: os.path.join(test_dir,'{}.png'.format(x))) 32 | 33 | train['file_name'] = train["image"].apply(lambda x: x + ".jpeg") 34 | # test['file_name'] = test["id_code"].apply(lambda x: x + ".png") 35 | 36 | train['diagnosis'] = train['level'].astype(str) 37 | 38 | return train 39 | 40 | 41 | def wiener_filter(img, kernel, K): 42 | kernel /= np.sum(kernel) 43 | dummy = np.copy(img) 44 | dummy = np.fft.fft2(dummy) 45 | kernel = np.fft.fft2(kernel, s = img.shape) 46 | kernel = np.conj(kernel) / (np.abs(kernel) ** 2 + K) 47 | dummy = dummy * kernel 48 | dummy = np.abs(np.fft.ifft2(dummy)) 49 | return dummy 50 | 51 | def gaussian_kernel(kernel_size = 3): 52 | h = gaussian(kernel_size, kernel_size / 3).reshape(kernel_size, 1) 53 | h = np.dot(h, h.transpose()) 54 | h /= np.sum(h) 55 | return h 56 | 57 | def isbright(image, dim=227, thresh=0.4): 58 | # Resize image to 10x10 59 | image = cv2.resize(image, (dim, dim)) 60 | # Convert color space to LAB format and extract L channel 61 | L, A, B = cv2.split(cv2.cvtColor(image, cv2.COLOR_BGR2LAB)) 62 | # Normalize L channel by dividing all pixel 63 | L = L/np.max(L) 64 | # Return True if mean is greater than thresh else False 65 | return np.mean(L) > thresh 66 | 67 | 68 | def crop_image_from_gray(img,tol=7): 69 | if img.ndim ==2: 70 | mask = img>tol 71 | return img[np.ix_(mask.any(1),mask.any(0))] 72 | elif img.ndim==3: 73 | gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 74 | mask = gray_img>tol 75 | 76 | check_shape = img[:,:,0][np.ix_(mask.any(1),mask.any(0))].shape[0] 77 | if (check_shape == 0): # image is too dark so that we crop out everything, 78 | return img # return original image 79 | else: 80 | img1=img[:,:,0][np.ix_(mask.any(1),mask.any(0))] 81 | img2=img[:,:,1][np.ix_(mask.any(1),mask.any(0))] 82 | img3=img[:,:,2][np.ix_(mask.any(1),mask.any(0))] 83 | # print(img1.shape,img2.shape,img3.shape) 84 | img = np.stack([img1,img2,img3],axis=-1) 85 | # print(img.shape) 86 | return img 87 | 88 | 89 | def circle_crop(img, sigmaX): 90 | """ 91 | Create circular crop around image centre 92 | """ 93 | 94 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 95 | img = crop_image_from_gray(img) 96 | height, width, depth = img.shape 97 | 98 | x = int(width/2) 99 | y = int(height/2) 100 | r = np.amin((x,y)) 101 | 102 | circle_img = np.zeros((height, width), np.uint8) 103 | cv2.circle(circle_img, (x,y), int(r), 1, thickness=-1) 104 | img = cv2.bitwise_and(img, img, mask=circle_img) 105 | img = crop_image_from_gray(img) 106 | img=cv2.resize(img, (224, 224)) 107 | img=cv2.addWeighted(img,4, cv2.GaussianBlur( img , (0,0) , sigmaX) ,-4 ,128) 108 | 109 | return img 110 | 111 | 112 | 113 | def image_preprocessing(img): 114 | # 1. Read the image 115 | # img = mpimg.imread(img_path) 116 | img = img.astype(np.uint8) 117 | 118 | # 2. Extract the green channel of the image 119 | b, g, r = cv2.split(img) 120 | 121 | # 3.1. Apply CLAHE to intensify the green channel extracted image 122 | clh = cv2.createCLAHE(clipLimit=4.0) 123 | g = clh.apply(g) 124 | 125 | # 3.2. Convert enhanced image to grayscale 126 | merged_bgr_green_fused = cv2.merge((b, g, r)) 127 | img_bw = cv2.cvtColor(merged_bgr_green_fused, cv2.COLOR_BGR2GRAY) 128 | 129 | # 4. Remove the isolated pixels using morphological cleaning operation. 130 | kernel1 = np.ones((1, 1), np.uint8) 131 | morph_open = cv2.morphologyEx(img_bw, cv2.MORPH_OPEN, kernel1) 132 | 133 | # 5. Extract blood vessels using mean-C thresholding. 134 | thresh = cv2.adaptiveThreshold(morph_open, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 9, 5) 135 | kernel2 = np.ones((2, 2), np.uint8) 136 | morph_open2 = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel2) 137 | 138 | # 6. Stacking the image into 3 channels 139 | stacked_img = np.stack((morph_open2,)*3, axis=-1) 140 | 141 | return stacked_img.astype("float64") 142 | 143 | 144 | df_train=load_data() 145 | print(df_train['diagnosis'].value_counts()) 146 | X_train=[] 147 | Y_train=[] 148 | count_0, count_1, count_2=0,0,0 149 | print(type(df_train.diagnosis.iloc[0])) 150 | 151 | idxs=list(np.random.choice( list(range(0,35000)),5000, replace = False)) 152 | 153 | 154 | for idx in tqdm(range(0,len(idxs))): 155 | i=idxs[idx] 156 | img = cv2.imread(df_train.file_path.iloc[i]) 157 | # print(type(img)) 158 | 159 | img = circle_crop(img,sigmaX=10) 160 | X_train.append(img) 161 | 162 | temp=[] 163 | 164 | if(int(df_train.diagnosis.iloc[i])==0): 165 | count_0+=1 166 | ans=0 167 | elif(int(df_train.diagnosis.iloc[i])==1 or int(df_train.diagnosis.iloc[i])==2): 168 | count_1+=1 169 | ans=1 170 | elif(int(df_train.diagnosis.iloc[i])==3 or int(df_train.diagnosis.iloc[i])==4): 171 | count_2+=1 172 | ans=2 173 | Y_train.append(ans) 174 | 175 | with open('data/x_train_eye_2', 'wb') as pickle_file: 176 | pickle.dump(X_train, pickle_file) 177 | with open('data/y_train_eye_2', 'wb') as pickle_file: 178 | pickle.dump(Y_train, pickle_file) 179 | 180 | print(count_0) 181 | print(count_1) 182 | print(count_2) 183 | 184 | -------------------------------------------------------------------------------- /profiler.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import resnet50 2 | from thop import profile 3 | import torch 4 | from torch import nn 5 | from torchvision import transforms 6 | from torch.utils.data import DataLoader, Dataset 7 | from pandas import DataFrame 8 | import pandas as pd 9 | from sklearn.model_selection import train_test_split 10 | from PIL import Image 11 | from glob import glob 12 | import math 13 | import random 14 | import numpy as np 15 | import os 16 | import matplotlib 17 | matplotlib.use('Agg') 18 | import matplotlib.pyplot as plt 19 | import copy 20 | import argparse 21 | from utils import datasets,dataset_settings 22 | import time 23 | import torch.nn.functional as F 24 | 25 | 26 | #MODELS 27 | def conv3x3(in_planes, out_planes, stride=1): 28 | "3x3 convolution with padding" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=1, bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, inplanes, planes, stride=1, downsample=None): 37 | super(BasicBlock, self).__init__() 38 | self.conv1 = conv3x3(inplanes, planes, stride) 39 | self.bn1 = nn.BatchNorm2d(planes) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.conv2 = conv3x3(planes, planes) 42 | self.bn2 = nn.BatchNorm2d(planes) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | 56 | if self.downsample is not None: 57 | residual = self.downsample(x) 58 | 59 | out += residual 60 | out = self.relu(out) 61 | 62 | return out 63 | 64 | class ResNet18(nn.Module): 65 | 66 | def __init__(self, block, layers, input_channels, num_classes=1000): 67 | self.inplanes = 64 68 | super(ResNet18, self).__init__() 69 | self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, 70 | bias=False) 71 | self.bn1 = nn.BatchNorm2d(64) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 74 | self.layer1 = self._make_layer(block, 64, layers[0]) 75 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 76 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 77 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 78 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 79 | self.fc = nn.Linear(512 * block.expansion, num_classes) 80 | 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 84 | m.weight.data.normal_(0, math.sqrt(2. / n)) 85 | elif isinstance(m, nn.BatchNorm2d): 86 | m.weight.data.fill_(1) 87 | m.bias.data.zero_() 88 | 89 | def _make_layer(self, block, planes, blocks, stride=1): 90 | downsample = None 91 | if stride != 1 or self.inplanes != planes * block.expansion: 92 | downsample = nn.Sequential( 93 | nn.Conv2d(self.inplanes, planes * block.expansion, 94 | kernel_size=1, stride=stride, bias=False), 95 | nn.BatchNorm2d(planes * block.expansion), 96 | ) 97 | 98 | layers = [] 99 | layers.append(block(self.inplanes, planes, stride, downsample)) 100 | self.inplanes = planes * block.expansion 101 | for i in range(1, blocks): 102 | layers.append(block(self.inplanes, planes)) 103 | 104 | return nn.Sequential(*layers) 105 | 106 | def forward(self, x): 107 | x = self.conv1(x) 108 | x = self.bn1(x) 109 | x = self.relu(x) 110 | x = self.maxpool(x) 111 | 112 | x = self.layer1(x) 113 | x = self.layer2(x) 114 | x = self.layer3(x) 115 | x = self.layer4(x) 116 | 117 | x = self.avgpool(x) 118 | x = x.view(x.size(0), -1) 119 | x = self.fc(x) 120 | 121 | return x 122 | 123 | 124 | class ResNet18_client_side(nn.Module): 125 | def __init__(self,input_channels): 126 | super(ResNet18_client_side, self).__init__() 127 | self.layer1 = nn.Sequential ( 128 | nn.Conv2d(input_channels, 64, kernel_size = 7, stride = 2, padding = 3, bias = False), 129 | nn.BatchNorm2d(64), 130 | nn.ReLU (inplace = True), 131 | nn.MaxPool2d(kernel_size = 3, stride = 2, padding =1), 132 | ) 133 | self.layer2 = nn.Sequential ( 134 | nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1, bias = False), 135 | nn.BatchNorm2d(64), 136 | nn.ReLU (inplace = True), 137 | nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1), 138 | nn.BatchNorm2d(64), 139 | ) 140 | 141 | for m in self.modules(): 142 | if isinstance(m, nn.Conv2d): 143 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 144 | m.weight.data.normal_(0, math.sqrt(2. / n)) 145 | elif isinstance(m, nn.BatchNorm2d): 146 | m.weight.data.fill_(1) 147 | m.bias.data.zero_() 148 | 149 | 150 | def forward(self, x): 151 | resudial1 = F.relu(self.layer1(x)) 152 | out1 = self.layer2(resudial1) 153 | out1 = out1 + resudial1 # adding the resudial inputs -- downsampling not required in this layer 154 | resudial2 = F.relu(out1) 155 | return resudial2 156 | 157 | 158 | 159 | 160 | 161 | #setting 1 Conf 162 | # 150 datapoints per client, 64 batch size 163 | 164 | 165 | #FL 166 | FL = ResNet18(BasicBlock, [2, 2, 2, 2],3, 10) #last two params are input_channels and number of classes 167 | #SL 168 | SL = ResNet18_client_side(3) 169 | # 170 | SFLv1 = ResNet18_client_side(3) 171 | # 172 | SFLv2 = ResNet18_client_side(3) 173 | 174 | 175 | 176 | 177 | input = torch.randn(64,3,224,224) 178 | macs_client_FL, params_FL = profile(FL, inputs=(input, )) 179 | macs_client_SL, params_SL = profile(SL, inputs=(input, )) 180 | 181 | print(f"GFLOPS FL {((2 * macs_client_FL) / (10**9))} PARAMS: {params_FL}") 182 | print(f"GFLOPS SL {((2 * macs_client_SL) / (10**9))} PARAMS: {params_SL}") 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | -------------------------------------------------------------------------------- /client.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import multiprocessing 5 | from threading import Thread 6 | from utils.connections import is_socket_closed 7 | from utils.connections import send_object 8 | from utils.connections import get_object 9 | from utils.split_dataset import DatasetFromSubset 10 | import pickle 11 | import queue 12 | import struct 13 | 14 | 15 | class Client(Thread): 16 | def __init__(self, id, *args, **kwargs): 17 | super(Client, self).__init__(*args, **kwargs) 18 | self.id = id 19 | self.front_model = [] 20 | self.back_model = [] 21 | self.losses = [] 22 | self.train_dataset = None 23 | self.test_dataset = None 24 | self.train_DataLoader = None 25 | self.test_DataLoader = None 26 | self.socket = None 27 | self.server_socket = None 28 | self.train_batch_size = None 29 | self.test_batch_size = None 30 | self.iterator = None 31 | self.activations1 = None 32 | self.remote_activations1 = None 33 | self.outputs = None 34 | self.loss = None 35 | self.criterion = None 36 | self.data = None 37 | self.targets = None 38 | self.n_correct = 0 39 | self.n_samples = 0 40 | self.front_optimizer = None 41 | self.back_optimizer = None 42 | self.train_acc = [] 43 | self.test_acc = [] 44 | self.front_epsilons = [] 45 | self.front_best_alphas = [] 46 | self.pred=[] 47 | self.y=[] 48 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 49 | # self.device = torch.device('cpu') 50 | 51 | 52 | def backward_back(self): 53 | self.loss.backward() 54 | 55 | 56 | 57 | def backward_front(self): 58 | self.activations1.backward(self.remote_activations1.grad) 59 | 60 | 61 | def calculate_loss(self): 62 | self.criterion = F.cross_entropy 63 | self.loss = self.criterion(self.outputs, self.targets) 64 | 65 | 66 | def calculate_test_acc(self): 67 | with torch.no_grad(): 68 | _, self.predicted = torch.max(self.outputs.data, 1) 69 | self.n_correct = (self.predicted == self.targets).sum().item() 70 | self.n_samples = self.targets.size(0) 71 | self.pred.extend(self.predicted.cpu().detach().numpy().tolist()) 72 | self.y.extend(self.targets.cpu().detach().numpy().tolist()) 73 | # self.test_acc.append(100.0 * self.n_correct/self.n_samples) 74 | return 100.0 * self.n_correct/self.n_samples 75 | # print(f'Acc: {self.test_acc[-1]}') 76 | 77 | 78 | def calculate_train_acc(self): 79 | with torch.no_grad(): 80 | _, self.predicted = torch.max(self.outputs.data, 1) 81 | self.n_correct = (self.predicted == self.targets).sum().item() 82 | self.n_samples = self.targets.size(0) 83 | # self.train_acc.append(100.0 * self.n_correct/self.n_samples) 84 | return 100.0 * self.n_correct/self.n_samples 85 | # print(f'Acc: {self.train_acc[-1]}') 86 | 87 | 88 | def connect_server(self, host='localhost', port=8000, BUFFER_SIZE=4096): 89 | self.socket, self.server_socket = multiprocessing.Pipe() 90 | print(f"[*] Client {self.id} connecting to {host}") 91 | 92 | 93 | def create_DataLoader(self, train_batch_size, test_batch_size): 94 | self.train_batch_size = train_batch_size 95 | self.test_batch_size = test_batch_size 96 | self.train_DataLoader = torch.utils.data.DataLoader(dataset=self.train_dataset, 97 | batch_size=self.train_batch_size, 98 | shuffle=True) 99 | self.test_DataLoader = torch.utils.data.DataLoader(dataset=self.test_dataset, 100 | batch_size=self.test_batch_size, 101 | shuffle=True) 102 | 103 | 104 | def disconnect_server(self) -> bool: 105 | if not is_socket_closed(self.socket): 106 | self.socket.close() 107 | return True 108 | else: 109 | return False 110 | 111 | 112 | def forward_back(self): 113 | self.back_model.to(self.device) 114 | self.outputs = self.back_model(self.remote_activations2) 115 | 116 | 117 | def forward_front(self): 118 | self.data, self.targets = next(self.iterator) 119 | self.data, self.targets = self.data.to(self.device), self.targets.to(self.device) 120 | self.front_model.to(self.device) 121 | self.activations1 = self.front_model(self.data) 122 | self.remote_activations1 = self.activations1.detach().requires_grad_(True) 123 | 124 | 125 | # def getModel(self): 126 | # self.onThread(self._getModel) 127 | 128 | 129 | def get_model(self): 130 | model = get_object(self.socket) 131 | self.front_model = model['front'] 132 | self.back_model = model['back'] 133 | 134 | 135 | def get_remote_activations1_grads(self): 136 | self.remote_activations1.grad = get_object(self.socket) 137 | 138 | 139 | def get_remote_activations2(self): 140 | self.remote_activations2 = get_object(self.socket) 141 | 142 | 143 | def idle(self): 144 | pass 145 | 146 | 147 | def load_data(self, dataset, transform): 148 | try: 149 | dataset_path = os.path.join(f'data/{dataset}/{self.id}') 150 | except: 151 | raise Exception(f'Dataset not found for client {self.id}') 152 | self.train_dataset = torch.load(f'{dataset_path}/train/{self.id}.pt') 153 | self.test_dataset = torch.load(f'{dataset_path}/test/{self.id}.pt') 154 | 155 | self.train_dataset = DatasetFromSubset( 156 | self.train_dataset, transform=transform 157 | ) 158 | self.test_dataset = DatasetFromSubset( 159 | self.test_dataset, transform=transform 160 | ) 161 | 162 | 163 | # def onThread(self, function, *args, **kwargs): 164 | # self.q.put((function, args, kwargs)) 165 | 166 | 167 | # def run(self, *args, **kwargs): 168 | # super(Client, self).run(*args, **kwargs) 169 | # while True: 170 | # try: 171 | # function, args, kwargs = self.q.get(timeout=self.timeout) 172 | # function(*args, **kwargs) 173 | # except queue.Empty: 174 | # self.idle() 175 | 176 | 177 | def send_remote_activations1(self): 178 | send_object(self.socket, self.remote_activations1) 179 | 180 | 181 | def send_remote_activations2_grads(self): 182 | send_object(self.socket, self.remote_activations2.grad) 183 | 184 | 185 | def step_front(self): 186 | self.front_optimizer.step() 187 | 188 | 189 | def step_back(self): 190 | self.back_optimizer.step() 191 | 192 | 193 | # def train_model(self): 194 | # forward_front_model() 195 | # send_activations_to_server() 196 | # forward_back_model() 197 | # loss_calculation() 198 | # backward_back_model() 199 | # send_gradients_to_server() 200 | # backward_front_model() 201 | 202 | 203 | def zero_grad_front(self): 204 | self.front_optimizer.zero_grad() 205 | 206 | 207 | def zero_grad_back(self): 208 | self.back_optimizer.zero_grad() 209 | 210 | -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | from torchvision import models 6 | from torch.utils.data import Dataset, DataLoader 7 | from glob import glob 8 | import os 9 | import pandas as pd 10 | import sklearn 11 | from sklearn.model_selection import train_test_split 12 | from PIL import Image 13 | 14 | def MNIST(path, transform_train = None, transform_test = None): 15 | 16 | if transform_train is None: 17 | #TL 18 | print("TL transform Cifar") 19 | transform_train = transforms.Compose([ 20 | transforms.Resize((224, 224)), 21 | transforms.ToTensor(), 22 | transforms.Normalize(mean=[0.485], 23 | std=[0.229]) 24 | ]) 25 | if transform_test is None: 26 | #TL 27 | transform_test = transforms.Compose([ 28 | transforms.Resize((224, 224)), 29 | transforms.ToTensor(), 30 | transforms.Normalize(mean=[0.485], 31 | std=[0.229])]) 32 | 33 | train_dataset = torchvision.datasets.MNIST(root=path, 34 | train=True, 35 | transform=transform_train,download=True) 36 | 37 | test_dataset = torchvision.datasets.MNIST(root=path, 38 | train=False, 39 | transform=transform_test, download = True) 40 | return train_dataset, test_dataset 41 | 42 | 43 | def CIFAR10(path, transform_train = None, transform_test = None): 44 | 45 | 46 | if transform_train is None: 47 | print("No Tl Cifar10") 48 | transform_train = transforms.Compose([ 49 | transforms.RandomCrop(32, padding=4), 50 | transforms.RandomHorizontalFlip(), 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 53 | ]) 54 | 55 | if transform_test is None: 56 | transform_test = transforms.Compose([ 57 | transforms.RandomCrop(32, padding=4), 58 | transforms.ToTensor(), 59 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 60 | ]) 61 | 62 | train_dataset = torchvision.datasets.CIFAR10(root=path, 63 | train=True, 64 | transform=transform_train,download= True) 65 | 66 | test_dataset = torchvision.datasets.CIFAR10(root=path, 67 | train=False, 68 | transform=transform_test, download = True) 69 | return train_dataset, test_dataset 70 | 71 | def CIFAR10_iid(num_clients, datapoints, path, transform_train, transform_test): 72 | 73 | 74 | #these transforms are same as Imagenet configurations as TL is being used 75 | 76 | if transform_train is None: 77 | #TL 78 | print("TL transform Cifar") 79 | transform_train = transforms.Compose([ 80 | transforms.Resize((224, 224)), 81 | transforms.ToTensor(), 82 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 83 | std=[0.229, 0.224, 0.225]) 84 | ]) 85 | if transform_test is None: 86 | #TL 87 | transform_test = transforms.Compose([ 88 | transforms.Resize((224, 224)), 89 | transforms.ToTensor(), 90 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 91 | std=[0.229, 0.224, 0.225])]) 92 | 93 | train_dataset = torchvision.datasets.CIFAR10(root=path, 94 | train=True, 95 | transform=transform_train,download= True) 96 | 97 | test_dataset = torchvision.datasets.CIFAR10(root=path, 98 | train=False, 99 | transform=transform_test, download = True) 100 | 101 | class2idx = train_dataset.class_to_idx.items() 102 | idx2class = {v: k for k, v in train_dataset.class_to_idx.items()} 103 | 104 | new_train_dataset_size = num_clients * datapoints 105 | temp = len(train_dataset) - new_train_dataset_size 106 | 107 | print(len(train_dataset), new_train_dataset_size, temp) 108 | 109 | new_train_dataset,_ = torch.utils.data.random_split(train_dataset, (new_train_dataset_size, temp)) 110 | new_test_dataset,_ = torch.utils.data.random_split(test_dataset, (2000, 8000)) #keeping 2k datapoints with each client 111 | 112 | return new_train_dataset, new_test_dataset, dict(class2idx), idx2class 113 | 114 | 115 | def cifar10_setting3(num_clients, unique_datapoint, c_datapoints, path, transform_train, transform_test): 116 | 117 | if transform_train is None: 118 | #TL 119 | print("TL transform Cifar") 120 | transform_train = transforms.Compose([ 121 | transforms.Resize((224, 224)), 122 | transforms.ToTensor(), 123 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 124 | std=[0.229, 0.224, 0.225]) 125 | ]) 126 | if transform_test is None: 127 | #TL 128 | transform_test = transforms.Compose([ 129 | transforms.Resize((224, 224)), 130 | transforms.ToTensor(), 131 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 132 | std=[0.229, 0.224, 0.225])]) 133 | 134 | train_dataset = torchvision.datasets.CIFAR10(root=path, 135 | train=True, 136 | transform=transform_train,download= True) 137 | 138 | test_dataset = torchvision.datasets.CIFAR10(root=path, 139 | train=False, 140 | transform=transform_test, download = True) 141 | 142 | class2idx = train_dataset.class_to_idx.items() 143 | idx2class = {v: k for k, v in train_dataset.class_to_idx.items()} 144 | 145 | #for setting two the datapoints are the equal datapoints 146 | # so c1: 2000, c2 - c11: 150 each 147 | 148 | new_train_dataset_size = unique_datapoint + ((num_clients - 1) * c_datapoints) # 2000 + 10 * 150 149 | temp = len(train_dataset) - new_train_dataset_size 150 | 151 | print(len(train_dataset), new_train_dataset_size, temp) 152 | 153 | new_train_dataset,_ = torch.utils.data.random_split(train_dataset, (new_train_dataset_size, temp)) 154 | unique_train_dataset, common_train_dataset = torch.utils.data.random_split(new_train_dataset, (unique_datapoint, new_train_dataset_size - unique_datapoint)) #2000, 1500 155 | new_test_dataset,_ = torch.utils.data.random_split(test_dataset, (1100, 8900)) 156 | 157 | return unique_train_dataset, common_train_dataset, new_test_dataset, dict(class2idx), idx2class 158 | 159 | 160 | 161 | def FashionMNIST(path, transform_train = None,transform_test = None): 162 | 163 | if transform_train is None: 164 | #TL 165 | print("TL transform Cifar") 166 | transform_train = transforms.Compose([ 167 | transforms.Resize((224, 224)), 168 | transforms.ToTensor(), 169 | transforms.Normalize(mean=[0.485], 170 | std=[0.229]) 171 | ]) 172 | if transform_test is None: 173 | #TL 174 | transform_test = transforms.Compose([ 175 | transforms.Resize((224, 224)), 176 | transforms.ToTensor(), 177 | transforms.Normalize(mean=[0.485], 178 | std=[0.229])]) 179 | 180 | 181 | train_dataset = torchvision.datasets.FashionMNIST(root=path,train=True,transform=transform_train,download=True) 182 | test_dataset = torchvision.datasets.FashionMNIST(root=path,train=False,transform=transform_test,download=True) 183 | 184 | return train_dataset, test_dataset 185 | 186 | 187 | class HAM10000(Dataset): 188 | def __init__(self, df, transform=None): 189 | self.df = df 190 | self.transform = transform 191 | 192 | def __len__(self): 193 | return len(self.df) 194 | 195 | def __getitem__(self, index): 196 | # Load data and get label 197 | X = Image.open(self.df['path'][index]) 198 | y = torch.tensor(int(self.df['labels'][index])) 199 | 200 | if self.transform: 201 | X = self.transform(X) 202 | 203 | return X, y 204 | 205 | 206 | def get_duplicates(x): 207 | unique_list = list(df_undup['lesion_id']) 208 | if x in unique_list: 209 | return 'unduplicated' 210 | else: 211 | return 'duplicated' 212 | 213 | def get_test_rows(x): 214 | # create a list of all the lesion_id's in the test set 215 | test_list = list(df_test['image_id']) 216 | if str(x) in test_list: 217 | return 'test' 218 | else: 219 | return 'train' 220 | 221 | def HAM10000_data(path): 222 | data_dir='/home/manas/priv_SLR/data/HAM10000_data_files' 223 | images_dir = data_dir + "/images" 224 | all_image_path = glob(os.path.join(images_dir,'*.jpg')) 225 | print("gnc", len(all_image_path)) 226 | imageid_path_dict = {os.path.splitext(os.path.basename(x))[0]: x for x in all_image_path} 227 | #{"ISIC_0024306":/home/manas/priv_SLR/data/HAM10000_data_files/images/ISIC_0024306.jpg } 228 | 229 | lesion_type_dict = { 230 | 'nv': 'Melanocytic nevi', 231 | 'mel': 'dermatofibroma', 232 | 'bkl': 'Benign keratosis-like lesions ', 233 | 'bcc': 'Basal cell carcinoma', 234 | 'akiec': 'Actinic keratoses', 235 | 'vasc': 'Vascular lesions', 236 | 'df': 'Dermatofibroma' 237 | } 238 | 239 | df_original = pd.read_csv(data_dir + '/HAM10000_metadata') 240 | df_original['path'] = df_original['image_id'].map(imageid_path_dict.get) 241 | df_original['cell_type'] = df_original['dx'].map(lesion_type_dict.get) 242 | df_original['labels'] = pd.Categorical(df_original['cell_type']).codes 243 | 244 | global df_undup 245 | df_undup = df_original.groupby('lesion_id').count() 246 | # now we filter out lesion_id's that have only one image associated with it 247 | df_undup = df_undup[df_undup['image_id'] == 1] 248 | df_undup.reset_index(inplace=True) 249 | df_original['duplicates'] = df_original['lesion_id'] 250 | # apply the function to this new column 251 | df_original['duplicates'] = df_original['duplicates'].apply(get_duplicates) 252 | df_undup = df_original[df_original['duplicates'] == 'unduplicated'] 253 | y = df_undup['labels'] 254 | global df_test, df_train 255 | _, df_test = train_test_split(df_undup, test_size=0.2, random_state=101, stratify=y) 256 | 257 | df_original['train_or_test'] = df_original['image_id'] 258 | df_original['train_or_test'] = df_original['train_or_test'].apply(get_test_rows) 259 | # filter out train rows 260 | df_train = df_original[df_original['train_or_test'] == 'train'] 261 | 262 | df_train=df_train[:8910] 263 | df_test=df_test[:1100] 264 | 265 | df_train = df_train.reset_index() 266 | df_test = df_test.reset_index() 267 | 268 | norm_mean=[0.763038, 0.54564667, 0.57004464] 269 | norm_std = [0.14092727, 0.15261286, 0.1699712] 270 | train_transform = transforms.Compose([transforms.Resize((64, 64)),transforms.RandomHorizontalFlip(), 271 | transforms.RandomVerticalFlip(),transforms.RandomRotation(20), 272 | transforms.ColorJitter(brightness=0.1, contrast=0.1, hue=0.1), 273 | transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std)]) 274 | test_transform = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor(), 275 | transforms.Normalize(norm_mean, norm_std)]) 276 | train_dataset = HAM10000(df_train, transform=train_transform) 277 | test_dataset= HAM10000(df_test, transform=test_transform) 278 | 279 | return train_dataset, test_dataset 280 | 281 | 282 | 283 | def load_full_dataset(dataset, dataset_path, num_clients, datapoints = None, pretrained = False,transform_train = None, transform_test = None): 284 | 285 | if dataset == 'mnist': 286 | train_dataset, test_dataset = MNIST(dataset_path) 287 | input_channels = 1 288 | 289 | if dataset == 'cifar10': 290 | 291 | 292 | if pretrained: 293 | transform_train = transforms.Compose([ 294 | transforms.Resize((224, 224)), 295 | transforms.ToTensor(), 296 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 297 | std=[0.229, 0.224, 0.225]) 298 | ]) 299 | 300 | transform_test = transforms.Compose([ 301 | transforms.Resize((224, 224)), 302 | transforms.ToTensor(), 303 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 304 | std=[0.229, 0.224, 0.225])]) 305 | 306 | train_dataset, test_dataset = CIFAR10(dataset_path, transform_train, transform_test) 307 | else: 308 | train_dataset, test_dataset = CIFAR10(dataset_path) 309 | 310 | input_channels = 3 311 | 312 | 313 | if dataset == 'ham10000': 314 | train_dataset, test_dataset = HAM10000_data(dataset_path) 315 | input_channels = 3 316 | 317 | if dataset == "fmnist": 318 | train_dataset, test_dataset = FashionMNIST(dataset_path) 319 | input_channels = 1 320 | 321 | if dataset == "cifar10_tl": 322 | #this is the reduced train_dataset 323 | 324 | train_dataset, test_dataset, ci,ic = CIFAR10_iid(num_clients, datapoints, dataset_path, transform_train, transform_test) 325 | input_channels = 3 326 | 327 | if dataset == "cifar10_setting3": 328 | u_datapoints = 2000 329 | c_datapoints = 150 330 | input_channels = 3 331 | #num_clients = 11 332 | #transform_train, tranform_test = None, None 333 | 334 | u_train_dataset, c_train_dataset, test_dataset,_,_ = cifar10_setting3(num_clients, u_datapoints, c_datapoints, dataset_path, transform_train, transform_test) 335 | return u_train_dataset,c_train_dataset, test_dataset, input_channels 336 | 337 | return train_dataset, test_dataset, input_channels 338 | 339 | 340 | 341 | 342 | 343 | 344 | # if __name__ == "__main__": 345 | # train_dataset, test_dataset = MNIST('../data') 346 | # print(type(train_dataset)) 347 | -------------------------------------------------------------------------------- /PFSL.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import string 4 | import socket 5 | import requests 6 | import sys 7 | import threading 8 | import time 9 | import torch 10 | from math import ceil 11 | from torchvision import transforms 12 | from torch.utils.data import DataLoader, Dataset 13 | from utils.split_dataset import split_dataset, split_dataset_cifar10tl_exp 14 | from utils.client_simulation import generate_random_clients 15 | from utils.connections import send_object 16 | from utils.arg_parser import parse_arguments 17 | import matplotlib.pyplot as plt 18 | import time 19 | import multiprocessing 20 | from sklearn.metrics import classification_report 21 | import torch.optim as optim 22 | import copy 23 | from datetime import datetime 24 | from scipy.interpolate import make_interp_spline 25 | import numpy as np 26 | from ConnectedClient import ConnectedClient 27 | import importlib 28 | from utils.merge import merge_grads, merge_weights 29 | import pandas as pd 30 | import time 31 | from utils import dataset_settings, datasets 32 | import torch.nn.functional as F 33 | 34 | 35 | #To intialize every client with their train and test data 36 | def initialize_client(client, dataset, batch_size, test_batch_size, tranform): 37 | 38 | client.load_data(dataset, transform) 39 | print(f'Length of train dataset client {client.id}: {len(client.train_dataset)}') 40 | client.create_DataLoader(batch_size, test_batch_size) 41 | 42 | 43 | #Plots class distribution of train data available to each client 44 | def plot_class_distribution(clients, dataset, batch_size, epochs, opt, client_ids): 45 | class_distribution=dict() 46 | number_of_clients=len(client_ids) 47 | if(len(clients)<=20): 48 | plot_for_clients=client_ids 49 | else: 50 | plot_for_clients=random.sample(client_ids, 20) 51 | 52 | fig, ax = plt.subplots(nrows=(int(ceil(len(client_ids)/5))), ncols=5, figsize=(15, 10)) 53 | j=0 54 | i=0 55 | 56 | #plot histogram 57 | for client_id in plot_for_clients: 58 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels']) 59 | class_distribution[client_id]=df['labels'].value_counts().sort_index() 60 | df['labels'].value_counts().sort_index().plot(ax = ax[i,j], kind = 'bar', ylabel = 'frequency', xlabel=client_id) 61 | j+=1 62 | if(j==5 or j==10 or j==15): 63 | i+=1 64 | j=0 65 | fig.tight_layout() 66 | plt.show() 67 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_histogram.png') 68 | plt.savefig('plot_setting3_exp.png') 69 | 70 | max_len=0 71 | #plot line graphs 72 | for client_id in plot_for_clients: 73 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels']) 74 | df['labels'].value_counts().sort_index().plot(kind = 'line', ylabel = 'frequency', label=client_id) 75 | max_len=max(max_len, list(df['labels'].value_counts(sort=False)[df.labels.mode()])[0]) 76 | plt.xticks(np.arange(0,10)) 77 | plt.ylim(0, max_len) 78 | plt.legend() 79 | plt.show() 80 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_line_graph.png') 81 | 82 | return class_distribution 83 | 84 | 85 | 86 | if __name__ == "__main__": 87 | 88 | 89 | args = parse_arguments() 90 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 91 | print("Arguments provided", args) 92 | 93 | 94 | 95 | 96 | random.seed(args.seed) 97 | torch.manual_seed(args.seed) 98 | 99 | overall_test_acc = [] 100 | overall_train_acc = [] 101 | 102 | print('Generating random clients...', end='') 103 | clients = generate_random_clients(args.number_of_clients) 104 | client_ids = list(clients.keys()) 105 | print('Done') 106 | 107 | train_dataset_size, input_channels = split_dataset(args.dataset, client_ids, pretrained=args.pretrained) 108 | 109 | print(f'Random client ids:{str(client_ids)}') 110 | transform=None 111 | max_epoch=0 112 | max_f1=0 113 | 114 | #Assigning train and test data to each client depending for each client 115 | print('Initializing clients...') 116 | 117 | 118 | for _, client in clients.items(): 119 | (initialize_client(client, args.dataset, args.batch_size, args.test_batch_size, transform)) 120 | 121 | 122 | 123 | 124 | print('Client Intialization complete.') 125 | # Train and test data intialisation complete 126 | 127 | # class_distribution=plot_class_distribution(clients, args.dataset, args.batch_size, args.epochs, args.opt_iden, client_ids) 128 | 129 | 130 | #Assigning front, center and back models and their optimizers for all the clients 131 | model = importlib.import_module(f'models.{args.model}') 132 | 133 | for _, client in clients.items(): 134 | client.front_model = model.front(input_channels, pretrained=args.pretrained) 135 | client.back_model = model.back(pretrained=args.pretrained) 136 | print('Done') 137 | 138 | for _, client in clients.items(): 139 | # client.front_optimizer = optim.SGD(client.front_model.parameters(), lr=args.lr, momentum=0.9) 140 | # client.back_optimizer = optim.SGD(client.back_model.parameters(), lr=args.lr, momentum=0.9) 141 | client.front_optimizer = optim.Adam(client.front_model.parameters(), lr=args.lr) 142 | client.back_optimizer = optim.Adam(client.back_model.parameters(), lr=args.lr) 143 | 144 | first_client = clients[client_ids[0]] 145 | num_iterations = ceil(len(first_client.train_DataLoader.dataset)/args.batch_size) 146 | 147 | num_test_iterations= ceil(len(first_client.test_DataLoader.dataset)/args.test_batch_size) 148 | sc_clients = {} #server copy clients 149 | 150 | for iden in client_ids: 151 | sc_clients[iden] = ConnectedClient(iden, None) 152 | 153 | for _,s_client in sc_clients.items(): 154 | s_client.center_model = model.center(pretrained=args.pretrained) 155 | s_client.center_model.to(device) 156 | # s_client.center_optimizer = optim.SGD(s_client.center_model.parameters(), lr=args.lr, momentum=0.9) 157 | s_client.center_optimizer = optim.Adam(s_client.center_model.parameters(), args.lr) 158 | 159 | st = time.time() 160 | 161 | macro_avg_f1_2classes=[] 162 | 163 | criterion=F.cross_entropy 164 | 165 | 166 | #Starting the training process 167 | for epoch in range(args.epochs): 168 | if(epoch==args.checkpoint): # When starting epoch of the perosnalisation is reached, freeze all the layers of the center model 169 | for _, s_client in sc_clients.items(): 170 | s_client.center_model.freeze(epoch, pretrained=True) 171 | 172 | overall_train_acc.append(0) 173 | 174 | 175 | for _, client in clients.items(): 176 | client.train_acc.append(0) 177 | client.iterator = iter(client.train_DataLoader) 178 | 179 | #For every batch in the current epoch 180 | for iteration in range(num_iterations): 181 | print(f'\rEpoch: {epoch+1}, Iteration: {iteration+1}/{num_iterations}', end='') 182 | 183 | 184 | for _, client in clients.items(): 185 | client.forward_front() 186 | 187 | for client_id, client in sc_clients.items(): 188 | client.remote_activations1 = clients[client_id].remote_activations1 189 | client.forward_center() 190 | 191 | for client_id, client in clients.items(): 192 | client.remote_activations2 = sc_clients[client_id].remote_activations2 193 | client.forward_back() 194 | 195 | for _, client in clients.items(): 196 | client.calculate_loss() 197 | 198 | for _, client in clients.items(): 199 | client.backward_back() 200 | 201 | for client_id, client in sc_clients.items(): 202 | client.remote_activations2 = clients[client_id].remote_activations2 203 | client.backward_center() 204 | 205 | for _, client in clients.items(): 206 | client.step_back() 207 | client.zero_grad_back() 208 | 209 | #merge grads uncomment below 210 | 211 | # if epoch%2 == 0: 212 | # params = [] 213 | # normalized_data_sizes = [] 214 | # for iden, client in clients.items(): 215 | # params.append(sc_clients[iden].center_model.parameters()) 216 | # normalized_data_sizes.append(len(client.train_dataset) / train_dataset_size) 217 | # merge_grads(normalized_data_sizes, params) 218 | 219 | for _, client in sc_clients.items(): 220 | client.center_optimizer.step() 221 | client.center_optimizer.zero_grad() 222 | 223 | for _, client in clients.items(): 224 | client.train_acc[-1] += client.calculate_train_acc() #train accuracy of every client in the current epoch in the current batch 225 | 226 | for c_id, client in clients.items(): 227 | client.train_acc[-1] /= num_iterations # train accuracy of every client of all the batches in the current epoch 228 | overall_train_acc[-1] += client.train_acc[-1] 229 | 230 | overall_train_acc[-1] /= len(clients) #avg train accuracy of all the clients in the current epoch 231 | print(f' Personalized Average Train Acc: {overall_train_acc[-1]}') 232 | 233 | # merge weights below uncomment 234 | params = [] 235 | for _, client in sc_clients.items(): 236 | params.append(copy.deepcopy(client.center_model.state_dict())) 237 | w_glob = merge_weights(params) 238 | 239 | for _, client in sc_clients.items(): 240 | client.center_model.load_state_dict(w_glob) 241 | 242 | params = [] 243 | 244 | #In the personalisation phase merging of weights of the back layers is stopped 245 | if(epoch <=args.checkpoint): 246 | for _, client in clients.items(): 247 | params.append(copy.deepcopy(client.back_model.state_dict())) 248 | w_glob_cb = merge_weights(params) 249 | del params 250 | 251 | for _, client in clients.items(): 252 | client.back_model.load_state_dict(w_glob_cb) 253 | 254 | #Testing every epoch 255 | if (epoch%1 == 0 ): 256 | if(epoch==args.checkpoint): 257 | for _, s_client in sc_clients.items(): 258 | s_client.center_model.freeze(epoch, pretrained=True) 259 | with torch.no_grad(): 260 | test_acc = 0 261 | overall_test_acc.append(0) 262 | 263 | for _, client in clients.items(): 264 | client.test_acc.append(0) 265 | client.iterator = iter(client.test_DataLoader) 266 | client.pred=[] 267 | client.y=[] 268 | 269 | #For every batch in the testing phase 270 | for iteration in range(num_test_iterations): 271 | 272 | for _, client in clients.items(): 273 | client.forward_front() 274 | 275 | for client_id, client in sc_clients.items(): 276 | client.remote_activations1 = clients[client_id].remote_activations1 277 | client.forward_center() 278 | 279 | for client_id, client in clients.items(): 280 | client.remote_activations2 = sc_clients[client_id].remote_activations2 281 | client.forward_back() 282 | 283 | for _, client in clients.items(): 284 | client.test_acc[-1] += client.calculate_test_acc() 285 | 286 | 287 | for _, client in clients.items(): 288 | client.test_acc[-1] /= num_test_iterations 289 | overall_test_acc[-1] += client.test_acc[-1] 290 | 291 | 292 | 293 | 294 | overall_test_acc[-1] /= len(clients) #average test accuracy of all the clients in the current epoch 295 | 296 | 297 | print(f' Personalized Average Test Acc: {overall_test_acc[-1]} ') 298 | 299 | 300 | 301 | timestamp = int(datetime.now().timestamp()) 302 | plot_config = f'''dataset: {args.dataset}, 303 | model: {args.model}, 304 | batch_size: {args.batch_size}, lr: {args.lr}, 305 | ''' 306 | 307 | et = time.time() 308 | print(f"Time taken for this run {(et - st)/60} mins") 309 | 310 | 311 | 312 | # calculating the train and test standarad deviation and teh confidence intervals 313 | X = range(args.epochs) 314 | all_clients_stacked_train = np.array([client.train_acc for _,client in clients.items()]) 315 | all_clients_stacked_test = np.array([client.test_acc for _,client in clients.items()]) 316 | epochs_train_std = np.std(all_clients_stacked_train,axis = 0, dtype = np.float64) 317 | epochs_test_std = np.std(all_clients_stacked_test,axis = 0, dtype = np.float64) 318 | 319 | #Y_train is the average client train accuracies at each epoch 320 | #epoch_train_std is the standard deviation of clients train accuracies at each epoch 321 | Y_train = overall_train_acc 322 | Y_train_lower = Y_train - (1.65 * epochs_train_std) #95% of the values lie between 1.65*std 323 | Y_train_upper = Y_train + (1.65 * epochs_train_std) 324 | 325 | Y_test = overall_test_acc 326 | Y_test_lower = Y_test - (1.65 * epochs_test_std) #95% of the values lie between 1.65*std 327 | Y_test_upper = Y_test + (1.65 * epochs_test_std) 328 | 329 | Y_train_cv = epochs_train_std / Y_train 330 | Y_test_cv = epochs_test_std / Y_test 331 | 332 | plt.figure(0) 333 | plt.plot(X, Y_train) 334 | plt.fill_between(X,Y_train_lower , Y_train_upper, color='blue', alpha=0.25) 335 | # plt.savefig(f'./results/test_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight') 336 | plt.show() 337 | 338 | 339 | plt.figure(1) 340 | plt.plot(X, Y_test) 341 | plt.fill_between(X,Y_test_lower , Y_test_upper, color='blue', alpha=0.25) 342 | # plt.savefig(f'./results/test_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight') 343 | plt.show() 344 | 345 | plt.figure(2) 346 | plt.plot(X, Y_train_cv) 347 | plt.show() 348 | 349 | 350 | plt.figure(3) 351 | plt.plot(X, Y_test_cv) 352 | plt.show() 353 | 354 | 355 | 356 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PFSL: Personalized & Fair Split Learning with Data & Label Privacy for thin clients 2 | 3 | ## 1) Please cite as below if you use this repository: 4 | `@software{Manas_Wadhwa_and_Gagan_Gupta_and_Ashutosh_Sahu_and_Rahul_Saini_and_Vidhi_Mittal_PFSL_2023, 5 | author = {Manas Wadhwa and Gagan Gupta and Ashutosh Sahu and Rahul Saini and Vidhi Mittal}, 6 | month = {2}, 7 | title = {{PFSL}}, 8 | url = {https://github.com/mnswdhw/PFSL}, 9 | version = {1.0.0}, 10 | year = {2023} 11 | }` 12 | 13 | 14 | ## 2) Credits 15 | 16 | To reproduce the results of the paper `Thapa, C., Chamikara, M. A., Camtepe, S., & Sun, L. (2020). SplitFed: When Federated Learning Meets Split Learning. ArXiv. https://doi.org/10.48550/arXiv.2004.12088`, we use their official source code which can be found here: https://github.com/chandra2thapa/SplitFed-When-Federated-Learning-Meets-Split-Learning 17 | 18 | For finding the FLOPs of our Pytorch split model at the client side, we use the profiler: https://github.com/Lyken17/pytorch-OpCounter 19 | 20 | 21 | ## 3) Build requirements: 22 | * Python3 (3.8) 23 | * pip3 24 | * Nvidia GPU (>=12GB) 25 | * conda 26 | 27 | 28 | ## 4)Installation 29 | Use the following steps to install the required libraries: 30 | * Change Directory into the project folder 31 | * Create a conda environment using the command 32 | `conda create --name {env_name} python=3.8` 33 | Eg- `conda create --name pfsl python=3.8` 34 | * Activate conda environment using the command 35 | `conda activate {env_name}` 36 | Eg- `conda activate pfsl` 37 | * The use the command: `pip install -r requirements.txt` 38 | 39 | ## 5) Test Run 40 | 41 | ### Parameters 42 | The parameters options for a particular file can be checked adding -–help argument. 43 |
Optional arguments available for PFSL are: 44 | * -h, --help show this help message and exit 45 | * -c, -–number of clients Number of Clients (default: 10) 46 | * -b, -–batch_size Batch size (default: 128) 47 | * –-test_batch_size Input batch size for testing (default: 128) 48 | * -n , –-epochs Total number of epochs to train (default: 10) 49 | * –-lr Learning rate (default: 0.001) 50 | * -–save model Save the trained model (default: False) 51 | * –-dataset States dataset to be used (default: cifar10) 52 | * –-seed Random seed (default: 1234) 53 | * –-model Model you would like to train (default: resnet18) 54 | * –-epoch_batch Number of epochs after which next batchof clients should join (default: 5) 55 | * –-opt_iden optional identifier of experiment (default: ) 56 | * –-pretrained Use transfer learning using a pretrained model (default: False) 57 | * –-datapoints Number of samples of training data allotted to each client (default: 500) 58 | * –-setting Setting you would like to run for, i.e, setting1 ,setting2 or setting4 (default: setting1) 59 | * –-checkpoint Epoch at which personalisation phase will start (default: 50) 60 | * --rate This arguments specifies the fraction of clients dropped off in every epoch (used in setting 5)(default: 0.5) 61 | 62 | For reproducing the results, always add argument –-pretrained while running the PFSL script. 63 | 64 | Create a results directory in the project folder to store all the resulting plots using the below commands. 65 | * `mkdir results` 66 | * `mkdir results/FL` 67 | * `mkdir results/SL` 68 | * `mkdir results/SFLv1` 69 | * `mkdir results/SFLv2` 70 | 71 | ### Commands for all the scenarios 72 | 73 | Below we state the commands for running PFSL, SL, FL, SFLv1 and SFLv2 for all the experimental scenarios. 74 | 75 |
Setting 1: Small Sample Size (Equal), i.i.d. 76 |

In this scenario, each client has a very small number of labelled data points ranging from 50 to 500, and all these samples are distributed identically across clients. There is no class imbalance in training data of each client. To run all the algorithms for setting 1 argument –-setting setting1 and –-datapoints [number of sample per client] has to be added. 77 | Rest of the arguments can be selected as per choice. Numberof data samples can be chosen from 50, 150, 250, 350 and 500 to reproduce the results. When total data sample size was 78 | 50, batch size was chosen to be 32 and for other data samples 79 | greater than 50 batch size was kept at 64. Test batch size was 80 | always taken to be 512. For data sample 150, command are 81 | given below. 82 | 83 | * `python PFSL_Setting124.py --dataset cifar10 --setting setting1 --datapoints 150 --pretrained --model resnet18 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100` 84 | * `python FL.py --dataset cifar10 --setting setting1 --datapoints 150 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100` 85 | * `python SL.py --dataset cifar10 --setting setting1 --datapoints 150 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100` 86 | * `python SFLv1.py --dataset cifar10 --setting setting1 --datapoints 150 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100` 87 | * `python SFLv2.py --dataset cifar10 --setting setting1 --datapoints 150 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100` 88 | 89 |

90 | 91 | 92 | 93 |
Setting 2: Small Sample Size (Equal), non-i.i.d. 94 |

In this setting, we model a situation where every client has more labelled data points from a subset of classes (prominent 95 | classes) and less from the remaining classes. We chose to experiment with heavy label imbalance and diversity. Sample size is small and each client has equal number of training samples. To run all the algorithms for setting 2 argument --setting setting2 has to be added. For PFSL, to enable personalisation phase 96 | from xth epoch, argument --checkpoint [x] has to be added. 97 | Rest of the arguments can be selected as per choice. 98 | 99 | * `python PFSL_Setting124.py --dataset cifar10 --model resnet18 --pretrained --setting setting2 --batch_size 64 --test_batch_size 512 --checkpoint 25 --epochs 30` 100 | * `python FL.py --dataset cifar10 --setting setting2 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100` 101 | * `python SL.py --dataset cifar10 --setting setting2 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100` 102 | * `python SFLv1.py --dataset cifar10 --setting setting2 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100` 103 | * `python SFLv2.py --dataset cifar10 --setting setting2 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100` 104 | 105 | 106 | 107 |

108 |
109 | 110 |
Setting 3: Small Sample Size (Unequal), i.i.d. 111 |

In this settingwe consider we there 11 clients where the Large client has 2000 labelled data points 112 | while the other ten small clients have 150 labelled data points, 113 | each distributed identically. The class distributions 114 | among all the clients are the same. For evaluation purposes, 115 | we consider a test set having 2000 data points with an identical 116 | distribution of classes as the train set. 117 | 118 | To reproduce Table IV of the paper, run setting 1 with 119 | datapoints as 150 as illustrated above. To reproduce Table V 120 | of the paper follow the below commands. In all the commands argument --datapoints that denotes the number of datapoints of the large client has to be added.In our case it was 2000. 121 | 122 | * `python PFSL_Setting3.py --datapoints 2000 --dataset cifar10 --pretrained --model resnet18 -c 11 --epochs 50` 123 | * `python SFLv1_Setting3.py --datapoints 2000 --dataset cifar10_setting3 -c 11 --epochs 100` 124 | * `python SFLv2_Setting3.py --datapoints 2000 --dataset cifar10_setting3 -c 11 --epochs 100` 125 | * `python FL_Setting3.py --datapoints 2000 --dataset cifar10_setting3 -c 11 --epochs 100` 126 | * `python SL_Setting3.py --datapoints 2000 --dataset cifar10_setting3 -c 11 --epochs 100` 127 | 128 |

129 |
130 | 131 | 132 |
133 | Setting 4: A large number of data samples 134 |

Here, all clients have large number of samples. This experiment was done with three different image classification datasets: 135 | MNIST, FMNIST, and CIFAR-10. To run all the algorithms for setting 4 argument --setting setting4 has 136 | to be added. Rest of the arguments can be selected as per choice. Dataset argument has 3 options: cifar10, mnist and fmnist. 137 | 138 | * `python PFSL_Setting124.py --dataset cifar10 --setting setting4 --pretrained --model resnet18 -c 5 --epochs 20` 139 | * `python FL.py --dataset cifar10 --setting setting4 -c 5 --epochs 20` 140 | * `python SL.py --dataset cifar10 --setting setting4 -c 5 --epochs 20` 141 | * `python SFLv1.py --dataset cifar10 --setting setting4 -c 5 --epochs 20` 142 | * `python SFLv2.py --dataset cifar10 --setting setting4 -c 5 --epochs 20` 143 |

144 |
145 | 146 | 147 |
148 | Setting 5: System simulation with 1000 client 149 |

In this setting we try to simulate an environment with 1000 clients. Each client stays in the system only for 1 round which lasts only 1 epoch. 150 | Thus, we evaluate our system for the worst possible scenario when every client cannot stay in the system for long and can only afford to make a minimal effort to participate. We assume that each client has 50 labeled data points sampled randomly but unique to the client. Within each round, we 151 | simulate a dropout, where clients begin training but are not able to complete the weight averaging. We keep the dropout probability at 50%. 152 | 153 | 154 | Use the following command to reproduce the results: Here rate argument specifies the dropoff rate which is the numberof clients that will be dropped randomly in every epoch 155 | 156 | * `python system_simulation_e2.py -c 10 --batch_size 16 --dataset cifar10 --model resnet18 --pretrained --epochs 100 --rate 0.3` 157 | 158 |

159 |
160 | 161 | 162 | 163 | 164 | 165 |
166 | Setting 6: Different Diabetic Retinopathy Datasets: 167 |

This experiment describes the realistic scenario when healthcare centers have different sets of raw patient data for the 168 | same disease. We have used two datasets EyePACS and APTOS whose references are given below. 169 | 170 | 171 | Dataset Sources: 172 | * Source of Dataset 1, https://www.kaggle.com/competitions/aptos2019-blindness-detection/data 173 | * Source of Dataset 2, https://www.kaggle.com/datasets/mariaherrerot/eyepacspreprocess 174 | 175 | To preprocess the dataset download and store the unzipped files in data/eye dataset1 folder and data/eye dataset2 folder. 176 | For this create directories using the command: 177 | * `mkdir data/eye_dataset1` 178 | * `mkdir data/eye_dataset2` 179 |
180 | 181 | 182 | The directory structure of data is as follows: 183 | * `data/eye_dataset1/train_images` 184 | * `data/eye_dataset1/test_images` 185 | * `data/eye_dataset1/test.csv` 186 | * `data/eye_dataset1/train.csv` 187 | * `data/eye_dataset2/eyepacs_preprocess/eyepacs_preprocess/` 188 | * `data/eye_dataset2/trainLabels.csv` 189 | 190 | Once verify the path of the unzipped folders in the load data function of preprocess_eye_dataset_1.py and preprocess_eye_dataset_2.py files. 191 | 192 | For Data preprocessing, run the commands mentioned below 193 | for both the datasets
194 | `python utils/preprocess_eye_dataset_1.py`
195 | `python utils/preprocess_eye_dataset 2.py` 196 | 197 | * `python PFSL_DR.py --pretrained --model resnet18 -c 10 --batch_size 64 --test_batch_size 512 --epochs 50` 198 | * `python FL_DR.py -c 10 --batch_size 64 --test_batch_size 512 --epochs 50` 199 | * `python SL_DR.py --batch_size 64 --test_batch_size 512 --epochs 50` 200 | * `python SFLv1_DR.py --batch_size 64 --test_batch_size 512 --epochs 50` 201 | * `python SFLv2_DR.py --batch_size 64 --test_batch_size 512 --epochs 50` 202 |

203 |
204 | 205 | 206 | ## (6) Test Example Outputs for different Settings 207 | 208 | ### Setting 1 209 | 210 | Command: `python PFSL_Setting124.py --dataset cifar10 --setting setting1 --datapoints 150 --pretrained --model resnet18 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100` 211 | 212 | Maximum test accuracy and time taken for the run is noted in this setting. 213 | 214 | Final Output of the above command is as follows:
215 | * Epoch: 100, Iteration: 3/3 216 | * Training Accuracy: 100.0 217 | * Maximum Test Accuracy: 82.52188846982759 218 | * Time taken for this run 49.36871874332428 mins 219 | 220 | ### Setting 2 221 | 222 | Command: `python PFSL_Setting124.py --dataset cifar10 --model resnet18 --pretrained --setting setting2 --batch_size 64 --test_batch_size 512 --checkpoint 25 --epochs 30` 223 | 224 | After the 25th layer, personalization phase begins since checkpoint is specified as 25 in the above command. It outputs F1 Score just before the start of the personalization phase and the maximum F1 Score achieved in that phase. 225 | 226 | Final Output:
227 | * Epoch: 25, Iteration: 8/8freezing the center model 228 | * Epoch: 26, Iteration: 8/8freezing the center model 229 | * F1 Score at epoch 25 : 0.8151361976947273 230 | * Epoch: 30, Iteration: 8/8 231 | * Training Accuracy: 100.0 232 | * Maximum F1 Score: 0.9509444261471961 233 | * Time taken for this run 11.245620834827424 mins 234 | 235 | 236 | ### Setting 3 237 | 238 | Command: `python PFSL_Setting3.py --datapoints 2000 --dataset cifar10 --pretrained --model resnet18 -c 11 --epochs 50` 239 |
240 | 241 | This command will print statements of the form as below every epoch
242 | * Large client train/Test accuracy xx.xx 243 | * Epoch 19 C1-C10 Average Train/Test Acc: xx.xx 244 | 245 | Final Output will be the maximum test accuracy of the large client and the maximum average test accuracy of the remaining clients which for the above command is
246 | * Average C1 - C10 test accuracy: 86.1162109375 247 | * Large Client test Accuracy: 87.109375 248 | * Time taken for this run 7.857871949672699 mins 249 | 250 | 251 | ### Setting 4 252 | 253 | Command: `python PFSL_Setting124.py --dataset cifar10 --setting setting4 --pretrained --model resnet18 -c 5 --epochs 20` 254 |
255 | 256 | Final Output of the above command is as follows
257 | * Epoch: 20, Iteration: 79/79 258 | * Training Accuracy: 98.90427215189872 259 | * Maximum Test Accuracy: 94.1484375 260 | * Time taken for this run 36.06000682512919 mins 261 | 262 | ### Setting 5 263 | 264 | Command: `python system_simulation_e2.py -c 10 --batch_size 16 --dataset cifar10 --model resnet18 --pretrained --epochs 40 --rate 0.3` 265 | 266 | For every epoch it prints the average train accuracy and the number of clients that are dropped off. Next, it prints the ids of the clients that are not dropped off. 267 | 268 | 269 | Final Output of the above command is as follows
270 | * Personalized Average Test Acc: 88.19791666666666 271 | * Time taken for this run 18.778587651252746 mins 272 | 273 | 274 | 275 | 276 | ### Setting 6 277 | 278 | Command: `python PFSL_DR.py --pretrained --model resnet18 -c 10 --batch_size 64 --test_batch_size 512 --epochs 50` 279 | 280 | Average test accuracy for clients having the APTOS dataset and clients having the EyePACS dataset is noted separately. Also, F1 Score for one representative client from each group is noted. The command above outputs these metrics for the epoch in which the maximum average test accuracy of all the clients is achieved. 281 | 282 | Final Output:
283 | * Epoch: 50, Iteration: 8/8 284 | * Time taken for this run 75.75171089967093 mins 285 | * Time taken for this run 75.75171089967093 mins 286 | * Maximum Personalized Average Test Acc: 79.24868724385246 287 | * Maximum Personalized Average Train Acc: 97.85606971153845 288 | * Client0 F1 Scores: 0.772644561137067 289 | * Client5 F1 Scores:0.5906352306590767 290 | * Personalized Average Test Accuracy for Clients 0 to 4 ": 85.21932633196721 291 | * Personalized Average Test Accuracy for Clients 5 to 9": 73.27804815573771 292 | 293 | 294 | 295 | 296 | ## (7) Quick Validation of Environment 297 | 298 | Command: `python PFSL_Setting124.py --dataset cifar10 --setting setting1 --datapoints 50 --pretrained --model resnet18 -c 5 --batch_size 64 --test_batch_size 512 --epochs 2` 299 | 300 | Output
301 | 302 | * Training Accuracy: 82.0 303 | * Maximum Test Accuracy: 57.88355334051723 304 | * Time taken for this run 0.4888111670811971 mins 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | -------------------------------------------------------------------------------- /PFSL_DR.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import string 4 | import socket 5 | import requests 6 | import sys 7 | import threading 8 | import time 9 | import torch 10 | from math import ceil 11 | from torchvision import transforms 12 | from torch.utils.data import DataLoader, Dataset 13 | from utils.split_dataset import split_dataset, split_dataset_cifar10tl_exp 14 | from utils.client_simulation import generate_random_clients 15 | from utils.connections import send_object 16 | from utils.arg_parser import parse_arguments 17 | import matplotlib.pyplot as plt 18 | import time 19 | import multiprocessing 20 | from sklearn.metrics import classification_report 21 | import torch.optim as optim 22 | import copy 23 | from datetime import datetime 24 | from scipy.interpolate import make_interp_spline 25 | import numpy as np 26 | from ConnectedClient import ConnectedClient 27 | import importlib 28 | from utils.merge import merge_grads, merge_weights 29 | import pandas as pd 30 | import time 31 | from utils import get_eye_dataset 32 | import torch.nn.functional as F 33 | # Clients Side Program 34 | #============================================================================================================== 35 | class DatasetSplit(Dataset): 36 | def __init__(self, dataset, idxs): 37 | self.dataset = dataset 38 | self.idxs = list(idxs) 39 | 40 | def __len__(self): 41 | return len(self.idxs) 42 | 43 | def __getitem__(self, item): 44 | image, label = self.dataset[self.idxs[item]] 45 | return image, label 46 | 47 | 48 | def initialize_client(client, dataset, batch_size, test_batch_size, tranform): 49 | client.load_data(args.dataset, transform) 50 | print(f'Length of train dataset client {client.id}: {len(client.train_dataset)}') 51 | client.create_DataLoader(batch_size, test_batch_size) 52 | 53 | 54 | def select_random_clients(clients): 55 | random_clients = {} 56 | client_ids = list(clients.keys()) 57 | random_index = random.randint(0,len(client_ids)-1) 58 | random_client_ids = client_ids[random_index] 59 | 60 | print(random_client_ids) 61 | print(clients) 62 | 63 | for random_client_id in random_client_ids: 64 | random_clients[random_client_id] = clients[random_client_id] 65 | return random_clients 66 | 67 | 68 | def plot_class_distribution(clients, dataset, batch_size, epochs, opt, client_ids): 69 | class_distribution=dict() 70 | number_of_clients=len(client_ids) 71 | if(len(clients)<=20): 72 | plot_for_clients=client_ids 73 | else: 74 | plot_for_clients=random.sample(client_ids, 20) 75 | 76 | fig, ax = plt.subplots(nrows=(int(ceil(len(client_ids)/5))), ncols=5, figsize=(15, 10)) 77 | j=0 78 | i=0 79 | 80 | #plot histogram 81 | for client_id in plot_for_clients: 82 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels']) 83 | class_distribution[client_id]=df['labels'].value_counts().sort_index() 84 | df['labels'].value_counts().sort_index().plot(ax = ax[i,j], kind = 'bar', ylabel = 'frequency', xlabel=client_id) 85 | j+=1 86 | if(j==5 or j==10 or j==15): 87 | i+=1 88 | j=0 89 | fig.tight_layout() 90 | plt.show() 91 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_histogram.png') 92 | plt.savefig('plot_setting_DR_exp.png') 93 | 94 | max_len=0 95 | #plot line graphs 96 | for client_id in plot_for_clients: 97 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels']) 98 | df['labels'].value_counts().sort_index().plot(kind = 'line', ylabel = 'frequency', label=client_id) 99 | max_len=max(max_len, list(df['labels'].value_counts(sort=False)[df.labels.mode()])[0]) 100 | plt.xticks(np.arange(0,10)) 101 | plt.ylim(0, max_len) 102 | plt.legend() 103 | plt.show() 104 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_line_graph.png') 105 | 106 | return class_distribution 107 | 108 | if __name__ == "__main__": 109 | 110 | 111 | args = parse_arguments() 112 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 113 | print("Arguments provided", args) 114 | 115 | random.seed(args.seed) 116 | torch.manual_seed(args.seed) 117 | 118 | overall_test_acc = [] 119 | overall_test_acc1 = [] 120 | overall_test_acc2 = [] 121 | overall_train_acc = [] 122 | 123 | print('Generating random clients...', end='') 124 | clients = generate_random_clients(args.number_of_clients) 125 | client_ids = list(clients.keys()) 126 | print('Done') 127 | 128 | # train_dataset_size, input_channels = split_dataset_cifar10tl_exp(client_ids, args.datapoints) 129 | 130 | print(f'Random client ids:{str(client_ids)}') 131 | transform=None 132 | max_epoch=0 133 | max_f1=0 134 | max_accuracy=0 135 | max_train_accuracy=0 136 | max_c0_4_test=0 137 | max_c0_f1=0 138 | max_c5_9_test=0 139 | max_c5_f1=0 140 | 141 | 142 | print('Initializing clients...') 143 | 144 | 145 | 146 | d1, d2=get_eye_dataset.get_idxs() 147 | 148 | 149 | 150 | 151 | input_channels=3 152 | 153 | 154 | 155 | 156 | i=0 157 | dict_user_train=dict() 158 | dict_user_test=dict() 159 | client_idxs=dict() 160 | 161 | for _, client in clients.items(): 162 | dict_user_train[_]=d1[i] 163 | dict_user_test[_]=d2[i] 164 | # dict_user_test_generalized[_]=dict_users_test_equal[i%2] 165 | client_idxs[_]=i 166 | i+=1 167 | for _, client in clients.items(): 168 | # client.train_dataset=DatasetSplit(train_full_dataset, dict_user_train[_]) 169 | client_idx=client_idxs[_] 170 | if(client_idx>=0 and client_idx<=4 ): 171 | train_data_id=1 172 | test_data_id=3 173 | elif(client_idx>=5 and client_idx<=9): 174 | train_data_id=2 175 | test_data_id=4 176 | 177 | client.train_dataset=get_eye_dataset.get_eye_data(dict_user_train[_], train_data_id) 178 | client.test_dataset=get_eye_dataset.get_eye_data(dict_user_test[_], test_data_id) 179 | client.create_DataLoader(args.batch_size, args.test_batch_size) 180 | 181 | 182 | 183 | 184 | #class_distribution=plot_class_distribution(clients, args.dataset, args.batch_size, args.epochs, args.opt_iden, client_ids) 185 | print('Client Intialization complete.') 186 | model = importlib.import_module(f'models.{args.model}') 187 | count=0 188 | for _, client in clients.items(): 189 | client.front_model = model.front(input_channels, pretrained=args.pretrained) 190 | client.back_model = model.back(pretrained=args.pretrained) 191 | 192 | print('Done') 193 | 194 | 195 | 196 | 197 | 198 | for _, client in clients.items(): 199 | 200 | # client.front_optimizer = optim.SGD(client.front_model.parameters(), lr=args.lr, momentum=0.9) 201 | # client.back_optimizer = optim.SGD(client.back_model.parameters(), lr=args.lr, momentum=0.9) 202 | client.front_optimizer = optim.Adam(client.front_model.parameters(), lr=args.lr) 203 | client.back_optimizer = optim.Adam(client.back_model.parameters(), lr=args.lr) 204 | 205 | 206 | 207 | first_client = clients[client_ids[0]] 208 | num_iterations = ceil(len(first_client.train_DataLoader.dataset)/args.batch_size) 209 | num_test_iterations_personalization = ceil(len(first_client.test_DataLoader.dataset)/args.test_batch_size) 210 | sc_clients = {} #server copy clients 211 | 212 | for iden in client_ids: 213 | sc_clients[iden] = ConnectedClient(iden, None) 214 | 215 | for _,s_client in sc_clients.items(): 216 | s_client.center_model = model.center(pretrained=args.pretrained) 217 | 218 | s_client.center_model.to(device) 219 | # s_client.center_optimizer = optim.SGD(s_client.center_model.parameters(), lr=args.lr, momentum=0.9) 220 | s_client.center_optimizer = optim.Adam(s_client.center_model.parameters(), args.lr) 221 | 222 | 223 | st = time.time() 224 | 225 | macro_avg_f1_3classes=[] 226 | macro_avg_f1_dict={} 227 | 228 | criterion=F.cross_entropy 229 | 230 | 231 | for epoch in range(args.epochs): 232 | 233 | 234 | overall_train_acc.append(0) 235 | for _, client in clients.items(): 236 | client.train_acc.append(0) 237 | client.iterator = iter(client.train_DataLoader) 238 | 239 | 240 | for iteration in range(num_iterations): 241 | print(f'\rEpoch: {epoch+1}, Iteration: {iteration+1}/{num_iterations}', end='') 242 | 243 | for _, client in clients.items(): 244 | client.forward_front() 245 | 246 | for client_id, client in sc_clients.items(): 247 | client.remote_activations1 = clients[client_id].remote_activations1 248 | client.forward_center() 249 | 250 | for client_id, client in clients.items(): 251 | client.remote_activations2 = sc_clients[client_id].remote_activations2 252 | client.forward_back() 253 | 254 | for _, client in clients.items(): 255 | client.calculate_loss() 256 | 257 | for _, client in clients.items(): 258 | client.backward_back() 259 | 260 | for client_id, client in sc_clients.items(): 261 | client.remote_activations2 = clients[client_id].remote_activations2 262 | client.backward_center() 263 | 264 | 265 | for _, client in clients.items(): 266 | client.step_back() 267 | client.zero_grad_back() 268 | 269 | for _, client in sc_clients.items(): 270 | client.center_optimizer.step() 271 | client.center_optimizer.zero_grad() 272 | 273 | for _, client in clients.items(): 274 | client.train_acc[-1] += client.calculate_train_acc() 275 | 276 | for c_id, client in clients.items(): 277 | client.train_acc[-1] /= num_iterations 278 | overall_train_acc[-1] += client.train_acc[-1] 279 | 280 | overall_train_acc[-1] /= len(clients) 281 | 282 | 283 | # merge weights below uncomment 284 | params = [] 285 | for _, client in sc_clients.items(): 286 | params.append(copy.deepcopy(client.center_model.state_dict())) 287 | w_glob = merge_weights(params) 288 | 289 | for _, client in sc_clients.items(): 290 | client.center_model.load_state_dict(w_glob) 291 | 292 | params = [] 293 | # if(epoch <=args.checkpoint): 294 | for _, client in clients.items(): 295 | params.append(copy.deepcopy(client.back_model.state_dict())) 296 | w_glob_cb = merge_weights(params) 297 | 298 | for _, client in clients.items(): 299 | client.back_model.load_state_dict(w_glob_cb) 300 | 301 | 302 | 303 | # Testing on every 5th epoch 304 | 305 | if (epoch%1 == 0 ): 306 | 307 | with torch.no_grad(): 308 | test_acc = 0 309 | overall_test_acc.append(0) 310 | overall_test_acc1.append(0) 311 | overall_test_acc2.append(0) 312 | 313 | # for 314 | for _, client in clients.items(): 315 | client.test_acc.append(0) 316 | client.iterator = iter(client.test_DataLoader) 317 | client.pred=[] 318 | client.y=[] 319 | for iteration in range(num_test_iterations_personalization): 320 | 321 | for _, client in clients.items(): 322 | client.forward_front() 323 | 324 | for client_id, client in sc_clients.items(): 325 | client.remote_activations1 = clients[client_id].remote_activations1 326 | client.forward_center() 327 | 328 | for client_id, client in clients.items(): 329 | client.remote_activations2 = sc_clients[client_id].remote_activations2 330 | client.forward_back() 331 | 332 | for _, client in clients.items(): 333 | client.test_acc[-1] += client.calculate_test_acc() 334 | 335 | for _, client in clients.items(): 336 | client.test_acc[-1] /= num_test_iterations_personalization 337 | overall_test_acc[-1] += client.test_acc[-1] 338 | idx=client_idxs[_] 339 | if(idx>=0 and idx<5): 340 | overall_test_acc1[-1] += client.test_acc[-1] 341 | elif(idx>=5 and idx<10): 342 | overall_test_acc2[-1] += client.test_acc[-1] 343 | 344 | clr=classification_report(np.array(client.y), np.array(client.pred), output_dict=True, zero_division=0) 345 | 346 | 347 | curr_f1=(clr['0']['f1-score']+clr['1']['f1-score']+clr['2']['f1-score'])/3 348 | macro_avg_f1_3classes.append(curr_f1) 349 | macro_avg_f1_dict[idx]=curr_f1 350 | 351 | overall_test_acc[-1] /= len(clients) 352 | overall_test_acc1[-1] /=5 353 | overall_test_acc2[-1] /= 5 354 | 355 | f1_avg_all_user=sum(macro_avg_f1_3classes)/len(macro_avg_f1_3classes) 356 | macro_avg_f1_3classes=[] 357 | 358 | if(overall_test_acc[-1] > max_accuracy): 359 | max_accuracy=overall_test_acc[-1] 360 | max_train_accuracy=overall_train_acc[-1] 361 | max_epoch=epoch 362 | max_c0_f1=macro_avg_f1_dict[0] 363 | max_c5_f1=macro_avg_f1_dict[5] 364 | max_c0_4_test=overall_test_acc1[-1] 365 | max_c5_9_test=overall_test_acc2[-1] 366 | 367 | 368 | 369 | macro_avg_f1_dict={} 370 | 371 | timestamp = int(datetime.now().timestamp()) 372 | plot_config = f'''dataset: {args.dataset}, 373 | model: {args.model}, 374 | batch_size: {args.batch_size}, lr: {args.lr}, 375 | ''' 376 | 377 | et = time.time() 378 | print(f"\nTime taken for this run {(et - st)/60} mins") 379 | print(f"Time taken for this run {(et - st)/60} mins") 380 | print(f'Maximum Personalized Average Test Acc: {max_accuracy} ') 381 | print(f'Maximum Personalized Average Train Acc: {max_train_accuracy} ') 382 | print(f'Client0 F1 Scores: {max_c0_f1}') 383 | print(f'Client5 F1 Scores:{max_c5_f1}') 384 | print(f'Personalized Average Test Accuracy for Clients 0 to 4 ": {max_c0_4_test}') 385 | print(f'Personalized Average Test Accuracy for Clients 5 to 9": {max_c5_9_test}') 386 | 387 | 388 | X = range(args.epochs) 389 | all_clients_stacked_train = np.array([client.train_acc for _,client in clients.items()]) 390 | all_clients_stacked_test = np.array([client.test_acc for _,client in clients.items()]) 391 | epochs_train_std = np.std(all_clients_stacked_train,axis = 0, dtype = np.float64) 392 | epochs_test_std = np.std(all_clients_stacked_test,axis = 0, dtype = np.float64) 393 | 394 | #Y_train is the average client train accuracies at each epoch 395 | #epoch_train_std is the standard deviation of clients train accuracies at each epoch 396 | Y_train = overall_train_acc 397 | Y_train_lower = Y_train - (1.65 * epochs_train_std) #95% of the values lie between 1.65*std 398 | Y_train_upper = Y_train + (1.65 * epochs_train_std) 399 | 400 | Y_test = overall_test_acc 401 | Y_test_lower = Y_test - (1.65 * epochs_test_std) #95% of the values lie between 1.65*std 402 | Y_test_upper = Y_test + (1.65 * epochs_test_std) 403 | 404 | Y_train_cv = epochs_train_std / Y_train 405 | Y_test_cv = epochs_test_std / Y_test 406 | 407 | plt.figure(0) 408 | plt.plot(X, Y_train) 409 | plt.fill_between(X,Y_train_lower , Y_train_upper, color='blue', alpha=0.25) 410 | # plt.savefig(f'./results/train_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight') 411 | plt.show() 412 | 413 | 414 | plt.figure(1) 415 | plt.plot(X, Y_test) 416 | plt.fill_between(X,Y_test_lower , Y_test_upper, color='blue', alpha=0.25) 417 | # plt.savefig(f'./results/test_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight') 418 | plt.show() 419 | 420 | plt.figure(2) 421 | plt.plot(X, Y_train_cv) 422 | plt.show() 423 | 424 | plt.figure(3) 425 | plt.plot(X, Y_test_cv) 426 | plt.show() 427 | 428 | 429 | 430 | -------------------------------------------------------------------------------- /PFSL_Setting124.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import string 4 | import socket 5 | import requests 6 | import sys 7 | import threading 8 | import time 9 | import torch 10 | from math import ceil 11 | from torchvision import transforms 12 | from torch.utils.data import DataLoader, Dataset 13 | from utils.split_dataset import split_dataset, split_dataset_cifar10tl_exp 14 | from utils.client_simulation import generate_random_clients 15 | from utils.connections import send_object 16 | from utils.arg_parser import parse_arguments 17 | import matplotlib.pyplot as plt 18 | import time 19 | import multiprocessing 20 | from sklearn.metrics import classification_report 21 | import torch.optim as optim 22 | import copy 23 | from datetime import datetime 24 | from scipy.interpolate import make_interp_spline 25 | import numpy as np 26 | from ConnectedClient import ConnectedClient 27 | import importlib 28 | from utils.merge import merge_grads, merge_weights 29 | import pandas as pd 30 | import time 31 | from utils import dataset_settings, datasets 32 | import torch.nn.functional as F 33 | 34 | 35 | #To load train and test data for each client for setting 1 and setting 2 36 | class DatasetSplit(Dataset): 37 | def __init__(self, dataset, idxs): 38 | self.dataset = dataset 39 | self.idxs = list(idxs) 40 | 41 | def __len__(self): 42 | return len(self.idxs) 43 | 44 | def __getitem__(self, item): 45 | image, label = self.dataset[self.idxs[item]] 46 | return image, label 47 | 48 | 49 | #To intialize every client with their train and test data for setting 4 50 | def initialize_client(client, dataset, batch_size, test_batch_size, tranform): 51 | 52 | client.load_data(dataset, transform) 53 | print(f'Length of train dataset client {client.id}: {len(client.train_dataset)}') 54 | client.create_DataLoader(batch_size, test_batch_size) 55 | 56 | 57 | #Plots class distribution of train data available to each client 58 | def plot_class_distribution(clients, dataset, batch_size, epochs, opt, client_ids): 59 | class_distribution=dict() 60 | number_of_clients=len(client_ids) 61 | if(len(clients)<=20): 62 | plot_for_clients=client_ids 63 | else: 64 | plot_for_clients=random.sample(client_ids, 20) 65 | 66 | fig, ax = plt.subplots(nrows=(int(ceil(len(client_ids)/5))), ncols=5, figsize=(15, 10)) 67 | j=0 68 | i=0 69 | 70 | #plot histogram 71 | for client_id in plot_for_clients: 72 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels']) 73 | class_distribution[client_id]=df['labels'].value_counts().sort_index() 74 | df['labels'].value_counts().sort_index().plot(ax = ax[i,j], kind = 'bar', ylabel = 'frequency', xlabel=client_id) 75 | j+=1 76 | if(j==5 or j==10 or j==15): 77 | i+=1 78 | j=0 79 | fig.tight_layout() 80 | plt.show() 81 | 82 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_histogram.png') 83 | plt.savefig('plot_setting3_exp.png') 84 | 85 | max_len=0 86 | #plot line graphs 87 | for client_id in plot_for_clients: 88 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels']) 89 | df['labels'].value_counts().sort_index().plot(kind = 'line', ylabel = 'frequency', label=client_id) 90 | max_len=max(max_len, list(df['labels'].value_counts(sort=False)[df.labels.mode()])[0]) 91 | plt.xticks(np.arange(0,10)) 92 | plt.ylim(0, max_len) 93 | plt.legend() 94 | plt.show() 95 | 96 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_line_graph.png') 97 | 98 | return class_distribution 99 | 100 | 101 | 102 | if __name__ == "__main__": 103 | 104 | 105 | args = parse_arguments() 106 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 107 | print("Arguments provided", args) 108 | 109 | 110 | 111 | random.seed(args.seed) 112 | torch.manual_seed(args.seed) 113 | 114 | overall_test_acc = [] 115 | overall_train_acc = [] 116 | 117 | print('Generating random clients...', end='') 118 | clients = generate_random_clients(args.number_of_clients) 119 | client_ids = list(clients.keys()) 120 | print('Done') 121 | 122 | train_dataset_size, input_channels = split_dataset(args.dataset, client_ids, pretrained=args.pretrained) 123 | 124 | print(f'Random client ids:{str(client_ids)}') 125 | transform=None 126 | max_epoch=0 127 | max_f1=0 128 | max_accuracy=0 129 | 130 | #Assigning train and test data to each client depending for each client 131 | print('Initializing clients...') 132 | 133 | if(args.setting=="setting4"): 134 | for _, client in clients.items(): 135 | (initialize_client(client, args.dataset, args.batch_size, args.test_batch_size, transform)) 136 | else: 137 | 138 | train_full_dataset, test_full_dataset, input_channels = datasets.load_full_dataset(args.dataset, "data", args.number_of_clients, args.datapoints, args.pretrained) 139 | #---------------------------------------------------------------- 140 | dict_users , dict_users2 = dataset_settings.get_dicts(train_full_dataset, test_full_dataset, args.number_of_clients, args.setting, args.datapoints) 141 | 142 | dict_users_test_equal=dataset_settings.get_test_dict(test_full_dataset, args.number_of_clients) 143 | 144 | client_idx=0 145 | dict_user_train=dict() 146 | dict_user_test=dict() 147 | client_idxs=dict() 148 | 149 | for _, client in clients.items(): 150 | dict_user_train[_]=dict_users[client_idx] 151 | dict_user_test[_]=dict_users2[client_idx] 152 | client_idxs[_]=client_idx 153 | client_idx+=1 154 | for _, client in clients.items(): 155 | client.train_dataset=DatasetSplit(train_full_dataset, dict_user_train[_]) 156 | client.test_dataset=DatasetSplit(test_full_dataset, dict_user_test[_]) 157 | client.create_DataLoader(args.batch_size, args.test_batch_size) 158 | print('Client Intialization complete.') 159 | # Train and test data intialisation complete 160 | 161 | #Setting the start of personalisation phase 162 | if(args.setting!='setting2'): 163 | args.checkpoint=args.epochs+10 164 | 165 | 166 | # class_distribution=plot_class_distribution(clients, args.dataset, args.batch_size, args.epochs, args.opt_iden, client_ids) 167 | 168 | 169 | #Assigning front, center and back models and their optimizers for all the clients 170 | model = importlib.import_module(f'models.{args.model}') 171 | 172 | for _, client in clients.items(): 173 | client.front_model = model.front(input_channels, pretrained=args.pretrained) 174 | client.back_model = model.back(pretrained=args.pretrained) 175 | print('Done') 176 | 177 | for _, client in clients.items(): 178 | # client.front_optimizer = optim.SGD(client.front_model.parameters(), lr=args.lr, momentum=0.9) 179 | # client.back_optimizer = optim.SGD(client.back_model.parameters(), lr=args.lr, momentum=0.9) 180 | client.front_optimizer = optim.Adam(client.front_model.parameters(), lr=args.lr) 181 | client.back_optimizer = optim.Adam(client.back_model.parameters(), lr=args.lr) 182 | 183 | first_client = clients[client_ids[0]] 184 | num_iterations = ceil(len(first_client.train_DataLoader.dataset)/args.batch_size) 185 | 186 | num_test_iterations= ceil(len(first_client.test_DataLoader.dataset)/args.test_batch_size) 187 | sc_clients = {} #server copy clients 188 | 189 | for iden in client_ids: 190 | sc_clients[iden] = ConnectedClient(iden, None) 191 | 192 | for _,s_client in sc_clients.items(): 193 | s_client.center_model = model.center(pretrained=args.pretrained) 194 | s_client.center_model.to(device) 195 | # s_client.center_optimizer = optim.SGD(s_client.center_model.parameters(), lr=args.lr, momentum=0.9) 196 | s_client.center_optimizer = optim.Adam(s_client.center_model.parameters(), args.lr) 197 | 198 | st = time.time() 199 | 200 | macro_avg_f1_2classes=[] 201 | 202 | criterion=F.cross_entropy 203 | 204 | 205 | 206 | #Starting the training process 207 | for epoch in range(args.epochs): 208 | if(epoch==args.checkpoint): # When starting epoch of the perosnalisation is reached, freeze all the layers of the center model 209 | print("freezing the center model") 210 | for _, s_client in sc_clients.items(): 211 | s_client.center_model.freeze(epoch, pretrained=True) 212 | 213 | overall_train_acc.append(0) 214 | 215 | 216 | for _, client in clients.items(): 217 | client.train_acc.append(0) 218 | client.iterator = iter(client.train_DataLoader) 219 | 220 | #For every batch in the current epoch 221 | for iteration in range(num_iterations): 222 | print(f'\rEpoch: {epoch+1}, Iteration: {iteration+1}/{num_iterations}', end='') 223 | 224 | for _, client in clients.items(): 225 | client.forward_front() 226 | 227 | for client_id, client in sc_clients.items(): 228 | client.remote_activations1 = clients[client_id].remote_activations1 229 | client.forward_center() 230 | 231 | for client_id, client in clients.items(): 232 | client.remote_activations2 = sc_clients[client_id].remote_activations2 233 | client.forward_back() 234 | 235 | for _, client in clients.items(): 236 | client.calculate_loss() 237 | 238 | for _, client in clients.items(): 239 | client.backward_back() 240 | 241 | for client_id, client in sc_clients.items(): 242 | client.remote_activations2 = clients[client_id].remote_activations2 243 | client.backward_center() 244 | 245 | for _, client in clients.items(): 246 | client.step_back() 247 | client.zero_grad_back() 248 | 249 | #merge grads uncomment below 250 | 251 | # if epoch%2 == 0: 252 | # params = [] 253 | # normalized_data_sizes = [] 254 | # for iden, client in clients.items(): 255 | # params.append(sc_clients[iden].center_model.parameters()) 256 | # normalized_data_sizes.append(len(client.train_dataset) / train_dataset_size) 257 | # merge_grads(normalized_data_sizes, params) 258 | 259 | for _, client in sc_clients.items(): 260 | client.center_optimizer.step() 261 | client.center_optimizer.zero_grad() 262 | 263 | for _, client in clients.items(): 264 | client.train_acc[-1] += client.calculate_train_acc() #train accuracy of every client in the current epoch in the current batch 265 | 266 | for c_id, client in clients.items(): 267 | client.train_acc[-1] /= num_iterations # train accuracy of every client of all the batches in the current epoch 268 | overall_train_acc[-1] += client.train_acc[-1] 269 | 270 | overall_train_acc[-1] /= len(clients) #avg train accuracy of all the clients in the current epoch 271 | 272 | 273 | # merge weights below uncomment 274 | params = [] 275 | for _, client in sc_clients.items(): 276 | params.append(copy.deepcopy(client.center_model.state_dict())) 277 | w_glob = merge_weights(params) 278 | 279 | for _, client in sc_clients.items(): 280 | client.center_model.load_state_dict(w_glob) 281 | 282 | params = [] 283 | 284 | #In the personalisation phase merging of weights of the back layers is stopped 285 | if(epoch <=args.checkpoint): 286 | for _, client in clients.items(): 287 | params.append(copy.deepcopy(client.back_model.state_dict())) 288 | w_glob_cb = merge_weights(params) 289 | del params 290 | 291 | for _, client in clients.items(): 292 | client.back_model.load_state_dict(w_glob_cb) 293 | 294 | #Testing every epoch 295 | if (epoch%1 == 0 ): 296 | if(epoch==args.checkpoint): 297 | print("freezing the center model") 298 | print("F1 Score at epoch ", epoch, " : ",f1_avg_all_user) 299 | 300 | for _, s_client in sc_clients.items(): 301 | s_client.center_model.freeze(epoch, pretrained=True) 302 | with torch.no_grad(): 303 | test_acc = 0 304 | overall_test_acc.append(0) 305 | 306 | for _, client in clients.items(): 307 | client.test_acc.append(0) 308 | client.iterator = iter(client.test_DataLoader) 309 | client.pred=[] 310 | client.y=[] 311 | 312 | #For every batch in the testing phase 313 | for iteration in range(num_test_iterations): 314 | 315 | for _, client in clients.items(): 316 | client.forward_front() 317 | 318 | for client_id, client in sc_clients.items(): 319 | client.remote_activations1 = clients[client_id].remote_activations1 320 | client.forward_center() 321 | 322 | for client_id, client in clients.items(): 323 | client.remote_activations2 = sc_clients[client_id].remote_activations2 324 | client.forward_back() 325 | 326 | for _, client in clients.items(): 327 | client.test_acc[-1] += client.calculate_test_acc() 328 | 329 | 330 | for _, client in clients.items(): 331 | client.test_acc[-1] /= num_test_iterations 332 | overall_test_acc[-1] += client.test_acc[-1] 333 | #Calculating the F1 scores using the classification report from sklearn metrics 334 | if(args.setting=='setting2'): 335 | clr=classification_report(np.array(client.y), np.array(client.pred), output_dict=True, zero_division=0) 336 | idx=client_idxs[_] 337 | 338 | macro_avg_f1_2classes.append((clr[str(idx)]['f1-score']+clr[str((idx+1)%10)]['f1-score'])/2) #macro f1 score of the 2 prominent classes in setting2 339 | 340 | 341 | overall_test_acc[-1] /= len(clients) #average test accuracy of all the clients in the current epoch 342 | 343 | if(args.setting=='setting2'): 344 | f1_avg_all_user=sum(macro_avg_f1_2classes)/len(macro_avg_f1_2classes) #average f1 scores of the clients for the prominent 2 classes in the current epoch 345 | macro_avg_f1_2classes=[] 346 | 347 | 348 | 349 | #Noting the maximum f1 score 350 | if(f1_avg_all_user> max_f1): 351 | max_f1=f1_avg_all_user 352 | max_epoch=epoch 353 | 354 | else: 355 | if(overall_test_acc[-1]> max_accuracy): 356 | max_accuracy=overall_test_acc[-1] 357 | max_epoch=epoch 358 | 359 | 360 | 361 | timestamp = int(datetime.now().timestamp()) 362 | plot_config = f'''dataset: {args.dataset}, 363 | model: {args.model}, 364 | batch_size: {args.batch_size}, lr: {args.lr}, 365 | ''' 366 | 367 | et = time.time() 368 | print("\nTraining Accuracy: ", overall_train_acc[max_epoch]) 369 | if(args.setting=='setting2'): 370 | print("Maximum F1 Score: ", max_f1) 371 | else: 372 | print("Maximum Test Accuracy: ", max_accuracy) 373 | print(f"Time taken for this run {(et - st)/60} mins") 374 | 375 | 376 | 377 | # calculating the train and test standarad deviation and teh confidence intervals 378 | X = range(args.epochs) 379 | all_clients_stacked_train = np.array([client.train_acc for _,client in clients.items()]) 380 | all_clients_stacked_test = np.array([client.test_acc for _,client in clients.items()]) 381 | epochs_train_std = np.std(all_clients_stacked_train,axis = 0, dtype = np.float64) 382 | epochs_test_std = np.std(all_clients_stacked_test,axis = 0, dtype = np.float64) 383 | 384 | #Y_train is the average client train accuracies at each epoch 385 | #epoch_train_std is the standard deviation of clients train accuracies at each epoch 386 | Y_train = overall_train_acc 387 | Y_train_lower = Y_train - (1.65 * epochs_train_std) #95% of the values lie between 1.65*std 388 | Y_train_upper = Y_train + (1.65 * epochs_train_std) 389 | 390 | Y_test = overall_test_acc 391 | Y_test_lower = Y_test - (1.65 * epochs_test_std) #95% of the values lie between 1.65*std 392 | Y_test_upper = Y_test + (1.65 * epochs_test_std) 393 | 394 | Y_train_cv = epochs_train_std / Y_train 395 | Y_test_cv = epochs_test_std / Y_test 396 | 397 | plt.figure(0) 398 | plt.plot(X, Y_train) 399 | plt.fill_between(X,Y_train_lower , Y_train_upper, color='blue', alpha=0.25) 400 | # plt.savefig(f'./results/test_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight') 401 | plt.show() 402 | 403 | 404 | plt.figure(1) 405 | plt.plot(X, Y_test) 406 | plt.fill_between(X,Y_test_lower , Y_test_upper, color='blue', alpha=0.25) 407 | # plt.savefig(f'./results/test_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight') 408 | plt.show() 409 | 410 | 411 | plt.figure(2) 412 | plt.plot(X, Y_train_cv) 413 | plt.show() 414 | 415 | 416 | plt.figure(3) 417 | plt.plot(X, Y_test_cv) 418 | plt.show() 419 | 420 | 421 | 422 | -------------------------------------------------------------------------------- /FL_Setting3.py: -------------------------------------------------------------------------------- 1 | #=========================================================== 2 | # Federated learning: ResNet18 3 | # =========================================================== 4 | import torch 5 | from torch import nn 6 | from torchvision import transforms 7 | from torch.utils.data import DataLoader, Dataset 8 | from pandas import DataFrame 9 | import pandas as pd 10 | from sklearn.model_selection import train_test_split 11 | from PIL import Image 12 | from glob import glob 13 | import math 14 | import random 15 | import numpy as np 16 | import os 17 | import matplotlib 18 | matplotlib.use('Agg') 19 | import matplotlib.pyplot as plt 20 | import copy 21 | import argparse 22 | from utils import datasets,dataset_settings 23 | import time 24 | 25 | 26 | 27 | ## ARGPARSER 28 | 29 | def parse_arguments(): 30 | # Training settings 31 | parser = argparse.ArgumentParser( 32 | description="Splitfed V1 configurations", 33 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 34 | ) 35 | parser.add_argument( 36 | "--seed", 37 | type=int, 38 | default=1234, 39 | help="Random seed", 40 | ) 41 | parser.add_argument( 42 | "-c", 43 | "--number_of_clients", 44 | type=int, 45 | default=5, 46 | metavar="C", 47 | help="Number of Clients", 48 | ) 49 | 50 | parser.add_argument( 51 | "-n", 52 | "--epochs", 53 | type=int, 54 | default=50, 55 | metavar="N", 56 | help="Total number of epochs to train", 57 | ) 58 | 59 | parser.add_argument( 60 | "--fac", 61 | type=float, 62 | default= 1.0, 63 | metavar="N", 64 | help="fraction of active/participating clients, if 1 then all clients participate in SFLV1", 65 | ) 66 | 67 | parser.add_argument( 68 | "--lr", 69 | type=float, 70 | default=0.001, 71 | metavar="LR", 72 | help="Learning rate", 73 | ) 74 | 75 | parser.add_argument( 76 | "--dataset", 77 | type=str, 78 | default="mnist", 79 | help="States dataset to be used", 80 | ) 81 | parser.add_argument( 82 | "-b", 83 | "--batch_size", 84 | type=int, 85 | default=128, 86 | metavar="B", 87 | help="Batch size", 88 | ) 89 | 90 | 91 | parser.add_argument( 92 | "--setting", 93 | type=str, 94 | default="setting1", 95 | 96 | ) 97 | 98 | parser.add_argument( 99 | "--test_batch_size", 100 | type=int, 101 | default=128, 102 | 103 | ) 104 | 105 | 106 | parser.add_argument( 107 | "--datapoints" , 108 | type=int, 109 | default=500, 110 | ) 111 | args = parser.parse_args() 112 | args = parser.parse_args() 113 | return args 114 | 115 | 116 | 117 | #============================================================================================================== 118 | # Client Side Program 119 | #============================================================================================================== 120 | class DatasetSplit(Dataset): 121 | def __init__(self, dataset, idxs): 122 | self.dataset = dataset 123 | self.idxs = list(idxs) 124 | 125 | def __len__(self): 126 | return len(self.idxs) 127 | 128 | def __getitem__(self, item): 129 | image, label = self.dataset[self.idxs[item]] 130 | return image, label 131 | 132 | # Client-side functions associated with Training and Testing 133 | class LocalUpdate(object): 134 | def __init__(self, idx, lr, device, dataset_train = None, dataset_test = None, idxs = None, idxs_test = None): 135 | self.idx = idx 136 | self.device = device 137 | self.lr = lr 138 | self.local_ep = 1 139 | self.loss_func = nn.CrossEntropyLoss() 140 | self.selected_clients = [] 141 | self.ldr_train = DataLoader(DatasetSplit(dataset_train, idxs), batch_size = args.batch_size, shuffle = True) 142 | self.ldr_test = DataLoader(DatasetSplit(dataset_test, idxs_test), batch_size = args.test_batch_size, shuffle = True) 143 | 144 | def train(self, net): 145 | net.train() 146 | # train and update 147 | #optimizer = torch.optim.SGD(net.parameters(), lr = self.lr, momentum = 0.5) 148 | optimizer = torch.optim.Adam(net.parameters(), lr = self.lr) 149 | 150 | epoch_acc = [] 151 | epoch_loss = [] 152 | for iter in range(self.local_ep): 153 | batch_acc = [] 154 | batch_loss = [] 155 | 156 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 157 | images, labels = images.to(self.device), labels.to(self.device) 158 | optimizer.zero_grad() 159 | #---------forward prop------------- 160 | fx = net(images) 161 | 162 | # calculate loss 163 | loss = self.loss_func(fx, labels) 164 | # calculate accuracy 165 | acc = calculate_accuracy(fx, labels) 166 | 167 | #--------backward prop-------------- 168 | loss.backward() 169 | optimizer.step() 170 | 171 | batch_loss.append(loss.item()) 172 | batch_acc.append(acc.item()) 173 | 174 | if self.idx == 0: 175 | prRed('Client{} Train => Local Epoch: {} \tAcc: {:.3f} \tLoss: {:.4f}'.format(self.idx, 176 | iter, acc.item(), loss.item())) 177 | 178 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 179 | epoch_acc.append(sum(batch_acc)/len(batch_acc)) 180 | 181 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss), sum(epoch_acc) / len(epoch_acc) 182 | 183 | def evaluate(self, net): 184 | net.eval() 185 | 186 | epoch_acc = [] 187 | epoch_loss = [] 188 | with torch.no_grad(): 189 | batch_acc = [] 190 | batch_loss = [] 191 | for batch_idx, (images, labels) in enumerate(self.ldr_test): 192 | images, labels = images.to(self.device), labels.to(self.device) 193 | #---------forward prop------------- 194 | fx = net(images) 195 | 196 | # calculate loss 197 | loss = self.loss_func(fx, labels) 198 | # calculate accuracy 199 | acc = calculate_accuracy(fx, labels) 200 | 201 | batch_loss.append(loss.item()) 202 | batch_acc.append(acc.item()) 203 | 204 | prGreen('Client{} Test => \tLoss: {:.4f} \tAcc: {:.3f}'.format(self.idx, loss.item(), acc.item())) 205 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 206 | epoch_acc.append(sum(batch_acc)/len(batch_acc)) 207 | if self.idx == 0: 208 | unique_test.append(sum(batch_acc)/len(batch_acc)) 209 | return sum(epoch_loss) / len(epoch_loss), sum(epoch_acc) / len(epoch_acc) 210 | 211 | 212 | 213 | #===================================================================================================== 214 | # dataset_iid() will create a dictionary to collect the indices of the data samples randomly for each client 215 | 216 | 217 | def dataset_iid(dataset, num_users): 218 | 219 | num_items = int(len(dataset)/num_users) 220 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 221 | for i in range(num_users): 222 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace = False)) 223 | all_idxs = list(set(all_idxs) - dict_users[i]) 224 | return dict_users 225 | 226 | def dataset_iid_setting3(u_dataset, c_dataset, num_users): 227 | 228 | # u_user_idx = 0 229 | c_users = num_users - 1 230 | num_items = int(len(c_dataset)/c_users) #150 231 | 232 | unique_idxs = [i for i in range(len(u_dataset))] 233 | dict_users, all_idxs = {}, [i for i in range(len(c_dataset))] 234 | dict_users[0] = set(unique_idxs) 235 | for i in range(1, c_users + 1): 236 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace = False)) 237 | all_idxs = list(set(all_idxs) - dict_users[i]) 238 | return dict_users 239 | 240 | #==================================================================================================== 241 | # Server Side Program 242 | #==================================================================================================== 243 | def calculate_accuracy(fx, y): 244 | preds = fx.max(1, keepdim=True)[1] 245 | correct = preds.eq(y.view_as(preds)).sum() 246 | acc = 100.00 *correct.float()/preds.shape[0] 247 | return acc 248 | 249 | #============================================================================= 250 | # Model definition: ResNet18 251 | #============================================================================= 252 | # building a ResNet18 Architecture 253 | def conv3x3(in_planes, out_planes, stride=1): 254 | "3x3 convolution with padding" 255 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 256 | padding=1, bias=False) 257 | 258 | 259 | class BasicBlock(nn.Module): 260 | expansion = 1 261 | 262 | def __init__(self, inplanes, planes, stride=1, downsample=None): 263 | super(BasicBlock, self).__init__() 264 | self.conv1 = conv3x3(inplanes, planes, stride) 265 | self.bn1 = nn.BatchNorm2d(planes) 266 | self.relu = nn.ReLU(inplace=True) 267 | self.conv2 = conv3x3(planes, planes) 268 | self.bn2 = nn.BatchNorm2d(planes) 269 | self.downsample = downsample 270 | self.stride = stride 271 | 272 | def forward(self, x): 273 | residual = x 274 | 275 | out = self.conv1(x) 276 | out = self.bn1(out) 277 | out = self.relu(out) 278 | 279 | out = self.conv2(out) 280 | out = self.bn2(out) 281 | 282 | if self.downsample is not None: 283 | residual = self.downsample(x) 284 | 285 | out += residual 286 | out = self.relu(out) 287 | 288 | return out 289 | 290 | class ResNet18(nn.Module): 291 | 292 | def __init__(self, block, layers, input_channels, num_classes=1000): 293 | self.inplanes = 64 294 | super(ResNet18, self).__init__() 295 | self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, 296 | bias=False) 297 | self.bn1 = nn.BatchNorm2d(64) 298 | self.relu = nn.ReLU(inplace=True) 299 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 300 | self.layer1 = self._make_layer(block, 64, layers[0]) 301 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 302 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 303 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 304 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 305 | self.fc = nn.Linear(512 * block.expansion, num_classes) 306 | 307 | for m in self.modules(): 308 | if isinstance(m, nn.Conv2d): 309 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 310 | m.weight.data.normal_(0, math.sqrt(2. / n)) 311 | elif isinstance(m, nn.BatchNorm2d): 312 | m.weight.data.fill_(1) 313 | m.bias.data.zero_() 314 | 315 | def _make_layer(self, block, planes, blocks, stride=1): 316 | downsample = None 317 | if stride != 1 or self.inplanes != planes * block.expansion: 318 | downsample = nn.Sequential( 319 | nn.Conv2d(self.inplanes, planes * block.expansion, 320 | kernel_size=1, stride=stride, bias=False), 321 | nn.BatchNorm2d(planes * block.expansion), 322 | ) 323 | 324 | layers = [] 325 | layers.append(block(self.inplanes, planes, stride, downsample)) 326 | self.inplanes = planes * block.expansion 327 | for i in range(1, blocks): 328 | layers.append(block(self.inplanes, planes)) 329 | 330 | return nn.Sequential(*layers) 331 | 332 | def forward(self, x): 333 | x = self.conv1(x) 334 | x = self.bn1(x) 335 | x = self.relu(x) 336 | x = self.maxpool(x) 337 | 338 | x = self.layer1(x) 339 | x = self.layer2(x) 340 | x = self.layer3(x) 341 | x = self.layer4(x) 342 | 343 | x = self.avgpool(x) 344 | x = x.view(x.size(0), -1) 345 | x = self.fc(x) 346 | 347 | return x 348 | 349 | 350 | #=========================================================================================== 351 | # Federated averaging: FedAvg 352 | def FedAvg(w): 353 | w_avg = copy.deepcopy(w[0]) 354 | for k in w_avg.keys(): 355 | for i in range(1, len(w)): 356 | w_avg[k] += w[i][k] 357 | w_avg[k] = torch.div(w_avg[k], len(w)) 358 | return w_avg 359 | #==================================================== 360 | 361 | 362 | if __name__ == "__main__": 363 | 364 | 365 | unique_test = [] 366 | #=================================================================== 367 | program = "FL ResNet18" 368 | print(f"---------{program}----------") 369 | 370 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 371 | 372 | args = parse_arguments() 373 | print(args) 374 | 375 | SEED = args.seed 376 | num_users = args.number_of_clients 377 | epochs = args.epochs 378 | frac = args.fac 379 | lr = args.lr 380 | dataset = args.dataset 381 | 382 | if args.dataset == "mnist" or args.dataset == "fmnist": 383 | input_channels = 1 384 | else: 385 | input_channels = 3 386 | 387 | if args.dataset == "ham10k": 388 | no_classes = 7 389 | else: 390 | no_classes = 10 391 | 392 | random.seed(SEED) 393 | np.random.seed(SEED) 394 | torch.manual_seed(SEED) 395 | torch.cuda.manual_seed(SEED) 396 | 397 | 398 | # To print in color during test/train 399 | def prRed(skk): print("\033[91m {}\033[00m" .format(skk)) 400 | def prGreen(skk): print("\033[92m {}\033[00m" .format(skk)) 401 | 402 | 403 | transform_train = transforms.Compose([ 404 | transforms.RandomCrop(32, padding=4), 405 | transforms.RandomHorizontalFlip(), 406 | transforms.ToTensor(), 407 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 408 | ]) 409 | 410 | transform_test = transforms.Compose([ 411 | transforms.RandomCrop(32, padding=4), 412 | transforms.ToTensor(), 413 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 414 | ]) 415 | 416 | # transform_train = None 417 | # transform_test = None 418 | 419 | if dataset == "cifar10_setting3": 420 | u_train_dataset,c_train_full_dataset, test_full_dataset, input_channels = datasets.load_full_dataset(dataset, "data", num_users, args.datapoints, transform_train, transform_test) 421 | dict_users = dataset_iid_setting3(u_train_dataset, c_train_full_dataset, num_users) 422 | dict_users_test = dataset_iid(test_full_dataset, num_users) 423 | 424 | 425 | net_glob = ResNet18(BasicBlock, [2, 2, 2, 2],input_channels, no_classes) 426 | net_glob.to(device) 427 | # print(net_glob) 428 | 429 | net_glob.train() 430 | w_glob = net_glob.state_dict() 431 | 432 | loss_train_collect = [] 433 | acc_train_collect = [] 434 | loss_test_collect = [] 435 | acc_test_collect = [] 436 | 437 | st = time.time() 438 | 439 | 440 | for iter in range(epochs): 441 | w_locals, loss_locals_train, acc_locals_train, loss_locals_test, acc_locals_test = [], [], [], [], [] 442 | m = max(int(frac * num_users), 1) 443 | idxs_users = np.random.choice(range(num_users), m, replace = False) 444 | 445 | # Training/Testing simulation 446 | for idx in idxs_users: # each client 447 | 448 | 449 | if idx == 0: 450 | local = LocalUpdate(idx, lr, device, dataset_train = u_train_dataset, dataset_test = test_full_dataset, idxs = dict_users[idx], idxs_test = dict_users_test[idx]) 451 | else: 452 | local = LocalUpdate(idx, lr, device, dataset_train = c_train_full_dataset, dataset_test = test_full_dataset, idxs = dict_users[idx], idxs_test = dict_users_test[idx]) 453 | # Training ------------------ 454 | w, loss_train, acc_train = local.train(net = copy.deepcopy(net_glob).to(device)) 455 | w_locals.append(copy.deepcopy(w)) 456 | loss_locals_train.append(copy.deepcopy(loss_train)) 457 | acc_locals_train.append(copy.deepcopy(acc_train)) 458 | # Testing ------------------- 459 | loss_test, acc_test = local.evaluate(net = copy.deepcopy(net_glob).to(device)) 460 | loss_locals_test.append(copy.deepcopy(loss_test)) 461 | acc_locals_test.append(copy.deepcopy(acc_test)) 462 | 463 | 464 | 465 | # Federation process 466 | w_glob = FedAvg(w_locals) 467 | print("------------------------------------------------") 468 | print("------ Federation process at Server-Side -------") 469 | print("------------------------------------------------") 470 | 471 | # update global model --- copy weight to net_glob -- distributed the model to all users 472 | net_glob.load_state_dict(w_glob) 473 | 474 | # Train/Test accuracy 475 | acc_avg_train = sum(acc_locals_train[1:]) / len(acc_locals_train[1:]) 476 | acc_train_collect.append(acc_avg_train) 477 | acc_avg_test = sum(acc_locals_test[1:]) / len(acc_locals_test[1:]) 478 | acc_test_collect.append(acc_avg_test) 479 | 480 | # Train/Test loss 481 | loss_avg_train = sum(loss_locals_train) / len(loss_locals_train) 482 | loss_train_collect.append(loss_avg_train) 483 | loss_avg_test = sum(loss_locals_test) / len(loss_locals_test) 484 | loss_test_collect.append(loss_avg_test) 485 | 486 | 487 | print('------------------- SERVER ----------------------------------------------') 488 | print('Train: Round {:3d}, Avg Accuracy {:.3f} | Avg Loss {:.3f}'.format(iter, acc_avg_train, loss_avg_train)) 489 | print('Test: Round {:3d}, Avg Accuracy {:.3f} | Avg Loss {:.3f}'.format(iter, acc_avg_test, loss_avg_test)) 490 | print('-------------------------------------------------------------------------') 491 | 492 | 493 | #=================================================================================== 494 | 495 | print("Training and Evaluation completed!") 496 | et = time.time() 497 | print(f"Total time taken is {(et-st)/60} mins") 498 | print("Average C1 - C10 test accuracy: ", max(acc_test_collect)) 499 | print("Large Client Test Accuracy: ", max(unique_test)) 500 | 501 | #=============================================================================== 502 | # Save output data to .excel file (we use for comparision plots) 503 | round_process = [i for i in range(1, len(acc_train_collect)+1)] 504 | df = DataFrame({'round': round_process,'acc_train':acc_train_collect, 'acc_test':acc_test_collect}) 505 | file_name = f"results/FL/{program}_{args.batch_size}_{args.dataset}_{args.lr}_{args.epochs}_setting2"+".xlsx" 506 | df.to_excel(file_name, sheet_name= "v1_test", index = False) 507 | 508 | 509 | 510 | #============================================================================= 511 | # Program Completed 512 | #============================================================================= 513 | -------------------------------------------------------------------------------- /PFSL_Setting3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import string 4 | import socket 5 | import requests 6 | import sys 7 | import threading 8 | import time 9 | import torch 10 | from math import ceil 11 | from torchvision import transforms 12 | from utils.split_dataset import split_dataset, split_dataset_cifar10tl_exp, split_dataset_cifar_setting2 13 | from utils.client_simulation import generate_random_clients 14 | from utils.connections import send_object 15 | from utils.arg_parser import parse_arguments 16 | import matplotlib.pyplot as plt 17 | import time 18 | import multiprocessing 19 | import torch.optim as optim 20 | import copy 21 | from datetime import datetime 22 | from scipy.interpolate import make_interp_spline 23 | import numpy as np 24 | from ConnectedClient import ConnectedClient 25 | import importlib 26 | from utils.merge import merge_grads, merge_weights 27 | import pandas as pd 28 | import time 29 | from utils.split_dataset import DatasetFromSubset 30 | from utils import datasets,dataset_settings 31 | import torch.nn.functional as F 32 | 33 | #Helper function to load train and test data for each client 34 | class DatasetSplit(torch.utils.data.Dataset): 35 | def __init__(self, dataset, idxs): 36 | self.dataset = dataset 37 | self.idxs = list(idxs) 38 | 39 | def __len__(self): 40 | return len(self.idxs) 41 | 42 | def __getitem__(self, item): 43 | image, label = self.dataset[self.idxs[item]] 44 | return image, label 45 | 46 | 47 | class Client_try(): 48 | def __init__(self, id, *args, **kwargs): 49 | super(Client_try, self).__init__(*args, **kwargs) 50 | self.id = id 51 | self.front_model = [] 52 | self.back_model = [] 53 | self.losses = [] 54 | self.train_dataset = None 55 | self.test_dataset = None 56 | self.train_DataLoader = None 57 | self.test_DataLoader = None 58 | self.socket = None 59 | self.server_socket = None 60 | self.train_batch_size = None 61 | self.test_batch_size = None 62 | self.train_iterator = None 63 | self.test_iterator = None 64 | self.activations1 = None 65 | self.remote_activations1 = None 66 | self.outputs = None 67 | self.loss = None 68 | self.criterion = None 69 | self.data = None 70 | self.targets = None 71 | self.n_correct = 0 72 | self.n_samples = 0 73 | self.front_optimizer = None 74 | self.back_optimizer = None 75 | self.train_acc = [] 76 | self.test_acc = [] 77 | self.front_epsilons = [] 78 | self.front_best_alphas = [] 79 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 80 | # self.device = torch.device('cpu') 81 | 82 | 83 | def backward_back(self): 84 | self.loss.backward() 85 | 86 | 87 | def backward_front(self): 88 | print(self.remote_activations1.grad) 89 | self.activations1.backward(self.remote_activations1.grad) 90 | 91 | 92 | def calculate_loss(self): 93 | self.criterion = F.cross_entropy 94 | self.loss = self.criterion(self.outputs, self.targets) 95 | 96 | 97 | def calculate_test_acc(self): 98 | with torch.no_grad(): 99 | _, self.predicted = torch.max(self.outputs.data, 1) 100 | self.n_correct = (self.predicted == self.targets).sum().item() 101 | self.n_samples = self.targets.size(0) 102 | # self.test_acc.append(100.0 * self.n_correct/self.n_samples) 103 | return 100.0 * self.n_correct/self.n_samples 104 | # print(f'Acc: {self.test_acc[-1]}') 105 | 106 | 107 | def calculate_train_acc(self): 108 | with torch.no_grad(): 109 | _, self.predicted = torch.max(self.outputs.data, 1) 110 | self.n_correct = (self.predicted == self.targets).sum().item() 111 | self.n_samples = self.targets.size(0) 112 | # self.train_acc.append(100.0 * self.n_correct/self.n_samples) 113 | return 100.0 * self.n_correct/self.n_samples 114 | # print(f'Acc: {self.train_acc[-1]}') 115 | 116 | def create_DataLoader(self, dataset_train,dataset_test,idxs,idxs_test, batch_size, test_batch_size): 117 | self.train_DataLoader = torch.utils.data.DataLoader(DatasetSplit(dataset_train, idxs), batch_size = batch_size, shuffle = True) 118 | self.test_DataLoader = torch.utils.data.DataLoader(DatasetSplit(dataset_test, idxs_test), batch_size = test_batch_size, shuffle = True) 119 | 120 | def forward_back(self): 121 | self.back_model.to(self.device) 122 | self.outputs = self.back_model(self.remote_activations2) 123 | 124 | 125 | def forward_front(self, type, u_id = None): 126 | 127 | if type == "train": 128 | 129 | if self.id == u_id: 130 | 131 | try: 132 | self.data, self.targets = next(self.train_iterator) 133 | except StopIteration: 134 | self.train_iterator = iter(self.train_DataLoader) 135 | self.data, self.targets = next(self.train_iterator) 136 | else: 137 | self.data, self.targets = next(self.train_iterator) 138 | 139 | else: 140 | self.data, self.targets = next(self.test_iterator) 141 | self.data, self.targets = self.data.to(self.device), self.targets.to(self.device) 142 | self.front_model.to(self.device) 143 | self.activations1 = self.front_model(self.data) 144 | self.remote_activations1 = self.activations1.detach().requires_grad_(True) 145 | 146 | 147 | def get_model(self): 148 | model = get_object(self.socket) 149 | self.front_model = model['front'] 150 | self.back_model = model['back'] 151 | 152 | def idle(self): 153 | pass 154 | 155 | 156 | def load_data(self, dataset, transform): 157 | try: 158 | dataset_path = os.path.join(f'data/{dataset}/{self.id}') 159 | except: 160 | raise Exception(f'Dataset not found for client {self.id}') 161 | self.train_dataset = torch.load(f'{dataset_path}/train/{self.id}.pt') 162 | self.test_dataset = torch.load('data/cifar10_setting2/test/common_test.pt') 163 | 164 | self.train_dataset = DatasetFromSubset( 165 | self.train_dataset, transform=transform 166 | ) 167 | self.test_dataset = DatasetFromSubset( 168 | self.test_dataset, transform=transform 169 | ) 170 | 171 | 172 | def step_front(self): 173 | self.front_optimizer.step() 174 | 175 | 176 | def step_back(self): 177 | self.back_optimizer.step() 178 | 179 | 180 | def zero_grad_front(self): 181 | self.front_optimizer.zero_grad() 182 | 183 | 184 | def zero_grad_back(self): 185 | self.back_optimizer.zero_grad() 186 | 187 | 188 | 189 | 190 | def generate_random_client_ids_try(num_clients, id_len=4) -> list: 191 | client_ids = [] 192 | for _ in range(num_clients): 193 | client_ids.append(''.join(random.sample("abcdefghijklmnopqrstuvwxyz1234567890", id_len))) 194 | return client_ids 195 | 196 | def generate_random_clients_try(num_clients) -> dict: 197 | client_ids = generate_random_client_ids_try(num_clients) 198 | clients = {} 199 | for id in client_ids: 200 | clients[id] = Client_try(id) 201 | return clients 202 | 203 | 204 | 205 | def initialize_client(client, dataset_train,dataset_test,idxs,idxs_test, batch_size, test_batch_size): 206 | client.create_DataLoader(dataset_train,dataset_test, idxs, idxs_test, batch_size, test_batch_size) 207 | 208 | #Plots class distribution of train data available to each client 209 | def plot_class_distribution(clients, dataset, batch_size, epochs, opt, client_ids): 210 | class_distribution=dict() 211 | number_of_clients=len(client_ids) 212 | if(len(clients)<=20): 213 | plot_for_clients=client_ids 214 | else: 215 | plot_for_clients=random.sample(client_ids, 20) 216 | 217 | fig, ax = plt.subplots(nrows=(int(ceil(len(client_ids)/5))), ncols=5, figsize=(15, 10)) 218 | j=0 219 | i=0 220 | 221 | #plot histogram 222 | for client_id in plot_for_clients: 223 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels']) 224 | class_distribution[client_id]=df['labels'].value_counts().sort_index() 225 | df['labels'].value_counts().sort_index().plot(ax = ax[i,j], kind = 'bar', ylabel = 'frequency', xlabel=client_id) 226 | j+=1 227 | if(j==5 or j==10 or j==15): 228 | i+=1 229 | j=0 230 | fig.tight_layout() 231 | plt.show() 232 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_histogram.png') 233 | 234 | max_len=0 235 | #plot line graphs 236 | for client_id in plot_for_clients: 237 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels']) 238 | df['labels'].value_counts().sort_index().plot(kind = 'line', ylabel = 'frequency', label=client_id) 239 | max_len=max(max_len, list(df['labels'].value_counts(sort=False)[df.labels.mode()])[0]) 240 | plt.xticks(np.arange(0,10)) 241 | plt.ylim(0, max_len) 242 | plt.legend() 243 | plt.show() 244 | 245 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_line_graph.png') 246 | 247 | return class_distribution 248 | 249 | if __name__ == "__main__": 250 | 251 | 252 | args = parse_arguments() 253 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 254 | print("Arguments provided", args) 255 | 256 | 257 | random.seed(args.seed) 258 | torch.manual_seed(args.seed) 259 | 260 | overall_test_acc = [] 261 | overall_train_acc = [] 262 | 263 | print('Generating random clients...', end='') 264 | clients = generate_random_clients_try(args.number_of_clients) 265 | client_ids = list(clients.keys()) 266 | print('Done') 267 | 268 | train_full_dataset, test_full_dataset, input_channels = datasets.load_full_dataset(args.dataset, "data", args.number_of_clients, args.datapoints, args.pretrained) 269 | dict_users_train , dict_users_test = split_dataset_cifar_setting2(client_ids, train_full_dataset, test_full_dataset) 270 | 271 | transform=None 272 | 273 | #Assigning train and test data to each client depending for each client 274 | print('Initializing clients...') 275 | for i,(_, client) in enumerate(clients.items()): 276 | initialize_client(client, train_full_dataset, test_full_dataset, dict_users_train[i], dict_users_test[i], args.batch_size, args.test_batch_size) 277 | 278 | print('Client Intialization complete.') 279 | # Train and test data intialisation complete 280 | 281 | # class_distribution=plot_class_distribution(clients, args.dataset, args.batch_size, args.epochs, args.opt_iden, client_ids) 282 | 283 | #Assigning front, center and back models and their optimizers for all the clients 284 | model = importlib.import_module(f'models.{args.model}') 285 | 286 | for _, client in clients.items(): 287 | client.front_model = model.front(input_channels, pretrained=args.pretrained) 288 | client.back_model = model.back(pretrained=args.pretrained) 289 | print('Done') 290 | 291 | 292 | for _, client in clients.items(): 293 | # client.front_optimizer = optim.SGD(client.front_model.parameters(), lr=args.lr, momentum=0.9) 294 | # client.back_optimizer = optim.SGD(client.back_model.parameters(), lr=args.lr, momentum=0.9) 295 | client.front_optimizer = optim.Adam(client.front_model.parameters(), lr=args.lr) 296 | client.back_optimizer = optim.Adam(client.back_model.parameters(), lr=args.lr) 297 | 298 | common_client = clients[client_ids[1]] 299 | unique_client_id = client_ids[0] 300 | unique_client = clients[unique_client_id] 301 | num_iterations_common = ceil(len(common_client.train_DataLoader.dataset)/args.batch_size) 302 | num_test_iterations = ceil(len(common_client.test_DataLoader.dataset)/args.test_batch_size) 303 | unique_client_iterations = ceil(len(unique_client.train_DataLoader.dataset)/args.batch_size) 304 | unique_client.train_iterator = iter(unique_client.train_DataLoader) 305 | 306 | sc_clients = {} #server copy clients 307 | 308 | for iden in client_ids: 309 | sc_clients[iden] = ConnectedClient(iden, None) 310 | 311 | for _,s_client in sc_clients.items(): 312 | s_client.center_model = model.center(pretrained=args.pretrained) 313 | s_client.center_model.to(device) 314 | # s_client.center_optimizer = optim.SGD(s_client.center_model.parameters(), lr=args.lr, momentum=0.9) 315 | s_client.center_optimizer = optim.Adam(s_client.center_model.parameters(), args.lr) 316 | 317 | 318 | st = time.time() 319 | 320 | max_train_small_clients = 0 321 | max_test_small_clients = 0 322 | max_train_large_client = 0 323 | max_test_large_client = 0 324 | 325 | #Starting the training process 326 | for epoch in range(args.epochs): 327 | 328 | overall_train_acc.append(0) 329 | for i,(_, client) in enumerate(clients.items()): 330 | client.train_acc.append(0) 331 | if i != 0: 332 | client.train_iterator = iter(client.train_DataLoader) 333 | 334 | for c_id, client in clients.items(): 335 | #For every batch in the current epoch 336 | for iteration in range(num_iterations_common): 337 | 338 | 339 | if c_id == unique_client_id: 340 | client.forward_front("train", u_id = c_id) 341 | else: 342 | client.forward_front("train") 343 | 344 | sc_clients[c_id].remote_activations1 = clients[c_id].remote_activations1 345 | sc_clients[c_id].forward_center() 346 | client.remote_activations2 = sc_clients[c_id].remote_activations2 347 | client.forward_back() 348 | client.calculate_loss() 349 | client.backward_back() 350 | sc_clients[c_id].remote_activations2 = clients[c_id].remote_activations2 351 | sc_clients[c_id].backward_center() 352 | 353 | # client.remote_activations1 = copy.deepcopy(sc_clients[client_id].remote_activations1) 354 | # client.backward_front() 355 | 356 | client.step_back() 357 | client.zero_grad_back() 358 | sc_clients[c_id].center_optimizer.step() 359 | sc_clients[c_id].center_optimizer.zero_grad() 360 | client.train_acc[-1] += client.calculate_train_acc() 361 | 362 | client.train_acc[-1] /= num_iterations_common 363 | if c_id == unique_client_id: 364 | print("Large client train accuracy", client.train_acc[-1]) 365 | if client.train_acc[-1] >= max_train_large_client: 366 | max_train_large_client = client.train_acc[-1] 367 | else: 368 | overall_train_acc[-1] += client.train_acc[-1] 369 | 370 | overall_train_acc[-1] /= (len(clients)-1) 371 | if overall_train_acc[-1] >= max_train_small_clients: 372 | max_train_small_clients = overall_train_acc[-1] 373 | print(f'Epoch {epoch} C1-C10 Average Train Acc: {overall_train_acc[-1]}\n') 374 | 375 | # merge weights below uncomment 376 | params = [] 377 | for _, client in sc_clients.items(): 378 | params.append(copy.deepcopy(client.center_model.state_dict())) 379 | w_glob = merge_weights(params) 380 | 381 | for _, client in sc_clients.items(): 382 | client.center_model.load_state_dict(w_glob) 383 | 384 | params = [] 385 | for _, client in clients.items(): 386 | params.append(copy.deepcopy(client.back_model.state_dict())) 387 | w_glob_cb = merge_weights(params) 388 | 389 | for _, client in clients.items(): 390 | client.back_model.load_state_dict(w_glob_cb) 391 | 392 | 393 | 394 | 395 | 396 | # Testing on every 5th epoch 397 | if epoch%5 == 0: 398 | with torch.no_grad(): 399 | test_acc = 0 400 | overall_test_acc.append(0) 401 | for _, client in clients.items(): 402 | client.test_acc.append(0) 403 | client.test_iterator = iter(client.test_DataLoader) 404 | 405 | for client_id, client in clients.items(): 406 | for iteration in range(num_test_iterations): 407 | 408 | client.forward_front("test") 409 | sc_clients[client_id].remote_activations1 = clients[client_id].remote_activations1 410 | sc_clients[client_id].forward_center() 411 | client.remote_activations2 = sc_clients[client_id].remote_activations2 412 | client.forward_back() 413 | client.test_acc[-1] += client.calculate_test_acc() 414 | 415 | client.test_acc[-1] /= num_test_iterations 416 | if client_id == unique_client_id: 417 | print("Large client test accuracy", client.test_acc[-1]) 418 | if client.test_acc[-1] >= max_test_large_client: 419 | max_test_large_client = client.train_acc[-1] 420 | else: 421 | overall_test_acc[-1] += client.test_acc[-1] #not including test accuracy of unique client 422 | 423 | overall_test_acc[-1] /= (len(clients)-1) 424 | 425 | if overall_test_acc[-1] >= max_test_small_clients: 426 | max_test_small_clients = overall_test_acc[-1] 427 | print(f' Epoch {epoch} C1-C10 Average Test Acc: {overall_test_acc[-1]}\n') 428 | 429 | # print("Average C1 - C10 Train accuracy: ", max_train_small_clients) 430 | # print("Large Client train Accuracy: ", max_train_large_client) 431 | print("Average C1 - C10 test accuracy: ", max_test_small_clients) 432 | print("Large Client test Accuracy: ", max_test_large_client) 433 | 434 | 435 | 436 | 437 | timestamp = int(datetime.now().timestamp()) 438 | plot_config = f'''dataset: {args.dataset}, 439 | model: {args.model}, 440 | batch_size: {args.batch_size}, lr: {args.lr}, 441 | ''' 442 | 443 | et = time.time() 444 | print(f"Time taken for this run {(et - st)/60} mins") 445 | 446 | 447 | #BELOW CODE TO PLOT MULTIPLE LINES ON A SINGLE PLOT ONE LINE FOR EACH CLIENT 448 | # for client_id, client in clients.items(): 449 | # plt.plot(list(range(args.epochs)), client.train_acc, label=f'{client_id} (Max:{max(client.train_acc):.4f})') 450 | # plt.plot(list(range(args.epochs)), overall_train_acc, label=f'Average (Max:{max(overall_train_acc):.4f})') 451 | # plt.title(f'{args.number_of_clients} Clients: Train Accuracy vs. Epochs') 452 | # plt.ylabel('Train Accuracy') 453 | # plt.xlabel('Epochs') 454 | # plt.legend() 455 | # plt.ioff() 456 | # plt.savefig(f'./results/train_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight') 457 | # plt.show() 458 | 459 | # for client_id, client in clients.items(): 460 | # plt.plot(list(range(args.epochs)), client.test_acc, label=f'{client_id} (Max:{max(client.test_acc):.4f})') 461 | # plt.plot(list(range(args.epochs)), overall_test_acc, label=f'Average (Max:{max(overall_test_acc):.4f})') 462 | # plt.title(f'{args.number_of_clients} Clients: Test Accuracy vs. Epochs') 463 | # plt.ylabel('Test Accuracy') 464 | # plt.xlabel('Epochs') 465 | # plt.legend() 466 | # plt.ioff() 467 | # plt.savefig(f'./results/test_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight') 468 | # plt.show() 469 | -------------------------------------------------------------------------------- /system_simulation_e2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import string 4 | import socket 5 | import requests 6 | import sys 7 | import threading 8 | import time 9 | import torch 10 | from math import ceil 11 | from torchvision import transforms 12 | from utils.split_dataset import split_dataset, split_dataset_cifar10tl_exp, split_dataset_cifar_setting2 13 | from utils.client_simulation import generate_random_clients 14 | from utils.connections import send_object 15 | from utils.arg_parser import parse_arguments 16 | import matplotlib.pyplot as plt 17 | import time 18 | import multiprocessing 19 | import torch.optim as optim 20 | import copy 21 | from datetime import datetime 22 | from scipy.interpolate import make_interp_spline 23 | import numpy as np 24 | from ConnectedClient import ConnectedClient 25 | import importlib 26 | from utils.merge import merge_grads, merge_weights 27 | import pandas as pd 28 | import time 29 | from utils.split_dataset import DatasetFromSubset 30 | from utils import datasets,dataset_settings 31 | import torch.nn.functional as F 32 | 33 | class DatasetSplit(torch.utils.data.Dataset): 34 | def __init__(self, dataset, idxs): 35 | self.dataset = dataset 36 | self.idxs = list(idxs) 37 | 38 | def __len__(self): 39 | return len(self.idxs) 40 | 41 | def __getitem__(self, item): 42 | image, label = self.dataset[self.idxs[item]] 43 | return image, label 44 | 45 | 46 | class Client_try(): 47 | def __init__(self, id, *args, **kwargs): 48 | super(Client_try, self).__init__(*args, **kwargs) 49 | self.id = id 50 | self.front_model = [] 51 | self.back_model = [] 52 | self.losses = [] 53 | self.train_dataset = None 54 | self.test_dataset = None 55 | self.train_DataLoader = None 56 | self.test_DataLoader = None 57 | self.socket = None 58 | self.server_socket = None 59 | self.train_batch_size = None 60 | self.test_batch_size = None 61 | self.train_iterator = None 62 | self.test_iterator = None 63 | self.activations1 = None 64 | self.remote_activations1 = None 65 | self.outputs = None 66 | self.loss = None 67 | self.criterion = None 68 | self.data = None 69 | self.targets = None 70 | self.n_correct = 0 71 | self.n_samples = 0 72 | self.front_optimizer = None 73 | self.back_optimizer = None 74 | self.train_acc = [] 75 | self.test_acc = [] 76 | self.front_epsilons = [] 77 | self.front_best_alphas = [] 78 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 79 | # self.device = torch.device('cpu') 80 | 81 | 82 | def backward_back(self): 83 | self.loss.backward() 84 | 85 | 86 | def backward_front(self): 87 | print(self.remote_activations1.grad) 88 | self.activations1.backward(self.remote_activations1.grad) 89 | 90 | 91 | def calculate_loss(self): 92 | self.criterion = F.cross_entropy 93 | self.loss = self.criterion(self.outputs, self.targets) 94 | 95 | 96 | def calculate_test_acc(self): 97 | with torch.no_grad(): 98 | _, self.predicted = torch.max(self.outputs.data, 1) 99 | self.n_correct = (self.predicted == self.targets).sum().item() 100 | self.n_samples = self.targets.size(0) 101 | # self.test_acc.append(100.0 * self.n_correct/self.n_samples) 102 | return 100.0 * self.n_correct/self.n_samples 103 | # print(f'Acc: {self.test_acc[-1]}') 104 | 105 | 106 | def calculate_train_acc(self): 107 | with torch.no_grad(): 108 | _, self.predicted = torch.max(self.outputs.data, 1) 109 | self.n_correct = (self.predicted == self.targets).sum().item() 110 | self.n_samples = self.targets.size(0) 111 | # self.train_acc.append(100.0 * self.n_correct/self.n_samples) 112 | return 100.0 * self.n_correct/self.n_samples 113 | # print(f'Acc: {self.train_acc[-1]}') 114 | 115 | def create_DataLoader(self, dataset_train,dataset_test,idxs,idxs_test, batch_size, test_batch_size): 116 | self.train_DataLoader = torch.utils.data.DataLoader(DatasetSplit(dataset_train, idxs), batch_size = batch_size, shuffle = True) 117 | self.test_DataLoader = torch.utils.data.DataLoader(DatasetSplit(dataset_test, idxs_test), batch_size = test_batch_size, shuffle = True) 118 | 119 | def forward_back(self): 120 | self.back_model.to(self.device) 121 | self.outputs = self.back_model(self.remote_activations2) 122 | 123 | 124 | def forward_front(self, type): 125 | 126 | if type == "train": 127 | self.data, self.targets = next(self.train_iterator) 128 | else: 129 | self.data, self.targets = next(self.test_iterator) 130 | self.data, self.targets = self.data.to(self.device), self.targets.to(self.device) 131 | self.front_model.to(self.device) 132 | self.activations1 = self.front_model(self.data) 133 | self.remote_activations1 = self.activations1.detach().requires_grad_(True) 134 | 135 | 136 | def get_model(self): 137 | model = get_object(self.socket) 138 | self.front_model = model['front'] 139 | self.back_model = model['back'] 140 | 141 | def idle(self): 142 | pass 143 | 144 | 145 | def load_data(self, dataset, transform): 146 | try: 147 | dataset_path = os.path.join(f'data/{dataset}/{self.id}') 148 | except: 149 | raise Exception(f'Dataset not found for client {self.id}') 150 | self.train_dataset = torch.load(f'{dataset_path}/train/{self.id}.pt') 151 | self.test_dataset = torch.load('data/cifar10_setting2/test/common_test.pt') 152 | 153 | self.train_dataset = DatasetFromSubset( 154 | self.train_dataset, transform=transform 155 | ) 156 | self.test_dataset = DatasetFromSubset( 157 | self.test_dataset, transform=transform 158 | ) 159 | 160 | 161 | def step_front(self): 162 | self.front_optimizer.step() 163 | 164 | 165 | def step_back(self): 166 | self.back_optimizer.step() 167 | 168 | 169 | def zero_grad_front(self): 170 | self.front_optimizer.zero_grad() 171 | 172 | 173 | def zero_grad_back(self): 174 | self.back_optimizer.zero_grad() 175 | 176 | 177 | 178 | 179 | def generate_random_client_ids_try(num_clients, id_len=4) -> list: 180 | client_ids = [] 181 | for _ in range(num_clients): 182 | client_ids.append(''.join(random.sample("abcdefghijklmnopqrstuvwxyz1234567890", id_len))) 183 | return client_ids 184 | 185 | def generate_random_clients_try(num_clients) -> dict: 186 | client_ids = generate_random_client_ids_try(num_clients) 187 | clients = {} 188 | for id in client_ids: 189 | clients[id] = Client_try(id) 190 | return clients 191 | 192 | 193 | 194 | def initialize_client(client, dataset_train,dataset_test,idxs,idxs_test, batch_size, test_batch_size): 195 | client.create_DataLoader(dataset_train,dataset_test, idxs, idxs_test, batch_size, test_batch_size) 196 | 197 | 198 | def select_random_clients(clients): 199 | random_clients = {} 200 | client_ids = list(clients.keys()) 201 | random_index = random.randint(0,len(client_ids)-1) 202 | random_client_ids = client_ids[random_index] 203 | 204 | print(random_client_ids) 205 | print(clients) 206 | 207 | for random_client_id in random_client_ids: 208 | random_clients[random_client_id] = clients[random_client_id] 209 | return random_clients 210 | 211 | 212 | def plot_class_distribution(clients, dataset, batch_size, epochs, opt, client_ids): 213 | class_distribution=dict() 214 | number_of_clients=len(client_ids) 215 | if(len(clients)<=20): 216 | plot_for_clients=client_ids 217 | else: 218 | plot_for_clients=random.sample(client_ids, 20) 219 | 220 | fig, ax = plt.subplots(nrows=(int(ceil(len(client_ids)/5))), ncols=5, figsize=(15, 10)) 221 | j=0 222 | i=0 223 | 224 | #plot histogram 225 | for client_id in plot_for_clients: 226 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels']) 227 | class_distribution[client_id]=df['labels'].value_counts().sort_index() 228 | df['labels'].value_counts().sort_index().plot(ax = ax[i,j], kind = 'bar', ylabel = 'frequency', xlabel=client_id) 229 | j+=1 230 | if(j==5 or j==10 or j==15): 231 | i+=1 232 | j=0 233 | fig.tight_layout() 234 | plt.show() 235 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_histogram.png') 236 | 237 | max_len=0 238 | #plot line graphs 239 | for client_id in plot_for_clients: 240 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels']) 241 | df['labels'].value_counts().sort_index().plot(kind = 'line', ylabel = 'frequency', label=client_id) 242 | max_len=max(max_len, list(df['labels'].value_counts(sort=False)[df.labels.mode()])[0]) 243 | plt.xticks(np.arange(0,10)) 244 | plt.ylim(0, max_len) 245 | plt.legend() 246 | plt.show() 247 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_line_graph.png') 248 | 249 | return class_distribution 250 | 251 | if __name__ == "__main__": 252 | 253 | 254 | args = parse_arguments() 255 | print(args) 256 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 257 | print("Arguments provided", args) 258 | 259 | 260 | 261 | 262 | random.seed(args.seed) 263 | torch.manual_seed(args.seed) 264 | 265 | overall_test_acc = [] 266 | overall_train_acc = [] 267 | 268 | print('Generating random clients...', end='') 269 | clients = generate_random_clients_try(args.number_of_clients) 270 | client_ids = list(clients.keys()) 271 | print('Done') 272 | 273 | train_full_dataset, test_full_dataset, input_channels = datasets.load_full_dataset(args.dataset, "data", args.number_of_clients, args.datapoints, args.pretrained) 274 | dict_users_train , dict_users_test = split_dataset_cifar_setting2(client_ids, train_full_dataset, test_full_dataset,5000,5000) 275 | 276 | # print(f'Random client ids:{str(client_ids)}') 277 | transform=None 278 | 279 | 280 | print('Initializing clients...') 281 | for i,(_, client) in enumerate(clients.items()): 282 | initialize_client(client, train_full_dataset, test_full_dataset, dict_users_train[i], dict_users_test[i], args.batch_size, args.test_batch_size) 283 | # if(args.dataset!='ham10000') 284 | # class_distribution=plot_class_distribution(clients, args.dataset, args.batch_size, args.epochs, args.opt_iden, client_ids) 285 | print('Client Intialization complete.') 286 | model = importlib.import_module(f'models.{args.model}') 287 | 288 | for _, client in clients.items(): 289 | client.front_model = model.front(input_channels, pretrained=args.pretrained) 290 | client.back_model = model.back(pretrained=args.pretrained) 291 | print('Done') 292 | 293 | for _, client in clients.items(): 294 | 295 | # client.front_optimizer = optim.SGD(client.front_model.parameters(), lr=args.lr, momentum=0.9) 296 | # client.back_optimizer = optim.SGD(client.back_model.parameters(), lr=args.lr, momentum=0.9) 297 | client.front_optimizer = optim.Adam(client.front_model.parameters(), lr=args.lr) 298 | client.back_optimizer = optim.Adam(client.back_model.parameters(), lr=args.lr) 299 | 300 | 301 | sample_client = clients[client_ids[0]] 302 | #number of iterations will be (di)/b. 303 | num_iterations = ceil(50 // args.batch_size) 304 | num_test_iterations = ceil(len(sample_client.test_DataLoader.dataset)//args.test_batch_size) 305 | # print(num_iterations) 306 | # print(num_test_iterations) 307 | 308 | 309 | #define train iterators for all clients 310 | for _,client in clients.items(): 311 | client.train_iterator = iter(client.train_DataLoader) 312 | 313 | 314 | sc_clients = {} #server copy clients 315 | 316 | for iden in client_ids: 317 | sc_clients[iden] = ConnectedClient(iden, None) 318 | 319 | for _,s_client in sc_clients.items(): 320 | s_client.center_model = model.center(pretrained=args.pretrained) 321 | s_client.center_model.to(device) 322 | # s_client.center_optimizer = optim.SGD(s_client.center_model.parameters(), lr=args.lr, momentum=0.9) 323 | s_client.center_optimizer = optim.Adam(s_client.center_model.parameters(), args.lr) 324 | 325 | st = time.time() 326 | 327 | for epoch in range(args.epochs): 328 | 329 | overall_train_acc.append(0) 330 | for i,(_, client) in enumerate(clients.items()): 331 | client.train_acc.append(0) 332 | 333 | for c_id, client in clients.items(): 334 | for iteration in range(num_iterations): 335 | 336 | client.forward_front("train") 337 | sc_clients[c_id].remote_activations1 = clients[c_id].remote_activations1 338 | sc_clients[c_id].forward_center() 339 | client.remote_activations2 = sc_clients[c_id].remote_activations2 340 | client.forward_back() 341 | client.calculate_loss() 342 | client.backward_back() 343 | sc_clients[c_id].remote_activations2 = clients[c_id].remote_activations2 344 | sc_clients[c_id].backward_center() 345 | 346 | # client.remote_activations1 = copy.deepcopy(sc_clients[client_id].remote_activations1) 347 | # client.backward_front() 348 | 349 | client.step_back() 350 | client.zero_grad_back() 351 | sc_clients[c_id].center_optimizer.step() 352 | sc_clients[c_id].center_optimizer.zero_grad() 353 | client.train_acc[-1] += client.calculate_train_acc() 354 | 355 | client.train_acc[-1] /= num_iterations 356 | overall_train_acc[-1] += client.train_acc[-1] 357 | 358 | overall_train_acc[-1] /= len(clients) 359 | print(f'Epoch {epoch} Personalized Average Train Acc: {overall_train_acc[-1]}') 360 | 361 | num_clients = len(client_ids) 362 | drop_clients_ids = [] 363 | rate = args.rate 364 | num_dropoff = int(rate * num_clients) 365 | print("number of clients dropped off", num_dropoff) 366 | 367 | for _ in range(num_dropoff): 368 | drop_clients_ids.append(int(random.uniform(0,(num_clients-1)))) 369 | 370 | 371 | # merge weights below uncomment 372 | params = [] 373 | for i,(_, client) in enumerate(sc_clients.items()): 374 | if i not in drop_clients_ids: 375 | print(i, "ids of clients that are considered for weight merging") 376 | params.append(copy.deepcopy(client.center_model.state_dict())) 377 | w_glob = merge_weights(params) 378 | 379 | for _, client in sc_clients.items(): 380 | client.center_model.load_state_dict(w_glob) 381 | 382 | params = [] 383 | for i,(_, client) in enumerate(clients.items()): 384 | if i not in drop_clients_ids: 385 | params.append(copy.deepcopy(client.back_model.state_dict())) 386 | w_glob_cb = merge_weights(params) 387 | 388 | for _, client in clients.items(): 389 | client.back_model.load_state_dict(w_glob_cb) 390 | 391 | 392 | 393 | if epoch%1 == 0: 394 | with torch.no_grad(): 395 | test_acc = 0 396 | overall_test_acc.append(0) 397 | for _, client in clients.items(): 398 | client.test_acc.append(0) 399 | client.test_iterator = iter(client.test_DataLoader) 400 | 401 | for client_id, client in clients.items(): 402 | for iteration in range(num_test_iterations): 403 | 404 | client.forward_front("test") 405 | sc_clients[client_id].remote_activations1 = clients[client_id].remote_activations1 406 | sc_clients[client_id].forward_center() 407 | client.remote_activations2 = sc_clients[client_id].remote_activations2 408 | client.forward_back() 409 | client.test_acc[-1] += client.calculate_test_acc() 410 | 411 | client.test_acc[-1] /= num_test_iterations 412 | overall_test_acc[-1] += client.test_acc[-1] #not including test accuracy of unique client 413 | 414 | overall_test_acc[-1] /= len(clients) 415 | print(f' Personalized Average Test Acc: {overall_test_acc[-1]}') 416 | 417 | 418 | 419 | 420 | timestamp = int(datetime.now().timestamp()) 421 | plot_config = f'''dataset: {args.dataset}, 422 | model: {args.model}, 423 | batch_size: {args.batch_size}, lr: {args.lr}, 424 | ''' 425 | 426 | et = time.time() 427 | print(f"Time taken for this run {(et - st)/60} mins") 428 | 429 | 430 | X = range(args.epochs) 431 | all_clients_stacked_train = np.array([client.train_acc for _,client in clients.items()]) 432 | all_clients_stacked_test = np.array([client.test_acc for _,client in clients.items()]) 433 | epochs_train_std = np.std(all_clients_stacked_train,axis = 0, dtype = np.float64) 434 | epochs_test_std = np.std(all_clients_stacked_test,axis = 0, dtype = np.float64) 435 | 436 | #Y_train is the average client train accuracies at each epoch 437 | #epoch_train_std is the standard deviation of clients train accuracies at each epoch 438 | Y_train = overall_train_acc 439 | Y_train_lower = Y_train - (1.65 * epochs_train_std) #95% of the values lie between 1.65*std 440 | Y_train_upper = Y_train + (1.65 * epochs_train_std) 441 | 442 | Y_test = overall_test_acc 443 | Y_test_lower = Y_test - (1.65 * epochs_test_std) #95% of the values lie between 1.65*std 444 | Y_test_upper = Y_test + (1.65 * epochs_test_std) 445 | 446 | Y_train_cv = epochs_train_std / Y_train 447 | Y_test_cv = epochs_test_std / Y_test 448 | 449 | plt.figure(0) 450 | plt.plot(X, Y_train) 451 | plt.fill_between(X,Y_train_lower , Y_train_upper, color='blue', alpha=0.25) 452 | # plt.savefig(f'./results/train_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight') 453 | plt.show() 454 | 455 | plt.figure(1) 456 | plt.plot(X, Y_test) 457 | plt.fill_between(X,Y_test_lower , Y_test_upper, color='blue', alpha=0.25) 458 | # plt.savefig(f'./results/test_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight') 459 | plt.show() 460 | 461 | 462 | plt.figure(2) 463 | plt.plot(X, Y_train_cv) 464 | plt.show() 465 | 466 | 467 | plt.figure(3) 468 | plt.plot(X, Y_test_cv) 469 | plt.show() 470 | 471 | 472 | 473 | 474 | 475 | #BELOW CODE TO PLOT MULTIPLE LINES ON A SINGLE PLOT ONE LINE FOR EACH CLIENT 476 | # for client_id, client in clients.items(): 477 | # plt.plot(list(range(args.epochs)), client.train_acc, label=f'{client_id} (Max:{max(client.train_acc):.4f})') 478 | # plt.plot(list(range(args.epochs)), overall_train_acc, label=f'Average (Max:{max(overall_train_acc):.4f})') 479 | # plt.title(f'{args.number_of_clients} Clients: Train Accuracy vs. Epochs') 480 | # plt.ylabel('Train Accuracy') 481 | # plt.xlabel('Epochs') 482 | # plt.legend() 483 | # plt.ioff() 484 | # plt.savefig(f'./results/train_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight') 485 | # plt.show() 486 | 487 | # for client_id, client in clients.items(): 488 | # plt.plot(list(range(args.epochs)), client.test_acc, label=f'{client_id} (Max:{max(client.test_acc):.4f})') 489 | # plt.plot(list(range(args.epochs)), overall_test_acc, label=f'Average (Max:{max(overall_test_acc):.4f})') 490 | # plt.title(f'{args.number_of_clients} Clients: Test Accuracy vs. Epochs') 491 | # plt.ylabel('Test Accuracy') 492 | # plt.xlabel('Epochs') 493 | # plt.legend() 494 | # plt.ioff() 495 | # plt.savefig(f'./results/test_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight') 496 | # plt.show() 497 | 498 | 499 | -------------------------------------------------------------------------------- /FL_DR.py: -------------------------------------------------------------------------------- 1 | #=========================================================== 2 | # Federated learning: ResNet18 3 | # =========================================================== 4 | import torch 5 | from torch import nn 6 | from torchvision import transforms 7 | from torch.utils.data import DataLoader, Dataset 8 | from pandas import DataFrame 9 | import pandas as pd 10 | from sklearn.model_selection import train_test_split 11 | from PIL import Image 12 | from glob import glob 13 | import math 14 | import random 15 | import numpy as np 16 | import os 17 | import matplotlib 18 | matplotlib.use('Agg') 19 | import matplotlib.pyplot as plt 20 | import copy 21 | import argparse 22 | # from utils import datasets, dataset_settings 23 | from utils import get_eye_dataset 24 | import time 25 | from math import ceil 26 | from sklearn.metrics import classification_report 27 | from sklearn.preprocessing import LabelBinarizer 28 | from sklearn.metrics import roc_curve, auc, roc_auc_score 29 | 30 | 31 | 32 | ## ARGPARSER 33 | 34 | def parse_arguments(): 35 | # Training settings 36 | parser = argparse.ArgumentParser( 37 | description="Splitfed V1 configurations", 38 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 39 | ) 40 | parser.add_argument( 41 | "--seed", 42 | type=int, 43 | default=1234, 44 | help="Random seed", 45 | ) 46 | parser.add_argument( 47 | "-c", 48 | "--number_of_clients", 49 | type=int, 50 | default=5, 51 | metavar="C", 52 | help="Number of Clients", 53 | ) 54 | 55 | parser.add_argument( 56 | "-n", 57 | "--epochs", 58 | type=int, 59 | default=20, 60 | metavar="N", 61 | help="Total number of epochs to train", 62 | ) 63 | 64 | parser.add_argument( 65 | "--fac", 66 | type=float, 67 | default= 1.0, 68 | metavar="N", 69 | help="fraction of active/participating clients, if 1 then all clients participate in SFLV1", 70 | ) 71 | 72 | parser.add_argument( 73 | "--lr", 74 | type=float, 75 | default=0.001, 76 | metavar="LR", 77 | help="Learning rate", 78 | ) 79 | 80 | parser.add_argument( 81 | "--dataset", 82 | type=str, 83 | default="mnist", 84 | help="States dataset to be used", 85 | ) 86 | parser.add_argument( 87 | "-b", 88 | "--batch_size", 89 | type=int, 90 | default=1024, 91 | metavar="B", 92 | help="Batch size", 93 | ) 94 | parser.add_argument( 95 | 96 | "--test_batch_size", 97 | type=int, 98 | default=512, 99 | metavar="B", 100 | help="Batch size", 101 | ) 102 | parser.add_argument( 103 | "--setting", 104 | type=str, 105 | default="setting1", 106 | 107 | ) 108 | 109 | parser.add_argument( 110 | "--datapoints", 111 | type=int, 112 | default=500, 113 | 114 | ) 115 | 116 | 117 | parser.add_argument( 118 | "--opt_iden", 119 | type=str, 120 | # default=False 121 | ) 122 | 123 | args = parser.parse_args() 124 | return args 125 | 126 | #============================================================================================================== 127 | # Client Side Program 128 | #============================================================================================================== 129 | class DatasetSplit(Dataset): 130 | def __init__(self, dataset, idxs): 131 | self.dataset = dataset 132 | self.idxs = list(idxs) 133 | 134 | def __len__(self): 135 | return len(self.idxs) 136 | 137 | def __getitem__(self, item): 138 | image, label = self.dataset[self.idxs[item]] 139 | return image, label 140 | 141 | # Client-side functions associated with Training and Testing 142 | class LocalUpdate(object): 143 | def __init__(self, idx, lr, device, client_idx, dataset_train = None, dataset_test = None, idxs = None, idxs_test = None): 144 | self.idx = idx 145 | self.device = device 146 | self.lr = lr 147 | self.local_ep = 1 148 | self.loss_func = nn.CrossEntropyLoss() 149 | self.selected_clients = [] 150 | 151 | if(client_idx>=0 and client_idx<=4 ): 152 | train_data_id=1 153 | test_data_id=3 154 | elif(client_idx>=5 and client_idx<=9): 155 | train_data_id=2 156 | test_data_id=4 157 | 158 | self.train_dataset=get_eye_dataset.get_eye_data(idxs, train_data_id) 159 | self.test_dataset=get_eye_dataset.get_eye_data(idxs_test, test_data_id) 160 | self.ldr_train = DataLoader(self.train_dataset, batch_size = args.batch_size, shuffle = True) 161 | self.ldr_test = DataLoader(self.test_dataset, batch_size = args.test_batch_size, shuffle = True) 162 | clients[idx]=self 163 | # # if idx == 0: 164 | # self.ldr_train = DataLoader(dataset_train, batch_size = args.batch_size, shuffle = True) 165 | # else: 166 | # self.ldr_train = DataLoader(DatasetSplit(dataset_train, idxs), batch_size = args.batch_size, shuffle = True) 167 | 168 | # self.ldr_test = DataLoader(dataset_test, batch_size = args.batch_size, shuffle = True) 169 | 170 | def train(self, net): 171 | net.train() 172 | # train and update 173 | #optimizer = torch.optim.SGD(net.parameters(), lr = self.lr, momentum = 0.5) 174 | optimizer = torch.optim.Adam(net.parameters(), lr = self.lr) 175 | 176 | epoch_acc = [] 177 | epoch_loss = [] 178 | for iter in range(self.local_ep): 179 | batch_acc = [] 180 | batch_loss = [] 181 | 182 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 183 | images, labels = images.to(self.device), labels.to(self.device) 184 | optimizer.zero_grad() 185 | #---------forward prop------------- 186 | fx = net(images) 187 | 188 | # calculate loss 189 | loss = self.loss_func(fx, labels) 190 | # calculate accuracy 191 | acc = calculate_accuracy(fx, labels) 192 | 193 | #--------backward prop-------------- 194 | loss.backward() 195 | optimizer.step() 196 | 197 | batch_loss.append(loss.item()) 198 | batch_acc.append(acc.item()) 199 | 200 | 201 | 202 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 203 | epoch_acc.append(sum(batch_acc)/len(batch_acc)) 204 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss), sum(epoch_acc) / len(epoch_acc) 205 | 206 | def evaluate(self, net, ell): 207 | global targets, outputs 208 | net.eval() 209 | 210 | epoch_acc = [] 211 | epoch_loss = [] 212 | with torch.no_grad(): 213 | batch_acc = [] 214 | batch_loss = [] 215 | for batch_idx, (images, labels) in enumerate(self.ldr_test): 216 | images, labels = images.to(self.device), labels.to(self.device) 217 | #---------forward prop------------- 218 | fx = net(images) 219 | _,pred_t = torch.max(fx, dim=1) 220 | outputs.extend(pred_t.cpu().detach().numpy().tolist()) 221 | targets.extend(labels.cpu().detach().numpy().tolist()) 222 | 223 | 224 | 225 | # calculate loss 226 | loss = self.loss_func(fx, labels) 227 | # calculate accuracy 228 | acc = calculate_accuracy(fx, labels) 229 | 230 | batch_loss.append(loss.item()) 231 | batch_acc.append(acc.item()) 232 | 233 | 234 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 235 | epoch_acc.append(sum(batch_acc)/len(batch_acc)) 236 | 237 | clr=classification_report(np.array(targets), np.array(outputs), output_dict=True, zero_division=0) 238 | curr_f1=(clr['0']['f1-score']+clr['1']['f1-score']+clr['2']['f1-score'])/3 239 | macro_avg_f1_3classes.append(curr_f1) 240 | macro_avg_f1_dict[idx]=curr_f1 241 | 242 | targets=[] 243 | outputs=[] 244 | 245 | return sum(epoch_loss) / len(epoch_loss), sum(epoch_acc) / len(epoch_acc) 246 | 247 | 248 | 249 | #===================================================================================================== 250 | # dataset_iid() will create a dictionary to collect the indices of the data samples randomly for each client 251 | 252 | 253 | def dataset_iid(dataset, num_users): 254 | 255 | num_items = int(len(dataset)/num_users) 256 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 257 | for i in range(num_users): 258 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace = False)) 259 | all_idxs = list(set(all_idxs) - dict_users[i]) 260 | return dict_users 261 | 262 | #==================================================================================================== 263 | # Server Side Program 264 | #==================================================================================================== 265 | def calculate_accuracy(fx, y): 266 | preds = fx.max(1, keepdim=True)[1] 267 | correct = preds.eq(y.view_as(preds)).sum() 268 | acc = 100.00 *correct.float()/preds.shape[0] 269 | return acc 270 | 271 | #============================================================================= 272 | # Model definition: ResNet18 273 | #============================================================================= 274 | # building a ResNet18 Architecture 275 | def conv3x3(in_planes, out_planes, stride=1): 276 | "3x3 convolution with padding" 277 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 278 | padding=1, bias=False) 279 | 280 | 281 | def plot_class_distribution(clients, client_ids): 282 | class_distribution=dict() 283 | number_of_clients=len(client_ids) 284 | if(len(clients)<=20): 285 | plot_for_clients=client_ids 286 | else: 287 | plot_for_clients=random.sample(client_ids, 20) 288 | 289 | fig, ax = plt.subplots(nrows=(int(ceil(len(client_ids)/5))), ncols=5, figsize=(15, 10)) 290 | j=0 291 | i=0 292 | 293 | #plot histogram 294 | for client_id in plot_for_clients: 295 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels']) 296 | class_distribution[client_id]=df['labels'].value_counts().sort_index() 297 | df['labels'].value_counts().sort_index().plot(ax = ax[i,j], kind = 'bar', ylabel = 'frequency', xlabel=client_id) 298 | j+=1 299 | if(j==5 or j==10 or j==15): 300 | i+=1 301 | j=0 302 | fig.tight_layout() 303 | plt.show() 304 | 305 | plt.savefig('plot_fl.png') 306 | # plt.savefig(f'./results/classvsfreq/settin3{dataset}.png') 307 | 308 | max_len=0 309 | #plot line graphs 310 | for client_id in plot_for_clients: 311 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels']) 312 | df['labels'].value_counts().sort_index().plot(kind = 'line', ylabel = 'frequency', label=client_id) 313 | max_len=max(max_len, list(df['labels'].value_counts(sort=False)[df.labels.mode()])[0]) 314 | plt.xticks(np.arange(0,10)) 315 | plt.ylim(0, max_len) 316 | plt.legend() 317 | plt.show() 318 | 319 | # plt.savefig(f'./results/class_vs_fre/q/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_line_graph.png') 320 | 321 | return class_distribution 322 | 323 | 324 | class BasicBlock(nn.Module): 325 | expansion = 1 326 | 327 | def __init__(self, inplanes, planes, stride=1, downsample=None): 328 | super(BasicBlock, self).__init__() 329 | self.conv1 = conv3x3(inplanes, planes, stride) 330 | self.bn1 = nn.BatchNorm2d(planes) 331 | self.relu = nn.ReLU(inplace=True) 332 | self.conv2 = conv3x3(planes, planes) 333 | self.bn2 = nn.BatchNorm2d(planes) 334 | self.downsample = downsample 335 | self.stride = stride 336 | 337 | def forward(self, x): 338 | residual = x 339 | 340 | out = self.conv1(x) 341 | out = self.bn1(out) 342 | out = self.relu(out) 343 | 344 | out = self.conv2(out) 345 | out = self.bn2(out) 346 | 347 | if self.downsample is not None: 348 | residual = self.downsample(x) 349 | 350 | out += residual 351 | out = self.relu(out) 352 | 353 | return out 354 | 355 | class ResNet18(nn.Module): 356 | 357 | def __init__(self, block, layers, input_channels, num_classes=1000): 358 | self.inplanes = 64 359 | super(ResNet18, self).__init__() 360 | self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, 361 | bias=False) 362 | self.bn1 = nn.BatchNorm2d(64) 363 | self.relu = nn.ReLU(inplace=True) 364 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 365 | self.layer1 = self._make_layer(block, 64, layers[0]) 366 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 367 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 368 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 369 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 370 | self.fc = nn.Linear(512 * block.expansion, num_classes) 371 | 372 | for m in self.modules(): 373 | if isinstance(m, nn.Conv2d): 374 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 375 | m.weight.data.normal_(0, math.sqrt(2. / n)) 376 | elif isinstance(m, nn.BatchNorm2d): 377 | m.weight.data.fill_(1) 378 | m.bias.data.zero_() 379 | 380 | def _make_layer(self, block, planes, blocks, stride=1): 381 | downsample = None 382 | if stride != 1 or self.inplanes != planes * block.expansion: 383 | downsample = nn.Sequential( 384 | nn.Conv2d(self.inplanes, planes * block.expansion, 385 | kernel_size=1, stride=stride, bias=False), 386 | nn.BatchNorm2d(planes * block.expansion), 387 | ) 388 | 389 | layers = [] 390 | layers.append(block(self.inplanes, planes, stride, downsample)) 391 | self.inplanes = planes * block.expansion 392 | for i in range(1, blocks): 393 | layers.append(block(self.inplanes, planes)) 394 | 395 | return nn.Sequential(*layers) 396 | 397 | def forward(self, x): 398 | x = self.conv1(x) 399 | x = self.bn1(x) 400 | x = self.relu(x) 401 | x = self.maxpool(x) 402 | 403 | x = self.layer1(x) 404 | x = self.layer2(x) 405 | x = self.layer3(x) 406 | x = self.layer4(x) 407 | 408 | x = self.avgpool(x) 409 | x = x.view(x.size(0), -1) 410 | x = self.fc(x) 411 | 412 | return x 413 | 414 | #=========================================================================================== 415 | # Federated averaging: FedAvg 416 | def FedAvg(w): 417 | w_avg = copy.deepcopy(w[0]) 418 | for k in w_avg.keys(): 419 | for i in range(1, len(w)): 420 | w_avg[k] += w[i][k] 421 | w_avg[k] = torch.div(w_avg[k], len(w)) 422 | return w_avg 423 | #==================================================== 424 | 425 | 426 | if __name__ == "__main__": 427 | 428 | #=================================================================== 429 | program = "FL ResNet18" 430 | print(f"---------{program}----------") 431 | 432 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 433 | 434 | args = parse_arguments() 435 | 436 | SEED = args.seed 437 | num_users = 10 438 | epochs = args.epochs 439 | frac = args.fac 440 | lr = args.lr 441 | dataset = args.dataset 442 | 443 | input_channels=3 444 | no_classes=3 445 | 446 | random.seed(SEED) 447 | np.random.seed(SEED) 448 | torch.manual_seed(SEED) 449 | torch.cuda.manual_seed(SEED) 450 | global outputs, targets 451 | targets=[] 452 | outputs=[] 453 | max_accuracy=0 454 | max_train_accuracy=0 455 | max_c0_4_test=0 456 | max_c0_f1=0 457 | max_c5_9_test=0 458 | max_c5_f1=0 459 | global clients 460 | clients={} 461 | 462 | 463 | d1, d2=get_eye_dataset.get_idxs() 464 | dict_users, dict_users_test=d1, d2 465 | 466 | net_glob = ResNet18(BasicBlock, [2, 2, 2, 2],input_channels, no_classes) 467 | net_glob.to(device) 468 | 469 | 470 | net_glob.train() 471 | w_glob = net_glob.state_dict() 472 | 473 | loss_train_collect = [] 474 | acc_train_collect = [] 475 | loss_test_collect = [] 476 | acc_test_collect = [] 477 | macro_avg_f1_3classes=[] 478 | macro_avg_f1_dict={} 479 | 480 | st = time.time() 481 | max_epoch, max_f1=0,0 482 | 483 | for iter in range(epochs): 484 | w_locals, loss_locals_train, acc_locals_train, loss_locals_test, acc_locals_test, acc_locals_test1, acc_locals_test2, macro_avg_f1_3classes = [], [], [], [], [], [], [], [] 485 | m = max(int(frac * num_users), 1) 486 | idxs_users = np.random.choice(range(num_users), m, replace = False) 487 | 488 | # Training/Testing simulation 489 | for idx in idxs_users: # each client 490 | 491 | local = LocalUpdate(idx, lr, device, client_idx=idx, idxs = dict_users[idx], idxs_test = dict_users_test[idx]) 492 | # Training ------------------ 493 | w, loss_train, acc_train = local.train(net = copy.deepcopy(net_glob).to(device)) 494 | w_locals.append(copy.deepcopy(w)) 495 | loss_locals_train.append(copy.deepcopy(loss_train)) 496 | acc_locals_train.append(copy.deepcopy(acc_train)) 497 | # Testing ------------------- 498 | loss_test, acc_test = local.evaluate(net = copy.deepcopy(net_glob).to(device), ell=iter) 499 | loss_locals_test.append(copy.deepcopy(loss_test)) 500 | acc_locals_test.append(copy.deepcopy(acc_test)) 501 | if(idx>=0 and idx<5): 502 | acc_locals_test1.append(copy.deepcopy(acc_test)) 503 | elif(idx>=5 and idx<10): 504 | acc_locals_test2.append(copy.deepcopy(acc_test)) 505 | 506 | # Federation process 507 | w_glob = FedAvg(w_locals) 508 | 509 | # update global model --- copy weight to net_glob -- distributed the model to all users 510 | net_glob.load_state_dict(w_glob) 511 | 512 | # Train/Test accuracy 513 | acc_avg_train = sum(acc_locals_train) / len(acc_locals_train) 514 | acc_train_collect.append(acc_avg_train) 515 | acc_avg_test = sum(acc_locals_test) / len(acc_locals_test) 516 | acc_test_collect.append(acc_avg_test) 517 | acc_avg_test1= sum(acc_locals_test1) / len(acc_locals_test1) 518 | 519 | acc_avg_test2 = sum(acc_locals_test2) / len(acc_locals_test2) 520 | 521 | f1_avg_all_user=sum(macro_avg_f1_3classes)/ len(macro_avg_f1_3classes) 522 | 523 | # Train/Test loss 524 | loss_avg_train = sum(loss_locals_train) / len(loss_locals_train) 525 | loss_train_collect.append(loss_avg_train) 526 | loss_avg_test = sum(loss_locals_test) / len(loss_locals_test) 527 | loss_test_collect.append(loss_avg_test) 528 | 529 | 530 | # print('------------------- SERVER ----------------------------------------------') 531 | # print('Train: Round {:3d}, Avg Accuracy {:.3f} | Avg Loss {:.3f}'.format(iter, acc_avg_train, loss_avg_train)) 532 | # print('Test: Round {:3d}, Avg Accuracy {:.3f} | Avg Loss {:.3f} | F1 Score {:.3f}'.format(iter, acc_avg_test, loss_avg_test, f1_avg_all_user)) 533 | # print('-------------------------------------------------------------------------') 534 | print(f'\rEpoch: {iter}', end='') 535 | 536 | 537 | if(acc_avg_test> max_accuracy): 538 | max_accuracy=acc_avg_test 539 | max_train_accuracy=acc_avg_train 540 | max_epoch=iter 541 | max_c0_f1=macro_avg_f1_dict[0] 542 | max_c5_f1=macro_avg_f1_dict[5] 543 | max_c0_4_test=acc_avg_test1 544 | max_c5_9_test=acc_avg_test2 545 | 546 | macro_avg_f1_dict={} 547 | 548 | 549 | #=================================================================================== 550 | 551 | 552 | et = time.time() 553 | print("\nTraining and Evaluation completed!") 554 | print(f"Time taken for this run {(et - st)/60} mins") 555 | print(f'Maximum Personalized Average Test Acc: {max_accuracy} ') 556 | print(f'Maximum Personalized Average Train Acc: {max_train_accuracy} ') 557 | print(f'Client0 F1 Scores: {max_c0_f1}') 558 | print(f'Client5 F1 Scores:{max_c5_f1}') 559 | print(f'Personalized Average Test Accuracy for Clients 0 to 4 ": {max_c0_4_test}') 560 | print(f'Personalized Average Test Accuracy for Clients 5 to 9": {max_c5_9_test}') 561 | #=============================================================================== 562 | # Save output data to .excel file (we use for comparision plots) 563 | round_process = [i for i in range(1, len(acc_train_collect)+1)] 564 | df = DataFrame({'round': round_process,'acc_train':acc_train_collect, 'acc_test':acc_test_collect}) 565 | file_name = f"results/FL/{program}_{args.batch_size}_{args.dataset}_{args.lr}_{args.epochs}"+".xlsx" 566 | df.to_excel(file_name, sheet_name= "v1_test", index = False) 567 | 568 | 569 | 570 | #============================================================================= 571 | # Program Completed 572 | #============================================================================= 573 | --------------------------------------------------------------------------------