├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── fedbase ├── __init__.py ├── baselines │ ├── __init__.py │ ├── central.py │ ├── ditto.py │ ├── fedavg.py │ ├── fedavg_ensemble.py │ ├── fedprox.py │ ├── fedprox_ensemble.py │ ├── fesem.py │ ├── fesem_cam.py │ ├── fesem_con.py │ ├── ifca.py │ ├── ifca_cam.py │ ├── ifca_con.py │ ├── local.py │ ├── wecfl.py │ ├── wecfl_cam.py │ └── wecfl_con.py ├── model │ ├── __init__.py │ ├── model.py │ └── resnet.py ├── nodes │ ├── __init__.py │ └── node.py ├── server │ ├── __init__.py │ └── server.py └── utils │ ├── __init__.py │ ├── cfl_utils │ ├── data_loader.py │ ├── femnist.py │ ├── model_utils.py │ ├── tools.py │ └── visualize.py ├── setup.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | command.txt 2 | .pypirc 3 | publish.sh 4 | result_analysis.py 5 | ablation.py 6 | visual.ipynb 7 | 8 | legacy/ 9 | data/ 10 | example/data/ 11 | log/ 12 | vis/ 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | pip-wheel-metadata/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | db.sqlite3-journal 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 JieMA 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedBase 2 | An easy, modularized, DIY Federated Learning framework with many baselines for individual researchers. 3 | 4 | ## Installation 5 | [fedbase @ pypi](https://pypi.org/project/fedbase/) 6 | ```python 7 | pip install --upgrade fedbase 8 | ``` 9 | 10 | ## Baselines 11 | 1. Centralized training 12 | 2. Local training 13 | 3. FedAvg, [Communication-Efficient Learning of Deep Networksfrom Decentralized Data](https://arxiv.org/abs/1602.05629) 14 | 4. FedAvg + Finetune 15 | 5. Fedprox, [Federated Optimization in Heterogeneous Networks](https://arxiv.org/abs/1812.06127) 16 | 5. Ditto, [Ditto: Fair and Robust Federated Learning Through Personalization](https://arxiv.org/abs/2012.04221) 17 | 6. WeCFL, [On the Convergence of Clustered Federated Learning](https://arxiv.org/abs/2202.06187) 18 | 7. IFCA, [An Efficient Framework for Clustered Federated Learning](https://arxiv.org/abs/2006.04088) 19 | 8. FeSEM, [Multi-Center Federated Learning](https://arxiv.org/abs/2005.01026) 20 | 8. To be continued... 21 | 22 | ## Three steps to achieve FedAvg! 23 | 1. Data partition 24 | 2. Nodes and server simulation 25 | 3. Train and test 26 | 27 | ## Design philosophy 28 | 1. Dataset 29 | 1. Dataset 30 | 1. MNIST 31 | 2. CIFAR-10 32 | 3. Fashion-MNIST 33 | 4. ... 34 | 2. Dataset partition 35 | 1. IID 36 | 2. Non-IID 37 | 1. Dirichlet distribution 38 | 2. N-class 39 | 3. ... 40 | 3. Fake data 41 | 4. ... 42 | 43 | 2. Node 44 | 1. Local dataset 45 | 2. Model 46 | 3. Objective 47 | 4. Optimizer 48 | 5. Local update 49 | 6. Test 50 | 3. Server 51 | 1. Model 52 | 2. Aggregate 53 | 3. Distribute 54 | 4. Server & Node 55 | 1. Topology 56 | 2. Client sampling 57 | 3. Exchange message 58 | 5. Baselines 59 | 1. Global 60 | 2. Local 61 | 3. FedAvg 62 | 6. Visualization 63 | 64 | ## How to develop your own FL with fedbase? -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | name = "fedbase" -------------------------------------------------------------------------------- /fedbase/__init__.py: -------------------------------------------------------------------------------- 1 | name = "fedbase" -------------------------------------------------------------------------------- /fedbase/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | modules = glob.glob(join(dirname(__file__), "*.py")) 4 | __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')] -------------------------------------------------------------------------------- /fedbase/baselines/central.py: -------------------------------------------------------------------------------- 1 | from fedbase.utils.data_loader import data_process, log 2 | from fedbase.utils.tools import add_ 3 | from fedbase.nodes.node import node 4 | import torch 5 | from torch.utils.data import DataLoader 6 | import torch.optim as optim 7 | import os 8 | 9 | 10 | def run(dataset, batch_size, model, objective, optimizer, global_rounds, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 11 | dt = dataset 12 | trainset,testset = dt.train_dataset, dt.test_dataset 13 | trainloader = DataLoader(trainset, batch_size=batch_size, 14 | shuffle=True) 15 | testloader = DataLoader(testset, batch_size=batch_size, 16 | shuffle=False) 17 | 18 | nodes0 = node(0, device) 19 | nodes0.assign_train(trainloader) 20 | nodes0.assign_test(testloader) 21 | nodes0.assign_model(model()) 22 | nodes0.assign_objective(objective()) 23 | nodes0.assign_optim(optimizer(nodes0.model.parameters())) 24 | 25 | print('-------------------start-------------------') 26 | for i in range(global_rounds): 27 | nodes0.local_update_epochs(1) 28 | nodes0.local_test() 29 | 30 | # log 31 | log(os.path.basename(__file__)[:-3]+ add_(dt.dataset_name), [nodes0], server={}) 32 | 33 | return nodes0.model -------------------------------------------------------------------------------- /fedbase/baselines/ditto.py: -------------------------------------------------------------------------------- 1 | from fedbase.utils.data_loader import data_process, log 2 | from fedbase.utils.tools import add_ 3 | from fedbase.nodes.node import node 4 | from fedbase.server.server import server_class 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import torch.optim as optim 8 | import os 9 | from functools import partial 10 | 11 | def run(dataset_splited, batch_size, num_nodes, model, objective, optimizer, global_rounds, local_steps, reg, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 12 | # dt = data_process(dataset) 13 | # train_splited, test_splited = dt.split_dataset(num_nodes, split['split_para'], split['split_method']) 14 | train_splited, test_splited, split_para = dataset_splited 15 | 16 | server = server_class(device) 17 | server.assign_model(model()) 18 | 19 | nodes = [node(i, device) for i in range(num_nodes)] 20 | # local_models = [model() for i in range(num_nodes)] 21 | # local_loss = [objective() for i in range(num_nodes)] 22 | 23 | for i in range(num_nodes): 24 | # data 25 | # print(len(train_splited[i]), len(test_splited[i])) 26 | nodes[i].assign_train(DataLoader(train_splited[i], batch_size=batch_size, shuffle=True)) 27 | nodes[i].assign_test(DataLoader(test_splited[i], batch_size=batch_size, shuffle=False)) 28 | # model 29 | nodes[i].assign_model(model()) 30 | # objective 31 | nodes[i].assign_objective(objective()) 32 | # optim 33 | nodes[i].assign_optim(optimizer(nodes[i].model.parameters())) 34 | 35 | del train_splited, test_splited 36 | 37 | # initialize parameters to nodes 38 | weight_list = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 39 | server.distribute([nodes[i].model for i in range(num_nodes)]) 40 | 41 | # train! 42 | for i in range(global_rounds): 43 | print('-------------------Global round %d start-------------------' % (i)) 44 | # single-processing! 45 | for j in range(num_nodes): 46 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_fedprox, reg_model = server.model, lam= reg)) 47 | nodes[j].local_test() 48 | # server aggregation 49 | server.model.load_state_dict(server.aggregate([nodes[i].model for i in range(num_nodes)], weight_list)) 50 | # test accuracy 51 | server.acc(nodes, weight_list) 52 | 53 | # log 54 | log(os.path.basename(__file__)[:-3] + add_(reg) + add_(split_para) , nodes, server) 55 | 56 | return [nodes[i].model for i in range(num_nodes)] -------------------------------------------------------------------------------- /fedbase/baselines/fedavg.py: -------------------------------------------------------------------------------- 1 | from fedbase.utils.data_loader import data_process, log 2 | from fedbase.utils.tools import add_ 3 | from fedbase.nodes.node import node 4 | from fedbase.server.server import server_class 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import torch.optim as optim 8 | import os 9 | from functools import partial 10 | 11 | def run(dataset_splited, batch_size, num_nodes, model, objective, optimizer, global_rounds, local_steps,\ 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'), log_file=True, finetune=False, finetune_steps = None): 13 | # dt = data_process(dataset) 14 | # train_splited, test_splited = dt.split_dataset(num_nodes, split['split_para'], split['split_method']) 15 | train_splited, test_splited, split_para = dataset_splited 16 | 17 | server = server_class(device) 18 | server.assign_model(model()) 19 | 20 | nodes = [node(i, device) for i in range(num_nodes)] 21 | # local_models = [model() for i in range(num_nodes)] 22 | # local_loss = [objective() for i in range(num_nodes)] 23 | 24 | for i in range(num_nodes): 25 | # data 26 | # print(len(train_splited[i]), len(test_splited[i])) 27 | nodes[i].assign_train(DataLoader(train_splited[i], batch_size=batch_size, shuffle=True)) 28 | nodes[i].assign_test(DataLoader(test_splited[i], batch_size=batch_size, shuffle=False)) 29 | # model 30 | nodes[i].assign_model(model()) 31 | # objective 32 | nodes[i].assign_objective(objective()) 33 | # optim 34 | nodes[i].assign_optim(optimizer(nodes[i].model.parameters())) 35 | 36 | del train_splited, test_splited 37 | 38 | # initialize parameters to nodes 39 | server.distribute([nodes[i].model for i in range(num_nodes)]) 40 | weight_list = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 41 | 42 | # train! 43 | for t in range(global_rounds): 44 | print('-------------------Global round %d start-------------------' % (t)) 45 | # single-processing! 46 | for j in range(num_nodes): 47 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step)) 48 | # server aggregation and distribution 49 | server.model.load_state_dict(server.aggregate([nodes[i].model for i in range(num_nodes)], weight_list)) 50 | server.distribute([nodes[i].model for i in range(num_nodes)]) 51 | # test accuracy 52 | for j in range(num_nodes): 53 | nodes[j].local_test() 54 | server.acc(nodes, weight_list) 55 | 56 | if not finetune: 57 | # log 58 | if log_file: 59 | log(os.path.basename(__file__)[:-3] + add_(split_para), nodes, server) 60 | return server.model 61 | else: 62 | if not finetune_steps: 63 | finetune_steps = local_steps 64 | # fine tune 65 | for j in range(num_nodes): 66 | nodes[j].local_update_steps(finetune_steps, partial(nodes[j].train_single_step)) 67 | nodes[j].local_test() 68 | server.acc(nodes, weight_list) 69 | # log 70 | log(os.path.basename(__file__)[:-3] + add_('finetune') + add_(split_para), nodes, server) 71 | return [nodes[i].model for i in range(num_nodes)] -------------------------------------------------------------------------------- /fedbase/baselines/fedavg_ensemble.py: -------------------------------------------------------------------------------- 1 | from fedbase.utils.data_loader import data_process, log 2 | from fedbase.utils.tools import add_ 3 | from fedbase.nodes.node import node 4 | from fedbase.server.server import server_class 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import torch.optim as optim 8 | import os 9 | from functools import partial 10 | 11 | def run(dataset_splited, batch_size, num_nodes, model, objective, optimizer, global_rounds, local_steps, n_ensemble, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 12 | # dt = data_process(dataset) 13 | # train_splited, test_splited = dt.split_dataset(num_nodes, split['split_para'], split['split_method']) 14 | train_splited, test_splited, split_para = dataset_splited 15 | 16 | models = [] 17 | for _ in range(n_ensemble): 18 | server = server_class(device) 19 | server.assign_model(model()) 20 | 21 | nodes = [node(i, device) for i in range(num_nodes)] 22 | # local_models = [model() for i in range(num_nodes)] 23 | # local_loss = [objective() for i in range(num_nodes)] 24 | 25 | for i in range(num_nodes): 26 | # data 27 | # print(len(train_splited[i]), len(test_splited[i])) 28 | nodes[i].assign_train(DataLoader(train_splited[i], batch_size=batch_size, shuffle=True)) 29 | nodes[i].assign_test(DataLoader(test_splited[i], batch_size=batch_size, shuffle=False)) 30 | # model 31 | nodes[i].assign_model(model()) 32 | # objective 33 | nodes[i].assign_objective(objective()) 34 | # optim 35 | nodes[i].assign_optim(optimizer(nodes[i].model.parameters())) 36 | 37 | # initialize parameters to nodes 38 | weight_list = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 39 | server.distribute([nodes[i].model for i in range(num_nodes)]) 40 | 41 | # train! 42 | for i in range(global_rounds): 43 | print('-------------------Global round %d start-------------------' % (i)) 44 | # single-processing! 45 | for j in range(num_nodes): 46 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step)) 47 | # server aggregation and distribution 48 | server.model.load_state_dict(server.aggregate([nodes[i].model for i in range(num_nodes)], weight_list)) 49 | server.distribute([nodes[i].model for i in range(num_nodes)]) 50 | # test accuracy 51 | for j in range(num_nodes): 52 | nodes[j].local_test() 53 | server.acc(nodes, weight_list) 54 | 55 | # ensemble 56 | models.append(server.model) 57 | 58 | # test ensemble 59 | print('test ensemble\n') 60 | for j in range(num_nodes): 61 | nodes[j].local_ensemble_test(models, voting = 'soft') 62 | server.acc(nodes, list(range(num_nodes))) 63 | 64 | # log 65 | log(os.path.basename(__file__)[:-3] + add_(n_ensemble) + add_(split_para), nodes, server) 66 | 67 | return models 68 | 69 | -------------------------------------------------------------------------------- /fedbase/baselines/fedprox.py: -------------------------------------------------------------------------------- 1 | from fedbase.utils.data_loader import data_process, log 2 | from fedbase.nodes.node import node 3 | from fedbase.utils.tools import add_ 4 | from fedbase.server.server import server_class 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import torch.optim as optim 8 | import os 9 | from functools import partial 10 | 11 | def run(dataset_splited, batch_size, num_nodes, model, objective, optimizer, global_rounds, local_steps, reg, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 12 | # dt = data_process(dataset) 13 | # train_splited, test_splited = dt.split_dataset(num_nodes, split['split_para'], split['split_method']) 14 | train_splited, test_splited, split_para = dataset_splited 15 | print('data splited') 16 | server = server_class(device) 17 | server.assign_model(model()) 18 | 19 | nodes = [node(i, device) for i in range(num_nodes)] 20 | # local_models = [model() for i in range(num_nodes)] 21 | # local_loss = [objective() for i in range(num_nodes)] 22 | 23 | for i in range(num_nodes): 24 | # data 25 | # print(len(train_splited[i]), len(test_splited[i])) 26 | nodes[i].assign_train(DataLoader(train_splited[i], batch_size=batch_size, shuffle=True)) 27 | nodes[i].assign_test(DataLoader(test_splited[i], batch_size=batch_size, shuffle=False)) 28 | # model 29 | nodes[i].assign_model(model()) 30 | # objective 31 | nodes[i].assign_objective(objective()) 32 | # optim 33 | nodes[i].assign_optim(optimizer(nodes[i].model.parameters())) 34 | 35 | del train_splited, test_splited 36 | 37 | # initialize parameters to nodes 38 | weight_list = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 39 | server.distribute([nodes[i].model for i in range(num_nodes)]) 40 | 41 | # train! 42 | for i in range(global_rounds): 43 | print('-------------------Global round %d start-------------------' % (i)) 44 | # single-processing! 45 | for j in range(num_nodes): 46 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_fedprox, reg_model = server.model, reg_lam= reg)) 47 | # server aggregation and distribution 48 | server.model.load_state_dict(server.aggregate([nodes[i].model for i in range(num_nodes)], weight_list)) 49 | server.distribute([nodes[i].model for i in range(num_nodes)]) 50 | # test accuracy 51 | for j in range(num_nodes): 52 | nodes[j].local_test() 53 | server.acc(nodes, weight_list) 54 | 55 | # log 56 | log(os.path.basename(__file__)[:-3] + add_(reg) + add_(split_para) , nodes, server) 57 | 58 | return server.model -------------------------------------------------------------------------------- /fedbase/baselines/fedprox_ensemble.py: -------------------------------------------------------------------------------- 1 | from fedbase.utils.data_loader import data_process, log 2 | from fedbase.nodes.node import node 3 | from fedbase.utils.tools import add_ 4 | from fedbase.server.server import server_class 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import torch.optim as optim 8 | import os 9 | from functools import partial 10 | 11 | def run(dataset_splited, batch_size, num_nodes, model, objective, optimizer, global_rounds, local_steps, reg, n_ensemble, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 12 | # dt = data_process(dataset) 13 | # train_splited, test_splited = dt.split_dataset(num_nodes, split['split_para'], split['split_method']) 14 | train_splited, test_splited, split_para = dataset_splited 15 | print('data splited') 16 | 17 | models = [] 18 | for _ in range(n_ensemble): 19 | server = server_class(device) 20 | server.assign_model(model()) 21 | 22 | nodes = [node(i, device) for i in range(num_nodes)] 23 | # local_models = [model() for i in range(num_nodes)] 24 | # local_loss = [objective() for i in range(num_nodes)] 25 | 26 | for i in range(num_nodes): 27 | # data 28 | # print(len(train_splited[i]), len(test_splited[i])) 29 | nodes[i].assign_train(DataLoader(train_splited[i], batch_size=batch_size, shuffle=True)) 30 | nodes[i].assign_test(DataLoader(test_splited[i], batch_size=batch_size, shuffle=False)) 31 | # model 32 | nodes[i].assign_model(model()) 33 | # objective 34 | nodes[i].assign_objective(objective()) 35 | # optim 36 | nodes[i].assign_optim(optimizer(nodes[i].model.parameters())) 37 | 38 | # initialize parameters to nodes 39 | weight_list = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 40 | server.distribute([nodes[i].model for i in range(num_nodes)]) 41 | 42 | # train! 43 | for i in range(global_rounds): 44 | print('-------------------Global round %d start-------------------' % (i)) 45 | # single-processing! 46 | for j in range(num_nodes): 47 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_fedprox, reg_model = server.model, lam= reg)) 48 | # server aggregation and distribution 49 | server.model.load_state_dict(server.aggregate([nodes[i].model for i in range(num_nodes)], weight_list)) 50 | server.distribute([nodes[i].model for i in range(num_nodes)]) 51 | # test accuracy 52 | for j in range(num_nodes): 53 | nodes[j].local_test() 54 | server.acc(nodes, weight_list) 55 | 56 | # ensemble 57 | models.append(server.model) 58 | 59 | # test ensemble 60 | print('test ensemble\n') 61 | for j in range(num_nodes): 62 | nodes[j].local_ensemble_test(models, voting = 'soft') 63 | server.acc(nodes, list(range(num_nodes))) 64 | 65 | # log 66 | log(os.path.basename(__file__)[:-3] + add_(n_ensemble) + add_(reg) + add_(split_para), nodes, server) 67 | 68 | return models 69 | -------------------------------------------------------------------------------- /fedbase/baselines/fesem.py: -------------------------------------------------------------------------------- 1 | from fedbase.utils.data_loader import data_process, log 2 | from fedbase.nodes.node import node 3 | from fedbase.utils.tools import add_ 4 | from fedbase.server.server import server_class 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import torch.optim as optim 8 | from fedbase.model.model import CNNCifar, CNNMnist 9 | import os 10 | import sys 11 | import inspect 12 | from functools import partial 13 | 14 | def run(dataset_splited, batch_size, K, num_nodes, model, objective, optimizer, global_rounds, local_steps, \ 15 | reg_lam = None, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'), finetune=False, finetune_steps = None): 16 | # dt = data_process(dataset) 17 | # train_splited, test_splited = dt.split_dataset(num_nodes, split['split_para'], split['split_method']) 18 | train_splited, test_splited, split_para = dataset_splited 19 | server = server_class(device) 20 | server.assign_model(model()) 21 | 22 | nodes = [node(i, device) for i in range(num_nodes)] 23 | # local_models = [model() for i in range(num_nodes)] 24 | # local_loss = [objective() for i in range(num_nodes)] 25 | 26 | for i in range(num_nodes): 27 | # data 28 | # print(len(train_splited[i]), len(test_splited[i])) 29 | nodes[i].assign_train(DataLoader(train_splited[i], batch_size=batch_size, shuffle=True)) 30 | nodes[i].assign_test(DataLoader(test_splited[i], batch_size=batch_size, shuffle=False)) 31 | # model 32 | nodes[i].assign_model(model()) 33 | # objective 34 | nodes[i].assign_objective(objective()) 35 | # optim 36 | nodes[i].assign_optim(optimizer(nodes[i].model.parameters())) 37 | 38 | del train_splited, test_splited 39 | 40 | # initialize parameters to nodes 41 | server.distribute([nodes[i].model for i in range(num_nodes)]) 42 | weight_list = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 43 | 44 | # initialize K cluster model 45 | cluster_models = [model() for i in range(K)] 46 | 47 | # train! 48 | for t in range(global_rounds): 49 | print('-------------------Global round %d start-------------------' % (t)) 50 | # local update 51 | for j in range(num_nodes): 52 | if not reg_lam or t == 0: 53 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step)) 54 | else: 55 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_fedprox, reg_model = cluster_models[nodes[j].label], reg_lam= reg_lam)) 56 | # server clustering 57 | server.weighted_clustering(nodes, list(range(num_nodes)), K, weight_type= 'equal') 58 | 59 | # server aggregation and distribution by cluster 60 | for j in range(K): 61 | assign_ls = [i for i in list(range(num_nodes)) if nodes[i].label==j] 62 | weight_ls = [nodes[i].data_size/sum([nodes[i].data_size for i in assign_ls]) for i in assign_ls] 63 | model_k = server.aggregate([nodes[i].model for i in assign_ls], weight_ls) 64 | server.distribute([nodes[i].model for i in assign_ls], model_k) 65 | cluster_models[j].load_state_dict(model_k) 66 | 67 | # test accuracy 68 | for j in range(num_nodes): 69 | nodes[j].local_test() 70 | server.acc(nodes, weight_list) 71 | 72 | if not finetune: 73 | assign = [[i for i in range(num_nodes) if nodes[i].label == k] for k in range(K)] 74 | # log 75 | log(os.path.basename(__file__)[:-3] + add_(K) + add_(reg_lam) + add_(split_para), nodes, server) 76 | return cluster_models, assign 77 | else: 78 | if not finetune_steps: 79 | finetune_steps = local_steps 80 | # fine tune 81 | for j in range(num_nodes): 82 | if not reg_lam: 83 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step)) 84 | else: 85 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_fedprox, reg_model = cluster_models[nodes[j].label], reg_lam= reg_lam)) 86 | nodes[j].local_test() 87 | server.acc(nodes, weight_list) 88 | # log 89 | log(os.path.basename(__file__)[:-3] + add_('finetune') + add_(K) + add_(reg_lam) + add_(split_para), nodes, server) 90 | return [nodes[i].model for i in range(num_nodes)] -------------------------------------------------------------------------------- /fedbase/baselines/fesem_cam.py: -------------------------------------------------------------------------------- 1 | from fedbase.utils.data_loader import data_process, log 2 | from fedbase.utils.visualize import dimension_reduction 3 | from fedbase.utils.tools import add_ 4 | from fedbase.nodes.node import node 5 | from fedbase.server.server import server_class 6 | from fedbase.baselines import local 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torch.optim as optim 10 | from fedbase.model.model import CNNCifar, CNNMnist 11 | import os 12 | import sys 13 | import inspect 14 | from functools import partial 15 | import numpy as np 16 | 17 | def run(dataset_splited, batch_size, K, num_nodes, model, objective, optimizer, warmup_rounds, global_rounds, local_steps, \ 18 | reg_lam = None, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'), finetune=False, finetune_steps = None): 19 | train_splited, test_splited, split_para = dataset_splited 20 | # warmup 21 | local_models_warmup = local.run(dataset_splited, batch_size, num_nodes, model, objective, optimizer, warmup_rounds, local_steps, device = device, log_file=False) 22 | 23 | # initialize 24 | server = server_class(device) 25 | server.assign_model(model()) 26 | server.model_g = model() 27 | 28 | nodes = [node(i, device) for i in range(num_nodes)] 29 | 30 | for i in range(num_nodes): 31 | # data 32 | # print(len(train_splited[i]), len(test_splited[i])) 33 | nodes[i].assign_train(DataLoader(train_splited[i], batch_size=batch_size, shuffle=True)) 34 | nodes[i].assign_test(DataLoader(test_splited[i], batch_size=batch_size, shuffle=False)) 35 | # model 36 | nodes[i].assign_model(local_models_warmup[i]) 37 | nodes[i].model_g = model() 38 | nodes[i].model_g.to(device) 39 | # objective 40 | nodes[i].assign_objective(objective()) 41 | # optim 42 | nodes[i].assign_optim({'local_0': optimizer(nodes[i].model.parameters()),\ 43 | 'local_1': optimizer(nodes[i].model_g.parameters()),\ 44 | 'all': optimizer(list(nodes[i].model.parameters())+list(nodes[i].model_g.parameters()))}) 45 | 46 | del train_splited, test_splited 47 | 48 | # initialize K cluster model 49 | cluster_models = [model() for i in range(K)] 50 | 51 | # initialize clustering and distribute 52 | server.weighted_clustering(nodes, list(range(num_nodes)), K, weight_type= 'equal') 53 | for j in range(K): 54 | assign_ls = [i for i in list(range(num_nodes)) if nodes[i].label==j] 55 | weight_ls = [nodes[i].data_size/sum([nodes[i].data_size for i in assign_ls]) for i in assign_ls] 56 | model_k = server.aggregate([nodes[i].model for i in assign_ls], weight_ls) 57 | server.distribute([nodes[i].model for i in assign_ls], model_k) 58 | cluster_models[j].load_state_dict(model_k) 59 | 60 | weight_list = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 61 | # train! 62 | for i in range(global_rounds - warmup_rounds): 63 | print('-------------------Global round %d start-------------------' % (i)) 64 | # update model_g 65 | for j in range(num_nodes): 66 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_res, optimizer = nodes[j].optim['local_1'],\ 67 | model_opt = nodes[j].model_g, model_fix = nodes[j].model)) 68 | 69 | # update local model 70 | for j in range(num_nodes): 71 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_res, optimizer = nodes[j].optim['local_0'], \ 72 | model_opt = nodes[j].model, model_fix = nodes[j].model_g, reg_lam = reg_lam, reg_model = cluster_models[nodes[j].label])) 73 | 74 | # aggregate and distribute model_g 75 | weight_all = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 76 | server.model.load_state_dict(server.aggregate([nodes[i].model_g for i in range(num_nodes)], weight_all)) 77 | server.distribute([nodes[i].model_g for i in range(num_nodes)]) 78 | 79 | # server clustering 80 | server.weighted_clustering(nodes, list(range(num_nodes)), K, weight_type= 'equal') 81 | 82 | # server aggregation and distribution by cluster 83 | for j in range(K): 84 | assign_ls = [i for i in list(range(num_nodes)) if nodes[i].label==j] 85 | weight_ls = [nodes[i].data_size/sum([nodes[i].data_size for i in assign_ls]) for i in assign_ls] 86 | model_k = server.aggregate([nodes[i].model for i in assign_ls], weight_ls) 87 | server.distribute([nodes[i].model for i in assign_ls], model_k) 88 | cluster_models[j].load_state_dict(model_k) 89 | 90 | # test accuracy 91 | for j in range(num_nodes): 92 | nodes[j].local_test(model_res = nodes[j].model_g) 93 | server.acc(nodes, weight_list) 94 | 95 | if not finetune: 96 | assign = [[i for i in range(num_nodes) if nodes[i].label == k] for k in range(K)] 97 | # log 98 | log(os.path.basename(__file__)[:-3] + add_(K) + add_(reg_lam) + add_(split_para), nodes, server) 99 | return cluster_models, assign 100 | else: 101 | if not finetune_steps: 102 | finetune_steps = local_steps 103 | # fine tune 104 | # update model_g 105 | for j in range(num_nodes): 106 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_res, optimizer = nodes[j].optim['local_1'],\ 107 | model_opt = nodes[j].model_g, model_fix = nodes[j].model)) 108 | 109 | # update local model 110 | for j in range(num_nodes): 111 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_res, optimizer = nodes[j].optim['local_0'], \ 112 | model_opt = nodes[j].model, model_fix = nodes[j].model_g, reg_lam = reg_lam, reg_model = cluster_models[nodes[j].label])) 113 | nodes[j].local_test() 114 | server.acc(nodes, weight_list) 115 | # log 116 | log(os.path.basename(__file__)[:-3] + add_('finetune') + add_(K) + add_(reg_lam) + add_(split_para), nodes, server) 117 | return [nodes[i].model for i in range(num_nodes)] 118 | -------------------------------------------------------------------------------- /fedbase/baselines/fesem_con.py: -------------------------------------------------------------------------------- 1 | from fedbase.utils.data_loader import data_process, log 2 | from fedbase.utils.tools import add_ 3 | from fedbase.utils.visualize import dimension_reduction 4 | from fedbase.nodes.node import node 5 | from fedbase.server.server import server_class 6 | import torch 7 | from torch.utils.data import DataLoader 8 | import torch.optim as optim 9 | from fedbase.model.model import CNNCifar, CNNMnist 10 | import os 11 | import sys 12 | import inspect 13 | from functools import partial 14 | import numpy as np 15 | 16 | def run(dataset_splited, batch_size, K, num_nodes, model, objective, optimizer, global_rounds, local_steps, warmup_rounds, tmp, mu, base, reg_lam = None, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 17 | # dt = data_process(dataset) 18 | # train_splited, test_splited = dt.split_dataset(num_nodes, split['split_para'], split['split_method']) 19 | train_splited, test_splited, split_para = dataset_splited 20 | server = server_class(device) 21 | server.assign_model(model()) 22 | 23 | nodes = [node(i, device) for i in range(num_nodes)] 24 | # local_models = [model() for i in range(num_nodes)] 25 | # local_loss = [objective() for i in range(num_nodes)] 26 | 27 | for i in range(num_nodes): 28 | # data 29 | # print(len(train_splited[i]), len(test_splited[i])) 30 | nodes[i].assign_train(DataLoader(train_splited[i], batch_size=batch_size, shuffle=True)) 31 | nodes[i].assign_test(DataLoader(test_splited[i], batch_size=batch_size, shuffle=False)) 32 | # model 33 | nodes[i].assign_model(model()) 34 | # objective 35 | nodes[i].assign_objective(objective()) 36 | # optim 37 | nodes[i].assign_optim(optimizer(nodes[i].model.parameters())) 38 | 39 | del train_splited, test_splited 40 | 41 | # initialize parameters to nodes 42 | server.distribute([nodes[i].model for i in range(num_nodes)]) 43 | 44 | # initialize K cluster model 45 | cluster_models = [model().to(device) for i in range(K)] 46 | 47 | # train! 48 | # b_list = [] 49 | # uu_list = [] 50 | for i in range(global_rounds): 51 | print('-------------------Global round %d start-------------------' % (i)) 52 | 53 | # local update 54 | weight_list = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 55 | server.model.load_state_dict(server.aggregate([nodes[i].model for i in range(num_nodes)], weight_list)) 56 | for j in range(num_nodes): 57 | if i == 0: 58 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step)) 59 | elif i < warmup_rounds: 60 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_con, \ 61 | model_sim = cluster_models[nodes[j].label], model_all = cluster_models, tmp = tmp, mu = mu, base = None\ 62 | , reg_lam = reg_lam, reg_model = server.model)) 63 | else: 64 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_con, \ 65 | model_sim = cluster_models[nodes[j].label], model_all = cluster_models, tmp = tmp, mu = mu, base = base\ 66 | , reg_lam = reg_lam, reg_model = server.model)) 67 | 68 | # # tsne or pca plot 69 | # # if i == global_rounds-1: 70 | # if i == i : 71 | # cluster_data = torch.nn.utils.parameters_to_vector(nodes[0].model.parameters()).cpu().detach().numpy()[-1000:] 72 | # for i in range(1, num_nodes): 73 | # cluster_data = np.concatenate((cluster_data, torch.nn.utils.parameters_to_vector(nodes[i].model.parameters()).cpu().detach().numpy()[-1000:]), axis = 0) 74 | # cluster_data = np.reshape(cluster_data, (num_nodes,int(len(cluster_data)/num_nodes))) 75 | # cluster_label = server.clustering['label'][-1] 76 | # # cluster_label = np.repeat(range(10),20) 77 | # dimension_reduction(cluster_data, cluster_label, method= 'tsne') 78 | # plot B 79 | # print(server.calculate_B(nodes, range(20))) 80 | # B_list, u_list = server.calculate_B(nodes, range(20)) 81 | # b_list.append(max(B_list)) 82 | # uu_list.append(max(u_list)) 83 | # print(b_list, uu_list) 84 | # for k in range(num_nodes): 85 | # nodes[k].grads = [] 86 | 87 | # server clustering 88 | server.weighted_clustering(nodes, list(range(num_nodes)), K, weight_type= 'equal') 89 | 90 | # server aggregation and distribution by cluster 91 | for j in range(K): 92 | assign_ls = [i for i in list(range(num_nodes)) if nodes[i].label==j] 93 | weight_ls = [nodes[i].data_size/sum([nodes[i].data_size for i in assign_ls]) for i in assign_ls] 94 | model_k = server.aggregate([nodes[i].model for i in assign_ls], weight_ls) 95 | server.distribute([nodes[i].model for i in assign_ls], model_k) 96 | cluster_models[j].load_state_dict(model_k) 97 | 98 | # test accuracy 99 | for j in range(num_nodes): 100 | nodes[j].local_test() 101 | server.acc(nodes, weight_list) 102 | 103 | assign = [[i for i in range(num_nodes) if nodes[i].label == k] for k in range(K)] 104 | # log 105 | log(os.path.basename(__file__)[:-3] + add_(K) + add_(base) + add_(tmp) + add_(mu) + add_(reg_lam) + add_(split_para), nodes, server) 106 | 107 | return cluster_models, assign 108 | -------------------------------------------------------------------------------- /fedbase/baselines/ifca.py: -------------------------------------------------------------------------------- 1 | from fedbase.utils.data_loader import data_process, log 2 | from fedbase.nodes.node import node 3 | from fedbase.utils.tools import add_ 4 | from fedbase.server.server import server_class 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import torch.optim as optim 8 | from fedbase.model.model import CNNCifar, CNNMnist 9 | import os 10 | import sys 11 | import inspect 12 | from functools import partial 13 | 14 | def run(dataset_splited, batch_size, K, num_nodes, model, objective, optimizer, global_rounds, local_steps, reg = None, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'), finetune=False, finetune_steps = None): 15 | # dt = data_process(dataset) 16 | # train_splited, test_splited = dt.split_dataset(num_nodes, split['split_para'], split['split_method']) 17 | train_splited, test_splited, split_para = dataset_splited 18 | server = server_class(device) 19 | server.assign_model(model()) 20 | 21 | nodes = [node(i, device) for i in range(num_nodes)] 22 | # local_models = [model() for i in range(num_nodes)] 23 | # local_loss = [objective() for i in range(num_nodes)] 24 | 25 | for i in range(num_nodes): 26 | # data 27 | # print(len(train_splited[i]), len(test_splited[i])) 28 | nodes[i].assign_train(DataLoader(train_splited[i], batch_size=batch_size, shuffle=True)) 29 | nodes[i].assign_test(DataLoader(test_splited[i], batch_size=batch_size, shuffle=False)) 30 | # model 31 | # nodes[i].assign_model(local_models[i]) 32 | # objective 33 | nodes[i].assign_objective(objective()) 34 | # optim 35 | # nodes[i].assign_optim(optimizer(model().parameters())) 36 | 37 | del train_splited, test_splited 38 | 39 | # initialize parameters to nodes 40 | # server.distribute(nodes, list(range(num_nodes))) 41 | 42 | # initialize K cluster model 43 | cluster_models = [model() for i in range(K)] 44 | 45 | # train! 46 | weight_list = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 47 | for t in range(global_rounds): 48 | print('-------------------Global round %d start-------------------' % (t)) 49 | # assign client to cluster 50 | assignment = [[] for _ in range(K)] 51 | for i in range(num_nodes): 52 | m = 0 53 | for k in range(1, K): 54 | # print(nodes[i].local_train_loss(cluster_models[m]), nodes[i].local_train_loss(cluster_models[k])) 55 | if nodes[i].local_train_loss(cluster_models[m])>=nodes[i].local_train_loss(cluster_models[k]): 56 | m = k 57 | assignment[m].append(i) 58 | nodes[i].label = m 59 | nodes[i].assign_model(cluster_models[m]) 60 | nodes[i].assign_optim(optimizer(nodes[i].model.parameters())) 61 | # print(server.clustering) 62 | server.clustering['label'].append(assignment) 63 | print(assignment) 64 | print([len(assignment[i]) for i in range(len(assignment))]) 65 | 66 | # local update 67 | for j in range(num_nodes): 68 | if not reg: 69 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step)) 70 | else: 71 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_fedprox, reg_model = server.aggregate(nodes, list(range(num_nodes))), lam= reg)) 72 | 73 | # server aggregation and distribution by cluster 74 | for k in range(K): 75 | if len(assignment[k])>0: 76 | weight_ls = [nodes[i].data_size/sum([nodes[i].data_size for i in assignment[k]]) for i in assignment[k]] 77 | model_k = server.aggregate([nodes[i].model for i in assignment[k]], weight_ls) 78 | server.distribute([nodes[i].model for i in assignment[k]], model_k) 79 | cluster_models[k].load_state_dict(model_k) 80 | 81 | # test accuracy 82 | for i in range(num_nodes): 83 | nodes[i].local_test() 84 | server.acc(nodes, weight_list) 85 | 86 | if not finetune: 87 | assign = [[i for i in range(num_nodes) if nodes[i].label == k] for k in range(K)] 88 | # log 89 | log(os.path.basename(__file__)[:-3] + add_(K) + add_(reg) + add_(split_para), nodes, server) 90 | return cluster_models, assign 91 | else: 92 | if not finetune_steps: 93 | finetune_steps = local_steps 94 | # fine tune 95 | for j in range(num_nodes): 96 | if not reg: 97 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step)) 98 | else: 99 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_fedprox, reg_model = server.aggregate(nodes, list(range(num_nodes))), lam= reg)) 100 | nodes[j].local_test() 101 | server.acc(nodes, weight_list) 102 | # log 103 | log(os.path.basename(__file__)[:-3] + add_('finetune') + add_(K) + add_(reg) + add_(split_para), nodes, server) 104 | return [nodes[i].model for i in range(num_nodes)] -------------------------------------------------------------------------------- /fedbase/baselines/ifca_cam.py: -------------------------------------------------------------------------------- 1 | from fedbase.utils.data_loader import data_process, log 2 | from fedbase.nodes.node import node 3 | from fedbase.utils.tools import add_ 4 | from fedbase.server.server import server_class 5 | from fedbase.baselines import fedavg 6 | import torch 7 | from torch.utils.data import DataLoader 8 | import torch.optim as optim 9 | from fedbase.model.model import CNNCifar, CNNMnist 10 | import os 11 | import sys 12 | import inspect 13 | from functools import partial 14 | 15 | def run(dataset_splited, batch_size, K, num_nodes, model, objective, optimizer, warmup_rounds, global_rounds, local_steps, \ 16 | reg = None, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'), finetune=False, finetune_steps = None): 17 | train_splited, test_splited, split_para = dataset_splited 18 | # warm up 19 | model_g = fedavg.run(dataset_splited, batch_size, num_nodes, model, objective, optimizer, warmup_rounds, local_steps, device, log_file = False) 20 | 21 | # initialize 22 | server = server_class(device) 23 | server.assign_model(model_g) 24 | 25 | nodes = [node(i, device) for i in range(num_nodes)] 26 | 27 | for i in range(num_nodes): 28 | # data 29 | # print(len(train_splited[i]), len(test_splited[i])) 30 | nodes[i].assign_train(DataLoader(train_splited[i], batch_size=batch_size, shuffle=True)) 31 | nodes[i].assign_test(DataLoader(test_splited[i], batch_size=batch_size, shuffle=False)) 32 | # objective 33 | nodes[i].assign_objective(objective()) 34 | nodes[i].assign_model(model()) 35 | nodes[i].model_g = model().to(device) 36 | nodes[i].assign_optim({'local_0': optimizer(nodes[i].model.parameters()),\ 37 | 'local_1': optimizer(nodes[i].model_g.parameters()),\ 38 | 'all': optimizer(list(nodes[i].model.parameters())+list(nodes[i].model_g.parameters()))}) 39 | 40 | del train_splited, test_splited 41 | 42 | # initialize K cluster model 43 | cluster_models = [model().to(device) for i in range(K)] 44 | 45 | weight_list = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 46 | 47 | # distribute global model to model_g 48 | server.distribute([nodes[i].model_g for i in range(num_nodes)]) 49 | 50 | # train! 51 | for i in range(global_rounds - warmup_rounds): 52 | print('-------------------Global round %d start-------------------' % (i)) 53 | 54 | # local update model_0 for cluster model 55 | for j in range(num_nodes): 56 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_res, optimizer = nodes[j].optim['local_0'], \ 57 | model_opt = nodes[j].model, model_fix = nodes[j].model_g)) 58 | 59 | # local update model_g for global model 60 | for j in range(num_nodes): 61 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_res, optimizer = nodes[j].optim['local_1'],\ 62 | model_opt = nodes[j].model_g, model_fix = nodes[j].model)) 63 | 64 | # server clustering 65 | assignment = [[] for _ in range(K)] 66 | for i in range(num_nodes): 67 | m = 0 68 | for k in range(1, K): 69 | # if i <=5: 70 | # print(nodes[i].local_train_acc(cluster_models[m]), nodes[i].local_train_acc(cluster_models[k])) 71 | # if nodes[i].local_train_loss(cluster_models[m])>=nodes[i].local_train_loss(cluster_models[k]): 72 | if nodes[i].local_train_acc(cluster_models[m])<=nodes[i].local_train_acc(cluster_models[k]): 73 | m = k 74 | assignment[m].append(i) 75 | nodes[i].label = m 76 | 77 | server.clustering['label'].append(assignment) 78 | # print(assignment) 79 | print([len(assignment[i]) for i in range(len(assignment))]) 80 | 81 | # server aggregation and distribution by cluster 82 | for k in range(K): 83 | if len(assignment[k])>0: 84 | model_k = server.aggregate([nodes[i].model for i in assignment[k]], \ 85 | [nodes[i].data_size/sum([nodes[i].data_size for i in assignment[k]]) for i in assignment[k]]) 86 | server.distribute([nodes[i].model for i in assignment[k]], model_k) 87 | cluster_models[k].load_state_dict(model_k) 88 | 89 | # aggregate model_g 90 | server.model.load_state_dict(server.aggregate([nodes[i].model_g for i in range(num_nodes)], weight_list)) 91 | # distribute global model to model_g 92 | server.distribute([nodes[i].model_g for i in range(num_nodes)]) 93 | 94 | # test accuracy 95 | for j in range(num_nodes): 96 | nodes[j].local_test(model_res = nodes[j].model_g) 97 | server.acc(nodes, weight_list) 98 | 99 | if not finetune: 100 | assign = [[i for i in range(num_nodes) if nodes[i].label == k] for k in range(K)] 101 | # log 102 | log(os.path.basename(__file__)[:-3] + add_(K) + add_(reg) + add_(split_para), nodes, server) 103 | return cluster_models, assign 104 | else: 105 | if not finetune_steps: 106 | finetune_steps = local_steps 107 | # fine tune 108 | # local update model for cluster model 109 | for j in range(num_nodes): 110 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_res, optimizer = nodes[j].optim['local_0'], \ 111 | model_opt = nodes[j].model, model_fix = nodes[j].model_g)) 112 | 113 | # local update model_g for global model 114 | for j in range(num_nodes): 115 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_res, optimizer = nodes[j].optim['local_1'],\ 116 | model_opt = nodes[j].model_g, model_fix = nodes[j].model)) 117 | nodes[j].local_test() 118 | server.acc(nodes, weight_list) 119 | # log 120 | log(os.path.basename(__file__)[:-3] + add_('finetune') + add_(K) + add_(reg) + add_(split_para), nodes, server) 121 | return [nodes[i].model for i in range(num_nodes)] -------------------------------------------------------------------------------- /fedbase/baselines/ifca_con.py: -------------------------------------------------------------------------------- 1 | from fedbase.utils.data_loader import data_process, log 2 | from fedbase.nodes.node import node 3 | from fedbase.utils.tools import add_ 4 | from fedbase.server.server import server_class 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import torch.optim as optim 8 | from fedbase.model.model import CNNCifar, CNNMnist 9 | import os 10 | import sys 11 | import inspect 12 | from functools import partial 13 | 14 | def run(dataset_splited, batch_size, K, num_nodes, model, objective, optimizer, global_rounds, local_steps, warmup_rounds, tmp, mu, base, reg_lam = None, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 15 | # dt = data_process(dataset) 16 | # train_splited, test_splited = dt.split_dataset(num_nodes, split['split_para'], split['split_method']) 17 | train_splited, test_splited, split_para = dataset_splited 18 | server = server_class(device) 19 | server.assign_model(model()) 20 | 21 | nodes = [node(i, device) for i in range(num_nodes)] 22 | # local_models = [model() for i in range(num_nodes)] 23 | # local_loss = [objective() for i in range(num_nodes)] 24 | 25 | for i in range(num_nodes): 26 | # data 27 | # print(len(train_splited[i]), len(test_splited[i])) 28 | nodes[i].assign_train(DataLoader(train_splited[i], batch_size=batch_size, shuffle=True)) 29 | nodes[i].assign_test(DataLoader(test_splited[i], batch_size=batch_size, shuffle=False)) 30 | # model 31 | # nodes[i].assign_model(local_models[i]) 32 | # objective 33 | nodes[i].assign_objective(objective()) 34 | # optim 35 | # nodes[i].assign_optim(optimizer(model().parameters())) 36 | 37 | del train_splited, test_splited 38 | 39 | # initialize parameters to nodes 40 | # server.distribute(nodes, list(range(num_nodes))) 41 | 42 | # initialize K cluster model 43 | cluster_models = [model() for i in range(K)] 44 | 45 | # train! 46 | for i in range(global_rounds): 47 | print('-------------------Global round %d start-------------------' % (i)) 48 | # assign client to cluster 49 | assignment = [[] for _ in range(K)] 50 | for i in range(num_nodes): 51 | m = 0 52 | for k in range(1, K): 53 | # print(nodes[i].local_train_loss(cluster_models[m]), nodes[i].local_train_loss(cluster_models[k])) 54 | if nodes[i].local_train_loss(cluster_models[m])>=nodes[i].local_train_loss(cluster_models[k]): 55 | m = k 56 | assignment[m].append(i) 57 | nodes[i].label = m 58 | nodes[i].assign_model(cluster_models[m]) 59 | nodes[i].assign_optim(optimizer(nodes[i].model.parameters())) 60 | # print(server.clustering) 61 | server.clustering['label'].append(assignment) 62 | print(assignment) 63 | print([len(assignment[i]) for i in range(len(assignment))]) 64 | 65 | # local update 66 | weight_list = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 67 | server.model.load_state_dict(server.aggregate([nodes[i].model for i in range(num_nodes)], weight_list)) 68 | for j in range(num_nodes): 69 | if i == 0: 70 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step)) 71 | elif i < warmup_rounds: 72 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_con, \ 73 | model_sim = cluster_models[nodes[j].label], model_all = cluster_models, tmp = tmp, mu = mu, base = None\ 74 | , reg_lam = reg_lam, reg_model = server.model)) 75 | else: 76 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_con, \ 77 | model_sim = cluster_models[nodes[j].label], model_all = cluster_models, tmp = tmp, mu = mu, base = base\ 78 | , reg_lam = reg_lam, reg_model = server.model)) 79 | 80 | # server aggregation and distribution by cluster 81 | for k in range(K): 82 | if len(assignment[k])>0: 83 | weight_ls = [nodes[i].data_size/sum([nodes[i].data_size for i in assignment[k]]) for i in assignment[k]] 84 | model_k = server.aggregate([nodes[i].model for i in assignment[k]], weight_ls) 85 | server.distribute([nodes[i].model for i in assignment[k]], model_k) 86 | cluster_models[k].load_state_dict(model_k) 87 | 88 | # test accuracy 89 | for i in range(num_nodes): 90 | nodes[i].local_test() 91 | server.acc(nodes, weight_list) 92 | 93 | assign = [[i for i in range(num_nodes) if nodes[i].label == k] for k in range(K)] 94 | # log 95 | log(os.path.basename(__file__)[:-3] + add_(K) + add_(base) + add_(tmp) + add_(mu) + add_(reg_lam) + add_(split_para), nodes, server) 96 | 97 | return cluster_models, assign -------------------------------------------------------------------------------- /fedbase/baselines/local.py: -------------------------------------------------------------------------------- 1 | from fedbase.utils.data_loader import data_process, log 2 | from fedbase.utils.tools import add_ 3 | from fedbase.nodes.node import node 4 | from fedbase.server.server import server_class 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import torch.optim as optim 8 | import os 9 | from functools import partial 10 | 11 | 12 | def run(dataset_splited, batch_size, num_nodes, model, objective, optimizer, global_rounds, local_steps, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'), log_file=True): 13 | # dt = data_process(dataset) 14 | # train_splited, test_splited = dt.split_dataset(num_nodes, split['split_para'], split['split_method']) 15 | train_splited, test_splited, split_para = dataset_splited 16 | 17 | server = server_class(device) 18 | server.assign_model(model()) 19 | 20 | nodes = [node(i, device) for i in range(num_nodes)] 21 | # local_models = [model() for i in range(num_nodes)] 22 | # local_loss = [objective() for i in range(num_nodes)] 23 | 24 | for i in range(num_nodes): 25 | # data 26 | # print(len(train_splited[i]), len(test_splited[i])) 27 | nodes[i].assign_train(DataLoader(train_splited[i], batch_size=batch_size, shuffle=True)) 28 | nodes[i].assign_test(DataLoader(test_splited[i], batch_size=batch_size, shuffle=False)) 29 | # model 30 | nodes[i].assign_model(model()) 31 | # objective 32 | nodes[i].assign_objective(objective()) 33 | # optim 34 | nodes[i].assign_optim(optimizer(nodes[i].model.parameters())) 35 | 36 | del train_splited, test_splited 37 | 38 | # initialize parameters to nodes 39 | server.distribute([nodes[i].model for i in range(num_nodes)]) 40 | 41 | # train! 42 | weight_all = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 43 | for i in range(global_rounds): 44 | print('-------------------Global round %d start-------------------' % (i)) 45 | # single-processing! 46 | for j in range(num_nodes): 47 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step)) 48 | nodes[j].local_test() 49 | # test accuracy 50 | server.acc(nodes, weight_all) 51 | 52 | # log 53 | if log_file: 54 | log(os.path.basename(__file__)[:-3] + add_(split_para), nodes, server) 55 | 56 | return [nodes[i].model for i in range(num_nodes)] -------------------------------------------------------------------------------- /fedbase/baselines/wecfl.py: -------------------------------------------------------------------------------- 1 | from fedbase.utils.data_loader import data_process, log 2 | from fedbase.nodes.node import node 3 | from fedbase.utils.tools import add_ 4 | from fedbase.server.server import server_class 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import torch.optim as optim 8 | from fedbase.model.model import CNNCifar, CNNMnist 9 | import os 10 | import sys 11 | import inspect 12 | from functools import partial 13 | 14 | def run(dataset_splited, batch_size, K, num_nodes, model, objective, optimizer, global_rounds, local_steps, \ 15 | reg_lam = None, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'), finetune=False, finetune_steps = None): 16 | # dt = data_process(dataset) 17 | # train_splited, test_splited = dt.split_dataset(num_nodes, split['split_para'], split['split_method']) 18 | train_splited, test_splited, split_para = dataset_splited 19 | server = server_class(device) 20 | server.assign_model(model()) 21 | 22 | nodes = [node(i, device) for i in range(num_nodes)] 23 | # local_models = [model() for i in range(num_nodes)] 24 | # local_loss = [objective() for i in range(num_nodes)] 25 | 26 | for i in range(num_nodes): 27 | # data 28 | # print(len(train_splited[i]), len(test_splited[i])) 29 | nodes[i].assign_train(DataLoader(train_splited[i], batch_size=batch_size, shuffle=True)) 30 | nodes[i].assign_test(DataLoader(test_splited[i], batch_size=batch_size, shuffle=False)) 31 | # model 32 | nodes[i].assign_model(model()) 33 | # objective 34 | nodes[i].assign_objective(objective()) 35 | # optim 36 | nodes[i].assign_optim(optimizer(nodes[i].model.parameters())) 37 | 38 | del train_splited, test_splited 39 | 40 | # initialize parameters to nodes 41 | server.distribute([nodes[i].model for i in range(num_nodes)]) 42 | weight_list = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 43 | 44 | # initialize K cluster model 45 | cluster_models = [model() for i in range(K)] 46 | 47 | # train! 48 | for t in range(global_rounds): 49 | print('-------------------Global round %d start-------------------' % (t)) 50 | # local update 51 | for j in range(num_nodes): 52 | if not reg_lam or t == 0: 53 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step)) 54 | else: 55 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_fedprox, reg_model = cluster_models[nodes[j].label], reg_lam= reg_lam)) 56 | # server clustering 57 | server.weighted_clustering(nodes, list(range(num_nodes)), K) 58 | 59 | # server aggregation and distribution by cluster 60 | for j in range(K): 61 | assign_ls = [i for i in list(range(num_nodes)) if nodes[i].label==j] 62 | weight_ls = [nodes[i].data_size/sum([nodes[i].data_size for i in assign_ls]) for i in assign_ls] 63 | model_k = server.aggregate([nodes[i].model for i in assign_ls], weight_ls) 64 | server.distribute([nodes[i].model for i in assign_ls], model_k) 65 | cluster_models[j].load_state_dict(model_k) 66 | 67 | # test accuracy 68 | for j in range(num_nodes): 69 | nodes[j].local_test() 70 | server.acc(nodes, weight_list) 71 | 72 | if not finetune: 73 | assign = [[i for i in range(num_nodes) if nodes[i].label == k] for k in range(K)] 74 | # log 75 | log(os.path.basename(__file__)[:-3] + add_(K) + add_(reg_lam) + add_(split_para), nodes, server) 76 | return cluster_models, assign 77 | else: 78 | if not finetune_steps: 79 | finetune_steps = local_steps 80 | # fine tune 81 | for j in range(num_nodes): 82 | if not reg_lam: 83 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step)) 84 | else: 85 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_fedprox, reg_model = cluster_models[nodes[j].label], reg_lam= reg_lam)) 86 | nodes[j].local_test() 87 | server.acc(nodes, weight_list) 88 | # log 89 | log(os.path.basename(__file__)[:-3] + add_('finetune') + add_(K) + add_(reg_lam) + add_(split_para), nodes, server) 90 | return [nodes[i].model for i in range(num_nodes)] -------------------------------------------------------------------------------- /fedbase/baselines/wecfl_cam.py: -------------------------------------------------------------------------------- 1 | from fedbase.utils.data_loader import data_process, log 2 | from fedbase.utils.visualize import dimension_reduction 3 | from fedbase.utils.tools import add_ 4 | from fedbase.nodes.node import node 5 | from fedbase.server.server import server_class 6 | from fedbase.baselines import local 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torch.optim as optim 10 | from fedbase.model.model import CNNCifar, CNNMnist 11 | import os 12 | import sys 13 | import inspect 14 | from functools import partial 15 | import numpy as np 16 | 17 | def run(dataset_splited, batch_size, K, num_nodes, model, objective, optimizer, warmup_rounds, global_rounds, local_steps, \ 18 | reg_lam = None, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'), finetune=False, finetune_steps = None): 19 | train_splited, test_splited, split_para = dataset_splited 20 | # warmup 21 | local_models_warmup = local.run(dataset_splited, batch_size, num_nodes, model, objective, optimizer, warmup_rounds, local_steps, device = device, log_file=False) 22 | 23 | # initialize 24 | server = server_class(device) 25 | server.assign_model(model()) 26 | server.model_g = model() 27 | 28 | nodes = [node(i, device) for i in range(num_nodes)] 29 | 30 | for i in range(num_nodes): 31 | # data 32 | # print(len(train_splited[i]), len(test_splited[i])) 33 | nodes[i].assign_train(DataLoader(train_splited[i], batch_size=batch_size, shuffle=True)) 34 | nodes[i].assign_test(DataLoader(test_splited[i], batch_size=batch_size, shuffle=False)) 35 | # model 36 | nodes[i].assign_model(local_models_warmup[i]) 37 | nodes[i].model_g = model() 38 | nodes[i].model_g.to(device) 39 | # objective 40 | nodes[i].assign_objective(objective()) 41 | # optim 42 | nodes[i].assign_optim({'local_0': optimizer(nodes[i].model.parameters()),\ 43 | 'local_1': optimizer(nodes[i].model_g.parameters()),\ 44 | 'all': optimizer(list(nodes[i].model.parameters())+list(nodes[i].model_g.parameters()))}) 45 | 46 | del train_splited, test_splited 47 | 48 | # initialize K cluster model 49 | cluster_models = [model() for i in range(K)] 50 | 51 | # initialize clustering and distribute 52 | server.weighted_clustering(nodes, list(range(num_nodes)), K) 53 | for j in range(K): 54 | assign_ls = [i for i in list(range(num_nodes)) if nodes[i].label==j] 55 | weight_ls = [nodes[i].data_size/sum([nodes[i].data_size for i in assign_ls]) for i in assign_ls] 56 | model_k = server.aggregate([nodes[i].model for i in assign_ls], weight_ls) 57 | server.distribute([nodes[i].model for i in assign_ls], model_k) 58 | cluster_models[j].load_state_dict(model_k) 59 | 60 | weight_list = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 61 | # train! 62 | for i in range(global_rounds - warmup_rounds): 63 | print('-------------------Global round %d start-------------------' % (i)) 64 | # update model_g 65 | for j in range(num_nodes): 66 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_res, optimizer = nodes[j].optim['local_1'],\ 67 | model_opt = nodes[j].model_g, model_fix = nodes[j].model)) 68 | 69 | # update local model 70 | for j in range(num_nodes): 71 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_res, optimizer = nodes[j].optim['local_0'], \ 72 | model_opt = nodes[j].model, model_fix = nodes[j].model_g, reg_lam = reg_lam, reg_model = cluster_models[nodes[j].label])) 73 | 74 | # aggregate and distribute model_g 75 | weight_all = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 76 | server.model.load_state_dict(server.aggregate([nodes[i].model_g for i in range(num_nodes)], weight_all)) 77 | server.distribute([nodes[i].model_g for i in range(num_nodes)]) 78 | 79 | # server clustering 80 | server.weighted_clustering(nodes, list(range(num_nodes)), K) 81 | 82 | # server aggregation and distribution by cluster 83 | for j in range(K): 84 | assign_ls = [i for i in list(range(num_nodes)) if nodes[i].label==j] 85 | weight_ls = [nodes[i].data_size/sum([nodes[i].data_size for i in assign_ls]) for i in assign_ls] 86 | model_k = server.aggregate([nodes[i].model for i in assign_ls], weight_ls) 87 | server.distribute([nodes[i].model for i in assign_ls], model_k) 88 | cluster_models[j].load_state_dict(model_k) 89 | 90 | # test accuracy 91 | for j in range(num_nodes): 92 | nodes[j].local_test(model_res = nodes[j].model_g) 93 | server.acc(nodes, weight_list) 94 | 95 | if not finetune: 96 | assign = [[i for i in range(num_nodes) if nodes[i].label == k] for k in range(K)] 97 | # log 98 | log(os.path.basename(__file__)[:-3] + add_(K) + add_(reg) + add_(split_para), nodes, server) 99 | return cluster_models, assign 100 | else: 101 | if not finetune_steps: 102 | finetune_steps = local_steps 103 | # fine tune 104 | # update model_g 105 | for j in range(num_nodes): 106 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_res, optimizer = nodes[j].optim['local_1'],\ 107 | model_opt = nodes[j].model_g, model_fix = nodes[j].model)) 108 | 109 | # update local model 110 | for j in range(num_nodes): 111 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_res, optimizer = nodes[j].optim['local_0'], \ 112 | model_opt = nodes[j].model, model_fix = nodes[j].model_g, reg_lam = reg_lam, reg_model = cluster_models[nodes[j].label])) 113 | nodes[j].local_test() 114 | server.acc(nodes, weight_list) 115 | # log 116 | log(os.path.basename(__file__)[:-3] + add_('finetune') + add_(K) + add_(reg) + add_(split_para), nodes, server) 117 | return [nodes[i].model for i in range(num_nodes)] 118 | -------------------------------------------------------------------------------- /fedbase/baselines/wecfl_con.py: -------------------------------------------------------------------------------- 1 | from fedbase.utils.data_loader import data_process, log 2 | from fedbase.utils.visualize import dimension_reduction 3 | from fedbase.utils.tools import add_ 4 | from fedbase.nodes.node import node 5 | from fedbase.server.server import server_class 6 | import torch 7 | from torch.utils.data import DataLoader 8 | import torch.optim as optim 9 | from fedbase.model.model import CNNCifar, CNNMnist 10 | import os 11 | import sys 12 | import inspect 13 | from functools import partial 14 | import numpy as np 15 | 16 | def run(dataset_splited, batch_size, K, num_nodes, model, objective, optimizer, global_rounds, local_steps, warmup_rounds, tmp, mu, base, reg_lam = None, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 17 | # dt = data_process(dataset) 18 | # train_splited, test_splited = dt.split_dataset(num_nodes, split['split_para'], split['split_method']) 19 | train_splited, test_splited, split_para = dataset_splited 20 | server = server_class(device) 21 | server.assign_model(model()) 22 | 23 | nodes = [node(i, device) for i in range(num_nodes)] 24 | # local_models = [model() for i in range(num_nodes)] 25 | # local_loss = [objective() for i in range(num_nodes)] 26 | 27 | for i in range(num_nodes): 28 | # data 29 | # print(len(train_splited[i]), len(test_splited[i])) 30 | nodes[i].assign_train(DataLoader(train_splited[i], batch_size=batch_size, shuffle=True)) 31 | nodes[i].assign_test(DataLoader(test_splited[i], batch_size=batch_size, shuffle=False)) 32 | # model 33 | nodes[i].assign_model(model()) 34 | # objective 35 | nodes[i].assign_objective(objective()) 36 | # optim 37 | nodes[i].assign_optim(optimizer(nodes[i].model.parameters())) 38 | 39 | del train_splited, test_splited 40 | 41 | # initialize parameters to nodes 42 | server.distribute([nodes[i].model for i in range(num_nodes)]) 43 | 44 | # initialize K cluster model 45 | cluster_models = [model().to(device) for i in range(K)] 46 | 47 | # train! 48 | # b_list = [] 49 | # uu_list = [] 50 | weight_list = [nodes[i].data_size/sum([nodes[i].data_size for i in range(num_nodes)]) for i in range(num_nodes)] 51 | for i in range(global_rounds): 52 | print('-------------------Global round %d start-------------------' % (i)) 53 | 54 | # local update 55 | server.model.load_state_dict(server.aggregate([nodes[i].model for i in range(num_nodes)], weight_list)) 56 | for j in range(num_nodes): 57 | if i == 0: 58 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step)) 59 | elif i < warmup_rounds: 60 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_con, \ 61 | model_sim = cluster_models[nodes[j].label], model_all = cluster_models, tmp = tmp, mu = mu, base = None\ 62 | , reg_lam = reg_lam, reg_model = server.model)) 63 | else: 64 | nodes[j].local_update_steps(local_steps, partial(nodes[j].train_single_step_con, \ 65 | model_sim = cluster_models[nodes[j].label], model_all = cluster_models, tmp = tmp, mu = mu, base = base\ 66 | , reg_lam = reg_lam, reg_model = server.model)) 67 | 68 | # # tsne or pca plot 69 | # # if i == global_rounds-1: 70 | # if i == i : 71 | # cluster_data = torch.nn.utils.parameters_to_vector(nodes[0].model.parameters()).cpu().detach().numpy()[-1000:] 72 | # for i in range(1, num_nodes): 73 | # cluster_data = np.concatenate((cluster_data, torch.nn.utils.parameters_to_vector(nodes[i].model.parameters()).cpu().detach().numpy()[-1000:]), axis = 0) 74 | # cluster_data = np.reshape(cluster_data, (num_nodes,int(len(cluster_data)/num_nodes))) 75 | # cluster_label = server.clustering['label'][-1] 76 | # # cluster_label = np.repeat(range(10),20) 77 | # dimension_reduction(cluster_data, cluster_label, method= 'tsne') 78 | # plot B 79 | # print(server.calculate_B(nodes, range(20))) 80 | # B_list, u_list = server.calculate_B(nodes, range(20)) 81 | # b_list.append(max(B_list)) 82 | # uu_list.append(max(u_list)) 83 | # print(b_list, uu_list) 84 | # for k in range(num_nodes): 85 | # nodes[k].grads = [] 86 | 87 | # server clustering 88 | server.weighted_clustering(nodes, list(range(num_nodes)), K) 89 | 90 | # server aggregation and distribution by cluster 91 | for j in range(K): 92 | assign_ls = [i for i in list(range(num_nodes)) if nodes[i].label==j] 93 | weight_ls = [nodes[i].data_size/sum([nodes[i].data_size for i in assign_ls]) for i in assign_ls] 94 | model_k = server.aggregate([nodes[i].model for i in assign_ls], weight_ls) 95 | server.distribute([nodes[i].model for i in assign_ls], model_k) 96 | cluster_models[j].load_state_dict(model_k) 97 | 98 | # test accuracy 99 | for j in range(num_nodes): 100 | nodes[j].local_test() 101 | server.acc(nodes, weight_list) 102 | 103 | assign = [[i for i in range(num_nodes) if nodes[i].label == k] for k in range(K)] 104 | # log 105 | log(os.path.basename(__file__)[:-3] + add_(K) + add_(base) + add_(tmp) + add_(mu) + add_(reg_lam) + add_(split_para), nodes, server) 106 | 107 | return cluster_models, assign 108 | -------------------------------------------------------------------------------- /fedbase/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsmjie/FedBase/5c386c8e5b64591f435f6821b339f8aa2d867db9/fedbase/model/__init__.py -------------------------------------------------------------------------------- /fedbase/model/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, dim_in, dim_hidden, dim_out): 8 | super(MLP, self).__init__() 9 | self.layer_input = nn.Linear(dim_in, dim_hidden) 10 | self.relu = nn.ReLU() 11 | self.dropout = nn.Dropout() 12 | self.layer_hidden = nn.Linear(dim_hidden, dim_out) 13 | self.softmax = nn.Softmax(dim=1) 14 | 15 | def forward(self, x): 16 | x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1]) 17 | x = self.layer_input(x) 18 | x = self.dropout(x) 19 | x = self.relu(x) 20 | x = self.layer_hidden(x) 21 | # return self.softmax(x) 22 | return x 23 | 24 | 25 | class CNNMnist(nn.Module): 26 | def __init__(self): 27 | super(CNNMnist, self).__init__() 28 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 29 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 30 | self.conv2_drop = nn.Dropout2d() 31 | self.fc1 = nn.Linear(320, 50) 32 | self.fc2 = nn.Linear(50, 10) 33 | 34 | def forward(self, x): 35 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 36 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 37 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 38 | x = F.relu(self.fc1(x)) 39 | x = F.dropout(x, training=self.training) 40 | x = self.fc2(x) 41 | # return F.log_softmax(x, dim=1) 42 | return x 43 | 44 | class CNNFemnist(nn.Module): 45 | def __init__(self, args): 46 | super(CNNFemnist, self).__init__() 47 | self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=3) 48 | self.conv2 = nn.Conv2d(10, args.out_channels, kernel_size=5) 49 | self.conv2_drop = nn.Dropout2d() 50 | self.fc1 = nn.Linear(int(16820/20*args.out_channels), 50) 51 | self.fc2 = nn.Linear(50, args.num_classes) 52 | 53 | def forward(self, x): 54 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 55 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 56 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 57 | x1 = F.relu(self.fc1(x)) 58 | x = F.dropout(x1, training=self.training) 59 | x = self.fc2(x) 60 | return F.log_softmax(x, dim=1), x1 61 | 62 | 63 | class CNNFashion_Mnist(nn.Module): 64 | def __init__(self): 65 | super(CNNFashion_Mnist, self).__init__() 66 | self.pool = nn.MaxPool2d(2, 2) 67 | self.bn1 = nn.BatchNorm2d(16) 68 | self.bn2 = nn.BatchNorm2d(32) 69 | self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=2) 70 | self.conv2 = nn.Conv2d(16, 32, kernel_size=5, padding=2) 71 | self.fc1 = nn.Linear(7*7*32, 10) 72 | 73 | def forward(self, x): 74 | x = self.pool(F.relu(self.bn1(self.conv1(x)))) 75 | x = self.pool(F.relu(self.bn2(self.conv2(x)))) 76 | # x = x.view(x.size(0), -1) 77 | x = torch.flatten(x, 1) 78 | x = self.fc1(x) 79 | return x 80 | 81 | class CNNCifar(nn.Module): 82 | def __init__(self): 83 | super(CNNCifar, self).__init__() 84 | self.conv1 = nn.Conv2d(3, 6, 5) 85 | self.pool = nn.MaxPool2d(2, 2) 86 | self.conv2 = nn.Conv2d(6, 16, 5) 87 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 88 | self.fc2 = nn.Linear(120, 84) 89 | self.fc3 = nn.Linear(84, 10) 90 | 91 | def forward(self, x): 92 | x = self.pool(F.relu(self.conv1(x))) 93 | x = self.pool(F.relu(self.conv2(x))) 94 | # x = x.view(x.size(0), -1) 95 | x = torch.flatten(x, 1) 96 | x = F.relu(self.fc1(x)) 97 | x = F.relu(self.fc2(x)) 98 | x = self.fc3(x) 99 | return x 100 | 101 | class CNNPath(nn.Module): 102 | def __init__(self): 103 | super(CNNPath, self).__init__() 104 | self.conv1 = nn.Conv2d(3, 6, 5) 105 | self.pool = nn.MaxPool2d(2, 2) 106 | self.conv2 = nn.Conv2d(6, 16, 5) 107 | self.fc1 = nn.Linear(16 * 4 * 4, 120) 108 | self.fc2 = nn.Linear(120, 84) 109 | self.fc3 = nn.Linear(84, 9) 110 | 111 | def forward(self, x): 112 | x = self.pool(F.relu(self.conv1(x))) 113 | x = self.pool(F.relu(self.conv2(x))) 114 | # x = x.view(x.size(0), -1) 115 | x = torch.flatten(x, 1) 116 | x = F.relu(self.fc1(x)) 117 | x = F.relu(self.fc2(x)) 118 | x = self.fc3(x) 119 | return x 120 | 121 | class CNNTissue(nn.Module): 122 | def __init__(self): 123 | super(CNNTissue, self).__init__() 124 | self.pool = nn.MaxPool2d(2, 2) 125 | self.bn1 = nn.BatchNorm2d(16) 126 | self.bn2 = nn.BatchNorm2d(32) 127 | self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=2) 128 | self.conv2 = nn.Conv2d(16, 32, kernel_size=5, padding=2) 129 | self.fc1 = nn.Linear(7*7*32, 8) 130 | 131 | def forward(self, x): 132 | x = self.pool(F.relu(self.bn1(self.conv1(x)))) 133 | x = self.pool(F.relu(self.bn2(self.conv2(x)))) 134 | # x = x.view(x.size(0), -1) 135 | x = torch.flatten(x, 1) 136 | x = self.fc1(x) 137 | return x 138 | 139 | # def __init__(self): 140 | # super(CNNTissue, self).__init__() 141 | # self.conv1 = nn.Sequential( 142 | # nn.Conv2d(1, 16, kernel_size=5, padding=2), 143 | # nn.BatchNorm2d(16), 144 | # nn.ReLU(), 145 | # nn.MaxPool2d(2)) 146 | # self.conv2 = nn.Sequential( 147 | # nn.Conv2d(16, 32, kernel_size=5, padding=2), 148 | # nn.BatchNorm2d(32), 149 | # nn.ReLU(), 150 | # nn.MaxPool2d(2)) 151 | # self.fc = nn.Linear(7*7*32, 8) 152 | 153 | # def forward(self, x): 154 | # out = self.conv1(x) 155 | # out = self.conv2(out) 156 | # out = out.view(out.size(0), -1) 157 | # out = self.fc(out) 158 | # return out 159 | 160 | class Lenet(nn.Module): 161 | def __init__(self, args): 162 | super(Lenet, self).__init__() 163 | self.n_cls = 10 164 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5) 165 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5) 166 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 167 | self.fc1 = nn.Linear(64 * 5 * 5, 384) 168 | self.fc2 = nn.Linear(384, 192) 169 | self.fc3 = nn.Linear(192, self.n_cls) 170 | 171 | def forward(self, x): 172 | x = self.pool(F.relu(self.conv1(x))) 173 | x = self.pool(F.relu(self.conv2(x))) 174 | x = x.view(-1, 64 * 5 * 5) 175 | x1 = F.relu(self.fc1(x)) 176 | x = F.relu(self.fc2(x1)) 177 | x = self.fc3(x) 178 | 179 | return F.log_softmax(x, dim=1), x1 180 | 181 | # define a simple CNN model 182 | 183 | class oct_net(nn.Module): 184 | def __init__(self, in_channels = 1, num_classes = 4): 185 | super(oct_net, self).__init__() 186 | 187 | self.layer1 = nn.Sequential( 188 | nn.Conv2d(in_channels, 16, kernel_size=3), 189 | nn.BatchNorm2d(16), 190 | nn.ReLU()) 191 | 192 | self.layer2 = nn.Sequential( 193 | nn.Conv2d(16, 16, kernel_size=3), 194 | nn.BatchNorm2d(16), 195 | nn.ReLU(), 196 | nn.MaxPool2d(kernel_size=2, stride=2)) 197 | 198 | self.layer3 = nn.Sequential( 199 | nn.Conv2d(16, 64, kernel_size=3), 200 | nn.BatchNorm2d(64), 201 | nn.ReLU()) 202 | 203 | self.layer4 = nn.Sequential( 204 | nn.Conv2d(64, 64, kernel_size=3), 205 | nn.BatchNorm2d(64), 206 | nn.ReLU()) 207 | 208 | self.layer5 = nn.Sequential( 209 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 210 | nn.BatchNorm2d(64), 211 | nn.ReLU(), 212 | nn.MaxPool2d(kernel_size=2, stride=2)) 213 | 214 | self.fc = nn.Sequential( 215 | nn.Linear(64 * 4 * 4, 128), 216 | nn.ReLU(), 217 | nn.Linear(128, 128), 218 | nn.ReLU(), 219 | nn.Linear(128, num_classes)) 220 | 221 | def forward(self, x): 222 | x = self.layer1(x) 223 | x = self.layer2(x) 224 | x = self.layer3(x) 225 | x = self.layer4(x) 226 | x = self.layer5(x) 227 | x = x.view(x.size(0), -1) 228 | x = self.fc(x) 229 | return x -------------------------------------------------------------------------------- /fedbase/model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | __all__ = [ 6 | "ResNet", 7 | "resnet18", 8 | "resnet34", 9 | "resnet50", 10 | ] 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d( 16 | in_planes, 17 | out_planes, 18 | kernel_size=3, 19 | stride=stride, 20 | padding=dilation, 21 | groups=groups, 22 | bias=False, 23 | dilation=dilation, 24 | ) 25 | 26 | 27 | def conv1x1(in_planes, out_planes, stride=1): 28 | """1x1 convolution""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | expansion = 1 34 | 35 | def __init__( 36 | self, 37 | inplanes, 38 | planes, 39 | stride=1, 40 | downsample=None, 41 | groups=1, 42 | base_width=64, 43 | dilation=1, 44 | norm_layer=None, 45 | ): 46 | super(BasicBlock, self).__init__() 47 | if norm_layer is None: 48 | norm_layer = nn.BatchNorm2d 49 | if groups != 1 or base_width != 64: 50 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 51 | if dilation > 1: 52 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 53 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 54 | self.conv1 = conv3x3(inplanes, planes, stride) 55 | self.bn1 = norm_layer(planes) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.conv2 = conv3x3(planes, planes) 58 | self.bn2 = norm_layer(planes) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | identity = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__( 85 | self, 86 | inplanes, 87 | planes, 88 | stride=1, 89 | downsample=None, 90 | groups=1, 91 | base_width=64, 92 | dilation=1, 93 | norm_layer=None, 94 | ): 95 | super(Bottleneck, self).__init__() 96 | if norm_layer is None: 97 | norm_layer = nn.BatchNorm2d 98 | width = int(planes * (base_width / 64.0)) * groups 99 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 100 | self.conv1 = conv1x1(inplanes, width) 101 | self.bn1 = norm_layer(width) 102 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 103 | self.bn2 = norm_layer(width) 104 | self.conv3 = conv1x1(width, planes * self.expansion) 105 | self.bn3 = norm_layer(planes * self.expansion) 106 | self.relu = nn.ReLU(inplace=True) 107 | self.downsample = downsample 108 | self.stride = stride 109 | 110 | def forward(self, x): 111 | identity = x 112 | 113 | out = self.conv1(x) 114 | out = self.bn1(out) 115 | out = self.relu(out) 116 | 117 | out = self.conv2(out) 118 | out = self.bn2(out) 119 | out = self.relu(out) 120 | 121 | out = self.conv3(out) 122 | out = self.bn3(out) 123 | 124 | if self.downsample is not None: 125 | identity = self.downsample(x) 126 | 127 | out += identity 128 | out = self.relu(out) 129 | 130 | return out 131 | 132 | 133 | class ResNet(nn.Module): 134 | def __init__( 135 | self, 136 | block, 137 | layers, 138 | num_classes=10, 139 | zero_init_residual=False, 140 | groups=1, 141 | width_per_group=64, 142 | replace_stride_with_dilation=None, 143 | norm_layer=None, 144 | ): 145 | super(ResNet, self).__init__() 146 | if norm_layer is None: 147 | norm_layer = nn.BatchNorm2d 148 | self._norm_layer = norm_layer 149 | 150 | self.inplanes = 64 151 | self.dilation = 1 152 | if replace_stride_with_dilation is None: 153 | # each element in the tuple indicates if we should replace 154 | # the 2x2 stride with a dilated convolution instead 155 | replace_stride_with_dilation = [False, False, False] 156 | if len(replace_stride_with_dilation) != 3: 157 | raise ValueError( 158 | "replace_stride_with_dilation should be None " 159 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation) 160 | ) 161 | self.groups = groups 162 | self.base_width = width_per_group 163 | 164 | # CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1 165 | self.conv1 = nn.Conv2d( 166 | 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False 167 | ) 168 | # END 169 | 170 | self.bn1 = norm_layer(self.inplanes) 171 | self.relu = nn.ReLU(inplace=True) 172 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 173 | self.layer1 = self._make_layer(block, 64, layers[0]) 174 | self.layer2 = self._make_layer( 175 | block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] 176 | ) 177 | self.layer3 = self._make_layer( 178 | block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] 179 | ) 180 | self.layer4 = self._make_layer( 181 | block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] 182 | ) 183 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 184 | self.fc = nn.Linear(512 * block.expansion, num_classes) 185 | 186 | for m in self.modules(): 187 | if isinstance(m, nn.Conv2d): 188 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 189 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 190 | nn.init.constant_(m.weight, 1) 191 | nn.init.constant_(m.bias, 0) 192 | 193 | # Zero-initialize the last BN in each residual branch, 194 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 195 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 196 | if zero_init_residual: 197 | for m in self.modules(): 198 | if isinstance(m, Bottleneck): 199 | nn.init.constant_(m.bn3.weight, 0) 200 | elif isinstance(m, BasicBlock): 201 | nn.init.constant_(m.bn2.weight, 0) 202 | 203 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 204 | norm_layer = self._norm_layer 205 | downsample = None 206 | previous_dilation = self.dilation 207 | if dilate: 208 | self.dilation *= stride 209 | stride = 1 210 | if stride != 1 or self.inplanes != planes * block.expansion: 211 | downsample = nn.Sequential( 212 | conv1x1(self.inplanes, planes * block.expansion, stride), 213 | norm_layer(planes * block.expansion), 214 | ) 215 | 216 | layers = [] 217 | layers.append( 218 | block( 219 | self.inplanes, 220 | planes, 221 | stride, 222 | downsample, 223 | self.groups, 224 | self.base_width, 225 | previous_dilation, 226 | norm_layer, 227 | ) 228 | ) 229 | self.inplanes = planes * block.expansion 230 | for _ in range(1, blocks): 231 | layers.append( 232 | block( 233 | self.inplanes, 234 | planes, 235 | groups=self.groups, 236 | base_width=self.base_width, 237 | dilation=self.dilation, 238 | norm_layer=norm_layer, 239 | ) 240 | ) 241 | 242 | return nn.Sequential(*layers) 243 | 244 | def forward(self, x): 245 | x = self.conv1(x) 246 | x = self.bn1(x) 247 | x = self.relu(x) 248 | x = self.maxpool(x) 249 | 250 | x = self.layer1(x) 251 | x = self.layer2(x) 252 | x = self.layer3(x) 253 | x = self.layer4(x) 254 | 255 | x = self.avgpool(x) 256 | x = x.reshape(x.size(0), -1) 257 | x = self.fc(x) 258 | 259 | return x 260 | 261 | 262 | def _resnet(arch, block, layers, pretrained, progress, device, **kwargs): 263 | model = ResNet(block, layers, **kwargs) 264 | if pretrained: 265 | script_dir = os.path.dirname(__file__) 266 | state_dict = torch.load( 267 | script_dir + "/state_dicts/" + arch + ".pt", map_location=device 268 | ) 269 | model.load_state_dict(state_dict) 270 | return model 271 | 272 | 273 | def resnet18(pretrained=False, progress=True, device="gpu", **kwargs): 274 | """Constructs a ResNet-18 model. 275 | Args: 276 | pretrained (bool): If True, returns a model pre-trained on ImageNet 277 | progress (bool): If True, displays a progress bar of the download to stderr 278 | """ 279 | return _resnet( 280 | "resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, device, **kwargs 281 | ) 282 | 283 | 284 | def resnet34(pretrained=False, progress=True, device="cpu", **kwargs): 285 | """Constructs a ResNet-34 model. 286 | Args: 287 | pretrained (bool): If True, returns a model pre-trained on ImageNet 288 | progress (bool): If True, displays a progress bar of the download to stderr 289 | """ 290 | return _resnet( 291 | "resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, device, **kwargs 292 | ) 293 | 294 | 295 | def resnet50(pretrained=False, progress=True, device="cpu", **kwargs): 296 | """Constructs a ResNet-50 model. 297 | Args: 298 | pretrained (bool): If True, returns a model pre-trained on ImageNet 299 | progress (bool): If True, displays a progress bar of the download to stderr 300 | """ 301 | return _resnet( 302 | "resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, device, **kwargs 303 | ) -------------------------------------------------------------------------------- /fedbase/nodes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsmjie/FedBase/5c386c8e5b64591f435f6821b339f8aa2d867db9/fedbase/nodes/__init__.py -------------------------------------------------------------------------------- /fedbase/nodes/node.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import linalg as LA 3 | from fedbase.utils.model_utils import save_checkpoint, load_checkpoint 4 | from fedbase.model.model import CNNMnist, MLP 5 | from sklearn.metrics import accuracy_score, f1_score 6 | from fedbase.utils.tools import unpack_args 7 | from functools import partial 8 | from statistics import mode 9 | import torch.nn.functional as F 10 | 11 | class node(): 12 | def __init__(self, id, device): 13 | self.id = id 14 | self.test_metrics = [] 15 | self.step = 0 16 | self.device = device 17 | self.grads = [] 18 | # self.train_steps = 0 19 | 20 | def assign_train(self, data): 21 | self.train = data 22 | self.data_size = len(data.dataset) 23 | 24 | def assign_test(self,data): 25 | self.test = data 26 | 27 | def assign_model(self, model): 28 | try: 29 | self.model.load_state_dict(model.state_dict()) 30 | except: 31 | self.model = model 32 | self.model.to(self.device) 33 | # try: 34 | # self.model = torch.compile(self.model) 35 | # except: 36 | # pass 37 | 38 | def assign_objective(self, objective): 39 | self.objective = objective 40 | 41 | def assign_optim(self, optim): 42 | self.optim = optim 43 | 44 | def local_update_steps(self, local_steps, train_single_step_func): 45 | # print(len(self.train), self.step) 46 | if len(self.train) - self.step > local_steps: 47 | for k, (inputs, labels) in enumerate(self.train): 48 | if k < self.step or k >= self.step + local_steps: 49 | continue 50 | train_single_step_func(inputs, labels) 51 | self.step = self.step + local_steps 52 | else: 53 | for k, (inputs, labels) in enumerate(self.train): 54 | if k < self.step: 55 | continue 56 | train_single_step_func(inputs, labels) 57 | for j in range((local_steps-len(self.train)+self.step)//len(self.train)): 58 | for k, (inputs, labels) in enumerate(self.train): 59 | train_single_step_func(inputs, labels) 60 | for k, (inputs, labels) in enumerate(self.train): 61 | if k >=(local_steps-len(self.train)+self.step)%len(self.train): 62 | continue 63 | train_single_step_func(inputs, labels) 64 | self.step = (local_steps-len(self.train)+self.step)%len(self.train) 65 | # torch.cuda.empty_cache() 66 | 67 | def local_update_epochs(self, local_epochs, train_single_step_func): 68 | # local_steps may be better!! 69 | running_loss = 0 70 | for j in range(local_epochs): 71 | for k, (inputs, labels) in enumerate(self.train): 72 | train_single_step_func(inputs, labels) 73 | # torch.cuda.empty_cache() 74 | 75 | def train_single_step(self, inputs, labels): 76 | inputs = inputs.to(self.device) 77 | labels = torch.flatten(labels) 78 | labels = labels.to(self.device, dtype = torch.long) 79 | # print(labels) 80 | # zero the parameter gradients 81 | # self.model.zero_grad(set_to_none=True) 82 | self.optim.zero_grad() 83 | # forward + backward + optimize 84 | outputs = self.model(inputs) 85 | # optim 86 | loss = self.objective(outputs, F.one_hot(labels, outputs.shape[1]).float()) 87 | loss.backward() 88 | 89 | # calculate accumulate gradients 90 | # grads = torch.tensor([]) 91 | # for index, param in enumerate(self.model.parameters()): 92 | # # param.grad = torch.tensor(grads[index]) 93 | # grads= torch.cat((grads, torch.flatten(param.grad).cpu()),0) 94 | # self.grads.append(grads) 95 | 96 | self.optim.step() 97 | # self.train_steps+=1 98 | 99 | # for fedprox and ditto 100 | def train_single_step_fedprox(self, inputs, labels, reg_lam = None, reg_model = None): 101 | inputs = inputs.to(self.device) 102 | labels = torch.flatten(labels) 103 | labels = labels.to(self.device, dtype = torch.long) 104 | # zero the parameter gradients 105 | # self.model.zero_grad(set_to_none=True) 106 | self.optim.zero_grad() 107 | # forward + backward + optimize 108 | outputs = self.model(inputs) 109 | # optim 110 | if reg_lam: 111 | reg_model.to(self.device) 112 | reg = torch.square(torch.norm(torch.cat(tuple([torch.flatten(self.model.state_dict()[k] - reg_model.state_dict()[k])\ 113 | for k in self.model.state_dict().keys()]),0),2)) 114 | else: 115 | reg, reg_lam = 0, 0 116 | self.loss = self.objective(outputs, labels) + reg_lam*reg/2 117 | # print(self.objective(outputs, labels)) 118 | self.loss.backward() 119 | self.optim.step() 120 | # print('after', self.objective(self.model(inputs), labels)) 121 | 122 | def train_single_step_res(self, inputs, labels, optimizer, model_opt, model_fix, reg_lam = None, reg_model = None): 123 | inputs = inputs.to(self.device) 124 | labels = torch.flatten(labels) 125 | labels = labels.to(self.device, dtype = torch.long) 126 | # zero the parameter gradients 127 | # model_opt.zero_grad(set_to_none=True) 128 | # model_fix.zero_grad(set_to_none=True) 129 | optimizer.zero_grad() 130 | # model_2.zero_grad(set_to_none=True) 131 | # forward + backward + optimize 132 | # loss 1 133 | # m = torch.nn.LogSoftmax(dim=1) 134 | # ls = torch.nn.NLLLoss() 135 | # outputs = (m(model_opt(inputs)) + m(model_fix(inputs)))/2 136 | # loss = ls(outputs, labels) 137 | 138 | # loss 2 139 | # reg = 0 140 | # for p,q in zip(model_opt.parameters(), model_fix.parameters()): 141 | # reg += torch.norm((p-q),2) 142 | 143 | outputs = model_opt(inputs) + model_fix(inputs) 144 | 145 | if reg_lam: 146 | reg_model.to(self.device) 147 | reg = torch.square(torch.norm(torch.cat(tuple([torch.flatten(model_opt.state_dict()[k] - reg_model.state_dict()[k])\ 148 | for k in model_opt.state_dict().keys()]),0),2)) 149 | else: 150 | reg, reg_lam = 0, 0 151 | 152 | loss = self.objective(outputs, labels) + reg_lam/2 * reg 153 | # loss 3 154 | # reg = 0 155 | # for p,q in zip(model_opt.parameters(), model_fix.parameters()): 156 | # reg += torch.square(torch.norm((p-q),2)) 157 | # # print(reg) 158 | # outputs = model_opt(inputs) 159 | # loss = self.objective(outputs, labels) + 0.001*reg 160 | # # optim 161 | # outputs = torch.norm(model_opt(inputs) - model_fix(inputs), p = 2) 162 | # self.loss = outputs 163 | loss.backward() 164 | optimizer.step() 165 | 166 | def train_single_step_con(self, inputs, labels, model_sim, model_all, tmp, mu, base = 'representation', reg_lam = None, reg_model = None): 167 | inputs = inputs.to(self.device) 168 | labels = torch.flatten(labels) 169 | labels = labels.to(self.device, dtype = torch.long) 170 | # zero the parameter gradients 171 | # model_opt.zero_grad(set_to_none=True) 172 | self.optim.zero_grad() 173 | # forward + backward + optimize 174 | 175 | # contrastive loss 176 | output_con_dn = 0 177 | if base == 'representation': 178 | # for i in model_all: 179 | # i.to(self.device) 180 | # output_con_dn += torch.exp(F.cosine_similarity(self.intermediate_output(inputs, self.model, self.model.conv2, 'conv2')\ 181 | # , self.intermediate_output(inputs, i, i.conv2, 'conv2'), dim = -1)/tmp) 182 | # output_con_n = torch.exp(F.cosine_similarity(self.intermediate_output(inputs, self.model, self.model.conv2, 'conv2')\ 183 | # , self.intermediate_output(inputs, model_sim, model_sim.conv2, 'conv2'), dim = -1)/tmp) 184 | for i in model_all: 185 | output_con_dn += torch.exp(F.cosine_similarity(self.model(inputs), i(inputs))/tmp) 186 | output_con_n = torch.exp(F.cosine_similarity(self.model(inputs), model_sim(inputs))/tmp) 187 | con_loss = torch.mean(-torch.log(output_con_n/output_con_dn)) 188 | 189 | elif base == 'parameter': 190 | negative = [torch.cat(tuple([torch.flatten(i.state_dict()[k]) for k in i.state_dict().keys() if 'fc' in k]),0) for i in model_all] 191 | positive = torch.cat(tuple([torch.flatten(self.model.state_dict()[k]) for k in self.model.state_dict().keys() if 'fc' in k]),0) 192 | 193 | for i in range(len(model_all)): 194 | tmp = torch.exp(F.cosine_similarity(positive, negative[i], dim=0)/tmp) 195 | output_con_dn += tmp 196 | if i == self.label: 197 | output_con_n = tmp 198 | con_loss = -torch.log(output_con_n/output_con_dn) 199 | else: 200 | con_loss = 0 201 | 202 | # knowledge sharing 203 | if reg_lam: 204 | reg = torch.square(torch.norm(torch.cat(tuple([torch.flatten(self.model.state_dict()[k] - reg_model.state_dict()[k])\ 205 | for k in self.model.state_dict().keys() if 'fc' not in k]),0),2)) 206 | else: 207 | reg, reg_lam = 0, 0 208 | 209 | loss = self.objective(self.model(inputs), labels) + con_loss * mu + reg_lam * reg 210 | # if self.id == 0: 211 | # # print(self.model.state_dict()['fc1.bias']) 212 | # # print(self.label) 213 | # # for i in range(len(model_all)): 214 | # # print(i, model_all[i].state_dict()['fc1.bias']) 215 | # print(self.intermediate_output(inputs, self.model, self.model.conv2, 'conv2').shape) 216 | # # print(self.intermediate_output(inputs, model_sim, model_sim.conv2, 'conv2')) 217 | # print(output_con_n, output_con_dn) 218 | 219 | loss.backward() 220 | self.optim.step() 221 | 222 | def intermediate_output(self, inputs, model, model_layer, layer_name): 223 | activation = {} 224 | def get_activation(name): 225 | def hook(model, input, output): 226 | activation[name] = output.detach() 227 | return hook 228 | 229 | model_layer.register_forward_hook(get_activation(layer_name)) 230 | out = model(inputs) 231 | return torch.flatten(activation[layer_name],1) 232 | 233 | # for IFCA 234 | def local_train_loss(self, model): 235 | model.to(self.device) 236 | train_loss = 0 237 | i = 0 238 | with torch.no_grad(): 239 | for data in self.train: 240 | inputs, labels = data 241 | inputs = inputs.to(self.device) 242 | labels = torch.flatten(labels) 243 | labels = labels.to(self.device, dtype = torch.long) 244 | # forward 245 | outputs = model(inputs) 246 | train_loss += self.objective(outputs, labels) 247 | i+=1 248 | if i>=10: 249 | break 250 | # return train_loss/len(self.train) 251 | return train_loss/i 252 | 253 | def local_train_acc(self, model): 254 | model.to(self.device) 255 | predict_ts = torch.empty(0).to(self.device) 256 | label_ts = torch.empty(0).to(self.device) 257 | i = 0 258 | with torch.no_grad(): 259 | for data in self.train: 260 | inputs, labels = data 261 | inputs = inputs.to(self.device) 262 | labels = torch.flatten(labels) 263 | labels = labels.to(self.device, dtype = torch.long) 264 | outputs = model(inputs) 265 | _, predicted = torch.max(outputs.data, 1) 266 | predict_ts = torch.cat([predict_ts, predicted], 0) 267 | label_ts = torch.cat([label_ts, labels], 0) 268 | i+=1 269 | if i>=10: 270 | break 271 | acc = accuracy_score(label_ts.cpu(), predict_ts.cpu()) 272 | return acc 273 | 274 | def local_test(self, model_res = None): 275 | predict_ts = torch.empty(0).to(self.device) 276 | label_ts = torch.empty(0).to(self.device) 277 | with torch.no_grad(): 278 | for data in self.test: 279 | inputs, labels = data 280 | inputs = inputs.to(self.device) 281 | labels = torch.flatten(labels) 282 | labels = labels.to(self.device, dtype = torch.long) 283 | if model_res: 284 | model_res.to(self.device) 285 | outputs = model_res(inputs) + self.model(inputs) 286 | else: 287 | outputs = self.model(inputs) 288 | # print(outputs.data.dtype) 289 | _, predicted = torch.max(outputs.data, 1) 290 | predict_ts = torch.cat([predict_ts, predicted], 0) 291 | label_ts = torch.cat([label_ts, labels], 0) 292 | acc = accuracy_score(label_ts.cpu(), predict_ts.cpu()) 293 | macro_f1 = f1_score(label_ts.cpu(), predict_ts.cpu(), average='macro') 294 | # micro_f1 = f1_score(label_ts.cpu(), predict_ts.cpu(), average='micro') 295 | # print('Accuracy, Macro F1, Micro F1 of Device %d on the %d test cases: %.2f %%, %.2f, %.2f' % (self.id, len(label_ts), acc*100, macro_f1, micro_f1)) 296 | print('Accuracy, Macro F1 of Device %d on the %d test cases: %.2f %%, %.2f %%' % (self.id, len(label_ts), acc*100, macro_f1*100)) 297 | self.test_metrics.append([acc, macro_f1]) 298 | 299 | 300 | def local_ensemble_test(self, model_list, voting = 'soft'): 301 | predict_ts = torch.empty(0).to(self.device) 302 | label_ts = torch.empty(0).to(self.device) 303 | with torch.no_grad(): 304 | for data in self.test: 305 | inputs, labels = data 306 | inputs = inputs.to(self.device) 307 | labels = torch.flatten(labels) 308 | labels = labels.to(self.device, dtype = torch.long) 309 | out_hard = [] 310 | if voting == 'soft': 311 | out = torch.zeros(self.model(inputs).data.shape).to(self.device) 312 | for model in model_list: 313 | outputs = model(inputs) 314 | out = out + outputs.data/len(model_list) 315 | _, predicted = torch.max(out, 1) 316 | elif voting == 'hard': 317 | out_hard = [] 318 | for model in model_list: 319 | outputs = model(inputs) 320 | _, predicted = torch.max(outputs.data, 1) 321 | out_hard.append(predicted) 322 | predicted = torch.tensor([mode([out_hard[i][j] for i in range(len(out_hard))]) for j in range(len(out_hard[0]))]).to(self.device) 323 | 324 | predict_ts = torch.cat([predict_ts, predicted], 0) 325 | label_ts = torch.cat([label_ts, labels], 0) 326 | acc = accuracy_score(label_ts.cpu(), predict_ts.cpu()) 327 | macro_f1 = f1_score(label_ts.cpu(), predict_ts.cpu(), average='macro') 328 | # micro_f1 = f1_score(label_ts.cpu(), predict_ts.cpu(), average='micro') 329 | # print('Accuracy, Macro F1, Micro F1 of Device %d on the %d test cases: %.2f %%, %.2f, %.2f' % (self.id, len(label_ts), acc*100, macro_f1, micro_f1)) 330 | print('Accuracy, Macro F1 of Device %d on the %d test cases: %.2f %%, %.2f %%' % (self.id, len(label_ts), acc*100, macro_f1*100)) 331 | self.test_metrics.append([acc, macro_f1]) 332 | 333 | 334 | # def model_representation(self, test_set, repr='output'): 335 | # self.model_repr = self.model(test_set)/len(test_set) 336 | # return self.model_repr 337 | 338 | # def log(self): 339 | # pass 340 | 341 | 342 | # class nodes_control(node): 343 | # def __init__(self, id_list): 344 | # pass 345 | 346 | # def assign_one_model(self, model): 347 | # pass 348 | -------------------------------------------------------------------------------- /fedbase/server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsmjie/FedBase/5c386c8e5b64591f435f6821b339f8aa2d867db9/fedbase/server/__init__.py -------------------------------------------------------------------------------- /fedbase/server/server.py: -------------------------------------------------------------------------------- 1 | # from nodes.node import node 2 | import torch 3 | from sklearn.cluster import KMeans 4 | import numpy as np 5 | import pandas as pd 6 | import matplotlib.pyplot as plt 7 | import matplotlib as mpl 8 | from pandas.plotting import parallel_coordinates 9 | import traceback 10 | 11 | class server_class(): 12 | def __init__(self, device): 13 | self.device = device 14 | self.test_metrics = [] 15 | self.clustering = {'label':[], 'raw':[], 'center':[]} 16 | 17 | def assign_model(self, model): 18 | try: 19 | self.model.load_state_dict(model.state_dict()) 20 | except: 21 | self.model = model 22 | self.model.to(self.device) 23 | 24 | def aggregate(self, node_id_list, model_list, weight_list): 25 | aggregated_weights = self.model.state_dict() 26 | for j in aggregated_weights.keys(): 27 | aggregated_weights[j] = torch.zeros(aggregated_weights[j].shape).to(self.device) 28 | weight_list = [i/sum(weight_list) for i in weight_list] 29 | for i in node_id_list: 30 | for j in model_list[i].state_dict().keys(): 31 | aggregated_weights[j] += model_list[i].state_dict()[j]*weight_list[i] 32 | return aggregated_weights 33 | 34 | def aggregate(self, model_list, weight_list): 35 | aggregated_weights = model_list[0].state_dict() 36 | for j in aggregated_weights.keys(): 37 | aggregated_weights[j] = torch.zeros(aggregated_weights[j].shape).to(self.device) 38 | # sum_size = sum([nodes[i].data_size for i in idlist]) 39 | for i in range(len(model_list)): 40 | for j in model_list[i].state_dict().keys(): 41 | aggregated_weights[j] += model_list[i].state_dict()[j]*weight_list[i] 42 | return aggregated_weights 43 | 44 | def distribute(self, model_in_list, model_dis_dict = None): 45 | if not model_dis_dict: 46 | model_dis_dict = self.model.state_dict() 47 | for i in model_in_list: 48 | i.load_state_dict(model_dis_dict) 49 | 50 | def acc(self, nodes, weight_list): 51 | global_test_metrics = [0]*2 52 | for i in range(len(weight_list)): 53 | for j in range(len(global_test_metrics)): 54 | global_test_metrics[j] += weight_list[i]*nodes[i].test_metrics[-1][j] 55 | print('GLOBAL Accuracy, Macro F1 is %.2f %%, %.2f %%' % (100*global_test_metrics[0], 100*global_test_metrics[1])) 56 | self.test_metrics.append(global_test_metrics) 57 | 58 | def client_sampling(self, frac, distribution): 59 | pass 60 | 61 | def test(self, test_loader): 62 | correct = 0 63 | total = 0 64 | with torch.no_grad(): 65 | for data in test_loader: 66 | inputs, labels = data 67 | inputs = inputs.to(self.device) 68 | labels = torch.flatten(labels) 69 | labels = labels.to(self.device, dtype = torch.long) 70 | outputs = self.model(inputs) 71 | _, predicted = torch.max(outputs.data, 1) 72 | total += labels.size(0) 73 | correct += (predicted == labels).sum().item() 74 | print('Accuracy on the %d test cases: %.2f %%' % (total, 100*correct / total)) 75 | # torch.cuda.empty_cache() 76 | 77 | def model_similarity(model_repr_1, model_repr_2, repr='output'): 78 | if repr == 'output': 79 | self.similarity = (log(model_repr_1)-log(model_repr_2)).sum(axis=1).abs() 80 | 81 | def weighted_clustering(self, nodes, idlist, K, weight_type='data_size'): 82 | weight = [] 83 | X = [] 84 | sum_size = sum([nodes[i].data_size for i in idlist]) 85 | # print(list(nodes[0].model.state_dict().keys())) 86 | for i in idlist: 87 | if weight_type == 'equal': 88 | weight.append(1/len(idlist)) 89 | elif weight_type == 'data_size': 90 | weight.append(nodes[i].data_size/sum_size) 91 | X.append(np.array(torch.flatten(nodes[i].model.state_dict()[list(nodes[i].model.state_dict().keys())[-2]]).cpu())) 92 | # print(X, np.array(X).shape) 93 | kmeans = KMeans(n_clusters=K, n_init = 5).fit(np.asarray(X), sample_weight= weight) 94 | labels = kmeans.labels_ 95 | print(labels) 96 | print([list(labels).count(i) for i in range(K)]) 97 | for i in idlist: 98 | nodes[i].label = labels[i] 99 | self.clustering['label'].append(list(labels.astype(int))) 100 | # self.clustering['raw'].append(X) 101 | # self.clustering['center'].append(kmeans.cluster_centers_) 102 | 103 | def calculate_B(self, nodes, idlist): 104 | sum_size = sum([nodes[i].data_size for i in idlist]) 105 | # print(idlist, sum_size) 106 | avg = sum([sum(nodes[i].grads)*(nodes[i].data_size)/sum_size for i in idlist]) 107 | # print(avg[-10:], nodes[idlist[0]].grads[0][-10:]) 108 | # print(avg.shape, nodes[idlist[0]].grads.shape) 109 | B_list = [] 110 | u_list = [] 111 | for i in idlist: 112 | # print(torch.norm((sum(nodes[i].grads) - avg), 1)/torch.norm(avg, 1)) 113 | # print(torch.norm((sum(nodes[i].grads) - avg), 1)) 114 | u_list.append(float(torch.norm((sum(nodes[i].grads) - avg), 2))) 115 | B_list.append(float(torch.norm((sum(nodes[i].grads) - avg), 2)/torch.norm(avg, 2))) 116 | # nodes[i].grads = [] 117 | # print(torch.norm(nodes[i].grads - avg, 2),torch.norm(avg, 2)) 118 | return B_list, u_list 119 | 120 | def clustering_plot(self): 121 | # print(self.clustering) 122 | # self.clustering =[[1,1,2,2,3,3],[1,1,1,2,2,2],[1, 1, 1, 2, 2, 2],[1, 1, 1, 2, 2, 2]] 123 | col = [str(i) for i in range(len(self.clustering))]+['id'] 124 | self.clustering.append(list(range(len(self.clustering[0])))) 125 | data= pd.DataFrame(np.array(self.clustering).T,columns= col) 126 | for i in data.columns: 127 | data[i]=data[i].apply(lambda x: str(x)) 128 | # Make the plot 129 | parallel_coordinates(data, 'id', colormap=plt.get_cmap("Set2")) 130 | plt.show() 131 | -------------------------------------------------------------------------------- /fedbase/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsmjie/FedBase/5c386c8e5b64591f435f6821b339f8aa2d867db9/fedbase/utils/__init__.py -------------------------------------------------------------------------------- /fedbase/utils/cfl_utils: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.cluster import AgglomerativeClustering 3 | import torch 4 | 5 | 6 | def pairwise_angles(sources): 7 | angles = torch.zeros([len(sources), len(sources)]) 8 | for i, source1 in enumerate(sources): 9 | for j, source2 in enumerate(sources): 10 | s1 = flatten(source1) 11 | s2 = flatten(source2) 12 | angles[i,j] = torch.sum(s1*s2)/(torch.norm(s1)*torch.norm(s2)+1e-12) 13 | return angles.numpy() 14 | 15 | def compute_pairwise_similarities(clients): 16 | return pairwise_angles([client.dW for client in clients]) 17 | 18 | def cluster_clients( S): 19 | clustering = AgglomerativeClustering(affinity="precomputed", linkage="complete").fit(-S) 20 | 21 | c1 = np.argwhere(clustering.labels_ == 0).flatten() 22 | c2 = np.argwhere(clustering.labels_ == 1).flatten() 23 | return c1, c2 24 | 25 | def compute_max_update_norm(self, cluster): 26 | return np.max([torch.norm(flatten(client.dW)).item() for client in cluster]) 27 | 28 | 29 | def compute_mean_update_norm(self, cluster): 30 | return torch.norm(torch.mean(torch.stack([flatten(client.dW) for client in cluster]), 31 | dim=0)).item() -------------------------------------------------------------------------------- /fedbase/utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torchvision import datasets, transforms 3 | import numpy as np 4 | from torch.utils.data import DataLoader, Dataset, random_split, Subset, ChainDataset, ConcatDataset 5 | import torch 6 | from torch._utils import _accumulate 7 | import matplotlib.pyplot as plt 8 | import matplotlib as mpl 9 | from fedbase.utils.tools import get_targets 10 | from fedbase.utils import femnist 11 | import os 12 | # import pickle 13 | import json 14 | import datetime as d 15 | import math 16 | import pandas as pd 17 | from pathlib import Path 18 | from collections import Counter 19 | import medmnist 20 | 21 | class data_process: 22 | def __init__(self, dataset_name): 23 | dir ='./data/' 24 | self.dataset_name = dataset_name 25 | if dataset_name == 'mnist': 26 | apply_transform = transforms.Compose([ 27 | transforms.ToTensor(), 28 | transforms.Normalize((0.1307,), (0.3081,))]) 29 | self.train_dataset = datasets.MNIST( 30 | dir+dataset_name, train=True, download=True, transform=apply_transform) 31 | self.test_dataset = datasets.MNIST( 32 | dir+dataset_name, train=False, download=True, transform=apply_transform) 33 | elif dataset_name == 'cifar10': 34 | transform = transforms.Compose( 35 | [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 36 | self.train_dataset = datasets.CIFAR10( 37 | dir+dataset_name, train=True, download=True, transform=transform) 38 | self.test_dataset = datasets.CIFAR10( 39 | dir+dataset_name, train=False, download=True, transform=transform) 40 | elif dataset_name == 'femnist': 41 | apply_transform = transforms.Compose( 42 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 43 | self.train_dataset = femnist.FEMNIST(dir+dataset_name, train=True, download=False, 44 | transform=apply_transform) 45 | self.test_dataset = femnist.FEMNIST(dir+dataset_name, train=False, download=False, 46 | transform=apply_transform) 47 | elif dataset_name == 'fashion_mnist': 48 | apply_transform = transforms.Compose([ 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.5,), (0.5,))]) 51 | self.train_dataset = datasets.FashionMNIST( 52 | dir+dataset_name, train=True, download=True, transform=apply_transform) 53 | self.test_dataset = datasets.FashionMNIST( 54 | dir+dataset_name, train=False, download=True, transform=apply_transform) 55 | elif 'medmnist' in dataset_name: 56 | data_flag = dataset_name[9:] 57 | DataClass = getattr(medmnist, medmnist.INFO[data_flag]['python_class']) 58 | # preprocessing 59 | data_transform = transforms.Compose([ 60 | transforms.ToTensor(), 61 | transforms.Normalize(mean=[.5], std=[.5]) 62 | ]) 63 | # load the data 64 | dir = dir + dataset_name +'/' 65 | if not os.path.exists(dir): 66 | os.mkdir(dir) 67 | self.train_dataset = DataClass(split='train', transform=data_transform, download=True, root = dir) 68 | # self.train_dataset.labels = torch.tensor(self.train_dataset.labels, dtype = torch.long) 69 | self.val_dataset = DataClass(split='val', transform=data_transform, download=True, root = dir) 70 | self.test_dataset = DataClass(split='test', transform=data_transform, download=True, root = dir) 71 | 72 | self.test_dataset = ConcatDataset([self.val_dataset, self.test_dataset]) 73 | # print(len(self.val_dataset), len(self.test_dataset)) 74 | # print(self.train_dataset) 75 | 76 | sample = next(iter(self.train_dataset)) 77 | image, label = sample 78 | print(image.shape) 79 | 80 | # show image 81 | # batch_size = 4 82 | # trainloader = DataLoader(self.train_dataset, batch_size=batch_size, 83 | # shuffle=True, num_workers=2) 84 | # classes = ('plane', 'car', 'bird', 'cat', 85 | # 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 86 | # def imshow(img): 87 | # img = img / 2 + 0.5 # unnormalize 88 | # npimg = img.numpy() 89 | # plt.imshow(np.transpose(npimg, (1, 2, 0))) 90 | # plt.show() 91 | 92 | # # get some random training images 93 | # dataiter = iter(trainloader) 94 | # images, labels = dataiter.next() 95 | 96 | # # show images 97 | # imshow(torchvision.utils.make_grid(images)) 98 | # # print labels 99 | # print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size))) 100 | 101 | def split_dataset(self, num_nodes, alpha, method='dirichlet', train_dataset = None, test_dataset = None, plot_show = False): 102 | train_dataset = self.train_dataset if train_dataset is None else train_dataset 103 | test_dataset = self.test_dataset if test_dataset is None else test_dataset 104 | train_targets, test_targets = get_targets(train_dataset), get_targets(test_dataset) 105 | if num_nodes == 1: 106 | return train_dataset, test_dataset 107 | else: 108 | if method == 'iid': 109 | train_lens_list = [int(len(train_dataset)/num_nodes) for i in range(num_nodes)] 110 | test_lens_list = [int(len(test_dataset)/num_nodes) for i in range(num_nodes)] 111 | train_splited, test_splited = random_split(Subset(train_dataset, torch.arange(sum(train_lens_list))), train_lens_list), random_split(Subset(test_dataset, torch.arange(sum(test_lens_list))), test_lens_list) 112 | # plot 113 | labels = torch.unique(train_targets) 114 | self.plot_split(labels, train_splited) 115 | return train_splited, test_splited 116 | else: 117 | labels, train_label_size = torch.unique(train_targets, return_counts=True) 118 | _, test_label_size = torch.unique(test_targets, return_counts=True) 119 | # print(train_label_size, test_label_size) 120 | l_train = train_label_size.reshape( 121 | len(labels), 1).repeat(1, num_nodes) 122 | l_test = test_label_size.reshape( 123 | len(labels), 1).repeat(1, num_nodes) 124 | 125 | train_splited = [] 126 | test_splited = [] 127 | while len(test_splited) <= num_nodes//2: 128 | # print(l_test) 129 | if method == 'dirichlet': 130 | # print(len(test_dataset), min(test_label_size)) 131 | # dec_round = round(math.log(len(test_dataset)/len(labels),10)) 132 | dec_round = 2 133 | # p = torch.tensor(np.round(np.random.dirichlet(np.repeat(alpha, num_nodes), len(labels)), round(math.log(len(test_dataset)/len(labels),10)))) 134 | p = torch.tensor(np.floor(np.random.dirichlet(np.repeat(alpha, num_nodes), len(labels))*10**dec_round)/10**dec_round) 135 | # print(torch.sum(p,axis=1)) 136 | # print(p) 137 | elif method == 'class': 138 | p = np.zeros((len(labels), 1)) 139 | J = np.random.choice(len(labels), alpha, replace=False) 140 | p[J] = 1 141 | for k in range(1, num_nodes): 142 | x = np.zeros((len(labels), 1)) 143 | J = np.random.choice(len(labels), alpha, replace=False) 144 | x[J] = 1 145 | p = np.concatenate((p, x), axis=1) 146 | p = p / np.repeat((p.sum(axis=1)+10**-10).reshape(len(labels), 1), num_nodes, axis=1) 147 | # print(p.sum(axis=1),p) 148 | train_size = torch.round(l_train*p).int() 149 | test_size = torch.round(l_test*p).int() 150 | # print(train_size, test_size) 151 | train_label_index = [] 152 | test_label_index = [] 153 | for j in range(len(labels)): 154 | train_label_index.append([(train_targets== labels[j]).nonzero(as_tuple=True)[ 155 | 0][offset-length:offset] for offset, length in zip(_accumulate(train_size[j, :]), train_size[j, :])]) 156 | test_label_index.append([(test_targets== labels[j]).nonzero(as_tuple=True)[ 157 | 0][offset-length:offset] for offset, length in zip(_accumulate(test_size[j, :]), test_size[j, :])]) 158 | # how to deal with 0? 159 | for i in range(num_nodes): 160 | if len(ConcatDataset([Subset(test_dataset, test_label_index[j][i]) for j in range(len(labels))]))>5: # 0-10, to control the minimun length 161 | train_splited.append(ConcatDataset( 162 | [Subset(train_dataset, train_label_index[j][i]) for j in range(len(labels))])) 163 | test_splited.append(ConcatDataset( 164 | [Subset(test_dataset, test_label_index[j][i]) for j in range(len(labels))])) 165 | while len(test_splited)`_ Dataset. 28 | 29 | Args: 30 | root (string): Root directory of dataset where ``processed/training.pt`` 31 | and ``processed/test.pt`` exist. 32 | train (bool, optional): If True, creates dataset from ``training.pt``, 33 | otherwise from ``test.pt``. 34 | download (bool, optional): If true, downloads the dataset from the internet and 35 | puts it in root directory. If dataset is already downloaded, it is not 36 | downloaded again. 37 | transform (callable, optional): A function/transform that takes in an PIL image 38 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 39 | target_transform (callable, optional): A function/transform that takes in the 40 | target and transforms it. 41 | """ 42 | 43 | training_file = 'training.pt' 44 | test_file = 'test.pt' 45 | classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', 46 | '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] 47 | 48 | @property 49 | def train_labels(self): 50 | warnings.warn("train_labels has been renamed targets") 51 | return self.targets 52 | 53 | @property 54 | def test_labels(self): 55 | warnings.warn("test_labels has been renamed targets") 56 | return self.targets 57 | 58 | @property 59 | def train_data(self): 60 | warnings.warn("train_data has been renamed data") 61 | return self.data 62 | 63 | @property 64 | def test_data(self): 65 | warnings.warn("test_data has been renamed data") 66 | return self.data 67 | 68 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False): 69 | self.root = os.path.expanduser(root) 70 | self.transform = transform 71 | self.target_transform = target_transform 72 | self.train = train # training set or test set 73 | 74 | # s_list = random.sample(range(0, 7), num_users) 75 | if self.train: 76 | # data_file = self.training_file 77 | self.data, self.targets = self.generate_ds( self.root) 78 | # self.loader = self.generate_ds(args, self.root) 79 | else: 80 | # data_file = self.test_file 81 | self.data, self.targets = self.generate_ds_test(self.root) 82 | # self.loader = self.generate_ds_test(args, self.root) 83 | # self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) 84 | 85 | 86 | def __getitem__(self, index): 87 | """ 88 | Args: 89 | index (int): Index 90 | 91 | Returns: 92 | tuple: (image, target) where target is index of the target class. 93 | """ 94 | img, target = self.data[index], int(self.targets[index]) 95 | 96 | # doing this so that it is consistent with all other datasets 97 | # to return a PIL Image 98 | img = Image.open(img).convert('L') 99 | # loader = transforms.Compose([transforms.ToTensor()]) 100 | # img = loader(img).unsqueeze(0)[0, 0, :, :] 101 | 102 | if self.transform is not None: 103 | img = self.transform(img) 104 | 105 | if self.target_transform is not None: 106 | target = self.target_transform(target) 107 | 108 | return img, target 109 | 110 | def __len__(self): 111 | return len(self.data) 112 | 113 | @property 114 | def raw_folder(self): 115 | return os.path.join(self.root, self.__class__.__name__, 'raw') 116 | 117 | @property 118 | def processed_folder(self): 119 | return os.path.join(self.root, 'data', 'processed') 120 | 121 | @property 122 | def class_to_idx(self): 123 | return {_class: i for i, _class in enumerate(self.classes)} 124 | 125 | def generate_ds(self, args, root): 126 | # read 100 images per class per style 127 | num_class = args.num_classes 128 | num_img = args.train_shots_max * args.num_users 129 | 130 | data = [] 131 | targets = torch.zeros([num_class * num_img]) 132 | # files = os.listdir(os.path.join(root, 'data', 'raw_data', 'by_class')) 133 | files = os.listdir(os.path.join(root, 'by_class')) 134 | 135 | for i in range(num_class): 136 | for k in range(num_img): 137 | # img = os.path.join(root, 'data', 'raw_data', 'by_class', files[i], 'train_' + files[i], 'train_' + files[i] + '_'+str("%05d"%k)+'.png') 138 | img = os.path.join(root, 'by_class', files[i], 'train_' + files[i],'train_' + files[i] + '_' + str("%05d" % k) + '.png') 139 | data.append(img) 140 | targets[i * num_img + k] = i 141 | 142 | targets = targets.reshape([num_class * num_img]) 143 | 144 | return data, targets 145 | 146 | def generate_ds_test(self, args, root): 147 | # read 100 images per classes per style 148 | 149 | num_class = args.num_classes 150 | # num_style = args.num_styles 151 | num_img = args.test_shots * args.num_users 152 | 153 | data = [] 154 | # targets = torch.zeros([num_class * num_style * num_img]) 155 | targets = torch.zeros([num_class * num_img]) 156 | # files = os.listdir(os.path.join(root, 'data', 'raw_data', 'by_class')) 157 | files = os.listdir(os.path.join(root, 'by_class')) 158 | 159 | for i in range(num_class): 160 | for k in range(num_img): 161 | # img = os.path.join(root, 'data', 'raw_data', 'by_class', files[i], 'hsf_0', 'hsf_0'+'_00'+str("%03d"%(k))+'.png') 162 | img = os.path.join(root, 'by_class', files[i], 'hsf_0', 'hsf_0' + '_00' + str("%03d" % (k)) + '.png') 163 | data.append(img) 164 | targets[i * num_img + k] = i 165 | 166 | targets = targets.reshape([num_class * num_img]) 167 | 168 | return data, targets -------------------------------------------------------------------------------- /fedbase/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def save_checkpoint(model_structure, model_parameter, optimizer_parameter, path): 4 | checkpoint = {'model_structure': model_structure, 5 | 'model_parameter': model_parameter, 'optimizer_parameter': optimizer_parameter} 6 | torch.save(checkpoint, path) 7 | 8 | def load_checkpoint(path): 9 | checkpoint = torch.load(path) 10 | model = checkpoint['model_structure'] 11 | model.load_state_dict(checkpoint['model_parameter']) 12 | # for parameter in model.parameters(): 13 | # parameter.requires_grad = False 14 | # optimizer = checkpoint['optimizer'] 15 | # model.eval() 16 | return model, checkpoint['optimizer_parameter'] 17 | 18 | 19 | # KL divergence variant log p/q,abs(log p-log q) 20 | def blackbox_mc(model, testset): 21 | outputs = model(testset) 22 | return outputs.sum(axis=1)/len(testset) 23 | 24 | # model similarity 25 | def similarity(model_1,model_2): 26 | return abs(log(blackbox_mc(model_1))-log(blackbox_mc(model_2))) 27 | -------------------------------------------------------------------------------- /fedbase/utils/tools.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from torch.utils.data import Dataset 3 | import torch.utils.data as data 4 | import torch 5 | 6 | def unpack_args(func): 7 | from functools import wraps 8 | @wraps(func) 9 | def wrapper(args): 10 | if isinstance(args, dict): 11 | return func(**args) 12 | else: 13 | return func(*args) 14 | return wrapper 15 | 16 | def find_files(dir,*args): 17 | files=glob.glob(dir) 18 | print(files) 19 | 20 | def get_targets(dataset): 21 | """Get the targets of a dataset without any target target transforms(!).""" 22 | # if isinstance(dataset, TransformedDataset): 23 | # return get_targets(dataset.dataset) 24 | if isinstance(dataset, data.Subset): 25 | targets = get_targets(dataset.dataset) 26 | return torch.as_tensor(targets)[dataset.indices] 27 | if isinstance(dataset, data.ConcatDataset): 28 | return torch.cat([get_targets(sub_dataset) for sub_dataset in dataset.datasets]) 29 | try: 30 | if torch.is_tensor(dataset.targets)==False: 31 | return torch.as_tensor(dataset.targets) 32 | else: 33 | return dataset.targets 34 | except: 35 | # print(dataset.labels) 36 | if torch.is_tensor(dataset.labels)==False: 37 | return torch.as_tensor(dataset.labels) 38 | else: 39 | return dataset.labels 40 | # if isinstance( 41 | # dataset, (datasets.MNIST, datasets.ImageFolder,) 42 | # ): 43 | # return torch.as_tensor(dataset.targets) 44 | # if isinstance(dataset, datasets.SVHN): 45 | # return dataset.labels 46 | 47 | # raise NotImplementedError(f"Unknown dataset {dataset}!") 48 | # find_files('./log/central*cifar10*') 49 | 50 | 51 | def add_(input_str): 52 | return f'_{str(input_str)}' if input_str is not None else '' -------------------------------------------------------------------------------- /fedbase/utils/visualize.py: -------------------------------------------------------------------------------- 1 | from sklearn.manifold import TSNE 2 | # from keras.datasets import mnist 3 | from sklearn.datasets import load_iris 4 | from numpy import reshape 5 | import seaborn as sns 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | from sklearn import decomposition 9 | import numpy as np 10 | 11 | def dimension_reduction(data, label , method): 12 | x = data 13 | y = np.array(label) 14 | 15 | print(x.shape) 16 | print(y.shape) 17 | 18 | if method == 'tsne': 19 | tsne = TSNE(n_components=2, verbose=1, random_state=123) 20 | z = tsne.fit_transform(x) 21 | elif method == 'pca': 22 | pca = decomposition.PCA(n_components=2) 23 | pca.fit(x) 24 | z = pca.transform(x) 25 | 26 | df = pd.DataFrame() 27 | df["y"] = y 28 | df["v1"] = z[:,0] 29 | df["v2"] = z[:,1] 30 | 31 | sns.set_theme(style="darkgrid") 32 | sns.scatterplot(x="v1", y="v2", hue=df.y.tolist(), 33 | palette=sns.color_palette("hls", len(set(y))), 34 | data=df, s = 12) 35 | # plt.legend(bbox_to_anchor=(0, 1), loc='upper left', ncol=1) 36 | plt.legend().remove() 37 | plt.xlabel('v1', fontsize=16) 38 | plt.ylabel('v2', fontsize=16) 39 | plt.show() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="fedbase", 8 | version="0.10.3", 9 | author="Jie MA", 10 | # author_email="ustcmj@gmail.com, jie.ma-5@student.uts.edu.au", 11 | author_email="ustcmj@gmail.com", 12 | description="An easy, silly, DIY Federated Learning framework with many baselines for individual researchers.", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/jie-ma-ai/FedBase", 16 | packages=setuptools.find_packages(), 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | ], 22 | ) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from fedbase.baselines import * 3 | from fedbase.model.model import * 4 | from fedbase.nodes.node import node 5 | from fedbase.utils.tools import unpack_args 6 | from fedbase.utils.data_loader import data_process 7 | import torch 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | from functools import partial 11 | import numpy as np 12 | import multiprocessing as mp 13 | import time 14 | import torchvision.models as models 15 | 16 | os.chdir(os.path.dirname(os.path.abspath(__file__))) # set the current path as the working directory 17 | global_rounds = 3 18 | num_nodes = 200 19 | local_steps = 10 20 | batch_size = 32 21 | # optimizer = partial(optim.SGD,lr=0.001, momentum=0.9) 22 | optimizer = partial(optim.SGD,lr=0.001) 23 | # device = torch.device('cuda:2') 24 | device = torch.device('cuda') # Use GPU if available 25 | 26 | 27 | @unpack_args 28 | def main0(seeds, dataset_splited, model): 29 | np.random.seed(seeds) 30 | central.run(dataset_splited, batch_size, model, nn.CrossEntropyLoss, optimizer, global_rounds, device = device) 31 | 32 | @unpack_args 33 | def main1(seeds, dataset_splited, model): 34 | np.random.seed(seeds) 35 | fedavg.run(dataset_splited, batch_size, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, device = device) 36 | # fedavg_finetune.run(dataset_splited, batch_size, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, 10, device = device) 37 | # local.run(dataset_splited, batch_size, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, device = device) 38 | # ditto.run(dataset_splited, batch_size, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, 0.95, device = device) 39 | fedprox.run(dataset_splited, batch_size, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, 0.1, device = device) 40 | 41 | @unpack_args 42 | def main2(seeds, dataset_splited, model, K): 43 | np.random.seed(seeds) 44 | fedavg_ensemble.run(dataset_splited, batch_size, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, K, device = device) 45 | fedprox_ensemble.run(dataset_splited, batch_size, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, K, device = device) 46 | # wecfl.run(dataset_splited, batch_size, K, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, device = device) 47 | # fesem.run(dataset_splited, batch_size, K, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, device = device) 48 | # ifca.run(dataset_splited, batch_size, K, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, device = device) 49 | # wecfl.run(dataset_splited, batch_size, K, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, 0.95, device = device) 50 | # fesem.run(dataset_splited, batch_size, K, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, 0.95, device = device) 51 | # ifca.run(dataset_splited, batch_size, K, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, 0.95, device = device) 52 | 53 | # multiprocessing 54 | if __name__ == '__main__': 55 | seed = 1 56 | np.random.seed(seed) 57 | torch.manual_seed(seed) 58 | # data_process('cifar10').split_dataset(200,2,'class') 59 | # for i in range(3): 60 | # data_process('cifar10').split_dataset_groupwise(10,0.1,'dirichlet',20,10,'dirichlet') 61 | # for i in range(1): 62 | # np.random.seed(i) 63 | # data_process('fashion_mnist').split_dataset_groupwise(10,0.1,'dirichlet',20,10,'dirichlet', plot_show=True) 64 | # data_process('fashion_mnist').split_dataset_groupwise(10,3,'class',20,2,'class', plot_show=True) 65 | # data_process('medmnist_pathmnist').split_dataset_groupwise(10,0.1,'dirichlet',20,10,'dirichlet', plot_show=True) 66 | # data_process('medmnist_octmnist').split_dataset_groupwise(10,0.1,'dirichlet',20,10,'dirichlet', plot_show=True) 67 | # data_process('medmnist_tissuemnist').split_dataset_groupwise(10,3,'class',20,2,'class', plot_show=True) 68 | # data_process('fashion_mnist').split_dataset(200,0.1,'dirichlet', plot_show= True) 69 | # data_process('fashion_mnist').split_dataset(200,2,'class', plot_show= True) 70 | # print(a) 71 | # # data_process('fashion_mnist').split_dataset_groupwise(10,3,'class',20, 2,'class', plot_show=True) 72 | # # data_process('fashion_mnist').split_dataset(18,0.1,'dirichlet', plot_show= True) 73 | # # data_process('cifar10').split_dataset(200,2,'class', plot_show= True) 74 | # # data_process('medmnist_octmnist').split_dataset(200,2,'class', plot_show= True) 75 | # print(a) 76 | # data_process('medmnist_pathmnist').split_dataset(200,2,'class', plot_show= True) 77 | # ditto.run(data_process('fashion_mnist').split_dataset_groupwise(5,6,'class',10,5,'class'), 16, 10, CNNFashion_Mnist, nn.CrossEntropyLoss, optimizer, 3, 10, 0.95) 78 | # fedprox.run(data_process('fashion_mnist').split_dataset_groupwise(10,6,'class',20,5,'class'), batch_size, num_nodes, CNNFashion_Mnist, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, 1) 79 | # fedprox_ensemble.run(data_process('fashion_mnist').split_dataset_groupwise(10,6,'class',20,5,'class'), 16,10, CNNFashion_Mnist, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, 1, 3) 80 | # fedavg.run(data_process('fashion_mnist').split_dataset_groupwise(5,6,'class',10,5,'class'), 16, 10, CNNFashion_Mnist, nn.CrossEntropyLoss, optimizer, 3, 10,finetune=True) 81 | # fedavg_ensemble.run(data_process('fashion_mnist').split_dataset_groupwise(5,6,'class',10,5,'class'), 16,10, CNNFashion_Mnist, nn.CrossEntropyLoss, optimizer, 2, 10, 3) 82 | # ifca.run(data_process('cifar10').split_dataset_groupwise(10,0.1,'dirichlet',20,10,'dirichlet'), batch_size, 10, num_nodes, CNNCifar, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps) 83 | # wecfl.run(data_process('medmnist_tissuemnist').split_dataset_groupwise(10,0.1,'dirichlet',20,10,'dirichlet'), batch_size, 10, num_nodes, CNNTissue, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, reg=0) 84 | # fesem.run(data_process('cifar10').split_dataset_groupwise(10, 0.1, 'dirichlet', 20, 10, 'dirichlet'), batch_size, 10, num_nodes, CNNCifar, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, reg_lam=0.001, finetune=True) 85 | # wecfl.run(data_process('cifar10').split_dataset_groupwise(5, 3, 'class', 40, 2, 'class'), batch_size, 5, num_nodes, CNNCifar, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps) 86 | # cfl_res.run(data_process('cifar10').split_dataset_groupwise(5, 3, 'class', 40, 2, 'class'), batch_size, 5, num_nodes, CNNCifar, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps) 87 | fesem_cam.run(data_process('fashion_mnist').split_dataset(200, 2, 'class'), batch_size, 5, num_nodes, CNNFashion_Mnist, nn.CrossEntropyLoss, optimizer, 2, global_rounds, local_steps, finetune =True, reg_lam=0.01) 88 | # wecfl_con.run(data_process('fashion_mnist').split_dataset(200, 2, 'class'), batch_size, 5, num_nodes, CNNFashion_Mnist, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, warmup_rounds = 1, tmp = 0.1, mu =1, base = 'parameter', reg_lam = 0.01) 89 | # fesem_con.run(data_process('fashion_mnist').split_dataset(200, 2, 'class'), batch_size, 5, num_nodes, CNNFashion_Mnist, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, warmup_rounds = 1, tmp = 0.1, mu =1, base = 'representation', reg_lam = 0.01) 90 | # wecfl_con.run(data_process('fashion_mnist').split_dataset(200, 2, 'class'), batch_size, 5, num_nodes, CNNFashion_Mnist, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, warmup_rounds = 1, tmp = 0.1, mu =10, base = 'parameter', reg_lam = 0.01) 91 | # ifca_con.run(data_process('fashion_mnist').split_dataset(200, 2, 'class'), batch_size, 5, num_nodes, CNNFashion_Mnist, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, warmup_rounds = 1, tmp = 0.1, mu =1, base = 'parameter', reg_lam = 0.01) 92 | # wecfl.run(data_process('cifar10').split_dataset(200, 0.1, 'dirichlet'), batch_size, 5, num_nodes, CNNCifar, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps) 93 | # wecfl_res.run(data_process('cifar10').split_dataset_groupwise(5, 3, 'class', 40, 2, 'class'), batch_size, 5, num_nodes, CNNCifar, nn.CrossEntropyLoss, optimizer, 2, global_rounds, local_steps) 94 | # ifca.run(data_process('fashion_mnist').split_dataset_groupwise(10, 0.1, 'dirichlet', 20, 5, 'dirichlet'), batch_size, 10, num_nodes, CNNFashion_Mnist, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps) 95 | # ifca.run(data_process('cifar10').split_dataset_groupwise(10, 0.1, 'dirichlet', 20, 5, 'dirichlet'), batch_size, 10, num_nodes, CNNCifar, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps) 96 | # ifca.run(data_process('medmnist_octmnist').split_dataset_groupwise(10, 0.1, 'dirichlet', 20, 5, 'dirichlet', plot_show= True), batch_size, 10, num_nodes, oct_net, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps) 97 | print(a) 98 | multi_processes = 2 99 | seeds = 1 100 | # Run 101 | start = time.perf_counter() 102 | mp.set_start_method('spawn') 103 | with mp.Pool(multi_processes) as p: 104 | # group_wise 105 | # p.map(main4, [(i, data_process(dataset).split_dataset_groupwise(n0,j0,k0,n1,j1,k1), model) for i in range(27, 27+seeds) for dataset, model in zip(['cifar10', 'fashion_mnist'],[CNNCifar, CNNFashion_Mnist]) \ 106 | # for n0,n1 in zip([5, 10],[40, 20]) for j0, k0, j1, k1 in zip([6, 0.1], ['class', 'dirichlet'], [5, 10], ['class', 'dirichlet'])]) 107 | # p.map(main5, [(i, data_process(dataset).split_dataset_groupwise(n0,j0,k0,n1,j1,k1), model, K) for i in range(27, 27+seeds) for dataset, model in zip(['cifar10', 'fashion_mnist'],[CNNCifar, CNNFashion_Mnist]) \ 108 | # for K,n0,n1 in zip([5, 10], [5, 10],[40, 20]) for j0, k0, j1, k1 in zip([6, 0.1], ['class', 'dirichlet'], [5, 10], ['class', 'dirichlet'])]) 109 | # client_wise 110 | # p.map(main1, [(i, data_process(dataset).split_dataset(num_nodes, j, k), model) for i in range(27, 27+seeds) for dataset, model in zip(['cifar10', 'fashion_mnist'],[CNNCifar, CNNFashion_Mnist]) for j, k in zip([2, 0.1], ['class', 'dirichlet'])]) 111 | p.map(main2, [(i, data_process(dataset).split_dataset(num_nodes, j, k), model, K) for i in range(27, 27+seeds) for dataset, model in zip(['medmnist_octmnist'],[oct_net]) for j, k in zip([2, 0.1], ['class', 'dirichlet']) for K in [3,5,10]]) 112 | p.close() 113 | print(time.perf_counter()-start, "seconds") --------------------------------------------------------------------------------