├── Figures └── method.PNG ├── LICENSE ├── requirements.txt ├── .gitignore ├── centralized.py ├── local.py ├── SLViT.py ├── README.md ├── dataset.py ├── FeSViBS.py ├── utils.py └── models.py /Figures/method.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/faresmalik/FeSViBS/HEAD/Figures/method.PNG -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Faris_Malik 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | asttokens==2.0.5 2 | backcall==0.2.0 3 | certifi==2022.12.7 4 | charset-normalizer==3.1.0 5 | cmake==3.26.0 6 | comm==0.1.2 7 | contourpy==1.0.7 8 | cycler==0.11.0 9 | debugpy==1.5.1 10 | decorator==5.1.1 11 | entrypoints==0.4 12 | executing==0.8.3 13 | filelock==3.10.0 14 | fire==0.5.0 15 | fonttools==4.39.0 16 | idna==3.4 17 | imageio==2.26.0 18 | ipykernel==6.19.2 19 | ipython==8.10.0 20 | jedi==0.18.1 21 | Jinja2==3.1.2 22 | joblib==1.2.0 23 | jupyter_client==7.4.9 24 | jupyter_core==5.2.0 25 | kiwisolver==1.4.4 26 | lazy_loader==0.1 27 | lit==15.0.7 28 | MarkupSafe==2.1.2 29 | matplotlib==3.7.1 30 | matplotlib-inline==0.1.6 31 | medmnist==2.1.0 32 | mpmath==1.3.0 33 | nest-asyncio==1.5.6 34 | networkx==3.0 35 | numpy==1.24.2 36 | nvidia-cublas-cu11==11.10.3.66 37 | nvidia-cuda-cupti-cu11==11.7.101 38 | nvidia-cuda-nvrtc-cu11==11.7.99 39 | nvidia-cuda-runtime-cu11==11.7.99 40 | nvidia-cudnn-cu11==8.5.0.96 41 | nvidia-cufft-cu11==10.9.0.58 42 | nvidia-curand-cu11==10.2.10.91 43 | nvidia-cusolver-cu11==11.4.0.1 44 | nvidia-cusparse-cu11==11.7.4.91 45 | nvidia-nccl-cu11==2.14.3 46 | nvidia-nvtx-cu11==11.7.91 47 | packaging==22.0 48 | pandas==1.5.3 49 | parso==0.8.3 50 | pexpect==4.8.0 51 | pickleshare==0.7.5 52 | Pillow==9.0.1 53 | pip==23.0.1 54 | platformdirs==2.5.2 55 | prompt-toolkit==3.0.36 56 | psutil==5.9.0 57 | ptyprocess==0.7.0 58 | pure-eval==0.2.2 59 | Pygments==2.11.2 60 | pyparsing==3.0.9 61 | python-dateutil==2.8.2 62 | pytz==2022.7.1 63 | PyWavelets==1.4.1 64 | pyzmq==23.2.0 65 | requests==2.28.2 66 | scikit-image==0.20.0 67 | scikit-learn==1.2.2 68 | scipy==1.10.1 69 | setuptools==65.6.3 70 | six==1.16.0 71 | stack-data==0.2.0 72 | sympy==1.11.1 73 | termcolor==2.2.0 74 | threadpoolctl==3.1.0 75 | tifffile==2023.3.15 76 | timm==0.5.4 77 | torch==2.0.0 78 | torchvision==0.15.1 79 | tornado==6.2 80 | tqdm==4.65.0 81 | traitlets==5.7.1 82 | triton==2.0.0 83 | typing_extensions==4.5.0 84 | urllib3==1.26.15 85 | wcwidth==0.2.5 86 | wheel==0.38.4 87 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | vit_base_r50_s16_224_0.0001lr_bloodmnist_200rounds_Centralized/ 132 | vit_base_r50_s16_224_0.0001lr_isic2019_200rounds_Centralized/ 133 | vit_base_r50_s16_224_0.0001lr_isic2019_200rounds_Local/ 134 | vit_base_r50_s16_224_0.0001lr_HAM_200rounds_Local/ 135 | vit_base_r50_s16_224_0.0001lr_bloodmnist_200rounds_Local/ 136 | vit_base_r50_s16_224_0.0001lr_HAM_200rounds_Centralized/ 137 | vit_base_r50_s16_224_0.0001lr_HAM_6Clients_FalseDP_32Batch_SViBS/ 138 | vit_base_r50_s16_224_0.0001lr_bloodmnist_6Clients_FalseDP_32Batch_SViBS/ 139 | vit_base_r50_s16_224_0.0001lr_HAM_6Clients_(4.0, 1e-05)DP_32Batch_SViBS/ 140 | vit_base_r50_s16_224_0.0001lr_bloodmnist_6Clients_(4.0, 1e-05)DP_32Batch_SViBS/ 141 | vit_base_r50_s16_224_0.0001lr_isic2019_6Clients_FalseDP_32Batch_SViBS/ 142 | vit_base_r50_s16_224_0.0001lr_isic2019_6Clients_(4.0, 1e-05)DP_32Batch_SViBS/ 143 | vit_base_r50_s16_224_0.0001lr_HAM_6Clients_1to6Blocks_32Batch__(5.0, 1e-05)DP_FeSViBS/ 144 | vit_base_r50_s16_224_0.0001lr_HAM_6Clients_1to6Blocks_32Batch_FeSViBS/ 145 | vit_base_r50_s16_224_0.0001lr_isic2019_6Clients_1to6Blocks_32Batch__(5.0, 1e-05)DP_FeSViBS/ 146 | vit_base_r50_s16_224_0.0001lr_isic2019_6Clients_1to6Blocks_32Batch_FeSViBS/ 147 | 148 | -------------------------------------------------------------------------------- /centralized.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | import dataset 4 | import os 5 | import random 6 | import numpy as np 7 | from curses.ascii import FF 8 | from models import CentralizedFashion 9 | from torch import nn 10 | import argparse 11 | 12 | from dataset import skinCancer, bloodmnisit, isic2019 13 | 14 | 15 | def centralized(dataset_name, lr, batch_size, Epochs, input_size, num_workers, save_every_epochs, model_name, pretrained, opt_name, seed , base_dir, root_dir, csv_file_path): 16 | 17 | torch.manual_seed(seed) 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | 21 | 22 | if torch.cuda.is_available(): 23 | device = 'cuda' 24 | else: 25 | device = 'cpu' 26 | 27 | 28 | print('Creating Loggings Directory!') 29 | save_dir = f'{model_name}_{lr}lr_{dataset_name}_{Epochs}rounds_Centralized' 30 | os.mkdir(save_dir) 31 | 32 | print('Getting the Dataset and Dataloader!') 33 | if dataset_name == 'HAM': 34 | num_classes = 7 35 | train_loader, test_loader,_,_ = skinCancer(input_size= input_size, batch_size = batch_size, base_dir= base_dir, num_workers=num_workers) 36 | num_channels = 3 37 | 38 | elif dataset_name == 'bloodmnist': 39 | num_classes = 8 40 | train_loader, test_loader,_,_ = bloodmnisit(input_size= input_size, batch_size = batch_size, download= True, num_workers=num_workers) 41 | num_channels = 3 42 | 43 | elif dataset_name == 'isic2019': 44 | num_classes = 8 45 | _, _, train_loader, _, _, test_loader = isic2019(input_size= input_size, batch_size = batch_size, root_dir=root_dir, csv_file_path=csv_file_path, num_workers=num_workers) 46 | num_channels = 3 47 | 48 | print('Getting the model from timm library!') 49 | model = timm.create_model( 50 | model_name= model_name, pretrained= pretrained, 51 | num_classes = num_classes, in_chans=num_channels 52 | ).to(device) 53 | 54 | 55 | criterion = torch.nn.CrossEntropyLoss() 56 | 57 | centralized_network = CentralizedFashion( 58 | device= device, network=model, criterion= criterion, 59 | base_dir=save_dir 60 | ) 61 | 62 | #Instantiate metrics and set optimizer 63 | centralized_network.init_logs() 64 | centralized_network.set_optimizer(name=opt_name, lr = lr) 65 | 66 | print(f'Train Centralized Fashion:\n model: {model_name}\n dataset: {dataset_name}\n LR: {lr}\n Number of Epochs: {Epochs}\n Loggings: {save_dir}\n') 67 | print('Start Training! \n') 68 | 69 | #Training and Evaluation Loop 70 | for r in range(Epochs): 71 | print(f"Round {r+1} / {Epochs}") 72 | centralized_network.train_round(train_loader) 73 | centralized_network.eval_round(test_loader) 74 | print('---------') 75 | if (r+1) % save_every_epochs == 0 and r != 0: 76 | centralized_network.save_pickles(save_dir) 77 | print('============================================') 78 | 79 | 80 | if __name__ == "__main__": 81 | 82 | parser = argparse.ArgumentParser(description='Run Centralized Experiments') 83 | 84 | parser.add_argument('--dataset_name', type=str, choices=['HAM', 'bloodmnist', 'isic2019'], help='Dataset Name') 85 | parser.add_argument('--input_size', type=int, default= 224, help='Input size --> (input_size, input_size), default : 224') 86 | parser.add_argument('--num_workers', type=int, default= 8, help='Number of workers for dataloaders, default : 8') 87 | parser.add_argument('--model_name', type=str, default= 'vit_base_r50_s16_224', help='Model name from timm library, default: vit_base_r50_s16_224') 88 | parser.add_argument('--pretrained', type=bool, default= False, help='Pretrained weights flag, default: False') 89 | parser.add_argument('--batch_size', type=int, default= 32, help='Batch size, default : 32') 90 | parser.add_argument('--Epochs', type=int, default= 200, help='Number of Epochs, default : 200') 91 | parser.add_argument('--opt_name', type=str, choices=['Adam'], default = 'Adam', help='Optimizer name, only ADAM optimizer is available') 92 | parser.add_argument('--lr', type=float, default= 1e-4, help='Learning rate, default : 1e-4') 93 | parser.add_argument('--save_every_epochs', type=int, default= 10, help='Save metrics every this number of epochs, default: 10') 94 | parser.add_argument('--seed', type=int, default= 105, help='Seed, default: 105') 95 | parser.add_argument('--base_dir', type=str, default= None, help='') 96 | parser.add_argument('--root_dir', type=str, default= None, help='') 97 | parser.add_argument('--csv_file_path', type=str, default=None, help='') 98 | 99 | args = parser.parse_args() 100 | 101 | centralized( 102 | dataset_name = args.dataset_name, input_size= args.input_size, 103 | num_workers= args.num_workers, model_name= args.model_name, 104 | pretrained= args.pretrained, batch_size= args.batch_size, 105 | Epochs= args.Epochs, opt_name= args.opt_name, lr= args.lr, 106 | save_every_epochs= args.save_every_epochs, seed= args.seed, 107 | base_dir= args.base_dir, root_dir= args.root_dir, csv_file_path= args.csv_file_path 108 | ) -------------------------------------------------------------------------------- /local.py: -------------------------------------------------------------------------------- 1 | import os 2 | import timm 3 | import torch 4 | import numpy as np 5 | from torch import nn 6 | import os 7 | import random 8 | import argparse 9 | 10 | from models import CentralizedFashion 11 | from dataset import skinCancer, bloodmnisit, isic2019, distribute_images 12 | 13 | 14 | def local(dataset_name, lr, batch_size, Epochs, input_size, num_workers, save_every_epochs, model_name, pretrained, opt_name, seed, base_dir, root_dir, csv_file_path, num_clients, local_arg): 15 | 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | random.seed(seed) 19 | 20 | if torch.cuda.is_available(): 21 | device = 'cuda' 22 | else: 23 | device = 'cpu' 24 | 25 | print('Load Dataset and DataLoader!') 26 | if dataset_name == 'HAM': 27 | num_classes = 7 28 | train_loader, test_loader, train_data, test_data = skinCancer(input_size= input_size, batch_size = batch_size, base_dir= base_dir, num_workers=num_workers) 29 | num_channels = 3 30 | 31 | elif dataset_name == 'bloodmnist': 32 | num_classes = 8 33 | train_loader, test_loader, train_data, test_data = bloodmnisit(input_size= input_size, batch_size = batch_size, download= True, num_workers=num_workers) 34 | num_channels = 3 35 | 36 | elif dataset_name == 'isic2019': 37 | num_classes = 8 38 | DATALOADERS, _, _, _, _, test_loader = isic2019(input_size= input_size, batch_size = batch_size, root_dir=root_dir, csv_file_path=csv_file_path, num_workers=num_workers) 39 | num_channels = 3 40 | 41 | 42 | 43 | print('Create Directory for metrics loggings!') 44 | save_dir = f'{model_name}_{lr}lr_{dataset_name}_{Epochs}rounds_Local' 45 | os.mkdir(save_dir) 46 | 47 | print(f'Train Local Fashion:\n Number of Clients :{num_clients}\n model: {model_name}\n dataset: {dataset_name}\n LR: {lr}\n Number of Epochs: {Epochs}\n Loggings: {save_dir}\n') 48 | 49 | if dataset_name in ['HAM', 'bloodmnist']: 50 | print(f'Distribute Dataset Among {num_clients} Clients') 51 | 52 | DATALOADERS, test_loader = distribute_images( 53 | dataset_name = dataset_name, train_data = train_data, num_clients= num_clients, 54 | test_data = test_data, batch_size = batch_size, num_workers= num_workers 55 | ) 56 | 57 | print('Loading Model form timm Library for All clients!') 58 | model = [timm.create_model( 59 | model_name= model_name, 60 | num_classes= num_classes, 61 | in_chans = num_channels, 62 | pretrained= pretrained, 63 | ).to(device) for i in range(num_clients)] 64 | 65 | criterion = nn.CrossEntropyLoss() 66 | 67 | local = [CentralizedFashion( 68 | device = device, 69 | network = model[i], criterion = criterion, 70 | base_dir = save_dir 71 | ) for i in range(num_clients)] 72 | 73 | 74 | for i in range(num_clients): 75 | local[i].set_optimizer(opt_name, lr = lr) 76 | local[i].init_logs() 77 | 78 | for r in range(Epochs): 79 | print(f"Round {r+1} / {Epochs}") 80 | for client_i in range(num_clients): 81 | print(f'Client {client_i+1} / {num_clients}') 82 | local[client_i].train_round(DATALOADERS[client_i]) 83 | local[client_i].eval_round(test_loader) 84 | print('---------') 85 | if (r+1) % save_every_epochs == 0 and r != 0: 86 | local[client_i].save_pickles(save_dir,local= local_arg, client_id=client_i+1) 87 | print('============================================') 88 | 89 | 90 | if __name__ == "__main__": 91 | 92 | parser = argparse.ArgumentParser(description='Run Centralized Experiments') 93 | 94 | parser.add_argument('--dataset_name', type=str, choices=['HAM', 'bloodmnist', 'isic2019'], help='Dataset Name') 95 | parser.add_argument('--num_clients', type=int, default= 6, help='Number of clients, default : 6') 96 | parser.add_argument('--local_arg', type=bool, default= True, help='Local Argument, default: True') 97 | parser.add_argument('--input_size', type=int, default= 224, help='Input size --> (input_size, input_size), default : 224') 98 | parser.add_argument('--num_workers', type=int, default= 8, help='Number of workers for dataloaders, default : 8') 99 | parser.add_argument('--model_name', type=str, default= 'vit_base_r50_s16_224', help='Model name from timm library, default: vit_base_r50_s16_224') 100 | parser.add_argument('--pretrained', type=bool, default= False, help='Pretrained weights flag, default: False') 101 | parser.add_argument('--batch_size', type=int, default= 32, help='Batch size, default : 32') 102 | parser.add_argument('--Epochs', type=int, default= 200, help='Number of Epochs, default : 200') 103 | parser.add_argument('--opt_name', type=str, choices=['Adam'], default = 'Adam', help='Optimizer name, only ADAM optimizer is available') 104 | parser.add_argument('--lr', type=float, default= 1e-4, help='Learning rate, default : 1e-4') 105 | parser.add_argument('--save_every_epochs', type=int, default= 10, help='Save metrics every this number of epochs, default: 10') 106 | parser.add_argument('--seed', type=int, default= 105, help='Seed, default: 105') 107 | parser.add_argument('--base_dir', type=str, default= None, help='') 108 | parser.add_argument('--root_dir', type=str, default= None, help='') 109 | parser.add_argument('--csv_file_path', type=str, default=None, help='') 110 | 111 | args = parser.parse_args() 112 | 113 | local( 114 | dataset_name = args.dataset_name, num_clients= args.num_clients, 115 | input_size= args.input_size, local_arg= args.local_arg, 116 | num_workers= args.num_workers, model_name= args.model_name, 117 | pretrained= args.pretrained, batch_size= args.batch_size, 118 | Epochs= args.Epochs, opt_name= args.opt_name, lr= args.lr, 119 | save_every_epochs= args.save_every_epochs, seed= args.seed, 120 | base_dir= args.base_dir, root_dir= args.root_dir, csv_file_path= args.csv_file_path 121 | ) -------------------------------------------------------------------------------- /SLViT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch import nn 5 | import random 6 | from models import SLViT, SplitNetwork 7 | from dataset import skinCancer, bloodmnisit, isic2019 8 | import argparse 9 | from utils import weight_dec_global 10 | 11 | def slvit(dataset_name, lr, batch_size, Epochs, input_size, num_workers, save_every_epochs, model_name, pretrained, opt_name, seed , base_dir, root_dir, csv_file_path, num_clients, DP, epsilon, delta): 12 | 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | random.seed(seed) 16 | 17 | if torch.cuda.is_available(): 18 | device = 'cuda' 19 | else: 20 | device = 'cpu' 21 | 22 | mean = 0 23 | std = 1 24 | if DP: 25 | std = np.sqrt(2 * np.math.log(1.25/delta)) / epsilon 26 | 27 | save_dir = f'{model_name}_{lr}lr_{dataset_name}_{num_clients}Clients_{DP}DP_{batch_size}Batch_SLViT' 28 | 29 | if DP: 30 | save_dir = f'{model_name}_{lr}lr_{dataset_name}_{num_clients}Clients_({epsilon}, {delta})DP_{batch_size}Batch_SLViT' 31 | 32 | os.mkdir(save_dir) 33 | 34 | print('Getting the Dataset and Dataloader!') 35 | if dataset_name == 'HAM': 36 | num_classes = 7 37 | _, _, traindataset, testdataset = skinCancer(input_size= input_size, batch_size = batch_size, base_dir= base_dir, num_workers=num_workers) 38 | num_channels = 3 39 | 40 | elif dataset_name == 'bloodmnist': 41 | num_classes = 8 42 | _, _, traindataset, testdataset = bloodmnisit(input_size= input_size, batch_size = batch_size, download= True, num_workers=num_workers) 43 | num_channels = 3 44 | 45 | elif dataset_name == 'isic2019': 46 | num_classes = 8 47 | DATALOADERS, _, _, _, _, test_loader = isic2019(input_size= input_size, batch_size = batch_size, root_dir=root_dir, csv_file_path=csv_file_path, num_workers=num_workers) 48 | num_channels = 3 49 | 50 | slvit = SLViT( 51 | ViT_name= model_name, num_classes=num_classes, 52 | num_clients=num_clients, in_channels=num_channels, 53 | ViT_pretrained = pretrained, 54 | diff_privacy=DP, mean=mean, std = std 55 | ).to(device) 56 | 57 | criterion = nn.CrossEntropyLoss() 58 | 59 | Split = SplitNetwork( 60 | num_clients=num_clients, device = device, 61 | network = slvit, criterion = criterion, base_dir=save_dir, 62 | ) 63 | 64 | print('Distribute Data') 65 | if dataset_name != 'isic2019': 66 | Split.distribute_images(dataset_name=dataset_name, train_data=traindataset, test_data=testdataset , batch_size = batch_size) 67 | else: 68 | Split.CLIENTS_DATALOADERS = DATALOADERS 69 | Split.testloader = test_loader 70 | 71 | Split.set_optimizer(opt_name, lr = lr) 72 | Split.init_logs() 73 | 74 | for r in range(Epochs): 75 | print(f"Round {r+1} / {Epochs}") 76 | agg_weights = None 77 | for client_i in range(num_clients): 78 | weight_dict = Split.train_round(client_i) 79 | if client_i ==0: 80 | agg_weights = weight_dict 81 | else: 82 | agg_weights['blocks'] += weight_dict['blocks'] 83 | agg_weights['cls'] += weight_dict['cls'] 84 | agg_weights['pos_embed'] += weight_dict['pos_embed'] 85 | 86 | agg_weights['blocks'] /= num_clients 87 | agg_weights['cls'] /= num_clients 88 | agg_weights['pos_embed'] /= num_clients 89 | 90 | Split.network.vit.blocks = weight_dec_global( 91 | Split.network.vit.blocks, 92 | agg_weights['blocks'].to(device) 93 | ) 94 | 95 | Split.network.vit.cls_token.data = agg_weights['cls'].to(device) + 0.0 96 | Split.network.vit.pos_embed.data = agg_weights['pos_embed'].to(device) + 0.0 97 | 98 | for client_i in range(num_clients): 99 | Split.eval_round(client_i) 100 | 101 | print('---------') 102 | 103 | if (r+1) % save_every_epochs == 0 and r != 0: 104 | Split.save_pickles(save_dir) 105 | 106 | print('============================================') 107 | 108 | if __name__ == "__main__": 109 | 110 | parser = argparse.ArgumentParser(description='Run Centralized Experiments') 111 | 112 | parser.add_argument('--dataset_name', type=str, choices=['HAM', 'bloodmnist', 'isic2019'], help='Dataset Name') 113 | parser.add_argument('--input_size', type=int, default= 224, help='Input size --> (input_size, input_size), default : 224') 114 | parser.add_argument('--num_workers', type=int, default= 8, help='Number of workers for dataloaders, default : 8') 115 | parser.add_argument('--num_clients', type=int, default= 6, help='Number of Clients, default : 6') 116 | parser.add_argument('--model_name', type=str, default= 'vit_base_r50_s16_224', help='Model name from timm library, default: vit_base_r50_s16_224') 117 | parser.add_argument('--pretrained', type=bool, default= False, help='Pretrained weights flag, default: False') 118 | parser.add_argument('--batch_size', type=int, default= 32, help='Batch size, default : 32') 119 | parser.add_argument('--Epochs', type=int, default= 200, help='Number of Epochs, default : 200') 120 | parser.add_argument('--opt_name', type=str, choices=['Adam'], default = 'Adam', help='Optimizer name, only ADAM optimizer is available') 121 | parser.add_argument('--lr', type=float, default= 1e-4, help='Learning rate, default : 1e-4') 122 | parser.add_argument('--save_every_epochs', type=int, default= 10, help='Save metrics every this number of epochs, default: 10') 123 | parser.add_argument('--seed', type=int, default= 105, help='Seed, default: 105') 124 | parser.add_argument('--base_dir', type=str, default= None, help='') 125 | parser.add_argument('--root_dir', type=str, default= None, help='') 126 | parser.add_argument('--csv_file_path', type=str, default=None, help='') 127 | parser.add_argument('--DP', type=bool, default= False, help='Differential Privacy , default: False') 128 | parser.add_argument('--epsilon', type=float, default= 0, help='Epsilon Value for differential privacy') 129 | parser.add_argument('--delta', type=float, default= 0.00001, help='Delta Value for differential privacy') 130 | 131 | 132 | args = parser.parse_args() 133 | 134 | slvit( 135 | dataset_name = args.dataset_name, input_size= args.input_size, 136 | num_workers= args.num_workers, model_name= args.model_name, 137 | pretrained= args.pretrained, batch_size= args.batch_size, 138 | Epochs= args.Epochs, opt_name= args.opt_name, lr= args.lr, 139 | save_every_epochs= args.save_every_epochs, seed= args.seed, 140 | base_dir= args.base_dir, root_dir= args.root_dir, csv_file_path= args.csv_file_path, num_clients = args.num_clients, 141 | DP = args.DP, epsilon = args.epsilon, delta = args.delta 142 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FeSViBS 2 | Source code for MICCAI 2023 paper entitled: 'FeSViBS: Federated Split Learning of Vision Transformer with Block Sampling' 3 | 4 | 5 |
6 | 7 | ![Method](Figures/method.PNG) 8 | 9 | ## Abstract 10 | Data scarcity is a significant obstacle hindering the learning of powerful machine learning models in critical healthcare applications. Data-sharing mechanisms among multiple entities (e.g., hospitals) can accelerate model training and yield more accurate predictions. Recently, approaches such as Federated Learning (FL) and Split Learning (SL) have facilitated collaboration without the need to exchange private data. In this work, we propose a framework for medical imaging classification tasks called Federated Split learning of Vision transformer with Block Sampling (FeSViBS). The FeSViBS framework builds upon the existing federated split vision transformer and introduces a \emph{block sampling} module, which leverages intermediate features extracted by the Vision Transformer (ViT) at the server. This is achieved by sampling features (patch tokens) from an intermediate transformer block and distilling their information content into a pseudo class token before passing them back to the client. These pseudo class tokens serve as an effective feature augmentation strategy and enhances the generalizability of the learned model. We demonstrate the utility of our proposed method compared to other SL and FL approaches on three publicly available medical imaging datasets: HAM1000, BloodMNIST, and Fed-ISIC2019, under both IID and non-IID settings. 11 | 12 | ## Install Dependinces 13 | Install all dependincies by running the following command: 14 | 15 | ``` 16 | pip install -r requirements.txt 17 | 18 | ``` 19 | 20 | ## Datasets 21 | 22 | We conduct all experiments on **three** datasets: 23 | 24 | 1. HAM10000 [3] -- Can be downloaded from [here](https://www.kaggle.com/datasets/kmader/skin-cancer-mnist-ham10000?select=HAM10000_images_part_2) 25 | 2. Blood cells (BloodMNIST) -- MedMnist library [1] 26 | 3. Federated version of ISIC2019 dataset -- FLamby library [2] 27 | 28 | For the Federated ISIC2019 dataset, the path to __ISIC_2019_Training_Input_preprocessed__ directory and __train_test_split__ csv file, are required to run different methods on this dataset 29 | 30 | ## Running Centralized Training/Testing 31 | In order to run **Centralized Training** run the following command: 32 | 33 | ``` 34 | python centralized.py --dataset_name [choose the dataset name] --opt_name [default is Adam] --lr [learning rate] --seed [seed number] --base_dir [path data folder for HAM] --save_every_epochs [Save pickle files] --root_dir [Path to ISIC_2019_Training_Input_preprocessed for ISIC2019] --csv_file_path [Path to train_test_split csv for ISIC2019] --Epochs [Number of Epochs] 35 | 36 | ``` 37 | 38 | 39 | ## Running Local Training/Testing for Each Client 40 | In order to run **Local Training/Testing** run the following command: 41 | 42 | ``` 43 | python local.py --local_arg True --dataset_name [choose the dataset name] --opt_name [default is Adam] --lr [learning rate] --seed [seed number] --base_dir [path data folder for HAM] --save_every_epochs [Save pickle files] --root_dir [Path to ISIC_2019_Training_Input_preprocessed for ISIC2019] --csv_file_path [Path to train_test_split csv for ISIC2019] --num_clients [Number of clients] --Epochs [Number of Epochs] 44 | 45 | ``` 46 | 47 | ## Running Vanilla Split Learning with Vision Transformers (SLViT) 48 | In order to run **SLViT without** Differential Privacy (DP) run the following command: 49 | 50 | ``` 51 | python SLViT.py --dataset_name [choose the dataset name] --opt_name [default is Adam] --lr [learning rate] --seed [seed number] --base_dir [path data folder for HAM] --save_every_epochs [Save pickle files] --root_dir [Path to ISIC_2019_Training_Input_preprocessed for ISIC2019] --csv_file_path [Path to train_test_split csv for ISIC2019] --num_clients [Number of clients] --Epochs [Number of Epochs] 52 | 53 | ``` 54 | 55 | **SLViT with** Differential Privacy (DP) run the following command: 56 | 57 | ``` 58 | python SLViT.py --DP True --epsilon [epsilon value] --delta [delta value] --dataset_name [choose the dataset name] --opt_name [default is Adam] --lr [learning rate] --seed [seed number] --base_dir [path data folder for HAM] --save_every_epochs [Save pickle files] --root_dir [Path to ISIC_2019_Training_Input_preprocessed for ISIC2019] --csv_file_path [Path to train_test_split csv for ISIC2019] --num_clients [Number of clients] --Epochs [Number of Epochs] 59 | 60 | ``` 61 | 62 | ## Running Split Vision Transformer with Block Sampling (SViBS): 63 | In order to run **SViBS** run the following command: 64 | 65 | ``` 66 | python FeSViBS.py --dataset_name [choose the dataset name] --opt_name [default is Adam] --lr [learning rate] --seed [seed number] --base_dir [path data folder for HAM] --save_every_epochs [Save pickle files] --root_dir [Path to ISIC_2019_Training_Input_preprocessed for ISIC2019] --csv_file_path [Path to train_test_split csv for ISIC2019] --num_clients [Number of clients] --Epochs [Number of Epochs] --initial_block 1 --final_block 6 67 | 68 | ``` 69 | 70 | ## Running Federated Split Vision Transformer with Block Sampling (FeSViBS): 71 | In order to run **FeSViBS without** Differential Privacy (DP) run the following command: 72 | 73 | ``` 74 | python FeSViBS.py --fesvibs_arg True --local_round [number of local rounds before federation] --dataset_name [choose the dataset name] --opt_name [default is Adam] --lr [learning rate] --seed [seed number] --base_dir [path data folder for HAM] --save_every_epochs [Save pickle files] --root_dir [Path to ISIC_2019_Training_Input_preprocessed for ISIC2019] --csv_file_path [Path to train_test_split csv for ISIC2019] --num_clients [Number of clients] --Epochs [Number of Epochs] --initial_block 1 --final_block 6 75 | 76 | ``` 77 | 78 | In order to run **FeSViBS with** Differential Privacy (DP) run the following command: 79 | 80 | ``` 81 | python FeSViBS.py --fesvibs_arg True --DP True --epsilon [epsilon value] --delta [delta value] --local_round [number of local rounds before federation] --dataset_name [choose the dataset name] --opt_name [default is Adam] --lr [learning rate] --seed [seed number] --base_dir [path data folder for HAM] --save_every_epochs [Save pickle files] --root_dir [Path to ISIC_2019_Training_Input_preprocessed for ISIC2019] --csv_file_path [Path to train_test_split csv for ISIC2019] --num_clients [Number of clients] --Epochs [Number of Epochs] --initial_block 1 --final_block 6 82 | 83 | ``` 84 | ## Citation 85 | ``` 86 | @misc{almalik2023fesvibs, 87 | title={FeSViBS: Federated Split Learning of Vision Transformer with Block Sampling}, 88 | author={Faris Almalik and Naif Alkhunaizi and Ibrahim Almakky and Karthik Nandakumar}, 89 | year={2023}, 90 | eprint={2306.14638}, 91 | archivePrefix={arXiv}, 92 | primaryClass={cs.CV} 93 | } 94 | 95 | ``` 96 | ## References 97 | 98 | [1] Yang, J., Shi, R., Ni, B.: Medmnist classification decathlon: A lightweight automl benchmark for medical image analysis. In: IEEE 18th International Symposium on Biomedical Imaging (ISBI). pp. 191–195 (2021) 99 | 100 | [2] du Terrail, J.O., Ayed, S.S., Cyffers, E., Grimberg, F., He, C., Loeb, R., Mangold, P., Marchand, T., Marfoq, O., Mushtaq, E., Muzellec, B., Philippenko, C., Silva, S., Teleńczuk, M., Albarqouni, S., Avestimehr, S., Bellet, A., Dieuleveut, A., Jaggi, M., Karimireddy, S.P., Lorenzi, M., Neglia, G., Tommasi, M., Andreux, M.: FLamby: Datasets and benchmarks for cross-silo federated learning in realistic healthcare settings. In: Thirty-sixth Conference on Neural Information Processing Systems Datasets and Benchmarks Track (2022) 101 | 102 | [3] Tschandl, P., Rosendahl, C., Kittler, H.: The ham10000 dataset, a large collection of multi-source dermatoscopic images of common pigmented skin lesions. Scientific Data 5(11), 180161 (Aug 2018). 103 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import glob 4 | 5 | import numpy as np 6 | 7 | from PIL import Image 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | import torch.utils.data as data 11 | import torch 12 | 13 | import medmnist 14 | from medmnist import INFO 15 | 16 | from utils import get_data, CustomDataset, ISIC2019, blood_noniid, distribute_data 17 | 18 | import random 19 | 20 | seed = 105 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | random.seed(seed) 24 | 25 | 26 | def distribute_images(dataset_name,train_data, num_clients, test_data, batch_size, num_workers = 8): 27 | """ 28 | This method splits the dataset among clients. 29 | train_data: train dataset 30 | test_data: test dataset 31 | batch_size: batch size 32 | 33 | """ 34 | if dataset_name == 'HAM': 35 | CLIENTS_DATALOADERS = distribute_data(num_clients, train_data, batch_size) 36 | testloader = torch.utils.data.DataLoader(test_data,batch_size=batch_size, num_workers= num_workers) 37 | 38 | elif dataset_name == 'bloodmnist': 39 | _, testloader, train_dataset, _ = bloodmnisit(batch_size= batch_size) 40 | _, CLIENTS_DATALOADERS, _ = blood_noniid(num_clients, train_dataset, batch_size =batch_size) 41 | 42 | return CLIENTS_DATALOADERS, testloader 43 | 44 | def bloodmnisit(input_size =224, batch_size = 32, num_workers= 8, download = True): 45 | """ 46 | Get train/test loaders and sets for bloodmnist from medmnist library. 47 | 48 | Input: 49 | input_size (int): width of the input image which issimilar to height 50 | batch_size (int) 51 | num_workers (int): Num of workeres used for in creating the loaders 52 | download (bool): Whether to download the dataset or not 53 | 54 | return: 55 | train_loader, test_loader, train_dataset, test_dataset 56 | """ 57 | 58 | data_flag = 'bloodmnist' 59 | info = INFO[data_flag] 60 | DataClass = getattr(medmnist, info['python_class']) 61 | 62 | data_transform_train = transforms.Compose([ 63 | transforms.RandomVerticalFlip(), 64 | transforms.RandomHorizontalFlip(), 65 | transforms.RandomAffine(degrees= 10, translate=(0.1,0.1)), 66 | transforms.RandomResizedCrop(input_size, (0.75,1), (0.9,1)), 67 | transforms.ToTensor(), 68 | ]) 69 | 70 | data_transform_teest = transforms.Compose([ 71 | transforms.Resize(224), 72 | transforms.ToTensor(), 73 | ]) 74 | 75 | train_dataset = DataClass(split='train', transform=data_transform_train, download=download) 76 | test_dataset = DataClass(split='test', transform=data_transform_teest, download=download) 77 | 78 | train_loader = data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 79 | test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*batch_size, shuffle=False, num_workers=num_workers) 80 | 81 | return train_loader, test_loader, train_dataset, test_dataset 82 | 83 | def skinCancer(input_size = 224, batch_size = 32, base_dir = './data', num_workers = 8): 84 | """ 85 | Get the SkinCancer datasets and dataloaders. 86 | 87 | Input: 88 | input_size (int): width of the input image 89 | batch_size (int) 90 | base_dir (str): Path to directory which includes the skincancer images 91 | num_workers (int): for dataloaders 92 | 93 | return: 94 | train_loader, testing_loader, train_dataset, test_dataset 95 | 96 | """ 97 | all_image_path = glob.glob(os.path.join(base_dir, '*.jpg')) 98 | imageid_path_dict = {os.path.splitext(os.path.basename(x))[0]: x for x in all_image_path} 99 | df_train, df_val = get_data(base_dir, imageid_path_dict) 100 | 101 | normMean = [0.76303697, 0.54564005, 0.57004493] 102 | normStd = [0.14092775, 0.15261292, 0.16997] 103 | 104 | train_transform = transforms.Compose([transforms.RandomResizedCrop((input_size,input_size), scale=(0.9,1.1)), 105 | transforms.ColorJitter(brightness=0.1, contrast=0.1, hue=0.1), 106 | transforms.RandomRotation(10), 107 | transforms.RandomHorizontalFlip(), 108 | transforms.ToTensor(), 109 | transforms.Normalize(normMean, normStd)]) 110 | 111 | # define the transformation of the val images. 112 | val_transform = transforms.Compose([transforms.Resize((input_size,input_size)), 113 | transforms.ToTensor(), 114 | transforms.Normalize(normMean, normStd)]) 115 | 116 | training_set = CustomDataset(df_train.drop_duplicates('image_id'), transform=train_transform) 117 | train_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True, num_workers=num_workers) 118 | 119 | # Same for the validation set: 120 | validation_set = CustomDataset(df_val.drop_duplicates('image_id'), transform=val_transform) 121 | val_loader = DataLoader(validation_set, batch_size=batch_size, shuffle=False, num_workers=num_workers) 122 | 123 | return train_loader, val_loader, training_set, validation_set 124 | 125 | def isic2019(input_size = 224, root_dir = './ISIC_2019_Training_Input_preprocessed', csv_file_path = './train_test_split', batch_size = 32, num_workers=8): 126 | 127 | """ 128 | Function that return train and test dataloaders and datasets fir centralized training and federated settings. 129 | 130 | Input: 131 | root_dir (str): path to directory that has preproceessed images from FLamby library 132 | csv_file_path (str): Path to the csv file that has train_test_split as per FLamby Library 133 | 134 | Return: 135 | Clients train dataloaders (federated), Clients test loaders, Train dataloader (centralized), 136 | Clients train datasets (Federated), Clients test datasets (Federated), Test dataloader (All testing images in one loader) 137 | """ 138 | clients_datasets_train = [ 139 | ISIC2019( 140 | csv_file_path= csv_file_path, 141 | root_dir=root_dir,client_id=i,train=True, centralized=False, input_size= input_size) for i in range(6) 142 | ] 143 | 144 | test_datasets = [ 145 | ISIC2019( 146 | csv_file_path= csv_file_path, 147 | root_dir=root_dir, client_id=i, train=False, centralized=False, input_size= input_size) for i in range(6) 148 | 149 | ] 150 | 151 | centralized_dataset_train = ISIC2019( 152 | csv_file_path= csv_file_path, 153 | root_dir=root_dir, client_id=None ,train=True, centralized=True, input_size= input_size 154 | ) 155 | 156 | clients_dataloader_train = [ 157 | DataLoader( 158 | dataset=clients_datasets_train[i],batch_size= batch_size, shuffle=True, num_workers=num_workers 159 | ) for i in range(6) 160 | ] 161 | 162 | test_dataloaders = [ 163 | DataLoader(dataset=test_datasets[i],batch_size= batch_size, shuffle=False, num_workers=num_workers) 164 | for i in range(6) 165 | ] 166 | 167 | test_centralized_dataset = ISIC2019( 168 | csv_file_path= csv_file_path, 169 | root_dir=root_dir, client_id=None , train=False, centralized=True, input_size= input_size 170 | ) 171 | 172 | test_dataloader_centralized = DataLoader(dataset=test_centralized_dataset,batch_size= batch_size, shuffle=False, num_workers=num_workers) 173 | 174 | 175 | centralized_dataloader_train = DataLoader(dataset=centralized_dataset_train,batch_size= batch_size, shuffle=True, num_workers=num_workers) 176 | 177 | return clients_dataloader_train, test_dataloaders, centralized_dataloader_train, clients_datasets_train, test_datasets, test_dataloader_centralized -------------------------------------------------------------------------------- /FeSViBS.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import models 4 | import random 5 | from dataset import skinCancer, bloodmnisit, isic2019 6 | from utils import weight_dec_global, weight_vec 7 | import argparse 8 | import torch as torch 9 | from torch import nn 10 | 11 | 12 | 13 | 14 | def fesvibs( 15 | dataset_name, lr, batch_size, Epochs, input_size, num_workers, save_every_epochs, 16 | model_name, pretrained, opt_name, seed, base_dir, root_dir, csv_file_path, num_clients, DP, 17 | epsilon, delta, resnet_dropout, initial_block, final_block, fesvibs_arg, local_round 18 | ): 19 | 20 | torch.manual_seed(seed) 21 | random.seed(seed) 22 | np.random.seed(seed) 23 | 24 | if fesvibs_arg: 25 | method_flag = 'FeSViBS' 26 | else: 27 | method_flag = 'SViBS' 28 | 29 | if torch.cuda.is_available(): 30 | device = 'cuda' 31 | else: 32 | device = 'cpu' 33 | 34 | if DP: 35 | std = np.sqrt(2 * np.math.log(1.25/delta)) / epsilon 36 | mean=0 37 | dir_name = f"{model_name}_{lr}lr_{dataset_name}_{num_clients}Clients_{initial_block}to{final_block}Blocks_{batch_size}Batch__{epsilon,delta}DP_{method_flag}" 38 | else: 39 | mean = 0 40 | std = 0 41 | dir_name = f"{model_name}_{lr}lr_{dataset_name}_{num_clients}Clients_{initial_block}to{final_block}Blocks_{batch_size}Batch_{method_flag}" 42 | 43 | save_dir = f'{dir_name}' 44 | os.mkdir(save_dir) 45 | 46 | print(f"Logging to: {dir_name}") 47 | 48 | print('Getting the Dataset and Dataloader!') 49 | if dataset_name == 'HAM': 50 | num_classes = 7 51 | _, _, traindataset, testdataset = skinCancer(input_size= input_size, batch_size = batch_size, base_dir= base_dir, num_workers=num_workers) 52 | num_channels = 3 53 | 54 | elif dataset_name == 'bloodmnist': 55 | num_classes = 8 56 | _, _, traindataset, testdataset = bloodmnisit(input_size= input_size, batch_size = batch_size, download= True, num_workers=num_workers) 57 | num_channels = 3 58 | 59 | elif dataset_name == 'isic2019': 60 | num_classes = 8 61 | DATALOADERS, _, _, _, _, test_loader = isic2019(input_size= input_size, batch_size = batch_size, root_dir=root_dir, csv_file_path=csv_file_path, num_workers=num_workers) 62 | num_channels = 3 63 | 64 | criterion = nn.CrossEntropyLoss() 65 | 66 | fesvibs_network = models.FeSVBiS( 67 | ViT_name= model_name, num_classes= num_classes, 68 | num_clients = num_clients, in_channels = num_channels, 69 | ViT_pretrained= pretrained, 70 | initial_block= initial_block, final_block= final_block, 71 | resnet_dropout= resnet_dropout, DP=DP, mean= mean, std= std 72 | ).to(device) 73 | 74 | Split = models.SplitFeSViBS( 75 | num_clients=num_clients, device = device, network = fesvibs_network, 76 | criterion = criterion, base_dir=save_dir, 77 | initial_block= initial_block, final_block= final_block, 78 | ) 79 | 80 | 81 | if dataset_name != 'isic2019': 82 | print('Distribute Images Among Clients') 83 | Split.distribute_images(dataset_name=dataset_name, train_data= traindataset,test_data= testdataset ,batch_size = batch_size) 84 | else: 85 | Split.CLIENTS_DATALOADERS = DATALOADERS 86 | Split.testloader = test_loader 87 | 88 | Split.set_optimizer(opt_name, lr = lr) 89 | Split.init_logs() 90 | 91 | print('Start Training! \n') 92 | 93 | for r in range(Epochs): 94 | print(f"Round {r+1} / {Epochs}") 95 | agg_weights = None 96 | for client_i in range(num_clients): 97 | weight_dict = Split.train_round(client_i) 98 | if client_i == 0: 99 | agg_weights = weight_dict 100 | else: 101 | agg_weights['blocks'] += weight_dict['blocks'] 102 | agg_weights['cls'] += weight_dict['cls'] 103 | agg_weights['pos_embed'] += weight_dict['pos_embed'] 104 | 105 | agg_weights['blocks'] /= num_clients 106 | agg_weights['cls'] /= num_clients 107 | agg_weights['pos_embed'] /= num_clients 108 | 109 | 110 | Split.network.vit.blocks = weight_dec_global( 111 | Split.network.vit.blocks, 112 | agg_weights['blocks'].to(device) 113 | ) 114 | 115 | Split.network.vit.cls_token.data = agg_weights['cls'].to(device) + 0.0 116 | Split.network.vit.pos_embed.data = agg_weights['pos_embed'].to(device) + 0.0 117 | 118 | if fesvibs_arg and ((r+1) % local_round == 0 and r!= 0): 119 | print('========================== \t \t Federation \t \t ==========================') 120 | tails_weights = [] 121 | head_weights = [] 122 | for head, tail in zip(Split.network.resnet50_clients, Split.network.mlp_clients_tail): 123 | head_weights.append(weight_vec(head).detach().cpu()) 124 | tails_weights.append(weight_vec(tail).detach().cpu()) 125 | 126 | mean_avg_tail = torch.mean(torch.stack(tails_weights), axis = 0) 127 | mean_avg_head = torch.mean(torch.stack(head_weights), axis = 0) 128 | 129 | for i in range(num_clients): 130 | Split.network.mlp_clients_tail[i] = weight_dec_global(Split.network.mlp_clients_tail[i], 131 | mean_avg_tail.to(device)) 132 | Split.network.resnet50_clients[i] = weight_dec_global(Split.network.resnet50_clients[i], 133 | mean_avg_head.to(device)) 134 | 135 | for client_i in range(num_clients): 136 | Split.eval_round(client_i) 137 | 138 | print('---------') 139 | 140 | if (r+1) % save_every_epochs == 0 and r != 0: 141 | Split.save_pickles(save_dir) 142 | print('============================================') 143 | 144 | if __name__ == "__main__": 145 | 146 | parser = argparse.ArgumentParser(description='Run Centralized Experiments') 147 | parser.add_argument('--dataset_name', type=str, choices=['HAM', 'bloodmnist', 'isic2019'], help='Dataset Name') 148 | parser.add_argument('--input_size', type=int, default= 224, help='Input size --> (input_size, input_size), default : 224') 149 | parser.add_argument('--local_round', type=int, default= 2, help='Local round before federation in FeSViBS, default : 2') 150 | parser.add_argument('--num_workers', type=int, default= 8, help='Number of workers for dataloaders, default : 8') 151 | parser.add_argument('--initial_block', type=int, default= 1, help='Initial Block, default : 1') 152 | parser.add_argument('--final_block', type=int, default= 6, help='Final Block, default : 6') 153 | parser.add_argument('--num_clients', type=int, default= 6, help='Number of Clients, default : 6') 154 | parser.add_argument('--model_name', type=str, default= 'vit_base_r50_s16_224', help='Model name from timm library, default: vit_base_r50_s16_224') 155 | parser.add_argument('--pretrained', type=bool, default= False, help='Pretrained weights flag, default: False') 156 | parser.add_argument('--fesvibs_arg', type=bool, default= False, help='Flag to indicate whether SViBS or FeSViBS, default: False') 157 | parser.add_argument('--batch_size', type=int, default= 32, help='Batch size, default : 32') 158 | parser.add_argument('--Epochs', type=int, default= 200, help='Number of Epochs, default : 200') 159 | parser.add_argument('--opt_name', type=str, choices=['Adam'], default = 'Adam', help='Optimizer name, only ADAM optimizer is available') 160 | parser.add_argument('--lr', type=float, default= 1e-4, help='Learning rate, default : 1e-4') 161 | parser.add_argument('--save_every_epochs', type=int, default= 10, help='Save metrics every this number of epochs, default: 10') 162 | parser.add_argument('--seed', type=int, default= 105, help='Seed, default: 105') 163 | parser.add_argument('--base_dir', type=str, default= None, help='') 164 | parser.add_argument('--root_dir', type=str, default= None, help='') 165 | parser.add_argument('--csv_file_path', type=str, default=None, help='') 166 | parser.add_argument('--DP', type=bool, default= False, help='Differential Privacy , default: False') 167 | parser.add_argument('--epsilon', type=float, default= 0, help='Epsilon Value for differential privacy') 168 | parser.add_argument('--delta', type=float, default= 0.00001, help='Delta Value for differential privacy') 169 | parser.add_argument('--resnet_dropout', type=float, default= 0.5, help='ResNet Dropout, Default: 0.5') 170 | args = parser.parse_args() 171 | 172 | fesvibs( 173 | dataset_name = args.dataset_name, input_size= args.input_size, 174 | num_workers= args.num_workers, model_name= args.model_name, 175 | pretrained= args.pretrained, batch_size= args.batch_size, 176 | Epochs= args.Epochs, opt_name= args.opt_name, lr= args.lr, 177 | save_every_epochs= args.save_every_epochs, seed= args.seed, 178 | base_dir= args.base_dir, root_dir= args.root_dir, csv_file_path= args.csv_file_path, num_clients = args.num_clients, 179 | DP = args.DP, epsilon = args.epsilon, delta = args.delta, initial_block= args.initial_block, final_block=args.final_block, 180 | resnet_dropout = args.resnet_dropout, fesvibs_arg = args.fesvibs_arg, local_round = args.local_round 181 | ) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from PIL import Image 7 | 8 | from sklearn.model_selection import train_test_split 9 | from torch.utils.data import Dataset, DataLoader 10 | from torchvision import transforms 11 | 12 | 13 | def weight_vec(network): 14 | A = [] 15 | for w in network.parameters(): 16 | A.append(torch.flatten(w)) 17 | return torch.cat(A) 18 | 19 | 20 | def weight_dec_global(pyModel, weight_vec): 21 | """ 22 | Reshape the weight back to its original shape in pytorch and then 23 | plug it to the model 24 | """ 25 | c = 0 26 | for w in pyModel.parameters(): 27 | m = w.numel() 28 | D = weight_vec[c:m+c].reshape(w.data.shape) 29 | c+=m 30 | if w.data is None: 31 | w.data = D+0 32 | else: 33 | with torch.no_grad(): 34 | w.set_( D+0 ) 35 | return pyModel 36 | 37 | 38 | def distribute_data(numOfClients, train_dataset, batch_size): 39 | """ 40 | numOfClients: int 41 | train_dataset: train_dataset (torchvision.datasets class) 42 | return distributed dataloaders for each client 43 | """ 44 | # distribution list to fill the number of samples in each entry for each client 45 | distribution = [] 46 | # rounding the number to get the number of dataset each client will get 47 | p = round(1/numOfClients * len(train_dataset)) 48 | 49 | # the remainder data that won't be able to split if it's not an even number 50 | remainder_data = len(train_dataset) - numOfClients * p 51 | # if the remainder data is 0 ---> all clients will get the same number of dataset 52 | if remainder_data == 0: 53 | distribution = [p for i in range(numOfClients)] 54 | else: 55 | distribution = [p for i in range(numOfClients-1)] 56 | distribution.append(p+remainder_data) 57 | 58 | # splitting the data to different dataloaders 59 | data_split = torch.utils.data.random_split(train_dataset, distribution) 60 | # CLIENTS DATALOADERS 61 | ClIENTS_DATALOADERS = [torch.utils.data.DataLoader(data_split[i], batch_size=batch_size,shuffle=True, num_workers=32) for i in range(numOfClients)] 62 | 63 | print(f"Length of the training dataset: {len(train_dataset)} sample") 64 | return ClIENTS_DATALOADERS 65 | 66 | def get_data(base_dir, imageid_path_dict): 67 | 68 | """ 69 | Preprocessing for the SkinCancer dataset. 70 | Input: 71 | base_dir (str): path of the directory includes SkinCancer images 72 | imageid_path_dict (dict): dictionary with image id as keys and image pth as values 73 | 74 | Return: 75 | df_train: Dataframe for training 76 | df_val: Dataframe for testing 77 | 78 | """ 79 | 80 | lesion_type_dict = { 81 | 'nv': 'Melanocytic nevi', 82 | 'mel': 'dermatofibroma', 83 | 'bkl': 'Benign keratosis-like lesions ', 84 | 'bcc': 'Basal cell carcinoma', 85 | 'akiec': 'Actinic keratoses', 86 | 'vasc': 'Vascular lesions', 87 | 'df': 'Dermatofibroma' 88 | } 89 | 90 | df_original = pd.read_csv(os.path.join(base_dir, 'HAM10000_metadata.csv')) 91 | df_original['path'] = df_original['image_id'].map(imageid_path_dict.get) 92 | df_original['cell_type'] = df_original['dx'].map(lesion_type_dict.get) 93 | df_original['cell_type_idx'] = pd.Categorical(df_original['cell_type']).codes 94 | 95 | df_original[['cell_type_idx', 'cell_type']].sort_values('cell_type_idx').drop_duplicates() 96 | 97 | # Get number of images associated with each lesion_id 98 | df_undup = df_original.groupby('lesion_id').count() 99 | # Filter out lesion_id's that have only one image associated with it 100 | df_undup = df_undup[df_undup['image_id'] == 1] 101 | df_undup.reset_index(inplace=True) 102 | 103 | # Identify lesion_id's that have duplicate images and those that have only one image. 104 | def get_duplicates(x): 105 | unique_list = list(df_undup['lesion_id']) 106 | if x in unique_list: 107 | return 'unduplicated' 108 | else: 109 | return 'duplicated' 110 | 111 | # create a new colum that is a copy of the lesion_id column 112 | df_original['duplicates'] = df_original['lesion_id'] 113 | 114 | # apply the function to this new column 115 | df_original['duplicates'] = df_original['duplicates'].apply(get_duplicates) 116 | 117 | # Filter out images that don't have duplicates 118 | df_undup = df_original[df_original['duplicates'] == 'unduplicated'] 119 | 120 | # Create a val set using df because we are sure that none of these images have augmented duplicates in the train set 121 | y = df_undup['cell_type_idx'] 122 | _, df_val = train_test_split(df_undup, test_size=0.2, random_state=101, stratify=y) 123 | 124 | 125 | # This set will be df_original excluding all rows that are in the val set 126 | # This function identifies if an image is part of the train or val set. 127 | def get_val_rows(x): 128 | # create a list of all the lesion_id's in the val set 129 | val_list = list(df_val['image_id']) 130 | if str(x) in val_list: 131 | return 'val' 132 | else: 133 | return 'train' 134 | 135 | # Identify train and val rows 136 | # Create a new colum that is a copy of the image_id column 137 | df_original['train_or_val'] = df_original['image_id'] 138 | # Apply the function to this new column 139 | df_original['train_or_val'] = df_original['train_or_val'].apply(get_val_rows) 140 | # Filter out train rows 141 | df_train = df_original[df_original['train_or_val'] == 'train'] 142 | 143 | # Copy fewer class to balance the number of 7 classes 144 | data_aug_rate = [15,10,5,50,0,40,5] 145 | for i in range(7): 146 | if data_aug_rate[i]: 147 | df_train=df_train.append([df_train.loc[df_train['cell_type_idx'] == i,:]]*(data_aug_rate[i]-1), ignore_index=True) 148 | df_train['cell_type'].value_counts() 149 | 150 | df_train = df_train.reset_index() 151 | df_val = df_val.reset_index() 152 | 153 | return df_train, df_val 154 | 155 | class CustomDataset(Dataset): 156 | """ 157 | Cutom dataset for SkinCancer dataset 158 | """ 159 | def __init__(self, df, transform=None): 160 | self.df = df 161 | self.transform = transform 162 | 163 | def __len__(self): 164 | return len(self.df) 165 | 166 | def __getitem__(self, index): 167 | # Load data and get label 168 | X = Image.open(self.df['path'][index]) 169 | y = torch.tensor(int(self.df['cell_type_idx'][index])) 170 | 171 | if self.transform: 172 | X = self.transform(X) 173 | return X, y 174 | 175 | class ISIC2019(Dataset): 176 | 177 | 178 | TO_REPLACE_TRAIN = [None, [4,5,6], None, None,[4], [4,5,6]] 179 | VALUES_TRAIN = [None, [3,4,5], None, None,[2], [3,4,5]] 180 | 181 | def __init__(self, csv_file_path, root_dir, client_id, train = True, centralized = False, input_size = 224) -> None: 182 | super().__init__() 183 | self.image_root = root_dir 184 | self.train = train 185 | csv_file = pd.read_csv(csv_file_path) 186 | self.centralized = centralized 187 | 188 | if train: 189 | if centralized: 190 | self.csv = csv_file[csv_file['fold'] == 'train'].reset_index() 191 | else: 192 | self.csv = csv_file[csv_file['fold2'] == f'train_{client_id}'].reset_index() 193 | 194 | elif train == False: 195 | if centralized: 196 | self.csv = csv_file[csv_file['fold'] == 'test'].reset_index() 197 | else: 198 | self.csv = csv_file[csv_file['fold2'] == f'test_{client_id}'].reset_index() 199 | 200 | if train: 201 | self.transform = transforms.Compose([ 202 | transforms.RandomRotation(10), 203 | transforms.RandomHorizontalFlip(0.5), 204 | transforms.RandomVerticalFlip(0.5), 205 | transforms.RandomAffine(degrees = 0, shear=0.05), 206 | transforms.RandomResizedCrop((input_size, input_size), scale=(0.85,1.1)), 207 | transforms.ToTensor(), 208 | ]) 209 | 210 | elif train == False: 211 | self.transform = transforms.Compose([ 212 | transforms.Resize((input_size, input_size)), 213 | transforms.ToTensor(), 214 | ]) 215 | def __len__(self): 216 | return self.csv.shape[0] 217 | 218 | def __getitem__(self, idx): 219 | if torch.is_tensor(idx): 220 | idx = idx.tolist() 221 | 222 | img_name = os.path.join(self.image_root, 223 | self.csv['image'][idx]+'.jpg') 224 | sample = Image.open(img_name) 225 | target = self.csv['target'][idx] 226 | 227 | sample = self.transform(sample) 228 | 229 | return sample, target 230 | 231 | def blood_noniid(numOfAgents, data, batch_size): 232 | """ 233 | Function to divide the bloodmnist among clients 234 | 235 | Input: 236 | numOfAgents (int): Number of Agents (Clients) 237 | data: dataset to be divided 238 | batch_size (int) 239 | 240 | 241 | Return: 242 | datasets for agents, Loaders for agents , datasets for visualization 243 | 244 | """ 245 | # static way of creating non iid data, to change the distribution change the index of p in 246 | # the for loop 247 | nonIID_tensors = [[] for i in range(numOfAgents)] 248 | nonIID_labels = [[] for i in range(numOfAgents)] 249 | agents = np.arange(0,numOfAgents) 250 | c = 0 251 | p = np.ones((numOfAgents)) 252 | xx = 0 253 | for i in data: 254 | xx+=1 255 | p = np.ones((numOfAgents)) 256 | if float(i[1]) == 0: 257 | p[0] = numOfAgents 258 | p[1] = numOfAgents 259 | p[2] = numOfAgents 260 | if float(i[1]) == 1: 261 | p[0] = numOfAgents 262 | p[1] = numOfAgents 263 | p[2] = numOfAgents 264 | if float(i[1]) == 2: 265 | p[3] = numOfAgents 266 | p[5] = numOfAgents 267 | p[0] = numOfAgents 268 | if float(i[1]) == 3: 269 | p[0] = numOfAgents 270 | p[4] = numOfAgents 271 | p[5] = numOfAgents 272 | if float(i[1]) == 4: 273 | p[3] = numOfAgents 274 | p[4] = numOfAgents 275 | p[5] = numOfAgents 276 | if float(i[1]) == 5: 277 | p[3] = numOfAgents 278 | p[4] = numOfAgents 279 | p[5] = numOfAgents 280 | if float(i[1]) == 6: 281 | p[4] = numOfAgents 282 | p[5] = numOfAgents 283 | p[5] = numOfAgents 284 | if float(i[1]) == 7: 285 | p[0] = numOfAgents 286 | p[1] = numOfAgents 287 | p[2] = numOfAgents 288 | p = p / np.sum(p) 289 | j = np.random.choice(agents, p = p) 290 | nonIID_tensors[j].append(i[0]) 291 | nonIID_labels[j].append(torch.tensor(i[1]).reshape(1)) 292 | 293 | dataset_vis = [[] for i in range(numOfAgents) ] 294 | for i in range(numOfAgents): 295 | dataset_vis[i].append((torch.stack(nonIID_tensors[i]),torch.cat(nonIID_labels[i]))) 296 | 297 | dataset_agents = [[] for i in range(numOfAgents) ] 298 | for agent in range(numOfAgents): 299 | im_ = dataset_vis[agent][0][0] 300 | lab_ = dataset_vis[agent][0][1] 301 | for im, lab in zip(im_, lab_): 302 | dataset_agents[agent].append((im, lab)) 303 | 304 | dataset_loaders = [DataLoader(dataset_agents[i], batch_size=batch_size, shuffle=True, num_workers=8) for i in range(numOfAgents)] 305 | 306 | return dataset_agents, dataset_loaders, dataset_vis -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import pickle as pkl 3 | import os 4 | import timm 5 | import copy 6 | import numpy as np 7 | 8 | import torch.nn as nn 9 | import torch 10 | from sklearn.metrics import balanced_accuracy_score 11 | 12 | from dataset import blood_noniid, bloodmnisit, distribute_data 13 | from utils import weight_vec 14 | 15 | class CentralizedFashion(): 16 | def __init__(self, device, network, criterion, base_dir): 17 | """ 18 | Class for Centralized Paradigm. 19 | args: 20 | device: cuda vs cpu 21 | network: ViT model 22 | criterion: loss function to be used 23 | base_dir: where to save metrics as pickles 24 | return: 25 | None 26 | """ 27 | self.device = device 28 | self.network = network 29 | self.criterion = criterion 30 | self.base_dir = base_dir 31 | 32 | def set_optimizer(self, name, lr): 33 | """ 34 | name: Optimizer name, e.g. Adam 35 | lr: learning rate 36 | 37 | """ 38 | if name == 'Adam': 39 | self.optimizer = torch.optim.Adam(self.network.parameters(), lr = lr) 40 | 41 | def init_logs(self): 42 | """ 43 | A method to initialize dictionaries for the metrics 44 | return : None 45 | args: None 46 | """ 47 | self.losses = {'train':[], 'test':[]} 48 | self.balanced_accs = {'train':[], 'test':[]} 49 | 50 | def train_round(self, train_loader): 51 | """ 52 | Training loop. 53 | 54 | """ 55 | running_loss = 0 56 | whole_labels = [] 57 | whole_preds = [] 58 | whole_probs = [] 59 | for imgs, labels in tqdm(train_loader): 60 | self.optimizer.zero_grad() 61 | imgs, labels = imgs.to(self.device),labels.to(self.device) 62 | output = self.network(imgs) 63 | labels = labels.reshape(labels.shape[0]) 64 | loss = self.criterion(output, labels) 65 | loss.backward() 66 | self.optimizer.step() 67 | running_loss += loss.item() 68 | _, predicted = torch.max(output, 1) 69 | whole_probs.append(torch.nn.Softmax(dim = -1)(output).detach().cpu()) 70 | whole_labels.append(labels.detach().cpu()) 71 | whole_preds.append(predicted.detach().cpu()) 72 | self.metrics(whole_labels, whole_preds, running_loss, len(train_loader), whole_probs, train = True) 73 | 74 | def eval_round(self, test_loader): 75 | """ 76 | Evaluation loop. 77 | 78 | client_i: Client index. 79 | 80 | """ 81 | running_loss = 0 82 | whole_labels = [] 83 | whole_preds = [] 84 | whole_probs = [] 85 | with torch.no_grad(): 86 | for imgs, labels in tqdm(test_loader): 87 | imgs, labels = imgs.to(self.device), labels.to(self.device) 88 | output = self.network(imgs) 89 | labels = labels.reshape(labels.shape[0]) 90 | loss = self.criterion(output, labels) 91 | running_loss += loss.item() 92 | _, predicted = torch.max(output, 1) 93 | whole_probs.append(torch.nn.Softmax(dim = -1)(output).detach().cpu()) 94 | whole_labels.append(labels.detach().cpu()) 95 | whole_preds.append(predicted.detach().cpu()) 96 | self.metrics(whole_labels, whole_preds, running_loss, len(test_loader), whole_probs, train= False) 97 | 98 | def metrics(self, whole_labels, whole_preds, running_loss, len_loader, whole_probs, train): 99 | """ 100 | Save metrics as pickle files and the model as .pt file. 101 | 102 | """ 103 | whole_labels = torch.cat(whole_labels) 104 | whole_preds = torch.cat(whole_preds) 105 | loss_epoch = running_loss/len_loader 106 | balanced_acc = balanced_accuracy_score(whole_labels.detach().cpu(),whole_preds.detach().cpu()) 107 | if train == True: 108 | eval_name = 'train' 109 | else: 110 | eval_name = 'test' 111 | 112 | self.losses[eval_name].append(loss_epoch) 113 | self.balanced_accs[eval_name].append(balanced_acc) 114 | 115 | print(f"{eval_name}:") 116 | print(f"{eval_name}_loss :{loss_epoch:.3f}") 117 | print(f"{eval_name}_balanced_acc :{balanced_acc:.3f}") 118 | 119 | 120 | def save_pickles(self, base_dir, local= None, client_id=None): 121 | if local and client_id: 122 | with open(os.path.join(base_dir,f'loss_epoch_Client{client_id}'), 'wb') as handle: 123 | pkl.dump(self.losses, handle) 124 | with open(os.path.join(base_dir,f'balanced_accs{client_id}'), 'wb') as handle: 125 | pkl.dump(self.balanced_accs, handle) 126 | else: 127 | with open(os.path.join(base_dir,'loss_epoch'), 'wb') as handle: 128 | pkl.dump(self.losses, handle) 129 | with open(os.path.join(base_dir,f'balanced_accs'), 'wb') as handle: 130 | pkl.dump(self.balanced_accs, handle) 131 | 132 | class SLViT(nn.Module): 133 | def __init__( 134 | self, ViT_name, num_classes , num_clients=6, 135 | in_channels=3, ViT_pretrained = False, 136 | diff_privacy = False, mean = 0, std = 1 137 | ) -> None: 138 | 139 | super().__init__() 140 | 141 | self.vit = timm.create_model( 142 | model_name = ViT_name, 143 | pretrained = ViT_pretrained, 144 | num_classes = num_classes, 145 | in_chans = in_channels 146 | ) 147 | client_tail = MLP_cls_classes(num_classes= num_classes) 148 | self.mlp_clients_tail = nn.ModuleList([copy.deepcopy(client_tail)for i in range(num_clients)]) 149 | self.resnet50_clients = nn.ModuleList([copy.deepcopy(self.vit.patch_embed) for i in range(num_clients)]) 150 | 151 | self.diff_privacy = diff_privacy 152 | self.mean = mean 153 | self.std = std 154 | 155 | def forward(self, x, client_idx): 156 | x = self.resnet50_clients[client_idx](x) 157 | if self.diff_privacy == True: 158 | noise = torch.randn(size= x.shape).cuda() * self.std + self.mean 159 | x = x + noise 160 | x = torch.cat((self.vit.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 161 | x = self.vit.pos_drop(x + self.vit.pos_embed) 162 | for block_num in range(12): 163 | x = self.vit.blocks[block_num](x) 164 | x = self.vit.norm(x) 165 | cls = self.vit.pre_logits(x)[:,0,:] 166 | x = self.mlp_clients_tail[client_idx](cls) 167 | return x, cls 168 | 169 | class MLP_cls_classes(nn.Module): 170 | def __init__(self,num_classes): 171 | super().__init__() 172 | self.norm = nn.LayerNorm((768,), eps=1e-06, elementwise_affine=True) 173 | self.identity = nn.Identity() 174 | self.fc = nn.Linear(in_features=768, out_features=num_classes, bias=True) 175 | 176 | def forward(self, x): 177 | x = self.norm(x) 178 | x = self.identity(x) 179 | x = self.fc(x) 180 | return x 181 | 182 | class SplitNetwork(): 183 | def __init__( 184 | self, num_clients, device, network, 185 | criterion, base_dir, 186 | ): 187 | """ 188 | args: 189 | num_clients 190 | device: cuda vs cpu 191 | network: ViT model 192 | criterion: loss function to be used 193 | base_dir: where to save pickles/model files 194 | """ 195 | 196 | self.device = device 197 | self.num_clients = num_clients 198 | self.criterion = criterion 199 | self.network = network 200 | self.base_dir = base_dir 201 | 202 | def init_logs(self): 203 | """ 204 | This method initializes dictionaries for the metrics 205 | 206 | """ 207 | self.losses = {'train':[[] for i in range(self.num_clients)], 'test':[[] for i in range(self.num_clients)]} 208 | self.balanced_accs = {'train':[[] for i in range(self.num_clients)], 'test':[[] for i in range(self.num_clients)]} 209 | 210 | def set_optimizer(self, name, lr): 211 | """ 212 | name: Optimizer name, e.g. Adam 213 | lr: learning rate 214 | 215 | """ 216 | if name == 'Adam': 217 | self.optimizer = torch.optim.Adam(self.network.parameters(), lr = lr) 218 | 219 | def distribute_images(self, dataset_name ,train_data, test_data, batch_size): 220 | """ 221 | This method splits the dataset among clients. 222 | train_data: train dataset 223 | test_data: test dataset 224 | batch_size: batch size 225 | 226 | """ 227 | if dataset_name == 'HAM': 228 | self.CLIENTS_DATALOADERS = distribute_data(self.num_clients, train_data, batch_size) 229 | self.testloader = torch.utils.data.DataLoader(test_data,batch_size=batch_size, num_workers= 8) 230 | 231 | elif dataset_name == 'bloodmnist': 232 | _, self.testloader, train_dataset, _ = bloodmnisit(batch_size= batch_size) 233 | _, self.CLIENTS_DATALOADERS, _ = blood_noniid(self.num_clients, train_dataset, batch_size =batch_size) 234 | 235 | def train_round(self, client_i): 236 | """ 237 | Training loop. 238 | 239 | client_i: Client index. 240 | 241 | """ 242 | running_loss_client_i = 0 243 | mel_running_loss = 0 244 | whole_labels = [] 245 | whole_preds = [] 246 | whole_probs = [] 247 | copy_network = copy.deepcopy(self.network) 248 | weight_dic = {'blocks':None, 'cls':None, 'pos_embed':None} 249 | self.network.train() 250 | for data in tqdm(self.CLIENTS_DATALOADERS[client_i]): 251 | self.optimizer.zero_grad() 252 | imgs, labels = data[0].to(self.device), data[1].to(self.device) 253 | labels = labels.reshape(labels.shape[0]) 254 | tail_output = self.network(imgs, client_i) 255 | loss = self.criterion(tail_output[0], labels) 256 | loss.backward() 257 | self.optimizer.step() 258 | running_loss_client_i+= loss.item() 259 | _, predicted = torch.max(tail_output[0], 1) 260 | whole_probs.append(torch.nn.Softmax(dim = -1)(tail_output[0]).detach().cpu()) 261 | whole_labels.append(labels.detach().cpu()) 262 | whole_preds.append(predicted.detach().cpu()) 263 | self.metrics(client_i, whole_labels, whole_preds, running_loss_client_i, len(self.CLIENTS_DATALOADERS[client_i]), whole_probs, train = True) 264 | 265 | # if self.avg_body: 266 | weight_dic['blocks'] = weight_vec(self.network.vit.blocks).detach().cpu() 267 | weight_dic['cls'] = self.network.vit.cls_token.detach().cpu() 268 | weight_dic['pos_embed'] = self.network.vit.pos_embed.detach().cpu() 269 | 270 | self.network.vit.blocks = copy.deepcopy(copy_network.vit.blocks) 271 | self.network.vit.cls_token = copy.deepcopy(copy_network.vit.cls_token) 272 | self.network.vit.pos_embed = copy.deepcopy(copy_network.vit.pos_embed) 273 | return weight_dic 274 | 275 | def eval_round(self, client_i): 276 | """ 277 | Evaluation loop. 278 | 279 | client_i: Client index. 280 | 281 | """ 282 | running_loss_client_i = 0 283 | whole_labels = [] 284 | whole_preds = [] 285 | whole_probs = [] 286 | self.network.eval() 287 | with torch.no_grad(): 288 | for data in tqdm(self.testloader): 289 | imgs, labels = data[0].to(self.device), data[1].to(self.device) 290 | tail_output = self.network(imgs, client_i)[0] 291 | labels = labels.reshape(labels.shape[0]) 292 | loss = self.criterion(tail_output, labels) 293 | running_loss_client_i+= loss.item() 294 | _, predicted = torch.max(tail_output, 1) 295 | whole_probs.append(torch.nn.Softmax(dim = -1)(tail_output).detach().cpu()) 296 | whole_labels.append(labels.detach().cpu()) 297 | whole_preds.append(predicted.detach().cpu()) 298 | self.metrics(client_i, whole_labels, whole_preds, running_loss_client_i, len(self.testloader), whole_probs, train= False) 299 | 300 | def metrics(self, client_i, whole_labels, whole_preds, running_loss_client_i, len_loader, whole_probs, train): 301 | """ 302 | Save metrics as pickle files and the model as .pt file. 303 | 304 | """ 305 | whole_labels = torch.cat(whole_labels) 306 | whole_preds = torch.cat(whole_preds) 307 | loss_epoch = running_loss_client_i/len_loader 308 | balanced_acc = balanced_accuracy_score(whole_labels.detach().cpu(), whole_preds.detach().cpu()) 309 | 310 | if train == True: 311 | eval_name = 'train' 312 | else: 313 | eval_name = 'test' 314 | 315 | self.losses[eval_name][client_i].append(loss_epoch) 316 | self.balanced_accs[eval_name][client_i].append(balanced_acc) 317 | 318 | print(f"client{client_i}_{eval_name}:") 319 | print(f" Loss {eval_name}:{loss_epoch:.3f}") 320 | print(f"balanced accuracy {eval_name}:{balanced_acc:.3f}") 321 | 322 | def save_pickles(self, base_dir): 323 | with open(os.path.join(base_dir,'loss_epoch'), 'wb') as handle: 324 | pkl.dump(self.losses, handle) 325 | with open(os.path.join(base_dir,'balanced_accs'), 'wb') as handle: 326 | pkl.dump(self.balanced_accs, handle) 327 | 328 | class FeSVBiS(nn.Module): 329 | def __init__( 330 | self, ViT_name, num_classes, 331 | num_clients=6, in_channels=3, ViT_pretrained=False, 332 | initial_block=1, final_block=6, resnet_dropout = None, DP = False, mean = None, std = None 333 | ) -> None: 334 | super().__init__() 335 | 336 | self.initial_block = initial_block 337 | self.final_block = final_block 338 | 339 | self.vit = timm.create_model( 340 | model_name = ViT_name, 341 | pretrained = ViT_pretrained, 342 | num_classes = num_classes, 343 | in_chans = in_channels 344 | ) 345 | 346 | self.resnet50 = self.vit.patch_embed 347 | self.resnet50_clients = nn.ModuleList([copy.deepcopy(self.resnet50) for i in range(num_clients)]) 348 | self.common_network = ResidualBlock(drop_out=resnet_dropout) 349 | client_tail = MLP_cls_classes(num_classes= num_classes) 350 | self.mlp_clients_tail = nn.ModuleList([copy.deepcopy(client_tail) for i in range(num_clients)]) 351 | self.DP = DP 352 | self.mean = mean 353 | self.std = std 354 | 355 | def forward(self, x, chosen_block, client_idx): 356 | x = self.resnet50_clients[client_idx](x) 357 | if self.DP: 358 | noise = torch.randn(size= x.shape).cuda() * self.std + self.mean 359 | x = x + noise 360 | for block_num in range(chosen_block): 361 | x = self.vit.blocks[block_num](x) 362 | x = self.common_network(x) 363 | x = self.mlp_clients_tail[client_idx](x) 364 | return x 365 | 366 | 367 | class ResidualBlock(nn.Module): 368 | def __init__(self, in_channels=768, out_channels=768, stride = 1, downsample = None, drop_out= None): 369 | super(ResidualBlock, self).__init__() 370 | self.conv1 = nn.Sequential( 371 | nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1), 372 | nn.BatchNorm2d(out_channels), 373 | nn.ReLU()) 374 | self.conv2 = nn.Sequential( 375 | nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1), 376 | nn.BatchNorm2d(out_channels)) 377 | self.downsample = downsample 378 | self.relu = nn.ReLU() 379 | self.out_channels = out_channels 380 | self.pool = nn.AvgPool2d(14, stride=1) 381 | self.dropout = nn.Dropout2d(p=drop_out) 382 | self.drop_out = drop_out 383 | 384 | def forward(self, x): 385 | if len(x.shape) == 3: 386 | x = torch.permute(x,(0,-1,1)) 387 | x = x.reshape(x.shape[0], x.shape[1] , 14, 14) 388 | residual = x 389 | out = self.conv1(x) 390 | if self.drop_out is not None: 391 | out = self.dropout(out) 392 | out = self.conv2(out) 393 | if self.downsample: 394 | residual = self.downsample(x) 395 | out += residual 396 | out = self.relu(out) 397 | out = self.pool(out) 398 | return out.reshape(-1,768) 399 | 400 | class SplitFeSViBS(SplitNetwork): 401 | def __init__( 402 | self, num_clients, device, 403 | network, criterion, base_dir, 404 | initial_block, final_block, 405 | ): 406 | 407 | self.initial_block = initial_block 408 | self.final_block = final_block 409 | self.num_clients = num_clients 410 | self.device = device 411 | self.network = network 412 | self.criterion = criterion 413 | self.base_dir = base_dir 414 | self.train_chosen_blocks = [0] * num_clients 415 | 416 | def set_optimizer_mel(self, name, lr): 417 | if name == 'Adam': 418 | self.optimizer_mel = [torch.optim.Adam(self.mel_body[i].parameters(), lr = lr) for i in range(self.num_clients)] 419 | 420 | def train_round(self, client_i): 421 | """ 422 | Training loop. 423 | 424 | client_i: Client index. 425 | 426 | """ 427 | running_loss_client_i = 0 428 | whole_labels = [] 429 | whole_preds = [] 430 | whole_probs = [] 431 | self.chosen_block = np.random.randint(low = self.initial_block, high= self.final_block+1) 432 | self.train_chosen_blocks[client_i] = self.chosen_block 433 | copy_network = copy.deepcopy(self.network) 434 | weight_dic = {} 435 | weight_dic['blocks'] = None 436 | weight_dic['cls'] = None 437 | weight_dic['pos_embed'] = None 438 | weight_dic['resnet'] = None 439 | print(f"Chosen Block:{self.chosen_block} for client {client_i}") 440 | self.network.train() 441 | for data in tqdm(self.CLIENTS_DATALOADERS[client_i]): 442 | self.optimizer.zero_grad() 443 | imgs, labels = data[0].to(self.device), data[1].to(self.device) 444 | labels = labels.reshape(labels.shape[0]) 445 | tail_output = self.network(x=imgs, chosen_block=self.chosen_block, client_idx = client_i) 446 | loss = self.criterion(tail_output, labels) 447 | loss.backward() 448 | self.optimizer.step() 449 | running_loss_client_i+= loss.item() 450 | _, predicted = torch.max(tail_output, 1) 451 | whole_probs.append(torch.nn.Softmax(dim = -1)(tail_output).detach().cpu()) 452 | whole_labels.append(labels.detach().cpu()) 453 | whole_preds.append(predicted.detach().cpu()) 454 | self.metrics(client_i, whole_labels, whole_preds, running_loss_client_i, len(self.CLIENTS_DATALOADERS[client_i]), whole_probs, train = True) 455 | 456 | weight_dic['blocks'] = weight_vec(self.network.vit.blocks).detach().cpu() 457 | weight_dic['cls'] = self.network.vit.cls_token.detach().cpu() 458 | weight_dic['pos_embed'] = self.network.vit.pos_embed.detach().cpu() 459 | 460 | self.network.vit.blocks = copy.deepcopy(copy_network.vit.blocks) 461 | self.network.vit.cls_token = copy.deepcopy(copy_network.vit.cls_token) 462 | self.network.vit.pos_embed = copy.deepcopy(copy_network.vit.pos_embed) 463 | return weight_dic 464 | 465 | 466 | def eval_round(self, client_i): 467 | """ 468 | Evaluation loop. 469 | 470 | client_i: Client index. 471 | 472 | """ 473 | running_loss_client_i = 0 474 | whole_labels = [] 475 | whole_preds = [] 476 | whole_probs = [] 477 | num_b = self.train_chosen_blocks[client_i] 478 | print(f"Chosen block for testing: {num_b}") 479 | self.network.eval() 480 | with torch.no_grad(): 481 | for data in tqdm(self.testloader): 482 | imgs, labels = data[0].to(self.device), data[1].to(self.device) 483 | labels = labels.reshape(labels.shape[0]) 484 | tail_output = self.network(x=imgs, chosen_block=num_b, client_idx = client_i) 485 | loss = self.criterion(tail_output, labels) 486 | running_loss_client_i+= loss.item() 487 | _, predicted = torch.max(tail_output, 1) 488 | whole_probs.append(torch.nn.Softmax(dim = -1)(tail_output).detach().cpu()) 489 | whole_labels.append(labels.detach().cpu()) 490 | whole_preds.append(predicted.detach().cpu()) 491 | self.metrics(client_i, whole_labels, whole_preds, running_loss_client_i, len(self.testloader), whole_probs, train= False) 492 | --------------------------------------------------------------------------------