├── Figures
└── method.PNG
├── LICENSE
├── requirements.txt
├── .gitignore
├── centralized.py
├── local.py
├── SLViT.py
├── README.md
├── dataset.py
├── FeSViBS.py
├── utils.py
└── models.py
/Figures/method.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/faresmalik/FeSViBS/HEAD/Figures/method.PNG
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Faris_Malik
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | asttokens==2.0.5
2 | backcall==0.2.0
3 | certifi==2022.12.7
4 | charset-normalizer==3.1.0
5 | cmake==3.26.0
6 | comm==0.1.2
7 | contourpy==1.0.7
8 | cycler==0.11.0
9 | debugpy==1.5.1
10 | decorator==5.1.1
11 | entrypoints==0.4
12 | executing==0.8.3
13 | filelock==3.10.0
14 | fire==0.5.0
15 | fonttools==4.39.0
16 | idna==3.4
17 | imageio==2.26.0
18 | ipykernel==6.19.2
19 | ipython==8.10.0
20 | jedi==0.18.1
21 | Jinja2==3.1.2
22 | joblib==1.2.0
23 | jupyter_client==7.4.9
24 | jupyter_core==5.2.0
25 | kiwisolver==1.4.4
26 | lazy_loader==0.1
27 | lit==15.0.7
28 | MarkupSafe==2.1.2
29 | matplotlib==3.7.1
30 | matplotlib-inline==0.1.6
31 | medmnist==2.1.0
32 | mpmath==1.3.0
33 | nest-asyncio==1.5.6
34 | networkx==3.0
35 | numpy==1.24.2
36 | nvidia-cublas-cu11==11.10.3.66
37 | nvidia-cuda-cupti-cu11==11.7.101
38 | nvidia-cuda-nvrtc-cu11==11.7.99
39 | nvidia-cuda-runtime-cu11==11.7.99
40 | nvidia-cudnn-cu11==8.5.0.96
41 | nvidia-cufft-cu11==10.9.0.58
42 | nvidia-curand-cu11==10.2.10.91
43 | nvidia-cusolver-cu11==11.4.0.1
44 | nvidia-cusparse-cu11==11.7.4.91
45 | nvidia-nccl-cu11==2.14.3
46 | nvidia-nvtx-cu11==11.7.91
47 | packaging==22.0
48 | pandas==1.5.3
49 | parso==0.8.3
50 | pexpect==4.8.0
51 | pickleshare==0.7.5
52 | Pillow==9.0.1
53 | pip==23.0.1
54 | platformdirs==2.5.2
55 | prompt-toolkit==3.0.36
56 | psutil==5.9.0
57 | ptyprocess==0.7.0
58 | pure-eval==0.2.2
59 | Pygments==2.11.2
60 | pyparsing==3.0.9
61 | python-dateutil==2.8.2
62 | pytz==2022.7.1
63 | PyWavelets==1.4.1
64 | pyzmq==23.2.0
65 | requests==2.28.2
66 | scikit-image==0.20.0
67 | scikit-learn==1.2.2
68 | scipy==1.10.1
69 | setuptools==65.6.3
70 | six==1.16.0
71 | stack-data==0.2.0
72 | sympy==1.11.1
73 | termcolor==2.2.0
74 | threadpoolctl==3.1.0
75 | tifffile==2023.3.15
76 | timm==0.5.4
77 | torch==2.0.0
78 | torchvision==0.15.1
79 | tornado==6.2
80 | tqdm==4.65.0
81 | traitlets==5.7.1
82 | triton==2.0.0
83 | typing_extensions==4.5.0
84 | urllib3==1.26.15
85 | wcwidth==0.2.5
86 | wheel==0.38.4
87 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | vit_base_r50_s16_224_0.0001lr_bloodmnist_200rounds_Centralized/
132 | vit_base_r50_s16_224_0.0001lr_isic2019_200rounds_Centralized/
133 | vit_base_r50_s16_224_0.0001lr_isic2019_200rounds_Local/
134 | vit_base_r50_s16_224_0.0001lr_HAM_200rounds_Local/
135 | vit_base_r50_s16_224_0.0001lr_bloodmnist_200rounds_Local/
136 | vit_base_r50_s16_224_0.0001lr_HAM_200rounds_Centralized/
137 | vit_base_r50_s16_224_0.0001lr_HAM_6Clients_FalseDP_32Batch_SViBS/
138 | vit_base_r50_s16_224_0.0001lr_bloodmnist_6Clients_FalseDP_32Batch_SViBS/
139 | vit_base_r50_s16_224_0.0001lr_HAM_6Clients_(4.0, 1e-05)DP_32Batch_SViBS/
140 | vit_base_r50_s16_224_0.0001lr_bloodmnist_6Clients_(4.0, 1e-05)DP_32Batch_SViBS/
141 | vit_base_r50_s16_224_0.0001lr_isic2019_6Clients_FalseDP_32Batch_SViBS/
142 | vit_base_r50_s16_224_0.0001lr_isic2019_6Clients_(4.0, 1e-05)DP_32Batch_SViBS/
143 | vit_base_r50_s16_224_0.0001lr_HAM_6Clients_1to6Blocks_32Batch__(5.0, 1e-05)DP_FeSViBS/
144 | vit_base_r50_s16_224_0.0001lr_HAM_6Clients_1to6Blocks_32Batch_FeSViBS/
145 | vit_base_r50_s16_224_0.0001lr_isic2019_6Clients_1to6Blocks_32Batch__(5.0, 1e-05)DP_FeSViBS/
146 | vit_base_r50_s16_224_0.0001lr_isic2019_6Clients_1to6Blocks_32Batch_FeSViBS/
147 |
148 |
--------------------------------------------------------------------------------
/centralized.py:
--------------------------------------------------------------------------------
1 | import timm
2 | import torch
3 | import dataset
4 | import os
5 | import random
6 | import numpy as np
7 | from curses.ascii import FF
8 | from models import CentralizedFashion
9 | from torch import nn
10 | import argparse
11 |
12 | from dataset import skinCancer, bloodmnisit, isic2019
13 |
14 |
15 | def centralized(dataset_name, lr, batch_size, Epochs, input_size, num_workers, save_every_epochs, model_name, pretrained, opt_name, seed , base_dir, root_dir, csv_file_path):
16 |
17 | torch.manual_seed(seed)
18 | random.seed(seed)
19 | np.random.seed(seed)
20 |
21 |
22 | if torch.cuda.is_available():
23 | device = 'cuda'
24 | else:
25 | device = 'cpu'
26 |
27 |
28 | print('Creating Loggings Directory!')
29 | save_dir = f'{model_name}_{lr}lr_{dataset_name}_{Epochs}rounds_Centralized'
30 | os.mkdir(save_dir)
31 |
32 | print('Getting the Dataset and Dataloader!')
33 | if dataset_name == 'HAM':
34 | num_classes = 7
35 | train_loader, test_loader,_,_ = skinCancer(input_size= input_size, batch_size = batch_size, base_dir= base_dir, num_workers=num_workers)
36 | num_channels = 3
37 |
38 | elif dataset_name == 'bloodmnist':
39 | num_classes = 8
40 | train_loader, test_loader,_,_ = bloodmnisit(input_size= input_size, batch_size = batch_size, download= True, num_workers=num_workers)
41 | num_channels = 3
42 |
43 | elif dataset_name == 'isic2019':
44 | num_classes = 8
45 | _, _, train_loader, _, _, test_loader = isic2019(input_size= input_size, batch_size = batch_size, root_dir=root_dir, csv_file_path=csv_file_path, num_workers=num_workers)
46 | num_channels = 3
47 |
48 | print('Getting the model from timm library!')
49 | model = timm.create_model(
50 | model_name= model_name, pretrained= pretrained,
51 | num_classes = num_classes, in_chans=num_channels
52 | ).to(device)
53 |
54 |
55 | criterion = torch.nn.CrossEntropyLoss()
56 |
57 | centralized_network = CentralizedFashion(
58 | device= device, network=model, criterion= criterion,
59 | base_dir=save_dir
60 | )
61 |
62 | #Instantiate metrics and set optimizer
63 | centralized_network.init_logs()
64 | centralized_network.set_optimizer(name=opt_name, lr = lr)
65 |
66 | print(f'Train Centralized Fashion:\n model: {model_name}\n dataset: {dataset_name}\n LR: {lr}\n Number of Epochs: {Epochs}\n Loggings: {save_dir}\n')
67 | print('Start Training! \n')
68 |
69 | #Training and Evaluation Loop
70 | for r in range(Epochs):
71 | print(f"Round {r+1} / {Epochs}")
72 | centralized_network.train_round(train_loader)
73 | centralized_network.eval_round(test_loader)
74 | print('---------')
75 | if (r+1) % save_every_epochs == 0 and r != 0:
76 | centralized_network.save_pickles(save_dir)
77 | print('============================================')
78 |
79 |
80 | if __name__ == "__main__":
81 |
82 | parser = argparse.ArgumentParser(description='Run Centralized Experiments')
83 |
84 | parser.add_argument('--dataset_name', type=str, choices=['HAM', 'bloodmnist', 'isic2019'], help='Dataset Name')
85 | parser.add_argument('--input_size', type=int, default= 224, help='Input size --> (input_size, input_size), default : 224')
86 | parser.add_argument('--num_workers', type=int, default= 8, help='Number of workers for dataloaders, default : 8')
87 | parser.add_argument('--model_name', type=str, default= 'vit_base_r50_s16_224', help='Model name from timm library, default: vit_base_r50_s16_224')
88 | parser.add_argument('--pretrained', type=bool, default= False, help='Pretrained weights flag, default: False')
89 | parser.add_argument('--batch_size', type=int, default= 32, help='Batch size, default : 32')
90 | parser.add_argument('--Epochs', type=int, default= 200, help='Number of Epochs, default : 200')
91 | parser.add_argument('--opt_name', type=str, choices=['Adam'], default = 'Adam', help='Optimizer name, only ADAM optimizer is available')
92 | parser.add_argument('--lr', type=float, default= 1e-4, help='Learning rate, default : 1e-4')
93 | parser.add_argument('--save_every_epochs', type=int, default= 10, help='Save metrics every this number of epochs, default: 10')
94 | parser.add_argument('--seed', type=int, default= 105, help='Seed, default: 105')
95 | parser.add_argument('--base_dir', type=str, default= None, help='')
96 | parser.add_argument('--root_dir', type=str, default= None, help='')
97 | parser.add_argument('--csv_file_path', type=str, default=None, help='')
98 |
99 | args = parser.parse_args()
100 |
101 | centralized(
102 | dataset_name = args.dataset_name, input_size= args.input_size,
103 | num_workers= args.num_workers, model_name= args.model_name,
104 | pretrained= args.pretrained, batch_size= args.batch_size,
105 | Epochs= args.Epochs, opt_name= args.opt_name, lr= args.lr,
106 | save_every_epochs= args.save_every_epochs, seed= args.seed,
107 | base_dir= args.base_dir, root_dir= args.root_dir, csv_file_path= args.csv_file_path
108 | )
--------------------------------------------------------------------------------
/local.py:
--------------------------------------------------------------------------------
1 | import os
2 | import timm
3 | import torch
4 | import numpy as np
5 | from torch import nn
6 | import os
7 | import random
8 | import argparse
9 |
10 | from models import CentralizedFashion
11 | from dataset import skinCancer, bloodmnisit, isic2019, distribute_images
12 |
13 |
14 | def local(dataset_name, lr, batch_size, Epochs, input_size, num_workers, save_every_epochs, model_name, pretrained, opt_name, seed, base_dir, root_dir, csv_file_path, num_clients, local_arg):
15 |
16 | np.random.seed(seed)
17 | torch.manual_seed(seed)
18 | random.seed(seed)
19 |
20 | if torch.cuda.is_available():
21 | device = 'cuda'
22 | else:
23 | device = 'cpu'
24 |
25 | print('Load Dataset and DataLoader!')
26 | if dataset_name == 'HAM':
27 | num_classes = 7
28 | train_loader, test_loader, train_data, test_data = skinCancer(input_size= input_size, batch_size = batch_size, base_dir= base_dir, num_workers=num_workers)
29 | num_channels = 3
30 |
31 | elif dataset_name == 'bloodmnist':
32 | num_classes = 8
33 | train_loader, test_loader, train_data, test_data = bloodmnisit(input_size= input_size, batch_size = batch_size, download= True, num_workers=num_workers)
34 | num_channels = 3
35 |
36 | elif dataset_name == 'isic2019':
37 | num_classes = 8
38 | DATALOADERS, _, _, _, _, test_loader = isic2019(input_size= input_size, batch_size = batch_size, root_dir=root_dir, csv_file_path=csv_file_path, num_workers=num_workers)
39 | num_channels = 3
40 |
41 |
42 |
43 | print('Create Directory for metrics loggings!')
44 | save_dir = f'{model_name}_{lr}lr_{dataset_name}_{Epochs}rounds_Local'
45 | os.mkdir(save_dir)
46 |
47 | print(f'Train Local Fashion:\n Number of Clients :{num_clients}\n model: {model_name}\n dataset: {dataset_name}\n LR: {lr}\n Number of Epochs: {Epochs}\n Loggings: {save_dir}\n')
48 |
49 | if dataset_name in ['HAM', 'bloodmnist']:
50 | print(f'Distribute Dataset Among {num_clients} Clients')
51 |
52 | DATALOADERS, test_loader = distribute_images(
53 | dataset_name = dataset_name, train_data = train_data, num_clients= num_clients,
54 | test_data = test_data, batch_size = batch_size, num_workers= num_workers
55 | )
56 |
57 | print('Loading Model form timm Library for All clients!')
58 | model = [timm.create_model(
59 | model_name= model_name,
60 | num_classes= num_classes,
61 | in_chans = num_channels,
62 | pretrained= pretrained,
63 | ).to(device) for i in range(num_clients)]
64 |
65 | criterion = nn.CrossEntropyLoss()
66 |
67 | local = [CentralizedFashion(
68 | device = device,
69 | network = model[i], criterion = criterion,
70 | base_dir = save_dir
71 | ) for i in range(num_clients)]
72 |
73 |
74 | for i in range(num_clients):
75 | local[i].set_optimizer(opt_name, lr = lr)
76 | local[i].init_logs()
77 |
78 | for r in range(Epochs):
79 | print(f"Round {r+1} / {Epochs}")
80 | for client_i in range(num_clients):
81 | print(f'Client {client_i+1} / {num_clients}')
82 | local[client_i].train_round(DATALOADERS[client_i])
83 | local[client_i].eval_round(test_loader)
84 | print('---------')
85 | if (r+1) % save_every_epochs == 0 and r != 0:
86 | local[client_i].save_pickles(save_dir,local= local_arg, client_id=client_i+1)
87 | print('============================================')
88 |
89 |
90 | if __name__ == "__main__":
91 |
92 | parser = argparse.ArgumentParser(description='Run Centralized Experiments')
93 |
94 | parser.add_argument('--dataset_name', type=str, choices=['HAM', 'bloodmnist', 'isic2019'], help='Dataset Name')
95 | parser.add_argument('--num_clients', type=int, default= 6, help='Number of clients, default : 6')
96 | parser.add_argument('--local_arg', type=bool, default= True, help='Local Argument, default: True')
97 | parser.add_argument('--input_size', type=int, default= 224, help='Input size --> (input_size, input_size), default : 224')
98 | parser.add_argument('--num_workers', type=int, default= 8, help='Number of workers for dataloaders, default : 8')
99 | parser.add_argument('--model_name', type=str, default= 'vit_base_r50_s16_224', help='Model name from timm library, default: vit_base_r50_s16_224')
100 | parser.add_argument('--pretrained', type=bool, default= False, help='Pretrained weights flag, default: False')
101 | parser.add_argument('--batch_size', type=int, default= 32, help='Batch size, default : 32')
102 | parser.add_argument('--Epochs', type=int, default= 200, help='Number of Epochs, default : 200')
103 | parser.add_argument('--opt_name', type=str, choices=['Adam'], default = 'Adam', help='Optimizer name, only ADAM optimizer is available')
104 | parser.add_argument('--lr', type=float, default= 1e-4, help='Learning rate, default : 1e-4')
105 | parser.add_argument('--save_every_epochs', type=int, default= 10, help='Save metrics every this number of epochs, default: 10')
106 | parser.add_argument('--seed', type=int, default= 105, help='Seed, default: 105')
107 | parser.add_argument('--base_dir', type=str, default= None, help='')
108 | parser.add_argument('--root_dir', type=str, default= None, help='')
109 | parser.add_argument('--csv_file_path', type=str, default=None, help='')
110 |
111 | args = parser.parse_args()
112 |
113 | local(
114 | dataset_name = args.dataset_name, num_clients= args.num_clients,
115 | input_size= args.input_size, local_arg= args.local_arg,
116 | num_workers= args.num_workers, model_name= args.model_name,
117 | pretrained= args.pretrained, batch_size= args.batch_size,
118 | Epochs= args.Epochs, opt_name= args.opt_name, lr= args.lr,
119 | save_every_epochs= args.save_every_epochs, seed= args.seed,
120 | base_dir= args.base_dir, root_dir= args.root_dir, csv_file_path= args.csv_file_path
121 | )
--------------------------------------------------------------------------------
/SLViT.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from torch import nn
5 | import random
6 | from models import SLViT, SplitNetwork
7 | from dataset import skinCancer, bloodmnisit, isic2019
8 | import argparse
9 | from utils import weight_dec_global
10 |
11 | def slvit(dataset_name, lr, batch_size, Epochs, input_size, num_workers, save_every_epochs, model_name, pretrained, opt_name, seed , base_dir, root_dir, csv_file_path, num_clients, DP, epsilon, delta):
12 |
13 | np.random.seed(seed)
14 | torch.manual_seed(seed)
15 | random.seed(seed)
16 |
17 | if torch.cuda.is_available():
18 | device = 'cuda'
19 | else:
20 | device = 'cpu'
21 |
22 | mean = 0
23 | std = 1
24 | if DP:
25 | std = np.sqrt(2 * np.math.log(1.25/delta)) / epsilon
26 |
27 | save_dir = f'{model_name}_{lr}lr_{dataset_name}_{num_clients}Clients_{DP}DP_{batch_size}Batch_SLViT'
28 |
29 | if DP:
30 | save_dir = f'{model_name}_{lr}lr_{dataset_name}_{num_clients}Clients_({epsilon}, {delta})DP_{batch_size}Batch_SLViT'
31 |
32 | os.mkdir(save_dir)
33 |
34 | print('Getting the Dataset and Dataloader!')
35 | if dataset_name == 'HAM':
36 | num_classes = 7
37 | _, _, traindataset, testdataset = skinCancer(input_size= input_size, batch_size = batch_size, base_dir= base_dir, num_workers=num_workers)
38 | num_channels = 3
39 |
40 | elif dataset_name == 'bloodmnist':
41 | num_classes = 8
42 | _, _, traindataset, testdataset = bloodmnisit(input_size= input_size, batch_size = batch_size, download= True, num_workers=num_workers)
43 | num_channels = 3
44 |
45 | elif dataset_name == 'isic2019':
46 | num_classes = 8
47 | DATALOADERS, _, _, _, _, test_loader = isic2019(input_size= input_size, batch_size = batch_size, root_dir=root_dir, csv_file_path=csv_file_path, num_workers=num_workers)
48 | num_channels = 3
49 |
50 | slvit = SLViT(
51 | ViT_name= model_name, num_classes=num_classes,
52 | num_clients=num_clients, in_channels=num_channels,
53 | ViT_pretrained = pretrained,
54 | diff_privacy=DP, mean=mean, std = std
55 | ).to(device)
56 |
57 | criterion = nn.CrossEntropyLoss()
58 |
59 | Split = SplitNetwork(
60 | num_clients=num_clients, device = device,
61 | network = slvit, criterion = criterion, base_dir=save_dir,
62 | )
63 |
64 | print('Distribute Data')
65 | if dataset_name != 'isic2019':
66 | Split.distribute_images(dataset_name=dataset_name, train_data=traindataset, test_data=testdataset , batch_size = batch_size)
67 | else:
68 | Split.CLIENTS_DATALOADERS = DATALOADERS
69 | Split.testloader = test_loader
70 |
71 | Split.set_optimizer(opt_name, lr = lr)
72 | Split.init_logs()
73 |
74 | for r in range(Epochs):
75 | print(f"Round {r+1} / {Epochs}")
76 | agg_weights = None
77 | for client_i in range(num_clients):
78 | weight_dict = Split.train_round(client_i)
79 | if client_i ==0:
80 | agg_weights = weight_dict
81 | else:
82 | agg_weights['blocks'] += weight_dict['blocks']
83 | agg_weights['cls'] += weight_dict['cls']
84 | agg_weights['pos_embed'] += weight_dict['pos_embed']
85 |
86 | agg_weights['blocks'] /= num_clients
87 | agg_weights['cls'] /= num_clients
88 | agg_weights['pos_embed'] /= num_clients
89 |
90 | Split.network.vit.blocks = weight_dec_global(
91 | Split.network.vit.blocks,
92 | agg_weights['blocks'].to(device)
93 | )
94 |
95 | Split.network.vit.cls_token.data = agg_weights['cls'].to(device) + 0.0
96 | Split.network.vit.pos_embed.data = agg_weights['pos_embed'].to(device) + 0.0
97 |
98 | for client_i in range(num_clients):
99 | Split.eval_round(client_i)
100 |
101 | print('---------')
102 |
103 | if (r+1) % save_every_epochs == 0 and r != 0:
104 | Split.save_pickles(save_dir)
105 |
106 | print('============================================')
107 |
108 | if __name__ == "__main__":
109 |
110 | parser = argparse.ArgumentParser(description='Run Centralized Experiments')
111 |
112 | parser.add_argument('--dataset_name', type=str, choices=['HAM', 'bloodmnist', 'isic2019'], help='Dataset Name')
113 | parser.add_argument('--input_size', type=int, default= 224, help='Input size --> (input_size, input_size), default : 224')
114 | parser.add_argument('--num_workers', type=int, default= 8, help='Number of workers for dataloaders, default : 8')
115 | parser.add_argument('--num_clients', type=int, default= 6, help='Number of Clients, default : 6')
116 | parser.add_argument('--model_name', type=str, default= 'vit_base_r50_s16_224', help='Model name from timm library, default: vit_base_r50_s16_224')
117 | parser.add_argument('--pretrained', type=bool, default= False, help='Pretrained weights flag, default: False')
118 | parser.add_argument('--batch_size', type=int, default= 32, help='Batch size, default : 32')
119 | parser.add_argument('--Epochs', type=int, default= 200, help='Number of Epochs, default : 200')
120 | parser.add_argument('--opt_name', type=str, choices=['Adam'], default = 'Adam', help='Optimizer name, only ADAM optimizer is available')
121 | parser.add_argument('--lr', type=float, default= 1e-4, help='Learning rate, default : 1e-4')
122 | parser.add_argument('--save_every_epochs', type=int, default= 10, help='Save metrics every this number of epochs, default: 10')
123 | parser.add_argument('--seed', type=int, default= 105, help='Seed, default: 105')
124 | parser.add_argument('--base_dir', type=str, default= None, help='')
125 | parser.add_argument('--root_dir', type=str, default= None, help='')
126 | parser.add_argument('--csv_file_path', type=str, default=None, help='')
127 | parser.add_argument('--DP', type=bool, default= False, help='Differential Privacy , default: False')
128 | parser.add_argument('--epsilon', type=float, default= 0, help='Epsilon Value for differential privacy')
129 | parser.add_argument('--delta', type=float, default= 0.00001, help='Delta Value for differential privacy')
130 |
131 |
132 | args = parser.parse_args()
133 |
134 | slvit(
135 | dataset_name = args.dataset_name, input_size= args.input_size,
136 | num_workers= args.num_workers, model_name= args.model_name,
137 | pretrained= args.pretrained, batch_size= args.batch_size,
138 | Epochs= args.Epochs, opt_name= args.opt_name, lr= args.lr,
139 | save_every_epochs= args.save_every_epochs, seed= args.seed,
140 | base_dir= args.base_dir, root_dir= args.root_dir, csv_file_path= args.csv_file_path, num_clients = args.num_clients,
141 | DP = args.DP, epsilon = args.epsilon, delta = args.delta
142 | )
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FeSViBS
2 | Source code for MICCAI 2023 paper entitled: 'FeSViBS: Federated Split Learning of Vision Transformer with Block Sampling'
3 |
4 |
5 |
6 |
7 | 
8 |
9 | ## Abstract
10 | Data scarcity is a significant obstacle hindering the learning of powerful machine learning models in critical healthcare applications. Data-sharing mechanisms among multiple entities (e.g., hospitals) can accelerate model training and yield more accurate predictions. Recently, approaches such as Federated Learning (FL) and Split Learning (SL) have facilitated collaboration without the need to exchange private data. In this work, we propose a framework for medical imaging classification tasks called Federated Split learning of Vision transformer with Block Sampling (FeSViBS). The FeSViBS framework builds upon the existing federated split vision transformer and introduces a \emph{block sampling} module, which leverages intermediate features extracted by the Vision Transformer (ViT) at the server. This is achieved by sampling features (patch tokens) from an intermediate transformer block and distilling their information content into a pseudo class token before passing them back to the client. These pseudo class tokens serve as an effective feature augmentation strategy and enhances the generalizability of the learned model. We demonstrate the utility of our proposed method compared to other SL and FL approaches on three publicly available medical imaging datasets: HAM1000, BloodMNIST, and Fed-ISIC2019, under both IID and non-IID settings.
11 |
12 | ## Install Dependinces
13 | Install all dependincies by running the following command:
14 |
15 | ```
16 | pip install -r requirements.txt
17 |
18 | ```
19 |
20 | ## Datasets
21 |
22 | We conduct all experiments on **three** datasets:
23 |
24 | 1. HAM10000 [3] -- Can be downloaded from [here](https://www.kaggle.com/datasets/kmader/skin-cancer-mnist-ham10000?select=HAM10000_images_part_2)
25 | 2. Blood cells (BloodMNIST) -- MedMnist library [1]
26 | 3. Federated version of ISIC2019 dataset -- FLamby library [2]
27 |
28 | For the Federated ISIC2019 dataset, the path to __ISIC_2019_Training_Input_preprocessed__ directory and __train_test_split__ csv file, are required to run different methods on this dataset
29 |
30 | ## Running Centralized Training/Testing
31 | In order to run **Centralized Training** run the following command:
32 |
33 | ```
34 | python centralized.py --dataset_name [choose the dataset name] --opt_name [default is Adam] --lr [learning rate] --seed [seed number] --base_dir [path data folder for HAM] --save_every_epochs [Save pickle files] --root_dir [Path to ISIC_2019_Training_Input_preprocessed for ISIC2019] --csv_file_path [Path to train_test_split csv for ISIC2019] --Epochs [Number of Epochs]
35 |
36 | ```
37 |
38 |
39 | ## Running Local Training/Testing for Each Client
40 | In order to run **Local Training/Testing** run the following command:
41 |
42 | ```
43 | python local.py --local_arg True --dataset_name [choose the dataset name] --opt_name [default is Adam] --lr [learning rate] --seed [seed number] --base_dir [path data folder for HAM] --save_every_epochs [Save pickle files] --root_dir [Path to ISIC_2019_Training_Input_preprocessed for ISIC2019] --csv_file_path [Path to train_test_split csv for ISIC2019] --num_clients [Number of clients] --Epochs [Number of Epochs]
44 |
45 | ```
46 |
47 | ## Running Vanilla Split Learning with Vision Transformers (SLViT)
48 | In order to run **SLViT without** Differential Privacy (DP) run the following command:
49 |
50 | ```
51 | python SLViT.py --dataset_name [choose the dataset name] --opt_name [default is Adam] --lr [learning rate] --seed [seed number] --base_dir [path data folder for HAM] --save_every_epochs [Save pickle files] --root_dir [Path to ISIC_2019_Training_Input_preprocessed for ISIC2019] --csv_file_path [Path to train_test_split csv for ISIC2019] --num_clients [Number of clients] --Epochs [Number of Epochs]
52 |
53 | ```
54 |
55 | **SLViT with** Differential Privacy (DP) run the following command:
56 |
57 | ```
58 | python SLViT.py --DP True --epsilon [epsilon value] --delta [delta value] --dataset_name [choose the dataset name] --opt_name [default is Adam] --lr [learning rate] --seed [seed number] --base_dir [path data folder for HAM] --save_every_epochs [Save pickle files] --root_dir [Path to ISIC_2019_Training_Input_preprocessed for ISIC2019] --csv_file_path [Path to train_test_split csv for ISIC2019] --num_clients [Number of clients] --Epochs [Number of Epochs]
59 |
60 | ```
61 |
62 | ## Running Split Vision Transformer with Block Sampling (SViBS):
63 | In order to run **SViBS** run the following command:
64 |
65 | ```
66 | python FeSViBS.py --dataset_name [choose the dataset name] --opt_name [default is Adam] --lr [learning rate] --seed [seed number] --base_dir [path data folder for HAM] --save_every_epochs [Save pickle files] --root_dir [Path to ISIC_2019_Training_Input_preprocessed for ISIC2019] --csv_file_path [Path to train_test_split csv for ISIC2019] --num_clients [Number of clients] --Epochs [Number of Epochs] --initial_block 1 --final_block 6
67 |
68 | ```
69 |
70 | ## Running Federated Split Vision Transformer with Block Sampling (FeSViBS):
71 | In order to run **FeSViBS without** Differential Privacy (DP) run the following command:
72 |
73 | ```
74 | python FeSViBS.py --fesvibs_arg True --local_round [number of local rounds before federation] --dataset_name [choose the dataset name] --opt_name [default is Adam] --lr [learning rate] --seed [seed number] --base_dir [path data folder for HAM] --save_every_epochs [Save pickle files] --root_dir [Path to ISIC_2019_Training_Input_preprocessed for ISIC2019] --csv_file_path [Path to train_test_split csv for ISIC2019] --num_clients [Number of clients] --Epochs [Number of Epochs] --initial_block 1 --final_block 6
75 |
76 | ```
77 |
78 | In order to run **FeSViBS with** Differential Privacy (DP) run the following command:
79 |
80 | ```
81 | python FeSViBS.py --fesvibs_arg True --DP True --epsilon [epsilon value] --delta [delta value] --local_round [number of local rounds before federation] --dataset_name [choose the dataset name] --opt_name [default is Adam] --lr [learning rate] --seed [seed number] --base_dir [path data folder for HAM] --save_every_epochs [Save pickle files] --root_dir [Path to ISIC_2019_Training_Input_preprocessed for ISIC2019] --csv_file_path [Path to train_test_split csv for ISIC2019] --num_clients [Number of clients] --Epochs [Number of Epochs] --initial_block 1 --final_block 6
82 |
83 | ```
84 | ## Citation
85 | ```
86 | @misc{almalik2023fesvibs,
87 | title={FeSViBS: Federated Split Learning of Vision Transformer with Block Sampling},
88 | author={Faris Almalik and Naif Alkhunaizi and Ibrahim Almakky and Karthik Nandakumar},
89 | year={2023},
90 | eprint={2306.14638},
91 | archivePrefix={arXiv},
92 | primaryClass={cs.CV}
93 | }
94 |
95 | ```
96 | ## References
97 |
98 | [1] Yang, J., Shi, R., Ni, B.: Medmnist classification decathlon: A lightweight automl benchmark for medical image analysis. In: IEEE 18th International Symposium on Biomedical Imaging (ISBI). pp. 191–195 (2021)
99 |
100 | [2] du Terrail, J.O., Ayed, S.S., Cyffers, E., Grimberg, F., He, C., Loeb, R., Mangold, P., Marchand, T., Marfoq, O., Mushtaq, E., Muzellec, B., Philippenko, C., Silva, S., Teleńczuk, M., Albarqouni, S., Avestimehr, S., Bellet, A., Dieuleveut, A., Jaggi, M., Karimireddy, S.P., Lorenzi, M., Neglia, G., Tommasi, M., Andreux, M.: FLamby: Datasets and benchmarks for cross-silo federated learning in realistic healthcare settings. In: Thirty-sixth Conference on Neural Information Processing Systems Datasets and Benchmarks Track (2022)
101 |
102 | [3] Tschandl, P., Rosendahl, C., Kittler, H.: The ham10000 dataset, a large collection of multi-source dermatoscopic images of common pigmented skin lesions. Scientific Data 5(11), 180161 (Aug 2018).
103 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import glob
4 |
5 | import numpy as np
6 |
7 | from PIL import Image
8 | from torch.utils.data import DataLoader
9 | from torchvision import transforms
10 | import torch.utils.data as data
11 | import torch
12 |
13 | import medmnist
14 | from medmnist import INFO
15 |
16 | from utils import get_data, CustomDataset, ISIC2019, blood_noniid, distribute_data
17 |
18 | import random
19 |
20 | seed = 105
21 | np.random.seed(seed)
22 | torch.manual_seed(seed)
23 | random.seed(seed)
24 |
25 |
26 | def distribute_images(dataset_name,train_data, num_clients, test_data, batch_size, num_workers = 8):
27 | """
28 | This method splits the dataset among clients.
29 | train_data: train dataset
30 | test_data: test dataset
31 | batch_size: batch size
32 |
33 | """
34 | if dataset_name == 'HAM':
35 | CLIENTS_DATALOADERS = distribute_data(num_clients, train_data, batch_size)
36 | testloader = torch.utils.data.DataLoader(test_data,batch_size=batch_size, num_workers= num_workers)
37 |
38 | elif dataset_name == 'bloodmnist':
39 | _, testloader, train_dataset, _ = bloodmnisit(batch_size= batch_size)
40 | _, CLIENTS_DATALOADERS, _ = blood_noniid(num_clients, train_dataset, batch_size =batch_size)
41 |
42 | return CLIENTS_DATALOADERS, testloader
43 |
44 | def bloodmnisit(input_size =224, batch_size = 32, num_workers= 8, download = True):
45 | """
46 | Get train/test loaders and sets for bloodmnist from medmnist library.
47 |
48 | Input:
49 | input_size (int): width of the input image which issimilar to height
50 | batch_size (int)
51 | num_workers (int): Num of workeres used for in creating the loaders
52 | download (bool): Whether to download the dataset or not
53 |
54 | return:
55 | train_loader, test_loader, train_dataset, test_dataset
56 | """
57 |
58 | data_flag = 'bloodmnist'
59 | info = INFO[data_flag]
60 | DataClass = getattr(medmnist, info['python_class'])
61 |
62 | data_transform_train = transforms.Compose([
63 | transforms.RandomVerticalFlip(),
64 | transforms.RandomHorizontalFlip(),
65 | transforms.RandomAffine(degrees= 10, translate=(0.1,0.1)),
66 | transforms.RandomResizedCrop(input_size, (0.75,1), (0.9,1)),
67 | transforms.ToTensor(),
68 | ])
69 |
70 | data_transform_teest = transforms.Compose([
71 | transforms.Resize(224),
72 | transforms.ToTensor(),
73 | ])
74 |
75 | train_dataset = DataClass(split='train', transform=data_transform_train, download=download)
76 | test_dataset = DataClass(split='test', transform=data_transform_teest, download=download)
77 |
78 | train_loader = data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
79 | test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*batch_size, shuffle=False, num_workers=num_workers)
80 |
81 | return train_loader, test_loader, train_dataset, test_dataset
82 |
83 | def skinCancer(input_size = 224, batch_size = 32, base_dir = './data', num_workers = 8):
84 | """
85 | Get the SkinCancer datasets and dataloaders.
86 |
87 | Input:
88 | input_size (int): width of the input image
89 | batch_size (int)
90 | base_dir (str): Path to directory which includes the skincancer images
91 | num_workers (int): for dataloaders
92 |
93 | return:
94 | train_loader, testing_loader, train_dataset, test_dataset
95 |
96 | """
97 | all_image_path = glob.glob(os.path.join(base_dir, '*.jpg'))
98 | imageid_path_dict = {os.path.splitext(os.path.basename(x))[0]: x for x in all_image_path}
99 | df_train, df_val = get_data(base_dir, imageid_path_dict)
100 |
101 | normMean = [0.76303697, 0.54564005, 0.57004493]
102 | normStd = [0.14092775, 0.15261292, 0.16997]
103 |
104 | train_transform = transforms.Compose([transforms.RandomResizedCrop((input_size,input_size), scale=(0.9,1.1)),
105 | transforms.ColorJitter(brightness=0.1, contrast=0.1, hue=0.1),
106 | transforms.RandomRotation(10),
107 | transforms.RandomHorizontalFlip(),
108 | transforms.ToTensor(),
109 | transforms.Normalize(normMean, normStd)])
110 |
111 | # define the transformation of the val images.
112 | val_transform = transforms.Compose([transforms.Resize((input_size,input_size)),
113 | transforms.ToTensor(),
114 | transforms.Normalize(normMean, normStd)])
115 |
116 | training_set = CustomDataset(df_train.drop_duplicates('image_id'), transform=train_transform)
117 | train_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
118 |
119 | # Same for the validation set:
120 | validation_set = CustomDataset(df_val.drop_duplicates('image_id'), transform=val_transform)
121 | val_loader = DataLoader(validation_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
122 |
123 | return train_loader, val_loader, training_set, validation_set
124 |
125 | def isic2019(input_size = 224, root_dir = './ISIC_2019_Training_Input_preprocessed', csv_file_path = './train_test_split', batch_size = 32, num_workers=8):
126 |
127 | """
128 | Function that return train and test dataloaders and datasets fir centralized training and federated settings.
129 |
130 | Input:
131 | root_dir (str): path to directory that has preproceessed images from FLamby library
132 | csv_file_path (str): Path to the csv file that has train_test_split as per FLamby Library
133 |
134 | Return:
135 | Clients train dataloaders (federated), Clients test loaders, Train dataloader (centralized),
136 | Clients train datasets (Federated), Clients test datasets (Federated), Test dataloader (All testing images in one loader)
137 | """
138 | clients_datasets_train = [
139 | ISIC2019(
140 | csv_file_path= csv_file_path,
141 | root_dir=root_dir,client_id=i,train=True, centralized=False, input_size= input_size) for i in range(6)
142 | ]
143 |
144 | test_datasets = [
145 | ISIC2019(
146 | csv_file_path= csv_file_path,
147 | root_dir=root_dir, client_id=i, train=False, centralized=False, input_size= input_size) for i in range(6)
148 |
149 | ]
150 |
151 | centralized_dataset_train = ISIC2019(
152 | csv_file_path= csv_file_path,
153 | root_dir=root_dir, client_id=None ,train=True, centralized=True, input_size= input_size
154 | )
155 |
156 | clients_dataloader_train = [
157 | DataLoader(
158 | dataset=clients_datasets_train[i],batch_size= batch_size, shuffle=True, num_workers=num_workers
159 | ) for i in range(6)
160 | ]
161 |
162 | test_dataloaders = [
163 | DataLoader(dataset=test_datasets[i],batch_size= batch_size, shuffle=False, num_workers=num_workers)
164 | for i in range(6)
165 | ]
166 |
167 | test_centralized_dataset = ISIC2019(
168 | csv_file_path= csv_file_path,
169 | root_dir=root_dir, client_id=None , train=False, centralized=True, input_size= input_size
170 | )
171 |
172 | test_dataloader_centralized = DataLoader(dataset=test_centralized_dataset,batch_size= batch_size, shuffle=False, num_workers=num_workers)
173 |
174 |
175 | centralized_dataloader_train = DataLoader(dataset=centralized_dataset_train,batch_size= batch_size, shuffle=True, num_workers=num_workers)
176 |
177 | return clients_dataloader_train, test_dataloaders, centralized_dataloader_train, clients_datasets_train, test_datasets, test_dataloader_centralized
--------------------------------------------------------------------------------
/FeSViBS.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import models
4 | import random
5 | from dataset import skinCancer, bloodmnisit, isic2019
6 | from utils import weight_dec_global, weight_vec
7 | import argparse
8 | import torch as torch
9 | from torch import nn
10 |
11 |
12 |
13 |
14 | def fesvibs(
15 | dataset_name, lr, batch_size, Epochs, input_size, num_workers, save_every_epochs,
16 | model_name, pretrained, opt_name, seed, base_dir, root_dir, csv_file_path, num_clients, DP,
17 | epsilon, delta, resnet_dropout, initial_block, final_block, fesvibs_arg, local_round
18 | ):
19 |
20 | torch.manual_seed(seed)
21 | random.seed(seed)
22 | np.random.seed(seed)
23 |
24 | if fesvibs_arg:
25 | method_flag = 'FeSViBS'
26 | else:
27 | method_flag = 'SViBS'
28 |
29 | if torch.cuda.is_available():
30 | device = 'cuda'
31 | else:
32 | device = 'cpu'
33 |
34 | if DP:
35 | std = np.sqrt(2 * np.math.log(1.25/delta)) / epsilon
36 | mean=0
37 | dir_name = f"{model_name}_{lr}lr_{dataset_name}_{num_clients}Clients_{initial_block}to{final_block}Blocks_{batch_size}Batch__{epsilon,delta}DP_{method_flag}"
38 | else:
39 | mean = 0
40 | std = 0
41 | dir_name = f"{model_name}_{lr}lr_{dataset_name}_{num_clients}Clients_{initial_block}to{final_block}Blocks_{batch_size}Batch_{method_flag}"
42 |
43 | save_dir = f'{dir_name}'
44 | os.mkdir(save_dir)
45 |
46 | print(f"Logging to: {dir_name}")
47 |
48 | print('Getting the Dataset and Dataloader!')
49 | if dataset_name == 'HAM':
50 | num_classes = 7
51 | _, _, traindataset, testdataset = skinCancer(input_size= input_size, batch_size = batch_size, base_dir= base_dir, num_workers=num_workers)
52 | num_channels = 3
53 |
54 | elif dataset_name == 'bloodmnist':
55 | num_classes = 8
56 | _, _, traindataset, testdataset = bloodmnisit(input_size= input_size, batch_size = batch_size, download= True, num_workers=num_workers)
57 | num_channels = 3
58 |
59 | elif dataset_name == 'isic2019':
60 | num_classes = 8
61 | DATALOADERS, _, _, _, _, test_loader = isic2019(input_size= input_size, batch_size = batch_size, root_dir=root_dir, csv_file_path=csv_file_path, num_workers=num_workers)
62 | num_channels = 3
63 |
64 | criterion = nn.CrossEntropyLoss()
65 |
66 | fesvibs_network = models.FeSVBiS(
67 | ViT_name= model_name, num_classes= num_classes,
68 | num_clients = num_clients, in_channels = num_channels,
69 | ViT_pretrained= pretrained,
70 | initial_block= initial_block, final_block= final_block,
71 | resnet_dropout= resnet_dropout, DP=DP, mean= mean, std= std
72 | ).to(device)
73 |
74 | Split = models.SplitFeSViBS(
75 | num_clients=num_clients, device = device, network = fesvibs_network,
76 | criterion = criterion, base_dir=save_dir,
77 | initial_block= initial_block, final_block= final_block,
78 | )
79 |
80 |
81 | if dataset_name != 'isic2019':
82 | print('Distribute Images Among Clients')
83 | Split.distribute_images(dataset_name=dataset_name, train_data= traindataset,test_data= testdataset ,batch_size = batch_size)
84 | else:
85 | Split.CLIENTS_DATALOADERS = DATALOADERS
86 | Split.testloader = test_loader
87 |
88 | Split.set_optimizer(opt_name, lr = lr)
89 | Split.init_logs()
90 |
91 | print('Start Training! \n')
92 |
93 | for r in range(Epochs):
94 | print(f"Round {r+1} / {Epochs}")
95 | agg_weights = None
96 | for client_i in range(num_clients):
97 | weight_dict = Split.train_round(client_i)
98 | if client_i == 0:
99 | agg_weights = weight_dict
100 | else:
101 | agg_weights['blocks'] += weight_dict['blocks']
102 | agg_weights['cls'] += weight_dict['cls']
103 | agg_weights['pos_embed'] += weight_dict['pos_embed']
104 |
105 | agg_weights['blocks'] /= num_clients
106 | agg_weights['cls'] /= num_clients
107 | agg_weights['pos_embed'] /= num_clients
108 |
109 |
110 | Split.network.vit.blocks = weight_dec_global(
111 | Split.network.vit.blocks,
112 | agg_weights['blocks'].to(device)
113 | )
114 |
115 | Split.network.vit.cls_token.data = agg_weights['cls'].to(device) + 0.0
116 | Split.network.vit.pos_embed.data = agg_weights['pos_embed'].to(device) + 0.0
117 |
118 | if fesvibs_arg and ((r+1) % local_round == 0 and r!= 0):
119 | print('========================== \t \t Federation \t \t ==========================')
120 | tails_weights = []
121 | head_weights = []
122 | for head, tail in zip(Split.network.resnet50_clients, Split.network.mlp_clients_tail):
123 | head_weights.append(weight_vec(head).detach().cpu())
124 | tails_weights.append(weight_vec(tail).detach().cpu())
125 |
126 | mean_avg_tail = torch.mean(torch.stack(tails_weights), axis = 0)
127 | mean_avg_head = torch.mean(torch.stack(head_weights), axis = 0)
128 |
129 | for i in range(num_clients):
130 | Split.network.mlp_clients_tail[i] = weight_dec_global(Split.network.mlp_clients_tail[i],
131 | mean_avg_tail.to(device))
132 | Split.network.resnet50_clients[i] = weight_dec_global(Split.network.resnet50_clients[i],
133 | mean_avg_head.to(device))
134 |
135 | for client_i in range(num_clients):
136 | Split.eval_round(client_i)
137 |
138 | print('---------')
139 |
140 | if (r+1) % save_every_epochs == 0 and r != 0:
141 | Split.save_pickles(save_dir)
142 | print('============================================')
143 |
144 | if __name__ == "__main__":
145 |
146 | parser = argparse.ArgumentParser(description='Run Centralized Experiments')
147 | parser.add_argument('--dataset_name', type=str, choices=['HAM', 'bloodmnist', 'isic2019'], help='Dataset Name')
148 | parser.add_argument('--input_size', type=int, default= 224, help='Input size --> (input_size, input_size), default : 224')
149 | parser.add_argument('--local_round', type=int, default= 2, help='Local round before federation in FeSViBS, default : 2')
150 | parser.add_argument('--num_workers', type=int, default= 8, help='Number of workers for dataloaders, default : 8')
151 | parser.add_argument('--initial_block', type=int, default= 1, help='Initial Block, default : 1')
152 | parser.add_argument('--final_block', type=int, default= 6, help='Final Block, default : 6')
153 | parser.add_argument('--num_clients', type=int, default= 6, help='Number of Clients, default : 6')
154 | parser.add_argument('--model_name', type=str, default= 'vit_base_r50_s16_224', help='Model name from timm library, default: vit_base_r50_s16_224')
155 | parser.add_argument('--pretrained', type=bool, default= False, help='Pretrained weights flag, default: False')
156 | parser.add_argument('--fesvibs_arg', type=bool, default= False, help='Flag to indicate whether SViBS or FeSViBS, default: False')
157 | parser.add_argument('--batch_size', type=int, default= 32, help='Batch size, default : 32')
158 | parser.add_argument('--Epochs', type=int, default= 200, help='Number of Epochs, default : 200')
159 | parser.add_argument('--opt_name', type=str, choices=['Adam'], default = 'Adam', help='Optimizer name, only ADAM optimizer is available')
160 | parser.add_argument('--lr', type=float, default= 1e-4, help='Learning rate, default : 1e-4')
161 | parser.add_argument('--save_every_epochs', type=int, default= 10, help='Save metrics every this number of epochs, default: 10')
162 | parser.add_argument('--seed', type=int, default= 105, help='Seed, default: 105')
163 | parser.add_argument('--base_dir', type=str, default= None, help='')
164 | parser.add_argument('--root_dir', type=str, default= None, help='')
165 | parser.add_argument('--csv_file_path', type=str, default=None, help='')
166 | parser.add_argument('--DP', type=bool, default= False, help='Differential Privacy , default: False')
167 | parser.add_argument('--epsilon', type=float, default= 0, help='Epsilon Value for differential privacy')
168 | parser.add_argument('--delta', type=float, default= 0.00001, help='Delta Value for differential privacy')
169 | parser.add_argument('--resnet_dropout', type=float, default= 0.5, help='ResNet Dropout, Default: 0.5')
170 | args = parser.parse_args()
171 |
172 | fesvibs(
173 | dataset_name = args.dataset_name, input_size= args.input_size,
174 | num_workers= args.num_workers, model_name= args.model_name,
175 | pretrained= args.pretrained, batch_size= args.batch_size,
176 | Epochs= args.Epochs, opt_name= args.opt_name, lr= args.lr,
177 | save_every_epochs= args.save_every_epochs, seed= args.seed,
178 | base_dir= args.base_dir, root_dir= args.root_dir, csv_file_path= args.csv_file_path, num_clients = args.num_clients,
179 | DP = args.DP, epsilon = args.epsilon, delta = args.delta, initial_block= args.initial_block, final_block=args.final_block,
180 | resnet_dropout = args.resnet_dropout, fesvibs_arg = args.fesvibs_arg, local_round = args.local_round
181 | )
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | import numpy as np
5 | import pandas as pd
6 | from PIL import Image
7 |
8 | from sklearn.model_selection import train_test_split
9 | from torch.utils.data import Dataset, DataLoader
10 | from torchvision import transforms
11 |
12 |
13 | def weight_vec(network):
14 | A = []
15 | for w in network.parameters():
16 | A.append(torch.flatten(w))
17 | return torch.cat(A)
18 |
19 |
20 | def weight_dec_global(pyModel, weight_vec):
21 | """
22 | Reshape the weight back to its original shape in pytorch and then
23 | plug it to the model
24 | """
25 | c = 0
26 | for w in pyModel.parameters():
27 | m = w.numel()
28 | D = weight_vec[c:m+c].reshape(w.data.shape)
29 | c+=m
30 | if w.data is None:
31 | w.data = D+0
32 | else:
33 | with torch.no_grad():
34 | w.set_( D+0 )
35 | return pyModel
36 |
37 |
38 | def distribute_data(numOfClients, train_dataset, batch_size):
39 | """
40 | numOfClients: int
41 | train_dataset: train_dataset (torchvision.datasets class)
42 | return distributed dataloaders for each client
43 | """
44 | # distribution list to fill the number of samples in each entry for each client
45 | distribution = []
46 | # rounding the number to get the number of dataset each client will get
47 | p = round(1/numOfClients * len(train_dataset))
48 |
49 | # the remainder data that won't be able to split if it's not an even number
50 | remainder_data = len(train_dataset) - numOfClients * p
51 | # if the remainder data is 0 ---> all clients will get the same number of dataset
52 | if remainder_data == 0:
53 | distribution = [p for i in range(numOfClients)]
54 | else:
55 | distribution = [p for i in range(numOfClients-1)]
56 | distribution.append(p+remainder_data)
57 |
58 | # splitting the data to different dataloaders
59 | data_split = torch.utils.data.random_split(train_dataset, distribution)
60 | # CLIENTS DATALOADERS
61 | ClIENTS_DATALOADERS = [torch.utils.data.DataLoader(data_split[i], batch_size=batch_size,shuffle=True, num_workers=32) for i in range(numOfClients)]
62 |
63 | print(f"Length of the training dataset: {len(train_dataset)} sample")
64 | return ClIENTS_DATALOADERS
65 |
66 | def get_data(base_dir, imageid_path_dict):
67 |
68 | """
69 | Preprocessing for the SkinCancer dataset.
70 | Input:
71 | base_dir (str): path of the directory includes SkinCancer images
72 | imageid_path_dict (dict): dictionary with image id as keys and image pth as values
73 |
74 | Return:
75 | df_train: Dataframe for training
76 | df_val: Dataframe for testing
77 |
78 | """
79 |
80 | lesion_type_dict = {
81 | 'nv': 'Melanocytic nevi',
82 | 'mel': 'dermatofibroma',
83 | 'bkl': 'Benign keratosis-like lesions ',
84 | 'bcc': 'Basal cell carcinoma',
85 | 'akiec': 'Actinic keratoses',
86 | 'vasc': 'Vascular lesions',
87 | 'df': 'Dermatofibroma'
88 | }
89 |
90 | df_original = pd.read_csv(os.path.join(base_dir, 'HAM10000_metadata.csv'))
91 | df_original['path'] = df_original['image_id'].map(imageid_path_dict.get)
92 | df_original['cell_type'] = df_original['dx'].map(lesion_type_dict.get)
93 | df_original['cell_type_idx'] = pd.Categorical(df_original['cell_type']).codes
94 |
95 | df_original[['cell_type_idx', 'cell_type']].sort_values('cell_type_idx').drop_duplicates()
96 |
97 | # Get number of images associated with each lesion_id
98 | df_undup = df_original.groupby('lesion_id').count()
99 | # Filter out lesion_id's that have only one image associated with it
100 | df_undup = df_undup[df_undup['image_id'] == 1]
101 | df_undup.reset_index(inplace=True)
102 |
103 | # Identify lesion_id's that have duplicate images and those that have only one image.
104 | def get_duplicates(x):
105 | unique_list = list(df_undup['lesion_id'])
106 | if x in unique_list:
107 | return 'unduplicated'
108 | else:
109 | return 'duplicated'
110 |
111 | # create a new colum that is a copy of the lesion_id column
112 | df_original['duplicates'] = df_original['lesion_id']
113 |
114 | # apply the function to this new column
115 | df_original['duplicates'] = df_original['duplicates'].apply(get_duplicates)
116 |
117 | # Filter out images that don't have duplicates
118 | df_undup = df_original[df_original['duplicates'] == 'unduplicated']
119 |
120 | # Create a val set using df because we are sure that none of these images have augmented duplicates in the train set
121 | y = df_undup['cell_type_idx']
122 | _, df_val = train_test_split(df_undup, test_size=0.2, random_state=101, stratify=y)
123 |
124 |
125 | # This set will be df_original excluding all rows that are in the val set
126 | # This function identifies if an image is part of the train or val set.
127 | def get_val_rows(x):
128 | # create a list of all the lesion_id's in the val set
129 | val_list = list(df_val['image_id'])
130 | if str(x) in val_list:
131 | return 'val'
132 | else:
133 | return 'train'
134 |
135 | # Identify train and val rows
136 | # Create a new colum that is a copy of the image_id column
137 | df_original['train_or_val'] = df_original['image_id']
138 | # Apply the function to this new column
139 | df_original['train_or_val'] = df_original['train_or_val'].apply(get_val_rows)
140 | # Filter out train rows
141 | df_train = df_original[df_original['train_or_val'] == 'train']
142 |
143 | # Copy fewer class to balance the number of 7 classes
144 | data_aug_rate = [15,10,5,50,0,40,5]
145 | for i in range(7):
146 | if data_aug_rate[i]:
147 | df_train=df_train.append([df_train.loc[df_train['cell_type_idx'] == i,:]]*(data_aug_rate[i]-1), ignore_index=True)
148 | df_train['cell_type'].value_counts()
149 |
150 | df_train = df_train.reset_index()
151 | df_val = df_val.reset_index()
152 |
153 | return df_train, df_val
154 |
155 | class CustomDataset(Dataset):
156 | """
157 | Cutom dataset for SkinCancer dataset
158 | """
159 | def __init__(self, df, transform=None):
160 | self.df = df
161 | self.transform = transform
162 |
163 | def __len__(self):
164 | return len(self.df)
165 |
166 | def __getitem__(self, index):
167 | # Load data and get label
168 | X = Image.open(self.df['path'][index])
169 | y = torch.tensor(int(self.df['cell_type_idx'][index]))
170 |
171 | if self.transform:
172 | X = self.transform(X)
173 | return X, y
174 |
175 | class ISIC2019(Dataset):
176 |
177 |
178 | TO_REPLACE_TRAIN = [None, [4,5,6], None, None,[4], [4,5,6]]
179 | VALUES_TRAIN = [None, [3,4,5], None, None,[2], [3,4,5]]
180 |
181 | def __init__(self, csv_file_path, root_dir, client_id, train = True, centralized = False, input_size = 224) -> None:
182 | super().__init__()
183 | self.image_root = root_dir
184 | self.train = train
185 | csv_file = pd.read_csv(csv_file_path)
186 | self.centralized = centralized
187 |
188 | if train:
189 | if centralized:
190 | self.csv = csv_file[csv_file['fold'] == 'train'].reset_index()
191 | else:
192 | self.csv = csv_file[csv_file['fold2'] == f'train_{client_id}'].reset_index()
193 |
194 | elif train == False:
195 | if centralized:
196 | self.csv = csv_file[csv_file['fold'] == 'test'].reset_index()
197 | else:
198 | self.csv = csv_file[csv_file['fold2'] == f'test_{client_id}'].reset_index()
199 |
200 | if train:
201 | self.transform = transforms.Compose([
202 | transforms.RandomRotation(10),
203 | transforms.RandomHorizontalFlip(0.5),
204 | transforms.RandomVerticalFlip(0.5),
205 | transforms.RandomAffine(degrees = 0, shear=0.05),
206 | transforms.RandomResizedCrop((input_size, input_size), scale=(0.85,1.1)),
207 | transforms.ToTensor(),
208 | ])
209 |
210 | elif train == False:
211 | self.transform = transforms.Compose([
212 | transforms.Resize((input_size, input_size)),
213 | transforms.ToTensor(),
214 | ])
215 | def __len__(self):
216 | return self.csv.shape[0]
217 |
218 | def __getitem__(self, idx):
219 | if torch.is_tensor(idx):
220 | idx = idx.tolist()
221 |
222 | img_name = os.path.join(self.image_root,
223 | self.csv['image'][idx]+'.jpg')
224 | sample = Image.open(img_name)
225 | target = self.csv['target'][idx]
226 |
227 | sample = self.transform(sample)
228 |
229 | return sample, target
230 |
231 | def blood_noniid(numOfAgents, data, batch_size):
232 | """
233 | Function to divide the bloodmnist among clients
234 |
235 | Input:
236 | numOfAgents (int): Number of Agents (Clients)
237 | data: dataset to be divided
238 | batch_size (int)
239 |
240 |
241 | Return:
242 | datasets for agents, Loaders for agents , datasets for visualization
243 |
244 | """
245 | # static way of creating non iid data, to change the distribution change the index of p in
246 | # the for loop
247 | nonIID_tensors = [[] for i in range(numOfAgents)]
248 | nonIID_labels = [[] for i in range(numOfAgents)]
249 | agents = np.arange(0,numOfAgents)
250 | c = 0
251 | p = np.ones((numOfAgents))
252 | xx = 0
253 | for i in data:
254 | xx+=1
255 | p = np.ones((numOfAgents))
256 | if float(i[1]) == 0:
257 | p[0] = numOfAgents
258 | p[1] = numOfAgents
259 | p[2] = numOfAgents
260 | if float(i[1]) == 1:
261 | p[0] = numOfAgents
262 | p[1] = numOfAgents
263 | p[2] = numOfAgents
264 | if float(i[1]) == 2:
265 | p[3] = numOfAgents
266 | p[5] = numOfAgents
267 | p[0] = numOfAgents
268 | if float(i[1]) == 3:
269 | p[0] = numOfAgents
270 | p[4] = numOfAgents
271 | p[5] = numOfAgents
272 | if float(i[1]) == 4:
273 | p[3] = numOfAgents
274 | p[4] = numOfAgents
275 | p[5] = numOfAgents
276 | if float(i[1]) == 5:
277 | p[3] = numOfAgents
278 | p[4] = numOfAgents
279 | p[5] = numOfAgents
280 | if float(i[1]) == 6:
281 | p[4] = numOfAgents
282 | p[5] = numOfAgents
283 | p[5] = numOfAgents
284 | if float(i[1]) == 7:
285 | p[0] = numOfAgents
286 | p[1] = numOfAgents
287 | p[2] = numOfAgents
288 | p = p / np.sum(p)
289 | j = np.random.choice(agents, p = p)
290 | nonIID_tensors[j].append(i[0])
291 | nonIID_labels[j].append(torch.tensor(i[1]).reshape(1))
292 |
293 | dataset_vis = [[] for i in range(numOfAgents) ]
294 | for i in range(numOfAgents):
295 | dataset_vis[i].append((torch.stack(nonIID_tensors[i]),torch.cat(nonIID_labels[i])))
296 |
297 | dataset_agents = [[] for i in range(numOfAgents) ]
298 | for agent in range(numOfAgents):
299 | im_ = dataset_vis[agent][0][0]
300 | lab_ = dataset_vis[agent][0][1]
301 | for im, lab in zip(im_, lab_):
302 | dataset_agents[agent].append((im, lab))
303 |
304 | dataset_loaders = [DataLoader(dataset_agents[i], batch_size=batch_size, shuffle=True, num_workers=8) for i in range(numOfAgents)]
305 |
306 | return dataset_agents, dataset_loaders, dataset_vis
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import pickle as pkl
3 | import os
4 | import timm
5 | import copy
6 | import numpy as np
7 |
8 | import torch.nn as nn
9 | import torch
10 | from sklearn.metrics import balanced_accuracy_score
11 |
12 | from dataset import blood_noniid, bloodmnisit, distribute_data
13 | from utils import weight_vec
14 |
15 | class CentralizedFashion():
16 | def __init__(self, device, network, criterion, base_dir):
17 | """
18 | Class for Centralized Paradigm.
19 | args:
20 | device: cuda vs cpu
21 | network: ViT model
22 | criterion: loss function to be used
23 | base_dir: where to save metrics as pickles
24 | return:
25 | None
26 | """
27 | self.device = device
28 | self.network = network
29 | self.criterion = criterion
30 | self.base_dir = base_dir
31 |
32 | def set_optimizer(self, name, lr):
33 | """
34 | name: Optimizer name, e.g. Adam
35 | lr: learning rate
36 |
37 | """
38 | if name == 'Adam':
39 | self.optimizer = torch.optim.Adam(self.network.parameters(), lr = lr)
40 |
41 | def init_logs(self):
42 | """
43 | A method to initialize dictionaries for the metrics
44 | return : None
45 | args: None
46 | """
47 | self.losses = {'train':[], 'test':[]}
48 | self.balanced_accs = {'train':[], 'test':[]}
49 |
50 | def train_round(self, train_loader):
51 | """
52 | Training loop.
53 |
54 | """
55 | running_loss = 0
56 | whole_labels = []
57 | whole_preds = []
58 | whole_probs = []
59 | for imgs, labels in tqdm(train_loader):
60 | self.optimizer.zero_grad()
61 | imgs, labels = imgs.to(self.device),labels.to(self.device)
62 | output = self.network(imgs)
63 | labels = labels.reshape(labels.shape[0])
64 | loss = self.criterion(output, labels)
65 | loss.backward()
66 | self.optimizer.step()
67 | running_loss += loss.item()
68 | _, predicted = torch.max(output, 1)
69 | whole_probs.append(torch.nn.Softmax(dim = -1)(output).detach().cpu())
70 | whole_labels.append(labels.detach().cpu())
71 | whole_preds.append(predicted.detach().cpu())
72 | self.metrics(whole_labels, whole_preds, running_loss, len(train_loader), whole_probs, train = True)
73 |
74 | def eval_round(self, test_loader):
75 | """
76 | Evaluation loop.
77 |
78 | client_i: Client index.
79 |
80 | """
81 | running_loss = 0
82 | whole_labels = []
83 | whole_preds = []
84 | whole_probs = []
85 | with torch.no_grad():
86 | for imgs, labels in tqdm(test_loader):
87 | imgs, labels = imgs.to(self.device), labels.to(self.device)
88 | output = self.network(imgs)
89 | labels = labels.reshape(labels.shape[0])
90 | loss = self.criterion(output, labels)
91 | running_loss += loss.item()
92 | _, predicted = torch.max(output, 1)
93 | whole_probs.append(torch.nn.Softmax(dim = -1)(output).detach().cpu())
94 | whole_labels.append(labels.detach().cpu())
95 | whole_preds.append(predicted.detach().cpu())
96 | self.metrics(whole_labels, whole_preds, running_loss, len(test_loader), whole_probs, train= False)
97 |
98 | def metrics(self, whole_labels, whole_preds, running_loss, len_loader, whole_probs, train):
99 | """
100 | Save metrics as pickle files and the model as .pt file.
101 |
102 | """
103 | whole_labels = torch.cat(whole_labels)
104 | whole_preds = torch.cat(whole_preds)
105 | loss_epoch = running_loss/len_loader
106 | balanced_acc = balanced_accuracy_score(whole_labels.detach().cpu(),whole_preds.detach().cpu())
107 | if train == True:
108 | eval_name = 'train'
109 | else:
110 | eval_name = 'test'
111 |
112 | self.losses[eval_name].append(loss_epoch)
113 | self.balanced_accs[eval_name].append(balanced_acc)
114 |
115 | print(f"{eval_name}:")
116 | print(f"{eval_name}_loss :{loss_epoch:.3f}")
117 | print(f"{eval_name}_balanced_acc :{balanced_acc:.3f}")
118 |
119 |
120 | def save_pickles(self, base_dir, local= None, client_id=None):
121 | if local and client_id:
122 | with open(os.path.join(base_dir,f'loss_epoch_Client{client_id}'), 'wb') as handle:
123 | pkl.dump(self.losses, handle)
124 | with open(os.path.join(base_dir,f'balanced_accs{client_id}'), 'wb') as handle:
125 | pkl.dump(self.balanced_accs, handle)
126 | else:
127 | with open(os.path.join(base_dir,'loss_epoch'), 'wb') as handle:
128 | pkl.dump(self.losses, handle)
129 | with open(os.path.join(base_dir,f'balanced_accs'), 'wb') as handle:
130 | pkl.dump(self.balanced_accs, handle)
131 |
132 | class SLViT(nn.Module):
133 | def __init__(
134 | self, ViT_name, num_classes , num_clients=6,
135 | in_channels=3, ViT_pretrained = False,
136 | diff_privacy = False, mean = 0, std = 1
137 | ) -> None:
138 |
139 | super().__init__()
140 |
141 | self.vit = timm.create_model(
142 | model_name = ViT_name,
143 | pretrained = ViT_pretrained,
144 | num_classes = num_classes,
145 | in_chans = in_channels
146 | )
147 | client_tail = MLP_cls_classes(num_classes= num_classes)
148 | self.mlp_clients_tail = nn.ModuleList([copy.deepcopy(client_tail)for i in range(num_clients)])
149 | self.resnet50_clients = nn.ModuleList([copy.deepcopy(self.vit.patch_embed) for i in range(num_clients)])
150 |
151 | self.diff_privacy = diff_privacy
152 | self.mean = mean
153 | self.std = std
154 |
155 | def forward(self, x, client_idx):
156 | x = self.resnet50_clients[client_idx](x)
157 | if self.diff_privacy == True:
158 | noise = torch.randn(size= x.shape).cuda() * self.std + self.mean
159 | x = x + noise
160 | x = torch.cat((self.vit.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
161 | x = self.vit.pos_drop(x + self.vit.pos_embed)
162 | for block_num in range(12):
163 | x = self.vit.blocks[block_num](x)
164 | x = self.vit.norm(x)
165 | cls = self.vit.pre_logits(x)[:,0,:]
166 | x = self.mlp_clients_tail[client_idx](cls)
167 | return x, cls
168 |
169 | class MLP_cls_classes(nn.Module):
170 | def __init__(self,num_classes):
171 | super().__init__()
172 | self.norm = nn.LayerNorm((768,), eps=1e-06, elementwise_affine=True)
173 | self.identity = nn.Identity()
174 | self.fc = nn.Linear(in_features=768, out_features=num_classes, bias=True)
175 |
176 | def forward(self, x):
177 | x = self.norm(x)
178 | x = self.identity(x)
179 | x = self.fc(x)
180 | return x
181 |
182 | class SplitNetwork():
183 | def __init__(
184 | self, num_clients, device, network,
185 | criterion, base_dir,
186 | ):
187 | """
188 | args:
189 | num_clients
190 | device: cuda vs cpu
191 | network: ViT model
192 | criterion: loss function to be used
193 | base_dir: where to save pickles/model files
194 | """
195 |
196 | self.device = device
197 | self.num_clients = num_clients
198 | self.criterion = criterion
199 | self.network = network
200 | self.base_dir = base_dir
201 |
202 | def init_logs(self):
203 | """
204 | This method initializes dictionaries for the metrics
205 |
206 | """
207 | self.losses = {'train':[[] for i in range(self.num_clients)], 'test':[[] for i in range(self.num_clients)]}
208 | self.balanced_accs = {'train':[[] for i in range(self.num_clients)], 'test':[[] for i in range(self.num_clients)]}
209 |
210 | def set_optimizer(self, name, lr):
211 | """
212 | name: Optimizer name, e.g. Adam
213 | lr: learning rate
214 |
215 | """
216 | if name == 'Adam':
217 | self.optimizer = torch.optim.Adam(self.network.parameters(), lr = lr)
218 |
219 | def distribute_images(self, dataset_name ,train_data, test_data, batch_size):
220 | """
221 | This method splits the dataset among clients.
222 | train_data: train dataset
223 | test_data: test dataset
224 | batch_size: batch size
225 |
226 | """
227 | if dataset_name == 'HAM':
228 | self.CLIENTS_DATALOADERS = distribute_data(self.num_clients, train_data, batch_size)
229 | self.testloader = torch.utils.data.DataLoader(test_data,batch_size=batch_size, num_workers= 8)
230 |
231 | elif dataset_name == 'bloodmnist':
232 | _, self.testloader, train_dataset, _ = bloodmnisit(batch_size= batch_size)
233 | _, self.CLIENTS_DATALOADERS, _ = blood_noniid(self.num_clients, train_dataset, batch_size =batch_size)
234 |
235 | def train_round(self, client_i):
236 | """
237 | Training loop.
238 |
239 | client_i: Client index.
240 |
241 | """
242 | running_loss_client_i = 0
243 | mel_running_loss = 0
244 | whole_labels = []
245 | whole_preds = []
246 | whole_probs = []
247 | copy_network = copy.deepcopy(self.network)
248 | weight_dic = {'blocks':None, 'cls':None, 'pos_embed':None}
249 | self.network.train()
250 | for data in tqdm(self.CLIENTS_DATALOADERS[client_i]):
251 | self.optimizer.zero_grad()
252 | imgs, labels = data[0].to(self.device), data[1].to(self.device)
253 | labels = labels.reshape(labels.shape[0])
254 | tail_output = self.network(imgs, client_i)
255 | loss = self.criterion(tail_output[0], labels)
256 | loss.backward()
257 | self.optimizer.step()
258 | running_loss_client_i+= loss.item()
259 | _, predicted = torch.max(tail_output[0], 1)
260 | whole_probs.append(torch.nn.Softmax(dim = -1)(tail_output[0]).detach().cpu())
261 | whole_labels.append(labels.detach().cpu())
262 | whole_preds.append(predicted.detach().cpu())
263 | self.metrics(client_i, whole_labels, whole_preds, running_loss_client_i, len(self.CLIENTS_DATALOADERS[client_i]), whole_probs, train = True)
264 |
265 | # if self.avg_body:
266 | weight_dic['blocks'] = weight_vec(self.network.vit.blocks).detach().cpu()
267 | weight_dic['cls'] = self.network.vit.cls_token.detach().cpu()
268 | weight_dic['pos_embed'] = self.network.vit.pos_embed.detach().cpu()
269 |
270 | self.network.vit.blocks = copy.deepcopy(copy_network.vit.blocks)
271 | self.network.vit.cls_token = copy.deepcopy(copy_network.vit.cls_token)
272 | self.network.vit.pos_embed = copy.deepcopy(copy_network.vit.pos_embed)
273 | return weight_dic
274 |
275 | def eval_round(self, client_i):
276 | """
277 | Evaluation loop.
278 |
279 | client_i: Client index.
280 |
281 | """
282 | running_loss_client_i = 0
283 | whole_labels = []
284 | whole_preds = []
285 | whole_probs = []
286 | self.network.eval()
287 | with torch.no_grad():
288 | for data in tqdm(self.testloader):
289 | imgs, labels = data[0].to(self.device), data[1].to(self.device)
290 | tail_output = self.network(imgs, client_i)[0]
291 | labels = labels.reshape(labels.shape[0])
292 | loss = self.criterion(tail_output, labels)
293 | running_loss_client_i+= loss.item()
294 | _, predicted = torch.max(tail_output, 1)
295 | whole_probs.append(torch.nn.Softmax(dim = -1)(tail_output).detach().cpu())
296 | whole_labels.append(labels.detach().cpu())
297 | whole_preds.append(predicted.detach().cpu())
298 | self.metrics(client_i, whole_labels, whole_preds, running_loss_client_i, len(self.testloader), whole_probs, train= False)
299 |
300 | def metrics(self, client_i, whole_labels, whole_preds, running_loss_client_i, len_loader, whole_probs, train):
301 | """
302 | Save metrics as pickle files and the model as .pt file.
303 |
304 | """
305 | whole_labels = torch.cat(whole_labels)
306 | whole_preds = torch.cat(whole_preds)
307 | loss_epoch = running_loss_client_i/len_loader
308 | balanced_acc = balanced_accuracy_score(whole_labels.detach().cpu(), whole_preds.detach().cpu())
309 |
310 | if train == True:
311 | eval_name = 'train'
312 | else:
313 | eval_name = 'test'
314 |
315 | self.losses[eval_name][client_i].append(loss_epoch)
316 | self.balanced_accs[eval_name][client_i].append(balanced_acc)
317 |
318 | print(f"client{client_i}_{eval_name}:")
319 | print(f" Loss {eval_name}:{loss_epoch:.3f}")
320 | print(f"balanced accuracy {eval_name}:{balanced_acc:.3f}")
321 |
322 | def save_pickles(self, base_dir):
323 | with open(os.path.join(base_dir,'loss_epoch'), 'wb') as handle:
324 | pkl.dump(self.losses, handle)
325 | with open(os.path.join(base_dir,'balanced_accs'), 'wb') as handle:
326 | pkl.dump(self.balanced_accs, handle)
327 |
328 | class FeSVBiS(nn.Module):
329 | def __init__(
330 | self, ViT_name, num_classes,
331 | num_clients=6, in_channels=3, ViT_pretrained=False,
332 | initial_block=1, final_block=6, resnet_dropout = None, DP = False, mean = None, std = None
333 | ) -> None:
334 | super().__init__()
335 |
336 | self.initial_block = initial_block
337 | self.final_block = final_block
338 |
339 | self.vit = timm.create_model(
340 | model_name = ViT_name,
341 | pretrained = ViT_pretrained,
342 | num_classes = num_classes,
343 | in_chans = in_channels
344 | )
345 |
346 | self.resnet50 = self.vit.patch_embed
347 | self.resnet50_clients = nn.ModuleList([copy.deepcopy(self.resnet50) for i in range(num_clients)])
348 | self.common_network = ResidualBlock(drop_out=resnet_dropout)
349 | client_tail = MLP_cls_classes(num_classes= num_classes)
350 | self.mlp_clients_tail = nn.ModuleList([copy.deepcopy(client_tail) for i in range(num_clients)])
351 | self.DP = DP
352 | self.mean = mean
353 | self.std = std
354 |
355 | def forward(self, x, chosen_block, client_idx):
356 | x = self.resnet50_clients[client_idx](x)
357 | if self.DP:
358 | noise = torch.randn(size= x.shape).cuda() * self.std + self.mean
359 | x = x + noise
360 | for block_num in range(chosen_block):
361 | x = self.vit.blocks[block_num](x)
362 | x = self.common_network(x)
363 | x = self.mlp_clients_tail[client_idx](x)
364 | return x
365 |
366 |
367 | class ResidualBlock(nn.Module):
368 | def __init__(self, in_channels=768, out_channels=768, stride = 1, downsample = None, drop_out= None):
369 | super(ResidualBlock, self).__init__()
370 | self.conv1 = nn.Sequential(
371 | nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
372 | nn.BatchNorm2d(out_channels),
373 | nn.ReLU())
374 | self.conv2 = nn.Sequential(
375 | nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
376 | nn.BatchNorm2d(out_channels))
377 | self.downsample = downsample
378 | self.relu = nn.ReLU()
379 | self.out_channels = out_channels
380 | self.pool = nn.AvgPool2d(14, stride=1)
381 | self.dropout = nn.Dropout2d(p=drop_out)
382 | self.drop_out = drop_out
383 |
384 | def forward(self, x):
385 | if len(x.shape) == 3:
386 | x = torch.permute(x,(0,-1,1))
387 | x = x.reshape(x.shape[0], x.shape[1] , 14, 14)
388 | residual = x
389 | out = self.conv1(x)
390 | if self.drop_out is not None:
391 | out = self.dropout(out)
392 | out = self.conv2(out)
393 | if self.downsample:
394 | residual = self.downsample(x)
395 | out += residual
396 | out = self.relu(out)
397 | out = self.pool(out)
398 | return out.reshape(-1,768)
399 |
400 | class SplitFeSViBS(SplitNetwork):
401 | def __init__(
402 | self, num_clients, device,
403 | network, criterion, base_dir,
404 | initial_block, final_block,
405 | ):
406 |
407 | self.initial_block = initial_block
408 | self.final_block = final_block
409 | self.num_clients = num_clients
410 | self.device = device
411 | self.network = network
412 | self.criterion = criterion
413 | self.base_dir = base_dir
414 | self.train_chosen_blocks = [0] * num_clients
415 |
416 | def set_optimizer_mel(self, name, lr):
417 | if name == 'Adam':
418 | self.optimizer_mel = [torch.optim.Adam(self.mel_body[i].parameters(), lr = lr) for i in range(self.num_clients)]
419 |
420 | def train_round(self, client_i):
421 | """
422 | Training loop.
423 |
424 | client_i: Client index.
425 |
426 | """
427 | running_loss_client_i = 0
428 | whole_labels = []
429 | whole_preds = []
430 | whole_probs = []
431 | self.chosen_block = np.random.randint(low = self.initial_block, high= self.final_block+1)
432 | self.train_chosen_blocks[client_i] = self.chosen_block
433 | copy_network = copy.deepcopy(self.network)
434 | weight_dic = {}
435 | weight_dic['blocks'] = None
436 | weight_dic['cls'] = None
437 | weight_dic['pos_embed'] = None
438 | weight_dic['resnet'] = None
439 | print(f"Chosen Block:{self.chosen_block} for client {client_i}")
440 | self.network.train()
441 | for data in tqdm(self.CLIENTS_DATALOADERS[client_i]):
442 | self.optimizer.zero_grad()
443 | imgs, labels = data[0].to(self.device), data[1].to(self.device)
444 | labels = labels.reshape(labels.shape[0])
445 | tail_output = self.network(x=imgs, chosen_block=self.chosen_block, client_idx = client_i)
446 | loss = self.criterion(tail_output, labels)
447 | loss.backward()
448 | self.optimizer.step()
449 | running_loss_client_i+= loss.item()
450 | _, predicted = torch.max(tail_output, 1)
451 | whole_probs.append(torch.nn.Softmax(dim = -1)(tail_output).detach().cpu())
452 | whole_labels.append(labels.detach().cpu())
453 | whole_preds.append(predicted.detach().cpu())
454 | self.metrics(client_i, whole_labels, whole_preds, running_loss_client_i, len(self.CLIENTS_DATALOADERS[client_i]), whole_probs, train = True)
455 |
456 | weight_dic['blocks'] = weight_vec(self.network.vit.blocks).detach().cpu()
457 | weight_dic['cls'] = self.network.vit.cls_token.detach().cpu()
458 | weight_dic['pos_embed'] = self.network.vit.pos_embed.detach().cpu()
459 |
460 | self.network.vit.blocks = copy.deepcopy(copy_network.vit.blocks)
461 | self.network.vit.cls_token = copy.deepcopy(copy_network.vit.cls_token)
462 | self.network.vit.pos_embed = copy.deepcopy(copy_network.vit.pos_embed)
463 | return weight_dic
464 |
465 |
466 | def eval_round(self, client_i):
467 | """
468 | Evaluation loop.
469 |
470 | client_i: Client index.
471 |
472 | """
473 | running_loss_client_i = 0
474 | whole_labels = []
475 | whole_preds = []
476 | whole_probs = []
477 | num_b = self.train_chosen_blocks[client_i]
478 | print(f"Chosen block for testing: {num_b}")
479 | self.network.eval()
480 | with torch.no_grad():
481 | for data in tqdm(self.testloader):
482 | imgs, labels = data[0].to(self.device), data[1].to(self.device)
483 | labels = labels.reshape(labels.shape[0])
484 | tail_output = self.network(x=imgs, chosen_block=num_b, client_idx = client_i)
485 | loss = self.criterion(tail_output, labels)
486 | running_loss_client_i+= loss.item()
487 | _, predicted = torch.max(tail_output, 1)
488 | whole_probs.append(torch.nn.Softmax(dim = -1)(tail_output).detach().cpu())
489 | whole_labels.append(labels.detach().cpu())
490 | whole_preds.append(predicted.detach().cpu())
491 | self.metrics(client_i, whole_labels, whole_preds, running_loss_client_i, len(self.testloader), whole_probs, train= False)
492 |
--------------------------------------------------------------------------------