├── .gitignore ├── LICENSE ├── README.md ├── fedlab_benchmarks ├── README.md ├── __init__.py ├── cfl │ ├── README.md │ ├── __init__.py │ ├── cfl_trainer.py │ ├── datasets.py │ ├── helper.py │ └── standalone.py ├── compressors │ ├── __init__.py │ ├── base_compressor.py │ ├── compress_example │ │ ├── __init__.py │ │ ├── client.py │ │ ├── quick_start.sh │ │ └── server.py │ ├── dgc.py │ ├── qsgd.py │ └── topk.py ├── datasets │ ├── README.md │ ├── __init__.py │ ├── adult │ │ ├── __init__.py │ │ ├── adult.py │ │ └── adult_tutorial.ipynb │ ├── celeba │ │ ├── README.md │ │ ├── preprocess.sh │ │ ├── preprocess │ │ │ └── metadata_to_json.py │ │ └── stats.sh │ ├── cifar10 │ │ └── cifar10_partitioner.ipynb │ ├── cifar100 │ │ ├── data_partitioner.ipynb │ │ ├── imgs │ │ │ ├── cifar100_balance_dir_alpha_0.3_100clients.png │ │ │ ├── cifar100_balance_iid_100clients.png │ │ │ ├── cifar100_hetero_dir_0.3_100clients.png │ │ │ ├── cifar100_hetero_dir_0.3_100clients_dist.png │ │ │ ├── cifar100_shards_200_100clients.png │ │ │ ├── cifar100_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients.png │ │ │ ├── cifar100_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients_dist.png │ │ │ ├── cifar100_unbalance_iid_unbalance_sgm_0.3_100clients.png │ │ │ └── cifar100_unbalance_iid_unbalance_sgm_0.3_100clients_dist.png │ │ └── partition-reports │ │ │ ├── cifar100_balance_dir_alpha_0.3_100clients.csv │ │ │ ├── cifar100_balance_iid_100clients.csv │ │ │ ├── cifar100_hetero_dir_0.3_100clients.csv │ │ │ ├── cifar100_shards_200_100clients.csv │ │ │ ├── cifar100_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients.csv │ │ │ └── cifar100_unbalance_iid_unbalance_sgm_0.3_100clients.csv │ ├── covtype │ │ ├── __init__.py │ │ ├── covtype.py │ │ └── covtype_tutorial.ipynb │ ├── fcube │ │ ├── __init__.py │ │ ├── fcube.py │ │ └── fcube_tutorial.ipynb │ ├── femnist │ │ ├── README.md │ │ ├── preprocess.sh │ │ ├── preprocess │ │ │ ├── data_to_json.py │ │ │ ├── data_to_json.sh │ │ │ ├── get_data.sh │ │ │ ├── get_file_dirs.py │ │ │ ├── get_hashes.py │ │ │ ├── group_by_writer.py │ │ │ └── match_hashes.py │ │ └── stats.sh │ ├── fmnist │ │ └── fmnist_tutorial.ipynb │ ├── imgs │ │ ├── adult_iid_10clients.png │ │ ├── adult_noniid-label1_10clients.png │ │ ├── adult_noniid_labeldir_10clients.png │ │ ├── adult_unbalance_10clients.png │ │ ├── cifar10_balance_dir_alpha_0.3_100clients.png │ │ ├── cifar10_balance_iid_100clients.png │ │ ├── cifar10_hetero_dir_0.3_100clients.png │ │ ├── cifar10_hetero_dir_0.3_100clients_dist.png │ │ ├── cifar10_shards_200_100clients.png │ │ ├── cifar10_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients.png │ │ ├── cifar10_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients_dist.png │ │ ├── cifar10_unbalance_iid_unbalance_sgm_0.3_100clients.png │ │ ├── cifar10_unbalance_iid_unbalance_sgm_0.3_100clients_dist.png │ │ ├── fcube_class_dist.png │ │ ├── fcube_iid.png │ │ ├── fcube_iid_part.png │ │ ├── fcube_synthetic.png │ │ ├── fcube_synthetic_original_paper.png │ │ ├── fcube_synthetic_part.png │ │ ├── fcube_test_dist_vis.png │ │ ├── fcube_train_dist_vis.png │ │ ├── fmnist_feature_skew_vis.png │ │ ├── fmnist_iid_clients_10.png │ │ ├── fmnist_noniid-label_1_clients_10.png │ │ ├── fmnist_noniid-label_2_clients_10.png │ │ ├── fmnist_noniid-label_3_clients_10.png │ │ ├── fmnist_noniid_labeldir_clients_10.png │ │ ├── fmnist_unbalance_clients_10.png │ │ ├── fmnist_vis.png │ │ ├── svhn_feature_skew_vis.png │ │ └── svhn_vis.png │ ├── mnist │ │ ├── download_mnist.py │ │ ├── mnist_iid.pkl │ │ ├── mnist_noniid.pkl │ │ └── mnist_partition.py │ ├── partition-reports │ │ ├── adult_iid_10clients.csv │ │ ├── adult_noniid-label1_10clients.csv │ │ ├── adult_noniid_labeldir_10clients.csv │ │ ├── adult_unbalance_10clients.csv │ │ ├── cifar10_balance_dir_alpha_0.3_100clients.csv │ │ ├── cifar10_balance_iid_100clients.csv │ │ ├── cifar10_hetero_dir_0.3_100clients.csv │ │ ├── cifar10_shards_200_100clients.csv │ │ ├── cifar10_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients.csv │ │ ├── cifar10_unbalance_iid_unbalance_sgm_0.3_100clients.csv │ │ ├── fcube_iid.csv │ │ ├── fcube_synthetic.csv │ │ ├── fmnist_iid_clients_10.csv │ │ ├── fmnist_noniid-label_1_clients_10.csv │ │ ├── fmnist_noniid-label_2_clients_10.csv │ │ ├── fmnist_noniid-label_3_clients_10.csv │ │ ├── fmnist_noniid_labeldir_clients_10.csv │ │ └── fmnist_unbalance_clients_10.csv │ ├── rcv1 │ │ ├── __init__.py │ │ ├── rcv1.py │ │ └── rcv1_tutorial.ipynb │ ├── reddit │ │ ├── README.md │ │ ├── build_vocab.py │ │ └── source │ │ │ ├── clean_raw.py │ │ │ ├── delete_small_users.py │ │ │ ├── get_json.py │ │ │ ├── get_raw_users.py │ │ │ ├── merge_raw_users.py │ │ │ ├── preprocess.py │ │ │ ├── reddit_utils.py │ │ │ └── run_reddit.sh │ ├── sent140 │ │ ├── README.md │ │ ├── preprocess.sh │ │ ├── preprocess │ │ │ ├── combine_data.py │ │ │ ├── data_to_json.py │ │ │ ├── data_to_json.sh │ │ │ └── get_data.sh │ │ └── stats.sh │ ├── shakespeare │ │ ├── README.md │ │ ├── preprocess.sh │ │ ├── preprocess │ │ │ ├── data_to_json.sh │ │ │ ├── gen_all_data.py │ │ │ ├── get_data.sh │ │ │ ├── preprocess_shakespeare.py │ │ │ └── shake_utils.py │ │ └── stats.sh │ ├── svhn │ │ └── svhn_tutorial.ipynb │ ├── synthetic │ │ ├── README.md │ │ ├── data_generator.py │ │ ├── main.py │ │ ├── preprocess.sh │ │ └── stats.sh │ └── utils │ │ ├── README.md │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── preprocess.sh │ │ ├── remove_users.py │ │ ├── sample.py │ │ ├── split_data.py │ │ ├── stats.py │ │ └── util.py ├── feature-skew-fedavg │ ├── README.md │ ├── __init__.py │ ├── client.py │ ├── config.py │ ├── data_partition.py │ ├── models.py │ ├── server.py │ ├── start_clt.sh │ └── start_server.sh ├── fedasync │ ├── README.md │ ├── __init__.py │ ├── cross_process │ │ ├── client.py │ │ ├── quick_start.sh │ │ └── server.py │ └── standalone │ │ ├── __init__.py │ │ ├── cifar10_iid.pkl │ │ ├── cifar10_noniid.pkl │ │ └── standalone.py ├── fedavg_v1.1.2 │ ├── README.md │ ├── __init__.py │ ├── cross_process │ │ ├── LEAF_test.sh │ │ ├── __init__.py │ │ ├── client.py │ │ ├── quick_start.sh │ │ ├── server.py │ │ ├── setting.py │ │ └── start_clients.sh │ ├── scale │ │ ├── __init__.py │ │ ├── cifar10-cnn │ │ │ ├── cifar10_iid.pkl │ │ │ ├── cifar10_noniid.pkl │ │ │ ├── cifar10_partition.py │ │ │ ├── client.py │ │ │ ├── config.py │ │ │ ├── server.py │ │ │ └── start_clt.sh │ │ ├── femnist-cnn │ │ │ ├── client.py │ │ │ ├── server.py │ │ │ └── start_clt.sh │ │ ├── mnist-cnn │ │ │ ├── __init__.py │ │ │ ├── client.py │ │ │ ├── mnist_iid.pkl │ │ │ ├── mnist_noniid.pkl │ │ │ ├── mnist_partition.py │ │ │ ├── server.py │ │ │ └── start_clt.sh │ │ └── shakespeare-rnn │ │ │ ├── client.py │ │ │ ├── server.py │ │ │ └── start_clt.sh │ └── standalone │ │ ├── __init__.py │ │ ├── mnist_iid.pkl │ │ ├── mnist_noniid.pkl │ │ └── standalone.py ├── fedavg_v1.2.0 │ ├── client.py │ ├── mnist_partition.pkl │ ├── run.sh │ ├── server.py │ ├── setting.py │ └── standalone.py ├── feddyn │ ├── Output │ │ └── CIFAR10_100_iid_plots.png │ ├── README.md │ ├── __init__.py │ ├── client.py │ ├── client_starter.py │ ├── config.py │ ├── data_partition.py │ ├── models.py │ ├── results-plot.ipynb │ ├── server.py │ ├── server_starter.py │ ├── start_clt.sh │ ├── start_server.sh │ └── utils.py ├── fedmgda+ │ ├── README.md │ ├── __init__.py │ ├── client.py │ ├── mnist_iid_100.pkl │ ├── mnist_noniid.pkl │ ├── mnist_noniid_200_100.pkl │ ├── run.sh │ ├── server.py │ ├── setting.py │ ├── standalone.py │ └── start_clients.sh ├── fedprox │ ├── README.md │ ├── __init__.py │ ├── cross_process │ │ ├── __init__.py │ │ ├── client.py │ │ ├── server.py │ │ └── setting.py │ ├── fedprox_trainer.py │ └── standalone.py ├── leaf │ ├── README.md │ ├── README_tmp.md │ ├── README_zh_cn.md │ ├── __init__.py │ ├── dataloader.py │ ├── dataset │ │ ├── __init__.py │ │ ├── celeba_dataset.py │ │ ├── femnist_dataset.py │ │ ├── reddit_dataset.py │ │ ├── sent140_dataset.py │ │ └── shakespeare_dataset.py │ ├── gen_pickle_dataset.sh │ ├── nlp_utils │ │ ├── README.md │ │ ├── __init__.py │ │ ├── download_glove.sh │ │ └── util.py │ └── pickle_dataset.py ├── models │ ├── __init__.py │ ├── cnn.py │ ├── mlp.py │ └── rnn.py ├── perfedavg │ ├── README.md │ ├── __init__.py │ ├── fine_tuner.py │ ├── image │ │ └── paper_exp_res.png │ ├── models.py │ ├── multi_process │ │ ├── __init__.py │ │ ├── client.py │ │ ├── client_manager.py │ │ ├── handler.py │ │ ├── quick_start.sh │ │ ├── server.py │ │ └── server_manager.py │ ├── single_process.py │ ├── trainer.py │ └── utils.py ├── qfedavg │ ├── README.md │ ├── __init__.py │ ├── client.py │ ├── mnist_iid_10.pkl │ ├── mnist_noniid_200_10.pkl │ ├── run.sh │ └── server.py └── report_TEMPLATE.md └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # System file 2 | *.DS_Store 3 | 4 | # Pycharm 5 | .idea 6 | 7 | # VS Code 8 | .vscode 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | # docs/_build/ 82 | docs/build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | 144 | # pytype static type analyzer 145 | .pytype/ 146 | 147 | # Cython debug symbols 148 | cython_debug/ 149 | 150 | # datasets 151 | fedlab_benchmarks/datasets/cifar10/* 152 | !fedlab_benchmarks/datasets/cifar10/*.py 153 | !fedlab_benchmarks/datasets/cifar10/*.ipynb 154 | fedlab_benchmarks/datasets/*/data 155 | fedlab_benchmarks/datasets/*/meta 156 | fedlab_benchmarks/datasets/*/MNIST 157 | fedlab_benchmarks/datasets/*/*.pkl 158 | 159 | fedlab_benchmarks/leaf/pickle_datasets/* 160 | fedlab_benchmarks/leaf/pickle_datasets/* 161 | 162 | actions-runner 163 | 164 | # tests 165 | fedlab_benchmarks/fedavg/standalone/*.txt 166 | fedlab_benchmarks/fedavg/standalone/*.ipynb 167 | fedlab_benchmarks/fedavg/standalone/*.sh 168 | 169 | .script.ipynb 170 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | # FedLab-benchmarks 4 | 5 | **For future maintain and convenient organization, this repo is deprecated. We have merged all codes into the main repo [FedLab](https://github.com/SMILELab-FL/FedLab) for the latest version.** 6 | -------------------------------------------------------------------------------- /fedlab_benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | 18 | ROOT_DIR = os.path.join(os.path.abspath(__file__)) 19 | -------------------------------------------------------------------------------- /fedlab_benchmarks/cfl/README.md: -------------------------------------------------------------------------------- 1 | # CFL 2 | 3 | [Clustered Federated Learning: Model-Agnostic Distributed Multi-Task Optimization under Privacy Constraints](https://arxiv.org/abs/1910.01991) 4 | 5 | 6 | ## Requirements 7 | 8 | fedlab==1.2.1 9 | 10 | ## Run 11 | 12 | We implement shifted and rotated data generation. And we accept the real-world data, e.g. femnist. 13 | 14 | For shifted and rotated augmented data, we accept `mnist` and `cifar10` configuration now. 15 | 16 | And we also provide the rotated 0 and 180 degree `emnist` used in the [original paper's code implementation](https://github.com/felisat/clustered-federated-learning#clustered-federated-learning-model-agnostic-distributed-multi-task-optimization-under-privacy-constraints) 17 | 18 | If you don't have the augmented data, you should set the config `process_data=1` to generate data firstly. 19 | And the augmented data will be saved in `args.save_dir`, which defaults to `./datasets`. 20 | You can see and modify the two parameters `save_dir` and `root`, which represent the augmented data storage path and origin data read path respectively. 21 | `save_dir` defaults to `./datasets`, and `root` defaults to `../datasets/{dataset_name}` 22 | 23 | For real-world data, we refer [LEAF](https://github.com/TalwalkarLab/leaf) to simulate. In our benchmark, we implement in `fedlab_benchmarks/leaf/`. 24 | You can read the docs in `leaf` folder to generate data. And cfl will use the defaulted leaf data saving path `../leaf/pickle_datasets/` to get data. 25 | 26 | ## Performance 27 | 28 | Null 29 | 30 | ## References 31 | 32 | [1] Sattler, Felix, Klaus-Robert Müller, and Wojciech Samek. "Clustered Federated Learning: Model-Agnostic Distributed Multi-Task Optimization under Privacy Constraints." arXiv preprint arXiv:1910.01991 (2019). 33 | [2] [Resource Code](https://github.com/felisat/clustered-federated-learning#clustered-federated-learning-model-agnostic-distributed-multi-task-optimization-under-privacy-constraints) 34 | -------------------------------------------------------------------------------- /fedlab_benchmarks/cfl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/cfl/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/cfl/helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is copied from [clustered-federated-learning/models.py] 3 | https://github.com/felisat/clustered-federated-learning/blob/master/helper.py 4 | """ 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | class ExperimentLogger: 11 | def log(self, values): 12 | for k, v in values.items(): 13 | if k not in self.__dict__: 14 | self.__dict__[k] = [v] 15 | else: 16 | self.__dict__[k] += [v] 17 | 18 | 19 | def display_train_stats(cfl_stats, eps_1, eps_2, communication_rounds): 20 | plt.figure(figsize=(12, 4)) 21 | 22 | plt.subplot(1, 2, 1) 23 | acc_mean = np.mean(cfl_stats.acc_clusters, axis=1) 24 | acc_std = np.std(cfl_stats.acc_clusters, axis=1) 25 | plt.fill_between(cfl_stats.rounds, acc_mean - acc_std, acc_mean + acc_std, alpha=0.5, color="C0") 26 | plt.plot(cfl_stats.rounds, acc_mean, color="C0") 27 | 28 | if "split" in cfl_stats.__dict__: 29 | for s in cfl_stats.split: 30 | plt.axvline(x=s, linestyle="-", color="k", label=r"Split") 31 | 32 | plt.text(x=communication_rounds, y=1, ha="right", va="top", 33 | s=f"Clusters: {cfl_stats.clusters[-1]}") 34 | 35 | plt.xlabel("Communication Rounds") 36 | plt.ylabel("Accuracy") 37 | 38 | plt.xlim(0, communication_rounds) 39 | plt.ylim(0, 1) 40 | 41 | plt.subplot(1, 2, 2) 42 | 43 | plt.plot(cfl_stats.rounds, cfl_stats.mean_norm, color="C1", label=r"$\|\sum_i\Delta W_i \|$") 44 | plt.plot(cfl_stats.rounds, cfl_stats.max_norm, color="C2", label=r"$\max_i\|\Delta W_i \|$") 45 | 46 | plt.axhline(y=eps_1, linestyle="--", color="k", label=r"$\varepsilon_1$") 47 | plt.axhline(y=eps_2, linestyle=":", color="k", label=r"$\varepsilon_2$") 48 | 49 | if "split" in cfl_stats.__dict__: 50 | for s in cfl_stats.split: 51 | plt.axvline(x=s, linestyle="-", color="k", label=r"Split") 52 | 53 | plt.xlabel("Communication Rounds") 54 | plt.legend() 55 | 56 | plt.xlim(0, communication_rounds) 57 | # plt.ylim(0, 2) 58 | 59 | plt.show() -------------------------------------------------------------------------------- /fedlab_benchmarks/compressors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .topk import TopkCompressor 16 | from .qsgd import QSGDCompressor 17 | -------------------------------------------------------------------------------- /fedlab_benchmarks/compressors/base_compressor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC 16 | 17 | 18 | class Compressor(ABC): 19 | def __init__(self) -> None: 20 | super().__init__() 21 | 22 | def compress(self, *args, **kwargs): 23 | raise NotImplementedError() 24 | 25 | def decompress(self, *args, **kwargs): 26 | raise NotImplementedError() 27 | 28 | 29 | class Memory(ABC): 30 | def __init__(self) -> None: 31 | super().__init__() 32 | 33 | def initialize(self, *args, **kwargs): 34 | raise NotImplementedError() 35 | 36 | def compensate(self, tensor, *args, **kwargs): 37 | raise NotImplementedError() 38 | 39 | def update(self, *args, **kwargs): 40 | raise NotImplementedError() 41 | 42 | def state_dict(self): 43 | raise NotImplementedError() 44 | 45 | def load_state_dict(self, state_dict): 46 | raise NotImplementedError() -------------------------------------------------------------------------------- /fedlab_benchmarks/compressors/compress_example/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/compressors/compress_example/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/compressors/compress_example/quick_start.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | 3 | python server.py --ip 127.0.0.1 --port 3002 --world_size 3 --dataset mnist --round 3 & 4 | 5 | python client.py --ip 127.0.0.1 --port 3002 --world_size 3 --rank 1 --dataset mnist & 6 | 7 | python client.py --ip 127.0.0.1 --port 3002 --world_size 3 --rank 2 --dataset mnist & 8 | 9 | wait -------------------------------------------------------------------------------- /fedlab_benchmarks/compressors/compress_example/server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | from fedlab.core import communicator 5 | from fedlab.utils.logger import Logger 6 | from fedlab.core.server.handler import SyncParameterServerHandler 7 | from fedlab.core.server.manager import ServerSynchronousManager 8 | from fedlab.core.network import DistNetwork 9 | from fedlab.compressor.topk import TopkCompressor 10 | from fedlab.utils.message_code import MessageCode 11 | from fedlab.core.communicator.processor import PackageProcessor 12 | 13 | sys.path.append('../') 14 | from models.cnn import CNN_MNIST 15 | 16 | class CompressServerManager(ServerSynchronousManager): 17 | def __init__(self, network, handler, logger=None): 18 | super().__init__(network, handler, logger=logger) 19 | self.tpkc = TopkCompressor(compress_ratio=0.5) 20 | 21 | def on_receive(self, sender, message_code, payload): 22 | if message_code == MessageCode.ParameterUpdate: 23 | print(sender, message_code, payload[0].shape) 24 | #_, _, paylaod = PackageProcessor.recv_package(src=sender) 25 | #print("------", len(paylaod)) 26 | return True 27 | else: 28 | raise Exception("Unexpected message code {}".format(message_code)) 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser(description='FL server example') 33 | 34 | parser.add_argument('--ip', type=str) 35 | parser.add_argument('--port', type=str) 36 | parser.add_argument('--world_size', type=int) 37 | 38 | parser.add_argument('--round', type=int, default=1) 39 | parser.add_argument('--dataset', type=str) 40 | parser.add_argument('--ethernet', type=str, default=None) 41 | parser.add_argument('--sample', type=float, default=1) 42 | 43 | args = parser.parse_args() 44 | 45 | model = CNN_MNIST() 46 | LOGGER = Logger(log_name="server") 47 | handler = SyncParameterServerHandler(model, 48 | global_round=args.round, 49 | logger=LOGGER, 50 | sample_ratio=args.sample) 51 | network = DistNetwork(address=(args.ip, args.port), 52 | world_size=args.world_size, 53 | rank=0, 54 | ethernet=args.ethernet) 55 | 56 | manager_ = CompressServerManager(handler=handler, 57 | network=network, 58 | logger=LOGGER) 59 | 60 | manager_.run() 61 | 62 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/README.md: -------------------------------------------------------------------------------- 1 | # DATASETS README 2 | 3 | This folder contains download and preprocessed scripts for commonly used datasets, and provides leaf dataset interface. 4 | 5 | 6 | - `data` folder: contains download and preprocess scripts for commonly used datasets, each subfolder is named by dataset name. 7 | 8 | For LEAF dataset, it contain `celeba`, `femnist`, `reddit`, `sent140`, `shakespeare`, `synthetic`, whose download and preprocess scripts are copied by [LEAF-Github](https://github.com/TalwalkarLab/leaf). And we copy and modify from [Flower leaf script](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/scripts/leaf/femnist]) to store processed `Dataset` in the form of pickle files. 9 | 10 | For leaf dataset folders, run `create_datasets_and_save.sh` to get partitioned data. Also we can edit preprocess.sh command params to get a different partition way. 11 | 12 | - `leaf_data_process` folder: contains process method to read leaf data and get dataloader for users. -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .fcube import FCUBE 17 | from .adult import Adult 18 | from .rcv1 import RCV1 19 | from .covtype import Covtype 20 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/adult/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .adult import Adult 16 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/adult/adult.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | from sklearn.datasets import load_svmlight_file 17 | import random 18 | import os 19 | from urllib.request import urlretrieve 20 | 21 | from torch.utils.data import Dataset 22 | 23 | 24 | class Adult(Dataset): 25 | url = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/" 26 | train_file_name = "a9a" 27 | test_file_name = "a9a.t" 28 | num_classes = 2 29 | num_features = 123 30 | 31 | def __init__(self, root, train=True, 32 | transform=None, 33 | target_transform=None, 34 | download=False): 35 | self.root = root 36 | self.train = train 37 | self.transform = transform 38 | self.target_transform = target_transform 39 | 40 | if not os.path.exists(root): 41 | os.mkdir(root) 42 | 43 | if self.train: 44 | self.full_file_path = os.path.join(self.root, self.train_file_name) 45 | else: 46 | self.full_file_path = os.path.join(self.root, self.test_file_name) 47 | 48 | if download: 49 | self.download() 50 | 51 | if not self._local_file_existence(): 52 | raise RuntimeError( 53 | f"Adult-a9a source data file not found. You can use download=True to " 54 | f"download it.") 55 | 56 | # now load from source file 57 | X, y = load_svmlight_file(self.full_file_path) 58 | X = X.todense() # transform 59 | 60 | if not self.train: 61 | X = np.c_[X, np.zeros((len(y), 1))] # append a zero vector at the end of X_test 62 | 63 | X = np.array(X, dtype=np.float32) 64 | y = np.array((y + 1) / 2, dtype=np.int32) # map elements of y from {-1, 1} to {0, 1} 65 | print(f"Local file {self.full_file_path} loaded.") 66 | self.data, self.targets = X, y 67 | 68 | def download(self): 69 | if self._local_file_existence(): 70 | print(f"Source file already downloaded.") 71 | return 72 | 73 | if self.train: 74 | download_url = self.url + self.train_file_name 75 | else: 76 | download_url = self.url + self.test_file_name 77 | 78 | urlretrieve(download_url, self.full_file_path) 79 | 80 | def _local_file_existence(self): 81 | return os.path.exists(self.full_file_path) 82 | 83 | def __getitem__(self, index): 84 | data, label = self.data[index], self.targets[index] 85 | 86 | if self.transform is not None: 87 | data = self.transform(data) 88 | 89 | if self.target_transform is not None: 90 | label = self.target_transform(label) 91 | 92 | return data, label 93 | 94 | def __len__(self): 95 | return len(self.targets) 96 | 97 | def extra_repr(self) -> str: 98 | return "Split: {}".format("Train" if self.train is True else "Test") 99 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/celeba/preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # download data and convert to .json format 4 | 5 | if [ ! -d "data/raw/img_align_celeba" ] || [ ! "$(ls -A data/raw/img_align_celeba)" ] || [ ! -f "data/raw/list_attr_celeba.txt" ]; then 6 | echo "Please download the celebrity faces dataset and attributes file from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html" 7 | exit 1 8 | fi 9 | 10 | if [ ! -f "data/raw/identity_CelebA.txt" ]; then 11 | echo "Please request the celebrity identities file from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html" 12 | exit 1 13 | fi 14 | 15 | if [ ! -d "data/all_data" ] || [ ! "$(ls -A data/all_data)" ]; then 16 | echo "Preprocessing raw data" 17 | python preprocess/metadata_to_json.py 18 | fi 19 | 20 | NAME="celeba" # name of the dataset, equivalent to directory name 21 | 22 | 23 | cd ../utils 24 | 25 | bash preprocess.sh --name $NAME $@ 26 | 27 | cd ../$NAME -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/celeba/preprocess/metadata_to_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | 5 | 6 | TARGET_NAME = 'Smiling' 7 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 8 | 9 | 10 | def get_metadata(): 11 | f_identities = open(os.path.join( 12 | parent_path, 'data', 'raw', 'identity_CelebA.txt'), 'r') 13 | identities = f_identities.read().split('\n') 14 | 15 | f_attributes = open(os.path.join( 16 | parent_path, 'data', 'raw', 'list_attr_celeba.txt'), 'r') 17 | attributes = f_attributes.read().split('\n') 18 | 19 | return identities, attributes 20 | 21 | 22 | def get_celebrities_and_images(identities): 23 | all_celebs = {} 24 | 25 | for line in identities: 26 | info = line.split() 27 | if len(info) < 2: 28 | continue 29 | image, celeb = info[0], info[1] 30 | if celeb not in all_celebs: 31 | all_celebs[celeb] = [] 32 | all_celebs[celeb].append(image) 33 | 34 | good_celebs = {c: all_celebs[c] for c in all_celebs if len(all_celebs[c]) >= 5} 35 | return good_celebs 36 | 37 | 38 | def _get_celebrities_by_image(identities): 39 | good_images = {} 40 | for c in identities: 41 | images = identities[c] 42 | for img in images: 43 | good_images[img] = c 44 | return good_images 45 | 46 | 47 | def get_celebrities_and_target(celebrities, attributes, attribute_name=TARGET_NAME): 48 | col_names = attributes[1] 49 | col_idx = col_names.split().index(attribute_name) 50 | 51 | celeb_attributes = {} 52 | good_images = _get_celebrities_by_image(celebrities) 53 | 54 | for line in attributes[2:]: 55 | info = line.split() 56 | if len(info) == 0: 57 | continue 58 | 59 | image = info[0] 60 | if image not in good_images: 61 | continue 62 | 63 | celeb = good_images[image] 64 | att = (int(info[1:][col_idx]) + 1) / 2 65 | 66 | if celeb not in celeb_attributes: 67 | celeb_attributes[celeb] = [] 68 | 69 | celeb_attributes[celeb].append(att) 70 | 71 | return celeb_attributes 72 | 73 | 74 | def build_json_format(celebrities, targets): 75 | all_data = {} 76 | 77 | celeb_keys = [c for c in celebrities] 78 | num_samples = [len(celebrities[c]) for c in celeb_keys] 79 | data = {c: {'x': celebrities[c], 'y': targets[c]} for c in celebrities} 80 | 81 | all_data['users'] = celeb_keys 82 | all_data['num_samples'] = num_samples 83 | all_data['user_data'] = data 84 | return all_data 85 | 86 | 87 | def write_json(json_data): 88 | file_name = 'all_data.json' 89 | dir_path = os.path.join(parent_path, 'data', 'all_data') 90 | 91 | if not os.path.exists(dir_path): 92 | os.mkdir(dir_path) 93 | 94 | file_path = os.path.join(dir_path, file_name) 95 | 96 | print('writing {}'.format(file_name)) 97 | with open(file_path, 'w') as outfile: 98 | json.dump(json_data, outfile) 99 | 100 | 101 | def main(): 102 | identities, attributes = get_metadata() 103 | celebrities = get_celebrities_and_images(identities) 104 | targets = get_celebrities_and_target(celebrities, attributes) 105 | 106 | json_data = build_json_format(celebrities, targets) 107 | write_json(json_data) 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | 113 | 114 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/celeba/stats.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME="celeba" 4 | 5 | cd ../utils 6 | 7 | python3 stats.py --name $NAME 8 | 9 | cd ../$NAME -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/cifar100/imgs/cifar100_balance_dir_alpha_0.3_100clients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_balance_dir_alpha_0.3_100clients.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/cifar100/imgs/cifar100_balance_iid_100clients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_balance_iid_100clients.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/cifar100/imgs/cifar100_hetero_dir_0.3_100clients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_hetero_dir_0.3_100clients.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/cifar100/imgs/cifar100_hetero_dir_0.3_100clients_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_hetero_dir_0.3_100clients_dist.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/cifar100/imgs/cifar100_shards_200_100clients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_shards_200_100clients.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/cifar100/imgs/cifar100_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/cifar100/imgs/cifar100_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients_dist.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/cifar100/imgs/cifar100_unbalance_iid_unbalance_sgm_0.3_100clients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_unbalance_iid_unbalance_sgm_0.3_100clients.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/cifar100/imgs/cifar100_unbalance_iid_unbalance_sgm_0.3_100clients_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_unbalance_iid_unbalance_sgm_0.3_100clients_dist.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/covtype/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .covtype import Covtype 16 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/fcube/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .fcube import FCUBE 17 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/femnist/README.md: -------------------------------------------------------------------------------- 1 | # FEMNIST Dataset 2 | 3 | ## Setup Instructions 4 | - pip3 install numpy 5 | - pip3 install pillow 6 | - Run ```./preprocess.sh``` with a choice of the following tags: 7 | - ```-s``` := 'iid' to sample in an i.i.d. manner, or 'niid' to sample in a non-i.i.d. manner; more information on i.i.d. versus non-i.i.d. is included in the 'Notes' section 8 | - ```--iu``` := number of users, if iid sampling; expressed as a fraction of the total number of users; default is 0.01 9 | - ```--sf``` := fraction of data to sample, written as a decimal; default is 0.1 10 | - ```-k``` := minimum number of samples per user 11 | - ```-t``` := 'user' to partition users into train-test groups, or 'sample' to partition each user's samples into train-test groups 12 | - ```--tf``` := fraction of data in training set, written as a decimal; default is 0.9 13 | - ```--smplseed``` := seed to be used before random sampling of data 14 | - ```--spltseed``` := seed to be used before random split of data 15 | 16 | i.e. 17 | - ```./preprocess.sh -s niid --sf 1.0 -k 0 -t sample``` (full-sized dataset)
18 | - ```./preprocess.sh -s niid --sf 0.05 -k 0 -t sample``` (small-sized dataset) 19 | 20 | Make sure to delete the rem_user_data, sampled_data, test, and train subfolders in the data directory before re-running preprocess.sh 21 | 22 | ## Notes 23 | - More details on i.i.d. versus non-i.i.d.: 24 | - In the i.i.d. sampling scenario, each datapoint is equally likely to be sampled. Thus, all users have the same underlying distribution of data. 25 | - In the non-i.i.d. sampling scenario, the underlying distribution of data for each user is consistent with the raw data. Since we assume that data distributions vary between user in the raw data, we refer to this sampling process as non-i.i.d. 26 | - More details on ```preprocess.sh```: 27 | - The order in which ```preprocess.sh``` processes data is 1. generating all_data, 2. sampling, 3. removing users, and 4. creating train-test split. The script will look at the data in the last generated directory and continue preprocessing from that point. For example, if the ```all_data``` directory has already been generated and the user decides to skip sampling and only remove users with the ```-k``` tag (i.e. running ```preprocess.sh -k 50```), the script will effectively apply a remove user filter to data in ```all_data``` and place the resulting data in the ```rem_user_data``` directory. 28 | - File names provide information about the preprocessing steps taken to generate them. For example, the ```all_data_niid_1_keep_64.json``` file was generated by first sampling 10 percent (.1) of the data ```all_data.json``` in a non-i.i.d. manner and then applying the ```-k 64``` argument to the resulting data. 29 | - Each .json file is an object with 3 keys: 30 | 1. 'users', a list of users 31 | 2. 'num_samples', a list of the number of samples for each user, and 32 | 3. 'user_data', an object with user names as keys and their respective data as values; for each user, data is represented as a list of images, with each image represented as a size-784 integer list (flattened from 28 by 28) 33 | - Run ```./stats.sh``` to get statistics of data (data/all_data/all_data.json must have been generated already) 34 | - In order to run reference implementations in ```../models``` directory, the ```-t sample``` tag must be used when running ```./preprocess.sh``` 35 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/femnist/preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # download data and convert to .json format 4 | 5 | if [ ! -d "data/all_data" ] || [ ! "$(ls -A data/all_data)" ]; then 6 | cd preprocess 7 | bash data_to_json.sh 8 | cd .. 9 | else 10 | echo "using existing data/all_data data folder to preprocess" 11 | fi 12 | 13 | NAME="femnist" # name of the dataset, equivalent to directory name 14 | 15 | cd ../utils 16 | 17 | bash preprocess.sh --name $NAME $@ 18 | 19 | cd ../$NAME -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/femnist/preprocess/data_to_json.py: -------------------------------------------------------------------------------- 1 | # Converts a list of (writer, [list of (file,class)]) tuples into a json object 2 | # of the form: 3 | # {users: [bob, etc], num_samples: [124, etc.], 4 | # user_data: {bob : {x:[img1,img2,etc], y:[class1,class2,etc]}, etc}} 5 | # where 'img_' is a vectorized representation of the corresponding image 6 | 7 | from __future__ import division 8 | import json 9 | import math 10 | import numpy as np 11 | import os 12 | import sys 13 | 14 | from PIL import Image 15 | 16 | utils_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 17 | utils_dir = os.path.join(utils_dir, 'utils') 18 | sys.path.append(utils_dir) 19 | 20 | import util 21 | 22 | MAX_WRITERS = 100 # max number of writers per json file. 23 | 24 | 25 | def relabel_class(c): 26 | ''' 27 | maps hexadecimal class value (string) to a decimal number 28 | returns: 29 | - 0 through 9 for classes representing respective numbers 30 | - 10 through 35 for classes representing respective uppercase letters 31 | - 36 through 61 for classes representing respective lowercase letters 32 | ''' 33 | if c.isdigit() and int(c) < 40: 34 | return (int(c) - 30) 35 | elif int(c, 16) <= 90: # uppercase 36 | return (int(c, 16) - 55) 37 | else: 38 | return (int(c, 16) - 61) 39 | 40 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 41 | 42 | by_writer_dir = os.path.join(parent_path, 'data', 'intermediate', 'images_by_writer') 43 | writers = util.load_obj(by_writer_dir) 44 | 45 | num_json = int(math.ceil(len(writers) / MAX_WRITERS)) 46 | 47 | users = [] 48 | num_samples = [] 49 | user_data = {} 50 | 51 | writer_count, all_writers = 0, 0 52 | json_index = 0 53 | for (w, l) in writers: 54 | 55 | users.append(w) 56 | num_samples.append(len(l)) 57 | user_data[w] = {'x': [], 'y': []} 58 | 59 | size = 28, 28 # original image size is 128, 128 60 | for (f, c) in l: 61 | file_path = os.path.join(parent_path, f) 62 | img = Image.open(file_path) 63 | gray = img.convert('L') 64 | gray.thumbnail(size, Image.ANTIALIAS) 65 | arr = np.asarray(gray).copy() 66 | vec = arr.flatten() 67 | vec = vec / 255 # scale all pixel values to between 0 and 1 68 | vec = vec.tolist() 69 | 70 | nc = relabel_class(c) 71 | 72 | user_data[w]['x'].append(vec) 73 | user_data[w]['y'].append(nc) 74 | 75 | writer_count += 1 76 | all_writers += 1 77 | 78 | if writer_count == MAX_WRITERS or all_writers == len(writers): 79 | 80 | all_data = {} 81 | all_data['users'] = users 82 | all_data['num_samples'] = num_samples 83 | all_data['user_data'] = user_data 84 | 85 | file_name = 'all_data_%d.json' % json_index 86 | file_path = os.path.join(parent_path, 'data', 'all_data', file_name) 87 | 88 | print('writing %s' % file_name) 89 | 90 | with open(file_path, 'w') as outfile: 91 | json.dump(all_data, outfile) 92 | 93 | writer_count = 0 94 | json_index += 1 95 | 96 | users[:] = [] 97 | num_samples[:] = [] 98 | user_data.clear() 99 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/femnist/preprocess/data_to_json.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # assumes that the script is run in the preprocess folder 4 | 5 | if [ ! -d "../data" ]; then 6 | mkdir ../data 7 | fi 8 | if [ ! -d "../data/raw_data" ]; then 9 | mkdir ../data/raw_data 10 | fi 11 | 12 | # download and unzip 13 | if [ ! -d "../data/raw_data/by_class" ] || [ ! -d "../data/raw_data/by_write" ]; then 14 | echo "------------------------------" 15 | echo "downloading and unzipping raw data" 16 | bash get_data.sh 17 | echo "finished downloading and unzipping raw data" 18 | else 19 | echo "using existing unzipped raw data folders (by_class and by_write)" 20 | fi 21 | 22 | 23 | echo "generating intermediate data" 24 | if [ ! -d "../data/intermediate" ]; then # stores .pkl files during preprocessing 25 | mkdir ../data/intermediate 26 | fi 27 | 28 | if [ ! -f ../data/intermediate/class_file_dirs.pkl ]; then 29 | echo "------------------------------" 30 | echo "extracting file directories of images" 31 | python3 get_file_dirs.py 32 | echo "finished extracting file directories of images" 33 | else 34 | echo "using existing data/intermediate/class_file_dirs.pkl" 35 | fi 36 | 37 | if [ ! -f ../data/intermediate/class_file_hashes.pkl ]; then 38 | echo "------------------------------" 39 | echo "calculating image hashes" 40 | python3 get_hashes.py 41 | echo "finished calculating image hashes" 42 | else 43 | echo "using existing data/intermediate/class_file_hashes.pkl" 44 | fi 45 | 46 | if [ ! -f ../data/intermediate/write_with_class.pkl ]; then 47 | echo "------------------------------" 48 | echo "assigning class labels to write images" 49 | python3 match_hashes.py 50 | echo "finished assigning class labels to write images" 51 | else 52 | echo "using existing data/intermediate/write_with_class.pkl" 53 | fi 54 | 55 | if [ ! -f ../data/intermediate/images_by_writer.pkl ]; then 56 | echo "------------------------------" 57 | echo "grouping images by writer" 58 | python3 group_by_writer.py 59 | echo "finished grouping images by writer" 60 | else 61 | echo "using existing data/intermediate/images_by_writer.pkl" 62 | fi 63 | 64 | if [ ! -d "../data/all_data" ]; then 65 | mkdir ../data/all_data 66 | fi 67 | if [ ! "$(ls -A ../data/all_data)" ]; then 68 | echo "------------------------------" 69 | echo "converting data to .json format in data/all_data" 70 | python3 data_to_json.py 71 | echo "finished converting data to .json format" 72 | fi 73 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/femnist/preprocess/get_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # assumes that the script is run in the preprocess folder 4 | 5 | cd ../data/raw_data 6 | if [ ! -f by_class.zip ]; then 7 | echo "downloading by_class.zip" 8 | wget https://s3.amazonaws.com/nist-srd/SD19/by_class.zip 9 | else 10 | echo "using existing by_class.zip" 11 | fi 12 | 13 | if [ ! -f by_write.zip ]; then 14 | echo "downloading by_write.zip" 15 | wget https://s3.amazonaws.com/nist-srd/SD19/by_write.zip 16 | else 17 | echo "using existing by_write.zip" 18 | fi 19 | 20 | if [ ! -d "by_class" ]; then 21 | echo "unzipping by_class.zip" 22 | unzip by_class.zip 23 | #rm by_class.zip 24 | else 25 | echo "using existing unzipped folder by_class" 26 | fi 27 | 28 | if [ ! -d "by_write" ]; then 29 | echo "unzipping by_write.zip" 30 | unzip by_write.zip 31 | #rm by_write.zip 32 | else 33 | echo "using existing unzipped folder by_write" 34 | fi 35 | 36 | cd ../../preprocess 37 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/femnist/preprocess/get_file_dirs.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Creates .pkl files for: 3 | 1. list of directories of every image in 'by_class' 4 | 2. list of directories of every image in 'by_write' 5 | the hierarchal structure of the data is as follows: 6 | - by_class -> classes -> folders containing images -> images 7 | - by_write -> folders containing writers -> writer -> types of images -> images 8 | the directories written into the files are of the form 'raw_data/...' 9 | ''' 10 | 11 | import os 12 | import sys 13 | 14 | utils_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 15 | utils_dir = os.path.join(utils_dir, 'utils') 16 | sys.path.append(utils_dir) 17 | 18 | print(utils_dir) 19 | 20 | import util 21 | 22 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 23 | 24 | print(parent_path) 25 | class_files = [] # (class, file directory) 26 | write_files = [] # (writer, file directory) 27 | 28 | class_dir = os.path.join(parent_path, 'data', 'raw_data', 'by_class') 29 | rel_class_dir = os.path.join('data', 'raw_data', 'by_class') 30 | classes = os.listdir(class_dir) 31 | classes = [c for c in classes if len(c) == 2] 32 | 33 | for cl in classes: 34 | cldir = os.path.join(class_dir, cl) 35 | rel_cldir = os.path.join(rel_class_dir, cl) 36 | subcls = os.listdir(cldir) 37 | 38 | subcls = [s for s in subcls if (('hsf' in s) and ('mit' not in s))] 39 | 40 | for subcl in subcls: 41 | subcldir = os.path.join(cldir, subcl) 42 | rel_subcldir = os.path.join(rel_cldir, subcl) 43 | images = os.listdir(subcldir) 44 | image_dirs = [os.path.join(rel_subcldir, i) for i in images] 45 | 46 | for image_dir in image_dirs: 47 | class_files.append((cl, image_dir)) 48 | 49 | write_dir = os.path.join(parent_path, 'data', 'raw_data', 'by_write') 50 | rel_write_dir = os.path.join('data', 'raw_data', 'by_write') 51 | write_parts = os.listdir(write_dir) 52 | 53 | for write_part in write_parts: 54 | writers_dir = os.path.join(write_dir, write_part) 55 | rel_writers_dir = os.path.join(rel_write_dir, write_part) 56 | writers = os.listdir(writers_dir) 57 | 58 | for writer in writers: 59 | writer_dir = os.path.join(writers_dir, writer) 60 | rel_writer_dir = os.path.join(rel_writers_dir, writer) 61 | wtypes = os.listdir(writer_dir) 62 | 63 | for wtype in wtypes: 64 | type_dir = os.path.join(writer_dir, wtype) 65 | rel_type_dir = os.path.join(rel_writer_dir, wtype) 66 | images = os.listdir(type_dir) 67 | image_dirs = [os.path.join(rel_type_dir, i) for i in images] 68 | 69 | for image_dir in image_dirs: 70 | write_files.append((writer, image_dir)) 71 | 72 | util.save_obj( 73 | class_files, 74 | os.path.join(parent_path, 'data', 'intermediate', 'class_file_dirs')) 75 | util.save_obj( 76 | write_files, 77 | os.path.join(parent_path, 'data', 'intermediate', 'write_file_dirs')) 78 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/femnist/preprocess/get_hashes.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import sys 4 | 5 | utils_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 6 | utils_dir = os.path.join(utils_dir, 'utils') 7 | sys.path.append(utils_dir) 8 | 9 | import util 10 | 11 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 12 | 13 | cfd = os.path.join(parent_path, 'data', 'intermediate', 'class_file_dirs') 14 | wfd = os.path.join(parent_path, 'data', 'intermediate', 'write_file_dirs') 15 | class_file_dirs = util.load_obj(cfd) 16 | write_file_dirs = util.load_obj(wfd) 17 | 18 | class_file_hashes = [] 19 | write_file_hashes = [] 20 | 21 | count = 0 22 | for tup in class_file_dirs: 23 | if (count % 100000 == 0): 24 | print('hashed %d class images' % count) 25 | 26 | (cclass, cfile) = tup 27 | file_path = os.path.join(parent_path, cfile) 28 | 29 | chash = hashlib.md5(open(file_path, 'rb').read()).hexdigest() 30 | 31 | class_file_hashes.append((cclass, cfile, chash)) 32 | 33 | count += 1 34 | 35 | cfhd = os.path.join(parent_path, 'data', 'intermediate', 'class_file_hashes') 36 | util.save_obj(class_file_hashes, cfhd) 37 | 38 | count = 0 39 | for tup in write_file_dirs: 40 | if (count % 100000 == 0): 41 | print('hashed %d write images' % count) 42 | 43 | (cclass, cfile) = tup 44 | file_path = os.path.join(parent_path, cfile) 45 | 46 | chash = hashlib.md5(open(file_path, 'rb').read()).hexdigest() 47 | 48 | write_file_hashes.append((cclass, cfile, chash)) 49 | 50 | count += 1 51 | 52 | wfhd = os.path.join(parent_path, 'data', 'intermediate', 'write_file_hashes') 53 | util.save_obj(write_file_hashes, wfhd) 54 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/femnist/preprocess/group_by_writer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | utils_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | utils_dir = os.path.join(utils_dir, 'utils') 6 | sys.path.append(utils_dir) 7 | 8 | import util 9 | 10 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 11 | 12 | wwcd = os.path.join(parent_path, 'data', 'intermediate', 'write_with_class') 13 | write_class = util.load_obj(wwcd) 14 | 15 | writers = [] # each entry is a (writer, [list of (file, class)]) tuple 16 | cimages = [] 17 | (cw, _, _) = write_class[0] 18 | for (w, f, c) in write_class: 19 | if w != cw: 20 | writers.append((cw, cimages)) 21 | cw = w 22 | cimages = [(f, c)] 23 | cimages.append((f, c)) 24 | writers.append((cw, cimages)) 25 | 26 | ibwd = os.path.join(parent_path, 'data', 'intermediate', 'images_by_writer') 27 | util.save_obj(writers, ibwd) 28 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/femnist/preprocess/match_hashes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | utils_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | utils_dir = os.path.join(utils_dir, 'utils') 6 | sys.path.append(utils_dir) 7 | 8 | import util 9 | 10 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 11 | 12 | cfhd = os.path.join(parent_path, 'data', 'intermediate', 'class_file_hashes') 13 | wfhd = os.path.join(parent_path, 'data', 'intermediate', 'write_file_hashes') 14 | class_file_hashes = util.load_obj(cfhd) # each elem is (class, file dir, hash) 15 | write_file_hashes = util.load_obj(wfhd) # each elem is (writer, file dir, hash) 16 | 17 | class_hash_dict = {} 18 | for i in range(len(class_file_hashes)): 19 | (c, f, h) = class_file_hashes[len(class_file_hashes)-i-1] 20 | class_hash_dict[h] = (c, f) 21 | 22 | write_classes = [] 23 | for tup in write_file_hashes: 24 | (w, f, h) = tup 25 | write_classes.append((w, f, class_hash_dict[h][0])) 26 | 27 | wwcd = os.path.join(parent_path, 'data', 'intermediate', 'write_with_class') 28 | util.save_obj(write_classes, wwcd) 29 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/femnist/stats.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | NAME="femnist" 5 | 6 | cd ../utils 7 | 8 | python3 stats.py --name $NAME 9 | 10 | cd ../$NAME -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/adult_iid_10clients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/adult_iid_10clients.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/adult_noniid-label1_10clients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/adult_noniid-label1_10clients.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/adult_noniid_labeldir_10clients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/adult_noniid_labeldir_10clients.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/adult_unbalance_10clients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/adult_unbalance_10clients.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/cifar10_balance_dir_alpha_0.3_100clients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_balance_dir_alpha_0.3_100clients.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/cifar10_balance_iid_100clients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_balance_iid_100clients.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/cifar10_hetero_dir_0.3_100clients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_hetero_dir_0.3_100clients.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/cifar10_hetero_dir_0.3_100clients_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_hetero_dir_0.3_100clients_dist.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/cifar10_shards_200_100clients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_shards_200_100clients.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/cifar10_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/cifar10_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients_dist.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/cifar10_unbalance_iid_unbalance_sgm_0.3_100clients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_unbalance_iid_unbalance_sgm_0.3_100clients.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/cifar10_unbalance_iid_unbalance_sgm_0.3_100clients_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_unbalance_iid_unbalance_sgm_0.3_100clients_dist.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/fcube_class_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fcube_class_dist.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/fcube_iid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fcube_iid.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/fcube_iid_part.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fcube_iid_part.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/fcube_synthetic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fcube_synthetic.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/fcube_synthetic_original_paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fcube_synthetic_original_paper.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/fcube_synthetic_part.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fcube_synthetic_part.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/fcube_test_dist_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fcube_test_dist_vis.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/fcube_train_dist_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fcube_train_dist_vis.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/fmnist_feature_skew_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fmnist_feature_skew_vis.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/fmnist_iid_clients_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fmnist_iid_clients_10.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/fmnist_noniid-label_1_clients_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fmnist_noniid-label_1_clients_10.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/fmnist_noniid-label_2_clients_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fmnist_noniid-label_2_clients_10.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/fmnist_noniid-label_3_clients_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fmnist_noniid-label_3_clients_10.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/fmnist_noniid_labeldir_clients_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fmnist_noniid_labeldir_clients_10.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/fmnist_unbalance_clients_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fmnist_unbalance_clients_10.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/fmnist_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fmnist_vis.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/svhn_feature_skew_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/svhn_feature_skew_vis.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/imgs/svhn_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/svhn_vis.png -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/mnist/download_mnist.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | 3 | if __name__ == "__main__": 4 | root = './' 5 | trainset = torchvision.datasets.MNIST( 6 | root=root, 7 | train=True, 8 | download=True, 9 | ) 10 | 11 | testset = torchvision.datasets.MNIST( 12 | root=root, 13 | train=False, 14 | download=True, 15 | ) 16 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/mnist/mnist_iid.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/mnist/mnist_iid.pkl -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/mnist/mnist_noniid.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/mnist/mnist_noniid.pkl -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/mnist/mnist_partition.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from fedlab.utils.functional import save_dict 4 | from fedlab.utils.dataset.slicing import noniid_slicing, random_slicing 5 | 6 | import torchvision 7 | 8 | trainset = torchvision.datasets.MNIST(root="./", train=True, download=False) 9 | 10 | num_clients=100 11 | num_shards=200 12 | 13 | data_indices = noniid_slicing(trainset, num_clients=num_clients, num_shards=num_shards) 14 | save_dict(data_indices, "mnist_noniid_{}_{}.pkl".format(num_shards, num_clients)) 15 | 16 | data_indices = random_slicing(trainset, num_clients=num_clients) 17 | save_dict(data_indices, "mnist_iid_{}.pkl".format(num_clients)) 18 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/partition-reports/adult_iid_10clients.csv: -------------------------------------------------------------------------------- 1 | Class frequencies: 2 | client,class0,class1,Amount 3 | Client 0,0.751,0.249,3256 4 | Client 1,0.766,0.234,3256 5 | Client 2,0.764,0.236,3256 6 | Client 3,0.758,0.242,3256 7 | Client 4,0.757,0.243,3256 8 | Client 5,0.757,0.243,3256 9 | Client 6,0.756,0.244,3256 10 | Client 7,0.768,0.232,3256 11 | Client 8,0.763,0.237,3256 12 | Client 9,0.753,0.247,3256 -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/partition-reports/adult_noniid-label1_10clients.csv: -------------------------------------------------------------------------------- 1 | Class frequencies: 2 | client,class0,class1,Amount 3 | Client 0,1.000,0.00,4944 4 | Client 1,0.00,1.000,1569 5 | Client 2,1.000,0.00,4944 6 | Client 3,0.00,1.000,1568 7 | Client 4,1.000,0.00,4944 8 | Client 5,0.00,1.000,1568 9 | Client 6,1.000,0.00,4944 10 | Client 7,0.00,1.000,1568 11 | Client 8,1.000,0.00,4944 12 | Client 9,0.00,1.000,1568 -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/partition-reports/adult_noniid_labeldir_10clients.csv: -------------------------------------------------------------------------------- 1 | Class frequencies: 2 | client,class0,class1,Amount 3 | Client 0,1.000,0.00,3305 4 | Client 1,0.833,0.167,126 5 | Client 2,0.222,0.778,8300 6 | Client 3,0.768,0.232,777 7 | Client 4,0.552,0.448,747 8 | Client 5,0.709,0.291,629 9 | Client 6,0.750,0.250,1546 10 | Client 7,1.000,0.00,5688 11 | Client 8,1.000,0.00,9929 12 | Client 9,0.815,0.185,1514 -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/partition-reports/adult_unbalance_10clients.csv: -------------------------------------------------------------------------------- 1 | Class frequencies: 2 | client,class0,class1,Amount 3 | Client 0,0.751,0.249,2449 4 | Client 1,0.683,0.317,123 5 | Client 2,0.725,0.275,204 6 | Client 3,0.764,0.236,3095 7 | Client 4,0.763,0.237,59 8 | Client 5,0.761,0.239,20562 9 | Client 6,0.723,0.277,47 10 | Client 7,0.757,0.243,2557 11 | Client 8,0.759,0.241,2153 12 | Client 9,0.747,0.253,1306 -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/partition-reports/fcube_iid.csv: -------------------------------------------------------------------------------- 1 | Class frequencies: 2 | client,class0,class1,Amount 3 | Client 0,0.472,0.528,250 4 | Client 1,0.532,0.468,250 5 | Client 2,0.476,0.524,250 6 | Client 3,0.520,0.480,250 -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/partition-reports/fcube_synthetic.csv: -------------------------------------------------------------------------------- 1 | Class frequencies: 2 | client,class0,class1,Amount 3 | Client 0,0.500,0.500,250 4 | Client 1,0.500,0.500,250 5 | Client 2,0.500,0.500,250 6 | Client 3,0.500,0.500,250 -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/partition-reports/fmnist_iid_clients_10.csv: -------------------------------------------------------------------------------- 1 | Class frequencies: 2 | client,class0,class1,class2,class3,class4,class5,class6,class7,class8,class9,Amount 3 | Client 0,0.096,0.104,0.101,0.108,0.104,0.098,0.094,0.098,0.102,0.097,6000 4 | Client 1,0.105,0.100,0.101,0.097,0.090,0.099,0.103,0.100,0.097,0.108,6000 5 | Client 2,0.102,0.100,0.102,0.099,0.097,0.100,0.098,0.097,0.106,0.101,6000 6 | Client 3,0.096,0.105,0.101,0.096,0.100,0.102,0.102,0.102,0.099,0.099,6000 7 | Client 4,0.092,0.101,0.097,0.102,0.101,0.100,0.103,0.102,0.102,0.101,6000 8 | Client 5,0.103,0.098,0.106,0.100,0.098,0.097,0.098,0.105,0.096,0.101,6000 9 | Client 6,0.107,0.104,0.102,0.087,0.104,0.105,0.097,0.102,0.096,0.097,6000 10 | Client 7,0.101,0.097,0.096,0.100,0.102,0.099,0.101,0.104,0.102,0.098,6000 11 | Client 8,0.100,0.097,0.098,0.097,0.101,0.101,0.104,0.098,0.100,0.105,6000 12 | Client 9,0.099,0.096,0.098,0.114,0.103,0.101,0.101,0.093,0.100,0.095,6000 -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/partition-reports/fmnist_noniid-label_1_clients_10.csv: -------------------------------------------------------------------------------- 1 | Class frequencies: 2 | client,class0,class1,class2,class3,class4,class5,class6,class7,class8,class9,Amount 3 | Client 0,1.000,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,6000 4 | Client 1,0.00,1.000,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,6000 5 | Client 2,0.00,0.00,1.000,0.00,0.00,0.00,0.00,0.00,0.00,0.00,6000 6 | Client 3,0.00,0.00,0.00,1.000,0.00,0.00,0.00,0.00,0.00,0.00,6000 7 | Client 4,0.00,0.00,0.00,0.00,1.000,0.00,0.00,0.00,0.00,0.00,6000 8 | Client 5,0.00,0.00,0.00,0.00,0.00,1.000,0.00,0.00,0.00,0.00,6000 9 | Client 6,0.00,0.00,0.00,0.00,0.00,0.00,1.000,0.00,0.00,0.00,6000 10 | Client 7,0.00,0.00,0.00,0.00,0.00,0.00,0.00,1.000,0.00,0.00,6000 11 | Client 8,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,1.000,0.00,6000 12 | Client 9,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,1.000,6000 -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/partition-reports/fmnist_noniid-label_2_clients_10.csv: -------------------------------------------------------------------------------- 1 | Class frequencies: 2 | client,class0,class1,class2,class3,class4,class5,class6,class7,class8,class9,Amount 3 | Client 0,0.500,0.00,0.00,0.00,0.500,0.00,0.00,0.00,0.00,0.00,6000 4 | Client 1,0.00,0.500,0.00,0.00,0.00,0.500,0.00,0.00,0.00,0.00,4000 5 | Client 2,0.00,0.00,0.667,0.00,0.00,0.00,0.00,0.00,0.00,0.333,9000 6 | Client 3,0.333,0.00,0.00,0.667,0.00,0.00,0.00,0.00,0.00,0.00,9000 7 | Client 4,0.00,0.00,0.00,0.00,0.500,0.00,0.500,0.00,0.00,0.00,6000 8 | Client 5,0.00,0.00,0.00,0.00,0.00,0.400,0.00,0.00,0.600,0.00,5000 9 | Client 6,0.00,0.400,0.00,0.00,0.00,0.00,0.600,0.00,0.00,0.00,5000 10 | Client 7,0.00,0.00,0.00,0.00,0.00,0.400,0.00,0.600,0.00,0.00,5000 11 | Client 8,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.500,0.500,0.00,6000 12 | Client 9,0.00,0.400,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.600,5000 -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/partition-reports/fmnist_noniid-label_3_clients_10.csv: -------------------------------------------------------------------------------- 1 | Class frequencies: 2 | client,class0,class1,class2,class3,class4,class5,class6,class7,class8,class9,Amount 3 | Client 0,0.250,0.00,0.00,0.00,0.500,0.250,0.00,0.00,0.00,0.00,6000 4 | Client 1,0.263,0.211,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.526,5700 5 | Client 2,0.00,0.00,0.400,0.00,0.00,0.300,0.300,0.00,0.00,0.00,5000 6 | Client 3,0.00,0.00,0.00,0.400,0.00,0.00,0.200,0.00,0.400,0.00,7500 7 | Client 4,0.00,0.211,0.00,0.00,0.526,0.00,0.263,0.00,0.00,0.00,5700 8 | Client 5,0.00,0.211,0.00,0.00,0.00,0.263,0.00,0.526,0.00,0.00,5700 9 | Client 6,0.00,0.286,0.00,0.00,0.00,0.357,0.357,0.00,0.00,0.00,4200 10 | Client 7,0.231,0.00,0.308,0.00,0.00,0.00,0.00,0.462,0.00,0.00,6500 11 | Client 8,0.00,0.167,0.00,0.417,0.00,0.00,0.00,0.00,0.417,0.00,7200 12 | Client 9,0.231,0.00,0.308,0.00,0.00,0.00,0.00,0.00,0.00,0.462,6500 -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/partition-reports/fmnist_noniid_labeldir_clients_10.csv: -------------------------------------------------------------------------------- 1 | Class frequencies: 2 | client,class0,class1,class2,class3,class4,class5,class6,class7,class8,class9,Amount 3 | Client 0,0.000,0.040,0.00,0.007,0.036,0.092,0.008,0.454,0.362,0.00,6968 4 | Client 1,0.201,0.081,0.002,0.302,0.193,0.221,0.00,0.00,0.00,0.00,6300 5 | Client 2,0.002,0.090,0.172,0.178,0.080,0.048,0.430,0.00,0.00,0.00,6482 6 | Client 3,0.00,0.002,0.147,0.142,0.033,0.058,0.000,0.058,0.280,0.279,4687 7 | Client 4,0.040,0.177,0.014,0.00,0.200,0.022,0.154,0.204,0.189,0.00,6530 8 | Client 5,0.343,0.256,0.011,0.043,0.019,0.040,0.094,0.068,0.011,0.116,4562 9 | Client 6,0.065,0.001,0.078,0.146,0.301,0.033,0.000,0.098,0.003,0.274,6781 10 | Client 7,0.322,0.294,0.217,0.004,0.050,0.113,0.00,0.00,0.00,0.00,6637 11 | Client 8,0.039,0.016,0.120,0.060,0.002,0.283,0.161,0.015,0.051,0.252,5810 12 | Client 9,0.016,0.047,0.261,0.128,0.014,0.084,0.150,0.033,0.106,0.160,5243 -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/partition-reports/fmnist_unbalance_clients_10.csv: -------------------------------------------------------------------------------- 1 | Class frequencies: 2 | client,class0,class1,class2,class3,class4,class5,class6,class7,class8,class9,Amount 3 | Client 0,0.104,0.097,0.096,0.100,0.102,0.104,0.099,0.099,0.103,0.097,4513 4 | Client 1,0.097,0.128,0.093,0.101,0.093,0.101,0.097,0.110,0.079,0.101,227 5 | Client 2,0.106,0.101,0.106,0.082,0.088,0.106,0.114,0.095,0.095,0.106,377 6 | Client 3,0.095,0.100,0.098,0.101,0.098,0.097,0.108,0.098,0.102,0.102,5704 7 | Client 4,0.064,0.109,0.055,0.100,0.109,0.055,0.118,0.118,0.100,0.173,110 8 | Client 5,0.101,0.101,0.102,0.098,0.099,0.100,0.098,0.101,0.099,0.101,37889 9 | Client 6,0.080,0.125,0.080,0.102,0.080,0.148,0.148,0.068,0.068,0.102,88 10 | Client 7,0.099,0.095,0.097,0.096,0.107,0.101,0.102,0.098,0.105,0.100,4713 11 | Client 8,0.098,0.095,0.097,0.113,0.106,0.101,0.098,0.097,0.098,0.098,3967 12 | Client 9,0.096,0.102,0.099,0.114,0.101,0.100,0.106,0.093,0.098,0.091,2407 -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/rcv1/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .rcv1 import RCV1 16 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/reddit/README.md: -------------------------------------------------------------------------------- 1 | # Reddit Dataset 2 | 3 | We preprocess the Reddit data released by [pushshift.io](https://files.pushshift.io/reddit/) corresponding to December 2017. We perform the following operations: 4 | 5 | 1. Unescape html symbols. 6 | 2. Remove extraneous whitespaces. 7 | 3. Remove non-ascii symbols. 8 | 4. Replace URLS, reddit usernames and subreddit names with special tokens. 9 | 5. Lowercase the text. 10 | 6. Tokenize the text (using nltk's TweetTokenizer). 11 | 12 | We also remove users and comments that simple heuristics or preliminary inspections mark as bots; and remove users with less than 5 or more than 1000 comments (which account for less than 0.01% of users). We include the code for this preprocessing in the ```source``` folder for reference, but host the preprocessed dataset [here](https://drive.google.com/file/d/1CXufUKXNpR7Pn8gUbIerZ1-qHz1KatHH/view?usp=sharing). We further preprocess the data to make it ready for our reference model (by splitting it into train/val/test sets and by creating sequences of 10 tokens for the LSTM) [here](https://drive.google.com/file/d/1lT1Z0N1weG-oA2PgC1Jak_WQ6h3bu7V_/view?usp=sharing). The vocabulary of the 10 thousand most common tokens in the data can be found [here](https://drive.google.com/file/d/1I-CRlfAeiriLmAyICrmlpPE5zWJX4TOY/view?usp=sharing). 13 | 14 | ## Setup Instructions 15 | 16 | 1. To use our reference model, download the data [here](https://drive.google.com/file/d/1PwBpAEMYKNpnv64cQ2TIQfSc_vPbq3OQ/view?usp=sharing) into a ```data``` subfolder in this directory. This is a subsampled version of the complete data. Our reference implementation doesn't yet support training on the [complete dataset](https://drive.google.com/file/d/1lT1Z0N1weG-oA2PgC1Jak_WQ6h3bu7V_/view?usp=sharing), as it loads all given clients into memory. 17 | 2. With the data in the appropriate directory, run build the training vocabulary by running ```python build_vocab.py --data-dir ./data/train --target-dir vocab```. 18 | 3. With the data and the training vocabulary, you can now run our reference implementation in the ```models``` directory using a command as the following: 19 | - ```python3 main.py -dataset reddit -model stacked_lstm --eval-every 10 --num-rounds 100 --clients-per-round 10 --batch-size 5 -lr 5.65 --metrics-name reddit_experiment``` 20 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/reddit/build_vocab.py: -------------------------------------------------------------------------------- 1 | """Builds vocabulary file from data.""" 2 | 3 | import argparse 4 | import collections 5 | import json 6 | import os 7 | import pickle 8 | import re 9 | 10 | 11 | def build_counter(train_data, initial_counter=None): 12 | train_tokens = [] 13 | for u in train_data: 14 | for c in train_data[u]['x']: 15 | train_tokens.extend([s for s in c]) 16 | 17 | all_tokens = [] 18 | for i in train_tokens: 19 | all_tokens.extend(i) 20 | train_tokens = [] 21 | 22 | if initial_counter is None: 23 | counter = collections.Counter() 24 | else: 25 | counter = initial_counter 26 | 27 | counter.update(all_tokens) 28 | all_tokens = [] 29 | 30 | return counter 31 | 32 | 33 | def build_vocab(counter, vocab_size=10000): 34 | pad_symbol, unk_symbol = 0, 1 35 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) 36 | count_pairs = count_pairs[:(vocab_size - 2)] # -2 to account for the unknown and pad symbols 37 | 38 | words, _ = list(zip(*count_pairs)) 39 | 40 | vocab = {} 41 | vocab[''] = pad_symbol 42 | vocab[''] = unk_symbol 43 | 44 | for i, w in enumerate(words): 45 | if w != '': 46 | vocab[w] = i + 1 47 | 48 | return {'vocab': vocab, 'size': vocab_size, 'unk_symbol': unk_symbol, 'pad_symbol': pad_symbol} 49 | 50 | 51 | def load_leaf_data(file_path): 52 | with open(file_path) as json_file: 53 | data = json.load(json_file) 54 | to_ret = data['user_data'] 55 | data = None 56 | return to_ret 57 | 58 | 59 | def save_vocab(vocab, target_dir): 60 | os.makedirs(target_dir, exist_ok=True) 61 | pickle.dump(vocab, open(os.path.join(target_dir, 'reddit_vocab.pck'), 'wb')) 62 | 63 | 64 | def main(): 65 | args = parse_args() 66 | 67 | json_files = [f for f in os.listdir(args.data_dir) if f.endswith('.json')] 68 | json_files.sort() 69 | 70 | counter = None 71 | train_data = {} 72 | for f in json_files: 73 | print('loading {}'.format(f)) 74 | train_data = load_leaf_data(os.path.join(args.data_dir, f)) 75 | print('counting {}'.format(f)) 76 | counter = build_counter(train_data, initial_counter=counter) 77 | print() 78 | train_data = {} 79 | 80 | if counter is not None: 81 | vocab = build_vocab(counter, vocab_size=args.vocab_size) 82 | save_vocab(vocab, args.target_dir) 83 | else: 84 | print('No files to process.') 85 | 86 | 87 | def parse_args(): 88 | parser = argparse.ArgumentParser() 89 | 90 | parser.add_argument('--data-dir', 91 | help='dir with training file;', 92 | type=str, 93 | required=True) 94 | parser.add_argument('--vocab-size', 95 | help='size of the vocabulary;', 96 | type=int, 97 | default=10000, 98 | required=False) 99 | parser.add_argument('--target-dir', 100 | help='dir with training file;', 101 | type=str, 102 | default='./', 103 | required=False) 104 | 105 | return parser.parse_args() 106 | 107 | 108 | if __name__ == '__main__': 109 | main() 110 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/reddit/source/clean_raw.py: -------------------------------------------------------------------------------- 1 | import json 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import os 5 | import pickle 6 | import re 7 | import string 8 | import html 9 | 10 | from nltk.tokenize import TweetTokenizer 11 | 12 | 13 | DIR = os.path.join('data', 'reddit_merged') 14 | FINAL_DIR = os.path.join('data', 'reddit_clean') 15 | 16 | PHRASES_TO_AVOID = [ 17 | '[ deleted ]', 18 | '[ removed ]', 19 | '[deleted]', 20 | '[removed]', 21 | 'bot', 22 | 'thank you for participating', 23 | 'thank you for your submission', 24 | 'thanks for your submission', 25 | 'your submission has been removed', 26 | 'your comment has been removed', 27 | 'downvote this comment if this is', 28 | 'your post has been removed', 29 | ] 30 | 31 | def clean_file(f, tknzr): 32 | reddit = pickle.load(open(os.path.join(DIR, f), 'rb')) 33 | 34 | clean_reddit = {} 35 | for u, comments in reddit.items(): 36 | 37 | clean_comments = [] 38 | for c in comments: 39 | c.clean_body(tknzr) 40 | if len(c.body) > 0 and not any([p in c.body for p in PHRASES_TO_AVOID]): 41 | clean_comments.append(c) 42 | 43 | if len(clean_comments) > 0: 44 | clean_reddit[u] = clean_comments 45 | 46 | pickle.dump( 47 | clean_reddit, 48 | open(os.path.join(FINAL_DIR, f.replace('merged', 'cleaned')), 'wb')) 49 | 50 | def main(): 51 | tknzr = TweetTokenizer() 52 | 53 | if not os.path.exists(FINAL_DIR): 54 | os.makedirs(FINAL_DIR) 55 | 56 | files = [f for f in os.listdir(DIR) if f.endswith('.pck')] 57 | files.sort() 58 | 59 | num_files = len(files) 60 | for i, f in enumerate(files): 61 | clean_file(f, tknzr) 62 | print('Done with {} of {}'.format(i, num_files)) 63 | 64 | if __name__ == '__main__': 65 | main() 66 | 67 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/reddit/source/delete_small_users.py: -------------------------------------------------------------------------------- 1 | import json 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import os 5 | import pickle 6 | 7 | 8 | DIR = os.path.join('data', 'reddit_clean') 9 | FINAL_DIR = os.path.join('data', 'reddit_subsampled') 10 | 11 | 12 | def subsample_file(f): 13 | reddit = pickle.load(open(os.path.join(DIR, f), 'rb')) 14 | 15 | subsampled_reddit = {} 16 | for u, comments in reddit.items(): 17 | 18 | subsampled_comments = [c for c in comments if len(c.body.split()) >= 5] 19 | 20 | if len(subsampled_comments) >= 5 and len(subsampled_comments) <= 1000: 21 | subsampled_reddit[u] = subsampled_comments 22 | 23 | pickle.dump( 24 | subsampled_reddit, 25 | open(os.path.join(FINAL_DIR, f.replace('cleaned', 'subsampled')), 'wb')) 26 | 27 | 28 | def main(): 29 | if not os.path.exists(FINAL_DIR): 30 | os.makedirs(FINAL_DIR) 31 | 32 | files = [f for f in os.listdir(DIR) if f.endswith('.pck')] 33 | files.sort() 34 | 35 | num_files = len(files) 36 | for i, f in enumerate(files): 37 | subsample_file(f) 38 | print('Done with {} of {}'.format(i, num_files)) 39 | 40 | if __name__ == '__main__': 41 | main() 42 | 43 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/reddit/source/get_json.py: -------------------------------------------------------------------------------- 1 | import math 2 | import json 3 | import os 4 | import pickle 5 | 6 | DIR = os.path.join('data', 'reddit_subsampled') 7 | FINAL_DIR = os.path.join('data', 'reddit_json') 8 | FILES_PER_JSON = 10 9 | 10 | 11 | def merge_dicts(x, y): 12 | z = x.copy() 13 | z.update(y) 14 | return z 15 | 16 | 17 | def to_leaf_format(some_json, start_idx=0): 18 | leaf_json = {'users': [], 'num_samples': [], 'user_data': {}} 19 | new_idx = start_idx 20 | for u, comments in some_json.items(): 21 | new_idx += 1 22 | leaf_json['users'].append(str(new_idx)) 23 | leaf_json['num_samples'].append(len(comments)) 24 | 25 | x = [] 26 | y = [] 27 | for c in comments: 28 | assert c.author == u 29 | 30 | c_x = c.body 31 | c_y = { 32 | 'subreddit': c.subreddit, 33 | 'created_utc': c.created_utc, 34 | 'score': c.score, 35 | } 36 | 37 | x.append(c_x) 38 | y.append(c_y) 39 | 40 | user_data = {'x': x, 'y': y} 41 | leaf_json['user_data'][str(new_idx)] = user_data 42 | 43 | return leaf_json, new_idx 44 | 45 | 46 | def files_to_json(files, json_name, start_user_idx=0): 47 | all_users = {} 48 | 49 | for f in files: 50 | f_dir = os.path.join(DIR, f) 51 | f_users = pickle.load(open(f_dir, 'rb')) 52 | 53 | all_users = merge_dicts(all_users, f_users) 54 | 55 | all_users, last_user_idx = to_leaf_format(all_users, start_user_idx) 56 | 57 | with open(os.path.join(FINAL_DIR, json_name), 'w') as outfile: 58 | json.dump(all_users, outfile) 59 | 60 | return last_user_idx 61 | 62 | 63 | def main(): 64 | if not os.path.exists(FINAL_DIR): 65 | os.makedirs(FINAL_DIR) 66 | 67 | files = [f for f in os.listdir(DIR) if f.endswith('.pck')] 68 | files.sort() 69 | 70 | num_files = len(files) 71 | num_json = math.ceil(num_files / FILES_PER_JSON) 72 | 73 | last_user_idx = 0 74 | for i in range(num_json): 75 | cur_files = files[i * FILES_PER_JSON : (i+1) * FILES_PER_JSON] 76 | print('processing until', (i+1) * FILES_PER_JSON) 77 | last_user_idx = files_to_json( 78 | cur_files, 79 | 'reddit_{}.json'.format(i), 80 | last_user_idx) 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/reddit/source/merge_raw_users.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import pickle 4 | import random 5 | 6 | from reddit_utils import RedditComment 7 | 8 | 9 | DIR = os.path.join('data', 'reddit_raw') 10 | USERS_PER_REPEAT = 200000 11 | USERS_PER_FILE = 20000 12 | FINAL_DIR = os.path.join('data', 'reddit_merged') 13 | 14 | if not os.path.exists(FINAL_DIR): 15 | os.makedirs(FINAL_DIR) 16 | 17 | files = [f for f in os.listdir(DIR) if f.endswith('.pck')] 18 | files.sort() 19 | 20 | all_users = {} 21 | 22 | for f in files: 23 | f_path = os.path.join(DIR, f) 24 | users = pickle.load(open(f_path, 'rb')) 25 | users = list(users.keys()) 26 | 27 | for u in users: 28 | if u not in all_users: 29 | all_users[u] = [] 30 | all_users[u].append(f) 31 | 32 | 33 | user_keys = list(all_users.keys()) 34 | random.seed(3760145) 35 | random.shuffle(user_keys) 36 | 37 | num_users = len(all_users) 38 | 39 | num_lots = (num_users // USERS_PER_REPEAT) + 1 40 | print('num_lots', num_lots) 41 | 42 | cur_file = 1 43 | for l in range(num_lots): 44 | min_idx, max_idx = l * USERS_PER_REPEAT, min((l + 1) * USERS_PER_REPEAT, num_users) 45 | 46 | cur_user_keys = user_keys[min_idx:max_idx] 47 | num_cur_users = len(cur_user_keys) 48 | cur_users = {u: [] for u in cur_user_keys} 49 | 50 | for f in files: 51 | f_path = os.path.join(DIR, f) 52 | users = pickle.load(open(f_path, 'rb')) 53 | 54 | for u in cur_users: 55 | if f not in all_users[u]: 56 | continue 57 | 58 | cur_users[u].extend([c for c in users[u] if len(c.body) > 0]) 59 | 60 | written_users = 0 61 | while written_users < num_cur_users: 62 | low_bound, high_bound = written_users, min(written_users + USERS_PER_FILE, num_cur_users) 63 | file_keys = cur_user_keys[low_bound:high_bound] 64 | file_users = {u: cur_users[u] for u in file_keys if len(cur_users[u]) >= 5} 65 | 66 | pickle.dump(file_users, open(os.path.join(FINAL_DIR, 'reddit_users_merged_{}.pck'.format(cur_file)), 'wb')) 67 | 68 | written_users += USERS_PER_FILE 69 | cur_file += 1 70 | 71 | print(l + 1) 72 | 73 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/reddit/source/reddit_utils.py: -------------------------------------------------------------------------------- 1 | import html 2 | import re 3 | from nltk.tokenize import TweetTokenizer 4 | 5 | 6 | URL_TOKEN = '' 7 | USER_TOKEN = '' 8 | SUBREDDIT_TOKEN = '' 9 | 10 | URL_REGEX = r'http\S+' 11 | USER_REGEX = r'(?:\/?u\/\w+)' 12 | SUBREDDIT_REGEX = r'(?:\/?r\/\w+)' 13 | 14 | 15 | class RedditComment: 16 | 17 | def __init__(self, reddit_dict): 18 | self.body = reddit_dict['body'] 19 | self.author = reddit_dict['author'] 20 | self.subreddit = reddit_dict['subreddit'] 21 | self.subreddit_id = reddit_dict['subreddit_id'] 22 | self.created_utc = reddit_dict['created_utc'] 23 | self.score = reddit_dict['score'] 24 | 25 | def clean_body(self, tknzr=None): 26 | if tknzr is None: 27 | tknzr = TweetTokenizer() 28 | 29 | # unescape html symbols. 30 | new_body = html.unescape(self.body) 31 | 32 | # remove extraneous whitespace. 33 | new_body = new_body.replace('\n', ' ') 34 | new_body = new_body.replace('\t', ' ') 35 | new_body = re.sub(r'\s+', ' ', new_body).strip() 36 | 37 | # remove non-ascii symbols. 38 | new_body = new_body.encode('ascii', errors='ignore').decode() 39 | 40 | # replace URLS with a special token. 41 | new_body = re.sub(URL_REGEX, URL_TOKEN, new_body) 42 | 43 | # replace reddit user with a token 44 | new_body = re.sub(USER_REGEX, USER_TOKEN, new_body) 45 | 46 | # replace subreddit names with a token 47 | new_body = re.sub(SUBREDDIT_REGEX, SUBREDDIT_TOKEN, new_body) 48 | 49 | # lowercase the text 50 | new_body = new_body.casefold() 51 | 52 | # Could be done in addition: 53 | # get rid of comments with quotes 54 | 55 | # tokenize the text 56 | new_body = tknzr.tokenize(new_body) 57 | 58 | self.body = ' '.join(new_body) 59 | 60 | def __str__(self): 61 | return str(vars(self)) 62 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/reddit/source/run_reddit.sh: -------------------------------------------------------------------------------- 1 | python get_raw_users.py 2 | 3 | echo 'Good job with raw' 4 | 5 | python merge_raw_users.py 6 | 7 | echo 'Good job with merging' 8 | 9 | python clean_raw.py 10 | 11 | echo 'Good job with cleaning' 12 | 13 | python delete_small_users.py 14 | 15 | echo 'Good job subsampling' 16 | 17 | python get_json.py 18 | 19 | echo 'Good job creating json' 20 | 21 | python preprocess.py 22 | 23 | echo 'Good job preprocessing' 24 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/sent140/preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # download data and convert to .json format 4 | 5 | if [ ! -d "data/all_data" ] || [ ! "$(ls -A data/all_data)" ]; then 6 | cd preprocess 7 | bash data_to_json.sh 8 | cd .. 9 | else 10 | echo "using existing data/all_data data folder to preprocess" 11 | fi 12 | 13 | NAME="sent140" # name of the dataset, equivalent to directory name 14 | 15 | cd ../utils 16 | 17 | bash preprocess.sh --name $NAME $@ 18 | 19 | cd ../$NAME -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/sent140/preprocess/combine_data.py: -------------------------------------------------------------------------------- 1 | ''' 2 | each row of created .csv file is of the form: 3 | polarity, id, date, query, user, comment, test_or_training 4 | ''' 5 | 6 | import csv 7 | import os 8 | 9 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 10 | 11 | train_file_name = os.path.join(parent_path, 'data', 'raw_data', 'training.csv') 12 | 13 | training = [] 14 | with open(train_file_name, 'rt', encoding='ISO-8859-1') as f: 15 | reader = csv.reader(f) 16 | training = list(reader) 17 | 18 | test_file_name = os.path.join(parent_path, 'data', 'raw_data', 'test.csv') 19 | 20 | test = [] 21 | with open(test_file_name, 'rt', encoding='ISO-8859-1') as f: 22 | reader = csv.reader(f) 23 | test = list(reader) 24 | 25 | out_file_name = os.path.join(parent_path, 'data', 'intermediate', 'all_data.csv') 26 | 27 | with open(out_file_name, 'w') as f: 28 | writer = csv.writer(f) 29 | 30 | for row in training: 31 | row.append('training') 32 | writer.writerow(row) 33 | 34 | for row in test: 35 | row.append('test') 36 | writer.writerow(row) 37 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/sent140/preprocess/data_to_json.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import os 4 | 5 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 6 | 7 | data_dir = os.path.join(parent_path, 'data', 'intermediate', 'all_data.csv') 8 | 9 | data = [] 10 | with open(data_dir, 'rt', encoding='ISO-8859-1') as f: 11 | reader = csv.reader(f) 12 | data = list(reader) 13 | 14 | data = sorted(data, key=lambda x: x[4]) 15 | 16 | # ------------ 17 | # get # of users in data, and list of users (note automatically sorted) 18 | 19 | num_users = 1 20 | cuser = data[0][4] 21 | users = [cuser] 22 | 23 | for i in range(len(data)): 24 | row = data[i] 25 | tuser = row[4] 26 | if tuser != cuser: 27 | num_users += 1 28 | cuser = tuser 29 | users.append(tuser) 30 | 31 | # ------------ 32 | # get # of samples for each user 33 | 34 | num_samples = [0 for _ in range(num_users)] 35 | cuser = data[0][4] 36 | user_i = 0 37 | 38 | for i in range(len(data)): 39 | row = data[i] 40 | tuser = row[4] 41 | if tuser != cuser: 42 | cuser = tuser 43 | user_i += 1 44 | num_samples[user_i] += 1 45 | 46 | # ------------ 47 | # create user_data 48 | 49 | user_data = {} 50 | row_i = 0 51 | 52 | for u in users: 53 | user_data[u] = {'x': [], 'y': []} 54 | 55 | while ((row_i < len(data)) and (data[row_i][4] == u)): 56 | row = data[row_i] 57 | y = 1 if row[0] == "4" else 0 58 | user_data[u]['x'].append(row[1:]) 59 | user_data[u]['y'].append(y) 60 | 61 | row_i += 1 62 | 63 | # ------------ 64 | # create .json file 65 | 66 | all_data = {} 67 | all_data['users'] = users 68 | all_data['num_samples'] = num_samples 69 | all_data['user_data'] = user_data 70 | 71 | file_path = os.path.join(parent_path, 'data', 'all_data', 'all_data.json') 72 | 73 | with open(file_path, 'w') as outfile: 74 | json.dump(all_data, outfile) 75 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/sent140/preprocess/data_to_json.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -d "../data" ]; then 4 | mkdir ../data 5 | fi 6 | if [ ! -d "../data/raw_data" ]; then 7 | mkdir ../data/raw_data 8 | fi 9 | if [ ! -f ../data/raw_data/training.csv ] || [ ! -f ../data/raw_data/test.csv ]; then 10 | echo "------------------------------" 11 | echo "retrieving raw data" 12 | 13 | bash get_data.sh 14 | echo "finished retrieving raw data" 15 | else 16 | echo "using existing retrieved raw data" 17 | fi 18 | 19 | echo "generating intermediate data" 20 | if [ ! -d "../data/intermediate" ]; then 21 | mkdir ../data/intermediate 22 | fi 23 | 24 | if [ ! "$(ls -A ../data/intermediate)" ]; then 25 | echo "------------------------------" 26 | echo "combining raw_data .csv files" 27 | python3 combine_data.py 28 | echo "finished combining raw_data .csv files" 29 | else 30 | echo "using existing retrieved raw data" 31 | fi 32 | 33 | if [ ! -d "../data/all_data" ]; then 34 | mkdir ../data/all_data 35 | fi 36 | 37 | if [ ! "$(ls -A ../data/all_data)" ]; then 38 | echo "------------------------------" 39 | echo "converting data to .json format" 40 | python3 data_to_json.py 41 | echo "finished converting data to .json format" 42 | else 43 | echo "using existing data/all_data" 44 | fi 45 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/sent140/preprocess/get_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../data/raw_data 4 | 5 | if [ ! -f trainingandtestdata.zip ]; then 6 | echo "downloading trainingandtestdata.zip" 7 | wget --no-check-certificate http://cs.stanford.edu/people/alecmgo/trainingandtestdata.zip 8 | else 9 | echo "using existing trainingandtestdata.zip" 10 | fi 11 | 12 | if [ ! -f training.csv ] || [ ! -f test.csv ]; then 13 | echo "unzipping trainingandtestdata.zip" 14 | unzip trainingandtestdata.zip 15 | mv training.1600000.processed.noemoticon.csv training.csv 16 | mv testdata.manual.2009.06.14.csv test.csv 17 | else 18 | echo "using existing training.csv and test.csv in data/raw_data" 19 | fi 20 | #rm trainingandtestdata.zip 21 | 22 | cd ../../preprocess -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/sent140/stats.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME="sent140" 4 | 5 | cd ../utils 6 | 7 | python3 stats.py --name $NAME 8 | 9 | cd ../$NAME -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/shakespeare/preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # download data and convert to .json format 4 | 5 | RAWTAG="" 6 | if [[ $@ = *"--raw"* ]]; then 7 | RAWTAG="--raw" 8 | fi 9 | 10 | if [ ! -d "data/all_data" ] || [ ! "$(ls -A data/all_data)" ]; then 11 | cd preprocess 12 | bash data_to_json.sh $RAWTAG 13 | cd .. 14 | else 15 | echo "using existing data/all_data data folder to preprocess" 16 | fi 17 | 18 | NAME="shakespeare" 19 | 20 | cd ../utils 21 | 22 | bash preprocess.sh --name $NAME $@ 23 | 24 | cd ../$NAME -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/shakespeare/preprocess/data_to_json.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -d "../data" ]; then 4 | mkdir ../data 5 | fi 6 | 7 | if [ ! -d "../data/raw_data" ]; then 8 | mkdir ../data/raw_data 9 | fi 10 | 11 | if [ ! -f ../data/raw_data/raw_data.txt ]; then 12 | bash get_data.sh 13 | else 14 | echo "using existing data/raw_data/raw_data.txt" 15 | fi 16 | 17 | if [ ! -d "../data/raw_data/by_play_and_character" ] || [ ! -f ../data/raw_data/users_and_plays.json ]; then 18 | echo "dividing txt data between users" 19 | python3 preprocess_shakespeare.py ../data/raw_data/raw_data.txt ../data/raw_data/ 20 | else 21 | echo "using existing divided files of txt data" 22 | fi 23 | 24 | RAWTAG="" 25 | if [[ $@ = *"--raw"* ]]; then 26 | RAWTAG="--raw" 27 | fi 28 | if [ ! -d "../data/all_data" ]; then 29 | mkdir ../data/all_data 30 | fi 31 | if [ ! "$(ls -A ../data/all_data)" ]; then 32 | echo "generating all_data.json in data/all_data" 33 | python3 gen_all_data.py $RAWTAG 34 | fi -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/shakespeare/preprocess/gen_all_data.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import json 4 | import os 5 | 6 | from shake_utils import parse_data_in 7 | 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument('--raw', 11 | help='include users\' raw .txt data in respective .json files', 12 | action="store_true") 13 | 14 | parser.set_defaults(raw=False) 15 | 16 | args = parser.parse_args() 17 | 18 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 19 | 20 | users_and_plays_path = os.path.join(parent_path, 'data', 'raw_data', 'users_and_plays.json') 21 | txt_dir = os.path.join(parent_path, 'data', 'raw_data', 'by_play_and_character') 22 | json_data = parse_data_in(txt_dir, users_and_plays_path, args.raw) 23 | json_path = os.path.join(parent_path, 'data', 'all_data', 'all_data.json') 24 | with open(json_path, 'w') as outfile: 25 | json.dump(json_data, outfile) 26 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/shakespeare/preprocess/get_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../data/raw_data 4 | 5 | if [ ! -f 1994-01-100.zip ]; then 6 | wget http://www.gutenberg.org/files/100/old/1994-01-100.zip 7 | else 8 | echo "using existing 1994-01-100.zip" 9 | fi 10 | 11 | echo "unzipping 1994-01-100.zip" 12 | unzip 1994-01-100.zip 13 | #rm 1994-01-100.zip 14 | mv 100.txt raw_data.txt 15 | 16 | #wget --adjust-extension http://www.gutenberg.org/files/100/100-0.txt 17 | #mv 100-0.txt raw_data.txt 18 | 19 | cd ../../preprocess -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/shakespeare/preprocess/shake_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | helper functions for preprocessing shakespeare data 3 | ''' 4 | 5 | import json 6 | import os 7 | import re 8 | 9 | def __txt_to_data(txt_dir, seq_length=80): 10 | """Parses text file in given directory into data for next-character model. 11 | 12 | Args: 13 | txt_dir: path to text file 14 | seq_length: length of strings in X 15 | """ 16 | raw_text = "" 17 | with open(txt_dir,'r') as inf: 18 | raw_text = inf.read() 19 | raw_text = raw_text.replace('\n', ' ') 20 | raw_text = re.sub(r" *", r' ', raw_text) 21 | dataX = [] 22 | dataY = [] 23 | for i in range(0, len(raw_text) - seq_length, 1): 24 | seq_in = raw_text[i:i + seq_length] 25 | seq_out = raw_text[i + seq_length] 26 | dataX.append(seq_in) 27 | dataY.append(seq_out) 28 | return dataX, dataY 29 | 30 | def parse_data_in(data_dir, users_and_plays_path, raw=False): 31 | ''' 32 | returns dictionary with keys: users, num_samples, user_data 33 | raw := bool representing whether to include raw text in all_data 34 | if raw is True, then user_data key 35 | removes users with no data 36 | ''' 37 | with open(users_and_plays_path, 'r') as inf: 38 | users_and_plays = json.load(inf) 39 | files = os.listdir(data_dir) 40 | users = [] 41 | hierarchies = [] 42 | num_samples = [] 43 | user_data = {} 44 | for f in files: 45 | user = f[:-4] 46 | passage = '' 47 | filename = os.path.join(data_dir, f) 48 | with open(filename, 'r') as inf: 49 | passage = inf.read() 50 | dataX, dataY = __txt_to_data(filename) 51 | if(len(dataX) > 0): 52 | users.append(user) 53 | if raw: 54 | user_data[user] = {'raw': passage} 55 | else: 56 | user_data[user] = {} 57 | user_data[user]['x'] = dataX 58 | user_data[user]['y'] = dataY 59 | hierarchies.append(users_and_plays[user]) 60 | num_samples.append(len(dataY)) 61 | all_data = {} 62 | all_data['users'] = users 63 | all_data['hierarchies'] = hierarchies 64 | all_data['num_samples'] = num_samples 65 | all_data['user_data'] = user_data 66 | return all_data 67 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/shakespeare/stats.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME="shakespeare" 4 | 5 | cd ../utils 6 | 7 | python3 stats.py --name $NAME 8 | 9 | cd ../$NAME -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/synthetic/README.md: -------------------------------------------------------------------------------- 1 | # Synthetic Dataset 2 | 3 | We propose a process to generate synthetic federated datasets. The dataset is inspired by the one presented by [Li et al.](https://arxiv.org/abs/1905.10497), but has possible additional heterogeneity designed to make current meta-learning methods (such as [Reptile](https://openai.com/blog/reptile/)) struggle. The high-level goal is to create tasks whose true models are (1) task-dependant, and (2) clustered around more than just one center. To see a description of the whole generative process, please refer to the LEAF paper. 4 | 5 | We note that, at the moment, we default to one cluster of models in our code. This can be easily changed by modifying the PROB_CLUSTERS constant in ```main.py```. 6 | 7 | ## Setup Instructions 8 | - pip3 install numpy 9 | - pip3 install pillow 10 | - Run ```python main.py -num-tasks 1000 -num-classes 5 -num-dim 60``` to generate the initial data. 11 | - Run the ```./preprocess.sh``` (as with the other LEAF datasets) to produce the final data splits. We suggest using the following tags: 12 | - ```--sf``` := fraction of data to sample, written as a decimal; set it to 1.0 in order to keep the number of tasks/users specified earlier. 13 | - ```-k``` := minimum number of samples per user; set it to 5. 14 | - ```-t``` := 'user' to partition users into train-test groups, or 'sample' to partition each user's samples into train-test groups. 15 | - ```--tf``` := fraction of data in training set, written as a decimal; default is 0.9. 16 | - ```--smplseed``` := seed to be used before random sampling of data. 17 | - ```--spltseed``` := seed to be used before random split of data. 18 | 19 | i.e. 20 | - ```./preprocess.sh -s niid --sf 1.0 -k 5 -t sample --tf 0.6``` 21 | 22 | Make sure to delete the rem_user_data, sampled_data, test, and train subfolders in the data directory before re-running preprocess.sh 23 | 24 | ## Notes 25 | - More details on ```preprocess.sh```: 26 | - The order in which ```preprocess.sh``` processes data is 1. generating all_data (done here by the ```main.py``` script), 2. sampling, 3. removing users, and 4. creating train-test split. The script will look at the data in the last generated directory and continue preprocessing from that point. For example, if the ```all_data``` directory has already been generated and the user decides to skip sampling and only remove users with the ```-k``` tag (i.e. running ```preprocess.sh -k 50```), the script will effectively apply a remove user filter to data in ```all_data``` and place the resulting data in the ```rem_user_data``` directory. 27 | - File names provide information about the preprocessing steps taken to generate them. For example, the ```all_data_niid_1_keep_64.json``` file was generated by first sampling 10 percent (.1) of the data ```all_data.json``` in a non-i.i.d. manner and then applying the ```-k 64``` argument to the resulting data. 28 | - Each .json file is an object with 3 keys: 29 | 1. 'users', a list of users 30 | 2. 'num_samples', a list of the number of samples for each user, and 31 | 3. 'user_data', an object with user names as keys and their respective data as values. 32 | - Run ```./stats.sh``` to get statistics of data (data/all_data/all_data.json must have been generated already) 33 | - In order to run reference implementations in ```../models``` directory, the ```-t sample``` tag must be used when running ```./preprocess.sh``` 34 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/synthetic/data_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from scipy.special import softmax 4 | 5 | 6 | NUM_DIM = 10 7 | 8 | class SyntheticDataset: 9 | 10 | def __init__( 11 | self, 12 | num_classes=2, 13 | seed=931231, 14 | num_dim=NUM_DIM, 15 | prob_clusters=[0.5, 0.5]): 16 | 17 | np.random.seed(seed) 18 | 19 | self.num_classes = num_classes 20 | self.num_dim = num_dim 21 | self.num_clusters = len(prob_clusters) 22 | self.prob_clusters = prob_clusters 23 | 24 | self.side_info_dim = self.num_clusters 25 | 26 | self.Q = np.random.normal( 27 | loc=0.0, scale=1.0, size=(self.num_dim + 1, self.num_classes, self.side_info_dim)) 28 | 29 | self.Sigma = np.zeros((self.num_dim, self.num_dim)) 30 | for i in range(self.num_dim): 31 | self.Sigma[i, i] = (i + 1)**(-1.2) 32 | 33 | self.means = self._generate_clusters() 34 | 35 | def get_task(self, num_samples): 36 | cluster_idx = np.random.choice( 37 | range(self.num_clusters), size=None, replace=True, p=self.prob_clusters) 38 | new_task = self._generate_task(self.means[cluster_idx], cluster_idx, num_samples) 39 | return new_task 40 | 41 | def _generate_clusters(self): 42 | means = [] 43 | for i in range(self.num_clusters): 44 | loc = np.random.normal(loc=0, scale=1., size=None) 45 | mu = np.random.normal(loc=loc, scale=1., size=self.side_info_dim) 46 | means.append(mu) 47 | return means 48 | 49 | def _generate_x(self, num_samples): 50 | B = np.random.normal(loc=0.0, scale=1.0, size=None) 51 | loc = np.random.normal(loc=B, scale=1.0, size=self.num_dim) 52 | 53 | samples = np.ones((num_samples, self.num_dim + 1)) 54 | samples[:, 1:] = np.random.multivariate_normal( 55 | mean=loc, cov=self.Sigma, size=num_samples) 56 | 57 | return samples 58 | 59 | def _generate_y(self, x, cluster_mean): 60 | model_info = np.random.normal(loc=cluster_mean, scale=0.1, size=cluster_mean.shape) 61 | w = np.matmul(self.Q, model_info) 62 | 63 | num_samples = x.shape[0] 64 | prob = softmax(np.matmul(x, w) + np.random.normal(loc=0., scale=0.1, size=(num_samples, self.num_classes)), axis=1) 65 | 66 | y = np.argmax(prob, axis=1) 67 | return y, w, model_info 68 | 69 | def _generate_task(self, cluster_mean, cluster_id, num_samples): 70 | x = self._generate_x(num_samples) 71 | y, w, model_info = self._generate_y(x, cluster_mean) 72 | 73 | # now that we have y, we can remove the bias coeff 74 | x = x[:, 1:] 75 | 76 | return {'x': x, 'y': y, 'w': w, 'model_info': model_info, 'cluster': cluster_id} -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/synthetic/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import numpy as np 4 | import os 5 | 6 | import data_generator as generator 7 | 8 | 9 | PROB_CLUSTERS = [1.0] 10 | 11 | 12 | def main(): 13 | args = parse_args() 14 | np.random.seed(args.seed) 15 | 16 | print('Generating dataset') 17 | num_samples = get_num_samples(args.num_tasks) 18 | dataset = generator.SyntheticDataset( 19 | num_classes=args.num_classes, prob_clusters=PROB_CLUSTERS, num_dim=args.num_dim, seed=args.seed) 20 | tasks = [dataset.get_task(s) for s in num_samples] 21 | users, num_samples, user_data = to_leaf_format(tasks) 22 | save_json('data/all_data', 'data.json', users, num_samples, user_data) 23 | print('Done :D') 24 | 25 | 26 | def get_num_samples(num_tasks, min_num_samples=5, max_num_samples=1000): 27 | num_samples = np.random.lognormal(3, 2, (num_tasks)).astype(int) 28 | num_samples = [min(s + min_num_samples, max_num_samples) for s in num_samples] 29 | return num_samples 30 | 31 | 32 | def to_leaf_format(tasks): 33 | users, num_samples, user_data = [], [], {} 34 | 35 | for i, t in enumerate(tasks): 36 | x, y = t['x'].tolist(), t['y'].tolist() 37 | u_id = str(i) 38 | 39 | users.append(u_id) 40 | num_samples.append(len(y)) 41 | user_data[u_id] = {'x': x, 'y': y} 42 | 43 | return users, num_samples, user_data 44 | 45 | 46 | def save_json(json_dir, json_name, users, num_samples, user_data): 47 | if not os.path.exists(json_dir): 48 | os.makedirs(json_dir) 49 | 50 | json_file = { 51 | 'users': users, 52 | 'num_samples': num_samples, 53 | 'user_data': user_data, 54 | } 55 | 56 | with open(os.path.join(json_dir, json_name), 'w') as outfile: 57 | json.dump(json_file, outfile) 58 | 59 | 60 | def parse_args(): 61 | parser = argparse.ArgumentParser() 62 | 63 | parser.add_argument( 64 | '-num-tasks', 65 | help='number of devices;', 66 | type=int, 67 | required=True) 68 | parser.add_argument( 69 | '-num-classes', 70 | help='number of classes;', 71 | type=int, 72 | required=True) 73 | parser.add_argument( 74 | '-num-dim', 75 | help='number of dimensions;', 76 | type=int, 77 | required=True) 78 | parser.add_argument( 79 | '-seed', 80 | help='seed for the random processes;', 81 | type=int, 82 | default=931231, 83 | required=False) 84 | return parser.parse_args() 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/synthetic/preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # download data and convert to .json format 4 | 5 | if [ ! -f "data/all_data/data.json" ]; then 6 | echo "Please run the main.py script to generate the initial data." 7 | exit 1 8 | fi 9 | 10 | 11 | NAME="synthetic" # name of the dataset, equivalent to directory name 12 | 13 | 14 | cd ../utils 15 | 16 | bash preprocess.sh --name $NAME $@ 17 | 18 | cd ../$NAME -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/synthetic/stats.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME="synthetic" 4 | 5 | cd ../utils 6 | 7 | python3 stats.py --name $NAME 8 | 9 | cd ../$NAME -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/utils/README.md: -------------------------------------------------------------------------------- 1 | # UTILS README 2 | 3 | This folder contains leaf preprocessed scripts 4 | 5 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/utils/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/utils/constants.py: -------------------------------------------------------------------------------- 1 | DATASETS = ['sent140', 'femnist', 'shakespeare', 'celeba', 'synthetic'] 2 | SEED_FILES = { 'sampling': 'sampling_seed.txt', 'split': 'split_seed.txt' } 3 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/utils/remove_users.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | removes users with less than the given number of samples 4 | ''' 5 | 6 | import argparse 7 | import json 8 | import os 9 | 10 | from constants import DATASETS 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--name', 15 | help='name of dataset to parse; default: sent140;', 16 | type=str, 17 | choices=DATASETS, 18 | default='sent140') 19 | 20 | parser.add_argument('--min_samples', 21 | help='users with less than x samples are discarded; default: 10;', 22 | type=int, 23 | default=10) 24 | 25 | args = parser.parse_args() 26 | 27 | print('------------------------------') 28 | print('removing users with less than %d samples' % args.min_samples) 29 | 30 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 31 | dir = os.path.join(parent_path, args.name, 'data') 32 | subdir = os.path.join(dir, 'sampled_data') 33 | files = [] 34 | if os.path.exists(subdir): 35 | files = os.listdir(subdir) 36 | if len(files) == 0: 37 | subdir = os.path.join(dir, 'all_data') 38 | files = os.listdir(subdir) 39 | files = [f for f in files if f.endswith('.json')] 40 | 41 | for f in files: 42 | users = [] 43 | hierarchies = [] 44 | num_samples = [] 45 | user_data = {} 46 | 47 | file_dir = os.path.join(subdir, f) 48 | with open(file_dir, 'r') as inf: 49 | data = json.load(inf) 50 | 51 | num_users = len(data['users']) 52 | for i in range(num_users): 53 | curr_user = data['users'][i] 54 | curr_hierarchy = None 55 | if 'hierarchies' in data: 56 | curr_hierarchy = data['hierarchies'][i] 57 | curr_num_samples = data['num_samples'][i] 58 | if (curr_num_samples >= args.min_samples): 59 | user_data[curr_user] = data['user_data'][curr_user] 60 | users.append(curr_user) 61 | if curr_hierarchy is not None: 62 | hierarchies.append(curr_hierarchy) 63 | num_samples.append(data['num_samples'][i]) 64 | 65 | all_data = {} 66 | all_data['users'] = users 67 | if len(hierarchies) == len(users): 68 | all_data['hierarchies'] = hierarchies 69 | all_data['num_samples'] = num_samples 70 | all_data['user_data'] = user_data 71 | 72 | file_name = '%s_keep_%d.json' % ((f[:-5]), args.min_samples) 73 | ouf_dir = os.path.join(dir, 'rem_user_data', file_name) 74 | 75 | print('writing %s' % file_name) 76 | with open(ouf_dir, 'w') as outfile: 77 | json.dump(all_data, outfile) 78 | 79 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/utils/stats.py: -------------------------------------------------------------------------------- 1 | ''' 2 | assumes that the user has already generated .json file(s) containing data 3 | ''' 4 | 5 | import argparse 6 | import json 7 | import matplotlib.pyplot as plt 8 | import math 9 | import numpy as np 10 | import os 11 | 12 | from scipy import io 13 | from scipy import stats 14 | 15 | from constants import DATASETS 16 | 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument('--name', 20 | help='name of dataset to parse; default: sent140;', 21 | type=str, 22 | choices=DATASETS, 23 | default='sent140') 24 | 25 | args = parser.parse_args() 26 | 27 | 28 | def load_data(name): 29 | 30 | users = [] 31 | num_samples = [] 32 | 33 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 34 | data_dir = os.path.join(parent_path, name, 'data') 35 | subdir = os.path.join(data_dir, 'all_data') 36 | 37 | files = os.listdir(subdir) 38 | files = [f for f in files if f.endswith('.json')] 39 | 40 | for f in files: 41 | file_dir = os.path.join(subdir, f) 42 | 43 | with open(file_dir) as inf: 44 | data = json.load(inf) 45 | 46 | users.extend(data['users']) 47 | num_samples.extend(data['num_samples']) 48 | 49 | return users, num_samples 50 | 51 | def print_dataset_stats(name): 52 | users, num_samples = load_data(name) 53 | num_users = len(users) 54 | 55 | print('####################################') 56 | print('DATASET: %s' % name) 57 | print('%d users' % num_users) 58 | print('%d samples (total)' % np.sum(num_samples)) 59 | print('%.2f samples per user (mean)' % np.mean(num_samples)) 60 | print('num_samples (std): %.2f' % np.std(num_samples)) 61 | print('num_samples (std/mean): %.2f' % (np.std(num_samples)/np.mean(num_samples))) 62 | print('num_samples (skewness): %.2f' % stats.skew(num_samples)) 63 | 64 | bins = [0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200] 65 | if args.name == 'shakespeare': 66 | bins = [0, 2000, 4000, 6000, 8000, 10000, 12000, 14000, 16000, 18000, 20000] 67 | if args.name == 'femnist': 68 | bins = [0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200, 220, 240, 260, 280, 300, 320, 340, 360, 380, 400, 420, 440, 460, 480, 500] 69 | if args.name == 'celeba': 70 | bins = [2 * i for i in range(20)] 71 | if args.name == 'sent140': 72 | bins = [i for i in range(16)] 73 | 74 | hist, edges = np.histogram(num_samples, bins=bins) 75 | for e, h in zip(edges, hist): 76 | print(e, "\t", h) 77 | 78 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 79 | data_dir = os.path.join(parent_path, name, 'data') 80 | 81 | plt.hist(num_samples, bins = bins) 82 | fig_name = "%s_hist_nolabel.png" % name 83 | fig_dir = os.path.join(data_dir, fig_name) 84 | plt.savefig(fig_dir) 85 | plt.title(name) 86 | plt.xlabel('number of samples') 87 | plt.ylabel("number of users") 88 | fig_name = "%s_hist.png" % name 89 | fig_dir = os.path.join(data_dir, fig_name) 90 | plt.savefig(fig_dir) 91 | 92 | print_dataset_stats(args.name) 93 | -------------------------------------------------------------------------------- /fedlab_benchmarks/datasets/utils/util.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | 4 | def save_obj(obj, name): 5 | with open(name + '.pkl', 'wb') as f: 6 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 7 | 8 | 9 | def load_obj(name): 10 | with open(name + '.pkl', 'rb') as f: 11 | return pickle.load(f) 12 | 13 | 14 | def iid_divide(l, g): 15 | ''' 16 | divide list l among g groups 17 | each group has either int(len(l)/g) or int(len(l)/g)+1 elements 18 | returns a list of groups 19 | ''' 20 | num_elems = len(l) 21 | group_size = int(len(l)/g) 22 | num_big_groups = num_elems - g * group_size 23 | num_small_groups = g - num_big_groups 24 | glist = [] 25 | for i in range(num_small_groups): 26 | glist.append(l[group_size * i : group_size * (i + 1)]) 27 | bi = group_size*num_small_groups 28 | group_size += 1 29 | for i in range(num_big_groups): 30 | glist.append(l[bi + group_size * i:bi + group_size * (i + 1)]) 31 | return glist 32 | -------------------------------------------------------------------------------- /fedlab_benchmarks/feature-skew-fedavg/README.md: -------------------------------------------------------------------------------- 1 | # README 2 | 3 | demo for NIID-bench "feature distribution skew"-"noise-based feature imbalance" on FedAvg. 4 | 5 | __Setting:__ 6 | - `noise=0.1` (Gaussian noise) 7 | - `partition='homo'` (IID partition) 8 | - `n_parties=10` 9 | - `lr=0.01` 10 | - `momoentum=0.9` 11 | - `weight_decay=1e-5` (L2-norm weight decay) 12 | - `comm_round=50` 13 | - `sample=1` (client sampling ratio) 14 | - `alg='fedavg'` 15 | - `epochs=10` 16 | - `dataset='fmnist'` 17 | - `model='simple-cnn'` 18 | 19 | Top-1 accuracy for FMNIST in paper: $89.1\% \pm 0.3\%$. 20 | 21 | Top-1 accuracy for FMNIST in this demo: $89.37\% \pm 0.14 \%$​​ (5 runs) 22 | 23 | ## Requirements 24 | 25 | fedlab==1.1.2 26 | 27 | ## How to Run? 28 | 29 | `start_server.sh` is for server process launch, and `start_clt.sh` is for client process launch. 30 | 31 | 1. run command in terminal window 1 to launch server: 32 | 33 | ```bash 34 | bash start_server.sh 35 | ``` 36 | 37 | 2. run command in terminal window 2 to launch clients: 38 | 39 | ```bash 40 | bash start_client.sh 41 | ``` 42 | 43 | > random seed for data partiiton over clients can be set using `--seed` in `start_clt.sh`: 44 | > 45 | > ```shell 46 | > python data_partition.py --num-clients 10 --seed 1 47 | > ``` 48 | > 49 | > And the noise distribution can be send with `--noise`: 50 | > 51 | > ```bash 52 | > python client.py --world_size 2 --rank 1 --noise 0.1 53 | > ``` 54 | 55 | 56 | 57 | We highly recommend to launch clients after server is launched to avoid some conficts. 58 | 59 | 60 | 61 | ## References 62 | 63 | - Li, Q., Diao, Y., Chen, Q., & He, B. (2021). Federated learning on non-iid data silos: An experimental study. *arXiv preprint arXiv:2102.02079*. 64 | 65 | -------------------------------------------------------------------------------- /fedlab_benchmarks/feature-skew-fedavg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/feature-skew-fedavg/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/feature-skew-fedavg/config.py: -------------------------------------------------------------------------------- 1 | fmnist_noise_baseline_config = { 2 | "partition": "homo", 3 | "round": 50, 4 | "network": "simple-cnn", 5 | "sample_ratio": 1, 6 | "dataset": "fmnist", 7 | "total_client_num": 10, 8 | "lr": 0.01, 9 | "momentum": 0.9, 10 | "weight_decay": 1e-5, 11 | "batch_size": 64, 12 | "test_batch_size": 32, 13 | "epochs": 10 14 | } 15 | -------------------------------------------------------------------------------- /fedlab_benchmarks/feature-skew-fedavg/data_partition.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import sys 4 | 5 | from fedlab.utils.dataset import FMNISTPartitioner 6 | from fedlab.utils.functional import save_dict 7 | 8 | from torchvision.datasets import FashionMNIST 9 | 10 | 11 | # python data_partition.py --num-clients 10 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser(description='Data partition') 14 | 15 | parser.add_argument('--num-clients', type=int, default=10) 16 | parser.add_argument('--seed', type=int, default=2021) 17 | args = parser.parse_args() 18 | 19 | root = "../../../datasets/FMNIST" 20 | trainset = FashionMNIST(root=root, train=True, download=True) 21 | 22 | # perform partition 23 | partition = FMNISTPartitioner(trainset.targets, 24 | num_clients=args.num_clients, 25 | partition="iid", 26 | seed=args.seed) 27 | save_dict(partition.client_dict, "fmnist_iid.pkl") 28 | -------------------------------------------------------------------------------- /fedlab_benchmarks/feature-skew-fedavg/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | Code below is from NIID-bench official code: 7 | https://github.com/Xtra-Computing/NIID-Bench 8 | """ 9 | class SimpleCNNMNIST(nn.Module): 10 | def __init__(self, input_dim, hidden_dims, output_dim=10): 11 | super(SimpleCNNMNIST, self).__init__() 12 | self.conv1 = nn.Conv2d(1, 6, 5) 13 | self.pool = nn.MaxPool2d(2, 2) 14 | self.conv2 = nn.Conv2d(6, 16, 5) 15 | 16 | # for now, we hard coded this network 17 | # i.e. we fix the number of hidden layers i.e. 2 layers 18 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 19 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 20 | self.fc3 = nn.Linear(hidden_dims[1], output_dim) 21 | 22 | def forward(self, x): 23 | x = self.pool(F.relu(self.conv1(x))) 24 | x = self.pool(F.relu(self.conv2(x))) 25 | x = x.view(-1, 16 * 4 * 4) 26 | 27 | x = F.relu(self.fc1(x)) 28 | x = F.relu(self.fc2(x)) 29 | x = self.fc3(x) 30 | return x 31 | -------------------------------------------------------------------------------- /fedlab_benchmarks/feature-skew-fedavg/start_clt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # perform data partition over clients 3 | python data_partition.py --num-clients 10 --seed 1 4 | echo -e "New data partition done \n\n" 5 | # launch 10 clients in single serial trainer 6 | python client.py --world_size 2 --rank 1 --noise 0.1 7 | -------------------------------------------------------------------------------- /fedlab_benchmarks/feature-skew-fedavg/start_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python server.py --world_size 2 --run 0 3 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedasync/README.md: -------------------------------------------------------------------------------- 1 | # Fedasync 2 | 3 | The demo of [asynchronous federated optimization](https://arxiv.org/abs/1903.03934). 4 | 5 | ## Requirements 6 | 7 | fedlab==1.1.2 8 | 9 | ## Run 10 | 11 | standalone: 12 | 13 | `$ cd standalone` 14 | 15 | `$ python standalone.py` 16 | 17 | 18 | cross_process: 19 | 20 | 21 | `$ cd cross_process` 22 | 23 | `$ bash quick_start.sh` 24 | 25 | ## Performance 26 | 27 | Null 28 | 29 | ## References 30 | 31 | Xie, Cong, Sanmi Koyejo, and Indranil Gupta. "Asynchronous federated optimization." arXiv preprint arXiv:1903.03934 (2019). 32 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedasync/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedasync/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/fedasync/cross_process/client.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torchvision.transforms as transforms 3 | import torch 4 | import argparse 5 | import sys 6 | import os 7 | 8 | from torch import nn 9 | 10 | from fedlab.core.client.manager import ClientActiveManager 11 | from fedlab.core.client.trainer import ClientSGDTrainer 12 | from fedlab.utils.dataset.sampler import RawPartitionSampler 13 | from fedlab.core.network import DistNetwork 14 | 15 | sys.path.append("../../") 16 | from models.cnn import CNN_MNIST 17 | 18 | def get_dataset(args): 19 | """ 20 | :param dataset_name: 21 | :param transform: 22 | :param batch_size: 23 | :return: iterators for the datasetaccuracy_score 24 | """ 25 | train_transform = transforms.Compose([ 26 | transforms.ToTensor(), 27 | ]) 28 | 29 | test_transform = transforms.Compose([ 30 | transforms.ToTensor(), 31 | ]) 32 | trainset = torchvision.datasets.MNIST(root=args.root, 33 | train=True, 34 | download=True, 35 | transform=train_transform) 36 | testset = torchvision.datasets.MNIST(root=args.root, 37 | train=False, 38 | download=True, 39 | transform=test_transform) 40 | 41 | trainloader = torch.utils.data.DataLoader( 42 | trainset, 43 | sampler=RawPartitionSampler(trainset, 44 | client_id=args.rank, 45 | num_replicas=args.world_size - 1), 46 | batch_size=128, 47 | drop_last=True, 48 | num_workers=2) 49 | testloader = torch.utils.data.DataLoader(testset, 50 | batch_size=len(testset), 51 | drop_last=False, 52 | num_workers=2, 53 | shuffle=False) 54 | return trainloader, testloader 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser(description='Distbelief training example') 59 | parser.add_argument('--ip', type=str, default='127.0.0.1') 60 | parser.add_argument('--port', type=str, default='3002') 61 | parser.add_argument('--world_size', type=int) 62 | parser.add_argument('--rank', type=int) 63 | 64 | parser.add_argument("--epoch", type=int, default=2) 65 | parser.add_argument("--lr", type=float, default=0.1) 66 | parser.add_argument("--wd", type=float, default=0) 67 | parser.add_argument("--cuda", type=bool, default=True) 68 | args = parser.parse_args() 69 | args.root = '../../datasets/mnist/' 70 | args.cuda = True 71 | 72 | model = CNN_MNIST() 73 | trainloader, testloader = get_dataset(args) 74 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) 75 | criterion = nn.CrossEntropyLoss() 76 | handler = ClientSGDTrainer(model, 77 | trainloader, 78 | epochs=args.epoch, 79 | optimizer=optimizer, 80 | criterion=criterion, 81 | cuda=args.cuda) 82 | 83 | network = DistNetwork(address=(args.ip, args.port), 84 | world_size=args.world_size, 85 | rank=args.rank) 86 | Manager = ClientActiveManager(trainer=handler, network=network) 87 | Manager.run() 88 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedasync/cross_process/quick_start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | python server.py --world_size 3 & 5 | 6 | python client.py --world_size 3 --rank 1 & 7 | python client.py --world_size 3 --rank 2 & 8 | 9 | wait -------------------------------------------------------------------------------- /fedlab_benchmarks/fedasync/cross_process/server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from fedlab.core.network import DistNetwork 5 | from fedlab.core.server.handler import AsyncParameterServerHandler 6 | from fedlab.core.server.manager import ServerAsynchronousManager 7 | sys.path.append("../../") 8 | from models.cnn import CNN_MNIST 9 | import argparse 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser(description='Distbelief training example') 13 | 14 | parser.add_argument('--ip', type=str, default='127.0.0.1') 15 | parser.add_argument('--port', type=str, default='3002') 16 | parser.add_argument('--world_size', type=int) 17 | args = parser.parse_args() 18 | 19 | model = CNN_MNIST().cpu() 20 | ps = AsyncParameterServerHandler(model) 21 | 22 | network = DistNetwork(address=(args.ip, args.port), 23 | world_size=args.world_size, 24 | rank=0) 25 | Manager = ServerAsynchronousManager(handler=ps, network=network) 26 | 27 | Manager.run() 28 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedasync/standalone/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedasync/standalone/cifar10_iid.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedasync/standalone/cifar10_iid.pkl -------------------------------------------------------------------------------- /fedlab_benchmarks/fedasync/standalone/cifar10_noniid.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedasync/standalone/cifar10_noniid.pkl -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/README.md: -------------------------------------------------------------------------------- 1 | # FedAvg 2 | 3 | [FedAvg](http://proceedings.mlr.press/v54/mcmahan17a.html) is the baseline of synchronous federated learning algorithm, and FedLab implements the algorithm flow of FedAvg, including standalone and Cross Process scenarios. 4 | 5 | ## Requirements 6 | 7 | fedlab==1.1.2 8 | 9 | ## Run 10 | 11 | ### Standalone 12 | 13 | The` SerialTrainer` module is for the FL system simulation on a single machine, and its source code can be found in `fedlab/core fedlab/core/client/trainer.py`. 14 | 15 | Executable scripts is in ` fedlab_benchmarks/algorithm/fedavg/standalone/`. 16 | 17 | ### Cross Process 18 | 19 | The federated simulation of **multi-machine** and **single-machine multi-process** scenarios is the core module of FedLab, which is composed of various modules in `core/client` and `core/server`, please refer to overview for details . 20 | 21 | The executable script is in `fedlab_benchmarks/algorithm/fedavg/cross_process/` 22 | 23 | ## Performance 24 | 25 | Null 26 | 27 | ## References 28 | 29 | McMahan, Brendan, et al. "Communication-efficient learning of deep networks from decentralized data." Artificial intelligence and statistics. PMLR, 2017. -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/cross_process/LEAF_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # get world size dicts for all datasets and run all client processes for each dataset 4 | declare -A dataset_world_size=(['femnist']=5 ['shakespeare']=3) # each process represents one client/server 5 | 6 | for key in ${!dataset_world_size[*]} 7 | do 8 | echo "${key} client_num is ${dataset_world_size[${key}]}" 9 | 10 | echo "server started" 11 | python server.py --ip 127.0.0.1 --port 3002 --world_size ${dataset_world_size[${key}]} --dataset ${key} & 12 | 13 | for ((i=1; i<${dataset_world_size[$key]}; i++)) 14 | do 15 | { 16 | echo "client ${i} started" 17 | python client.py --ip 127.0.0.1 --port 3002 --world_size ${dataset_world_size[${key}]} --rank ${i} --dataset ${key} --epoch 2 18 | } & 19 | done 20 | wait 21 | 22 | echo "${key} experiment end" 23 | done -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/cross_process/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/cross_process/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/cross_process/client.py: -------------------------------------------------------------------------------- 1 | from logging import log 2 | import torch 3 | import argparse 4 | import sys 5 | import os 6 | 7 | from torch import nn 8 | from fedlab.core.client.manager import ClientPassiveManager 9 | from fedlab.core.client.trainer import ClientSGDTrainer 10 | from fedlab.core.network import DistNetwork 11 | from fedlab.utils.logger import Logger 12 | 13 | from setting import get_model, get_dataset 14 | 15 | if __name__ == "__main__": 16 | 17 | parser = argparse.ArgumentParser(description="Distbelief training example") 18 | 19 | parser.add_argument("--ip", type=str) 20 | parser.add_argument("--port", type=str) 21 | parser.add_argument("--world_size", type=int) 22 | parser.add_argument("--rank", type=int) 23 | 24 | parser.add_argument("--lr", type=float, default=0.01) 25 | parser.add_argument("--epoch", type=int, default=5) 26 | parser.add_argument("--dataset", type=str) 27 | parser.add_argument("--batch_size", type=int, default=100) 28 | 29 | parser.add_argument("--gpu", type=str, default="0,1,2,3") 30 | parser.add_argument("--ethernet", type=str, default=None) 31 | args = parser.parse_args() 32 | 33 | if args.gpu != "-1": 34 | args.cuda = True 35 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 36 | else: 37 | args.cuda = False 38 | 39 | model = get_model(args) 40 | trainloader, testloader = get_dataset(args) 41 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) 42 | criterion = nn.CrossEntropyLoss() 43 | 44 | network = DistNetwork( 45 | address=(args.ip, args.port), 46 | world_size=args.world_size, 47 | rank=args.rank, 48 | ethernet=args.ethernet, 49 | ) 50 | 51 | LOGGER = Logger(log_name="client " + str(args.rank)) 52 | 53 | trainer = ClientSGDTrainer( 54 | model, 55 | trainloader, 56 | epochs=args.epoch, 57 | optimizer=optimizer, 58 | criterion=criterion, 59 | cuda=args.cuda, 60 | logger=LOGGER, 61 | ) 62 | 63 | manager_ = ClientPassiveManager(trainer=trainer, 64 | network=network, 65 | logger=LOGGER) 66 | manager_.run() 67 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/cross_process/quick_start.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | 3 | python server.py --ip 127.0.0.1 --port 3002 --world_size 3 --dataset mnist --round 3 & 4 | 5 | python client.py --ip 127.0.0.1 --port 3002 --world_size 3 --rank 1 --dataset mnist & 6 | 7 | python client.py --ip 127.0.0.1 --port 3002 --world_size 3 --rank 2 --dataset mnist & 8 | 9 | wait -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/cross_process/server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from fedlab.utils.logger import Logger 4 | from fedlab.core.server.handler import SyncParameterServerHandler 5 | from fedlab.core.server.manager import ServerSynchronousManager 6 | from fedlab.core.network import DistNetwork 7 | from setting import get_model, get_dataset 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser(description='FL server example') 11 | 12 | parser.add_argument('--ip', type=str) 13 | parser.add_argument('--port', type=str) 14 | parser.add_argument('--world_size', type=int) 15 | 16 | parser.add_argument('--round', type=int, default=5) 17 | parser.add_argument('--dataset', type=str) 18 | parser.add_argument('--ethernet', type=str, default=None) 19 | parser.add_argument('--sample', type=float, default=1) 20 | 21 | args = parser.parse_args() 22 | 23 | model = get_model(args) 24 | LOGGER = Logger(log_name="server") 25 | handler = SyncParameterServerHandler(model, 26 | global_round=args.round, 27 | logger=LOGGER, 28 | sample_ratio=args.sample) 29 | network = DistNetwork(address=(args.ip, args.port), 30 | world_size=args.world_size, 31 | rank=0, 32 | ethernet=args.ethernet) 33 | manager_ = ServerSynchronousManager(handler=handler, 34 | network=network, 35 | logger=LOGGER) 36 | manager_.run() 37 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/cross_process/start_clients.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # start a group of client. continuous rank is required. 4 | # example: bash start_clients.sh ip port wolrd_size 1 5 dataset ethernet epoch 5 | # start client from rank 1-5 6 | 7 | echo "Connecting server:($1:$2), world_size $3, rank $4-$5, dataset $6, ethernet $7, local epoch $8" 8 | 9 | for ((i=$4; i<=$5; i++)) 10 | do 11 | { 12 | echo "client ${i} started" 13 | python client.py --server_ip $1 --server_port $2 --world_size $3 --rank ${i} --dataset $6 --ethernet $7 --epoch $8 & 14 | sleep 2s # waiting gpu resource allocation 15 | } 16 | done 17 | wait 18 | 19 | 20 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/scale/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/cifar10-cnn/cifar10_iid.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/scale/cifar10-cnn/cifar10_iid.pkl -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/cifar10-cnn/cifar10_noniid.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/scale/cifar10-cnn/cifar10_noniid.pkl -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/cifar10-cnn/cifar10_partition.py: -------------------------------------------------------------------------------- 1 | from fedlab.utils.functional import save_dict 2 | from fedlab.utils.dataset.slicing import noniid_slicing, random_slicing 3 | 4 | import torchvision 5 | 6 | root = '../../../../datasets/data/cifar10/' 7 | trainset = torchvision.datasets.CIFAR10(root=root, train=True, download=False) 8 | 9 | data_indices = noniid_slicing(trainset, num_clients=100, num_shards=200) 10 | save_dict(data_indices, "cifar10_noniid.pkl") 11 | 12 | data_indices = random_slicing(trainset, num_clients=100) 13 | save_dict(data_indices, "cifar10_iid.pkl") 14 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/cifar10-cnn/client.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | 10 | torch.manual_seed(0) 11 | 12 | from fedlab.core.client.scale.trainer import SubsetSerialTrainer 13 | from fedlab.core.client.scale.manager import ScaleClientPassiveManager 14 | from fedlab.core.network import DistNetwork 15 | 16 | from fedlab.utils.serialization import SerializationTool 17 | from fedlab.utils.logger import Logger 18 | from fedlab.utils.aggregator import Aggregators 19 | from fedlab.utils.functional import load_dict 20 | 21 | import sys 22 | 23 | sys.path.append("../../../") 24 | from models.cnn import AlexNet_CIFAR10, CNN_CIFAR10 25 | 26 | from config import cifar10_noniid_baseline_config, cifar10_iid_baseline_config 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser(description="Distbelief training example") 30 | 31 | parser.add_argument("--ip", type=str, default="127.0.0.1") 32 | parser.add_argument("--port", type=str, default="3003") 33 | parser.add_argument("--world_size", type=int) 34 | parser.add_argument("--rank", type=int) 35 | parser.add_argument("--ethernet", type=str, default=None) 36 | 37 | parser.add_argument("--setting", type=str) 38 | args = parser.parse_args() 39 | 40 | if args.setting == 'iid': 41 | config = cifar10_iid_baseline_config 42 | else: 43 | config = cifar10_noniid_baseline_config 44 | 45 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 46 | 47 | transform_train = transforms.Compose([ 48 | transforms.RandomCrop(32, padding=4), 49 | transforms.RandomHorizontalFlip(), 50 | transforms.ToTensor(), 51 | transforms.Normalize((0.4914, 0.4822, 0.4465), 52 | (0.2023, 0.1994, 0.2010)) 53 | ]) 54 | 55 | trainset = torchvision.datasets.CIFAR10(root='../../../datasets/cifar10/', 56 | train=True, 57 | download=True, 58 | transform=transform_train) 59 | 60 | if config['partition'] == "noniid": 61 | data_indices = load_dict("cifar10_noniid.pkl") 62 | if config['partition'] == "iid": 63 | data_indices = load_dict("cifar10_iid.pkl") 64 | 65 | # Process rank x represent client id from (x-1)*10 - (x-1)*10 +10 66 | # e.g. rank 5 <--> client 40-50 67 | client_id_list = [ 68 | i for i in range((args.rank - 1) * 10, (args.rank - 1) * 10 + 10) 69 | ] 70 | 71 | # get corresponding data partition indices 72 | sub_data_indices = { 73 | idx: data_indices[cid] 74 | for idx, cid in enumerate(client_id_list) 75 | } 76 | 77 | #model = CNN_Cifar10() 78 | model = AlexNet_CIFAR10() 79 | 80 | aggregator = Aggregators.fedavg_aggregate 81 | 82 | network = DistNetwork(address=(args.ip, args.port), 83 | world_size=args.world_size, 84 | rank=args.rank, 85 | ethernet=args.ethernet) 86 | 87 | trainer = SubsetSerialTrainer(model=model, 88 | dataset=trainset, 89 | data_slices=sub_data_indices, 90 | aggregator=aggregator, 91 | args=config) 92 | 93 | manager_ = ScaleClientPassiveManager(trainer=trainer, network=network) 94 | 95 | manager_.run() -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/cifar10-cnn/config.py: -------------------------------------------------------------------------------- 1 | cifar10_iid_baseline_config = { 2 | "partition": "iid", 3 | "round": 4000, 4 | "network": "alexnet", 5 | "sample_ratio": 0.1, 6 | "dataset": "cifar10", 7 | "total_client_num": 100, 8 | "lr": 0.1, 9 | "batch_size": 100, 10 | "epochs": 5 11 | } 12 | 13 | cifar10_noniid_baseline_config = { 14 | "partition": "noniid", 15 | "round": 4000, 16 | "network": "alexnet", 17 | "sample_ratio": 0.1, 18 | "dataset": "cifar10", 19 | "total_client_num": 100, 20 | "lr": 0.1, 21 | "batch_size": 100, 22 | "epochs": 5 23 | } -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/cifar10-cnn/start_clt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # TODO: try to add auto client assignment script 3 | # TODO: bash start_clt.sh [world_size] [num_clients] [world_size] 4 | # TODO: where world_size = 1 + client_ranks_num 5 | # bash start_clt.sh 11 1 10 [iid/noniid] 6 | for ((i=$2; i<=$3; i++)) 7 | do 8 | { 9 | echo "client ${i} started" 10 | python client.py --world_size $1 --rank ${i} --setting $4& 11 | sleep 2s 12 | } 13 | done 14 | wait 15 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/femnist-cnn/server.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torchvision 8 | from torchvision import transforms 9 | 10 | torch.manual_seed(0) 11 | 12 | from fedlab.core.server.handler import SyncParameterServerHandler 13 | from fedlab.core.server.scale.manager import ScaleSynchronousManager 14 | from fedlab.core.network import DistNetwork 15 | from fedlab.utils.functional import AverageMeter 16 | 17 | sys.path.append("../../../") 18 | from models.cnn import CNN_FEMNIST 19 | 20 | # python server.py --world_size 11 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser(description='FL server example') 23 | 24 | parser.add_argument('--ip', type=str, default="127.0.0.1") 25 | parser.add_argument('--port', type=str, default="3002") 26 | parser.add_argument('--world_size', type=int) 27 | 28 | parser.add_argument('--round', type=int, default=1000) 29 | parser.add_argument('--ethernet', type=str, default=None) 30 | parser.add_argument('--sample', type=float, default=0.01) 31 | 32 | args = parser.parse_args() 33 | 34 | model = CNN_FEMNIST() 35 | 36 | handler = SyncParameterServerHandler(model, 37 | global_round=args.round, 38 | sample_ratio=args.sample, 39 | cuda=True) 40 | 41 | network = DistNetwork(address=(args.ip, args.port), 42 | world_size=args.world_size, 43 | rank=0) 44 | 45 | manager_ = ScaleSynchronousManager(network=network, handler=handler) 46 | manager_.run() 47 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/femnist-cnn/start_clt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for ((i=$2; i<=$3; i++)) 4 | do 5 | { 6 | echo "client ${i} started" 7 | python client.py --world_size $1 --rank ${i} & 8 | sleep 2s 9 | } 10 | done 11 | wait -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/client.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import sys 4 | import os 5 | 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | 9 | from fedlab.core.client.scale.trainer import SubsetSerialTrainer 10 | from fedlab.core.client.scale.manager import ScaleClientPassiveManager 11 | from fedlab.core.network import DistNetwork 12 | 13 | from fedlab.utils.logger import Logger 14 | from fedlab.utils.aggregator import Aggregators 15 | from fedlab.utils.functional import load_dict 16 | 17 | sys.path.append("../../../") 18 | from models.cnn import CNN_MNIST 19 | 20 | if __name__ == "__main__": 21 | 22 | parser = argparse.ArgumentParser(description="Distbelief training example") 23 | 24 | parser.add_argument("--ip", type=str, default="127.0.0.1") 25 | parser.add_argument("--port", type=str, default="3002") 26 | parser.add_argument("--world_size", type=int) 27 | parser.add_argument("--rank", type=int) 28 | 29 | parser.add_argument("--partition", type=str, default="noniid") 30 | 31 | parser.add_argument("--gpu", type=str, default="0,1,2,3") 32 | parser.add_argument("--ethernet", type=str, default=None) 33 | 34 | args = parser.parse_args() 35 | 36 | if args.gpu != "-1": 37 | args.cuda = True 38 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 39 | else: 40 | args.cuda = False 41 | 42 | trainset = torchvision.datasets.MNIST( 43 | root='../../../datasets/mnist/', 44 | train=True, 45 | download=True, 46 | transform=transforms.ToTensor()) 47 | 48 | if args.partition == "noniid": 49 | data_indices = load_dict("mnist_noniid.pkl") 50 | elif args.partition == "iid": 51 | data_indices = load_dict("mnist_iid.pkl") 52 | else: 53 | raise ValueError("invalid partition type ", args.partition) 54 | 55 | # Process rank x represent client id from (x-1)*10 - (x-1)*10 +10 56 | # e.g. rank 5 <--> client 40-50 57 | client_id_list = [ 58 | i for i in range((args.rank - 1) * 10, (args.rank - 1) * 10 + 10) 59 | ] 60 | 61 | # get corresponding data partition indices 62 | sub_data_indices = { 63 | idx: data_indices[cid] 64 | for idx, cid in enumerate(client_id_list) 65 | } 66 | 67 | model = CNN_MNIST() 68 | 69 | aggregator = Aggregators.fedavg_aggregate 70 | 71 | network = DistNetwork(address=(args.ip, args.port), 72 | world_size=args.world_size, 73 | rank=args.rank, 74 | ethernet=args.ethernet) 75 | 76 | trainer = SubsetSerialTrainer(model=model, 77 | dataset=trainset, 78 | data_slices=sub_data_indices, 79 | aggregator=aggregator, 80 | args={ 81 | "batch_size": 100, 82 | "lr": 0.02, 83 | "epochs": 5 84 | }) 85 | 86 | manager_ = ScaleClientPassiveManager(trainer=trainer, network=network) 87 | 88 | manager_.run() -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/mnist_iid.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/mnist_iid.pkl -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/mnist_noniid.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/mnist_noniid.pkl -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/mnist_partition.py: -------------------------------------------------------------------------------- 1 | from fedlab.utils.functional import save_dict 2 | from fedlab.utils.dataset.slicing import noniid_slicing, random_slicing 3 | 4 | import torchvision 5 | 6 | root = '../../../../datasets/data/mnist/' 7 | trainset = torchvision.datasets.MNIST(root=root, train=True, download=True) 8 | 9 | data_indices = noniid_slicing(trainset, num_clients=100, num_shards=200) 10 | save_dict(data_indices, "mnist_noniid.pkl") 11 | 12 | data_indices = random_slicing(trainset, num_clients=100) 13 | save_dict(data_indices, "mnist_iid.pkl") 14 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/start_clt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for ((i=$2; i<=$3; i++)) 4 | do 5 | { 6 | echo "client ${i} started" 7 | python client.py --ip 127.0.0.1 --port 3002 --world_size $1 --rank ${i} & 8 | sleep 2s 9 | } 10 | done 11 | wait -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/shakespeare-rnn/server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torchvision import transforms 7 | 8 | torch.manual_seed(0) 9 | 10 | from fedlab.core.server.handler import SyncParameterServerHandler 11 | from fedlab.core.server.scale.manager import ScaleSynchronousManager 12 | from fedlab.core.network import DistNetwork 13 | from fedlab.utils.functional import evaluate 14 | 15 | import sys 16 | 17 | sys.path.append('../../../') 18 | from models.rnn import RNN_Shakespeare 19 | from leaf.dataloader import get_LEAF_all_test_dataloader 20 | 21 | # python server.py --world_size 11 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser(description='FL server example') 24 | 25 | parser.add_argument('--ip', type=str, default="127.0.0.1") 26 | parser.add_argument('--port', type=str, default="3002") 27 | parser.add_argument('--world_size', type=int) 28 | 29 | parser.add_argument('--round', type=int, default=2) 30 | parser.add_argument('--ethernet', type=str, default=None) 31 | parser.add_argument('--sample', type=float, default=0.05) 32 | 33 | args = parser.parse_args() 34 | 35 | model = RNN_Shakespeare() 36 | 37 | handler = SyncParameterServerHandler(model, 38 | global_round=args.round, 39 | sample_ratio=args.sample, 40 | cuda=True) 41 | 42 | network = DistNetwork(address=(args.ip, args.port), 43 | world_size=args.world_size, 44 | rank=0) 45 | 46 | manager_ = ScaleSynchronousManager(network=network, handler=handler) 47 | manager_.run() 48 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/scale/shakespeare-rnn/start_clt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for ((i=$2; i<=$3; i++)) 4 | do 5 | { 6 | echo "client ${i} started" 7 | python client.py --world_size $1 --rank ${i} & 8 | sleep 2s 9 | } 10 | done 11 | wait -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/standalone/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/standalone/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/standalone/mnist_iid.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/standalone/mnist_iid.pkl -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.1.2/standalone/mnist_noniid.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/standalone/mnist_noniid.pkl -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.2.0/client.py: -------------------------------------------------------------------------------- 1 | from logging import log 2 | import torch 3 | import argparse 4 | import sys 5 | import os 6 | 7 | from torch import nn 8 | from fedlab.core.client.manager import PassiveClientManager 9 | from fedlab.core.client.trainer import SGDClientTrainer 10 | from fedlab.core.network import DistNetwork 11 | from fedlab.utils.logger import Logger 12 | 13 | from setting import get_model, get_dataset 14 | 15 | if __name__ == "__main__": 16 | 17 | parser = argparse.ArgumentParser(description="Distbelief training example") 18 | 19 | parser.add_argument("--ip", type=str) 20 | parser.add_argument("--port", type=str) 21 | parser.add_argument("--world_size", type=int) 22 | parser.add_argument("--rank", type=int) 23 | 24 | parser.add_argument("--lr", type=float, default=0.01) 25 | parser.add_argument("--epoch", type=int, default=5) 26 | parser.add_argument("--dataset", type=str, default="mnist") 27 | parser.add_argument("--batch_size", type=int, default=100) 28 | 29 | parser.add_argument("--gpu", type=str, default="0,1,2,3") 30 | parser.add_argument("--ethernet", type=str, default=None) 31 | args = parser.parse_args() 32 | 33 | if args.gpu != "-1": 34 | args.cuda = True 35 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 36 | else: 37 | args.cuda = False 38 | 39 | model = get_model(args) 40 | trainloader, testloader = get_dataset(args) 41 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) 42 | criterion = nn.CrossEntropyLoss() 43 | 44 | network = DistNetwork( 45 | address=(args.ip, args.port), 46 | world_size=args.world_size, 47 | rank=args.rank, 48 | ethernet=args.ethernet, 49 | ) 50 | 51 | LOGGER = Logger(log_name="client " + str(args.rank)) 52 | 53 | trainer = SGDClientTrainer( 54 | model, 55 | trainloader, 56 | epochs=args.epoch, 57 | optimizer=optimizer, 58 | criterion=criterion, 59 | cuda=args.cuda, 60 | logger=LOGGER, 61 | ) 62 | 63 | manager_ = PassiveClientManager(trainer=trainer, 64 | network=network, 65 | logger=LOGGER) 66 | manager_.run() 67 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.2.0/mnist_partition.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.2.0/mnist_partition.pkl -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.2.0/run.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | 3 | python server.py --ip 127.0.0.1 --port 3002 --world_size 11 --dataset mnist & 4 | 5 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 1 --dataset mnist & 6 | 7 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 2 --dataset mnist & 8 | 9 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 3 --dataset mnist & 10 | 11 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 4 --dataset mnist & 12 | 13 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 5 --dataset mnist & 14 | 15 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 6 --dataset mnist & 16 | 17 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 7 --dataset mnist & 18 | 19 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 8 --dataset mnist & 20 | 21 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 9 --dataset mnist & 22 | 23 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 10 --dataset mnist & 24 | 25 | wait -------------------------------------------------------------------------------- /fedlab_benchmarks/fedavg_v1.2.0/setting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | import sys 5 | 6 | from fedlab.utils.dataset.sampler import RawPartitionSampler 7 | 8 | sys.path.append('../') 9 | 10 | from models.cnn import CNN_CIFAR10, CNN_FEMNIST, CNN_MNIST 11 | from models.rnn import RNN_Shakespeare 12 | from models.mlp import MLP_CelebA 13 | from leaf.dataloader import get_LEAF_dataloader 14 | 15 | def get_dataset(args): 16 | if args.dataset == 'mnist': 17 | root = '../datasets/mnist/' 18 | train_transform = transforms.Compose([ 19 | transforms.ToTensor(), 20 | ]) 21 | test_transform = transforms.Compose([ 22 | transforms.ToTensor(), 23 | ]) 24 | trainset = torchvision.datasets.MNIST(root=root, 25 | train=True, 26 | download=True, 27 | transform=train_transform) 28 | 29 | testset = torchvision.datasets.MNIST(root=root, 30 | train=False, 31 | download=True, 32 | transform=test_transform) 33 | 34 | trainloader = torch.utils.data.DataLoader( 35 | trainset, 36 | sampler=RawPartitionSampler(trainset, 37 | client_id=args.rank, 38 | num_replicas=args.world_size - 1), 39 | batch_size=args.batch_size, 40 | drop_last=True, 41 | num_workers=args.world_size) 42 | 43 | testloader = torch.utils.data.DataLoader(testset, 44 | batch_size=int( 45 | len(testset) / 10), 46 | drop_last=False, 47 | shuffle=False) 48 | elif args.dataset == 'femnist': 49 | trainloader, testloader = get_LEAF_dataloader(dataset=args.dataset, 50 | client_id=args.rank) 51 | elif args.dataset == 'shakespeare': 52 | trainloader, testloader = get_LEAF_dataloader(dataset=args.dataset, 53 | client_id=args.rank) 54 | elif args.dataset == 'celeba': 55 | trainloader, testloader = get_LEAF_dataloader(dataset=args.dataset, 56 | client_id=args.rank) 57 | else: 58 | raise ValueError("Invalid dataset:", args.dataset) 59 | 60 | return trainloader, testloader 61 | 62 | 63 | def get_model(args): 64 | if args.dataset == "mnist": 65 | model = CNN_MNIST() 66 | elif args.dataset == 'femnist': 67 | model = CNN_FEMNIST() 68 | elif args.dataset == 'shakespeare': 69 | model = RNN_Shakespeare() 70 | elif args.dataset == 'celeba': 71 | model = MLP_CelebA() 72 | else: 73 | raise ValueError("Invalid dataset:", args.dataset) 74 | 75 | return model -------------------------------------------------------------------------------- /fedlab_benchmarks/feddyn/Output/CIFAR10_100_iid_plots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/feddyn/Output/CIFAR10_100_iid_plots.png -------------------------------------------------------------------------------- /fedlab_benchmarks/feddyn/README.md: -------------------------------------------------------------------------------- 1 | ## FedDyn 2 | 3 | demo for FedDyn using FedLab in scale mode. 4 | 5 | ### Setting 6 | 7 | - ``dataset``: CIFAR10 8 | - `partition`: iid 9 | - `balance`: `True` 10 | - `batch_size`: 50 11 | - ``num_clients``: 100 12 | - `round`: 1000 13 | - `epochs`: 5 14 | - `lr`: 0.1 15 | - `alpha_coef`: 1e-2 16 | - `weight_decay`: 1e-3 17 | - `max_norm`: 10 18 | - `sample_ratio`: 1.0 19 | 20 | ### Requirements 21 | 22 | fedlab=1.1.2 23 | 24 | ### How to run? 25 | 26 | `start_server.sh` is for server process launch, and `start_clt.sh` is for client process launch. 27 | 28 | 1. run command in terminal window 1 to launch server: 29 | 30 | ```bash 31 | bash start_server.sh 32 | ``` 33 | 34 | 2. run command in terminal window 2 to launch clients: 35 | 36 | ```bash 37 | bash start_clt.sh 38 | ``` 39 | 40 | > random seed for data partiiton over clients can be set using `--seed` in `start_server.sh`: 41 | > 42 | > ```bash 43 | > python data_partition.py --out-dir ./Output/FedDyn/run1 --partition iid --balance True --dataset cifar10 --num-clients ${ClientNum} --seed 1 44 | > ``` 45 | 46 | 47 | 48 | We highly recommend to launch clients after server is launched to avoid some conficts. 49 | 50 | 51 | 52 | ### One-run Result 53 | 54 | | | FedDyn (Paper) | FedDyn (FedDyn code) | FedDyn (FedLab) | FedAvg (FedDyn code) | FedAvg (FedLab code) | 55 | | ------------------------ | :------------: | :------------------: | :-------------: | :------------------: | :------------------: | 56 | | Round for $acc>81.40\%$ | 67 | 64 | 65 | 491 | 423 | 57 | | Round for $acc>85.00\%$ | 198 | 185 | 195 | > 1000 | > 1000 | 58 | 59 | 60 | 61 | ### Duration 62 | 63 | | | FedAvg | FedDyn | 64 | | -------------- | :------------: | :------------: | 65 | | FedDyn code | 474.98 Min | 537.13 Min | 66 | | FedLab (scale) | __143.60 Min__ | __253.17 Min__ | 67 | 68 | ### Environment 69 | 70 | - CPU: 128G Memory, 32 cores, Intel(R) Core(TM) i9-9960X CPU @ 3.10GHz, 71 | - 4 * NVDIA GEFORCE RTX 2080 Ti 72 | 73 | 74 | 75 | ### Reference 76 | 77 | - Acar, D. A. E., Zhao, Y., Matas, R., Mattina, M., Whatmough, P., & Saligrama, V. (2020, September). Federated learning based on dynamic regularization. In *International Conference on Learning Representations*. 78 | 79 | -------------------------------------------------------------------------------- /fedlab_benchmarks/feddyn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 9/27/21 12:12 AM 3 | # @Author : Siqi Liang 4 | # @Contact : zszxlsq@gmail.com 5 | # @File : __init__.py 6 | # @Software: PyCharm 7 | -------------------------------------------------------------------------------- /fedlab_benchmarks/feddyn/config.py: -------------------------------------------------------------------------------- 1 | cifar10_config = { 2 | 'num_clients': 100, 3 | 'model_name': 'Cifar10Net', # Model type 4 | 'round': 1000, 5 | 'save_period': 200, 6 | 'weight_decay': 1e-3, 7 | 'batch_size': 50, 8 | 'test_batch_size': 256, # no this param in official code 9 | 'lr_decay_per_round': 1, 10 | 'epochs': 5, 11 | 'lr': 0.1, 12 | 'print_freq': 5, 13 | 'alpha_coef': 1e-2, 14 | 'max_norm': 10, 15 | 'sample_ratio': 1, 16 | 'partition': 'iid', 17 | 'dataset': 'cifar10', 18 | } 19 | 20 | 21 | debug_config = { 22 | 'num_clients': 30, 23 | 'model_name': 'Cifar10Net', # Model type 24 | 'round': 5, 25 | 'save_period': 2, 26 | 'weight_decay': 1e-3, 27 | 'batch_size': 50, 28 | 'test_batch_size': 50, 29 | 'act_prob': 1, 30 | 'lr_decay_per_round': 1, 31 | 'epochs': 5, 32 | 'lr': 0.1, 33 | 'print_freq': 1, 34 | 'alpha_coef': 1e-2, 35 | 'max_norm': 10, 36 | 'sample_ratio': 1, 37 | 'partition': 'iid', 38 | 'dataset': 'cifar10' 39 | } 40 | 41 | # usage: local_params_file_pattern.format(cid=cid) 42 | local_grad_vector_file_pattern = "client_{cid:03d}_local_grad_vector.pt" # accumulated model gradient 43 | clnt_params_file_pattern = "client_{cid:03d}_clnt_params.pt" # latest model param 44 | 45 | local_grad_vector_list_file_pattern = "client_rank_{rank:02d}_local_grad_vector_list.pt" # accumulated model gradient for clients in one client process 46 | clnt_params_list_file_pattern = "client_rank_{rank:02d}_clnt_params_list.pt" # latest model param for clients in one client process 47 | -------------------------------------------------------------------------------- /fedlab_benchmarks/feddyn/data_partition.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | from torchvision.datasets import CIFAR10, CIFAR100 6 | 7 | import sys 8 | sys.path.append("../../../FedLab/") 9 | 10 | from fedlab.utils.dataset import CIFAR10Partitioner, CIFAR100Partitioner 11 | from fedlab.utils.functional import partition_report, save_dict, load_dict 12 | 13 | 14 | def get_exp_name(args): 15 | exp_name = "" 16 | args_dict = vars(args) 17 | exclude_keys = ["out_dit", "data_dir"] 18 | 19 | for key in sorted(args_dict.keys()): 20 | exp_name += f"{key}_" 21 | if key not in exclude_keys: 22 | value = args_dict[key] 23 | if isinstance(value, float): 24 | exp_name += f"{value:.3f}_" 25 | else: 26 | exp_name += f"{value}_" 27 | 28 | return exp_name[:-1] 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser(description="FedDyn implementation: Client scale mode") 33 | 34 | parser.add_argument("--data-dir", type=str, default="../../../datasets/") 35 | parser.add_argument("--out-dir", type=str, default="./Output/") 36 | parser.add_argument("--dataset", type=str, default="cifar10", 37 | help="Currently only 'cifar10' and 'cifar100' are supported") 38 | parser.add_argument("--num-clients", type=int, default=100) 39 | parser.add_argument("--partition", type=str, default="iid") 40 | parser.add_argument("--balance", type=bool, default=None) 41 | parser.add_argument("--unbalance-sgm", type=float, default=0) 42 | parser.add_argument("--dir-alpha", type=float, default=None) 43 | parser.add_argument("--num-shards", type=int, default=None) 44 | parser.add_argument("--seed", type=int, default=0) 45 | 46 | args = parser.parse_args() 47 | 48 | Path(args.data_dir).mkdir(parents=True, exist_ok=True) 49 | Path(args.out_dir).mkdir(parents=True, exist_ok=True) 50 | 51 | if args.dataset == "cifar10": 52 | trainset = CIFAR10(root=os.path.join(args.data_dir, 'CIFAR10'), train=True, download=True) 53 | partitioner = CIFAR10Partitioner 54 | # elif args.dataset == "cifar100": 55 | # trainset = CIFAR100(root=os.path.join(args.data_dir, 'CIFAR100'), train=True, download=True) 56 | # partitioner = CIFAR100Partitoner 57 | else: 58 | raise ValueError(f"{args.dataset} is not supported yet.") 59 | 60 | partition = partitioner(targets=trainset.targets, 61 | num_clients=args.num_clients, 62 | balance=args.balance, 63 | partition=args.partition, 64 | unbalance_sgm=args.unbalance_sgm, 65 | num_shards=args.num_shards, 66 | dir_alpha=args.dir_alpha, 67 | seed=args.seed, 68 | verbose=True) 69 | file_name = f"{args.dataset}_{args.partition}.pkl" # get_exp_name(args) + ".pkl" 70 | save_dict(partition.client_dict, 71 | path=os.path.join(args.out_dir, file_name)) 72 | -------------------------------------------------------------------------------- /fedlab_benchmarks/feddyn/start_clt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # ========================================== 3 | # =============== EXPERIMENT =============== 4 | # ========================================== 5 | ClientRankNum=10 6 | ClientNumPerRank=10 7 | ClientNum=$(($ClientNumPerRank * $ClientRankNum)) 8 | WorldSize=$(($ClientRankNum + 1)) 9 | 10 | # ------ FedAvg 11 | # for ((i = 1; i <= ${ClientRankNum}; i++)); do 12 | # { 13 | # echo "client ${i} started" 14 | # python client_starter.py --world_size ${WorldSize} --rank ${i} --client-num-per-rank ${ClientNumPerRank} --alg FedAvg --out-dir ./Output/FedAvg/run1 & 15 | # sleep 2s 16 | # } 17 | # done 18 | 19 | # wait 20 | 21 | # ------ FedDyn 22 | for ((i = 1; i <= ${ClientRankNum}; i++)); do 23 | { 24 | echo "client ${i} started" 25 | python client_starter.py --world_size ${WorldSize} --rank ${i} --client-num-per-rank ${ClientNumPerRank} --alg FedDyn --out-dir ./Output/FedDyn/run1 & 26 | sleep 2s 27 | } 28 | done 29 | 30 | wait 31 | 32 | 33 | -------------------------------------------------------------------------------- /fedlab_benchmarks/feddyn/start_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # ========================================== 3 | # =============== EXPERIMENT =============== 4 | # ========================================== 5 | ClientRankNum=10 6 | ClientNumPerRank=10 7 | ClientNum=$(($ClientNumPerRank * $ClientRankNum)) 8 | WorldSize=$(($ClientRankNum + 1)) 9 | # balance iid cifar10 for 100 clients, check config.py for other setting 10 | python data_partition.py --out-dir ./Output/FedDyn/run1 --partition iid --balance True --dataset cifar10 --num-clients ${ClientNum} --seed 1 11 | echo -e "Data partition DONE.\n\n" 12 | sleep 4s 13 | 14 | # # ----- FedAvg 15 | # SECONDS=0 16 | 17 | # python server_starter.py --world_size ${WorldSize} --partition iid --alg FedAvg --out-dir ./Output/FedAvg/run1 18 | 19 | # ELAPSED="Elapsed: $(($SECONDS / 3600))hrs $((($SECONDS / 60) % 60))min $(($SECONDS % 60))sec" 20 | # echo $ELAPSED 21 | 22 | # ------- FedDyn 23 | SECONDS=0 24 | 25 | python server_starter.py --world_size ${WorldSize} --partition iid --alg FedDyn --out-dir ./Output/FedDyn/run1 26 | 27 | ELAPSED="Elapsed: $(($SECONDS / 60))min $(($SECONDS % 60))sec" 28 | echo $ELAPSED 29 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedmgda+/README.md: -------------------------------------------------------------------------------- 1 | # FedMGDA+: Federated Learning meets Multi-objective Optimization 2 | 3 | An reproduction of paper ["Fedmgda+: Federated learning meets multi-objective optimization"](https://arxiv.org/abs/2006.11489) via fedlab. 4 | 5 | Thanks to GitHub repo https://github.com/WwZzz/easyFL. 6 | 7 | ## Requirements 8 | 9 | fedlab==1.2.0 10 | 11 | ## Run 12 | 13 | $ bash run.sh 14 | 15 | ## Performance 16 | 17 | Null. 18 | 19 | ## References 20 | 21 | Hu, Zeou, et al. "Fedmgda+: Federated learning meets multi-objective optimization." arXiv preprint arXiv:2006.11489 (2020). 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedmgda+/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedmgda+/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/fedmgda+/mnist_iid_100.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedmgda+/mnist_iid_100.pkl -------------------------------------------------------------------------------- /fedlab_benchmarks/fedmgda+/mnist_noniid.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedmgda+/mnist_noniid.pkl -------------------------------------------------------------------------------- /fedlab_benchmarks/fedmgda+/mnist_noniid_200_100.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedmgda+/mnist_noniid_200_100.pkl -------------------------------------------------------------------------------- /fedlab_benchmarks/fedmgda+/run.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | 3 | python server.py --ip 127.0.0.1 --port 3002 --world_size 6 --dataset mnist --round 100 & 4 | 5 | python client.py --ip 127.0.0.1 --port 3002 --world_size 6 --rank 1 --dataset mnist & 6 | 7 | python client.py --ip 127.0.0.1 --port 3002 --world_size 6 --rank 2 --dataset mnist & 8 | 9 | python client.py --ip 127.0.0.1 --port 3002 --world_size 6 --rank 3 --dataset mnist & 10 | 11 | python client.py --ip 127.0.0.1 --port 3002 --world_size 6 --rank 4 --dataset mnist & 12 | 13 | python client.py --ip 127.0.0.1 --port 3002 --world_size 6 --rank 5 --dataset mnist & 14 | 15 | 16 | wait -------------------------------------------------------------------------------- /fedlab_benchmarks/fedmgda+/setting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | import sys 5 | 6 | from fedlab.utils.dataset.sampler import RawPartitionSampler 7 | 8 | sys.path.append('../') 9 | 10 | from models.cnn import CNN_CIFAR10, CNN_FEMNIST, CNN_MNIST 11 | from models.rnn import RNN_Shakespeare 12 | from models.mlp import MLP_CelebA 13 | from leaf.dataloader import get_LEAF_dataloader 14 | 15 | def get_dataloader(args): 16 | if args.dataset == 'mnist': 17 | root = '../datasets/mnist/' 18 | trainset = torchvision.datasets.MNIST(root=root, 19 | train=True, 20 | download=True, 21 | transform=transforms.ToTensor()) 22 | 23 | testset = torchvision.datasets.MNIST(root=root, 24 | train=False, 25 | download=True, 26 | transform=transforms.ToTensor()) 27 | 28 | trainloader = torch.utils.data.DataLoader( 29 | trainset, 30 | sampler=RawPartitionSampler(trainset, 31 | client_id=args.rank, 32 | num_replicas=args.world_size - 1), 33 | batch_size=args.batch_size, 34 | drop_last=True, 35 | num_workers=args.world_size) 36 | 37 | testloader = torch.utils.data.DataLoader(testset, 38 | batch_size=int( 39 | len(testset) / 10), 40 | drop_last=False, 41 | shuffle=False) 42 | elif args.dataset == 'femnist': 43 | trainloader, testloader = get_LEAF_dataloader(dataset=args.dataset, 44 | client_id=args.rank) 45 | elif args.dataset == 'shakespeare': 46 | trainloader, testloader = get_LEAF_dataloader(dataset=args.dataset, 47 | client_id=args.rank) 48 | elif args.dataset == 'celeba': 49 | trainloader, testloader = get_LEAF_dataloader(dataset=args.dataset, 50 | client_id=args.rank) 51 | else: 52 | raise ValueError("Invalid dataset:", args.dataset) 53 | 54 | return trainloader, testloader 55 | 56 | 57 | def get_model(args): 58 | if args.dataset == "mnist": 59 | model = CNN_MNIST() 60 | elif args.dataset == 'femnist': 61 | model = CNN_FEMNIST() 62 | elif args.dataset == 'shakespeare': 63 | model = RNN_Shakespeare() 64 | elif args.dataset == 'celeba': 65 | model = MLP_CelebA() 66 | else: 67 | raise ValueError("Invalid dataset:", args.dataset) 68 | 69 | return model -------------------------------------------------------------------------------- /fedlab_benchmarks/fedmgda+/start_clients.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for ((i=$2; i<=$3; i++)) 4 | do 5 | { 6 | echo "client ${i} started" 7 | python client.py --ip 127.0.0.1 --port 3002 --world_size $1 --rank ${i} --scale True & 8 | sleep 1s 9 | } 10 | done 11 | wait -------------------------------------------------------------------------------- /fedlab_benchmarks/fedprox/README.md: -------------------------------------------------------------------------------- 1 | # Fedprox 2 | 3 | [Federated optimization in heterogeneous networks](https://proceedings.mlsys.org/papers/2020/176) 4 | 5 | ## Requirements 6 | 7 | fedlab==1.1.2 8 | 9 | ## Run 10 | 11 | You are wellcome to read sourcecodes. There are similar with the demos of FedAvg. 12 | 13 | ## Performance 14 | 15 | Null 16 | 17 | ## References 18 | 19 | Li, Tian, et al. "Federated optimization in heterogeneous networks." Proceedings of Machine Learning and Systems 2 (2020): 429-450. 20 | 21 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedprox/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedprox/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/fedprox/cross_process/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedprox/cross_process/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/fedprox/cross_process/client.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | import os 5 | 6 | from setting import get_model, get_dataset 7 | from torch import nn, optim 8 | from fedlab.core.client.manager import ClientPassiveManager 9 | from fedlab.core.network import DistNetwork 10 | from fedlab.utils.logger import Logger 11 | 12 | sys.path.append("../") 13 | from fedprox_trainer import FedProxTrainer 14 | 15 | if __name__ == "__main__": 16 | 17 | parser = argparse.ArgumentParser(description="Distbelief training example") 18 | 19 | parser.add_argument("--ip", type=str) 20 | 21 | parser.add_argument("--port", type=str) 22 | 23 | parser.add_argument("--world_size", type=int) 24 | 25 | parser.add_argument("--rank", type=int) 26 | 27 | parser.add_argument("--lr", type=float, default=0.01) 28 | 29 | parser.add_argument("--epoch", type=int, default=5) 30 | 31 | parser.add_argument("--dataset", type=str) 32 | 33 | parser.add_argument("--batch_size", type=int, default=100) 34 | 35 | parser.add_argument("--gpu", type=str, default="0,1,2,3") 36 | 37 | parser.add_argument("--ethernet", type=str, default=None) 38 | 39 | parser.add_argument( 40 | "--straggler", type=float, default=0.0 41 | ) # vaild value should be in range [0, 1] and mod 0.1 == 0 42 | 43 | parser.add_argument( 44 | "--optimizer", type=str, default="sgd" 45 | ) # valid value: {"sgd", "adam", "rmsprop"} 46 | 47 | parser.add_argument( 48 | "--mu", type=float, default=0.0 49 | ) # recommended value: {0.001, 0.01, 0.1, 1.0} 50 | 51 | args = parser.parse_args() 52 | 53 | if args.gpu != "-1": 54 | args.cuda = True 55 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 56 | device = torch.device(args.gpu) 57 | else: 58 | args.cuda = False 59 | device = torch.device("cpu") 60 | 61 | model = get_model(args).to(device) 62 | trainloader, testloader = get_dataset(args) 63 | optimizer_dict = dict(sgd=optim.SGD, adam=optim.Adam, rmsprop=optim.RMSprop) 64 | optimizer = optimizer_dict[args.optimizer](model.parameters(), lr=args.lr) 65 | criterion = nn.CrossEntropyLoss() 66 | 67 | network = DistNetwork( 68 | address=(args.ip, args.port), 69 | world_size=args.world_size, 70 | rank=args.rank, 71 | ethernet=args.ethernet, 72 | ) 73 | 74 | LOGGER = Logger(log_name="client " + str(args.rank)) 75 | 76 | trainer = FedProxTrainer( 77 | model=model, 78 | data_loader=trainloader, 79 | epochs=args.epoch, 80 | optimizer=optimizer, 81 | criterion=criterion, 82 | mu=args.mu, 83 | cuda=args.cuda, 84 | logger=LOGGER, 85 | ) 86 | 87 | manager_ = ClientPassiveManager(trainer=trainer, network=network, logger=LOGGER) 88 | manager_.run() 89 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedprox/cross_process/server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from fedlab.utils.logger import Logger 4 | from fedlab.core.server.handler import SyncParameterServerHandler 5 | from fedlab.core.server.manager import ServerSynchronousManager 6 | from fedlab.core.network import DistNetwork 7 | from setting import get_model, get_dataset 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser(description="FL server example") 11 | 12 | parser.add_argument("--ip", type=str) 13 | 14 | parser.add_argument("--port", type=str) 15 | 16 | parser.add_argument("--world_size", type=int) 17 | 18 | parser.add_argument("--round", type=int, default=5) 19 | 20 | parser.add_argument("--dataset", type=str) 21 | 22 | parser.add_argument("--ethernet", type=str, default=None) 23 | 24 | parser.add_argument("--sample", type=float, default=1) 25 | 26 | args = parser.parse_args() 27 | 28 | model = get_model(args) 29 | LOGGER = Logger(log_name="server") 30 | handler = SyncParameterServerHandler(model, 31 | global_round=args.round, 32 | logger=LOGGER, 33 | sample_ratio=args.sample) 34 | network = DistNetwork( 35 | address=(args.ip, args.port), 36 | world_size=args.world_size, 37 | rank=0, 38 | ethernet=args.ethernet, 39 | ) 40 | manager_ = ServerSynchronousManager(handler=handler, 41 | network=network, 42 | logger=LOGGER) 43 | manager_.run() 44 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedprox/cross_process/setting.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../../") 3 | 4 | import argparse 5 | from models.cnn import CNN_CIFAR10, CNN_FEMNIST, CNN_MNIST 6 | from models.rnn import RNN_Shakespeare 7 | from models.mlp import MLP_CelebA 8 | from leaf.dataloader import get_LEAF_dataloader 9 | from torchvision import transforms, datasets 10 | from torch.utils.data import DataLoader 11 | from fedlab.utils.dataset.sampler import RawPartitionSampler 12 | 13 | 14 | def get_model(args): 15 | if args.dataset == "mnist": 16 | model = CNN_MNIST() 17 | elif args.dataset == "femnist": 18 | model = CNN_FEMNIST() 19 | elif args.dataset == "shakespeare": 20 | model = RNN_Shakespeare() 21 | elif args.dataset == "celeba": 22 | model = MLP_CelebA() 23 | else: 24 | raise ValueError("Invalid dataset:", args.dataset) 25 | 26 | return model 27 | 28 | 29 | def get_dataset(args): 30 | if args.dataset == "mnist": 31 | root = "../../datasets/mnist/" 32 | train_transform = transforms.Compose([transforms.ToTensor(),]) 33 | test_transform = transforms.Compose([transforms.ToTensor(),]) 34 | trainset = datasets.MNIST( 35 | root=root, train=True, download=True, transform=train_transform 36 | ) 37 | 38 | testset = datasets.MNIST( 39 | root=root, train=False, download=True, transform=test_transform 40 | ) 41 | 42 | trainloader = DataLoader( 43 | trainset, 44 | sampler=RawPartitionSampler( 45 | trainset, client_id=args.rank, num_replicas=args.world_size - 1 46 | ), 47 | batch_size=args.batch_size, 48 | drop_last=True, 49 | num_workers=args.world_size, 50 | ) 51 | 52 | testloader = DataLoader( 53 | testset, batch_size=int(len(testset) / 10), drop_last=False, shuffle=False 54 | ) 55 | elif args.dataset == "femnist": 56 | trainloader, testloader = get_LEAF_dataloader( 57 | dataset=args.dataset, client_id=args.rank 58 | ) 59 | elif args.dataset == "shakespeare": 60 | trainloader, testloader = get_LEAF_dataloader( 61 | dataset=args.dataset, client_id=args.rank 62 | ) 63 | elif args.dataset == "celeba": 64 | trainloader, testloader = get_LEAF_dataloader( 65 | dataset=args.dataset, client_id=args.rank 66 | ) 67 | else: 68 | raise ValueError("Invalid dataset:", args.dataset) 69 | 70 | return trainloader, testloader 71 | 72 | -------------------------------------------------------------------------------- /fedlab_benchmarks/fedprox/fedprox_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | 4 | from fedlab.core.client import ClientSGDTrainer 5 | from fedlab.utils.serialization import SerializationTool 6 | from fedlab.utils import Logger 7 | from tqdm import tqdm 8 | 9 | 10 | class FedProxTrainer(ClientSGDTrainer): 11 | """FedProxTrainer. 12 | 13 | Details of FedProx are available in paper: https://arxiv.org/abs/1812.06127 14 | 15 | Args: 16 | model (torch.nn.Module): PyTorch model. 17 | data_loader (torch.utils.data.DataLoader): :class:`torch.utils.data.DataLoader` for this client. 18 | epochs (int): the number of local epoch. 19 | optimizer (torch.optim.Optimizer, optional): optimizer for this client's model. 20 | criterion (torch.nn.Loss, optional): loss function used in local training process. 21 | cuda (bool, optional): use GPUs or not. Default: ``True``. 22 | logger (Logger, optional): :object of :class:`Logger`. 23 | mu (float): hyper-parameter of FedProx. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | model, 29 | data_loader, 30 | epochs, 31 | optimizer, 32 | criterion, 33 | mu, 34 | cuda=True, 35 | logger=Logger(), 36 | ): 37 | super().__init__( 38 | model, data_loader, epochs, optimizer, criterion, cuda=cuda, logger=logger 39 | ) 40 | 41 | self.mu = mu 42 | 43 | def train(self, model_parameters): 44 | """Client trains its local model on local dataset. 45 | 46 | Args: 47 | model_parameters (torch.Tensor): Serialized model parameters. 48 | """ 49 | frz_model = deepcopy(self._model) 50 | SerializationTool.deserialize_model(frz_model, model_parameters) 51 | SerializationTool.deserialize_model( 52 | self._model, model_parameters 53 | ) # load parameters 54 | self._LOGGER.info("Local train procedure is running") 55 | for ep in range(self.epochs): 56 | self._model.train() 57 | for inputs, labels in tqdm( 58 | self._data_loader, desc="{}, Epoch {}".format(self._LOGGER.name, ep) 59 | ): 60 | if self.cuda: 61 | inputs, labels = inputs.cuda(self.gpu), labels.cuda(self.gpu) 62 | 63 | outputs = self._model(inputs) 64 | l1 = self.criterion(outputs, labels) 65 | l2 = 0.0 66 | 67 | for w0, w in zip(frz_model.parameters(), self._model.parameters()): 68 | l2 += torch.sum(torch.pow(w - w0, 2)) 69 | 70 | loss = l1 + 0.5 * self.mu * l2 71 | 72 | self.optimizer.zero_grad() 73 | loss.backward() 74 | self.optimizer.step() 75 | self._LOGGER.info("Local train procedure is finished") 76 | 77 | return self.model_parameters 78 | -------------------------------------------------------------------------------- /fedlab_benchmarks/leaf/README_tmp.md: -------------------------------------------------------------------------------- 1 | # PROCESS_DATA README 2 | 3 | This folders contains processed dataset pickle files for leaf datasets. 4 | 5 | You can run `create_pickle_dataset.py` script for processed leaf data 6 | 7 | Notice: 8 | 1. please make sure leaf dataset is downloaded and processed by leaf. (leaf code in `fedlab_benchmarks/datasets/data`) 9 | 2. please make sure `fedlab_benchmarks/datasets/data/{dataset_name}/{train,test}` path existing for train data and test data. 10 | 3. example script: 11 | `python create_pickle_dataset.py --data_root "../../datasets/data" --save_root "pickle_dataset" --dataset_name "shakespeare"` 12 | 3. usage example: 13 | ``` 14 | pdataset = PickleDataset(pickle_root="pickle_datasets", dataset_name="shakespeare") 15 | pdataset.create_pickle_dataset(data_root="../datasets") 16 | dataset = pdataset.get_dataset_pickle(dataset_type="test", client_id="2") 17 | ``` 18 | -------------------------------------------------------------------------------- /fedlab_benchmarks/leaf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/leaf/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/leaf/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # from .femnist_dataset import FemnistDataset 3 | # from .shakespeare_dataset import ShakespeareDataset 4 | # from .celeba_dataset import CelebADataset 5 | # from sent140_dataset import Sent140Dataset -------------------------------------------------------------------------------- /fedlab_benchmarks/leaf/dataset/celeba_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from PIL import Image 17 | from pathlib import Path 18 | from torch.utils.data import Dataset 19 | 20 | 21 | class CelebADataset(Dataset): 22 | def __init__(self, client_id: int, client_str: str, data: list, targets: list, image_root: str, transform=None): 23 | """get `Dataset` for CelebA dataset 24 | 25 | Args: 26 | client_id (int): client id 27 | client_str (str): client name string 28 | data (list): input image name list data 29 | targets (list): output label list 30 | """ 31 | self.client_id = client_id 32 | self.client_str = client_str 33 | self.image_root = Path(__file__).parent.resolve() / image_root 34 | self.transform = transform 35 | self.data = data 36 | self.targets = targets 37 | self._process_data_target() 38 | 39 | def _process_data_target(self): 40 | """process client's data and target 41 | """ 42 | data = [] 43 | targets = [] 44 | for idx in range(len(self.data)): 45 | image_path = self.image_root / self.data[idx] 46 | image = Image.open(image_path).convert('RGB') 47 | data.append(image) 48 | targets.append(torch.tensor(self.targets[idx], dtype=torch.long)) 49 | self.data = data 50 | self.targets = targets 51 | 52 | def __len__(self): 53 | return len(self.targets) 54 | 55 | def __getitem__(self, index): 56 | data = self.data[index] 57 | if self.transform: 58 | data = self.transform(data) 59 | target = self.targets[index] 60 | return data, target 61 | -------------------------------------------------------------------------------- /fedlab_benchmarks/leaf/dataset/femnist_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import torch 17 | from torch.utils.data import Dataset 18 | 19 | 20 | class FemnistDataset(Dataset): 21 | 22 | def __init__(self, client_id: int, client_str: str, data: list, targets: list): 23 | """get `Dataset` for femnist dataset 24 | 25 | Args: 26 | client_id (int): client id 27 | client_str (str): client name string 28 | data (list): image data list 29 | targets (list): image class target list 30 | """ 31 | self.client_id = client_id 32 | self.client_str = client_str 33 | self.data = data 34 | self.targets = targets 35 | self._process_data_target() 36 | 37 | def _process_data_target(self): 38 | """process client's data and target 39 | 40 | """ 41 | self.data = torch.tensor(self.data, dtype=torch.float32).reshape(-1, 1, 28, 28) 42 | self.targets = torch.tensor(self.targets, dtype=torch.long) 43 | 44 | def __len__(self): 45 | return len(self.targets) 46 | 47 | def __getitem__(self, index): 48 | return self.data[index], self.targets[index] 49 | -------------------------------------------------------------------------------- /fedlab_benchmarks/leaf/dataset/reddit_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /fedlab_benchmarks/leaf/dataset/shakespeare_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import torch 17 | from torch.utils.data import Dataset 18 | 19 | 20 | class ShakespeareDataset(Dataset): 21 | 22 | def __init__(self, client_id: int, client_str: str, data: list, targets: list): 23 | """get `Dataset` for shakespeare dataset 24 | 25 | Args: 26 | client_id (int): client id 27 | client_str (str): client name string 28 | data (list): sentence list data 29 | targets (list): next-character target list 30 | """ 31 | self.client_id = client_id 32 | self.client_str = client_str 33 | self.ALL_LETTERS, self.VOCAB_SIZE = self._build_vocab() 34 | self.data = data 35 | self.targets = targets 36 | self._process_data_target() 37 | 38 | def _build_vocab(self): 39 | """ according all letters to build vocab 40 | 41 | Vocabulary re-used from the Federated Learning for Text Generation tutorial. 42 | https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation 43 | 44 | Returns: 45 | all letters vocabulary list and length of vocab list 46 | """ 47 | ALL_LETTERS = "\n !\"&'(),-.0123456789:;>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz}" 48 | VOCAB_SIZE = len(ALL_LETTERS) 49 | return ALL_LETTERS, VOCAB_SIZE 50 | 51 | def _process_data_target(self): 52 | """process client's data and target 53 | """ 54 | self.data = torch.tensor( 55 | [self.__sentence_to_indices(sentence) for sentence in self.data]) 56 | self.targets = torch.tensor( 57 | [self.__letter_to_index(letter) for letter in self.targets]) 58 | 59 | def __sentence_to_indices(self, sentence: str): 60 | """Returns list of integer for character indices in ALL_LETTERS 61 | 62 | Args: 63 | sentence (str): input sentence 64 | 65 | Returns: a integer list of character indices 66 | """ 67 | indices = [] 68 | for c in sentence: 69 | indices.append(self.ALL_LETTERS.find(c)) 70 | return indices 71 | 72 | def __letter_to_index(self, letter: str): 73 | """Returns index in ALL_LETTERS of given letter 74 | 75 | Args: 76 | letter (char/str[0]): input letter 77 | 78 | Returns: int index of input letter 79 | """ 80 | index = self.ALL_LETTERS.find(letter) 81 | return index 82 | 83 | def __len__(self): 84 | return len(self.targets) 85 | 86 | def __getitem__(self, index): 87 | return self.data[index], self.targets[index] 88 | -------------------------------------------------------------------------------- /fedlab_benchmarks/leaf/gen_pickle_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Example: bash gen_pickle_dataset.sh "shakespeare" "../datasets" "./pickle_datasets" 4 | # Example: bash gen_pickle_dataset.sh "sent140" "../datasets" "./pickle_datasets" 1 5 | dataset=$1 6 | data_root=${2:-'../datasets'} 7 | pickle_root=${3:-'./pickle_datasets'} 8 | # for nlp datasets 9 | build_vocab=${4:-'0'} 10 | vocab_save_root=${5:-'./nlp_utils/dataset_vocab'} 11 | vector_save_root=${6:-'./nlp_utils/glove'} 12 | vocab_limit_size=${7:-'50000'} 13 | 14 | 15 | python pickle_dataset.py \ 16 | --dataset ${dataset} \ 17 | --data_root ${data_root} \ 18 | --pickle_root ${pickle_root} \ 19 | --build_vocab ${build_vocab} \ 20 | --vocab_save_root ${vocab_save_root} \ 21 | --vector_save_root ${vector_save_root} \ 22 | --vocab_limit_size ${vocab_limit_size} 23 | -------------------------------------------------------------------------------- /fedlab_benchmarks/leaf/nlp_utils/README.md: -------------------------------------------------------------------------------- 1 | # NLP——UTILS 2 | 3 | This folder contains some lightweight utils for nlp process. 4 | 5 | - get_glove.sh: glove download script, from http://nlp.stanford.edu/data/glove.6B.zip to get glove.6B.300d.txt. 6 | 7 | - build_vocab.sh: provide a way to sample some clients' train data for building a vocabulary for federated nlp tasks, 8 | which is a simple alternative method compared with use all data directly in other implementation. 9 | 10 | - sample_build_vocab.py: provide a way to sample some clients' train data for building a vocabulary for federated nlp tasks. 11 | 12 | - tokenizer.py: provide `class Tokenizer`, splitting an entire text into smaller units called tokens, such as individual words or terms. 13 | 14 | - vocab.py: provide `class Vocab`, to encapsulate vocabulary operations in nlp, 15 | such as getting word2idx for tokenized input data, and get vector list from pretrained word_vec_file, such as glove.6B.300d.txt. -------------------------------------------------------------------------------- /fedlab_benchmarks/leaf/nlp_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/leaf/nlp_utils/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/leaf/nlp_utils/download_glove.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This is modified by [LEAF/models/sent140/get_embs.sh] 3 | # https://github.com/TalwalkarLab/leaf/blob/master/models/sent140/get_embs.sh 4 | 5 | if [ ! -f 'glove.6B.300d.txt' ]; then 6 | wget http://nlp.stanford.edu/data/glove.6B.zip 7 | unzip glove.6B.zip 8 | rm glove.6B.50d.txt glove.6B.100d.txt glove.6B.200d.txt glove.6B.zip 9 | 10 | if [ ! -d ./glove ];then 11 | mkdir glove 12 | fi 13 | mv glove.6B.300d.txt ./glove 14 | echo "download glove.6B.300d.txt successfully" 15 | fi -------------------------------------------------------------------------------- /fedlab_benchmarks/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/models/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/models/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch.nn as nn 16 | 17 | 18 | class MLP_CelebA(nn.Module): 19 | """Used for celeba experiment""" 20 | 21 | def __init__(self): 22 | super(MLP_CelebA, self).__init__() 23 | self.fc1 = nn.Linear(12288, 2048) # image_size=64, 64*64*3 24 | self.relu1 = nn.ReLU() 25 | self.fc2 = nn.Linear(2048, 500) 26 | self.relu2 = nn.ReLU() 27 | self.fc3 = nn.Linear(500, 100) 28 | self.relu3 = nn.ReLU() 29 | self.fc4 = nn.Linear(100, 2) 30 | 31 | def forward(self, x): 32 | x = x.view(x.shape[0], -1) 33 | x = self.relu1(self.fc1(x)) 34 | x = self.relu2(self.fc2(x)) 35 | x = self.relu3(self.fc3(x)) 36 | x = self.fc4(x) 37 | return x 38 | 39 | 40 | class MLP(nn.Module): 41 | def __init__(self, input_size, output_size): 42 | super(MLP, self).__init__() 43 | self.fc1 = nn.Linear(input_size, 200) 44 | self.fc2 = nn.Linear(200, 200) 45 | self.fc3 = nn.Linear(200, output_size) 46 | self.relu = nn.ReLU() 47 | 48 | def forward(self, x): 49 | x = x.view(x.shape[0], -1) 50 | x = self.relu(self.fc1(x)) 51 | x = self.relu(self.fc2(x)) 52 | x = self.fc3(x) 53 | return x -------------------------------------------------------------------------------- /fedlab_benchmarks/perfedavg/README.md: -------------------------------------------------------------------------------- 1 | # Personalized-FedAvg 2 | 3 | Personalized-FedAvg: [Improving Federated Learning Personalization via Model Agnostic Meta Learning](https://arxiv.org/abs/1909.12488) 4 | 5 | ## Further reading 6 | 7 | - MAML: [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/abs/1703.03400) 8 | - Reptile: [On First-Order Meta-Learning Algorithms](https://arxiv.org/abs/1803.02999) 9 | - LEAF: [LEAF: A Benchmark for Federated Settings](https://arxiv.org/abs/1812.01097) 10 | ## Dataset 11 | 12 | - Experiments are run on FEMNIST(derived from [LEAF](https://github.com/TalwalkarLab/leaf)) 13 | 14 | - Download and preprocess FEMNIST: 15 | 16 | - `cd leaf/datasets/femnist ; sh ./preprocess.sh -s niid --iu 1.0 --sf 1.0 -k 0 -t user --tf 0.8` 17 | 18 | You can get more details about this command in `/leaf/datasets/femnist/README.md`. 19 | 20 | ☝ Make sure you have download and preprocess FEMNIST before running the experiment. 21 | 22 | 23 | 24 | ## Requirements 25 | 26 | fedlab==1.1.2 27 | 28 | ## Run 29 | 30 | There're two way to run experiment in **Linux**. I have already set all hyper parameters well according to paper. Of course those can be modified. You can check `utils.get_args()` for more details about all hyper parameters. 31 | 32 | ### Single-process 33 | 34 | ```python 35 | python single_process.py 36 | ``` 37 | 38 | ### Multi-process (needs more computational power) 39 | 40 | I have set 3 workers(process) to handle all training tasks. 41 | 42 | ```python 43 | cd multi_process/ ; sh quick_start.sh 44 | ``` 45 | 46 | ## Performance 47 | 48 | Evaluation result after fine-tuned is shown below. 49 | 50 | Communication round: `500` 51 | 52 | Fine-tune: outer loop: `100`; inner loop: `10` 53 | 54 | Personalization round: `5` 55 | 56 | | FedAvg local training epochs (5 clients) | Initial loss | Initial Acc | Personalized loss | Personalized Acc | 57 | | ---------------------------------------- | ------------ | ----------- | ----------------- | ---------------- | 58 | | 20 | 2.3022 | 79.35% | 1.5766 | 84.86% | 59 | | 10 | 1.8387 | 80.53% | 1.1231 | 87.22% | 60 | | 5 | 1.4899 | **83.19%** | 0.9809 | **88.97%** | 61 | | 2 | 1.4613 | 81.70% | 0.9359 | 88.49% | 62 | 63 | | FedAvg local training epochs (20 clients) | Initial loss | Initial Acc | Personalized loss | Personalized Acc | 64 | | ----------------------------------------- | ------------ | ----------- | ----------------- | ---------------- | 65 | | 20 | 2.2398 | 82.40% | 0.9756 | 90.29% | 66 | | 10 | 1.6560 | **83.23**% | 0.8488 | 90.72% | 67 | | 5 | 1.5485 | 81.48% | 0.7452 | **90.77**% | 68 | | 2 | 1.2707 | 82.48% | 0.7139 | 90.48% | 69 | 70 | Experiment result from [paper](https://arxiv.org/abs/1909.12488) is shown below 71 | 72 | ![paper_exp_res](image/paper_exp_res.png) 73 | 74 | I ascribe the gap between mine and paper's results to the difference of hyper parameters setting, and I think there is no big mistake in algorithm implementation. If it actually has, please open an issue at this repo or [PerFedAvg](https://github.com/KarhouTam/PerFedAvg). 🙏 75 | 76 | -------------------------------------------------------------------------------- /fedlab_benchmarks/perfedavg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/perfedavg/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/perfedavg/fine_tuner.py: -------------------------------------------------------------------------------- 1 | from sys import path 2 | 3 | path.append("../") 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from fedlab.utils.serialization import SerializationTool 8 | from fedlab.utils.logger import Logger 9 | from fedlab.utils.functional import get_best_gpu 10 | from fedlab.core.client.trainer import ClientTrainer 11 | from utils import get_optimizer 12 | from tqdm import trange 13 | from leaf.pickle_dataset import PickleDataset 14 | 15 | 16 | class LocalFineTuner(ClientTrainer): 17 | """ 18 | Args: 19 | model (torch.nn.Module): Global model's architecture 20 | optimizer_type (str): Local optimizer. 21 | optimizer_args (dict): Provides necessary args for build local optimizer. 22 | criterion (torch.nn.CrossEntropyLoss / torch.nn.MSELoss()): Local loss function. 23 | epochs (int): Num of local training epoch. Personalization's local epochs may differ from others. 24 | batch_size (int): Batch size of local training. 25 | cuda (bool): True for using GPUs. 26 | logger (fedlab.utils.Logger): Object of Logger. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | model, 32 | optimizer_type, 33 | optimizer_args, 34 | criterion, 35 | epochs, 36 | batch_size, 37 | cuda, 38 | logger=Logger(), 39 | ): 40 | super(LocalFineTuner, self).__init__(model, cuda) 41 | if torch.cuda.is_available() and cuda: 42 | self.device = get_best_gpu() 43 | else: 44 | self.device = torch.device("cpu") 45 | self.epochs = epochs 46 | self._criterion = criterion 47 | self._optimizer = get_optimizer(self._model, optimizer_type, optimizer_args) 48 | self._logger = logger 49 | self.batch_size = batch_size 50 | self.dataset = PickleDataset("femnist") 51 | 52 | def train(self, client_id, model_parameters): 53 | trainloader = DataLoader( 54 | self.dataset.get_dataset_pickle("train", client_id), self.batch_size 55 | ) 56 | SerializationTool.deserialize_model(self._model, model_parameters) 57 | gradients = [] 58 | for param in self._model.parameters(): 59 | gradients.append( 60 | torch.zeros(param.size(), requires_grad=True, device=param.device) 61 | ) 62 | for _ in trange(self.epochs, desc="client [{}]".format(client_id)): 63 | 64 | for x, y in trainloader: 65 | x, y = x.to(self.device), y.to(self.device) 66 | 67 | logit = self._model(x) 68 | loss = self._criterion(logit, y) 69 | 70 | self._optimizer.zero_grad() 71 | loss.backward() 72 | self._optimizer.step() 73 | 74 | for idx, param in enumerate(self._model.parameters()): 75 | gradients[idx].data.add_(param.grad.data) 76 | return gradients 77 | -------------------------------------------------------------------------------- /fedlab_benchmarks/perfedavg/image/paper_exp_res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/perfedavg/image/paper_exp_res.png -------------------------------------------------------------------------------- /fedlab_benchmarks/perfedavg/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | # Model's architecture refers to https://github.com/tensorflow/federated/blob/main/tensorflow_federated/python/simulation/models/mnist.py 4 | class EmnistCNN(nn.Module): 5 | def __init__(self): 6 | super(EmnistCNN, self).__init__() 7 | self.net = nn.Sequential( 8 | nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2), 9 | nn.ReLU(inplace=True), 10 | nn.MaxPool2d(2, stride=2, padding=0), 11 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2), 12 | nn.ReLU(inplace=True), 13 | nn.MaxPool2d(2, stride=2, padding=0), 14 | nn.Flatten(), 15 | nn.Linear(7 * 7 * 64, out_features=1024), 16 | nn.ReLU(inplace=True), 17 | nn.Dropout(0.6), 18 | nn.Linear(in_features=1024, out_features=62), 19 | ) 20 | 21 | def forward(self, x): 22 | return self.net(x) 23 | -------------------------------------------------------------------------------- /fedlab_benchmarks/perfedavg/multi_process/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/perfedavg/multi_process/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/perfedavg/multi_process/client.py: -------------------------------------------------------------------------------- 1 | from sys import path 2 | 3 | path.append("../") 4 | path.append("../../") 5 | 6 | import argparse 7 | from torch import nn 8 | from fedlab.core.network import DistNetwork 9 | from fedlab.utils.logger import Logger 10 | from client_manager import PerFedAvgClientManager 11 | from trainer import PerFedAvgTrainer 12 | from fine_tuner import LocalFineTuner 13 | from models import EmnistCNN 14 | from utils import get_args 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | args = get_args(parser) 19 | model = EmnistCNN() 20 | criterion = nn.CrossEntropyLoss() 21 | network = DistNetwork( 22 | address=(args.ip, args.port), 23 | world_size=args.world_size, 24 | rank=args.rank, 25 | ethernet=args.ethernet, 26 | ) 27 | 28 | LOGGER = Logger(log_name="client process " + str(args.rank)) 29 | 30 | perfedavg_trainer = PerFedAvgTrainer( 31 | model=model, 32 | optimizer_type="sgd", 33 | optimizer_args=dict(lr=args.local_lr), 34 | criterion=criterion, 35 | epochs=args.inner_loops, 36 | batch_size=args.batch_size, 37 | pers_round=args.pers_round, 38 | cuda=args.cuda, 39 | logger=Logger(log_name="node {}".format(args.rank)), 40 | ) 41 | 42 | finetuner = LocalFineTuner( 43 | model=model, 44 | optimizer_type="adam", 45 | optimizer_args=dict(lr=args.fine_tune_local_lr, betas=(0, 0.999)), 46 | criterion=criterion, 47 | epochs=args.fine_tune_inner_loops, 48 | batch_size=args.batch_size, 49 | cuda=args.cuda, 50 | logger=Logger(log_name="node {}".format(args.rank)), 51 | ) 52 | 53 | manager_ = PerFedAvgClientManager( 54 | network=network, fedavg_trainer=perfedavg_trainer, fine_tuner=finetuner, 55 | ) 56 | manager_.run() 57 | -------------------------------------------------------------------------------- /fedlab_benchmarks/perfedavg/multi_process/quick_start.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | python server.py --ip 127.0.0.1 --port 3002 --world_size 4 --rank 0 & 3 | 4 | python client.py --ip 127.0.0.1 --port 3002 --world_size 4 --rank 1 & 5 | 6 | python client.py --ip 127.0.0.1 --port 3002 --world_size 4 --rank 2 & 7 | 8 | python client.py --ip 127.0.0.1 --port 3002 --world_size 4 --rank 3 & 9 | 10 | wait -------------------------------------------------------------------------------- /fedlab_benchmarks/perfedavg/multi_process/server.py: -------------------------------------------------------------------------------- 1 | from sys import path 2 | 3 | path.append("../") 4 | 5 | import argparse 6 | from handler import PersonalizaitonHandler, FedAvgHandler, FineTuneHandler 7 | from fedlab.core.network import DistNetwork 8 | from fedlab.utils.logger import Logger 9 | from server_manager import PerFedAvgSyncServerManager 10 | from utils import get_args 11 | from models import EmnistCNN 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | args = get_args(parser) 16 | 17 | model = EmnistCNN() 18 | fedavg_handler = FedAvgHandler( 19 | model=model, 20 | global_round=args.epochs, 21 | client_num_in_total=int(0.8 * args.client_num_in_total), 22 | client_num_per_round=args.client_num_per_round, 23 | optimizer_type="momentum_sgd", 24 | optimizer_args=dict(lr=args.server_lr, momentum=0.9), 25 | cuda=args.cuda, 26 | logger=Logger(log_name="fedavg"), 27 | ) 28 | 29 | finetune_handler = ( 30 | FineTuneHandler( 31 | model=model, 32 | global_round=args.fine_tune_outer_loops, 33 | client_num_in_total=int(0.8 * args.client_num_in_total), 34 | client_num_per_round=args.client_num_per_round, 35 | optimizer_type="sgd", 36 | optimizer_args=dict(lr=args.fine_tune_server_lr), 37 | cuda=args.cuda, 38 | logger=Logger(log_name="fine-tune"), 39 | ) 40 | if args.fine_tune 41 | else None 42 | ) 43 | 44 | personalization_handler = PersonalizaitonHandler( 45 | model=model, 46 | global_round=args.test_round, 47 | client_num_in_total=args.client_num_in_total 48 | - int(0.8 * args.client_num_in_total), 49 | client_num_per_round=args.client_num_per_round, 50 | cuda=args.cuda, 51 | logger=Logger(log_name="personalization"), 52 | ) 53 | network = DistNetwork( 54 | address=(args.ip, args.port), 55 | world_size=args.world_size, 56 | rank=0, 57 | ethernet=args.ethernet, 58 | ) 59 | manager_ = PerFedAvgSyncServerManager( 60 | network=network, 61 | fedavg_handler=fedavg_handler, 62 | finetune_handler=finetune_handler, 63 | personalization_handler=personalization_handler, 64 | logger=Logger(log_name="manager_server"), 65 | ) 66 | 67 | manager_.run() 68 | -------------------------------------------------------------------------------- /fedlab_benchmarks/qfedavg/README.md: -------------------------------------------------------------------------------- 1 | # qFedAvg 2 | 3 | An reproduction of paper ["Fair resource allocation in federated learning"](https://arxiv.org/abs/1905.10497) via fedlab. 4 | 5 | ## Requirements 6 | 7 | fedlab==1.2.0 8 | 9 | ## Run 10 | 11 | $ bash run.sh 12 | 13 | ## Performance 14 | 15 | Null. 16 | 17 | ## References 18 | 19 | Li, Tian, et al. "Fair resource allocation in federated learning." arXiv preprint arXiv:1905.10497 (2019). 20 | -------------------------------------------------------------------------------- /fedlab_benchmarks/qfedavg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/qfedavg/__init__.py -------------------------------------------------------------------------------- /fedlab_benchmarks/qfedavg/mnist_iid_10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/qfedavg/mnist_iid_10.pkl -------------------------------------------------------------------------------- /fedlab_benchmarks/qfedavg/mnist_noniid_200_10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/qfedavg/mnist_noniid_200_10.pkl -------------------------------------------------------------------------------- /fedlab_benchmarks/qfedavg/run.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | 3 | python server.py --ip 127.0.0.1 --port 3002 --world_size 11 --dataset mnist & 4 | 5 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 1 --dataset mnist & 6 | 7 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 2 --dataset mnist & 8 | 9 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 3 --dataset mnist & 10 | 11 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 4 --dataset mnist & 12 | 13 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 5 --dataset mnist & 14 | 15 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 6 --dataset mnist & 16 | 17 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 7 --dataset mnist & 18 | 19 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 8 --dataset mnist & 20 | 21 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 9 --dataset mnist & 22 | 23 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 10 --dataset mnist & 24 | 25 | wait -------------------------------------------------------------------------------- /fedlab_benchmarks/report_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Algorithm 2 | 3 | Algorithm and introductions. 4 | 5 | ## Requirements 6 | 7 | Clarify requirements, especially the version of fedlab (our version is updated frequently). 8 | 9 | ## Run 10 | 11 | Explain how to start your codes. 12 | 13 | ## Performance 14 | 15 | Please report the performance. 16 | 17 | ## References 18 | 19 | References here 20 | 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fedlab 2 | scikit-learn 3 | spacy 4 | matplotlib 5 | scipy --------------------------------------------------------------------------------