├── utils
├── __init__.py
├── __pycache__
│ ├── merge.cpython-39.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-39.pyc
│ ├── datasets.cpython-37.pyc
│ ├── datasets.cpython-39.pyc
│ ├── arg_parser.cpython-37.pyc
│ ├── arg_parser.cpython-39.pyc
│ ├── connections.cpython-37.pyc
│ ├── connections.cpython-39.pyc
│ ├── merge_grads.cpython-37.pyc
│ ├── split_dataset.cpython-37.pyc
│ ├── split_dataset.cpython-39.pyc
│ ├── client_simulation.cpython-37.pyc
│ ├── client_simulation.cpython-39.pyc
│ ├── dataset_settings.cpython-39.pyc
│ └── sharing_strategy.cpython-37.pyc
├── sharing_strategy.py
├── client_simulation.py
├── connections.py
├── merge.py
├── arg_parser.py
├── preprocess_eye_dataset_1.py
├── dataset_settings.py
├── get_eye_dataset.py
├── split_dataset.py
├── preprocess_eye_dataset_2.py
└── datasets.py
├── models
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-39.pyc
│ ├── resnet18.cpython-39.pyc
│ ├── MNIST_CNN.cpython-37.pyc
│ ├── MNIST_CNN.cpython-39.pyc
│ ├── resnet18_split.cpython-39.pyc
│ └── resnet18_setting3.cpython-39.pyc
└── resnet18.py
├── requirements.txt
├── citation.cff
├── profiler_ours.py
├── .gitignore
├── ConnectedClient.py
├── profiler.py
├── client.py
├── PFSL.py
├── README.md
├── PFSL_DR.py
├── PFSL_Setting124.py
├── FL_Setting3.py
├── PFSL_Setting3.py
├── system_simulation_e2.py
└── FL_DR.py
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/utils/__pycache__/merge.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/merge.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/models/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet18.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/models/__pycache__/resnet18.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/datasets.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/datasets.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/datasets.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/datasets.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/MNIST_CNN.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/models/__pycache__/MNIST_CNN.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/MNIST_CNN.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/models/__pycache__/MNIST_CNN.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/arg_parser.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/arg_parser.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/arg_parser.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/arg_parser.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/connections.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/connections.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/connections.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/connections.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/merge_grads.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/merge_grads.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/split_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/split_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/split_dataset.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/split_dataset.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet18_split.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/models/__pycache__/resnet18_split.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/client_simulation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/client_simulation.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/client_simulation.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/client_simulation.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dataset_settings.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/dataset_settings.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/sharing_strategy.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/utils/__pycache__/sharing_strategy.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet18_setting3.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mnswdhw/PFSL/HEAD/models/__pycache__/resnet18_setting3.cpython-39.pyc
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.4.2
2 | numpy==1.21.0
3 | opencv_python==4.5.1.48
4 | pandas==1.4.2
5 | Pillow==9.4.0
6 | requests==2.28.1
7 | scikit_learn==1.2.1
8 | scipy==1.8.1
9 | thop==0.1.1.post2209072238
10 | torch==1.13.0
11 | torchvision==0.14.0
12 | tqdm==4.59.0
13 | openpyxl==3.1.1
14 |
--------------------------------------------------------------------------------
/utils/sharing_strategy.py:
--------------------------------------------------------------------------------
1 | from heapq import nlargest
2 | from heapq import nsmallest
3 |
4 | def min_loss(clients_loss, n_smallest):
5 | smallest_loss_clients = nsmallest(N, clients_loss, key=clients_loss.get)
6 | return smallest_loss_clients
7 |
8 | def best_test_acc(clients_acc, n_largest):
9 | largest_acc_clients = nlargest(N, clients_acc, key=clients_acc.get)
10 | return largest_acc_clients
11 |
--------------------------------------------------------------------------------
/citation.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - given-names: "Manas Wadhwa"
5 | email-id: manasw@iitbhilai.ac.in
6 | - given-names: "Gagan Gupta"
7 | email-id: gagan@iitbhilai.ac.in
8 | - given-names: "Ashutosh Sahu"
9 | email-id: ashutoshsahu@iitbhilai.ac.in
10 | - given-names: "Rahul Saini"
11 | email-id: rahuls@iitbhilai.ac.in
12 | - given-names: "Vidhi Mittal"
13 | email-id: vidhimittal@iitbhilai.ac.in
14 |
15 | title: "PFSL"
16 | version: 1.0.0
17 | doi: 10.5281/zenodo.7739655
18 | date-released: 2023-02-11
19 | url: "https://github.com/mnswdhw/PFSL"
20 |
--------------------------------------------------------------------------------
/utils/client_simulation.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import os
4 | import sys
5 | sys.path.append('..')
6 | from client import Client
7 | from torch.utils.data import Dataset, DataLoader
8 | from torch.utils.data.sampler import RandomSampler
9 |
10 |
11 | def generate_random_client_ids(num_clients, id_len=4) -> list:
12 | client_ids = []
13 | for _ in range(num_clients):
14 | client_ids.append(''.join(random.sample("abcdefghijklmnopqrstuvwxyz1234567890", id_len)))
15 | return client_ids
16 |
17 |
18 | def generate_random_clients(num_clients) -> dict:
19 | client_ids = generate_random_client_ids(num_clients)
20 | clients = {}
21 | for id in client_ids:
22 | clients[id] = Client(id)
23 | return clients
24 |
--------------------------------------------------------------------------------
/utils/connections.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import socket
3 | import pickle
4 |
5 | def is_socket_closed(sock: socket.socket) -> bool:
6 | logger = logging.getLogger(__name__)
7 | try:
8 | # this will try to read bytes without blocking without removing them from buffer (peek only)
9 | data = sock.recv(16, socket.MSG_DONTWAIT | socket.MSG_PEEK)
10 | if len(data) == 0:
11 | return True
12 | except BlockingIOError:
13 | return False # socket is open and reading from it would block
14 | except ConnectionResetError:
15 | return True # socket was closed for some other reason
16 | except Exception:
17 | logger.exception("unexpected exception when checking if a socket is closed")
18 | return False
19 | return False
20 |
21 |
22 | def send_object(socket, data):
23 | socket.send(data)
24 |
25 |
26 | def get_object(socket):
27 | data = socket.recv()
28 | return data
29 |
30 |
--------------------------------------------------------------------------------
/utils/merge.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import copy
3 |
4 | def merge_grads(normalized_data_sizes, params):
5 | # params = [params_client1,
6 | # params_client2,
7 | # params_client3
8 | # ...
9 | # ]
10 | num_clients = len(params)
11 | for j,col in enumerate(zip(*params)):
12 | avg = 0
13 | for i,param in enumerate(col):
14 | avg += normalized_data_sizes[i] * param.grad
15 | # avg += param.grad
16 |
17 | # avg /= num_clients # (since we are already doing weighted adding of gradients)
18 | for param in col:
19 | param.grad = copy.deepcopy(avg)
20 | # print("is para grad equal to average?", param.grad)
21 |
22 | return
23 |
24 |
25 |
26 | def merge_weights(w):
27 | #after step op, merge weights
28 |
29 | w_avg = copy.deepcopy(w[0])
30 | for k in w_avg.keys():
31 | for i in range(1, len(w)):
32 | w_avg[k] += w[i][k]
33 | w_avg[k] = torch.div(w_avg[k], len(w))
34 |
35 | return w_avg
--------------------------------------------------------------------------------
/profiler_ours.py:
--------------------------------------------------------------------------------
1 | from torchvision.models import resnet50
2 | from thop import profile
3 | import os
4 | import random
5 | import string
6 | import socket
7 | import requests
8 | import sys
9 | import threading
10 | import time
11 | import torch
12 | from math import ceil
13 | from torchvision import transforms
14 | from utils.split_dataset import split_dataset, split_dataset_cifar10tl_exp
15 | from utils.client_simulation import generate_random_clients
16 | from utils.connections import send_object
17 | from utils.arg_parser import parse_arguments
18 | import matplotlib.pyplot as plt
19 | import time
20 | import server
21 | import multiprocessing
22 | # from opacus import PrivacyEngine
23 | # from opacus.accountants import RDPAccountant
24 | # from opacus import GradSampleModule
25 | # from opacus.optimizers import DPOptimizer
26 | # from opacus.validators import ModuleValidator
27 | import torch.optim as optim
28 | import copy
29 | from datetime import datetime
30 | from scipy.interpolate import make_interp_spline
31 | import numpy as np
32 | from ConnectedClient import ConnectedClient
33 | import importlib
34 | from utils.merge import merge_grads, merge_weights
35 | import pandas as pd
36 | import time
37 |
38 |
39 |
40 | model = importlib.import_module(f'models.resnet18')
41 | model_cf = model.front(3, pretrained=True)
42 | model_cb = model.back(pretrained = True)
43 | model_center = model.center(pretrained=True)
44 |
45 | input = torch.randn(64,3,224,224)
46 | input_back = torch.randn(64, 512, 7, 7)
47 | macs_client_CF, params_FL = profile(model_cf, inputs=(input, ))
48 | macs_client_CB, params_SL = profile(model_cb, inputs=(input_back, ))
49 |
50 | print(f"GFLOPS CF {((2 * macs_client_CF) / (10**9))} PARAMS: {params_FL}")
51 | print(f"GFLOPS CB {((2 * macs_client_CB) / (10**9))} PARAMS: {params_SL}")
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *$py.class
4 |
5 | # C extensions
6 | *.so
7 |
8 | # Distribution / packaging
9 | .Python
10 | build/
11 | develop-eggs/
12 | dist/
13 | downloads/
14 | eggs/
15 | .eggs/
16 | lib/
17 | lib64/
18 | parts/
19 | sdist/
20 | var/
21 | wheels/
22 | pip-wheel-metadata/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 | db.sqlite3-journal
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # IPython
80 | profile_default/
81 | ipython_config.py
82 |
83 | # pyenv
84 | .python-version
85 |
86 | # pipenv
87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
90 | # install all needed dependencies.
91 | #Pipfile.lock
92 |
93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
94 | __pypackages__/
95 |
96 | # Celery stuff
97 | celerybeat-schedule
98 | celerybeat.pid
99 |
100 | # SageMath parsed files
101 | *.sage.py
102 |
103 | # Environments
104 | .env
105 | .venv
106 | env/
107 | venv/
108 | ENV/
109 | env.bak/
110 | venv.bak/
111 |
112 | # Spyder project settings
113 | .spyderproject
114 | .spyproject
115 |
116 | # Rope project settings
117 | .ropeproject
118 |
119 | # mkdocs documentation
120 | /site
121 |
122 | # mypy
123 | .mypy_cache/
124 | .dmypy.json
125 | dmypy.json
126 |
127 | # Pyre type checker
128 | .pyre/
129 | results/
130 | data/
--------------------------------------------------------------------------------
/utils/arg_parser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def parse_arguments():
5 | # Training settings
6 | parser = argparse.ArgumentParser(
7 | description="Split Learning Research Simulation entrypoint",
8 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
9 | )
10 |
11 | parser.add_argument(
12 | "-c",
13 | "--number_of_clients",
14 | type=int,
15 | default=10,
16 | metavar="C",
17 | help="Number of Clients",
18 | )
19 |
20 | parser.add_argument(
21 | "-b",
22 | "--batch_size",
23 | type=int,
24 | default=128,
25 | metavar="B",
26 | help="Batch size",
27 | )
28 | parser.add_argument(
29 | "--test_batch_size",
30 | type=int,
31 | default=128,
32 | metavar="TB",
33 | help="Input batch size for testing",
34 | )
35 | parser.add_argument(
36 | "-n",
37 | "--epochs",
38 | type=int,
39 | default=50,
40 | metavar="N",
41 | help="Total number of epochs to train",
42 | )
43 | parser.add_argument(
44 | "--lr",
45 | type=float,
46 | default=0.001,
47 | metavar="LR",
48 | help="Learning rate",
49 | )
50 |
51 | parser.add_argument(
52 | "--rate",
53 | type=float,
54 | default=0.5,
55 | help="dropoff rate",
56 | )
57 |
58 | parser.add_argument(
59 | "--dataset",
60 | type=str,
61 | default="cifar10",
62 | help="States dataset to be used",
63 | )
64 | parser.add_argument(
65 | "--seed",
66 | type=int,
67 | default=1234,
68 | help="Random seed",
69 | )
70 | parser.add_argument(
71 | "--model",
72 | type=str,
73 | default="resnet18",
74 | help="Model you would like to train",
75 | )
76 | parser.add_argument(
77 | "--epoch_batch",
78 | type=str,
79 | default="5",
80 | help="Number of epochs after which next batch of clients should join",
81 | )
82 | parser.add_argument(
83 | "--opt_iden",
84 | type=str,
85 | default="",
86 | help="optional identifier of experiment",
87 | )
88 |
89 | parser.add_argument(
90 | "--pretrained",
91 | action="store_true",
92 | default=False,
93 | help="Use transfer learning using a pretrained model",
94 | )
95 | parser.add_argument(
96 | "--datapoints",
97 | type=int,
98 | default=500,
99 | help="Number of samples of training data allotted to each client",
100 | )
101 | parser.add_argument(
102 | "--setting",
103 | type=str,
104 | default='setting1',
105 | help='Setting you would like to run for, i.e, setting1 , setting2 or setting4'
106 |
107 | )
108 | parser.add_argument(
109 | "--checkpoint",
110 | type=int,
111 | default=50,
112 | help="Epoch at which personalisation phase will start",
113 | )
114 |
115 | args = parser.parse_args()
116 | return args
117 |
--------------------------------------------------------------------------------
/utils/preprocess_eye_dataset_1.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import cv2
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import Dataset, DataLoader
6 | import pandas as pd
7 | import os
8 | import pickle
9 | import torchvision.transforms as transforms
10 |
11 |
12 | def crop_image_from_gray(img,tol=7):
13 | if img.ndim ==2:
14 | mask = img>tol
15 | return img[np.ix_(mask.any(1),mask.any(0))]
16 | elif img.ndim==3:
17 | gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
18 | mask = gray_img>tol
19 |
20 | check_shape = img[:,:,0][np.ix_(mask.any(1),mask.any(0))].shape[0]
21 | if (check_shape == 0): # image is too dark so that we crop out everything,
22 | return img # return original image
23 | else:
24 | img1=img[:,:,0][np.ix_(mask.any(1),mask.any(0))]
25 | img2=img[:,:,1][np.ix_(mask.any(1),mask.any(0))]
26 | img3=img[:,:,2][np.ix_(mask.any(1),mask.any(0))]
27 | # print(img1.shape,img2.shape,img3.shape)
28 | img = np.stack([img1,img2,img3],axis=-1)
29 | # print(img.shape)
30 | return img
31 |
32 |
33 | def circle_crop(img, sigmaX):
34 | """
35 | Create circular crop around image centre
36 | """
37 |
38 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
39 | img = crop_image_from_gray(img)
40 | height, width, depth = img.shape
41 |
42 | x = int(width/2)
43 | y = int(height/2)
44 | r = np.amin((x,y))
45 |
46 | circle_img = np.zeros((height, width), np.uint8)
47 | cv2.circle(circle_img, (x,y), int(r), 1, thickness=-1)
48 | img = cv2.bitwise_and(img, img, mask=circle_img)
49 | img = crop_image_from_gray(img)
50 | img=cv2.resize(img, (224, 224))
51 | img=cv2.addWeighted(img,4, cv2.GaussianBlur( img , (0,0) , sigmaX) ,-4 ,128)
52 |
53 | return img
54 |
55 |
56 | def load_data():
57 | train = pd.read_csv('data/eye_dataset1/train.csv')
58 | test = pd.read_csv('data/eye_dataset1/test.csv')
59 |
60 | train_dir = os.path.join('data/eye_dataset1/train_images')
61 | test_dir = os.path.join('data/eye_dataset1/test_images')
62 |
63 | train['file_path'] = train['id_code'].map(lambda x: os.path.join(train_dir,'{}.png'.format(x)))
64 | test['file_path'] = test['id_code'].map(lambda x: os.path.join(test_dir,'{}.png'.format(x)))
65 |
66 | train['file_name'] = train["id_code"].apply(lambda x: x + ".png")
67 | test['file_name'] = test["id_code"].apply(lambda x: x + ".png")
68 |
69 | train['diagnosis'] = train['diagnosis'].astype(str)
70 |
71 | return train,test
72 |
73 |
74 | from tqdm import tqdm
75 |
76 | df_train,df_test = load_data()
77 | print(df_train['diagnosis'].value_counts())
78 | X_train=[]
79 | Y_train=[]
80 | count_0, count_1, count_2=0,0,0
81 | print(type(df_train.diagnosis.iloc[0]))
82 |
83 | for i in tqdm(range(0,len(df_train))):
84 |
85 | img = cv2.imread(df_train.file_path.iloc[i])
86 |
87 | img = circle_crop(img,sigmaX=10)
88 | X_train.append(img)
89 | temp=[]
90 |
91 | if(int(df_train.diagnosis.iloc[i])==0):
92 | count_0+=1
93 | ans=0
94 | elif(int(df_train.diagnosis.iloc[i])==1 or int(df_train.diagnosis.iloc[i])==2):
95 | count_1+=1
96 | ans=1
97 | elif(int(df_train.diagnosis.iloc[i])==3 or int(df_train.diagnosis.iloc[i])==4):
98 | count_2+=1
99 | ans=2
100 | Y_train.append(ans)
101 |
102 | with open('data/x_train_eye_1_1', 'wb') as pickle_file:
103 | pickle.dump(X_train, pickle_file)
104 | with open('data/y_train_eye_1_1', 'wb') as pickle_file:
105 | pickle.dump(Y_train, pickle_file)
106 |
107 |
--------------------------------------------------------------------------------
/ConnectedClient.py:
--------------------------------------------------------------------------------
1 | from threading import Thread
2 | from utils.connections import is_socket_closed
3 | from utils.connections import send_object
4 | from utils.connections import get_object
5 | import pickle
6 | import queue
7 | import struct
8 | import torch
9 |
10 |
11 | def handle(client, addr, file):
12 | buffsize = 1024
13 | # file = '/home/ashutosh/score_report.pdf'
14 | # print('File size:', os.path.getsize(file))
15 | fsize = struct.pack('!I', len(file))
16 | print('Len of file size struct:', len(fsize))
17 | client.send(fsize)
18 | # with open(file, 'rb') as fd:
19 | while True:
20 | chunk = fd.read(buffsize)
21 | if not chunk:
22 | break
23 | client.send(chunk)
24 | fd.seek(0)
25 | hash = hashlib.sha512()
26 | while True:
27 | chunk = fd.read(buffsize)
28 | if not chunk:
29 | break
30 | hash.update(chunk)
31 | client.send(hash.digest())
32 |
33 |
34 | class ConnectedClient(object):
35 | # def __init__(self, id, conn, address, loop_time=1/60, *args, **kwargs):
36 | def __init__(self, id, conn, *args, **kwargs):
37 | super(ConnectedClient, self).__init__(*args, **kwargs)
38 | self.id = id
39 | self.conn = conn
40 | self.front_model = None
41 | self.back_model = None
42 | self.center_model = None
43 | self.train_fun = None
44 | self.test_fun = None
45 | self.keepRunning = True
46 | self.a1 = None
47 | self.a2 = None
48 | self.center_optimizer = None
49 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
50 |
51 |
52 | # def onThread(self, function, *args, **kwargs):
53 | # self.q.put((function, args, kwargs))
54 |
55 |
56 | # def run(self, loop_time=1/60, *args, **kwargs):
57 | # super(ConnectedClient, self).run(*args, **kwargs)
58 | # while True:
59 | # try:
60 | # function, args, kwargs = self.q.get(timeout=self.timeout)
61 | # function(*args, **kwargs)
62 | # except queue.Empty:
63 | # self.idle()
64 |
65 | def forward_center(self):
66 | self.activations2 = self.center_model(self.remote_activations1)
67 | self.remote_activations2 = self.activations2.detach().requires_grad_(True)
68 |
69 |
70 | def backward_center(self):
71 | self.activations2.backward(self.remote_activations2.grad)
72 |
73 |
74 | def idle(self):
75 | pass
76 |
77 |
78 | def connect(self):
79 | pass
80 |
81 |
82 | def disconnect(self):
83 | if not is_socket_closed(self.conn):
84 | self.conn.close()
85 | return True
86 | else:
87 | return False
88 |
89 |
90 | # def _send_model(self):
91 | def send_model(self):
92 | model = {'front': self.front_model, 'back': self.back_model}
93 | send_object(self.conn, model)
94 | # handle(self.conn, self.address, model)
95 |
96 |
97 | # def send_optimizers(self):
98 | # # This is just a sample code and NOT optimizers. Need to write code for initializing optimizers
99 | # optimizers = {'front': self.front_model.parameters(), 'back': self.back_model.parameters()}
100 | # send_object(self.conn, optimizers)
101 |
102 |
103 | def send_activations(self, activations):
104 | send_object(self.conn, activations)
105 |
106 |
107 | def get_remote_activations1(self):
108 | self.remote_activations1 = get_object(self.conn)
109 |
110 |
111 | def send_remote_activations2(self):
112 | send_object(self.conn, self.remote_activations2)
113 |
114 |
115 | def get_remote_activations2_grads(self):
116 | self.remote_activations2.grad = get_object(self.conn)
117 |
118 |
119 | def send_remote_activations1_grads(self):
120 | send_object(self.conn, self.remote_activations1.grad)
121 |
122 | # def send_model(self):
123 | # self.onThread(self._send_model)
124 |
--------------------------------------------------------------------------------
/models/resnet18.py:
--------------------------------------------------------------------------------
1 | from torchvision import models
2 | import torch.nn as nn
3 |
4 | # change these manually for now
5 | num_front_layers = 4
6 | num_back_layers = 3
7 | # remember unfreeze will be counted from the end of each model
8 | num_unfrozen_front_layers = 0
9 | num_unfrozen_center_layers = 1 #exp with 1 or 2(max)
10 | num_unfrozen_back_layers = 4
11 |
12 | def get_resnet18(pretrained: bool):
13 | model = models.resnet18(pretrained=pretrained) # this will use cached model if available instead of downloadinig again
14 | return model
15 |
16 |
17 | class front(nn.Module):
18 | def __init__(self, input_channels=3, pretrained=False):
19 | super(front, self).__init__()
20 | model = get_resnet18(pretrained)
21 | model_children = list(model.children())
22 | self.input_channels = input_channels
23 | if self.input_channels == 1:
24 | self.conv_channel_change = nn.Conv2d(1,3,3,1,2) #to keep the image size same as input image size to this conv layer
25 | self.front_model = nn.Sequential(*model_children[:num_front_layers])
26 |
27 | if pretrained:
28 | layer_iterator = iter(self.front_model)
29 | for i in range(num_front_layers-num_unfrozen_front_layers):
30 | layer = layer_iterator.__next__()
31 | for param in layer.parameters():
32 | param.requires_grad = False
33 |
34 | def forward(self, x):
35 |
36 | if self.input_channels == 1:
37 | x = self.conv_channel_change(x)
38 | x = self.front_model(x)
39 | return x
40 |
41 |
42 | class center(nn.Module):
43 | def __init__(self, pretrained=False):
44 | super(center, self).__init__()
45 | model = get_resnet18(pretrained)
46 | model_children = list(model.children())
47 | global center_model_length
48 | center_model_length = len(model_children) - num_front_layers - num_back_layers
49 |
50 |
51 | self.center_model = nn.Sequential(*model_children[num_front_layers:center_model_length+num_front_layers])
52 |
53 |
54 | if pretrained:
55 |
56 | layer_iterator = iter(self.center_model)
57 |
58 | for i in range(center_model_length-num_unfrozen_center_layers):
59 | layer = layer_iterator.__next__()
60 | for param in layer.parameters():
61 | param.requires_grad = False
62 |
63 | def freeze(self, epoch, pretrained=False):
64 |
65 |
66 | num_unfrozen_center_layers=0
67 | if pretrained:
68 |
69 | layer_iterator = iter(self.center_model)
70 | for i in range(center_model_length-num_unfrozen_center_layers):
71 | layer = layer_iterator.__next__()
72 | for param in layer.parameters():
73 | param.requires_grad = False
74 |
75 | def forward(self, x):
76 | x = self.center_model(x)
77 | return x
78 |
79 |
80 |
81 | class back(nn.Module):
82 | def __init__(self, pretrained=False, output_dim=10):
83 | super(back, self).__init__()
84 | model = get_resnet18(pretrained)
85 | model_children = list(model.children())
86 | model_length = len(model_children)
87 |
88 | fc_layer = nn.Linear(512, output_dim)
89 | model_children = model_children[:-1] + [nn.Flatten()] + [fc_layer]
90 | self.back_model = nn.Sequential(*model_children[model_length-num_back_layers:])
91 |
92 | if pretrained:
93 | layer_iterator = iter(self.back_model)
94 | for i in range(num_back_layers-num_unfrozen_back_layers):
95 | layer = layer_iterator.__next__()
96 | for param in layer.parameters():
97 | param.requires_grad = False
98 |
99 |
100 | def forward(self, x):
101 | x = self.back_model(x)
102 | return x
103 |
104 |
105 | if __name__ == '__main__':
106 | model = front(pretrained=True)
107 | print(f'{model.front_model}\n\n')
108 | model = center(pretrained=True)
109 | print(f'{model.center_model}\n\n')
110 | model = back(pretrained=True)
111 | print(f'{model.back_model}')
112 |
--------------------------------------------------------------------------------
/utils/dataset_settings.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 |
4 |
5 | def setting2(train_full_dataset,test_full_dataset, num_users):
6 | dict_users, dict_users_test = {}, {}
7 | for i in range(num_users):
8 | dict_users[i]=[]
9 | dict_users_test[i]=[]
10 |
11 | df=pd.DataFrame(list(train_full_dataset), columns=['images', 'labels'])
12 | df_test=pd.DataFrame(list(test_full_dataset), columns=['images', 'labels'])
13 | num_of_classes=len(df['labels'].unique())
14 |
15 | dict_classwise={}
16 | dict_classwise_test={}
17 |
18 |
19 | for i in range(num_of_classes):
20 | dict_classwise[i] = df[df['labels']==i].index.values.astype(int)
21 |
22 | for i in range(num_of_classes):
23 | dict_classwise_test[i] = df_test[df_test['labels']==i].index.values.astype(int)
24 |
25 | for i in range(num_users):
26 |
27 | for j in range(num_of_classes):
28 | if(i==j or (i+1)%10==j):
29 | temp=list(np.random.choice(dict_classwise[j], 225, replace = False))
30 | dict_users[i].extend(temp)
31 | dict_classwise[j] = list(set(dict_classwise[j]) -set( temp))
32 |
33 | elif((i+2)%10==j or (i+3)%10==j):
34 | temp=list(np.random.choice(dict_classwise[j], 7, replace = False))
35 | dict_users[i].extend(temp)
36 | dict_classwise[j] = list(set(dict_classwise[j]) -set( temp))
37 |
38 | else:
39 | temp=list(np.random.choice(dict_classwise[j],6, replace = False))
40 | dict_users[i].extend(temp)
41 | dict_classwise[j] = list(set(dict_classwise[j]) -set( temp))
42 |
43 | for i in range(num_users):
44 |
45 | for j in range(num_of_classes):
46 | if(i==j or (i+1)%10==j):
47 | temp=list(np.random.choice(dict_classwise_test[j],450 , replace = False))
48 | dict_users_test[i].extend(temp)
49 | dict_classwise_test[j] = list(set(dict_classwise_test[j]) -set( temp))
50 | elif((i+2)%10==j or (i+3)%10==j or (i+4)%10==j or (i+5)%10==j):
51 | temp=list(np.random.choice(dict_classwise_test[j], 13, replace = False))
52 | dict_users_test[i].extend(temp)
53 | dict_classwise_test[j] = list(set(dict_classwise_test[j]) -set( temp))
54 | else:
55 | temp=list(np.random.choice(dict_classwise_test[j], 12, replace = False))
56 | dict_users_test[i].extend(temp)
57 | dict_classwise_test[j] = list(set(dict_classwise_test[j]) -set( temp))
58 |
59 |
60 | return dict_users , dict_users_test
61 |
62 |
63 | def setting1(dataset, num_users, datapoints):
64 |
65 | dict_users = {}
66 |
67 | for i in range(num_users):
68 | dict_users[i]=[]
69 | df=pd.DataFrame(list(dataset), columns=['images', 'labels'])
70 | num_of_classes=len(df['labels'].unique())
71 |
72 | per_class_client=int(datapoints/num_of_classes)
73 | per_class_total=per_class_client*num_users
74 |
75 | dict_classwise={}
76 |
77 | for i in range(num_of_classes):
78 | dict_classwise[i] = df[df['labels']==i].index.values.astype(int)[:per_class_total]
79 |
80 | for i in range(num_users):
81 |
82 | for j in range(num_of_classes):
83 | temp=list(np.random.choice(dict_classwise[j], per_class_client, replace = False))
84 | dict_users[i].extend(temp)
85 | dict_classwise[j] = list(set(dict_classwise[j]) -set( temp))
86 |
87 | return dict_users
88 |
89 | def get_test_dict(dataset, num_users):
90 |
91 | dict_users, all_idxs = {}, [i for i in range(len(dataset))]
92 | val = int(len(dataset)//num_users)
93 | #giving the same test set to all the clients
94 | test_set = set(np.random.choice(all_idxs, 2000, replace=False))
95 |
96 | for i in range(num_users):
97 | dict_users[i] = test_set
98 | return dict_users
99 |
100 |
101 | def get_dicts(train_full_dataset, test_full_dataset, num_users, setting, datapoints):
102 |
103 | if setting == 'setting2':
104 | dict_users, dict_users_test=setting2(train_full_dataset, test_full_dataset, num_users)
105 |
106 | elif setting == 'setting1':
107 | dict_users = setting1(train_full_dataset, num_users, datapoints)
108 | dict_users_test=get_test_dict(test_full_dataset, num_users)
109 |
110 | return dict_users, dict_users_test
--------------------------------------------------------------------------------
/utils/get_eye_dataset.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from torch.utils.data import DataLoader, Dataset
3 | import numpy as np
4 | import torchvision.transforms as transforms
5 | import random
6 |
7 | with open('data/x_train_eye_1_1', 'rb') as f:
8 | x_train1 = pickle.load(f)
9 | with open('data/y_train_eye_1_1', 'rb') as f:
10 | y_train1 = pickle.load(f)
11 |
12 | with open('data/x_train_eye_2', 'rb') as f:
13 | x_train2 = pickle.load(f)
14 | with open('data/y_train_eye_2', 'rb') as f:
15 | y_train2 = pickle.load(f)
16 |
17 | # a_list = [1,2,3,4,1,2,1,2,3,4]
18 | def find_indices(list_to_check, item_to_find):
19 | array = np.array(list_to_check)
20 | indices = np.where(array == item_to_find)[0]
21 | return list(indices)
22 |
23 |
24 | # test_1=
25 | # rndm_idxs=list(np.random.choice( list(range(0,3500)),1000, replace = False))
26 | x_test1=[]
27 | y_test1=[]
28 | x_test2=[]
29 | y_test2=[]
30 |
31 | temp = x_train1[0: 1000]
32 | x_test1.extend(temp)
33 |
34 | temp = y_train1[0: 1000]
35 | y_test1.extend(temp)
36 |
37 | x_train1=x_train1[1000:]
38 | y_train1=y_train1[1000:]
39 |
40 |
41 | temp = x_train2[0: 1000]
42 | x_test2.extend(temp)
43 | temp = y_train2[0: 1000]
44 | y_test2.extend(temp)
45 |
46 | x_train2=x_train2[1000:]
47 | y_train2= y_train2[1000:]
48 |
49 | list_0=find_indices(y_train1,0)
50 | list_1=find_indices(y_train1,1)
51 | list_2=find_indices(y_train1,2)
52 |
53 |
54 | list_0=find_indices(y_train2,0)
55 | list_1=find_indices(y_train2,1)
56 | list_2=find_indices(y_train2,2)
57 |
58 |
59 | list_0=find_indices(y_test1,0)
60 | list_1=find_indices(y_test1,1)
61 | list_2=find_indices(y_test1,2)
62 |
63 |
64 | list_0=find_indices(y_test2,0)
65 | list_1=find_indices(y_test2,1)
66 | list_2=find_indices(y_test2,2)
67 |
68 |
69 | class CreateDataset(Dataset):
70 | def __init__(self, x,y,idxs, transform=None):
71 | super().__init__()
72 | self.x=x
73 | self.y=y
74 | self.transform = transform
75 | self.idxs=list(idxs)
76 |
77 | def __len__(self):
78 | return len(self.idxs)
79 |
80 | def __getitem__(self, index):
81 | image,label = self.x[self.idxs[index]], self.y[self.idxs[index]]
82 |
83 | if self.transform is not None:
84 | image = self.transform(image)
85 |
86 | return image, label
87 |
88 | def get_eye_data(idxs, num):
89 | if(num==1):
90 | y_train=y_train1
91 | x_train=x_train1
92 | elif(num==2):
93 | y_train=y_train2
94 | x_train=x_train2
95 | elif(num==3):
96 | y_train=y_test1
97 | x_train=x_test1
98 | elif(num==4):
99 | y_train=y_test2
100 | x_train=x_test2
101 | transform_train= transforms.Compose([
102 | transforms.ToPILImage(),
103 | transforms.Resize((224, 224)),
104 | transforms.RandomHorizontalFlip(p=0.4),
105 | #transforms.ColorJitter(brightness=2, contrast=2),
106 | transforms.ToTensor(),
107 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
108 | ])
109 | train_data = CreateDataset(x=x_train, y=y_train, idxs=idxs, transform=transform_train)
110 | return train_data
111 |
112 |
113 |
114 | def get_idxs():
115 |
116 | l1=[250,180,70]
117 | l2=[370,109,21]
118 | # l3=[295,157,48]
119 | dict_users, dict_users_test={}, {}
120 | client_id=0
121 |
122 | # dict_users_test=get_test_idxs()
123 | test_list=list(range(0,1000))
124 | for num in range(0,2):
125 | data_id=num+1
126 | if(num==0):
127 | y_train=y_train1
128 | x_train=x_train1
129 | l=l1
130 | elif(num==1):
131 | y_train=y_train2
132 | x_train=x_train2
133 | l=l2
134 |
135 | list_0=find_indices(y_train,0)
136 | list_1=find_indices(y_train,1)
137 | list_2=find_indices(y_train,2)
138 |
139 | for i in range(0,5):
140 | train_idxs=[]
141 |
142 |
143 | temp=list(np.random.choice(list_0,l[0] , replace = False))
144 | train_idxs.extend(temp)
145 | list_0 = list(set(list_0) -set( temp))
146 |
147 |
148 | temp=list(np.random.choice(list_1,l[1] , replace = False))
149 | train_idxs.extend(temp)
150 | list_1 = list(set(list_1) -set( temp))
151 |
152 |
153 | temp=list(np.random.choice(list_2,l[2] , replace = False))
154 | train_idxs.extend(temp)
155 | list_2 = list(set(list_2) -set( temp))
156 |
157 |
158 |
159 |
160 | dict_users[client_id]=train_idxs
161 | dict_users_test[client_id]=test_list
162 |
163 | client_id+=1
164 |
165 |
166 |
167 | return dict_users, dict_users_test
168 |
169 |
170 |
171 |
172 |
173 |
--------------------------------------------------------------------------------
/utils/split_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import sys
4 | from utils import datasets
5 | import pickle
6 | from torch.utils.data import Dataset, random_split
7 | import pandas as pd
8 | import numpy as np
9 |
10 | class DatasetFromSubset(Dataset):
11 | def __init__(self, subset, transform=None):
12 | self.subset = subset
13 | self.transform = transform
14 |
15 | def __getitem__(self, index):
16 | x, y = self.subset[index]
17 | if self.transform:
18 | x = self.transform(x)
19 | return x, y
20 |
21 | def __len__(self):
22 | return len(self.subset)
23 |
24 |
25 | def split_dataset(dataset: str, client_ids: list, datapoints = None, pretrained = False, output_dir='data' ):
26 | print('Splitting dataset (may take some time)...', end='')
27 |
28 | num_clients = len(client_ids)
29 |
30 | train_dataset, test_dataset, input_channels = datasets.load_full_dataset(dataset, output_dir,num_clients, datapoints, pretrained)
31 | per_client_trainset_size = len(train_dataset)//len(client_ids)
32 | train_split = [per_client_trainset_size]*len(client_ids)
33 | # train_split.append(len(train_dataset)-per_client_trainset_size*(len(client_ids)-1))
34 |
35 | per_client_testset_size = len(test_dataset)//len(client_ids)
36 | test_split = [per_client_testset_size]*len(client_ids)
37 | # test_split.append(len(test_dataset)-test_batch_size*(num_clients-1))
38 |
39 | train_datasets = list(torch.utils.data.random_split(train_dataset, train_split))
40 | test_datasets = list(torch.utils.data.random_split(test_dataset, test_split))
41 | # print(type(train_datasets[0]))
42 | for i in range(len(client_ids)):
43 | out_dir = f'{output_dir}/{dataset}/{client_ids[i]}'
44 | os.makedirs(out_dir + '/train', exist_ok=True)
45 | os.makedirs(out_dir + '/test', exist_ok=True)
46 | torch.save(train_datasets[i], out_dir + f'/train/{client_ids[i]}.pt')
47 | torch.save(test_datasets[i], out_dir + f'/test/{client_ids[i]}.pt')
48 | print('Done')
49 |
50 | return len(train_dataset), input_channels
51 |
52 |
53 | def split_dataset_cifar10tl_exp(client_ids: list, datapoints, output_dir='data'):
54 | print('Splitting dataset (may take some time)...', end='')
55 |
56 | num_clients = len(client_ids)
57 | train_dataset, test_dataset, input_channels = datasets.load_full_dataset("cifar10_tl", output_dir, num_clients, datapoints)
58 | per_client_trainset_size = datapoints
59 | train_split = [per_client_trainset_size]*num_clients
60 | train_datasets = list(torch.utils.data.random_split(train_dataset, train_split))
61 |
62 | for i in range(len(client_ids)):
63 | out_dir = f'{output_dir}/cifar10_tl/{client_ids[i]}'
64 | os.makedirs(out_dir + '/train', exist_ok=True)
65 | os.makedirs(out_dir + '/test', exist_ok=True)
66 | torch.save(train_datasets[i], out_dir + f'/train/{client_ids[i]}.pt')
67 | torch.save(test_dataset, out_dir + f'/test/{client_ids[i]}.pt')
68 | print('Done')
69 |
70 | return len(train_dataset), input_channels
71 |
72 |
73 | def split_dataset_cifar_setting2(client_ids: list, train_dataset, test_dataset, u_datapoints = 2000, c_datapoints = 150):
74 |
75 | print("Unique datapoints", u_datapoints)
76 | print("Common datapoints", c_datapoints)
77 |
78 | print('Splitting dataset (may take some time)...', end='')
79 |
80 | num_users = len(client_ids)
81 |
82 | #test split
83 | dict_users_test, all_idxs_test = {}, [i for i in range(len(test_dataset))]
84 | test_ids = set(np.random.choice(all_idxs_test, 2000, replace=False))
85 | for i in range(num_users):
86 | dict_users_test[i] = test_ids
87 |
88 |
89 | #u_datapoints are the number of datapoints with the first client
90 | #c_datapoints are the number of datapoints with each of the remaining clients
91 |
92 | #train split
93 | dict_users_train = {}
94 |
95 | for i in range(num_users):
96 | dict_users_train[i]=[]
97 |
98 | df=pd.DataFrame(list(train_dataset), columns=['images', 'labels'])
99 | num_of_classes=len(df['labels'].unique())
100 |
101 | per_class_client=int(c_datapoints/num_of_classes)
102 | per_class_total=per_class_client*num_users
103 |
104 | per_class_uclient = int(u_datapoints/num_of_classes)
105 | per_class_total += per_class_uclient #2000/10 = 200
106 |
107 | dict_classwise={}
108 |
109 | for i in range(num_of_classes):
110 | dict_classwise[i] = df[df['labels']==i].index.values.astype(int)[:per_class_total]
111 |
112 | for j in range(num_of_classes):
113 | temp=list(np.random.choice(dict_classwise[j], per_class_uclient, replace = False))
114 | dict_users_train[0].extend(temp)
115 | dict_classwise[j] = list(set(dict_classwise[j]) -set( temp))
116 |
117 |
118 | for i in range(1, num_users):
119 |
120 | for j in range(num_of_classes):
121 | temp=list(np.random.choice(dict_classwise[j], per_class_client, replace = False))
122 | dict_users_train[i].extend(temp)
123 | dict_classwise[j] = list(set(dict_classwise[j]) -set( temp))
124 |
125 |
126 | return dict_users_train, dict_users_test
127 |
128 |
129 |
--------------------------------------------------------------------------------
/utils/preprocess_eye_dataset_2.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import os
4 | import shutil
5 | import pathlib
6 | import random
7 | import datetime
8 | import cv2
9 | import os
10 | import numpy as np
11 | import matplotlib.pyplot as plt
12 | import matplotlib.image as mpimg
13 | import pickle
14 | from sklearn.metrics import classification_report, precision_recall_fscore_support, accuracy_score, confusion_matrix
15 | from tqdm import tqdm
16 |
17 | dir_path='data/eye_dataset2/eyepacs_preprocess/eyepacs_preprocess'
18 | df_temp = pd.read_csv("data/eye_dataset2/trainLabels.csv")
19 | print(len(df_temp))
20 | df_temp['level'].value_counts()
21 |
22 |
23 | def load_data():
24 | train = pd.read_csv("data/eye_dataset2/trainLabels.csv")
25 | # test = pd.read_csv('data/test.csv')
26 |
27 | train_dir = os.path.join('data/eye_dataset2/eyepacs_preprocess/eyepacs_preprocess')
28 | # test_dir = os.path.join('data/test_images')
29 |
30 | train['file_path'] = train['image'].map(lambda x: os.path.join(train_dir,'{}.jpeg'.format(x)))
31 | # test['file_path'] = test['id_code'].map(lambda x: os.path.join(test_dir,'{}.png'.format(x)))
32 |
33 | train['file_name'] = train["image"].apply(lambda x: x + ".jpeg")
34 | # test['file_name'] = test["id_code"].apply(lambda x: x + ".png")
35 |
36 | train['diagnosis'] = train['level'].astype(str)
37 |
38 | return train
39 |
40 |
41 | def wiener_filter(img, kernel, K):
42 | kernel /= np.sum(kernel)
43 | dummy = np.copy(img)
44 | dummy = np.fft.fft2(dummy)
45 | kernel = np.fft.fft2(kernel, s = img.shape)
46 | kernel = np.conj(kernel) / (np.abs(kernel) ** 2 + K)
47 | dummy = dummy * kernel
48 | dummy = np.abs(np.fft.ifft2(dummy))
49 | return dummy
50 |
51 | def gaussian_kernel(kernel_size = 3):
52 | h = gaussian(kernel_size, kernel_size / 3).reshape(kernel_size, 1)
53 | h = np.dot(h, h.transpose())
54 | h /= np.sum(h)
55 | return h
56 |
57 | def isbright(image, dim=227, thresh=0.4):
58 | # Resize image to 10x10
59 | image = cv2.resize(image, (dim, dim))
60 | # Convert color space to LAB format and extract L channel
61 | L, A, B = cv2.split(cv2.cvtColor(image, cv2.COLOR_BGR2LAB))
62 | # Normalize L channel by dividing all pixel
63 | L = L/np.max(L)
64 | # Return True if mean is greater than thresh else False
65 | return np.mean(L) > thresh
66 |
67 |
68 | def crop_image_from_gray(img,tol=7):
69 | if img.ndim ==2:
70 | mask = img>tol
71 | return img[np.ix_(mask.any(1),mask.any(0))]
72 | elif img.ndim==3:
73 | gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
74 | mask = gray_img>tol
75 |
76 | check_shape = img[:,:,0][np.ix_(mask.any(1),mask.any(0))].shape[0]
77 | if (check_shape == 0): # image is too dark so that we crop out everything,
78 | return img # return original image
79 | else:
80 | img1=img[:,:,0][np.ix_(mask.any(1),mask.any(0))]
81 | img2=img[:,:,1][np.ix_(mask.any(1),mask.any(0))]
82 | img3=img[:,:,2][np.ix_(mask.any(1),mask.any(0))]
83 | # print(img1.shape,img2.shape,img3.shape)
84 | img = np.stack([img1,img2,img3],axis=-1)
85 | # print(img.shape)
86 | return img
87 |
88 |
89 | def circle_crop(img, sigmaX):
90 | """
91 | Create circular crop around image centre
92 | """
93 |
94 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
95 | img = crop_image_from_gray(img)
96 | height, width, depth = img.shape
97 |
98 | x = int(width/2)
99 | y = int(height/2)
100 | r = np.amin((x,y))
101 |
102 | circle_img = np.zeros((height, width), np.uint8)
103 | cv2.circle(circle_img, (x,y), int(r), 1, thickness=-1)
104 | img = cv2.bitwise_and(img, img, mask=circle_img)
105 | img = crop_image_from_gray(img)
106 | img=cv2.resize(img, (224, 224))
107 | img=cv2.addWeighted(img,4, cv2.GaussianBlur( img , (0,0) , sigmaX) ,-4 ,128)
108 |
109 | return img
110 |
111 |
112 |
113 | def image_preprocessing(img):
114 | # 1. Read the image
115 | # img = mpimg.imread(img_path)
116 | img = img.astype(np.uint8)
117 |
118 | # 2. Extract the green channel of the image
119 | b, g, r = cv2.split(img)
120 |
121 | # 3.1. Apply CLAHE to intensify the green channel extracted image
122 | clh = cv2.createCLAHE(clipLimit=4.0)
123 | g = clh.apply(g)
124 |
125 | # 3.2. Convert enhanced image to grayscale
126 | merged_bgr_green_fused = cv2.merge((b, g, r))
127 | img_bw = cv2.cvtColor(merged_bgr_green_fused, cv2.COLOR_BGR2GRAY)
128 |
129 | # 4. Remove the isolated pixels using morphological cleaning operation.
130 | kernel1 = np.ones((1, 1), np.uint8)
131 | morph_open = cv2.morphologyEx(img_bw, cv2.MORPH_OPEN, kernel1)
132 |
133 | # 5. Extract blood vessels using mean-C thresholding.
134 | thresh = cv2.adaptiveThreshold(morph_open, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 9, 5)
135 | kernel2 = np.ones((2, 2), np.uint8)
136 | morph_open2 = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel2)
137 |
138 | # 6. Stacking the image into 3 channels
139 | stacked_img = np.stack((morph_open2,)*3, axis=-1)
140 |
141 | return stacked_img.astype("float64")
142 |
143 |
144 | df_train=load_data()
145 | print(df_train['diagnosis'].value_counts())
146 | X_train=[]
147 | Y_train=[]
148 | count_0, count_1, count_2=0,0,0
149 | print(type(df_train.diagnosis.iloc[0]))
150 |
151 | idxs=list(np.random.choice( list(range(0,35000)),5000, replace = False))
152 |
153 |
154 | for idx in tqdm(range(0,len(idxs))):
155 | i=idxs[idx]
156 | img = cv2.imread(df_train.file_path.iloc[i])
157 | # print(type(img))
158 |
159 | img = circle_crop(img,sigmaX=10)
160 | X_train.append(img)
161 |
162 | temp=[]
163 |
164 | if(int(df_train.diagnosis.iloc[i])==0):
165 | count_0+=1
166 | ans=0
167 | elif(int(df_train.diagnosis.iloc[i])==1 or int(df_train.diagnosis.iloc[i])==2):
168 | count_1+=1
169 | ans=1
170 | elif(int(df_train.diagnosis.iloc[i])==3 or int(df_train.diagnosis.iloc[i])==4):
171 | count_2+=1
172 | ans=2
173 | Y_train.append(ans)
174 |
175 | with open('data/x_train_eye_2', 'wb') as pickle_file:
176 | pickle.dump(X_train, pickle_file)
177 | with open('data/y_train_eye_2', 'wb') as pickle_file:
178 | pickle.dump(Y_train, pickle_file)
179 |
180 | print(count_0)
181 | print(count_1)
182 | print(count_2)
183 |
184 |
--------------------------------------------------------------------------------
/profiler.py:
--------------------------------------------------------------------------------
1 | from torchvision.models import resnet50
2 | from thop import profile
3 | import torch
4 | from torch import nn
5 | from torchvision import transforms
6 | from torch.utils.data import DataLoader, Dataset
7 | from pandas import DataFrame
8 | import pandas as pd
9 | from sklearn.model_selection import train_test_split
10 | from PIL import Image
11 | from glob import glob
12 | import math
13 | import random
14 | import numpy as np
15 | import os
16 | import matplotlib
17 | matplotlib.use('Agg')
18 | import matplotlib.pyplot as plt
19 | import copy
20 | import argparse
21 | from utils import datasets,dataset_settings
22 | import time
23 | import torch.nn.functional as F
24 |
25 |
26 | #MODELS
27 | def conv3x3(in_planes, out_planes, stride=1):
28 | "3x3 convolution with padding"
29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
30 | padding=1, bias=False)
31 |
32 |
33 | class BasicBlock(nn.Module):
34 | expansion = 1
35 |
36 | def __init__(self, inplanes, planes, stride=1, downsample=None):
37 | super(BasicBlock, self).__init__()
38 | self.conv1 = conv3x3(inplanes, planes, stride)
39 | self.bn1 = nn.BatchNorm2d(planes)
40 | self.relu = nn.ReLU(inplace=True)
41 | self.conv2 = conv3x3(planes, planes)
42 | self.bn2 = nn.BatchNorm2d(planes)
43 | self.downsample = downsample
44 | self.stride = stride
45 |
46 | def forward(self, x):
47 | residual = x
48 |
49 | out = self.conv1(x)
50 | out = self.bn1(out)
51 | out = self.relu(out)
52 |
53 | out = self.conv2(out)
54 | out = self.bn2(out)
55 |
56 | if self.downsample is not None:
57 | residual = self.downsample(x)
58 |
59 | out += residual
60 | out = self.relu(out)
61 |
62 | return out
63 |
64 | class ResNet18(nn.Module):
65 |
66 | def __init__(self, block, layers, input_channels, num_classes=1000):
67 | self.inplanes = 64
68 | super(ResNet18, self).__init__()
69 | self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
70 | bias=False)
71 | self.bn1 = nn.BatchNorm2d(64)
72 | self.relu = nn.ReLU(inplace=True)
73 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
74 | self.layer1 = self._make_layer(block, 64, layers[0])
75 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
76 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
77 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
78 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
79 | self.fc = nn.Linear(512 * block.expansion, num_classes)
80 |
81 | for m in self.modules():
82 | if isinstance(m, nn.Conv2d):
83 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
84 | m.weight.data.normal_(0, math.sqrt(2. / n))
85 | elif isinstance(m, nn.BatchNorm2d):
86 | m.weight.data.fill_(1)
87 | m.bias.data.zero_()
88 |
89 | def _make_layer(self, block, planes, blocks, stride=1):
90 | downsample = None
91 | if stride != 1 or self.inplanes != planes * block.expansion:
92 | downsample = nn.Sequential(
93 | nn.Conv2d(self.inplanes, planes * block.expansion,
94 | kernel_size=1, stride=stride, bias=False),
95 | nn.BatchNorm2d(planes * block.expansion),
96 | )
97 |
98 | layers = []
99 | layers.append(block(self.inplanes, planes, stride, downsample))
100 | self.inplanes = planes * block.expansion
101 | for i in range(1, blocks):
102 | layers.append(block(self.inplanes, planes))
103 |
104 | return nn.Sequential(*layers)
105 |
106 | def forward(self, x):
107 | x = self.conv1(x)
108 | x = self.bn1(x)
109 | x = self.relu(x)
110 | x = self.maxpool(x)
111 |
112 | x = self.layer1(x)
113 | x = self.layer2(x)
114 | x = self.layer3(x)
115 | x = self.layer4(x)
116 |
117 | x = self.avgpool(x)
118 | x = x.view(x.size(0), -1)
119 | x = self.fc(x)
120 |
121 | return x
122 |
123 |
124 | class ResNet18_client_side(nn.Module):
125 | def __init__(self,input_channels):
126 | super(ResNet18_client_side, self).__init__()
127 | self.layer1 = nn.Sequential (
128 | nn.Conv2d(input_channels, 64, kernel_size = 7, stride = 2, padding = 3, bias = False),
129 | nn.BatchNorm2d(64),
130 | nn.ReLU (inplace = True),
131 | nn.MaxPool2d(kernel_size = 3, stride = 2, padding =1),
132 | )
133 | self.layer2 = nn.Sequential (
134 | nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1, bias = False),
135 | nn.BatchNorm2d(64),
136 | nn.ReLU (inplace = True),
137 | nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1),
138 | nn.BatchNorm2d(64),
139 | )
140 |
141 | for m in self.modules():
142 | if isinstance(m, nn.Conv2d):
143 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
144 | m.weight.data.normal_(0, math.sqrt(2. / n))
145 | elif isinstance(m, nn.BatchNorm2d):
146 | m.weight.data.fill_(1)
147 | m.bias.data.zero_()
148 |
149 |
150 | def forward(self, x):
151 | resudial1 = F.relu(self.layer1(x))
152 | out1 = self.layer2(resudial1)
153 | out1 = out1 + resudial1 # adding the resudial inputs -- downsampling not required in this layer
154 | resudial2 = F.relu(out1)
155 | return resudial2
156 |
157 |
158 |
159 |
160 |
161 | #setting 1 Conf
162 | # 150 datapoints per client, 64 batch size
163 |
164 |
165 | #FL
166 | FL = ResNet18(BasicBlock, [2, 2, 2, 2],3, 10) #last two params are input_channels and number of classes
167 | #SL
168 | SL = ResNet18_client_side(3)
169 | #
170 | SFLv1 = ResNet18_client_side(3)
171 | #
172 | SFLv2 = ResNet18_client_side(3)
173 |
174 |
175 |
176 |
177 | input = torch.randn(64,3,224,224)
178 | macs_client_FL, params_FL = profile(FL, inputs=(input, ))
179 | macs_client_SL, params_SL = profile(SL, inputs=(input, ))
180 |
181 | print(f"GFLOPS FL {((2 * macs_client_FL) / (10**9))} PARAMS: {params_FL}")
182 | print(f"GFLOPS SL {((2 * macs_client_SL) / (10**9))} PARAMS: {params_SL}")
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
--------------------------------------------------------------------------------
/client.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn.functional as F
4 | import multiprocessing
5 | from threading import Thread
6 | from utils.connections import is_socket_closed
7 | from utils.connections import send_object
8 | from utils.connections import get_object
9 | from utils.split_dataset import DatasetFromSubset
10 | import pickle
11 | import queue
12 | import struct
13 |
14 |
15 | class Client(Thread):
16 | def __init__(self, id, *args, **kwargs):
17 | super(Client, self).__init__(*args, **kwargs)
18 | self.id = id
19 | self.front_model = []
20 | self.back_model = []
21 | self.losses = []
22 | self.train_dataset = None
23 | self.test_dataset = None
24 | self.train_DataLoader = None
25 | self.test_DataLoader = None
26 | self.socket = None
27 | self.server_socket = None
28 | self.train_batch_size = None
29 | self.test_batch_size = None
30 | self.iterator = None
31 | self.activations1 = None
32 | self.remote_activations1 = None
33 | self.outputs = None
34 | self.loss = None
35 | self.criterion = None
36 | self.data = None
37 | self.targets = None
38 | self.n_correct = 0
39 | self.n_samples = 0
40 | self.front_optimizer = None
41 | self.back_optimizer = None
42 | self.train_acc = []
43 | self.test_acc = []
44 | self.front_epsilons = []
45 | self.front_best_alphas = []
46 | self.pred=[]
47 | self.y=[]
48 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49 | # self.device = torch.device('cpu')
50 |
51 |
52 | def backward_back(self):
53 | self.loss.backward()
54 |
55 |
56 |
57 | def backward_front(self):
58 | self.activations1.backward(self.remote_activations1.grad)
59 |
60 |
61 | def calculate_loss(self):
62 | self.criterion = F.cross_entropy
63 | self.loss = self.criterion(self.outputs, self.targets)
64 |
65 |
66 | def calculate_test_acc(self):
67 | with torch.no_grad():
68 | _, self.predicted = torch.max(self.outputs.data, 1)
69 | self.n_correct = (self.predicted == self.targets).sum().item()
70 | self.n_samples = self.targets.size(0)
71 | self.pred.extend(self.predicted.cpu().detach().numpy().tolist())
72 | self.y.extend(self.targets.cpu().detach().numpy().tolist())
73 | # self.test_acc.append(100.0 * self.n_correct/self.n_samples)
74 | return 100.0 * self.n_correct/self.n_samples
75 | # print(f'Acc: {self.test_acc[-1]}')
76 |
77 |
78 | def calculate_train_acc(self):
79 | with torch.no_grad():
80 | _, self.predicted = torch.max(self.outputs.data, 1)
81 | self.n_correct = (self.predicted == self.targets).sum().item()
82 | self.n_samples = self.targets.size(0)
83 | # self.train_acc.append(100.0 * self.n_correct/self.n_samples)
84 | return 100.0 * self.n_correct/self.n_samples
85 | # print(f'Acc: {self.train_acc[-1]}')
86 |
87 |
88 | def connect_server(self, host='localhost', port=8000, BUFFER_SIZE=4096):
89 | self.socket, self.server_socket = multiprocessing.Pipe()
90 | print(f"[*] Client {self.id} connecting to {host}")
91 |
92 |
93 | def create_DataLoader(self, train_batch_size, test_batch_size):
94 | self.train_batch_size = train_batch_size
95 | self.test_batch_size = test_batch_size
96 | self.train_DataLoader = torch.utils.data.DataLoader(dataset=self.train_dataset,
97 | batch_size=self.train_batch_size,
98 | shuffle=True)
99 | self.test_DataLoader = torch.utils.data.DataLoader(dataset=self.test_dataset,
100 | batch_size=self.test_batch_size,
101 | shuffle=True)
102 |
103 |
104 | def disconnect_server(self) -> bool:
105 | if not is_socket_closed(self.socket):
106 | self.socket.close()
107 | return True
108 | else:
109 | return False
110 |
111 |
112 | def forward_back(self):
113 | self.back_model.to(self.device)
114 | self.outputs = self.back_model(self.remote_activations2)
115 |
116 |
117 | def forward_front(self):
118 | self.data, self.targets = next(self.iterator)
119 | self.data, self.targets = self.data.to(self.device), self.targets.to(self.device)
120 | self.front_model.to(self.device)
121 | self.activations1 = self.front_model(self.data)
122 | self.remote_activations1 = self.activations1.detach().requires_grad_(True)
123 |
124 |
125 | # def getModel(self):
126 | # self.onThread(self._getModel)
127 |
128 |
129 | def get_model(self):
130 | model = get_object(self.socket)
131 | self.front_model = model['front']
132 | self.back_model = model['back']
133 |
134 |
135 | def get_remote_activations1_grads(self):
136 | self.remote_activations1.grad = get_object(self.socket)
137 |
138 |
139 | def get_remote_activations2(self):
140 | self.remote_activations2 = get_object(self.socket)
141 |
142 |
143 | def idle(self):
144 | pass
145 |
146 |
147 | def load_data(self, dataset, transform):
148 | try:
149 | dataset_path = os.path.join(f'data/{dataset}/{self.id}')
150 | except:
151 | raise Exception(f'Dataset not found for client {self.id}')
152 | self.train_dataset = torch.load(f'{dataset_path}/train/{self.id}.pt')
153 | self.test_dataset = torch.load(f'{dataset_path}/test/{self.id}.pt')
154 |
155 | self.train_dataset = DatasetFromSubset(
156 | self.train_dataset, transform=transform
157 | )
158 | self.test_dataset = DatasetFromSubset(
159 | self.test_dataset, transform=transform
160 | )
161 |
162 |
163 | # def onThread(self, function, *args, **kwargs):
164 | # self.q.put((function, args, kwargs))
165 |
166 |
167 | # def run(self, *args, **kwargs):
168 | # super(Client, self).run(*args, **kwargs)
169 | # while True:
170 | # try:
171 | # function, args, kwargs = self.q.get(timeout=self.timeout)
172 | # function(*args, **kwargs)
173 | # except queue.Empty:
174 | # self.idle()
175 |
176 |
177 | def send_remote_activations1(self):
178 | send_object(self.socket, self.remote_activations1)
179 |
180 |
181 | def send_remote_activations2_grads(self):
182 | send_object(self.socket, self.remote_activations2.grad)
183 |
184 |
185 | def step_front(self):
186 | self.front_optimizer.step()
187 |
188 |
189 | def step_back(self):
190 | self.back_optimizer.step()
191 |
192 |
193 | # def train_model(self):
194 | # forward_front_model()
195 | # send_activations_to_server()
196 | # forward_back_model()
197 | # loss_calculation()
198 | # backward_back_model()
199 | # send_gradients_to_server()
200 | # backward_front_model()
201 |
202 |
203 | def zero_grad_front(self):
204 | self.front_optimizer.zero_grad()
205 |
206 |
207 | def zero_grad_back(self):
208 | self.back_optimizer.zero_grad()
209 |
210 |
--------------------------------------------------------------------------------
/utils/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 | import torchvision.transforms as transforms
5 | from torchvision import models
6 | from torch.utils.data import Dataset, DataLoader
7 | from glob import glob
8 | import os
9 | import pandas as pd
10 | import sklearn
11 | from sklearn.model_selection import train_test_split
12 | from PIL import Image
13 |
14 | def MNIST(path, transform_train = None, transform_test = None):
15 |
16 | if transform_train is None:
17 | #TL
18 | print("TL transform Cifar")
19 | transform_train = transforms.Compose([
20 | transforms.Resize((224, 224)),
21 | transforms.ToTensor(),
22 | transforms.Normalize(mean=[0.485],
23 | std=[0.229])
24 | ])
25 | if transform_test is None:
26 | #TL
27 | transform_test = transforms.Compose([
28 | transforms.Resize((224, 224)),
29 | transforms.ToTensor(),
30 | transforms.Normalize(mean=[0.485],
31 | std=[0.229])])
32 |
33 | train_dataset = torchvision.datasets.MNIST(root=path,
34 | train=True,
35 | transform=transform_train,download=True)
36 |
37 | test_dataset = torchvision.datasets.MNIST(root=path,
38 | train=False,
39 | transform=transform_test, download = True)
40 | return train_dataset, test_dataset
41 |
42 |
43 | def CIFAR10(path, transform_train = None, transform_test = None):
44 |
45 |
46 | if transform_train is None:
47 | print("No Tl Cifar10")
48 | transform_train = transforms.Compose([
49 | transforms.RandomCrop(32, padding=4),
50 | transforms.RandomHorizontalFlip(),
51 | transforms.ToTensor(),
52 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
53 | ])
54 |
55 | if transform_test is None:
56 | transform_test = transforms.Compose([
57 | transforms.RandomCrop(32, padding=4),
58 | transforms.ToTensor(),
59 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
60 | ])
61 |
62 | train_dataset = torchvision.datasets.CIFAR10(root=path,
63 | train=True,
64 | transform=transform_train,download= True)
65 |
66 | test_dataset = torchvision.datasets.CIFAR10(root=path,
67 | train=False,
68 | transform=transform_test, download = True)
69 | return train_dataset, test_dataset
70 |
71 | def CIFAR10_iid(num_clients, datapoints, path, transform_train, transform_test):
72 |
73 |
74 | #these transforms are same as Imagenet configurations as TL is being used
75 |
76 | if transform_train is None:
77 | #TL
78 | print("TL transform Cifar")
79 | transform_train = transforms.Compose([
80 | transforms.Resize((224, 224)),
81 | transforms.ToTensor(),
82 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
83 | std=[0.229, 0.224, 0.225])
84 | ])
85 | if transform_test is None:
86 | #TL
87 | transform_test = transforms.Compose([
88 | transforms.Resize((224, 224)),
89 | transforms.ToTensor(),
90 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
91 | std=[0.229, 0.224, 0.225])])
92 |
93 | train_dataset = torchvision.datasets.CIFAR10(root=path,
94 | train=True,
95 | transform=transform_train,download= True)
96 |
97 | test_dataset = torchvision.datasets.CIFAR10(root=path,
98 | train=False,
99 | transform=transform_test, download = True)
100 |
101 | class2idx = train_dataset.class_to_idx.items()
102 | idx2class = {v: k for k, v in train_dataset.class_to_idx.items()}
103 |
104 | new_train_dataset_size = num_clients * datapoints
105 | temp = len(train_dataset) - new_train_dataset_size
106 |
107 | print(len(train_dataset), new_train_dataset_size, temp)
108 |
109 | new_train_dataset,_ = torch.utils.data.random_split(train_dataset, (new_train_dataset_size, temp))
110 | new_test_dataset,_ = torch.utils.data.random_split(test_dataset, (2000, 8000)) #keeping 2k datapoints with each client
111 |
112 | return new_train_dataset, new_test_dataset, dict(class2idx), idx2class
113 |
114 |
115 | def cifar10_setting3(num_clients, unique_datapoint, c_datapoints, path, transform_train, transform_test):
116 |
117 | if transform_train is None:
118 | #TL
119 | print("TL transform Cifar")
120 | transform_train = transforms.Compose([
121 | transforms.Resize((224, 224)),
122 | transforms.ToTensor(),
123 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
124 | std=[0.229, 0.224, 0.225])
125 | ])
126 | if transform_test is None:
127 | #TL
128 | transform_test = transforms.Compose([
129 | transforms.Resize((224, 224)),
130 | transforms.ToTensor(),
131 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
132 | std=[0.229, 0.224, 0.225])])
133 |
134 | train_dataset = torchvision.datasets.CIFAR10(root=path,
135 | train=True,
136 | transform=transform_train,download= True)
137 |
138 | test_dataset = torchvision.datasets.CIFAR10(root=path,
139 | train=False,
140 | transform=transform_test, download = True)
141 |
142 | class2idx = train_dataset.class_to_idx.items()
143 | idx2class = {v: k for k, v in train_dataset.class_to_idx.items()}
144 |
145 | #for setting two the datapoints are the equal datapoints
146 | # so c1: 2000, c2 - c11: 150 each
147 |
148 | new_train_dataset_size = unique_datapoint + ((num_clients - 1) * c_datapoints) # 2000 + 10 * 150
149 | temp = len(train_dataset) - new_train_dataset_size
150 |
151 | print(len(train_dataset), new_train_dataset_size, temp)
152 |
153 | new_train_dataset,_ = torch.utils.data.random_split(train_dataset, (new_train_dataset_size, temp))
154 | unique_train_dataset, common_train_dataset = torch.utils.data.random_split(new_train_dataset, (unique_datapoint, new_train_dataset_size - unique_datapoint)) #2000, 1500
155 | new_test_dataset,_ = torch.utils.data.random_split(test_dataset, (1100, 8900))
156 |
157 | return unique_train_dataset, common_train_dataset, new_test_dataset, dict(class2idx), idx2class
158 |
159 |
160 |
161 | def FashionMNIST(path, transform_train = None,transform_test = None):
162 |
163 | if transform_train is None:
164 | #TL
165 | print("TL transform Cifar")
166 | transform_train = transforms.Compose([
167 | transforms.Resize((224, 224)),
168 | transforms.ToTensor(),
169 | transforms.Normalize(mean=[0.485],
170 | std=[0.229])
171 | ])
172 | if transform_test is None:
173 | #TL
174 | transform_test = transforms.Compose([
175 | transforms.Resize((224, 224)),
176 | transforms.ToTensor(),
177 | transforms.Normalize(mean=[0.485],
178 | std=[0.229])])
179 |
180 |
181 | train_dataset = torchvision.datasets.FashionMNIST(root=path,train=True,transform=transform_train,download=True)
182 | test_dataset = torchvision.datasets.FashionMNIST(root=path,train=False,transform=transform_test,download=True)
183 |
184 | return train_dataset, test_dataset
185 |
186 |
187 | class HAM10000(Dataset):
188 | def __init__(self, df, transform=None):
189 | self.df = df
190 | self.transform = transform
191 |
192 | def __len__(self):
193 | return len(self.df)
194 |
195 | def __getitem__(self, index):
196 | # Load data and get label
197 | X = Image.open(self.df['path'][index])
198 | y = torch.tensor(int(self.df['labels'][index]))
199 |
200 | if self.transform:
201 | X = self.transform(X)
202 |
203 | return X, y
204 |
205 |
206 | def get_duplicates(x):
207 | unique_list = list(df_undup['lesion_id'])
208 | if x in unique_list:
209 | return 'unduplicated'
210 | else:
211 | return 'duplicated'
212 |
213 | def get_test_rows(x):
214 | # create a list of all the lesion_id's in the test set
215 | test_list = list(df_test['image_id'])
216 | if str(x) in test_list:
217 | return 'test'
218 | else:
219 | return 'train'
220 |
221 | def HAM10000_data(path):
222 | data_dir='/home/manas/priv_SLR/data/HAM10000_data_files'
223 | images_dir = data_dir + "/images"
224 | all_image_path = glob(os.path.join(images_dir,'*.jpg'))
225 | print("gnc", len(all_image_path))
226 | imageid_path_dict = {os.path.splitext(os.path.basename(x))[0]: x for x in all_image_path}
227 | #{"ISIC_0024306":/home/manas/priv_SLR/data/HAM10000_data_files/images/ISIC_0024306.jpg }
228 |
229 | lesion_type_dict = {
230 | 'nv': 'Melanocytic nevi',
231 | 'mel': 'dermatofibroma',
232 | 'bkl': 'Benign keratosis-like lesions ',
233 | 'bcc': 'Basal cell carcinoma',
234 | 'akiec': 'Actinic keratoses',
235 | 'vasc': 'Vascular lesions',
236 | 'df': 'Dermatofibroma'
237 | }
238 |
239 | df_original = pd.read_csv(data_dir + '/HAM10000_metadata')
240 | df_original['path'] = df_original['image_id'].map(imageid_path_dict.get)
241 | df_original['cell_type'] = df_original['dx'].map(lesion_type_dict.get)
242 | df_original['labels'] = pd.Categorical(df_original['cell_type']).codes
243 |
244 | global df_undup
245 | df_undup = df_original.groupby('lesion_id').count()
246 | # now we filter out lesion_id's that have only one image associated with it
247 | df_undup = df_undup[df_undup['image_id'] == 1]
248 | df_undup.reset_index(inplace=True)
249 | df_original['duplicates'] = df_original['lesion_id']
250 | # apply the function to this new column
251 | df_original['duplicates'] = df_original['duplicates'].apply(get_duplicates)
252 | df_undup = df_original[df_original['duplicates'] == 'unduplicated']
253 | y = df_undup['labels']
254 | global df_test, df_train
255 | _, df_test = train_test_split(df_undup, test_size=0.2, random_state=101, stratify=y)
256 |
257 | df_original['train_or_test'] = df_original['image_id']
258 | df_original['train_or_test'] = df_original['train_or_test'].apply(get_test_rows)
259 | # filter out train rows
260 | df_train = df_original[df_original['train_or_test'] == 'train']
261 |
262 | df_train=df_train[:8910]
263 | df_test=df_test[:1100]
264 |
265 | df_train = df_train.reset_index()
266 | df_test = df_test.reset_index()
267 |
268 | norm_mean=[0.763038, 0.54564667, 0.57004464]
269 | norm_std = [0.14092727, 0.15261286, 0.1699712]
270 | train_transform = transforms.Compose([transforms.Resize((64, 64)),transforms.RandomHorizontalFlip(),
271 | transforms.RandomVerticalFlip(),transforms.RandomRotation(20),
272 | transforms.ColorJitter(brightness=0.1, contrast=0.1, hue=0.1),
273 | transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std)])
274 | test_transform = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor(),
275 | transforms.Normalize(norm_mean, norm_std)])
276 | train_dataset = HAM10000(df_train, transform=train_transform)
277 | test_dataset= HAM10000(df_test, transform=test_transform)
278 |
279 | return train_dataset, test_dataset
280 |
281 |
282 |
283 | def load_full_dataset(dataset, dataset_path, num_clients, datapoints = None, pretrained = False,transform_train = None, transform_test = None):
284 |
285 | if dataset == 'mnist':
286 | train_dataset, test_dataset = MNIST(dataset_path)
287 | input_channels = 1
288 |
289 | if dataset == 'cifar10':
290 |
291 |
292 | if pretrained:
293 | transform_train = transforms.Compose([
294 | transforms.Resize((224, 224)),
295 | transforms.ToTensor(),
296 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
297 | std=[0.229, 0.224, 0.225])
298 | ])
299 |
300 | transform_test = transforms.Compose([
301 | transforms.Resize((224, 224)),
302 | transforms.ToTensor(),
303 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
304 | std=[0.229, 0.224, 0.225])])
305 |
306 | train_dataset, test_dataset = CIFAR10(dataset_path, transform_train, transform_test)
307 | else:
308 | train_dataset, test_dataset = CIFAR10(dataset_path)
309 |
310 | input_channels = 3
311 |
312 |
313 | if dataset == 'ham10000':
314 | train_dataset, test_dataset = HAM10000_data(dataset_path)
315 | input_channels = 3
316 |
317 | if dataset == "fmnist":
318 | train_dataset, test_dataset = FashionMNIST(dataset_path)
319 | input_channels = 1
320 |
321 | if dataset == "cifar10_tl":
322 | #this is the reduced train_dataset
323 |
324 | train_dataset, test_dataset, ci,ic = CIFAR10_iid(num_clients, datapoints, dataset_path, transform_train, transform_test)
325 | input_channels = 3
326 |
327 | if dataset == "cifar10_setting3":
328 | u_datapoints = 2000
329 | c_datapoints = 150
330 | input_channels = 3
331 | #num_clients = 11
332 | #transform_train, tranform_test = None, None
333 |
334 | u_train_dataset, c_train_dataset, test_dataset,_,_ = cifar10_setting3(num_clients, u_datapoints, c_datapoints, dataset_path, transform_train, transform_test)
335 | return u_train_dataset,c_train_dataset, test_dataset, input_channels
336 |
337 | return train_dataset, test_dataset, input_channels
338 |
339 |
340 |
341 |
342 |
343 |
344 | # if __name__ == "__main__":
345 | # train_dataset, test_dataset = MNIST('../data')
346 | # print(type(train_dataset))
347 |
--------------------------------------------------------------------------------
/PFSL.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import string
4 | import socket
5 | import requests
6 | import sys
7 | import threading
8 | import time
9 | import torch
10 | from math import ceil
11 | from torchvision import transforms
12 | from torch.utils.data import DataLoader, Dataset
13 | from utils.split_dataset import split_dataset, split_dataset_cifar10tl_exp
14 | from utils.client_simulation import generate_random_clients
15 | from utils.connections import send_object
16 | from utils.arg_parser import parse_arguments
17 | import matplotlib.pyplot as plt
18 | import time
19 | import multiprocessing
20 | from sklearn.metrics import classification_report
21 | import torch.optim as optim
22 | import copy
23 | from datetime import datetime
24 | from scipy.interpolate import make_interp_spline
25 | import numpy as np
26 | from ConnectedClient import ConnectedClient
27 | import importlib
28 | from utils.merge import merge_grads, merge_weights
29 | import pandas as pd
30 | import time
31 | from utils import dataset_settings, datasets
32 | import torch.nn.functional as F
33 |
34 |
35 | #To intialize every client with their train and test data
36 | def initialize_client(client, dataset, batch_size, test_batch_size, tranform):
37 |
38 | client.load_data(dataset, transform)
39 | print(f'Length of train dataset client {client.id}: {len(client.train_dataset)}')
40 | client.create_DataLoader(batch_size, test_batch_size)
41 |
42 |
43 | #Plots class distribution of train data available to each client
44 | def plot_class_distribution(clients, dataset, batch_size, epochs, opt, client_ids):
45 | class_distribution=dict()
46 | number_of_clients=len(client_ids)
47 | if(len(clients)<=20):
48 | plot_for_clients=client_ids
49 | else:
50 | plot_for_clients=random.sample(client_ids, 20)
51 |
52 | fig, ax = plt.subplots(nrows=(int(ceil(len(client_ids)/5))), ncols=5, figsize=(15, 10))
53 | j=0
54 | i=0
55 |
56 | #plot histogram
57 | for client_id in plot_for_clients:
58 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels'])
59 | class_distribution[client_id]=df['labels'].value_counts().sort_index()
60 | df['labels'].value_counts().sort_index().plot(ax = ax[i,j], kind = 'bar', ylabel = 'frequency', xlabel=client_id)
61 | j+=1
62 | if(j==5 or j==10 or j==15):
63 | i+=1
64 | j=0
65 | fig.tight_layout()
66 | plt.show()
67 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_histogram.png')
68 | plt.savefig('plot_setting3_exp.png')
69 |
70 | max_len=0
71 | #plot line graphs
72 | for client_id in plot_for_clients:
73 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels'])
74 | df['labels'].value_counts().sort_index().plot(kind = 'line', ylabel = 'frequency', label=client_id)
75 | max_len=max(max_len, list(df['labels'].value_counts(sort=False)[df.labels.mode()])[0])
76 | plt.xticks(np.arange(0,10))
77 | plt.ylim(0, max_len)
78 | plt.legend()
79 | plt.show()
80 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_line_graph.png')
81 |
82 | return class_distribution
83 |
84 |
85 |
86 | if __name__ == "__main__":
87 |
88 |
89 | args = parse_arguments()
90 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
91 | print("Arguments provided", args)
92 |
93 |
94 |
95 |
96 | random.seed(args.seed)
97 | torch.manual_seed(args.seed)
98 |
99 | overall_test_acc = []
100 | overall_train_acc = []
101 |
102 | print('Generating random clients...', end='')
103 | clients = generate_random_clients(args.number_of_clients)
104 | client_ids = list(clients.keys())
105 | print('Done')
106 |
107 | train_dataset_size, input_channels = split_dataset(args.dataset, client_ids, pretrained=args.pretrained)
108 |
109 | print(f'Random client ids:{str(client_ids)}')
110 | transform=None
111 | max_epoch=0
112 | max_f1=0
113 |
114 | #Assigning train and test data to each client depending for each client
115 | print('Initializing clients...')
116 |
117 |
118 | for _, client in clients.items():
119 | (initialize_client(client, args.dataset, args.batch_size, args.test_batch_size, transform))
120 |
121 |
122 |
123 |
124 | print('Client Intialization complete.')
125 | # Train and test data intialisation complete
126 |
127 | # class_distribution=plot_class_distribution(clients, args.dataset, args.batch_size, args.epochs, args.opt_iden, client_ids)
128 |
129 |
130 | #Assigning front, center and back models and their optimizers for all the clients
131 | model = importlib.import_module(f'models.{args.model}')
132 |
133 | for _, client in clients.items():
134 | client.front_model = model.front(input_channels, pretrained=args.pretrained)
135 | client.back_model = model.back(pretrained=args.pretrained)
136 | print('Done')
137 |
138 | for _, client in clients.items():
139 | # client.front_optimizer = optim.SGD(client.front_model.parameters(), lr=args.lr, momentum=0.9)
140 | # client.back_optimizer = optim.SGD(client.back_model.parameters(), lr=args.lr, momentum=0.9)
141 | client.front_optimizer = optim.Adam(client.front_model.parameters(), lr=args.lr)
142 | client.back_optimizer = optim.Adam(client.back_model.parameters(), lr=args.lr)
143 |
144 | first_client = clients[client_ids[0]]
145 | num_iterations = ceil(len(first_client.train_DataLoader.dataset)/args.batch_size)
146 |
147 | num_test_iterations= ceil(len(first_client.test_DataLoader.dataset)/args.test_batch_size)
148 | sc_clients = {} #server copy clients
149 |
150 | for iden in client_ids:
151 | sc_clients[iden] = ConnectedClient(iden, None)
152 |
153 | for _,s_client in sc_clients.items():
154 | s_client.center_model = model.center(pretrained=args.pretrained)
155 | s_client.center_model.to(device)
156 | # s_client.center_optimizer = optim.SGD(s_client.center_model.parameters(), lr=args.lr, momentum=0.9)
157 | s_client.center_optimizer = optim.Adam(s_client.center_model.parameters(), args.lr)
158 |
159 | st = time.time()
160 |
161 | macro_avg_f1_2classes=[]
162 |
163 | criterion=F.cross_entropy
164 |
165 |
166 | #Starting the training process
167 | for epoch in range(args.epochs):
168 | if(epoch==args.checkpoint): # When starting epoch of the perosnalisation is reached, freeze all the layers of the center model
169 | for _, s_client in sc_clients.items():
170 | s_client.center_model.freeze(epoch, pretrained=True)
171 |
172 | overall_train_acc.append(0)
173 |
174 |
175 | for _, client in clients.items():
176 | client.train_acc.append(0)
177 | client.iterator = iter(client.train_DataLoader)
178 |
179 | #For every batch in the current epoch
180 | for iteration in range(num_iterations):
181 | print(f'\rEpoch: {epoch+1}, Iteration: {iteration+1}/{num_iterations}', end='')
182 |
183 |
184 | for _, client in clients.items():
185 | client.forward_front()
186 |
187 | for client_id, client in sc_clients.items():
188 | client.remote_activations1 = clients[client_id].remote_activations1
189 | client.forward_center()
190 |
191 | for client_id, client in clients.items():
192 | client.remote_activations2 = sc_clients[client_id].remote_activations2
193 | client.forward_back()
194 |
195 | for _, client in clients.items():
196 | client.calculate_loss()
197 |
198 | for _, client in clients.items():
199 | client.backward_back()
200 |
201 | for client_id, client in sc_clients.items():
202 | client.remote_activations2 = clients[client_id].remote_activations2
203 | client.backward_center()
204 |
205 | for _, client in clients.items():
206 | client.step_back()
207 | client.zero_grad_back()
208 |
209 | #merge grads uncomment below
210 |
211 | # if epoch%2 == 0:
212 | # params = []
213 | # normalized_data_sizes = []
214 | # for iden, client in clients.items():
215 | # params.append(sc_clients[iden].center_model.parameters())
216 | # normalized_data_sizes.append(len(client.train_dataset) / train_dataset_size)
217 | # merge_grads(normalized_data_sizes, params)
218 |
219 | for _, client in sc_clients.items():
220 | client.center_optimizer.step()
221 | client.center_optimizer.zero_grad()
222 |
223 | for _, client in clients.items():
224 | client.train_acc[-1] += client.calculate_train_acc() #train accuracy of every client in the current epoch in the current batch
225 |
226 | for c_id, client in clients.items():
227 | client.train_acc[-1] /= num_iterations # train accuracy of every client of all the batches in the current epoch
228 | overall_train_acc[-1] += client.train_acc[-1]
229 |
230 | overall_train_acc[-1] /= len(clients) #avg train accuracy of all the clients in the current epoch
231 | print(f' Personalized Average Train Acc: {overall_train_acc[-1]}')
232 |
233 | # merge weights below uncomment
234 | params = []
235 | for _, client in sc_clients.items():
236 | params.append(copy.deepcopy(client.center_model.state_dict()))
237 | w_glob = merge_weights(params)
238 |
239 | for _, client in sc_clients.items():
240 | client.center_model.load_state_dict(w_glob)
241 |
242 | params = []
243 |
244 | #In the personalisation phase merging of weights of the back layers is stopped
245 | if(epoch <=args.checkpoint):
246 | for _, client in clients.items():
247 | params.append(copy.deepcopy(client.back_model.state_dict()))
248 | w_glob_cb = merge_weights(params)
249 | del params
250 |
251 | for _, client in clients.items():
252 | client.back_model.load_state_dict(w_glob_cb)
253 |
254 | #Testing every epoch
255 | if (epoch%1 == 0 ):
256 | if(epoch==args.checkpoint):
257 | for _, s_client in sc_clients.items():
258 | s_client.center_model.freeze(epoch, pretrained=True)
259 | with torch.no_grad():
260 | test_acc = 0
261 | overall_test_acc.append(0)
262 |
263 | for _, client in clients.items():
264 | client.test_acc.append(0)
265 | client.iterator = iter(client.test_DataLoader)
266 | client.pred=[]
267 | client.y=[]
268 |
269 | #For every batch in the testing phase
270 | for iteration in range(num_test_iterations):
271 |
272 | for _, client in clients.items():
273 | client.forward_front()
274 |
275 | for client_id, client in sc_clients.items():
276 | client.remote_activations1 = clients[client_id].remote_activations1
277 | client.forward_center()
278 |
279 | for client_id, client in clients.items():
280 | client.remote_activations2 = sc_clients[client_id].remote_activations2
281 | client.forward_back()
282 |
283 | for _, client in clients.items():
284 | client.test_acc[-1] += client.calculate_test_acc()
285 |
286 |
287 | for _, client in clients.items():
288 | client.test_acc[-1] /= num_test_iterations
289 | overall_test_acc[-1] += client.test_acc[-1]
290 |
291 |
292 |
293 |
294 | overall_test_acc[-1] /= len(clients) #average test accuracy of all the clients in the current epoch
295 |
296 |
297 | print(f' Personalized Average Test Acc: {overall_test_acc[-1]} ')
298 |
299 |
300 |
301 | timestamp = int(datetime.now().timestamp())
302 | plot_config = f'''dataset: {args.dataset},
303 | model: {args.model},
304 | batch_size: {args.batch_size}, lr: {args.lr},
305 | '''
306 |
307 | et = time.time()
308 | print(f"Time taken for this run {(et - st)/60} mins")
309 |
310 |
311 |
312 | # calculating the train and test standarad deviation and teh confidence intervals
313 | X = range(args.epochs)
314 | all_clients_stacked_train = np.array([client.train_acc for _,client in clients.items()])
315 | all_clients_stacked_test = np.array([client.test_acc for _,client in clients.items()])
316 | epochs_train_std = np.std(all_clients_stacked_train,axis = 0, dtype = np.float64)
317 | epochs_test_std = np.std(all_clients_stacked_test,axis = 0, dtype = np.float64)
318 |
319 | #Y_train is the average client train accuracies at each epoch
320 | #epoch_train_std is the standard deviation of clients train accuracies at each epoch
321 | Y_train = overall_train_acc
322 | Y_train_lower = Y_train - (1.65 * epochs_train_std) #95% of the values lie between 1.65*std
323 | Y_train_upper = Y_train + (1.65 * epochs_train_std)
324 |
325 | Y_test = overall_test_acc
326 | Y_test_lower = Y_test - (1.65 * epochs_test_std) #95% of the values lie between 1.65*std
327 | Y_test_upper = Y_test + (1.65 * epochs_test_std)
328 |
329 | Y_train_cv = epochs_train_std / Y_train
330 | Y_test_cv = epochs_test_std / Y_test
331 |
332 | plt.figure(0)
333 | plt.plot(X, Y_train)
334 | plt.fill_between(X,Y_train_lower , Y_train_upper, color='blue', alpha=0.25)
335 | # plt.savefig(f'./results/test_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight')
336 | plt.show()
337 |
338 |
339 | plt.figure(1)
340 | plt.plot(X, Y_test)
341 | plt.fill_between(X,Y_test_lower , Y_test_upper, color='blue', alpha=0.25)
342 | # plt.savefig(f'./results/test_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight')
343 | plt.show()
344 |
345 | plt.figure(2)
346 | plt.plot(X, Y_train_cv)
347 | plt.show()
348 |
349 |
350 | plt.figure(3)
351 | plt.plot(X, Y_test_cv)
352 | plt.show()
353 |
354 |
355 |
356 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PFSL: Personalized & Fair Split Learning with Data & Label Privacy for thin clients
2 |
3 | ## 1) Please cite as below if you use this repository:
4 | `@software{Manas_Wadhwa_and_Gagan_Gupta_and_Ashutosh_Sahu_and_Rahul_Saini_and_Vidhi_Mittal_PFSL_2023,
5 | author = {Manas Wadhwa and Gagan Gupta and Ashutosh Sahu and Rahul Saini and Vidhi Mittal},
6 | month = {2},
7 | title = {{PFSL}},
8 | url = {https://github.com/mnswdhw/PFSL},
9 | version = {1.0.0},
10 | year = {2023}
11 | }`
12 |
13 |
14 | ## 2) Credits
15 |
16 | To reproduce the results of the paper `Thapa, C., Chamikara, M. A., Camtepe, S., & Sun, L. (2020). SplitFed: When Federated Learning Meets Split Learning. ArXiv. https://doi.org/10.48550/arXiv.2004.12088`, we use their official source code which can be found here: https://github.com/chandra2thapa/SplitFed-When-Federated-Learning-Meets-Split-Learning
17 |
18 | For finding the FLOPs of our Pytorch split model at the client side, we use the profiler: https://github.com/Lyken17/pytorch-OpCounter
19 |
20 |
21 | ## 3) Build requirements:
22 | * Python3 (3.8)
23 | * pip3
24 | * Nvidia GPU (>=12GB)
25 | * conda
26 |
27 |
28 | ## 4)Installation
29 | Use the following steps to install the required libraries:
30 | * Change Directory into the project folder
31 | * Create a conda environment using the command
32 | `conda create --name {env_name} python=3.8`
33 | Eg- `conda create --name pfsl python=3.8`
34 | * Activate conda environment using the command
35 | `conda activate {env_name}`
36 | Eg- `conda activate pfsl`
37 | * The use the command: `pip install -r requirements.txt`
38 |
39 | ## 5) Test Run
40 |
41 | ### Parameters
42 | The parameters options for a particular file can be checked adding -–help argument.
43 |
Optional arguments available for PFSL are:
44 | * -h, --help show this help message and exit
45 | * -c, -–number of clients Number of Clients (default: 10)
46 | * -b, -–batch_size Batch size (default: 128)
47 | * –-test_batch_size Input batch size for testing (default: 128)
48 | * -n , –-epochs Total number of epochs to train (default: 10)
49 | * –-lr Learning rate (default: 0.001)
50 | * -–save model Save the trained model (default: False)
51 | * –-dataset States dataset to be used (default: cifar10)
52 | * –-seed Random seed (default: 1234)
53 | * –-model Model you would like to train (default: resnet18)
54 | * –-epoch_batch Number of epochs after which next batchof clients should join (default: 5)
55 | * –-opt_iden optional identifier of experiment (default: )
56 | * –-pretrained Use transfer learning using a pretrained model (default: False)
57 | * –-datapoints Number of samples of training data allotted to each client (default: 500)
58 | * –-setting Setting you would like to run for, i.e, setting1 ,setting2 or setting4 (default: setting1)
59 | * –-checkpoint Epoch at which personalisation phase will start (default: 50)
60 | * --rate This arguments specifies the fraction of clients dropped off in every epoch (used in setting 5)(default: 0.5)
61 |
62 | For reproducing the results, always add argument –-pretrained while running the PFSL script.
63 |
64 | Create a results directory in the project folder to store all the resulting plots using the below commands.
65 | * `mkdir results`
66 | * `mkdir results/FL`
67 | * `mkdir results/SL`
68 | * `mkdir results/SFLv1`
69 | * `mkdir results/SFLv2`
70 |
71 | ### Commands for all the scenarios
72 |
73 | Below we state the commands for running PFSL, SL, FL, SFLv1 and SFLv2 for all the experimental scenarios.
74 |
75 | Setting 1: Small Sample Size (Equal), i.i.d.
76 | In this scenario, each client has a very small number of labelled data points ranging from 50 to 500, and all these samples are distributed identically across clients. There is no class imbalance in training data of each client. To run all the algorithms for setting 1 argument –-setting setting1 and –-datapoints [number of sample per client] has to be added.
77 | Rest of the arguments can be selected as per choice. Numberof data samples can be chosen from 50, 150, 250, 350 and 500 to reproduce the results. When total data sample size was
78 | 50, batch size was chosen to be 32 and for other data samples
79 | greater than 50 batch size was kept at 64. Test batch size was
80 | always taken to be 512. For data sample 150, command are
81 | given below.
82 |
83 | * `python PFSL_Setting124.py --dataset cifar10 --setting setting1 --datapoints 150 --pretrained --model resnet18 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100`
84 | * `python FL.py --dataset cifar10 --setting setting1 --datapoints 150 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100`
85 | * `python SL.py --dataset cifar10 --setting setting1 --datapoints 150 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100`
86 | * `python SFLv1.py --dataset cifar10 --setting setting1 --datapoints 150 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100`
87 | * `python SFLv2.py --dataset cifar10 --setting setting1 --datapoints 150 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100`
88 |
89 |
90 |
91 |
92 |
93 | Setting 2: Small Sample Size (Equal), non-i.i.d.
94 | In this setting, we model a situation where every client has more labelled data points from a subset of classes (prominent
95 | classes) and less from the remaining classes. We chose to experiment with heavy label imbalance and diversity. Sample size is small and each client has equal number of training samples. To run all the algorithms for setting 2 argument --setting setting2 has to be added. For PFSL, to enable personalisation phase
96 | from xth epoch, argument --checkpoint [x] has to be added.
97 | Rest of the arguments can be selected as per choice.
98 |
99 | * `python PFSL_Setting124.py --dataset cifar10 --model resnet18 --pretrained --setting setting2 --batch_size 64 --test_batch_size 512 --checkpoint 25 --epochs 30`
100 | * `python FL.py --dataset cifar10 --setting setting2 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100`
101 | * `python SL.py --dataset cifar10 --setting setting2 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100`
102 | * `python SFLv1.py --dataset cifar10 --setting setting2 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100`
103 | * `python SFLv2.py --dataset cifar10 --setting setting2 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100`
104 |
105 |
106 |
107 |
108 |
109 |
110 | Setting 3: Small Sample Size (Unequal), i.i.d.
111 | In this settingwe consider we there 11 clients where the Large client has 2000 labelled data points
112 | while the other ten small clients have 150 labelled data points,
113 | each distributed identically. The class distributions
114 | among all the clients are the same. For evaluation purposes,
115 | we consider a test set having 2000 data points with an identical
116 | distribution of classes as the train set.
117 |
118 | To reproduce Table IV of the paper, run setting 1 with
119 | datapoints as 150 as illustrated above. To reproduce Table V
120 | of the paper follow the below commands. In all the commands argument --datapoints that denotes the number of datapoints of the large client has to be added.In our case it was 2000.
121 |
122 | * `python PFSL_Setting3.py --datapoints 2000 --dataset cifar10 --pretrained --model resnet18 -c 11 --epochs 50`
123 | * `python SFLv1_Setting3.py --datapoints 2000 --dataset cifar10_setting3 -c 11 --epochs 100`
124 | * `python SFLv2_Setting3.py --datapoints 2000 --dataset cifar10_setting3 -c 11 --epochs 100`
125 | * `python FL_Setting3.py --datapoints 2000 --dataset cifar10_setting3 -c 11 --epochs 100`
126 | * `python SL_Setting3.py --datapoints 2000 --dataset cifar10_setting3 -c 11 --epochs 100`
127 |
128 |
129 |
130 |
131 |
132 |
133 | Setting 4: A large number of data samples
134 | Here, all clients have large number of samples. This experiment was done with three different image classification datasets:
135 | MNIST, FMNIST, and CIFAR-10. To run all the algorithms for setting 4 argument --setting setting4 has
136 | to be added. Rest of the arguments can be selected as per choice. Dataset argument has 3 options: cifar10, mnist and fmnist.
137 |
138 | * `python PFSL_Setting124.py --dataset cifar10 --setting setting4 --pretrained --model resnet18 -c 5 --epochs 20`
139 | * `python FL.py --dataset cifar10 --setting setting4 -c 5 --epochs 20`
140 | * `python SL.py --dataset cifar10 --setting setting4 -c 5 --epochs 20`
141 | * `python SFLv1.py --dataset cifar10 --setting setting4 -c 5 --epochs 20`
142 | * `python SFLv2.py --dataset cifar10 --setting setting4 -c 5 --epochs 20`
143 |
144 |
145 |
146 |
147 |
148 | Setting 5: System simulation with 1000 client
149 | In this setting we try to simulate an environment with 1000 clients. Each client stays in the system only for 1 round which lasts only 1 epoch.
150 | Thus, we evaluate our system for the worst possible scenario when every client cannot stay in the system for long and can only afford to make a minimal effort to participate. We assume that each client has 50 labeled data points sampled randomly but unique to the client. Within each round, we
151 | simulate a dropout, where clients begin training but are not able to complete the weight averaging. We keep the dropout probability at 50%.
152 |
153 |
154 | Use the following command to reproduce the results: Here rate argument specifies the dropoff rate which is the numberof clients that will be dropped randomly in every epoch
155 |
156 | * `python system_simulation_e2.py -c 10 --batch_size 16 --dataset cifar10 --model resnet18 --pretrained --epochs 100 --rate 0.3`
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 | Setting 6: Different Diabetic Retinopathy Datasets:
167 | This experiment describes the realistic scenario when healthcare centers have different sets of raw patient data for the
168 | same disease. We have used two datasets EyePACS and APTOS whose references are given below.
169 |
170 |
171 | Dataset Sources:
172 | * Source of Dataset 1, https://www.kaggle.com/competitions/aptos2019-blindness-detection/data
173 | * Source of Dataset 2, https://www.kaggle.com/datasets/mariaherrerot/eyepacspreprocess
174 |
175 | To preprocess the dataset download and store the unzipped files in data/eye dataset1 folder and data/eye dataset2 folder.
176 | For this create directories using the command:
177 | * `mkdir data/eye_dataset1`
178 | * `mkdir data/eye_dataset2`
179 |
180 |
181 |
182 | The directory structure of data is as follows:
183 | * `data/eye_dataset1/train_images`
184 | * `data/eye_dataset1/test_images`
185 | * `data/eye_dataset1/test.csv`
186 | * `data/eye_dataset1/train.csv`
187 | * `data/eye_dataset2/eyepacs_preprocess/eyepacs_preprocess/`
188 | * `data/eye_dataset2/trainLabels.csv`
189 |
190 | Once verify the path of the unzipped folders in the load data function of preprocess_eye_dataset_1.py and preprocess_eye_dataset_2.py files.
191 |
192 | For Data preprocessing, run the commands mentioned below
193 | for both the datasets
194 | `python utils/preprocess_eye_dataset_1.py`
195 | `python utils/preprocess_eye_dataset 2.py`
196 |
197 | * `python PFSL_DR.py --pretrained --model resnet18 -c 10 --batch_size 64 --test_batch_size 512 --epochs 50`
198 | * `python FL_DR.py -c 10 --batch_size 64 --test_batch_size 512 --epochs 50`
199 | * `python SL_DR.py --batch_size 64 --test_batch_size 512 --epochs 50`
200 | * `python SFLv1_DR.py --batch_size 64 --test_batch_size 512 --epochs 50`
201 | * `python SFLv2_DR.py --batch_size 64 --test_batch_size 512 --epochs 50`
202 |
203 |
204 |
205 |
206 | ## (6) Test Example Outputs for different Settings
207 |
208 | ### Setting 1
209 |
210 | Command: `python PFSL_Setting124.py --dataset cifar10 --setting setting1 --datapoints 150 --pretrained --model resnet18 -c 10 --batch_size 64 --test_batch_size 512 --epochs 100`
211 |
212 | Maximum test accuracy and time taken for the run is noted in this setting.
213 |
214 | Final Output of the above command is as follows:
215 | * Epoch: 100, Iteration: 3/3
216 | * Training Accuracy: 100.0
217 | * Maximum Test Accuracy: 82.52188846982759
218 | * Time taken for this run 49.36871874332428 mins
219 |
220 | ### Setting 2
221 |
222 | Command: `python PFSL_Setting124.py --dataset cifar10 --model resnet18 --pretrained --setting setting2 --batch_size 64 --test_batch_size 512 --checkpoint 25 --epochs 30`
223 |
224 | After the 25th layer, personalization phase begins since checkpoint is specified as 25 in the above command. It outputs F1 Score just before the start of the personalization phase and the maximum F1 Score achieved in that phase.
225 |
226 | Final Output:
227 | * Epoch: 25, Iteration: 8/8freezing the center model
228 | * Epoch: 26, Iteration: 8/8freezing the center model
229 | * F1 Score at epoch 25 : 0.8151361976947273
230 | * Epoch: 30, Iteration: 8/8
231 | * Training Accuracy: 100.0
232 | * Maximum F1 Score: 0.9509444261471961
233 | * Time taken for this run 11.245620834827424 mins
234 |
235 |
236 | ### Setting 3
237 |
238 | Command: `python PFSL_Setting3.py --datapoints 2000 --dataset cifar10 --pretrained --model resnet18 -c 11 --epochs 50`
239 |
240 |
241 | This command will print statements of the form as below every epoch
242 | * Large client train/Test accuracy xx.xx
243 | * Epoch 19 C1-C10 Average Train/Test Acc: xx.xx
244 |
245 | Final Output will be the maximum test accuracy of the large client and the maximum average test accuracy of the remaining clients which for the above command is
246 | * Average C1 - C10 test accuracy: 86.1162109375
247 | * Large Client test Accuracy: 87.109375
248 | * Time taken for this run 7.857871949672699 mins
249 |
250 |
251 | ### Setting 4
252 |
253 | Command: `python PFSL_Setting124.py --dataset cifar10 --setting setting4 --pretrained --model resnet18 -c 5 --epochs 20`
254 |
255 |
256 | Final Output of the above command is as follows
257 | * Epoch: 20, Iteration: 79/79
258 | * Training Accuracy: 98.90427215189872
259 | * Maximum Test Accuracy: 94.1484375
260 | * Time taken for this run 36.06000682512919 mins
261 |
262 | ### Setting 5
263 |
264 | Command: `python system_simulation_e2.py -c 10 --batch_size 16 --dataset cifar10 --model resnet18 --pretrained --epochs 40 --rate 0.3`
265 |
266 | For every epoch it prints the average train accuracy and the number of clients that are dropped off. Next, it prints the ids of the clients that are not dropped off.
267 |
268 |
269 | Final Output of the above command is as follows
270 | * Personalized Average Test Acc: 88.19791666666666
271 | * Time taken for this run 18.778587651252746 mins
272 |
273 |
274 |
275 |
276 | ### Setting 6
277 |
278 | Command: `python PFSL_DR.py --pretrained --model resnet18 -c 10 --batch_size 64 --test_batch_size 512 --epochs 50`
279 |
280 | Average test accuracy for clients having the APTOS dataset and clients having the EyePACS dataset is noted separately. Also, F1 Score for one representative client from each group is noted. The command above outputs these metrics for the epoch in which the maximum average test accuracy of all the clients is achieved.
281 |
282 | Final Output:
283 | * Epoch: 50, Iteration: 8/8
284 | * Time taken for this run 75.75171089967093 mins
285 | * Time taken for this run 75.75171089967093 mins
286 | * Maximum Personalized Average Test Acc: 79.24868724385246
287 | * Maximum Personalized Average Train Acc: 97.85606971153845
288 | * Client0 F1 Scores: 0.772644561137067
289 | * Client5 F1 Scores:0.5906352306590767
290 | * Personalized Average Test Accuracy for Clients 0 to 4 ": 85.21932633196721
291 | * Personalized Average Test Accuracy for Clients 5 to 9": 73.27804815573771
292 |
293 |
294 |
295 |
296 | ## (7) Quick Validation of Environment
297 |
298 | Command: `python PFSL_Setting124.py --dataset cifar10 --setting setting1 --datapoints 50 --pretrained --model resnet18 -c 5 --batch_size 64 --test_batch_size 512 --epochs 2`
299 |
300 | Output
301 |
302 | * Training Accuracy: 82.0
303 | * Maximum Test Accuracy: 57.88355334051723
304 | * Time taken for this run 0.4888111670811971 mins
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
--------------------------------------------------------------------------------
/PFSL_DR.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import string
4 | import socket
5 | import requests
6 | import sys
7 | import threading
8 | import time
9 | import torch
10 | from math import ceil
11 | from torchvision import transforms
12 | from torch.utils.data import DataLoader, Dataset
13 | from utils.split_dataset import split_dataset, split_dataset_cifar10tl_exp
14 | from utils.client_simulation import generate_random_clients
15 | from utils.connections import send_object
16 | from utils.arg_parser import parse_arguments
17 | import matplotlib.pyplot as plt
18 | import time
19 | import multiprocessing
20 | from sklearn.metrics import classification_report
21 | import torch.optim as optim
22 | import copy
23 | from datetime import datetime
24 | from scipy.interpolate import make_interp_spline
25 | import numpy as np
26 | from ConnectedClient import ConnectedClient
27 | import importlib
28 | from utils.merge import merge_grads, merge_weights
29 | import pandas as pd
30 | import time
31 | from utils import get_eye_dataset
32 | import torch.nn.functional as F
33 | # Clients Side Program
34 | #==============================================================================================================
35 | class DatasetSplit(Dataset):
36 | def __init__(self, dataset, idxs):
37 | self.dataset = dataset
38 | self.idxs = list(idxs)
39 |
40 | def __len__(self):
41 | return len(self.idxs)
42 |
43 | def __getitem__(self, item):
44 | image, label = self.dataset[self.idxs[item]]
45 | return image, label
46 |
47 |
48 | def initialize_client(client, dataset, batch_size, test_batch_size, tranform):
49 | client.load_data(args.dataset, transform)
50 | print(f'Length of train dataset client {client.id}: {len(client.train_dataset)}')
51 | client.create_DataLoader(batch_size, test_batch_size)
52 |
53 |
54 | def select_random_clients(clients):
55 | random_clients = {}
56 | client_ids = list(clients.keys())
57 | random_index = random.randint(0,len(client_ids)-1)
58 | random_client_ids = client_ids[random_index]
59 |
60 | print(random_client_ids)
61 | print(clients)
62 |
63 | for random_client_id in random_client_ids:
64 | random_clients[random_client_id] = clients[random_client_id]
65 | return random_clients
66 |
67 |
68 | def plot_class_distribution(clients, dataset, batch_size, epochs, opt, client_ids):
69 | class_distribution=dict()
70 | number_of_clients=len(client_ids)
71 | if(len(clients)<=20):
72 | plot_for_clients=client_ids
73 | else:
74 | plot_for_clients=random.sample(client_ids, 20)
75 |
76 | fig, ax = plt.subplots(nrows=(int(ceil(len(client_ids)/5))), ncols=5, figsize=(15, 10))
77 | j=0
78 | i=0
79 |
80 | #plot histogram
81 | for client_id in plot_for_clients:
82 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels'])
83 | class_distribution[client_id]=df['labels'].value_counts().sort_index()
84 | df['labels'].value_counts().sort_index().plot(ax = ax[i,j], kind = 'bar', ylabel = 'frequency', xlabel=client_id)
85 | j+=1
86 | if(j==5 or j==10 or j==15):
87 | i+=1
88 | j=0
89 | fig.tight_layout()
90 | plt.show()
91 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_histogram.png')
92 | plt.savefig('plot_setting_DR_exp.png')
93 |
94 | max_len=0
95 | #plot line graphs
96 | for client_id in plot_for_clients:
97 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels'])
98 | df['labels'].value_counts().sort_index().plot(kind = 'line', ylabel = 'frequency', label=client_id)
99 | max_len=max(max_len, list(df['labels'].value_counts(sort=False)[df.labels.mode()])[0])
100 | plt.xticks(np.arange(0,10))
101 | plt.ylim(0, max_len)
102 | plt.legend()
103 | plt.show()
104 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_line_graph.png')
105 |
106 | return class_distribution
107 |
108 | if __name__ == "__main__":
109 |
110 |
111 | args = parse_arguments()
112 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
113 | print("Arguments provided", args)
114 |
115 | random.seed(args.seed)
116 | torch.manual_seed(args.seed)
117 |
118 | overall_test_acc = []
119 | overall_test_acc1 = []
120 | overall_test_acc2 = []
121 | overall_train_acc = []
122 |
123 | print('Generating random clients...', end='')
124 | clients = generate_random_clients(args.number_of_clients)
125 | client_ids = list(clients.keys())
126 | print('Done')
127 |
128 | # train_dataset_size, input_channels = split_dataset_cifar10tl_exp(client_ids, args.datapoints)
129 |
130 | print(f'Random client ids:{str(client_ids)}')
131 | transform=None
132 | max_epoch=0
133 | max_f1=0
134 | max_accuracy=0
135 | max_train_accuracy=0
136 | max_c0_4_test=0
137 | max_c0_f1=0
138 | max_c5_9_test=0
139 | max_c5_f1=0
140 |
141 |
142 | print('Initializing clients...')
143 |
144 |
145 |
146 | d1, d2=get_eye_dataset.get_idxs()
147 |
148 |
149 |
150 |
151 | input_channels=3
152 |
153 |
154 |
155 |
156 | i=0
157 | dict_user_train=dict()
158 | dict_user_test=dict()
159 | client_idxs=dict()
160 |
161 | for _, client in clients.items():
162 | dict_user_train[_]=d1[i]
163 | dict_user_test[_]=d2[i]
164 | # dict_user_test_generalized[_]=dict_users_test_equal[i%2]
165 | client_idxs[_]=i
166 | i+=1
167 | for _, client in clients.items():
168 | # client.train_dataset=DatasetSplit(train_full_dataset, dict_user_train[_])
169 | client_idx=client_idxs[_]
170 | if(client_idx>=0 and client_idx<=4 ):
171 | train_data_id=1
172 | test_data_id=3
173 | elif(client_idx>=5 and client_idx<=9):
174 | train_data_id=2
175 | test_data_id=4
176 |
177 | client.train_dataset=get_eye_dataset.get_eye_data(dict_user_train[_], train_data_id)
178 | client.test_dataset=get_eye_dataset.get_eye_data(dict_user_test[_], test_data_id)
179 | client.create_DataLoader(args.batch_size, args.test_batch_size)
180 |
181 |
182 |
183 |
184 | #class_distribution=plot_class_distribution(clients, args.dataset, args.batch_size, args.epochs, args.opt_iden, client_ids)
185 | print('Client Intialization complete.')
186 | model = importlib.import_module(f'models.{args.model}')
187 | count=0
188 | for _, client in clients.items():
189 | client.front_model = model.front(input_channels, pretrained=args.pretrained)
190 | client.back_model = model.back(pretrained=args.pretrained)
191 |
192 | print('Done')
193 |
194 |
195 |
196 |
197 |
198 | for _, client in clients.items():
199 |
200 | # client.front_optimizer = optim.SGD(client.front_model.parameters(), lr=args.lr, momentum=0.9)
201 | # client.back_optimizer = optim.SGD(client.back_model.parameters(), lr=args.lr, momentum=0.9)
202 | client.front_optimizer = optim.Adam(client.front_model.parameters(), lr=args.lr)
203 | client.back_optimizer = optim.Adam(client.back_model.parameters(), lr=args.lr)
204 |
205 |
206 |
207 | first_client = clients[client_ids[0]]
208 | num_iterations = ceil(len(first_client.train_DataLoader.dataset)/args.batch_size)
209 | num_test_iterations_personalization = ceil(len(first_client.test_DataLoader.dataset)/args.test_batch_size)
210 | sc_clients = {} #server copy clients
211 |
212 | for iden in client_ids:
213 | sc_clients[iden] = ConnectedClient(iden, None)
214 |
215 | for _,s_client in sc_clients.items():
216 | s_client.center_model = model.center(pretrained=args.pretrained)
217 |
218 | s_client.center_model.to(device)
219 | # s_client.center_optimizer = optim.SGD(s_client.center_model.parameters(), lr=args.lr, momentum=0.9)
220 | s_client.center_optimizer = optim.Adam(s_client.center_model.parameters(), args.lr)
221 |
222 |
223 | st = time.time()
224 |
225 | macro_avg_f1_3classes=[]
226 | macro_avg_f1_dict={}
227 |
228 | criterion=F.cross_entropy
229 |
230 |
231 | for epoch in range(args.epochs):
232 |
233 |
234 | overall_train_acc.append(0)
235 | for _, client in clients.items():
236 | client.train_acc.append(0)
237 | client.iterator = iter(client.train_DataLoader)
238 |
239 |
240 | for iteration in range(num_iterations):
241 | print(f'\rEpoch: {epoch+1}, Iteration: {iteration+1}/{num_iterations}', end='')
242 |
243 | for _, client in clients.items():
244 | client.forward_front()
245 |
246 | for client_id, client in sc_clients.items():
247 | client.remote_activations1 = clients[client_id].remote_activations1
248 | client.forward_center()
249 |
250 | for client_id, client in clients.items():
251 | client.remote_activations2 = sc_clients[client_id].remote_activations2
252 | client.forward_back()
253 |
254 | for _, client in clients.items():
255 | client.calculate_loss()
256 |
257 | for _, client in clients.items():
258 | client.backward_back()
259 |
260 | for client_id, client in sc_clients.items():
261 | client.remote_activations2 = clients[client_id].remote_activations2
262 | client.backward_center()
263 |
264 |
265 | for _, client in clients.items():
266 | client.step_back()
267 | client.zero_grad_back()
268 |
269 | for _, client in sc_clients.items():
270 | client.center_optimizer.step()
271 | client.center_optimizer.zero_grad()
272 |
273 | for _, client in clients.items():
274 | client.train_acc[-1] += client.calculate_train_acc()
275 |
276 | for c_id, client in clients.items():
277 | client.train_acc[-1] /= num_iterations
278 | overall_train_acc[-1] += client.train_acc[-1]
279 |
280 | overall_train_acc[-1] /= len(clients)
281 |
282 |
283 | # merge weights below uncomment
284 | params = []
285 | for _, client in sc_clients.items():
286 | params.append(copy.deepcopy(client.center_model.state_dict()))
287 | w_glob = merge_weights(params)
288 |
289 | for _, client in sc_clients.items():
290 | client.center_model.load_state_dict(w_glob)
291 |
292 | params = []
293 | # if(epoch <=args.checkpoint):
294 | for _, client in clients.items():
295 | params.append(copy.deepcopy(client.back_model.state_dict()))
296 | w_glob_cb = merge_weights(params)
297 |
298 | for _, client in clients.items():
299 | client.back_model.load_state_dict(w_glob_cb)
300 |
301 |
302 |
303 | # Testing on every 5th epoch
304 |
305 | if (epoch%1 == 0 ):
306 |
307 | with torch.no_grad():
308 | test_acc = 0
309 | overall_test_acc.append(0)
310 | overall_test_acc1.append(0)
311 | overall_test_acc2.append(0)
312 |
313 | # for
314 | for _, client in clients.items():
315 | client.test_acc.append(0)
316 | client.iterator = iter(client.test_DataLoader)
317 | client.pred=[]
318 | client.y=[]
319 | for iteration in range(num_test_iterations_personalization):
320 |
321 | for _, client in clients.items():
322 | client.forward_front()
323 |
324 | for client_id, client in sc_clients.items():
325 | client.remote_activations1 = clients[client_id].remote_activations1
326 | client.forward_center()
327 |
328 | for client_id, client in clients.items():
329 | client.remote_activations2 = sc_clients[client_id].remote_activations2
330 | client.forward_back()
331 |
332 | for _, client in clients.items():
333 | client.test_acc[-1] += client.calculate_test_acc()
334 |
335 | for _, client in clients.items():
336 | client.test_acc[-1] /= num_test_iterations_personalization
337 | overall_test_acc[-1] += client.test_acc[-1]
338 | idx=client_idxs[_]
339 | if(idx>=0 and idx<5):
340 | overall_test_acc1[-1] += client.test_acc[-1]
341 | elif(idx>=5 and idx<10):
342 | overall_test_acc2[-1] += client.test_acc[-1]
343 |
344 | clr=classification_report(np.array(client.y), np.array(client.pred), output_dict=True, zero_division=0)
345 |
346 |
347 | curr_f1=(clr['0']['f1-score']+clr['1']['f1-score']+clr['2']['f1-score'])/3
348 | macro_avg_f1_3classes.append(curr_f1)
349 | macro_avg_f1_dict[idx]=curr_f1
350 |
351 | overall_test_acc[-1] /= len(clients)
352 | overall_test_acc1[-1] /=5
353 | overall_test_acc2[-1] /= 5
354 |
355 | f1_avg_all_user=sum(macro_avg_f1_3classes)/len(macro_avg_f1_3classes)
356 | macro_avg_f1_3classes=[]
357 |
358 | if(overall_test_acc[-1] > max_accuracy):
359 | max_accuracy=overall_test_acc[-1]
360 | max_train_accuracy=overall_train_acc[-1]
361 | max_epoch=epoch
362 | max_c0_f1=macro_avg_f1_dict[0]
363 | max_c5_f1=macro_avg_f1_dict[5]
364 | max_c0_4_test=overall_test_acc1[-1]
365 | max_c5_9_test=overall_test_acc2[-1]
366 |
367 |
368 |
369 | macro_avg_f1_dict={}
370 |
371 | timestamp = int(datetime.now().timestamp())
372 | plot_config = f'''dataset: {args.dataset},
373 | model: {args.model},
374 | batch_size: {args.batch_size}, lr: {args.lr},
375 | '''
376 |
377 | et = time.time()
378 | print(f"\nTime taken for this run {(et - st)/60} mins")
379 | print(f"Time taken for this run {(et - st)/60} mins")
380 | print(f'Maximum Personalized Average Test Acc: {max_accuracy} ')
381 | print(f'Maximum Personalized Average Train Acc: {max_train_accuracy} ')
382 | print(f'Client0 F1 Scores: {max_c0_f1}')
383 | print(f'Client5 F1 Scores:{max_c5_f1}')
384 | print(f'Personalized Average Test Accuracy for Clients 0 to 4 ": {max_c0_4_test}')
385 | print(f'Personalized Average Test Accuracy for Clients 5 to 9": {max_c5_9_test}')
386 |
387 |
388 | X = range(args.epochs)
389 | all_clients_stacked_train = np.array([client.train_acc for _,client in clients.items()])
390 | all_clients_stacked_test = np.array([client.test_acc for _,client in clients.items()])
391 | epochs_train_std = np.std(all_clients_stacked_train,axis = 0, dtype = np.float64)
392 | epochs_test_std = np.std(all_clients_stacked_test,axis = 0, dtype = np.float64)
393 |
394 | #Y_train is the average client train accuracies at each epoch
395 | #epoch_train_std is the standard deviation of clients train accuracies at each epoch
396 | Y_train = overall_train_acc
397 | Y_train_lower = Y_train - (1.65 * epochs_train_std) #95% of the values lie between 1.65*std
398 | Y_train_upper = Y_train + (1.65 * epochs_train_std)
399 |
400 | Y_test = overall_test_acc
401 | Y_test_lower = Y_test - (1.65 * epochs_test_std) #95% of the values lie between 1.65*std
402 | Y_test_upper = Y_test + (1.65 * epochs_test_std)
403 |
404 | Y_train_cv = epochs_train_std / Y_train
405 | Y_test_cv = epochs_test_std / Y_test
406 |
407 | plt.figure(0)
408 | plt.plot(X, Y_train)
409 | plt.fill_between(X,Y_train_lower , Y_train_upper, color='blue', alpha=0.25)
410 | # plt.savefig(f'./results/train_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight')
411 | plt.show()
412 |
413 |
414 | plt.figure(1)
415 | plt.plot(X, Y_test)
416 | plt.fill_between(X,Y_test_lower , Y_test_upper, color='blue', alpha=0.25)
417 | # plt.savefig(f'./results/test_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight')
418 | plt.show()
419 |
420 | plt.figure(2)
421 | plt.plot(X, Y_train_cv)
422 | plt.show()
423 |
424 | plt.figure(3)
425 | plt.plot(X, Y_test_cv)
426 | plt.show()
427 |
428 |
429 |
430 |
--------------------------------------------------------------------------------
/PFSL_Setting124.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import string
4 | import socket
5 | import requests
6 | import sys
7 | import threading
8 | import time
9 | import torch
10 | from math import ceil
11 | from torchvision import transforms
12 | from torch.utils.data import DataLoader, Dataset
13 | from utils.split_dataset import split_dataset, split_dataset_cifar10tl_exp
14 | from utils.client_simulation import generate_random_clients
15 | from utils.connections import send_object
16 | from utils.arg_parser import parse_arguments
17 | import matplotlib.pyplot as plt
18 | import time
19 | import multiprocessing
20 | from sklearn.metrics import classification_report
21 | import torch.optim as optim
22 | import copy
23 | from datetime import datetime
24 | from scipy.interpolate import make_interp_spline
25 | import numpy as np
26 | from ConnectedClient import ConnectedClient
27 | import importlib
28 | from utils.merge import merge_grads, merge_weights
29 | import pandas as pd
30 | import time
31 | from utils import dataset_settings, datasets
32 | import torch.nn.functional as F
33 |
34 |
35 | #To load train and test data for each client for setting 1 and setting 2
36 | class DatasetSplit(Dataset):
37 | def __init__(self, dataset, idxs):
38 | self.dataset = dataset
39 | self.idxs = list(idxs)
40 |
41 | def __len__(self):
42 | return len(self.idxs)
43 |
44 | def __getitem__(self, item):
45 | image, label = self.dataset[self.idxs[item]]
46 | return image, label
47 |
48 |
49 | #To intialize every client with their train and test data for setting 4
50 | def initialize_client(client, dataset, batch_size, test_batch_size, tranform):
51 |
52 | client.load_data(dataset, transform)
53 | print(f'Length of train dataset client {client.id}: {len(client.train_dataset)}')
54 | client.create_DataLoader(batch_size, test_batch_size)
55 |
56 |
57 | #Plots class distribution of train data available to each client
58 | def plot_class_distribution(clients, dataset, batch_size, epochs, opt, client_ids):
59 | class_distribution=dict()
60 | number_of_clients=len(client_ids)
61 | if(len(clients)<=20):
62 | plot_for_clients=client_ids
63 | else:
64 | plot_for_clients=random.sample(client_ids, 20)
65 |
66 | fig, ax = plt.subplots(nrows=(int(ceil(len(client_ids)/5))), ncols=5, figsize=(15, 10))
67 | j=0
68 | i=0
69 |
70 | #plot histogram
71 | for client_id in plot_for_clients:
72 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels'])
73 | class_distribution[client_id]=df['labels'].value_counts().sort_index()
74 | df['labels'].value_counts().sort_index().plot(ax = ax[i,j], kind = 'bar', ylabel = 'frequency', xlabel=client_id)
75 | j+=1
76 | if(j==5 or j==10 or j==15):
77 | i+=1
78 | j=0
79 | fig.tight_layout()
80 | plt.show()
81 |
82 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_histogram.png')
83 | plt.savefig('plot_setting3_exp.png')
84 |
85 | max_len=0
86 | #plot line graphs
87 | for client_id in plot_for_clients:
88 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels'])
89 | df['labels'].value_counts().sort_index().plot(kind = 'line', ylabel = 'frequency', label=client_id)
90 | max_len=max(max_len, list(df['labels'].value_counts(sort=False)[df.labels.mode()])[0])
91 | plt.xticks(np.arange(0,10))
92 | plt.ylim(0, max_len)
93 | plt.legend()
94 | plt.show()
95 |
96 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_line_graph.png')
97 |
98 | return class_distribution
99 |
100 |
101 |
102 | if __name__ == "__main__":
103 |
104 |
105 | args = parse_arguments()
106 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
107 | print("Arguments provided", args)
108 |
109 |
110 |
111 | random.seed(args.seed)
112 | torch.manual_seed(args.seed)
113 |
114 | overall_test_acc = []
115 | overall_train_acc = []
116 |
117 | print('Generating random clients...', end='')
118 | clients = generate_random_clients(args.number_of_clients)
119 | client_ids = list(clients.keys())
120 | print('Done')
121 |
122 | train_dataset_size, input_channels = split_dataset(args.dataset, client_ids, pretrained=args.pretrained)
123 |
124 | print(f'Random client ids:{str(client_ids)}')
125 | transform=None
126 | max_epoch=0
127 | max_f1=0
128 | max_accuracy=0
129 |
130 | #Assigning train and test data to each client depending for each client
131 | print('Initializing clients...')
132 |
133 | if(args.setting=="setting4"):
134 | for _, client in clients.items():
135 | (initialize_client(client, args.dataset, args.batch_size, args.test_batch_size, transform))
136 | else:
137 |
138 | train_full_dataset, test_full_dataset, input_channels = datasets.load_full_dataset(args.dataset, "data", args.number_of_clients, args.datapoints, args.pretrained)
139 | #----------------------------------------------------------------
140 | dict_users , dict_users2 = dataset_settings.get_dicts(train_full_dataset, test_full_dataset, args.number_of_clients, args.setting, args.datapoints)
141 |
142 | dict_users_test_equal=dataset_settings.get_test_dict(test_full_dataset, args.number_of_clients)
143 |
144 | client_idx=0
145 | dict_user_train=dict()
146 | dict_user_test=dict()
147 | client_idxs=dict()
148 |
149 | for _, client in clients.items():
150 | dict_user_train[_]=dict_users[client_idx]
151 | dict_user_test[_]=dict_users2[client_idx]
152 | client_idxs[_]=client_idx
153 | client_idx+=1
154 | for _, client in clients.items():
155 | client.train_dataset=DatasetSplit(train_full_dataset, dict_user_train[_])
156 | client.test_dataset=DatasetSplit(test_full_dataset, dict_user_test[_])
157 | client.create_DataLoader(args.batch_size, args.test_batch_size)
158 | print('Client Intialization complete.')
159 | # Train and test data intialisation complete
160 |
161 | #Setting the start of personalisation phase
162 | if(args.setting!='setting2'):
163 | args.checkpoint=args.epochs+10
164 |
165 |
166 | # class_distribution=plot_class_distribution(clients, args.dataset, args.batch_size, args.epochs, args.opt_iden, client_ids)
167 |
168 |
169 | #Assigning front, center and back models and their optimizers for all the clients
170 | model = importlib.import_module(f'models.{args.model}')
171 |
172 | for _, client in clients.items():
173 | client.front_model = model.front(input_channels, pretrained=args.pretrained)
174 | client.back_model = model.back(pretrained=args.pretrained)
175 | print('Done')
176 |
177 | for _, client in clients.items():
178 | # client.front_optimizer = optim.SGD(client.front_model.parameters(), lr=args.lr, momentum=0.9)
179 | # client.back_optimizer = optim.SGD(client.back_model.parameters(), lr=args.lr, momentum=0.9)
180 | client.front_optimizer = optim.Adam(client.front_model.parameters(), lr=args.lr)
181 | client.back_optimizer = optim.Adam(client.back_model.parameters(), lr=args.lr)
182 |
183 | first_client = clients[client_ids[0]]
184 | num_iterations = ceil(len(first_client.train_DataLoader.dataset)/args.batch_size)
185 |
186 | num_test_iterations= ceil(len(first_client.test_DataLoader.dataset)/args.test_batch_size)
187 | sc_clients = {} #server copy clients
188 |
189 | for iden in client_ids:
190 | sc_clients[iden] = ConnectedClient(iden, None)
191 |
192 | for _,s_client in sc_clients.items():
193 | s_client.center_model = model.center(pretrained=args.pretrained)
194 | s_client.center_model.to(device)
195 | # s_client.center_optimizer = optim.SGD(s_client.center_model.parameters(), lr=args.lr, momentum=0.9)
196 | s_client.center_optimizer = optim.Adam(s_client.center_model.parameters(), args.lr)
197 |
198 | st = time.time()
199 |
200 | macro_avg_f1_2classes=[]
201 |
202 | criterion=F.cross_entropy
203 |
204 |
205 |
206 | #Starting the training process
207 | for epoch in range(args.epochs):
208 | if(epoch==args.checkpoint): # When starting epoch of the perosnalisation is reached, freeze all the layers of the center model
209 | print("freezing the center model")
210 | for _, s_client in sc_clients.items():
211 | s_client.center_model.freeze(epoch, pretrained=True)
212 |
213 | overall_train_acc.append(0)
214 |
215 |
216 | for _, client in clients.items():
217 | client.train_acc.append(0)
218 | client.iterator = iter(client.train_DataLoader)
219 |
220 | #For every batch in the current epoch
221 | for iteration in range(num_iterations):
222 | print(f'\rEpoch: {epoch+1}, Iteration: {iteration+1}/{num_iterations}', end='')
223 |
224 | for _, client in clients.items():
225 | client.forward_front()
226 |
227 | for client_id, client in sc_clients.items():
228 | client.remote_activations1 = clients[client_id].remote_activations1
229 | client.forward_center()
230 |
231 | for client_id, client in clients.items():
232 | client.remote_activations2 = sc_clients[client_id].remote_activations2
233 | client.forward_back()
234 |
235 | for _, client in clients.items():
236 | client.calculate_loss()
237 |
238 | for _, client in clients.items():
239 | client.backward_back()
240 |
241 | for client_id, client in sc_clients.items():
242 | client.remote_activations2 = clients[client_id].remote_activations2
243 | client.backward_center()
244 |
245 | for _, client in clients.items():
246 | client.step_back()
247 | client.zero_grad_back()
248 |
249 | #merge grads uncomment below
250 |
251 | # if epoch%2 == 0:
252 | # params = []
253 | # normalized_data_sizes = []
254 | # for iden, client in clients.items():
255 | # params.append(sc_clients[iden].center_model.parameters())
256 | # normalized_data_sizes.append(len(client.train_dataset) / train_dataset_size)
257 | # merge_grads(normalized_data_sizes, params)
258 |
259 | for _, client in sc_clients.items():
260 | client.center_optimizer.step()
261 | client.center_optimizer.zero_grad()
262 |
263 | for _, client in clients.items():
264 | client.train_acc[-1] += client.calculate_train_acc() #train accuracy of every client in the current epoch in the current batch
265 |
266 | for c_id, client in clients.items():
267 | client.train_acc[-1] /= num_iterations # train accuracy of every client of all the batches in the current epoch
268 | overall_train_acc[-1] += client.train_acc[-1]
269 |
270 | overall_train_acc[-1] /= len(clients) #avg train accuracy of all the clients in the current epoch
271 |
272 |
273 | # merge weights below uncomment
274 | params = []
275 | for _, client in sc_clients.items():
276 | params.append(copy.deepcopy(client.center_model.state_dict()))
277 | w_glob = merge_weights(params)
278 |
279 | for _, client in sc_clients.items():
280 | client.center_model.load_state_dict(w_glob)
281 |
282 | params = []
283 |
284 | #In the personalisation phase merging of weights of the back layers is stopped
285 | if(epoch <=args.checkpoint):
286 | for _, client in clients.items():
287 | params.append(copy.deepcopy(client.back_model.state_dict()))
288 | w_glob_cb = merge_weights(params)
289 | del params
290 |
291 | for _, client in clients.items():
292 | client.back_model.load_state_dict(w_glob_cb)
293 |
294 | #Testing every epoch
295 | if (epoch%1 == 0 ):
296 | if(epoch==args.checkpoint):
297 | print("freezing the center model")
298 | print("F1 Score at epoch ", epoch, " : ",f1_avg_all_user)
299 |
300 | for _, s_client in sc_clients.items():
301 | s_client.center_model.freeze(epoch, pretrained=True)
302 | with torch.no_grad():
303 | test_acc = 0
304 | overall_test_acc.append(0)
305 |
306 | for _, client in clients.items():
307 | client.test_acc.append(0)
308 | client.iterator = iter(client.test_DataLoader)
309 | client.pred=[]
310 | client.y=[]
311 |
312 | #For every batch in the testing phase
313 | for iteration in range(num_test_iterations):
314 |
315 | for _, client in clients.items():
316 | client.forward_front()
317 |
318 | for client_id, client in sc_clients.items():
319 | client.remote_activations1 = clients[client_id].remote_activations1
320 | client.forward_center()
321 |
322 | for client_id, client in clients.items():
323 | client.remote_activations2 = sc_clients[client_id].remote_activations2
324 | client.forward_back()
325 |
326 | for _, client in clients.items():
327 | client.test_acc[-1] += client.calculate_test_acc()
328 |
329 |
330 | for _, client in clients.items():
331 | client.test_acc[-1] /= num_test_iterations
332 | overall_test_acc[-1] += client.test_acc[-1]
333 | #Calculating the F1 scores using the classification report from sklearn metrics
334 | if(args.setting=='setting2'):
335 | clr=classification_report(np.array(client.y), np.array(client.pred), output_dict=True, zero_division=0)
336 | idx=client_idxs[_]
337 |
338 | macro_avg_f1_2classes.append((clr[str(idx)]['f1-score']+clr[str((idx+1)%10)]['f1-score'])/2) #macro f1 score of the 2 prominent classes in setting2
339 |
340 |
341 | overall_test_acc[-1] /= len(clients) #average test accuracy of all the clients in the current epoch
342 |
343 | if(args.setting=='setting2'):
344 | f1_avg_all_user=sum(macro_avg_f1_2classes)/len(macro_avg_f1_2classes) #average f1 scores of the clients for the prominent 2 classes in the current epoch
345 | macro_avg_f1_2classes=[]
346 |
347 |
348 |
349 | #Noting the maximum f1 score
350 | if(f1_avg_all_user> max_f1):
351 | max_f1=f1_avg_all_user
352 | max_epoch=epoch
353 |
354 | else:
355 | if(overall_test_acc[-1]> max_accuracy):
356 | max_accuracy=overall_test_acc[-1]
357 | max_epoch=epoch
358 |
359 |
360 |
361 | timestamp = int(datetime.now().timestamp())
362 | plot_config = f'''dataset: {args.dataset},
363 | model: {args.model},
364 | batch_size: {args.batch_size}, lr: {args.lr},
365 | '''
366 |
367 | et = time.time()
368 | print("\nTraining Accuracy: ", overall_train_acc[max_epoch])
369 | if(args.setting=='setting2'):
370 | print("Maximum F1 Score: ", max_f1)
371 | else:
372 | print("Maximum Test Accuracy: ", max_accuracy)
373 | print(f"Time taken for this run {(et - st)/60} mins")
374 |
375 |
376 |
377 | # calculating the train and test standarad deviation and teh confidence intervals
378 | X = range(args.epochs)
379 | all_clients_stacked_train = np.array([client.train_acc for _,client in clients.items()])
380 | all_clients_stacked_test = np.array([client.test_acc for _,client in clients.items()])
381 | epochs_train_std = np.std(all_clients_stacked_train,axis = 0, dtype = np.float64)
382 | epochs_test_std = np.std(all_clients_stacked_test,axis = 0, dtype = np.float64)
383 |
384 | #Y_train is the average client train accuracies at each epoch
385 | #epoch_train_std is the standard deviation of clients train accuracies at each epoch
386 | Y_train = overall_train_acc
387 | Y_train_lower = Y_train - (1.65 * epochs_train_std) #95% of the values lie between 1.65*std
388 | Y_train_upper = Y_train + (1.65 * epochs_train_std)
389 |
390 | Y_test = overall_test_acc
391 | Y_test_lower = Y_test - (1.65 * epochs_test_std) #95% of the values lie between 1.65*std
392 | Y_test_upper = Y_test + (1.65 * epochs_test_std)
393 |
394 | Y_train_cv = epochs_train_std / Y_train
395 | Y_test_cv = epochs_test_std / Y_test
396 |
397 | plt.figure(0)
398 | plt.plot(X, Y_train)
399 | plt.fill_between(X,Y_train_lower , Y_train_upper, color='blue', alpha=0.25)
400 | # plt.savefig(f'./results/test_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight')
401 | plt.show()
402 |
403 |
404 | plt.figure(1)
405 | plt.plot(X, Y_test)
406 | plt.fill_between(X,Y_test_lower , Y_test_upper, color='blue', alpha=0.25)
407 | # plt.savefig(f'./results/test_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight')
408 | plt.show()
409 |
410 |
411 | plt.figure(2)
412 | plt.plot(X, Y_train_cv)
413 | plt.show()
414 |
415 |
416 | plt.figure(3)
417 | plt.plot(X, Y_test_cv)
418 | plt.show()
419 |
420 |
421 |
422 |
--------------------------------------------------------------------------------
/FL_Setting3.py:
--------------------------------------------------------------------------------
1 | #===========================================================
2 | # Federated learning: ResNet18
3 | # ===========================================================
4 | import torch
5 | from torch import nn
6 | from torchvision import transforms
7 | from torch.utils.data import DataLoader, Dataset
8 | from pandas import DataFrame
9 | import pandas as pd
10 | from sklearn.model_selection import train_test_split
11 | from PIL import Image
12 | from glob import glob
13 | import math
14 | import random
15 | import numpy as np
16 | import os
17 | import matplotlib
18 | matplotlib.use('Agg')
19 | import matplotlib.pyplot as plt
20 | import copy
21 | import argparse
22 | from utils import datasets,dataset_settings
23 | import time
24 |
25 |
26 |
27 | ## ARGPARSER
28 |
29 | def parse_arguments():
30 | # Training settings
31 | parser = argparse.ArgumentParser(
32 | description="Splitfed V1 configurations",
33 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
34 | )
35 | parser.add_argument(
36 | "--seed",
37 | type=int,
38 | default=1234,
39 | help="Random seed",
40 | )
41 | parser.add_argument(
42 | "-c",
43 | "--number_of_clients",
44 | type=int,
45 | default=5,
46 | metavar="C",
47 | help="Number of Clients",
48 | )
49 |
50 | parser.add_argument(
51 | "-n",
52 | "--epochs",
53 | type=int,
54 | default=50,
55 | metavar="N",
56 | help="Total number of epochs to train",
57 | )
58 |
59 | parser.add_argument(
60 | "--fac",
61 | type=float,
62 | default= 1.0,
63 | metavar="N",
64 | help="fraction of active/participating clients, if 1 then all clients participate in SFLV1",
65 | )
66 |
67 | parser.add_argument(
68 | "--lr",
69 | type=float,
70 | default=0.001,
71 | metavar="LR",
72 | help="Learning rate",
73 | )
74 |
75 | parser.add_argument(
76 | "--dataset",
77 | type=str,
78 | default="mnist",
79 | help="States dataset to be used",
80 | )
81 | parser.add_argument(
82 | "-b",
83 | "--batch_size",
84 | type=int,
85 | default=128,
86 | metavar="B",
87 | help="Batch size",
88 | )
89 |
90 |
91 | parser.add_argument(
92 | "--setting",
93 | type=str,
94 | default="setting1",
95 |
96 | )
97 |
98 | parser.add_argument(
99 | "--test_batch_size",
100 | type=int,
101 | default=128,
102 |
103 | )
104 |
105 |
106 | parser.add_argument(
107 | "--datapoints" ,
108 | type=int,
109 | default=500,
110 | )
111 | args = parser.parse_args()
112 | args = parser.parse_args()
113 | return args
114 |
115 |
116 |
117 | #==============================================================================================================
118 | # Client Side Program
119 | #==============================================================================================================
120 | class DatasetSplit(Dataset):
121 | def __init__(self, dataset, idxs):
122 | self.dataset = dataset
123 | self.idxs = list(idxs)
124 |
125 | def __len__(self):
126 | return len(self.idxs)
127 |
128 | def __getitem__(self, item):
129 | image, label = self.dataset[self.idxs[item]]
130 | return image, label
131 |
132 | # Client-side functions associated with Training and Testing
133 | class LocalUpdate(object):
134 | def __init__(self, idx, lr, device, dataset_train = None, dataset_test = None, idxs = None, idxs_test = None):
135 | self.idx = idx
136 | self.device = device
137 | self.lr = lr
138 | self.local_ep = 1
139 | self.loss_func = nn.CrossEntropyLoss()
140 | self.selected_clients = []
141 | self.ldr_train = DataLoader(DatasetSplit(dataset_train, idxs), batch_size = args.batch_size, shuffle = True)
142 | self.ldr_test = DataLoader(DatasetSplit(dataset_test, idxs_test), batch_size = args.test_batch_size, shuffle = True)
143 |
144 | def train(self, net):
145 | net.train()
146 | # train and update
147 | #optimizer = torch.optim.SGD(net.parameters(), lr = self.lr, momentum = 0.5)
148 | optimizer = torch.optim.Adam(net.parameters(), lr = self.lr)
149 |
150 | epoch_acc = []
151 | epoch_loss = []
152 | for iter in range(self.local_ep):
153 | batch_acc = []
154 | batch_loss = []
155 |
156 | for batch_idx, (images, labels) in enumerate(self.ldr_train):
157 | images, labels = images.to(self.device), labels.to(self.device)
158 | optimizer.zero_grad()
159 | #---------forward prop-------------
160 | fx = net(images)
161 |
162 | # calculate loss
163 | loss = self.loss_func(fx, labels)
164 | # calculate accuracy
165 | acc = calculate_accuracy(fx, labels)
166 |
167 | #--------backward prop--------------
168 | loss.backward()
169 | optimizer.step()
170 |
171 | batch_loss.append(loss.item())
172 | batch_acc.append(acc.item())
173 |
174 | if self.idx == 0:
175 | prRed('Client{} Train => Local Epoch: {} \tAcc: {:.3f} \tLoss: {:.4f}'.format(self.idx,
176 | iter, acc.item(), loss.item()))
177 |
178 | epoch_loss.append(sum(batch_loss)/len(batch_loss))
179 | epoch_acc.append(sum(batch_acc)/len(batch_acc))
180 |
181 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss), sum(epoch_acc) / len(epoch_acc)
182 |
183 | def evaluate(self, net):
184 | net.eval()
185 |
186 | epoch_acc = []
187 | epoch_loss = []
188 | with torch.no_grad():
189 | batch_acc = []
190 | batch_loss = []
191 | for batch_idx, (images, labels) in enumerate(self.ldr_test):
192 | images, labels = images.to(self.device), labels.to(self.device)
193 | #---------forward prop-------------
194 | fx = net(images)
195 |
196 | # calculate loss
197 | loss = self.loss_func(fx, labels)
198 | # calculate accuracy
199 | acc = calculate_accuracy(fx, labels)
200 |
201 | batch_loss.append(loss.item())
202 | batch_acc.append(acc.item())
203 |
204 | prGreen('Client{} Test => \tLoss: {:.4f} \tAcc: {:.3f}'.format(self.idx, loss.item(), acc.item()))
205 | epoch_loss.append(sum(batch_loss)/len(batch_loss))
206 | epoch_acc.append(sum(batch_acc)/len(batch_acc))
207 | if self.idx == 0:
208 | unique_test.append(sum(batch_acc)/len(batch_acc))
209 | return sum(epoch_loss) / len(epoch_loss), sum(epoch_acc) / len(epoch_acc)
210 |
211 |
212 |
213 | #=====================================================================================================
214 | # dataset_iid() will create a dictionary to collect the indices of the data samples randomly for each client
215 |
216 |
217 | def dataset_iid(dataset, num_users):
218 |
219 | num_items = int(len(dataset)/num_users)
220 | dict_users, all_idxs = {}, [i for i in range(len(dataset))]
221 | for i in range(num_users):
222 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace = False))
223 | all_idxs = list(set(all_idxs) - dict_users[i])
224 | return dict_users
225 |
226 | def dataset_iid_setting3(u_dataset, c_dataset, num_users):
227 |
228 | # u_user_idx = 0
229 | c_users = num_users - 1
230 | num_items = int(len(c_dataset)/c_users) #150
231 |
232 | unique_idxs = [i for i in range(len(u_dataset))]
233 | dict_users, all_idxs = {}, [i for i in range(len(c_dataset))]
234 | dict_users[0] = set(unique_idxs)
235 | for i in range(1, c_users + 1):
236 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace = False))
237 | all_idxs = list(set(all_idxs) - dict_users[i])
238 | return dict_users
239 |
240 | #====================================================================================================
241 | # Server Side Program
242 | #====================================================================================================
243 | def calculate_accuracy(fx, y):
244 | preds = fx.max(1, keepdim=True)[1]
245 | correct = preds.eq(y.view_as(preds)).sum()
246 | acc = 100.00 *correct.float()/preds.shape[0]
247 | return acc
248 |
249 | #=============================================================================
250 | # Model definition: ResNet18
251 | #=============================================================================
252 | # building a ResNet18 Architecture
253 | def conv3x3(in_planes, out_planes, stride=1):
254 | "3x3 convolution with padding"
255 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
256 | padding=1, bias=False)
257 |
258 |
259 | class BasicBlock(nn.Module):
260 | expansion = 1
261 |
262 | def __init__(self, inplanes, planes, stride=1, downsample=None):
263 | super(BasicBlock, self).__init__()
264 | self.conv1 = conv3x3(inplanes, planes, stride)
265 | self.bn1 = nn.BatchNorm2d(planes)
266 | self.relu = nn.ReLU(inplace=True)
267 | self.conv2 = conv3x3(planes, planes)
268 | self.bn2 = nn.BatchNorm2d(planes)
269 | self.downsample = downsample
270 | self.stride = stride
271 |
272 | def forward(self, x):
273 | residual = x
274 |
275 | out = self.conv1(x)
276 | out = self.bn1(out)
277 | out = self.relu(out)
278 |
279 | out = self.conv2(out)
280 | out = self.bn2(out)
281 |
282 | if self.downsample is not None:
283 | residual = self.downsample(x)
284 |
285 | out += residual
286 | out = self.relu(out)
287 |
288 | return out
289 |
290 | class ResNet18(nn.Module):
291 |
292 | def __init__(self, block, layers, input_channels, num_classes=1000):
293 | self.inplanes = 64
294 | super(ResNet18, self).__init__()
295 | self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
296 | bias=False)
297 | self.bn1 = nn.BatchNorm2d(64)
298 | self.relu = nn.ReLU(inplace=True)
299 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
300 | self.layer1 = self._make_layer(block, 64, layers[0])
301 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
302 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
303 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
304 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
305 | self.fc = nn.Linear(512 * block.expansion, num_classes)
306 |
307 | for m in self.modules():
308 | if isinstance(m, nn.Conv2d):
309 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
310 | m.weight.data.normal_(0, math.sqrt(2. / n))
311 | elif isinstance(m, nn.BatchNorm2d):
312 | m.weight.data.fill_(1)
313 | m.bias.data.zero_()
314 |
315 | def _make_layer(self, block, planes, blocks, stride=1):
316 | downsample = None
317 | if stride != 1 or self.inplanes != planes * block.expansion:
318 | downsample = nn.Sequential(
319 | nn.Conv2d(self.inplanes, planes * block.expansion,
320 | kernel_size=1, stride=stride, bias=False),
321 | nn.BatchNorm2d(planes * block.expansion),
322 | )
323 |
324 | layers = []
325 | layers.append(block(self.inplanes, planes, stride, downsample))
326 | self.inplanes = planes * block.expansion
327 | for i in range(1, blocks):
328 | layers.append(block(self.inplanes, planes))
329 |
330 | return nn.Sequential(*layers)
331 |
332 | def forward(self, x):
333 | x = self.conv1(x)
334 | x = self.bn1(x)
335 | x = self.relu(x)
336 | x = self.maxpool(x)
337 |
338 | x = self.layer1(x)
339 | x = self.layer2(x)
340 | x = self.layer3(x)
341 | x = self.layer4(x)
342 |
343 | x = self.avgpool(x)
344 | x = x.view(x.size(0), -1)
345 | x = self.fc(x)
346 |
347 | return x
348 |
349 |
350 | #===========================================================================================
351 | # Federated averaging: FedAvg
352 | def FedAvg(w):
353 | w_avg = copy.deepcopy(w[0])
354 | for k in w_avg.keys():
355 | for i in range(1, len(w)):
356 | w_avg[k] += w[i][k]
357 | w_avg[k] = torch.div(w_avg[k], len(w))
358 | return w_avg
359 | #====================================================
360 |
361 |
362 | if __name__ == "__main__":
363 |
364 |
365 | unique_test = []
366 | #===================================================================
367 | program = "FL ResNet18"
368 | print(f"---------{program}----------")
369 |
370 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
371 |
372 | args = parse_arguments()
373 | print(args)
374 |
375 | SEED = args.seed
376 | num_users = args.number_of_clients
377 | epochs = args.epochs
378 | frac = args.fac
379 | lr = args.lr
380 | dataset = args.dataset
381 |
382 | if args.dataset == "mnist" or args.dataset == "fmnist":
383 | input_channels = 1
384 | else:
385 | input_channels = 3
386 |
387 | if args.dataset == "ham10k":
388 | no_classes = 7
389 | else:
390 | no_classes = 10
391 |
392 | random.seed(SEED)
393 | np.random.seed(SEED)
394 | torch.manual_seed(SEED)
395 | torch.cuda.manual_seed(SEED)
396 |
397 |
398 | # To print in color during test/train
399 | def prRed(skk): print("\033[91m {}\033[00m" .format(skk))
400 | def prGreen(skk): print("\033[92m {}\033[00m" .format(skk))
401 |
402 |
403 | transform_train = transforms.Compose([
404 | transforms.RandomCrop(32, padding=4),
405 | transforms.RandomHorizontalFlip(),
406 | transforms.ToTensor(),
407 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
408 | ])
409 |
410 | transform_test = transforms.Compose([
411 | transforms.RandomCrop(32, padding=4),
412 | transforms.ToTensor(),
413 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
414 | ])
415 |
416 | # transform_train = None
417 | # transform_test = None
418 |
419 | if dataset == "cifar10_setting3":
420 | u_train_dataset,c_train_full_dataset, test_full_dataset, input_channels = datasets.load_full_dataset(dataset, "data", num_users, args.datapoints, transform_train, transform_test)
421 | dict_users = dataset_iid_setting3(u_train_dataset, c_train_full_dataset, num_users)
422 | dict_users_test = dataset_iid(test_full_dataset, num_users)
423 |
424 |
425 | net_glob = ResNet18(BasicBlock, [2, 2, 2, 2],input_channels, no_classes)
426 | net_glob.to(device)
427 | # print(net_glob)
428 |
429 | net_glob.train()
430 | w_glob = net_glob.state_dict()
431 |
432 | loss_train_collect = []
433 | acc_train_collect = []
434 | loss_test_collect = []
435 | acc_test_collect = []
436 |
437 | st = time.time()
438 |
439 |
440 | for iter in range(epochs):
441 | w_locals, loss_locals_train, acc_locals_train, loss_locals_test, acc_locals_test = [], [], [], [], []
442 | m = max(int(frac * num_users), 1)
443 | idxs_users = np.random.choice(range(num_users), m, replace = False)
444 |
445 | # Training/Testing simulation
446 | for idx in idxs_users: # each client
447 |
448 |
449 | if idx == 0:
450 | local = LocalUpdate(idx, lr, device, dataset_train = u_train_dataset, dataset_test = test_full_dataset, idxs = dict_users[idx], idxs_test = dict_users_test[idx])
451 | else:
452 | local = LocalUpdate(idx, lr, device, dataset_train = c_train_full_dataset, dataset_test = test_full_dataset, idxs = dict_users[idx], idxs_test = dict_users_test[idx])
453 | # Training ------------------
454 | w, loss_train, acc_train = local.train(net = copy.deepcopy(net_glob).to(device))
455 | w_locals.append(copy.deepcopy(w))
456 | loss_locals_train.append(copy.deepcopy(loss_train))
457 | acc_locals_train.append(copy.deepcopy(acc_train))
458 | # Testing -------------------
459 | loss_test, acc_test = local.evaluate(net = copy.deepcopy(net_glob).to(device))
460 | loss_locals_test.append(copy.deepcopy(loss_test))
461 | acc_locals_test.append(copy.deepcopy(acc_test))
462 |
463 |
464 |
465 | # Federation process
466 | w_glob = FedAvg(w_locals)
467 | print("------------------------------------------------")
468 | print("------ Federation process at Server-Side -------")
469 | print("------------------------------------------------")
470 |
471 | # update global model --- copy weight to net_glob -- distributed the model to all users
472 | net_glob.load_state_dict(w_glob)
473 |
474 | # Train/Test accuracy
475 | acc_avg_train = sum(acc_locals_train[1:]) / len(acc_locals_train[1:])
476 | acc_train_collect.append(acc_avg_train)
477 | acc_avg_test = sum(acc_locals_test[1:]) / len(acc_locals_test[1:])
478 | acc_test_collect.append(acc_avg_test)
479 |
480 | # Train/Test loss
481 | loss_avg_train = sum(loss_locals_train) / len(loss_locals_train)
482 | loss_train_collect.append(loss_avg_train)
483 | loss_avg_test = sum(loss_locals_test) / len(loss_locals_test)
484 | loss_test_collect.append(loss_avg_test)
485 |
486 |
487 | print('------------------- SERVER ----------------------------------------------')
488 | print('Train: Round {:3d}, Avg Accuracy {:.3f} | Avg Loss {:.3f}'.format(iter, acc_avg_train, loss_avg_train))
489 | print('Test: Round {:3d}, Avg Accuracy {:.3f} | Avg Loss {:.3f}'.format(iter, acc_avg_test, loss_avg_test))
490 | print('-------------------------------------------------------------------------')
491 |
492 |
493 | #===================================================================================
494 |
495 | print("Training and Evaluation completed!")
496 | et = time.time()
497 | print(f"Total time taken is {(et-st)/60} mins")
498 | print("Average C1 - C10 test accuracy: ", max(acc_test_collect))
499 | print("Large Client Test Accuracy: ", max(unique_test))
500 |
501 | #===============================================================================
502 | # Save output data to .excel file (we use for comparision plots)
503 | round_process = [i for i in range(1, len(acc_train_collect)+1)]
504 | df = DataFrame({'round': round_process,'acc_train':acc_train_collect, 'acc_test':acc_test_collect})
505 | file_name = f"results/FL/{program}_{args.batch_size}_{args.dataset}_{args.lr}_{args.epochs}_setting2"+".xlsx"
506 | df.to_excel(file_name, sheet_name= "v1_test", index = False)
507 |
508 |
509 |
510 | #=============================================================================
511 | # Program Completed
512 | #=============================================================================
513 |
--------------------------------------------------------------------------------
/PFSL_Setting3.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import string
4 | import socket
5 | import requests
6 | import sys
7 | import threading
8 | import time
9 | import torch
10 | from math import ceil
11 | from torchvision import transforms
12 | from utils.split_dataset import split_dataset, split_dataset_cifar10tl_exp, split_dataset_cifar_setting2
13 | from utils.client_simulation import generate_random_clients
14 | from utils.connections import send_object
15 | from utils.arg_parser import parse_arguments
16 | import matplotlib.pyplot as plt
17 | import time
18 | import multiprocessing
19 | import torch.optim as optim
20 | import copy
21 | from datetime import datetime
22 | from scipy.interpolate import make_interp_spline
23 | import numpy as np
24 | from ConnectedClient import ConnectedClient
25 | import importlib
26 | from utils.merge import merge_grads, merge_weights
27 | import pandas as pd
28 | import time
29 | from utils.split_dataset import DatasetFromSubset
30 | from utils import datasets,dataset_settings
31 | import torch.nn.functional as F
32 |
33 | #Helper function to load train and test data for each client
34 | class DatasetSplit(torch.utils.data.Dataset):
35 | def __init__(self, dataset, idxs):
36 | self.dataset = dataset
37 | self.idxs = list(idxs)
38 |
39 | def __len__(self):
40 | return len(self.idxs)
41 |
42 | def __getitem__(self, item):
43 | image, label = self.dataset[self.idxs[item]]
44 | return image, label
45 |
46 |
47 | class Client_try():
48 | def __init__(self, id, *args, **kwargs):
49 | super(Client_try, self).__init__(*args, **kwargs)
50 | self.id = id
51 | self.front_model = []
52 | self.back_model = []
53 | self.losses = []
54 | self.train_dataset = None
55 | self.test_dataset = None
56 | self.train_DataLoader = None
57 | self.test_DataLoader = None
58 | self.socket = None
59 | self.server_socket = None
60 | self.train_batch_size = None
61 | self.test_batch_size = None
62 | self.train_iterator = None
63 | self.test_iterator = None
64 | self.activations1 = None
65 | self.remote_activations1 = None
66 | self.outputs = None
67 | self.loss = None
68 | self.criterion = None
69 | self.data = None
70 | self.targets = None
71 | self.n_correct = 0
72 | self.n_samples = 0
73 | self.front_optimizer = None
74 | self.back_optimizer = None
75 | self.train_acc = []
76 | self.test_acc = []
77 | self.front_epsilons = []
78 | self.front_best_alphas = []
79 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
80 | # self.device = torch.device('cpu')
81 |
82 |
83 | def backward_back(self):
84 | self.loss.backward()
85 |
86 |
87 | def backward_front(self):
88 | print(self.remote_activations1.grad)
89 | self.activations1.backward(self.remote_activations1.grad)
90 |
91 |
92 | def calculate_loss(self):
93 | self.criterion = F.cross_entropy
94 | self.loss = self.criterion(self.outputs, self.targets)
95 |
96 |
97 | def calculate_test_acc(self):
98 | with torch.no_grad():
99 | _, self.predicted = torch.max(self.outputs.data, 1)
100 | self.n_correct = (self.predicted == self.targets).sum().item()
101 | self.n_samples = self.targets.size(0)
102 | # self.test_acc.append(100.0 * self.n_correct/self.n_samples)
103 | return 100.0 * self.n_correct/self.n_samples
104 | # print(f'Acc: {self.test_acc[-1]}')
105 |
106 |
107 | def calculate_train_acc(self):
108 | with torch.no_grad():
109 | _, self.predicted = torch.max(self.outputs.data, 1)
110 | self.n_correct = (self.predicted == self.targets).sum().item()
111 | self.n_samples = self.targets.size(0)
112 | # self.train_acc.append(100.0 * self.n_correct/self.n_samples)
113 | return 100.0 * self.n_correct/self.n_samples
114 | # print(f'Acc: {self.train_acc[-1]}')
115 |
116 | def create_DataLoader(self, dataset_train,dataset_test,idxs,idxs_test, batch_size, test_batch_size):
117 | self.train_DataLoader = torch.utils.data.DataLoader(DatasetSplit(dataset_train, idxs), batch_size = batch_size, shuffle = True)
118 | self.test_DataLoader = torch.utils.data.DataLoader(DatasetSplit(dataset_test, idxs_test), batch_size = test_batch_size, shuffle = True)
119 |
120 | def forward_back(self):
121 | self.back_model.to(self.device)
122 | self.outputs = self.back_model(self.remote_activations2)
123 |
124 |
125 | def forward_front(self, type, u_id = None):
126 |
127 | if type == "train":
128 |
129 | if self.id == u_id:
130 |
131 | try:
132 | self.data, self.targets = next(self.train_iterator)
133 | except StopIteration:
134 | self.train_iterator = iter(self.train_DataLoader)
135 | self.data, self.targets = next(self.train_iterator)
136 | else:
137 | self.data, self.targets = next(self.train_iterator)
138 |
139 | else:
140 | self.data, self.targets = next(self.test_iterator)
141 | self.data, self.targets = self.data.to(self.device), self.targets.to(self.device)
142 | self.front_model.to(self.device)
143 | self.activations1 = self.front_model(self.data)
144 | self.remote_activations1 = self.activations1.detach().requires_grad_(True)
145 |
146 |
147 | def get_model(self):
148 | model = get_object(self.socket)
149 | self.front_model = model['front']
150 | self.back_model = model['back']
151 |
152 | def idle(self):
153 | pass
154 |
155 |
156 | def load_data(self, dataset, transform):
157 | try:
158 | dataset_path = os.path.join(f'data/{dataset}/{self.id}')
159 | except:
160 | raise Exception(f'Dataset not found for client {self.id}')
161 | self.train_dataset = torch.load(f'{dataset_path}/train/{self.id}.pt')
162 | self.test_dataset = torch.load('data/cifar10_setting2/test/common_test.pt')
163 |
164 | self.train_dataset = DatasetFromSubset(
165 | self.train_dataset, transform=transform
166 | )
167 | self.test_dataset = DatasetFromSubset(
168 | self.test_dataset, transform=transform
169 | )
170 |
171 |
172 | def step_front(self):
173 | self.front_optimizer.step()
174 |
175 |
176 | def step_back(self):
177 | self.back_optimizer.step()
178 |
179 |
180 | def zero_grad_front(self):
181 | self.front_optimizer.zero_grad()
182 |
183 |
184 | def zero_grad_back(self):
185 | self.back_optimizer.zero_grad()
186 |
187 |
188 |
189 |
190 | def generate_random_client_ids_try(num_clients, id_len=4) -> list:
191 | client_ids = []
192 | for _ in range(num_clients):
193 | client_ids.append(''.join(random.sample("abcdefghijklmnopqrstuvwxyz1234567890", id_len)))
194 | return client_ids
195 |
196 | def generate_random_clients_try(num_clients) -> dict:
197 | client_ids = generate_random_client_ids_try(num_clients)
198 | clients = {}
199 | for id in client_ids:
200 | clients[id] = Client_try(id)
201 | return clients
202 |
203 |
204 |
205 | def initialize_client(client, dataset_train,dataset_test,idxs,idxs_test, batch_size, test_batch_size):
206 | client.create_DataLoader(dataset_train,dataset_test, idxs, idxs_test, batch_size, test_batch_size)
207 |
208 | #Plots class distribution of train data available to each client
209 | def plot_class_distribution(clients, dataset, batch_size, epochs, opt, client_ids):
210 | class_distribution=dict()
211 | number_of_clients=len(client_ids)
212 | if(len(clients)<=20):
213 | plot_for_clients=client_ids
214 | else:
215 | plot_for_clients=random.sample(client_ids, 20)
216 |
217 | fig, ax = plt.subplots(nrows=(int(ceil(len(client_ids)/5))), ncols=5, figsize=(15, 10))
218 | j=0
219 | i=0
220 |
221 | #plot histogram
222 | for client_id in plot_for_clients:
223 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels'])
224 | class_distribution[client_id]=df['labels'].value_counts().sort_index()
225 | df['labels'].value_counts().sort_index().plot(ax = ax[i,j], kind = 'bar', ylabel = 'frequency', xlabel=client_id)
226 | j+=1
227 | if(j==5 or j==10 or j==15):
228 | i+=1
229 | j=0
230 | fig.tight_layout()
231 | plt.show()
232 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_histogram.png')
233 |
234 | max_len=0
235 | #plot line graphs
236 | for client_id in plot_for_clients:
237 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels'])
238 | df['labels'].value_counts().sort_index().plot(kind = 'line', ylabel = 'frequency', label=client_id)
239 | max_len=max(max_len, list(df['labels'].value_counts(sort=False)[df.labels.mode()])[0])
240 | plt.xticks(np.arange(0,10))
241 | plt.ylim(0, max_len)
242 | plt.legend()
243 | plt.show()
244 |
245 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_line_graph.png')
246 |
247 | return class_distribution
248 |
249 | if __name__ == "__main__":
250 |
251 |
252 | args = parse_arguments()
253 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
254 | print("Arguments provided", args)
255 |
256 |
257 | random.seed(args.seed)
258 | torch.manual_seed(args.seed)
259 |
260 | overall_test_acc = []
261 | overall_train_acc = []
262 |
263 | print('Generating random clients...', end='')
264 | clients = generate_random_clients_try(args.number_of_clients)
265 | client_ids = list(clients.keys())
266 | print('Done')
267 |
268 | train_full_dataset, test_full_dataset, input_channels = datasets.load_full_dataset(args.dataset, "data", args.number_of_clients, args.datapoints, args.pretrained)
269 | dict_users_train , dict_users_test = split_dataset_cifar_setting2(client_ids, train_full_dataset, test_full_dataset)
270 |
271 | transform=None
272 |
273 | #Assigning train and test data to each client depending for each client
274 | print('Initializing clients...')
275 | for i,(_, client) in enumerate(clients.items()):
276 | initialize_client(client, train_full_dataset, test_full_dataset, dict_users_train[i], dict_users_test[i], args.batch_size, args.test_batch_size)
277 |
278 | print('Client Intialization complete.')
279 | # Train and test data intialisation complete
280 |
281 | # class_distribution=plot_class_distribution(clients, args.dataset, args.batch_size, args.epochs, args.opt_iden, client_ids)
282 |
283 | #Assigning front, center and back models and their optimizers for all the clients
284 | model = importlib.import_module(f'models.{args.model}')
285 |
286 | for _, client in clients.items():
287 | client.front_model = model.front(input_channels, pretrained=args.pretrained)
288 | client.back_model = model.back(pretrained=args.pretrained)
289 | print('Done')
290 |
291 |
292 | for _, client in clients.items():
293 | # client.front_optimizer = optim.SGD(client.front_model.parameters(), lr=args.lr, momentum=0.9)
294 | # client.back_optimizer = optim.SGD(client.back_model.parameters(), lr=args.lr, momentum=0.9)
295 | client.front_optimizer = optim.Adam(client.front_model.parameters(), lr=args.lr)
296 | client.back_optimizer = optim.Adam(client.back_model.parameters(), lr=args.lr)
297 |
298 | common_client = clients[client_ids[1]]
299 | unique_client_id = client_ids[0]
300 | unique_client = clients[unique_client_id]
301 | num_iterations_common = ceil(len(common_client.train_DataLoader.dataset)/args.batch_size)
302 | num_test_iterations = ceil(len(common_client.test_DataLoader.dataset)/args.test_batch_size)
303 | unique_client_iterations = ceil(len(unique_client.train_DataLoader.dataset)/args.batch_size)
304 | unique_client.train_iterator = iter(unique_client.train_DataLoader)
305 |
306 | sc_clients = {} #server copy clients
307 |
308 | for iden in client_ids:
309 | sc_clients[iden] = ConnectedClient(iden, None)
310 |
311 | for _,s_client in sc_clients.items():
312 | s_client.center_model = model.center(pretrained=args.pretrained)
313 | s_client.center_model.to(device)
314 | # s_client.center_optimizer = optim.SGD(s_client.center_model.parameters(), lr=args.lr, momentum=0.9)
315 | s_client.center_optimizer = optim.Adam(s_client.center_model.parameters(), args.lr)
316 |
317 |
318 | st = time.time()
319 |
320 | max_train_small_clients = 0
321 | max_test_small_clients = 0
322 | max_train_large_client = 0
323 | max_test_large_client = 0
324 |
325 | #Starting the training process
326 | for epoch in range(args.epochs):
327 |
328 | overall_train_acc.append(0)
329 | for i,(_, client) in enumerate(clients.items()):
330 | client.train_acc.append(0)
331 | if i != 0:
332 | client.train_iterator = iter(client.train_DataLoader)
333 |
334 | for c_id, client in clients.items():
335 | #For every batch in the current epoch
336 | for iteration in range(num_iterations_common):
337 |
338 |
339 | if c_id == unique_client_id:
340 | client.forward_front("train", u_id = c_id)
341 | else:
342 | client.forward_front("train")
343 |
344 | sc_clients[c_id].remote_activations1 = clients[c_id].remote_activations1
345 | sc_clients[c_id].forward_center()
346 | client.remote_activations2 = sc_clients[c_id].remote_activations2
347 | client.forward_back()
348 | client.calculate_loss()
349 | client.backward_back()
350 | sc_clients[c_id].remote_activations2 = clients[c_id].remote_activations2
351 | sc_clients[c_id].backward_center()
352 |
353 | # client.remote_activations1 = copy.deepcopy(sc_clients[client_id].remote_activations1)
354 | # client.backward_front()
355 |
356 | client.step_back()
357 | client.zero_grad_back()
358 | sc_clients[c_id].center_optimizer.step()
359 | sc_clients[c_id].center_optimizer.zero_grad()
360 | client.train_acc[-1] += client.calculate_train_acc()
361 |
362 | client.train_acc[-1] /= num_iterations_common
363 | if c_id == unique_client_id:
364 | print("Large client train accuracy", client.train_acc[-1])
365 | if client.train_acc[-1] >= max_train_large_client:
366 | max_train_large_client = client.train_acc[-1]
367 | else:
368 | overall_train_acc[-1] += client.train_acc[-1]
369 |
370 | overall_train_acc[-1] /= (len(clients)-1)
371 | if overall_train_acc[-1] >= max_train_small_clients:
372 | max_train_small_clients = overall_train_acc[-1]
373 | print(f'Epoch {epoch} C1-C10 Average Train Acc: {overall_train_acc[-1]}\n')
374 |
375 | # merge weights below uncomment
376 | params = []
377 | for _, client in sc_clients.items():
378 | params.append(copy.deepcopy(client.center_model.state_dict()))
379 | w_glob = merge_weights(params)
380 |
381 | for _, client in sc_clients.items():
382 | client.center_model.load_state_dict(w_glob)
383 |
384 | params = []
385 | for _, client in clients.items():
386 | params.append(copy.deepcopy(client.back_model.state_dict()))
387 | w_glob_cb = merge_weights(params)
388 |
389 | for _, client in clients.items():
390 | client.back_model.load_state_dict(w_glob_cb)
391 |
392 |
393 |
394 |
395 |
396 | # Testing on every 5th epoch
397 | if epoch%5 == 0:
398 | with torch.no_grad():
399 | test_acc = 0
400 | overall_test_acc.append(0)
401 | for _, client in clients.items():
402 | client.test_acc.append(0)
403 | client.test_iterator = iter(client.test_DataLoader)
404 |
405 | for client_id, client in clients.items():
406 | for iteration in range(num_test_iterations):
407 |
408 | client.forward_front("test")
409 | sc_clients[client_id].remote_activations1 = clients[client_id].remote_activations1
410 | sc_clients[client_id].forward_center()
411 | client.remote_activations2 = sc_clients[client_id].remote_activations2
412 | client.forward_back()
413 | client.test_acc[-1] += client.calculate_test_acc()
414 |
415 | client.test_acc[-1] /= num_test_iterations
416 | if client_id == unique_client_id:
417 | print("Large client test accuracy", client.test_acc[-1])
418 | if client.test_acc[-1] >= max_test_large_client:
419 | max_test_large_client = client.train_acc[-1]
420 | else:
421 | overall_test_acc[-1] += client.test_acc[-1] #not including test accuracy of unique client
422 |
423 | overall_test_acc[-1] /= (len(clients)-1)
424 |
425 | if overall_test_acc[-1] >= max_test_small_clients:
426 | max_test_small_clients = overall_test_acc[-1]
427 | print(f' Epoch {epoch} C1-C10 Average Test Acc: {overall_test_acc[-1]}\n')
428 |
429 | # print("Average C1 - C10 Train accuracy: ", max_train_small_clients)
430 | # print("Large Client train Accuracy: ", max_train_large_client)
431 | print("Average C1 - C10 test accuracy: ", max_test_small_clients)
432 | print("Large Client test Accuracy: ", max_test_large_client)
433 |
434 |
435 |
436 |
437 | timestamp = int(datetime.now().timestamp())
438 | plot_config = f'''dataset: {args.dataset},
439 | model: {args.model},
440 | batch_size: {args.batch_size}, lr: {args.lr},
441 | '''
442 |
443 | et = time.time()
444 | print(f"Time taken for this run {(et - st)/60} mins")
445 |
446 |
447 | #BELOW CODE TO PLOT MULTIPLE LINES ON A SINGLE PLOT ONE LINE FOR EACH CLIENT
448 | # for client_id, client in clients.items():
449 | # plt.plot(list(range(args.epochs)), client.train_acc, label=f'{client_id} (Max:{max(client.train_acc):.4f})')
450 | # plt.plot(list(range(args.epochs)), overall_train_acc, label=f'Average (Max:{max(overall_train_acc):.4f})')
451 | # plt.title(f'{args.number_of_clients} Clients: Train Accuracy vs. Epochs')
452 | # plt.ylabel('Train Accuracy')
453 | # plt.xlabel('Epochs')
454 | # plt.legend()
455 | # plt.ioff()
456 | # plt.savefig(f'./results/train_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight')
457 | # plt.show()
458 |
459 | # for client_id, client in clients.items():
460 | # plt.plot(list(range(args.epochs)), client.test_acc, label=f'{client_id} (Max:{max(client.test_acc):.4f})')
461 | # plt.plot(list(range(args.epochs)), overall_test_acc, label=f'Average (Max:{max(overall_test_acc):.4f})')
462 | # plt.title(f'{args.number_of_clients} Clients: Test Accuracy vs. Epochs')
463 | # plt.ylabel('Test Accuracy')
464 | # plt.xlabel('Epochs')
465 | # plt.legend()
466 | # plt.ioff()
467 | # plt.savefig(f'./results/test_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight')
468 | # plt.show()
469 |
--------------------------------------------------------------------------------
/system_simulation_e2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import string
4 | import socket
5 | import requests
6 | import sys
7 | import threading
8 | import time
9 | import torch
10 | from math import ceil
11 | from torchvision import transforms
12 | from utils.split_dataset import split_dataset, split_dataset_cifar10tl_exp, split_dataset_cifar_setting2
13 | from utils.client_simulation import generate_random_clients
14 | from utils.connections import send_object
15 | from utils.arg_parser import parse_arguments
16 | import matplotlib.pyplot as plt
17 | import time
18 | import multiprocessing
19 | import torch.optim as optim
20 | import copy
21 | from datetime import datetime
22 | from scipy.interpolate import make_interp_spline
23 | import numpy as np
24 | from ConnectedClient import ConnectedClient
25 | import importlib
26 | from utils.merge import merge_grads, merge_weights
27 | import pandas as pd
28 | import time
29 | from utils.split_dataset import DatasetFromSubset
30 | from utils import datasets,dataset_settings
31 | import torch.nn.functional as F
32 |
33 | class DatasetSplit(torch.utils.data.Dataset):
34 | def __init__(self, dataset, idxs):
35 | self.dataset = dataset
36 | self.idxs = list(idxs)
37 |
38 | def __len__(self):
39 | return len(self.idxs)
40 |
41 | def __getitem__(self, item):
42 | image, label = self.dataset[self.idxs[item]]
43 | return image, label
44 |
45 |
46 | class Client_try():
47 | def __init__(self, id, *args, **kwargs):
48 | super(Client_try, self).__init__(*args, **kwargs)
49 | self.id = id
50 | self.front_model = []
51 | self.back_model = []
52 | self.losses = []
53 | self.train_dataset = None
54 | self.test_dataset = None
55 | self.train_DataLoader = None
56 | self.test_DataLoader = None
57 | self.socket = None
58 | self.server_socket = None
59 | self.train_batch_size = None
60 | self.test_batch_size = None
61 | self.train_iterator = None
62 | self.test_iterator = None
63 | self.activations1 = None
64 | self.remote_activations1 = None
65 | self.outputs = None
66 | self.loss = None
67 | self.criterion = None
68 | self.data = None
69 | self.targets = None
70 | self.n_correct = 0
71 | self.n_samples = 0
72 | self.front_optimizer = None
73 | self.back_optimizer = None
74 | self.train_acc = []
75 | self.test_acc = []
76 | self.front_epsilons = []
77 | self.front_best_alphas = []
78 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79 | # self.device = torch.device('cpu')
80 |
81 |
82 | def backward_back(self):
83 | self.loss.backward()
84 |
85 |
86 | def backward_front(self):
87 | print(self.remote_activations1.grad)
88 | self.activations1.backward(self.remote_activations1.grad)
89 |
90 |
91 | def calculate_loss(self):
92 | self.criterion = F.cross_entropy
93 | self.loss = self.criterion(self.outputs, self.targets)
94 |
95 |
96 | def calculate_test_acc(self):
97 | with torch.no_grad():
98 | _, self.predicted = torch.max(self.outputs.data, 1)
99 | self.n_correct = (self.predicted == self.targets).sum().item()
100 | self.n_samples = self.targets.size(0)
101 | # self.test_acc.append(100.0 * self.n_correct/self.n_samples)
102 | return 100.0 * self.n_correct/self.n_samples
103 | # print(f'Acc: {self.test_acc[-1]}')
104 |
105 |
106 | def calculate_train_acc(self):
107 | with torch.no_grad():
108 | _, self.predicted = torch.max(self.outputs.data, 1)
109 | self.n_correct = (self.predicted == self.targets).sum().item()
110 | self.n_samples = self.targets.size(0)
111 | # self.train_acc.append(100.0 * self.n_correct/self.n_samples)
112 | return 100.0 * self.n_correct/self.n_samples
113 | # print(f'Acc: {self.train_acc[-1]}')
114 |
115 | def create_DataLoader(self, dataset_train,dataset_test,idxs,idxs_test, batch_size, test_batch_size):
116 | self.train_DataLoader = torch.utils.data.DataLoader(DatasetSplit(dataset_train, idxs), batch_size = batch_size, shuffle = True)
117 | self.test_DataLoader = torch.utils.data.DataLoader(DatasetSplit(dataset_test, idxs_test), batch_size = test_batch_size, shuffle = True)
118 |
119 | def forward_back(self):
120 | self.back_model.to(self.device)
121 | self.outputs = self.back_model(self.remote_activations2)
122 |
123 |
124 | def forward_front(self, type):
125 |
126 | if type == "train":
127 | self.data, self.targets = next(self.train_iterator)
128 | else:
129 | self.data, self.targets = next(self.test_iterator)
130 | self.data, self.targets = self.data.to(self.device), self.targets.to(self.device)
131 | self.front_model.to(self.device)
132 | self.activations1 = self.front_model(self.data)
133 | self.remote_activations1 = self.activations1.detach().requires_grad_(True)
134 |
135 |
136 | def get_model(self):
137 | model = get_object(self.socket)
138 | self.front_model = model['front']
139 | self.back_model = model['back']
140 |
141 | def idle(self):
142 | pass
143 |
144 |
145 | def load_data(self, dataset, transform):
146 | try:
147 | dataset_path = os.path.join(f'data/{dataset}/{self.id}')
148 | except:
149 | raise Exception(f'Dataset not found for client {self.id}')
150 | self.train_dataset = torch.load(f'{dataset_path}/train/{self.id}.pt')
151 | self.test_dataset = torch.load('data/cifar10_setting2/test/common_test.pt')
152 |
153 | self.train_dataset = DatasetFromSubset(
154 | self.train_dataset, transform=transform
155 | )
156 | self.test_dataset = DatasetFromSubset(
157 | self.test_dataset, transform=transform
158 | )
159 |
160 |
161 | def step_front(self):
162 | self.front_optimizer.step()
163 |
164 |
165 | def step_back(self):
166 | self.back_optimizer.step()
167 |
168 |
169 | def zero_grad_front(self):
170 | self.front_optimizer.zero_grad()
171 |
172 |
173 | def zero_grad_back(self):
174 | self.back_optimizer.zero_grad()
175 |
176 |
177 |
178 |
179 | def generate_random_client_ids_try(num_clients, id_len=4) -> list:
180 | client_ids = []
181 | for _ in range(num_clients):
182 | client_ids.append(''.join(random.sample("abcdefghijklmnopqrstuvwxyz1234567890", id_len)))
183 | return client_ids
184 |
185 | def generate_random_clients_try(num_clients) -> dict:
186 | client_ids = generate_random_client_ids_try(num_clients)
187 | clients = {}
188 | for id in client_ids:
189 | clients[id] = Client_try(id)
190 | return clients
191 |
192 |
193 |
194 | def initialize_client(client, dataset_train,dataset_test,idxs,idxs_test, batch_size, test_batch_size):
195 | client.create_DataLoader(dataset_train,dataset_test, idxs, idxs_test, batch_size, test_batch_size)
196 |
197 |
198 | def select_random_clients(clients):
199 | random_clients = {}
200 | client_ids = list(clients.keys())
201 | random_index = random.randint(0,len(client_ids)-1)
202 | random_client_ids = client_ids[random_index]
203 |
204 | print(random_client_ids)
205 | print(clients)
206 |
207 | for random_client_id in random_client_ids:
208 | random_clients[random_client_id] = clients[random_client_id]
209 | return random_clients
210 |
211 |
212 | def plot_class_distribution(clients, dataset, batch_size, epochs, opt, client_ids):
213 | class_distribution=dict()
214 | number_of_clients=len(client_ids)
215 | if(len(clients)<=20):
216 | plot_for_clients=client_ids
217 | else:
218 | plot_for_clients=random.sample(client_ids, 20)
219 |
220 | fig, ax = plt.subplots(nrows=(int(ceil(len(client_ids)/5))), ncols=5, figsize=(15, 10))
221 | j=0
222 | i=0
223 |
224 | #plot histogram
225 | for client_id in plot_for_clients:
226 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels'])
227 | class_distribution[client_id]=df['labels'].value_counts().sort_index()
228 | df['labels'].value_counts().sort_index().plot(ax = ax[i,j], kind = 'bar', ylabel = 'frequency', xlabel=client_id)
229 | j+=1
230 | if(j==5 or j==10 or j==15):
231 | i+=1
232 | j=0
233 | fig.tight_layout()
234 | plt.show()
235 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_histogram.png')
236 |
237 | max_len=0
238 | #plot line graphs
239 | for client_id in plot_for_clients:
240 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels'])
241 | df['labels'].value_counts().sort_index().plot(kind = 'line', ylabel = 'frequency', label=client_id)
242 | max_len=max(max_len, list(df['labels'].value_counts(sort=False)[df.labels.mode()])[0])
243 | plt.xticks(np.arange(0,10))
244 | plt.ylim(0, max_len)
245 | plt.legend()
246 | plt.show()
247 | # plt.savefig(f'./results/class_vs_freq/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_line_graph.png')
248 |
249 | return class_distribution
250 |
251 | if __name__ == "__main__":
252 |
253 |
254 | args = parse_arguments()
255 | print(args)
256 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
257 | print("Arguments provided", args)
258 |
259 |
260 |
261 |
262 | random.seed(args.seed)
263 | torch.manual_seed(args.seed)
264 |
265 | overall_test_acc = []
266 | overall_train_acc = []
267 |
268 | print('Generating random clients...', end='')
269 | clients = generate_random_clients_try(args.number_of_clients)
270 | client_ids = list(clients.keys())
271 | print('Done')
272 |
273 | train_full_dataset, test_full_dataset, input_channels = datasets.load_full_dataset(args.dataset, "data", args.number_of_clients, args.datapoints, args.pretrained)
274 | dict_users_train , dict_users_test = split_dataset_cifar_setting2(client_ids, train_full_dataset, test_full_dataset,5000,5000)
275 |
276 | # print(f'Random client ids:{str(client_ids)}')
277 | transform=None
278 |
279 |
280 | print('Initializing clients...')
281 | for i,(_, client) in enumerate(clients.items()):
282 | initialize_client(client, train_full_dataset, test_full_dataset, dict_users_train[i], dict_users_test[i], args.batch_size, args.test_batch_size)
283 | # if(args.dataset!='ham10000')
284 | # class_distribution=plot_class_distribution(clients, args.dataset, args.batch_size, args.epochs, args.opt_iden, client_ids)
285 | print('Client Intialization complete.')
286 | model = importlib.import_module(f'models.{args.model}')
287 |
288 | for _, client in clients.items():
289 | client.front_model = model.front(input_channels, pretrained=args.pretrained)
290 | client.back_model = model.back(pretrained=args.pretrained)
291 | print('Done')
292 |
293 | for _, client in clients.items():
294 |
295 | # client.front_optimizer = optim.SGD(client.front_model.parameters(), lr=args.lr, momentum=0.9)
296 | # client.back_optimizer = optim.SGD(client.back_model.parameters(), lr=args.lr, momentum=0.9)
297 | client.front_optimizer = optim.Adam(client.front_model.parameters(), lr=args.lr)
298 | client.back_optimizer = optim.Adam(client.back_model.parameters(), lr=args.lr)
299 |
300 |
301 | sample_client = clients[client_ids[0]]
302 | #number of iterations will be (di)/b.
303 | num_iterations = ceil(50 // args.batch_size)
304 | num_test_iterations = ceil(len(sample_client.test_DataLoader.dataset)//args.test_batch_size)
305 | # print(num_iterations)
306 | # print(num_test_iterations)
307 |
308 |
309 | #define train iterators for all clients
310 | for _,client in clients.items():
311 | client.train_iterator = iter(client.train_DataLoader)
312 |
313 |
314 | sc_clients = {} #server copy clients
315 |
316 | for iden in client_ids:
317 | sc_clients[iden] = ConnectedClient(iden, None)
318 |
319 | for _,s_client in sc_clients.items():
320 | s_client.center_model = model.center(pretrained=args.pretrained)
321 | s_client.center_model.to(device)
322 | # s_client.center_optimizer = optim.SGD(s_client.center_model.parameters(), lr=args.lr, momentum=0.9)
323 | s_client.center_optimizer = optim.Adam(s_client.center_model.parameters(), args.lr)
324 |
325 | st = time.time()
326 |
327 | for epoch in range(args.epochs):
328 |
329 | overall_train_acc.append(0)
330 | for i,(_, client) in enumerate(clients.items()):
331 | client.train_acc.append(0)
332 |
333 | for c_id, client in clients.items():
334 | for iteration in range(num_iterations):
335 |
336 | client.forward_front("train")
337 | sc_clients[c_id].remote_activations1 = clients[c_id].remote_activations1
338 | sc_clients[c_id].forward_center()
339 | client.remote_activations2 = sc_clients[c_id].remote_activations2
340 | client.forward_back()
341 | client.calculate_loss()
342 | client.backward_back()
343 | sc_clients[c_id].remote_activations2 = clients[c_id].remote_activations2
344 | sc_clients[c_id].backward_center()
345 |
346 | # client.remote_activations1 = copy.deepcopy(sc_clients[client_id].remote_activations1)
347 | # client.backward_front()
348 |
349 | client.step_back()
350 | client.zero_grad_back()
351 | sc_clients[c_id].center_optimizer.step()
352 | sc_clients[c_id].center_optimizer.zero_grad()
353 | client.train_acc[-1] += client.calculate_train_acc()
354 |
355 | client.train_acc[-1] /= num_iterations
356 | overall_train_acc[-1] += client.train_acc[-1]
357 |
358 | overall_train_acc[-1] /= len(clients)
359 | print(f'Epoch {epoch} Personalized Average Train Acc: {overall_train_acc[-1]}')
360 |
361 | num_clients = len(client_ids)
362 | drop_clients_ids = []
363 | rate = args.rate
364 | num_dropoff = int(rate * num_clients)
365 | print("number of clients dropped off", num_dropoff)
366 |
367 | for _ in range(num_dropoff):
368 | drop_clients_ids.append(int(random.uniform(0,(num_clients-1))))
369 |
370 |
371 | # merge weights below uncomment
372 | params = []
373 | for i,(_, client) in enumerate(sc_clients.items()):
374 | if i not in drop_clients_ids:
375 | print(i, "ids of clients that are considered for weight merging")
376 | params.append(copy.deepcopy(client.center_model.state_dict()))
377 | w_glob = merge_weights(params)
378 |
379 | for _, client in sc_clients.items():
380 | client.center_model.load_state_dict(w_glob)
381 |
382 | params = []
383 | for i,(_, client) in enumerate(clients.items()):
384 | if i not in drop_clients_ids:
385 | params.append(copy.deepcopy(client.back_model.state_dict()))
386 | w_glob_cb = merge_weights(params)
387 |
388 | for _, client in clients.items():
389 | client.back_model.load_state_dict(w_glob_cb)
390 |
391 |
392 |
393 | if epoch%1 == 0:
394 | with torch.no_grad():
395 | test_acc = 0
396 | overall_test_acc.append(0)
397 | for _, client in clients.items():
398 | client.test_acc.append(0)
399 | client.test_iterator = iter(client.test_DataLoader)
400 |
401 | for client_id, client in clients.items():
402 | for iteration in range(num_test_iterations):
403 |
404 | client.forward_front("test")
405 | sc_clients[client_id].remote_activations1 = clients[client_id].remote_activations1
406 | sc_clients[client_id].forward_center()
407 | client.remote_activations2 = sc_clients[client_id].remote_activations2
408 | client.forward_back()
409 | client.test_acc[-1] += client.calculate_test_acc()
410 |
411 | client.test_acc[-1] /= num_test_iterations
412 | overall_test_acc[-1] += client.test_acc[-1] #not including test accuracy of unique client
413 |
414 | overall_test_acc[-1] /= len(clients)
415 | print(f' Personalized Average Test Acc: {overall_test_acc[-1]}')
416 |
417 |
418 |
419 |
420 | timestamp = int(datetime.now().timestamp())
421 | plot_config = f'''dataset: {args.dataset},
422 | model: {args.model},
423 | batch_size: {args.batch_size}, lr: {args.lr},
424 | '''
425 |
426 | et = time.time()
427 | print(f"Time taken for this run {(et - st)/60} mins")
428 |
429 |
430 | X = range(args.epochs)
431 | all_clients_stacked_train = np.array([client.train_acc for _,client in clients.items()])
432 | all_clients_stacked_test = np.array([client.test_acc for _,client in clients.items()])
433 | epochs_train_std = np.std(all_clients_stacked_train,axis = 0, dtype = np.float64)
434 | epochs_test_std = np.std(all_clients_stacked_test,axis = 0, dtype = np.float64)
435 |
436 | #Y_train is the average client train accuracies at each epoch
437 | #epoch_train_std is the standard deviation of clients train accuracies at each epoch
438 | Y_train = overall_train_acc
439 | Y_train_lower = Y_train - (1.65 * epochs_train_std) #95% of the values lie between 1.65*std
440 | Y_train_upper = Y_train + (1.65 * epochs_train_std)
441 |
442 | Y_test = overall_test_acc
443 | Y_test_lower = Y_test - (1.65 * epochs_test_std) #95% of the values lie between 1.65*std
444 | Y_test_upper = Y_test + (1.65 * epochs_test_std)
445 |
446 | Y_train_cv = epochs_train_std / Y_train
447 | Y_test_cv = epochs_test_std / Y_test
448 |
449 | plt.figure(0)
450 | plt.plot(X, Y_train)
451 | plt.fill_between(X,Y_train_lower , Y_train_upper, color='blue', alpha=0.25)
452 | # plt.savefig(f'./results/train_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight')
453 | plt.show()
454 |
455 | plt.figure(1)
456 | plt.plot(X, Y_test)
457 | plt.fill_between(X,Y_test_lower , Y_test_upper, color='blue', alpha=0.25)
458 | # plt.savefig(f'./results/test_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight')
459 | plt.show()
460 |
461 |
462 | plt.figure(2)
463 | plt.plot(X, Y_train_cv)
464 | plt.show()
465 |
466 |
467 | plt.figure(3)
468 | plt.plot(X, Y_test_cv)
469 | plt.show()
470 |
471 |
472 |
473 |
474 |
475 | #BELOW CODE TO PLOT MULTIPLE LINES ON A SINGLE PLOT ONE LINE FOR EACH CLIENT
476 | # for client_id, client in clients.items():
477 | # plt.plot(list(range(args.epochs)), client.train_acc, label=f'{client_id} (Max:{max(client.train_acc):.4f})')
478 | # plt.plot(list(range(args.epochs)), overall_train_acc, label=f'Average (Max:{max(overall_train_acc):.4f})')
479 | # plt.title(f'{args.number_of_clients} Clients: Train Accuracy vs. Epochs')
480 | # plt.ylabel('Train Accuracy')
481 | # plt.xlabel('Epochs')
482 | # plt.legend()
483 | # plt.ioff()
484 | # plt.savefig(f'./results/train_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight')
485 | # plt.show()
486 |
487 | # for client_id, client in clients.items():
488 | # plt.plot(list(range(args.epochs)), client.test_acc, label=f'{client_id} (Max:{max(client.test_acc):.4f})')
489 | # plt.plot(list(range(args.epochs)), overall_test_acc, label=f'Average (Max:{max(overall_test_acc):.4f})')
490 | # plt.title(f'{args.number_of_clients} Clients: Test Accuracy vs. Epochs')
491 | # plt.ylabel('Test Accuracy')
492 | # plt.xlabel('Epochs')
493 | # plt.legend()
494 | # plt.ioff()
495 | # plt.savefig(f'./results/test_acc_vs_epoch/{args.dataset}_{args.number_of_clients}clients_{args.epochs}epochs_{args.batch_size}batch_{args.opt}.png', bbox_inches='tight')
496 | # plt.show()
497 |
498 |
499 |
--------------------------------------------------------------------------------
/FL_DR.py:
--------------------------------------------------------------------------------
1 | #===========================================================
2 | # Federated learning: ResNet18
3 | # ===========================================================
4 | import torch
5 | from torch import nn
6 | from torchvision import transforms
7 | from torch.utils.data import DataLoader, Dataset
8 | from pandas import DataFrame
9 | import pandas as pd
10 | from sklearn.model_selection import train_test_split
11 | from PIL import Image
12 | from glob import glob
13 | import math
14 | import random
15 | import numpy as np
16 | import os
17 | import matplotlib
18 | matplotlib.use('Agg')
19 | import matplotlib.pyplot as plt
20 | import copy
21 | import argparse
22 | # from utils import datasets, dataset_settings
23 | from utils import get_eye_dataset
24 | import time
25 | from math import ceil
26 | from sklearn.metrics import classification_report
27 | from sklearn.preprocessing import LabelBinarizer
28 | from sklearn.metrics import roc_curve, auc, roc_auc_score
29 |
30 |
31 |
32 | ## ARGPARSER
33 |
34 | def parse_arguments():
35 | # Training settings
36 | parser = argparse.ArgumentParser(
37 | description="Splitfed V1 configurations",
38 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
39 | )
40 | parser.add_argument(
41 | "--seed",
42 | type=int,
43 | default=1234,
44 | help="Random seed",
45 | )
46 | parser.add_argument(
47 | "-c",
48 | "--number_of_clients",
49 | type=int,
50 | default=5,
51 | metavar="C",
52 | help="Number of Clients",
53 | )
54 |
55 | parser.add_argument(
56 | "-n",
57 | "--epochs",
58 | type=int,
59 | default=20,
60 | metavar="N",
61 | help="Total number of epochs to train",
62 | )
63 |
64 | parser.add_argument(
65 | "--fac",
66 | type=float,
67 | default= 1.0,
68 | metavar="N",
69 | help="fraction of active/participating clients, if 1 then all clients participate in SFLV1",
70 | )
71 |
72 | parser.add_argument(
73 | "--lr",
74 | type=float,
75 | default=0.001,
76 | metavar="LR",
77 | help="Learning rate",
78 | )
79 |
80 | parser.add_argument(
81 | "--dataset",
82 | type=str,
83 | default="mnist",
84 | help="States dataset to be used",
85 | )
86 | parser.add_argument(
87 | "-b",
88 | "--batch_size",
89 | type=int,
90 | default=1024,
91 | metavar="B",
92 | help="Batch size",
93 | )
94 | parser.add_argument(
95 |
96 | "--test_batch_size",
97 | type=int,
98 | default=512,
99 | metavar="B",
100 | help="Batch size",
101 | )
102 | parser.add_argument(
103 | "--setting",
104 | type=str,
105 | default="setting1",
106 |
107 | )
108 |
109 | parser.add_argument(
110 | "--datapoints",
111 | type=int,
112 | default=500,
113 |
114 | )
115 |
116 |
117 | parser.add_argument(
118 | "--opt_iden",
119 | type=str,
120 | # default=False
121 | )
122 |
123 | args = parser.parse_args()
124 | return args
125 |
126 | #==============================================================================================================
127 | # Client Side Program
128 | #==============================================================================================================
129 | class DatasetSplit(Dataset):
130 | def __init__(self, dataset, idxs):
131 | self.dataset = dataset
132 | self.idxs = list(idxs)
133 |
134 | def __len__(self):
135 | return len(self.idxs)
136 |
137 | def __getitem__(self, item):
138 | image, label = self.dataset[self.idxs[item]]
139 | return image, label
140 |
141 | # Client-side functions associated with Training and Testing
142 | class LocalUpdate(object):
143 | def __init__(self, idx, lr, device, client_idx, dataset_train = None, dataset_test = None, idxs = None, idxs_test = None):
144 | self.idx = idx
145 | self.device = device
146 | self.lr = lr
147 | self.local_ep = 1
148 | self.loss_func = nn.CrossEntropyLoss()
149 | self.selected_clients = []
150 |
151 | if(client_idx>=0 and client_idx<=4 ):
152 | train_data_id=1
153 | test_data_id=3
154 | elif(client_idx>=5 and client_idx<=9):
155 | train_data_id=2
156 | test_data_id=4
157 |
158 | self.train_dataset=get_eye_dataset.get_eye_data(idxs, train_data_id)
159 | self.test_dataset=get_eye_dataset.get_eye_data(idxs_test, test_data_id)
160 | self.ldr_train = DataLoader(self.train_dataset, batch_size = args.batch_size, shuffle = True)
161 | self.ldr_test = DataLoader(self.test_dataset, batch_size = args.test_batch_size, shuffle = True)
162 | clients[idx]=self
163 | # # if idx == 0:
164 | # self.ldr_train = DataLoader(dataset_train, batch_size = args.batch_size, shuffle = True)
165 | # else:
166 | # self.ldr_train = DataLoader(DatasetSplit(dataset_train, idxs), batch_size = args.batch_size, shuffle = True)
167 |
168 | # self.ldr_test = DataLoader(dataset_test, batch_size = args.batch_size, shuffle = True)
169 |
170 | def train(self, net):
171 | net.train()
172 | # train and update
173 | #optimizer = torch.optim.SGD(net.parameters(), lr = self.lr, momentum = 0.5)
174 | optimizer = torch.optim.Adam(net.parameters(), lr = self.lr)
175 |
176 | epoch_acc = []
177 | epoch_loss = []
178 | for iter in range(self.local_ep):
179 | batch_acc = []
180 | batch_loss = []
181 |
182 | for batch_idx, (images, labels) in enumerate(self.ldr_train):
183 | images, labels = images.to(self.device), labels.to(self.device)
184 | optimizer.zero_grad()
185 | #---------forward prop-------------
186 | fx = net(images)
187 |
188 | # calculate loss
189 | loss = self.loss_func(fx, labels)
190 | # calculate accuracy
191 | acc = calculate_accuracy(fx, labels)
192 |
193 | #--------backward prop--------------
194 | loss.backward()
195 | optimizer.step()
196 |
197 | batch_loss.append(loss.item())
198 | batch_acc.append(acc.item())
199 |
200 |
201 |
202 | epoch_loss.append(sum(batch_loss)/len(batch_loss))
203 | epoch_acc.append(sum(batch_acc)/len(batch_acc))
204 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss), sum(epoch_acc) / len(epoch_acc)
205 |
206 | def evaluate(self, net, ell):
207 | global targets, outputs
208 | net.eval()
209 |
210 | epoch_acc = []
211 | epoch_loss = []
212 | with torch.no_grad():
213 | batch_acc = []
214 | batch_loss = []
215 | for batch_idx, (images, labels) in enumerate(self.ldr_test):
216 | images, labels = images.to(self.device), labels.to(self.device)
217 | #---------forward prop-------------
218 | fx = net(images)
219 | _,pred_t = torch.max(fx, dim=1)
220 | outputs.extend(pred_t.cpu().detach().numpy().tolist())
221 | targets.extend(labels.cpu().detach().numpy().tolist())
222 |
223 |
224 |
225 | # calculate loss
226 | loss = self.loss_func(fx, labels)
227 | # calculate accuracy
228 | acc = calculate_accuracy(fx, labels)
229 |
230 | batch_loss.append(loss.item())
231 | batch_acc.append(acc.item())
232 |
233 |
234 | epoch_loss.append(sum(batch_loss)/len(batch_loss))
235 | epoch_acc.append(sum(batch_acc)/len(batch_acc))
236 |
237 | clr=classification_report(np.array(targets), np.array(outputs), output_dict=True, zero_division=0)
238 | curr_f1=(clr['0']['f1-score']+clr['1']['f1-score']+clr['2']['f1-score'])/3
239 | macro_avg_f1_3classes.append(curr_f1)
240 | macro_avg_f1_dict[idx]=curr_f1
241 |
242 | targets=[]
243 | outputs=[]
244 |
245 | return sum(epoch_loss) / len(epoch_loss), sum(epoch_acc) / len(epoch_acc)
246 |
247 |
248 |
249 | #=====================================================================================================
250 | # dataset_iid() will create a dictionary to collect the indices of the data samples randomly for each client
251 |
252 |
253 | def dataset_iid(dataset, num_users):
254 |
255 | num_items = int(len(dataset)/num_users)
256 | dict_users, all_idxs = {}, [i for i in range(len(dataset))]
257 | for i in range(num_users):
258 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace = False))
259 | all_idxs = list(set(all_idxs) - dict_users[i])
260 | return dict_users
261 |
262 | #====================================================================================================
263 | # Server Side Program
264 | #====================================================================================================
265 | def calculate_accuracy(fx, y):
266 | preds = fx.max(1, keepdim=True)[1]
267 | correct = preds.eq(y.view_as(preds)).sum()
268 | acc = 100.00 *correct.float()/preds.shape[0]
269 | return acc
270 |
271 | #=============================================================================
272 | # Model definition: ResNet18
273 | #=============================================================================
274 | # building a ResNet18 Architecture
275 | def conv3x3(in_planes, out_planes, stride=1):
276 | "3x3 convolution with padding"
277 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
278 | padding=1, bias=False)
279 |
280 |
281 | def plot_class_distribution(clients, client_ids):
282 | class_distribution=dict()
283 | number_of_clients=len(client_ids)
284 | if(len(clients)<=20):
285 | plot_for_clients=client_ids
286 | else:
287 | plot_for_clients=random.sample(client_ids, 20)
288 |
289 | fig, ax = plt.subplots(nrows=(int(ceil(len(client_ids)/5))), ncols=5, figsize=(15, 10))
290 | j=0
291 | i=0
292 |
293 | #plot histogram
294 | for client_id in plot_for_clients:
295 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels'])
296 | class_distribution[client_id]=df['labels'].value_counts().sort_index()
297 | df['labels'].value_counts().sort_index().plot(ax = ax[i,j], kind = 'bar', ylabel = 'frequency', xlabel=client_id)
298 | j+=1
299 | if(j==5 or j==10 or j==15):
300 | i+=1
301 | j=0
302 | fig.tight_layout()
303 | plt.show()
304 |
305 | plt.savefig('plot_fl.png')
306 | # plt.savefig(f'./results/classvsfreq/settin3{dataset}.png')
307 |
308 | max_len=0
309 | #plot line graphs
310 | for client_id in plot_for_clients:
311 | df=pd.DataFrame(list(clients[client_id].train_dataset), columns=['images', 'labels'])
312 | df['labels'].value_counts().sort_index().plot(kind = 'line', ylabel = 'frequency', label=client_id)
313 | max_len=max(max_len, list(df['labels'].value_counts(sort=False)[df.labels.mode()])[0])
314 | plt.xticks(np.arange(0,10))
315 | plt.ylim(0, max_len)
316 | plt.legend()
317 | plt.show()
318 |
319 | # plt.savefig(f'./results/class_vs_fre/q/{dataset}_{number_of_clients}clients_{epochs}epochs_{batch_size}batch_{opt}_line_graph.png')
320 |
321 | return class_distribution
322 |
323 |
324 | class BasicBlock(nn.Module):
325 | expansion = 1
326 |
327 | def __init__(self, inplanes, planes, stride=1, downsample=None):
328 | super(BasicBlock, self).__init__()
329 | self.conv1 = conv3x3(inplanes, planes, stride)
330 | self.bn1 = nn.BatchNorm2d(planes)
331 | self.relu = nn.ReLU(inplace=True)
332 | self.conv2 = conv3x3(planes, planes)
333 | self.bn2 = nn.BatchNorm2d(planes)
334 | self.downsample = downsample
335 | self.stride = stride
336 |
337 | def forward(self, x):
338 | residual = x
339 |
340 | out = self.conv1(x)
341 | out = self.bn1(out)
342 | out = self.relu(out)
343 |
344 | out = self.conv2(out)
345 | out = self.bn2(out)
346 |
347 | if self.downsample is not None:
348 | residual = self.downsample(x)
349 |
350 | out += residual
351 | out = self.relu(out)
352 |
353 | return out
354 |
355 | class ResNet18(nn.Module):
356 |
357 | def __init__(self, block, layers, input_channels, num_classes=1000):
358 | self.inplanes = 64
359 | super(ResNet18, self).__init__()
360 | self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
361 | bias=False)
362 | self.bn1 = nn.BatchNorm2d(64)
363 | self.relu = nn.ReLU(inplace=True)
364 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
365 | self.layer1 = self._make_layer(block, 64, layers[0])
366 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
367 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
368 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
369 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
370 | self.fc = nn.Linear(512 * block.expansion, num_classes)
371 |
372 | for m in self.modules():
373 | if isinstance(m, nn.Conv2d):
374 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
375 | m.weight.data.normal_(0, math.sqrt(2. / n))
376 | elif isinstance(m, nn.BatchNorm2d):
377 | m.weight.data.fill_(1)
378 | m.bias.data.zero_()
379 |
380 | def _make_layer(self, block, planes, blocks, stride=1):
381 | downsample = None
382 | if stride != 1 or self.inplanes != planes * block.expansion:
383 | downsample = nn.Sequential(
384 | nn.Conv2d(self.inplanes, planes * block.expansion,
385 | kernel_size=1, stride=stride, bias=False),
386 | nn.BatchNorm2d(planes * block.expansion),
387 | )
388 |
389 | layers = []
390 | layers.append(block(self.inplanes, planes, stride, downsample))
391 | self.inplanes = planes * block.expansion
392 | for i in range(1, blocks):
393 | layers.append(block(self.inplanes, planes))
394 |
395 | return nn.Sequential(*layers)
396 |
397 | def forward(self, x):
398 | x = self.conv1(x)
399 | x = self.bn1(x)
400 | x = self.relu(x)
401 | x = self.maxpool(x)
402 |
403 | x = self.layer1(x)
404 | x = self.layer2(x)
405 | x = self.layer3(x)
406 | x = self.layer4(x)
407 |
408 | x = self.avgpool(x)
409 | x = x.view(x.size(0), -1)
410 | x = self.fc(x)
411 |
412 | return x
413 |
414 | #===========================================================================================
415 | # Federated averaging: FedAvg
416 | def FedAvg(w):
417 | w_avg = copy.deepcopy(w[0])
418 | for k in w_avg.keys():
419 | for i in range(1, len(w)):
420 | w_avg[k] += w[i][k]
421 | w_avg[k] = torch.div(w_avg[k], len(w))
422 | return w_avg
423 | #====================================================
424 |
425 |
426 | if __name__ == "__main__":
427 |
428 | #===================================================================
429 | program = "FL ResNet18"
430 | print(f"---------{program}----------")
431 |
432 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
433 |
434 | args = parse_arguments()
435 |
436 | SEED = args.seed
437 | num_users = 10
438 | epochs = args.epochs
439 | frac = args.fac
440 | lr = args.lr
441 | dataset = args.dataset
442 |
443 | input_channels=3
444 | no_classes=3
445 |
446 | random.seed(SEED)
447 | np.random.seed(SEED)
448 | torch.manual_seed(SEED)
449 | torch.cuda.manual_seed(SEED)
450 | global outputs, targets
451 | targets=[]
452 | outputs=[]
453 | max_accuracy=0
454 | max_train_accuracy=0
455 | max_c0_4_test=0
456 | max_c0_f1=0
457 | max_c5_9_test=0
458 | max_c5_f1=0
459 | global clients
460 | clients={}
461 |
462 |
463 | d1, d2=get_eye_dataset.get_idxs()
464 | dict_users, dict_users_test=d1, d2
465 |
466 | net_glob = ResNet18(BasicBlock, [2, 2, 2, 2],input_channels, no_classes)
467 | net_glob.to(device)
468 |
469 |
470 | net_glob.train()
471 | w_glob = net_glob.state_dict()
472 |
473 | loss_train_collect = []
474 | acc_train_collect = []
475 | loss_test_collect = []
476 | acc_test_collect = []
477 | macro_avg_f1_3classes=[]
478 | macro_avg_f1_dict={}
479 |
480 | st = time.time()
481 | max_epoch, max_f1=0,0
482 |
483 | for iter in range(epochs):
484 | w_locals, loss_locals_train, acc_locals_train, loss_locals_test, acc_locals_test, acc_locals_test1, acc_locals_test2, macro_avg_f1_3classes = [], [], [], [], [], [], [], []
485 | m = max(int(frac * num_users), 1)
486 | idxs_users = np.random.choice(range(num_users), m, replace = False)
487 |
488 | # Training/Testing simulation
489 | for idx in idxs_users: # each client
490 |
491 | local = LocalUpdate(idx, lr, device, client_idx=idx, idxs = dict_users[idx], idxs_test = dict_users_test[idx])
492 | # Training ------------------
493 | w, loss_train, acc_train = local.train(net = copy.deepcopy(net_glob).to(device))
494 | w_locals.append(copy.deepcopy(w))
495 | loss_locals_train.append(copy.deepcopy(loss_train))
496 | acc_locals_train.append(copy.deepcopy(acc_train))
497 | # Testing -------------------
498 | loss_test, acc_test = local.evaluate(net = copy.deepcopy(net_glob).to(device), ell=iter)
499 | loss_locals_test.append(copy.deepcopy(loss_test))
500 | acc_locals_test.append(copy.deepcopy(acc_test))
501 | if(idx>=0 and idx<5):
502 | acc_locals_test1.append(copy.deepcopy(acc_test))
503 | elif(idx>=5 and idx<10):
504 | acc_locals_test2.append(copy.deepcopy(acc_test))
505 |
506 | # Federation process
507 | w_glob = FedAvg(w_locals)
508 |
509 | # update global model --- copy weight to net_glob -- distributed the model to all users
510 | net_glob.load_state_dict(w_glob)
511 |
512 | # Train/Test accuracy
513 | acc_avg_train = sum(acc_locals_train) / len(acc_locals_train)
514 | acc_train_collect.append(acc_avg_train)
515 | acc_avg_test = sum(acc_locals_test) / len(acc_locals_test)
516 | acc_test_collect.append(acc_avg_test)
517 | acc_avg_test1= sum(acc_locals_test1) / len(acc_locals_test1)
518 |
519 | acc_avg_test2 = sum(acc_locals_test2) / len(acc_locals_test2)
520 |
521 | f1_avg_all_user=sum(macro_avg_f1_3classes)/ len(macro_avg_f1_3classes)
522 |
523 | # Train/Test loss
524 | loss_avg_train = sum(loss_locals_train) / len(loss_locals_train)
525 | loss_train_collect.append(loss_avg_train)
526 | loss_avg_test = sum(loss_locals_test) / len(loss_locals_test)
527 | loss_test_collect.append(loss_avg_test)
528 |
529 |
530 | # print('------------------- SERVER ----------------------------------------------')
531 | # print('Train: Round {:3d}, Avg Accuracy {:.3f} | Avg Loss {:.3f}'.format(iter, acc_avg_train, loss_avg_train))
532 | # print('Test: Round {:3d}, Avg Accuracy {:.3f} | Avg Loss {:.3f} | F1 Score {:.3f}'.format(iter, acc_avg_test, loss_avg_test, f1_avg_all_user))
533 | # print('-------------------------------------------------------------------------')
534 | print(f'\rEpoch: {iter}', end='')
535 |
536 |
537 | if(acc_avg_test> max_accuracy):
538 | max_accuracy=acc_avg_test
539 | max_train_accuracy=acc_avg_train
540 | max_epoch=iter
541 | max_c0_f1=macro_avg_f1_dict[0]
542 | max_c5_f1=macro_avg_f1_dict[5]
543 | max_c0_4_test=acc_avg_test1
544 | max_c5_9_test=acc_avg_test2
545 |
546 | macro_avg_f1_dict={}
547 |
548 |
549 | #===================================================================================
550 |
551 |
552 | et = time.time()
553 | print("\nTraining and Evaluation completed!")
554 | print(f"Time taken for this run {(et - st)/60} mins")
555 | print(f'Maximum Personalized Average Test Acc: {max_accuracy} ')
556 | print(f'Maximum Personalized Average Train Acc: {max_train_accuracy} ')
557 | print(f'Client0 F1 Scores: {max_c0_f1}')
558 | print(f'Client5 F1 Scores:{max_c5_f1}')
559 | print(f'Personalized Average Test Accuracy for Clients 0 to 4 ": {max_c0_4_test}')
560 | print(f'Personalized Average Test Accuracy for Clients 5 to 9": {max_c5_9_test}')
561 | #===============================================================================
562 | # Save output data to .excel file (we use for comparision plots)
563 | round_process = [i for i in range(1, len(acc_train_collect)+1)]
564 | df = DataFrame({'round': round_process,'acc_train':acc_train_collect, 'acc_test':acc_test_collect})
565 | file_name = f"results/FL/{program}_{args.batch_size}_{args.dataset}_{args.lr}_{args.epochs}"+".xlsx"
566 | df.to_excel(file_name, sheet_name= "v1_test", index = False)
567 |
568 |
569 |
570 | #=============================================================================
571 | # Program Completed
572 | #=============================================================================
573 |
--------------------------------------------------------------------------------