├── .gitignore
├── LICENSE
├── README.md
├── fedlab_benchmarks
├── README.md
├── __init__.py
├── cfl
│ ├── README.md
│ ├── __init__.py
│ ├── cfl_trainer.py
│ ├── datasets.py
│ ├── helper.py
│ └── standalone.py
├── compressors
│ ├── __init__.py
│ ├── base_compressor.py
│ ├── compress_example
│ │ ├── __init__.py
│ │ ├── client.py
│ │ ├── quick_start.sh
│ │ └── server.py
│ ├── dgc.py
│ ├── qsgd.py
│ └── topk.py
├── datasets
│ ├── README.md
│ ├── __init__.py
│ ├── adult
│ │ ├── __init__.py
│ │ ├── adult.py
│ │ └── adult_tutorial.ipynb
│ ├── celeba
│ │ ├── README.md
│ │ ├── preprocess.sh
│ │ ├── preprocess
│ │ │ └── metadata_to_json.py
│ │ └── stats.sh
│ ├── cifar10
│ │ └── cifar10_partitioner.ipynb
│ ├── cifar100
│ │ ├── data_partitioner.ipynb
│ │ ├── imgs
│ │ │ ├── cifar100_balance_dir_alpha_0.3_100clients.png
│ │ │ ├── cifar100_balance_iid_100clients.png
│ │ │ ├── cifar100_hetero_dir_0.3_100clients.png
│ │ │ ├── cifar100_hetero_dir_0.3_100clients_dist.png
│ │ │ ├── cifar100_shards_200_100clients.png
│ │ │ ├── cifar100_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients.png
│ │ │ ├── cifar100_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients_dist.png
│ │ │ ├── cifar100_unbalance_iid_unbalance_sgm_0.3_100clients.png
│ │ │ └── cifar100_unbalance_iid_unbalance_sgm_0.3_100clients_dist.png
│ │ └── partition-reports
│ │ │ ├── cifar100_balance_dir_alpha_0.3_100clients.csv
│ │ │ ├── cifar100_balance_iid_100clients.csv
│ │ │ ├── cifar100_hetero_dir_0.3_100clients.csv
│ │ │ ├── cifar100_shards_200_100clients.csv
│ │ │ ├── cifar100_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients.csv
│ │ │ └── cifar100_unbalance_iid_unbalance_sgm_0.3_100clients.csv
│ ├── covtype
│ │ ├── __init__.py
│ │ ├── covtype.py
│ │ └── covtype_tutorial.ipynb
│ ├── fcube
│ │ ├── __init__.py
│ │ ├── fcube.py
│ │ └── fcube_tutorial.ipynb
│ ├── femnist
│ │ ├── README.md
│ │ ├── preprocess.sh
│ │ ├── preprocess
│ │ │ ├── data_to_json.py
│ │ │ ├── data_to_json.sh
│ │ │ ├── get_data.sh
│ │ │ ├── get_file_dirs.py
│ │ │ ├── get_hashes.py
│ │ │ ├── group_by_writer.py
│ │ │ └── match_hashes.py
│ │ └── stats.sh
│ ├── fmnist
│ │ └── fmnist_tutorial.ipynb
│ ├── imgs
│ │ ├── adult_iid_10clients.png
│ │ ├── adult_noniid-label1_10clients.png
│ │ ├── adult_noniid_labeldir_10clients.png
│ │ ├── adult_unbalance_10clients.png
│ │ ├── cifar10_balance_dir_alpha_0.3_100clients.png
│ │ ├── cifar10_balance_iid_100clients.png
│ │ ├── cifar10_hetero_dir_0.3_100clients.png
│ │ ├── cifar10_hetero_dir_0.3_100clients_dist.png
│ │ ├── cifar10_shards_200_100clients.png
│ │ ├── cifar10_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients.png
│ │ ├── cifar10_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients_dist.png
│ │ ├── cifar10_unbalance_iid_unbalance_sgm_0.3_100clients.png
│ │ ├── cifar10_unbalance_iid_unbalance_sgm_0.3_100clients_dist.png
│ │ ├── fcube_class_dist.png
│ │ ├── fcube_iid.png
│ │ ├── fcube_iid_part.png
│ │ ├── fcube_synthetic.png
│ │ ├── fcube_synthetic_original_paper.png
│ │ ├── fcube_synthetic_part.png
│ │ ├── fcube_test_dist_vis.png
│ │ ├── fcube_train_dist_vis.png
│ │ ├── fmnist_feature_skew_vis.png
│ │ ├── fmnist_iid_clients_10.png
│ │ ├── fmnist_noniid-label_1_clients_10.png
│ │ ├── fmnist_noniid-label_2_clients_10.png
│ │ ├── fmnist_noniid-label_3_clients_10.png
│ │ ├── fmnist_noniid_labeldir_clients_10.png
│ │ ├── fmnist_unbalance_clients_10.png
│ │ ├── fmnist_vis.png
│ │ ├── svhn_feature_skew_vis.png
│ │ └── svhn_vis.png
│ ├── mnist
│ │ ├── download_mnist.py
│ │ ├── mnist_iid.pkl
│ │ ├── mnist_noniid.pkl
│ │ └── mnist_partition.py
│ ├── partition-reports
│ │ ├── adult_iid_10clients.csv
│ │ ├── adult_noniid-label1_10clients.csv
│ │ ├── adult_noniid_labeldir_10clients.csv
│ │ ├── adult_unbalance_10clients.csv
│ │ ├── cifar10_balance_dir_alpha_0.3_100clients.csv
│ │ ├── cifar10_balance_iid_100clients.csv
│ │ ├── cifar10_hetero_dir_0.3_100clients.csv
│ │ ├── cifar10_shards_200_100clients.csv
│ │ ├── cifar10_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients.csv
│ │ ├── cifar10_unbalance_iid_unbalance_sgm_0.3_100clients.csv
│ │ ├── fcube_iid.csv
│ │ ├── fcube_synthetic.csv
│ │ ├── fmnist_iid_clients_10.csv
│ │ ├── fmnist_noniid-label_1_clients_10.csv
│ │ ├── fmnist_noniid-label_2_clients_10.csv
│ │ ├── fmnist_noniid-label_3_clients_10.csv
│ │ ├── fmnist_noniid_labeldir_clients_10.csv
│ │ └── fmnist_unbalance_clients_10.csv
│ ├── rcv1
│ │ ├── __init__.py
│ │ ├── rcv1.py
│ │ └── rcv1_tutorial.ipynb
│ ├── reddit
│ │ ├── README.md
│ │ ├── build_vocab.py
│ │ └── source
│ │ │ ├── clean_raw.py
│ │ │ ├── delete_small_users.py
│ │ │ ├── get_json.py
│ │ │ ├── get_raw_users.py
│ │ │ ├── merge_raw_users.py
│ │ │ ├── preprocess.py
│ │ │ ├── reddit_utils.py
│ │ │ └── run_reddit.sh
│ ├── sent140
│ │ ├── README.md
│ │ ├── preprocess.sh
│ │ ├── preprocess
│ │ │ ├── combine_data.py
│ │ │ ├── data_to_json.py
│ │ │ ├── data_to_json.sh
│ │ │ └── get_data.sh
│ │ └── stats.sh
│ ├── shakespeare
│ │ ├── README.md
│ │ ├── preprocess.sh
│ │ ├── preprocess
│ │ │ ├── data_to_json.sh
│ │ │ ├── gen_all_data.py
│ │ │ ├── get_data.sh
│ │ │ ├── preprocess_shakespeare.py
│ │ │ └── shake_utils.py
│ │ └── stats.sh
│ ├── svhn
│ │ └── svhn_tutorial.ipynb
│ ├── synthetic
│ │ ├── README.md
│ │ ├── data_generator.py
│ │ ├── main.py
│ │ ├── preprocess.sh
│ │ └── stats.sh
│ └── utils
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── preprocess.sh
│ │ ├── remove_users.py
│ │ ├── sample.py
│ │ ├── split_data.py
│ │ ├── stats.py
│ │ └── util.py
├── feature-skew-fedavg
│ ├── README.md
│ ├── __init__.py
│ ├── client.py
│ ├── config.py
│ ├── data_partition.py
│ ├── models.py
│ ├── server.py
│ ├── start_clt.sh
│ └── start_server.sh
├── fedasync
│ ├── README.md
│ ├── __init__.py
│ ├── cross_process
│ │ ├── client.py
│ │ ├── quick_start.sh
│ │ └── server.py
│ └── standalone
│ │ ├── __init__.py
│ │ ├── cifar10_iid.pkl
│ │ ├── cifar10_noniid.pkl
│ │ └── standalone.py
├── fedavg_v1.1.2
│ ├── README.md
│ ├── __init__.py
│ ├── cross_process
│ │ ├── LEAF_test.sh
│ │ ├── __init__.py
│ │ ├── client.py
│ │ ├── quick_start.sh
│ │ ├── server.py
│ │ ├── setting.py
│ │ └── start_clients.sh
│ ├── scale
│ │ ├── __init__.py
│ │ ├── cifar10-cnn
│ │ │ ├── cifar10_iid.pkl
│ │ │ ├── cifar10_noniid.pkl
│ │ │ ├── cifar10_partition.py
│ │ │ ├── client.py
│ │ │ ├── config.py
│ │ │ ├── server.py
│ │ │ └── start_clt.sh
│ │ ├── femnist-cnn
│ │ │ ├── client.py
│ │ │ ├── server.py
│ │ │ └── start_clt.sh
│ │ ├── mnist-cnn
│ │ │ ├── __init__.py
│ │ │ ├── client.py
│ │ │ ├── mnist_iid.pkl
│ │ │ ├── mnist_noniid.pkl
│ │ │ ├── mnist_partition.py
│ │ │ ├── server.py
│ │ │ └── start_clt.sh
│ │ └── shakespeare-rnn
│ │ │ ├── client.py
│ │ │ ├── server.py
│ │ │ └── start_clt.sh
│ └── standalone
│ │ ├── __init__.py
│ │ ├── mnist_iid.pkl
│ │ ├── mnist_noniid.pkl
│ │ └── standalone.py
├── fedavg_v1.2.0
│ ├── client.py
│ ├── mnist_partition.pkl
│ ├── run.sh
│ ├── server.py
│ ├── setting.py
│ └── standalone.py
├── feddyn
│ ├── Output
│ │ └── CIFAR10_100_iid_plots.png
│ ├── README.md
│ ├── __init__.py
│ ├── client.py
│ ├── client_starter.py
│ ├── config.py
│ ├── data_partition.py
│ ├── models.py
│ ├── results-plot.ipynb
│ ├── server.py
│ ├── server_starter.py
│ ├── start_clt.sh
│ ├── start_server.sh
│ └── utils.py
├── fedmgda+
│ ├── README.md
│ ├── __init__.py
│ ├── client.py
│ ├── mnist_iid_100.pkl
│ ├── mnist_noniid.pkl
│ ├── mnist_noniid_200_100.pkl
│ ├── run.sh
│ ├── server.py
│ ├── setting.py
│ ├── standalone.py
│ └── start_clients.sh
├── fedprox
│ ├── README.md
│ ├── __init__.py
│ ├── cross_process
│ │ ├── __init__.py
│ │ ├── client.py
│ │ ├── server.py
│ │ └── setting.py
│ ├── fedprox_trainer.py
│ └── standalone.py
├── leaf
│ ├── README.md
│ ├── README_tmp.md
│ ├── README_zh_cn.md
│ ├── __init__.py
│ ├── dataloader.py
│ ├── dataset
│ │ ├── __init__.py
│ │ ├── celeba_dataset.py
│ │ ├── femnist_dataset.py
│ │ ├── reddit_dataset.py
│ │ ├── sent140_dataset.py
│ │ └── shakespeare_dataset.py
│ ├── gen_pickle_dataset.sh
│ ├── nlp_utils
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── download_glove.sh
│ │ └── util.py
│ └── pickle_dataset.py
├── models
│ ├── __init__.py
│ ├── cnn.py
│ ├── mlp.py
│ └── rnn.py
├── perfedavg
│ ├── README.md
│ ├── __init__.py
│ ├── fine_tuner.py
│ ├── image
│ │ └── paper_exp_res.png
│ ├── models.py
│ ├── multi_process
│ │ ├── __init__.py
│ │ ├── client.py
│ │ ├── client_manager.py
│ │ ├── handler.py
│ │ ├── quick_start.sh
│ │ ├── server.py
│ │ └── server_manager.py
│ ├── single_process.py
│ ├── trainer.py
│ └── utils.py
├── qfedavg
│ ├── README.md
│ ├── __init__.py
│ ├── client.py
│ ├── mnist_iid_10.pkl
│ ├── mnist_noniid_200_10.pkl
│ ├── run.sh
│ └── server.py
└── report_TEMPLATE.md
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # System file
2 | *.DS_Store
3 |
4 | # Pycharm
5 | .idea
6 |
7 | # VS Code
8 | .vscode
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | share/python-wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 |
44 | # Installer logs
45 | pip-log.txt
46 | pip-delete-this-directory.txt
47 |
48 | # Unit test / coverage reports
49 | htmlcov/
50 | .tox/
51 | .nox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | *.py,cover
59 | .hypothesis/
60 | .pytest_cache/
61 | cover/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 | db.sqlite3-journal
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | # docs/_build/
82 | docs/build/
83 |
84 | # PyBuilder
85 | .pybuilder/
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | # For a library or package, you might want to ignore these files since the code is
97 | # intended to run in multiple environments; otherwise, check them in:
98 | # .python-version
99 |
100 | # pipenv
101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
104 | # install all needed dependencies.
105 | #Pipfile.lock
106 |
107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
108 | __pypackages__/
109 |
110 | # Celery stuff
111 | celerybeat-schedule
112 | celerybeat.pid
113 |
114 | # SageMath parsed files
115 | *.sage.py
116 |
117 | # Environments
118 | .env
119 | .venv
120 | env/
121 | venv/
122 | ENV/
123 | env.bak/
124 | venv.bak/
125 |
126 | # Spyder project settings
127 | .spyderproject
128 | .spyproject
129 |
130 | # Rope project settings
131 | .ropeproject
132 |
133 | # mkdocs documentation
134 | /site
135 |
136 | # mypy
137 | .mypy_cache/
138 | .dmypy.json
139 | dmypy.json
140 |
141 | # Pyre type checker
142 | .pyre/
143 |
144 | # pytype static type analyzer
145 | .pytype/
146 |
147 | # Cython debug symbols
148 | cython_debug/
149 |
150 | # datasets
151 | fedlab_benchmarks/datasets/cifar10/*
152 | !fedlab_benchmarks/datasets/cifar10/*.py
153 | !fedlab_benchmarks/datasets/cifar10/*.ipynb
154 | fedlab_benchmarks/datasets/*/data
155 | fedlab_benchmarks/datasets/*/meta
156 | fedlab_benchmarks/datasets/*/MNIST
157 | fedlab_benchmarks/datasets/*/*.pkl
158 |
159 | fedlab_benchmarks/leaf/pickle_datasets/*
160 | fedlab_benchmarks/leaf/pickle_datasets/*
161 |
162 | actions-runner
163 |
164 | # tests
165 | fedlab_benchmarks/fedavg/standalone/*.txt
166 | fedlab_benchmarks/fedavg/standalone/*.ipynb
167 | fedlab_benchmarks/fedavg/standalone/*.sh
168 |
169 | .script.ipynb
170 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |

2 |
3 | # FedLab-benchmarks
4 |
5 | **For future maintain and convenient organization, this repo is deprecated. We have merged all codes into the main repo [FedLab](https://github.com/SMILELab-FL/FedLab) for the latest version.**
6 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 |
18 | ROOT_DIR = os.path.join(os.path.abspath(__file__))
19 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/cfl/README.md:
--------------------------------------------------------------------------------
1 | # CFL
2 |
3 | [Clustered Federated Learning: Model-Agnostic Distributed Multi-Task Optimization under Privacy Constraints](https://arxiv.org/abs/1910.01991)
4 |
5 |
6 | ## Requirements
7 |
8 | fedlab==1.2.1
9 |
10 | ## Run
11 |
12 | We implement shifted and rotated data generation. And we accept the real-world data, e.g. femnist.
13 |
14 | For shifted and rotated augmented data, we accept `mnist` and `cifar10` configuration now.
15 |
16 | And we also provide the rotated 0 and 180 degree `emnist` used in the [original paper's code implementation](https://github.com/felisat/clustered-federated-learning#clustered-federated-learning-model-agnostic-distributed-multi-task-optimization-under-privacy-constraints)
17 |
18 | If you don't have the augmented data, you should set the config `process_data=1` to generate data firstly.
19 | And the augmented data will be saved in `args.save_dir`, which defaults to `./datasets`.
20 | You can see and modify the two parameters `save_dir` and `root`, which represent the augmented data storage path and origin data read path respectively.
21 | `save_dir` defaults to `./datasets`, and `root` defaults to `../datasets/{dataset_name}`
22 |
23 | For real-world data, we refer [LEAF](https://github.com/TalwalkarLab/leaf) to simulate. In our benchmark, we implement in `fedlab_benchmarks/leaf/`.
24 | You can read the docs in `leaf` folder to generate data. And cfl will use the defaulted leaf data saving path `../leaf/pickle_datasets/` to get data.
25 |
26 | ## Performance
27 |
28 | Null
29 |
30 | ## References
31 |
32 | [1] Sattler, Felix, Klaus-Robert Müller, and Wojciech Samek. "Clustered Federated Learning: Model-Agnostic Distributed Multi-Task Optimization under Privacy Constraints." arXiv preprint arXiv:1910.01991 (2019).
33 | [2] [Resource Code](https://github.com/felisat/clustered-federated-learning#clustered-federated-learning-model-agnostic-distributed-multi-task-optimization-under-privacy-constraints)
34 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/cfl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/cfl/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/cfl/helper.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is copied from [clustered-federated-learning/models.py]
3 | https://github.com/felisat/clustered-federated-learning/blob/master/helper.py
4 | """
5 |
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 |
9 |
10 | class ExperimentLogger:
11 | def log(self, values):
12 | for k, v in values.items():
13 | if k not in self.__dict__:
14 | self.__dict__[k] = [v]
15 | else:
16 | self.__dict__[k] += [v]
17 |
18 |
19 | def display_train_stats(cfl_stats, eps_1, eps_2, communication_rounds):
20 | plt.figure(figsize=(12, 4))
21 |
22 | plt.subplot(1, 2, 1)
23 | acc_mean = np.mean(cfl_stats.acc_clusters, axis=1)
24 | acc_std = np.std(cfl_stats.acc_clusters, axis=1)
25 | plt.fill_between(cfl_stats.rounds, acc_mean - acc_std, acc_mean + acc_std, alpha=0.5, color="C0")
26 | plt.plot(cfl_stats.rounds, acc_mean, color="C0")
27 |
28 | if "split" in cfl_stats.__dict__:
29 | for s in cfl_stats.split:
30 | plt.axvline(x=s, linestyle="-", color="k", label=r"Split")
31 |
32 | plt.text(x=communication_rounds, y=1, ha="right", va="top",
33 | s=f"Clusters: {cfl_stats.clusters[-1]}")
34 |
35 | plt.xlabel("Communication Rounds")
36 | plt.ylabel("Accuracy")
37 |
38 | plt.xlim(0, communication_rounds)
39 | plt.ylim(0, 1)
40 |
41 | plt.subplot(1, 2, 2)
42 |
43 | plt.plot(cfl_stats.rounds, cfl_stats.mean_norm, color="C1", label=r"$\|\sum_i\Delta W_i \|$")
44 | plt.plot(cfl_stats.rounds, cfl_stats.max_norm, color="C2", label=r"$\max_i\|\Delta W_i \|$")
45 |
46 | plt.axhline(y=eps_1, linestyle="--", color="k", label=r"$\varepsilon_1$")
47 | plt.axhline(y=eps_2, linestyle=":", color="k", label=r"$\varepsilon_2$")
48 |
49 | if "split" in cfl_stats.__dict__:
50 | for s in cfl_stats.split:
51 | plt.axvline(x=s, linestyle="-", color="k", label=r"Split")
52 |
53 | plt.xlabel("Communication Rounds")
54 | plt.legend()
55 |
56 | plt.xlim(0, communication_rounds)
57 | # plt.ylim(0, 2)
58 |
59 | plt.show()
--------------------------------------------------------------------------------
/fedlab_benchmarks/compressors/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .topk import TopkCompressor
16 | from .qsgd import QSGDCompressor
17 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/compressors/base_compressor.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from abc import ABC
16 |
17 |
18 | class Compressor(ABC):
19 | def __init__(self) -> None:
20 | super().__init__()
21 |
22 | def compress(self, *args, **kwargs):
23 | raise NotImplementedError()
24 |
25 | def decompress(self, *args, **kwargs):
26 | raise NotImplementedError()
27 |
28 |
29 | class Memory(ABC):
30 | def __init__(self) -> None:
31 | super().__init__()
32 |
33 | def initialize(self, *args, **kwargs):
34 | raise NotImplementedError()
35 |
36 | def compensate(self, tensor, *args, **kwargs):
37 | raise NotImplementedError()
38 |
39 | def update(self, *args, **kwargs):
40 | raise NotImplementedError()
41 |
42 | def state_dict(self):
43 | raise NotImplementedError()
44 |
45 | def load_state_dict(self, state_dict):
46 | raise NotImplementedError()
--------------------------------------------------------------------------------
/fedlab_benchmarks/compressors/compress_example/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/compressors/compress_example/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/compressors/compress_example/quick_start.sh:
--------------------------------------------------------------------------------
1 | #!bin/bash
2 |
3 | python server.py --ip 127.0.0.1 --port 3002 --world_size 3 --dataset mnist --round 3 &
4 |
5 | python client.py --ip 127.0.0.1 --port 3002 --world_size 3 --rank 1 --dataset mnist &
6 |
7 | python client.py --ip 127.0.0.1 --port 3002 --world_size 3 --rank 2 --dataset mnist &
8 |
9 | wait
--------------------------------------------------------------------------------
/fedlab_benchmarks/compressors/compress_example/server.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 |
4 | from fedlab.core import communicator
5 | from fedlab.utils.logger import Logger
6 | from fedlab.core.server.handler import SyncParameterServerHandler
7 | from fedlab.core.server.manager import ServerSynchronousManager
8 | from fedlab.core.network import DistNetwork
9 | from fedlab.compressor.topk import TopkCompressor
10 | from fedlab.utils.message_code import MessageCode
11 | from fedlab.core.communicator.processor import PackageProcessor
12 |
13 | sys.path.append('../')
14 | from models.cnn import CNN_MNIST
15 |
16 | class CompressServerManager(ServerSynchronousManager):
17 | def __init__(self, network, handler, logger=None):
18 | super().__init__(network, handler, logger=logger)
19 | self.tpkc = TopkCompressor(compress_ratio=0.5)
20 |
21 | def on_receive(self, sender, message_code, payload):
22 | if message_code == MessageCode.ParameterUpdate:
23 | print(sender, message_code, payload[0].shape)
24 | #_, _, paylaod = PackageProcessor.recv_package(src=sender)
25 | #print("------", len(paylaod))
26 | return True
27 | else:
28 | raise Exception("Unexpected message code {}".format(message_code))
29 |
30 |
31 | if __name__ == "__main__":
32 | parser = argparse.ArgumentParser(description='FL server example')
33 |
34 | parser.add_argument('--ip', type=str)
35 | parser.add_argument('--port', type=str)
36 | parser.add_argument('--world_size', type=int)
37 |
38 | parser.add_argument('--round', type=int, default=1)
39 | parser.add_argument('--dataset', type=str)
40 | parser.add_argument('--ethernet', type=str, default=None)
41 | parser.add_argument('--sample', type=float, default=1)
42 |
43 | args = parser.parse_args()
44 |
45 | model = CNN_MNIST()
46 | LOGGER = Logger(log_name="server")
47 | handler = SyncParameterServerHandler(model,
48 | global_round=args.round,
49 | logger=LOGGER,
50 | sample_ratio=args.sample)
51 | network = DistNetwork(address=(args.ip, args.port),
52 | world_size=args.world_size,
53 | rank=0,
54 | ethernet=args.ethernet)
55 |
56 | manager_ = CompressServerManager(handler=handler,
57 | network=network,
58 | logger=LOGGER)
59 |
60 | manager_.run()
61 |
62 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/README.md:
--------------------------------------------------------------------------------
1 | # DATASETS README
2 |
3 | This folder contains download and preprocessed scripts for commonly used datasets, and provides leaf dataset interface.
4 |
5 |
6 | - `data` folder: contains download and preprocess scripts for commonly used datasets, each subfolder is named by dataset name.
7 |
8 | For LEAF dataset, it contain `celeba`, `femnist`, `reddit`, `sent140`, `shakespeare`, `synthetic`, whose download and preprocess scripts are copied by [LEAF-Github](https://github.com/TalwalkarLab/leaf). And we copy and modify from [Flower leaf script](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/scripts/leaf/femnist]) to store processed `Dataset` in the form of pickle files.
9 |
10 | For leaf dataset folders, run `create_datasets_and_save.sh` to get partitioned data. Also we can edit preprocess.sh command params to get a different partition way.
11 |
12 | - `leaf_data_process` folder: contains process method to read leaf data and get dataloader for users.
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .fcube import FCUBE
17 | from .adult import Adult
18 | from .rcv1 import RCV1
19 | from .covtype import Covtype
20 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/adult/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .adult import Adult
16 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/adult/adult.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | from sklearn.datasets import load_svmlight_file
17 | import random
18 | import os
19 | from urllib.request import urlretrieve
20 |
21 | from torch.utils.data import Dataset
22 |
23 |
24 | class Adult(Dataset):
25 | url = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/"
26 | train_file_name = "a9a"
27 | test_file_name = "a9a.t"
28 | num_classes = 2
29 | num_features = 123
30 |
31 | def __init__(self, root, train=True,
32 | transform=None,
33 | target_transform=None,
34 | download=False):
35 | self.root = root
36 | self.train = train
37 | self.transform = transform
38 | self.target_transform = target_transform
39 |
40 | if not os.path.exists(root):
41 | os.mkdir(root)
42 |
43 | if self.train:
44 | self.full_file_path = os.path.join(self.root, self.train_file_name)
45 | else:
46 | self.full_file_path = os.path.join(self.root, self.test_file_name)
47 |
48 | if download:
49 | self.download()
50 |
51 | if not self._local_file_existence():
52 | raise RuntimeError(
53 | f"Adult-a9a source data file not found. You can use download=True to "
54 | f"download it.")
55 |
56 | # now load from source file
57 | X, y = load_svmlight_file(self.full_file_path)
58 | X = X.todense() # transform
59 |
60 | if not self.train:
61 | X = np.c_[X, np.zeros((len(y), 1))] # append a zero vector at the end of X_test
62 |
63 | X = np.array(X, dtype=np.float32)
64 | y = np.array((y + 1) / 2, dtype=np.int32) # map elements of y from {-1, 1} to {0, 1}
65 | print(f"Local file {self.full_file_path} loaded.")
66 | self.data, self.targets = X, y
67 |
68 | def download(self):
69 | if self._local_file_existence():
70 | print(f"Source file already downloaded.")
71 | return
72 |
73 | if self.train:
74 | download_url = self.url + self.train_file_name
75 | else:
76 | download_url = self.url + self.test_file_name
77 |
78 | urlretrieve(download_url, self.full_file_path)
79 |
80 | def _local_file_existence(self):
81 | return os.path.exists(self.full_file_path)
82 |
83 | def __getitem__(self, index):
84 | data, label = self.data[index], self.targets[index]
85 |
86 | if self.transform is not None:
87 | data = self.transform(data)
88 |
89 | if self.target_transform is not None:
90 | label = self.target_transform(label)
91 |
92 | return data, label
93 |
94 | def __len__(self):
95 | return len(self.targets)
96 |
97 | def extra_repr(self) -> str:
98 | return "Split: {}".format("Train" if self.train is True else "Test")
99 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/celeba/preprocess.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # download data and convert to .json format
4 |
5 | if [ ! -d "data/raw/img_align_celeba" ] || [ ! "$(ls -A data/raw/img_align_celeba)" ] || [ ! -f "data/raw/list_attr_celeba.txt" ]; then
6 | echo "Please download the celebrity faces dataset and attributes file from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html"
7 | exit 1
8 | fi
9 |
10 | if [ ! -f "data/raw/identity_CelebA.txt" ]; then
11 | echo "Please request the celebrity identities file from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html"
12 | exit 1
13 | fi
14 |
15 | if [ ! -d "data/all_data" ] || [ ! "$(ls -A data/all_data)" ]; then
16 | echo "Preprocessing raw data"
17 | python preprocess/metadata_to_json.py
18 | fi
19 |
20 | NAME="celeba" # name of the dataset, equivalent to directory name
21 |
22 |
23 | cd ../utils
24 |
25 | bash preprocess.sh --name $NAME $@
26 |
27 | cd ../$NAME
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/celeba/preprocess/metadata_to_json.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import os
4 |
5 |
6 | TARGET_NAME = 'Smiling'
7 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
8 |
9 |
10 | def get_metadata():
11 | f_identities = open(os.path.join(
12 | parent_path, 'data', 'raw', 'identity_CelebA.txt'), 'r')
13 | identities = f_identities.read().split('\n')
14 |
15 | f_attributes = open(os.path.join(
16 | parent_path, 'data', 'raw', 'list_attr_celeba.txt'), 'r')
17 | attributes = f_attributes.read().split('\n')
18 |
19 | return identities, attributes
20 |
21 |
22 | def get_celebrities_and_images(identities):
23 | all_celebs = {}
24 |
25 | for line in identities:
26 | info = line.split()
27 | if len(info) < 2:
28 | continue
29 | image, celeb = info[0], info[1]
30 | if celeb not in all_celebs:
31 | all_celebs[celeb] = []
32 | all_celebs[celeb].append(image)
33 |
34 | good_celebs = {c: all_celebs[c] for c in all_celebs if len(all_celebs[c]) >= 5}
35 | return good_celebs
36 |
37 |
38 | def _get_celebrities_by_image(identities):
39 | good_images = {}
40 | for c in identities:
41 | images = identities[c]
42 | for img in images:
43 | good_images[img] = c
44 | return good_images
45 |
46 |
47 | def get_celebrities_and_target(celebrities, attributes, attribute_name=TARGET_NAME):
48 | col_names = attributes[1]
49 | col_idx = col_names.split().index(attribute_name)
50 |
51 | celeb_attributes = {}
52 | good_images = _get_celebrities_by_image(celebrities)
53 |
54 | for line in attributes[2:]:
55 | info = line.split()
56 | if len(info) == 0:
57 | continue
58 |
59 | image = info[0]
60 | if image not in good_images:
61 | continue
62 |
63 | celeb = good_images[image]
64 | att = (int(info[1:][col_idx]) + 1) / 2
65 |
66 | if celeb not in celeb_attributes:
67 | celeb_attributes[celeb] = []
68 |
69 | celeb_attributes[celeb].append(att)
70 |
71 | return celeb_attributes
72 |
73 |
74 | def build_json_format(celebrities, targets):
75 | all_data = {}
76 |
77 | celeb_keys = [c for c in celebrities]
78 | num_samples = [len(celebrities[c]) for c in celeb_keys]
79 | data = {c: {'x': celebrities[c], 'y': targets[c]} for c in celebrities}
80 |
81 | all_data['users'] = celeb_keys
82 | all_data['num_samples'] = num_samples
83 | all_data['user_data'] = data
84 | return all_data
85 |
86 |
87 | def write_json(json_data):
88 | file_name = 'all_data.json'
89 | dir_path = os.path.join(parent_path, 'data', 'all_data')
90 |
91 | if not os.path.exists(dir_path):
92 | os.mkdir(dir_path)
93 |
94 | file_path = os.path.join(dir_path, file_name)
95 |
96 | print('writing {}'.format(file_name))
97 | with open(file_path, 'w') as outfile:
98 | json.dump(json_data, outfile)
99 |
100 |
101 | def main():
102 | identities, attributes = get_metadata()
103 | celebrities = get_celebrities_and_images(identities)
104 | targets = get_celebrities_and_target(celebrities, attributes)
105 |
106 | json_data = build_json_format(celebrities, targets)
107 | write_json(json_data)
108 |
109 |
110 | if __name__ == '__main__':
111 | main()
112 |
113 |
114 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/celeba/stats.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | NAME="celeba"
4 |
5 | cd ../utils
6 |
7 | python3 stats.py --name $NAME
8 |
9 | cd ../$NAME
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_balance_dir_alpha_0.3_100clients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_balance_dir_alpha_0.3_100clients.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_balance_iid_100clients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_balance_iid_100clients.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_hetero_dir_0.3_100clients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_hetero_dir_0.3_100clients.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_hetero_dir_0.3_100clients_dist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_hetero_dir_0.3_100clients_dist.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_shards_200_100clients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_shards_200_100clients.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients_dist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients_dist.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_unbalance_iid_unbalance_sgm_0.3_100clients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_unbalance_iid_unbalance_sgm_0.3_100clients.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_unbalance_iid_unbalance_sgm_0.3_100clients_dist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/cifar100/imgs/cifar100_unbalance_iid_unbalance_sgm_0.3_100clients_dist.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/covtype/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .covtype import Covtype
16 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/fcube/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .fcube import FCUBE
17 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/femnist/README.md:
--------------------------------------------------------------------------------
1 | # FEMNIST Dataset
2 |
3 | ## Setup Instructions
4 | - pip3 install numpy
5 | - pip3 install pillow
6 | - Run ```./preprocess.sh``` with a choice of the following tags:
7 | - ```-s``` := 'iid' to sample in an i.i.d. manner, or 'niid' to sample in a non-i.i.d. manner; more information on i.i.d. versus non-i.i.d. is included in the 'Notes' section
8 | - ```--iu``` := number of users, if iid sampling; expressed as a fraction of the total number of users; default is 0.01
9 | - ```--sf``` := fraction of data to sample, written as a decimal; default is 0.1
10 | - ```-k``` := minimum number of samples per user
11 | - ```-t``` := 'user' to partition users into train-test groups, or 'sample' to partition each user's samples into train-test groups
12 | - ```--tf``` := fraction of data in training set, written as a decimal; default is 0.9
13 | - ```--smplseed``` := seed to be used before random sampling of data
14 | - ```--spltseed``` := seed to be used before random split of data
15 |
16 | i.e.
17 | - ```./preprocess.sh -s niid --sf 1.0 -k 0 -t sample``` (full-sized dataset)
18 | - ```./preprocess.sh -s niid --sf 0.05 -k 0 -t sample``` (small-sized dataset)
19 |
20 | Make sure to delete the rem_user_data, sampled_data, test, and train subfolders in the data directory before re-running preprocess.sh
21 |
22 | ## Notes
23 | - More details on i.i.d. versus non-i.i.d.:
24 | - In the i.i.d. sampling scenario, each datapoint is equally likely to be sampled. Thus, all users have the same underlying distribution of data.
25 | - In the non-i.i.d. sampling scenario, the underlying distribution of data for each user is consistent with the raw data. Since we assume that data distributions vary between user in the raw data, we refer to this sampling process as non-i.i.d.
26 | - More details on ```preprocess.sh```:
27 | - The order in which ```preprocess.sh``` processes data is 1. generating all_data, 2. sampling, 3. removing users, and 4. creating train-test split. The script will look at the data in the last generated directory and continue preprocessing from that point. For example, if the ```all_data``` directory has already been generated and the user decides to skip sampling and only remove users with the ```-k``` tag (i.e. running ```preprocess.sh -k 50```), the script will effectively apply a remove user filter to data in ```all_data``` and place the resulting data in the ```rem_user_data``` directory.
28 | - File names provide information about the preprocessing steps taken to generate them. For example, the ```all_data_niid_1_keep_64.json``` file was generated by first sampling 10 percent (.1) of the data ```all_data.json``` in a non-i.i.d. manner and then applying the ```-k 64``` argument to the resulting data.
29 | - Each .json file is an object with 3 keys:
30 | 1. 'users', a list of users
31 | 2. 'num_samples', a list of the number of samples for each user, and
32 | 3. 'user_data', an object with user names as keys and their respective data as values; for each user, data is represented as a list of images, with each image represented as a size-784 integer list (flattened from 28 by 28)
33 | - Run ```./stats.sh``` to get statistics of data (data/all_data/all_data.json must have been generated already)
34 | - In order to run reference implementations in ```../models``` directory, the ```-t sample``` tag must be used when running ```./preprocess.sh```
35 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/femnist/preprocess.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # download data and convert to .json format
4 |
5 | if [ ! -d "data/all_data" ] || [ ! "$(ls -A data/all_data)" ]; then
6 | cd preprocess
7 | bash data_to_json.sh
8 | cd ..
9 | else
10 | echo "using existing data/all_data data folder to preprocess"
11 | fi
12 |
13 | NAME="femnist" # name of the dataset, equivalent to directory name
14 |
15 | cd ../utils
16 |
17 | bash preprocess.sh --name $NAME $@
18 |
19 | cd ../$NAME
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/femnist/preprocess/data_to_json.py:
--------------------------------------------------------------------------------
1 | # Converts a list of (writer, [list of (file,class)]) tuples into a json object
2 | # of the form:
3 | # {users: [bob, etc], num_samples: [124, etc.],
4 | # user_data: {bob : {x:[img1,img2,etc], y:[class1,class2,etc]}, etc}}
5 | # where 'img_' is a vectorized representation of the corresponding image
6 |
7 | from __future__ import division
8 | import json
9 | import math
10 | import numpy as np
11 | import os
12 | import sys
13 |
14 | from PIL import Image
15 |
16 | utils_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
17 | utils_dir = os.path.join(utils_dir, 'utils')
18 | sys.path.append(utils_dir)
19 |
20 | import util
21 |
22 | MAX_WRITERS = 100 # max number of writers per json file.
23 |
24 |
25 | def relabel_class(c):
26 | '''
27 | maps hexadecimal class value (string) to a decimal number
28 | returns:
29 | - 0 through 9 for classes representing respective numbers
30 | - 10 through 35 for classes representing respective uppercase letters
31 | - 36 through 61 for classes representing respective lowercase letters
32 | '''
33 | if c.isdigit() and int(c) < 40:
34 | return (int(c) - 30)
35 | elif int(c, 16) <= 90: # uppercase
36 | return (int(c, 16) - 55)
37 | else:
38 | return (int(c, 16) - 61)
39 |
40 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
41 |
42 | by_writer_dir = os.path.join(parent_path, 'data', 'intermediate', 'images_by_writer')
43 | writers = util.load_obj(by_writer_dir)
44 |
45 | num_json = int(math.ceil(len(writers) / MAX_WRITERS))
46 |
47 | users = []
48 | num_samples = []
49 | user_data = {}
50 |
51 | writer_count, all_writers = 0, 0
52 | json_index = 0
53 | for (w, l) in writers:
54 |
55 | users.append(w)
56 | num_samples.append(len(l))
57 | user_data[w] = {'x': [], 'y': []}
58 |
59 | size = 28, 28 # original image size is 128, 128
60 | for (f, c) in l:
61 | file_path = os.path.join(parent_path, f)
62 | img = Image.open(file_path)
63 | gray = img.convert('L')
64 | gray.thumbnail(size, Image.ANTIALIAS)
65 | arr = np.asarray(gray).copy()
66 | vec = arr.flatten()
67 | vec = vec / 255 # scale all pixel values to between 0 and 1
68 | vec = vec.tolist()
69 |
70 | nc = relabel_class(c)
71 |
72 | user_data[w]['x'].append(vec)
73 | user_data[w]['y'].append(nc)
74 |
75 | writer_count += 1
76 | all_writers += 1
77 |
78 | if writer_count == MAX_WRITERS or all_writers == len(writers):
79 |
80 | all_data = {}
81 | all_data['users'] = users
82 | all_data['num_samples'] = num_samples
83 | all_data['user_data'] = user_data
84 |
85 | file_name = 'all_data_%d.json' % json_index
86 | file_path = os.path.join(parent_path, 'data', 'all_data', file_name)
87 |
88 | print('writing %s' % file_name)
89 |
90 | with open(file_path, 'w') as outfile:
91 | json.dump(all_data, outfile)
92 |
93 | writer_count = 0
94 | json_index += 1
95 |
96 | users[:] = []
97 | num_samples[:] = []
98 | user_data.clear()
99 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/femnist/preprocess/data_to_json.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # assumes that the script is run in the preprocess folder
4 |
5 | if [ ! -d "../data" ]; then
6 | mkdir ../data
7 | fi
8 | if [ ! -d "../data/raw_data" ]; then
9 | mkdir ../data/raw_data
10 | fi
11 |
12 | # download and unzip
13 | if [ ! -d "../data/raw_data/by_class" ] || [ ! -d "../data/raw_data/by_write" ]; then
14 | echo "------------------------------"
15 | echo "downloading and unzipping raw data"
16 | bash get_data.sh
17 | echo "finished downloading and unzipping raw data"
18 | else
19 | echo "using existing unzipped raw data folders (by_class and by_write)"
20 | fi
21 |
22 |
23 | echo "generating intermediate data"
24 | if [ ! -d "../data/intermediate" ]; then # stores .pkl files during preprocessing
25 | mkdir ../data/intermediate
26 | fi
27 |
28 | if [ ! -f ../data/intermediate/class_file_dirs.pkl ]; then
29 | echo "------------------------------"
30 | echo "extracting file directories of images"
31 | python3 get_file_dirs.py
32 | echo "finished extracting file directories of images"
33 | else
34 | echo "using existing data/intermediate/class_file_dirs.pkl"
35 | fi
36 |
37 | if [ ! -f ../data/intermediate/class_file_hashes.pkl ]; then
38 | echo "------------------------------"
39 | echo "calculating image hashes"
40 | python3 get_hashes.py
41 | echo "finished calculating image hashes"
42 | else
43 | echo "using existing data/intermediate/class_file_hashes.pkl"
44 | fi
45 |
46 | if [ ! -f ../data/intermediate/write_with_class.pkl ]; then
47 | echo "------------------------------"
48 | echo "assigning class labels to write images"
49 | python3 match_hashes.py
50 | echo "finished assigning class labels to write images"
51 | else
52 | echo "using existing data/intermediate/write_with_class.pkl"
53 | fi
54 |
55 | if [ ! -f ../data/intermediate/images_by_writer.pkl ]; then
56 | echo "------------------------------"
57 | echo "grouping images by writer"
58 | python3 group_by_writer.py
59 | echo "finished grouping images by writer"
60 | else
61 | echo "using existing data/intermediate/images_by_writer.pkl"
62 | fi
63 |
64 | if [ ! -d "../data/all_data" ]; then
65 | mkdir ../data/all_data
66 | fi
67 | if [ ! "$(ls -A ../data/all_data)" ]; then
68 | echo "------------------------------"
69 | echo "converting data to .json format in data/all_data"
70 | python3 data_to_json.py
71 | echo "finished converting data to .json format"
72 | fi
73 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/femnist/preprocess/get_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # assumes that the script is run in the preprocess folder
4 |
5 | cd ../data/raw_data
6 | if [ ! -f by_class.zip ]; then
7 | echo "downloading by_class.zip"
8 | wget https://s3.amazonaws.com/nist-srd/SD19/by_class.zip
9 | else
10 | echo "using existing by_class.zip"
11 | fi
12 |
13 | if [ ! -f by_write.zip ]; then
14 | echo "downloading by_write.zip"
15 | wget https://s3.amazonaws.com/nist-srd/SD19/by_write.zip
16 | else
17 | echo "using existing by_write.zip"
18 | fi
19 |
20 | if [ ! -d "by_class" ]; then
21 | echo "unzipping by_class.zip"
22 | unzip by_class.zip
23 | #rm by_class.zip
24 | else
25 | echo "using existing unzipped folder by_class"
26 | fi
27 |
28 | if [ ! -d "by_write" ]; then
29 | echo "unzipping by_write.zip"
30 | unzip by_write.zip
31 | #rm by_write.zip
32 | else
33 | echo "using existing unzipped folder by_write"
34 | fi
35 |
36 | cd ../../preprocess
37 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/femnist/preprocess/get_file_dirs.py:
--------------------------------------------------------------------------------
1 | '''
2 | Creates .pkl files for:
3 | 1. list of directories of every image in 'by_class'
4 | 2. list of directories of every image in 'by_write'
5 | the hierarchal structure of the data is as follows:
6 | - by_class -> classes -> folders containing images -> images
7 | - by_write -> folders containing writers -> writer -> types of images -> images
8 | the directories written into the files are of the form 'raw_data/...'
9 | '''
10 |
11 | import os
12 | import sys
13 |
14 | utils_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
15 | utils_dir = os.path.join(utils_dir, 'utils')
16 | sys.path.append(utils_dir)
17 |
18 | print(utils_dir)
19 |
20 | import util
21 |
22 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
23 |
24 | print(parent_path)
25 | class_files = [] # (class, file directory)
26 | write_files = [] # (writer, file directory)
27 |
28 | class_dir = os.path.join(parent_path, 'data', 'raw_data', 'by_class')
29 | rel_class_dir = os.path.join('data', 'raw_data', 'by_class')
30 | classes = os.listdir(class_dir)
31 | classes = [c for c in classes if len(c) == 2]
32 |
33 | for cl in classes:
34 | cldir = os.path.join(class_dir, cl)
35 | rel_cldir = os.path.join(rel_class_dir, cl)
36 | subcls = os.listdir(cldir)
37 |
38 | subcls = [s for s in subcls if (('hsf' in s) and ('mit' not in s))]
39 |
40 | for subcl in subcls:
41 | subcldir = os.path.join(cldir, subcl)
42 | rel_subcldir = os.path.join(rel_cldir, subcl)
43 | images = os.listdir(subcldir)
44 | image_dirs = [os.path.join(rel_subcldir, i) for i in images]
45 |
46 | for image_dir in image_dirs:
47 | class_files.append((cl, image_dir))
48 |
49 | write_dir = os.path.join(parent_path, 'data', 'raw_data', 'by_write')
50 | rel_write_dir = os.path.join('data', 'raw_data', 'by_write')
51 | write_parts = os.listdir(write_dir)
52 |
53 | for write_part in write_parts:
54 | writers_dir = os.path.join(write_dir, write_part)
55 | rel_writers_dir = os.path.join(rel_write_dir, write_part)
56 | writers = os.listdir(writers_dir)
57 |
58 | for writer in writers:
59 | writer_dir = os.path.join(writers_dir, writer)
60 | rel_writer_dir = os.path.join(rel_writers_dir, writer)
61 | wtypes = os.listdir(writer_dir)
62 |
63 | for wtype in wtypes:
64 | type_dir = os.path.join(writer_dir, wtype)
65 | rel_type_dir = os.path.join(rel_writer_dir, wtype)
66 | images = os.listdir(type_dir)
67 | image_dirs = [os.path.join(rel_type_dir, i) for i in images]
68 |
69 | for image_dir in image_dirs:
70 | write_files.append((writer, image_dir))
71 |
72 | util.save_obj(
73 | class_files,
74 | os.path.join(parent_path, 'data', 'intermediate', 'class_file_dirs'))
75 | util.save_obj(
76 | write_files,
77 | os.path.join(parent_path, 'data', 'intermediate', 'write_file_dirs'))
78 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/femnist/preprocess/get_hashes.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import sys
4 |
5 | utils_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6 | utils_dir = os.path.join(utils_dir, 'utils')
7 | sys.path.append(utils_dir)
8 |
9 | import util
10 |
11 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12 |
13 | cfd = os.path.join(parent_path, 'data', 'intermediate', 'class_file_dirs')
14 | wfd = os.path.join(parent_path, 'data', 'intermediate', 'write_file_dirs')
15 | class_file_dirs = util.load_obj(cfd)
16 | write_file_dirs = util.load_obj(wfd)
17 |
18 | class_file_hashes = []
19 | write_file_hashes = []
20 |
21 | count = 0
22 | for tup in class_file_dirs:
23 | if (count % 100000 == 0):
24 | print('hashed %d class images' % count)
25 |
26 | (cclass, cfile) = tup
27 | file_path = os.path.join(parent_path, cfile)
28 |
29 | chash = hashlib.md5(open(file_path, 'rb').read()).hexdigest()
30 |
31 | class_file_hashes.append((cclass, cfile, chash))
32 |
33 | count += 1
34 |
35 | cfhd = os.path.join(parent_path, 'data', 'intermediate', 'class_file_hashes')
36 | util.save_obj(class_file_hashes, cfhd)
37 |
38 | count = 0
39 | for tup in write_file_dirs:
40 | if (count % 100000 == 0):
41 | print('hashed %d write images' % count)
42 |
43 | (cclass, cfile) = tup
44 | file_path = os.path.join(parent_path, cfile)
45 |
46 | chash = hashlib.md5(open(file_path, 'rb').read()).hexdigest()
47 |
48 | write_file_hashes.append((cclass, cfile, chash))
49 |
50 | count += 1
51 |
52 | wfhd = os.path.join(parent_path, 'data', 'intermediate', 'write_file_hashes')
53 | util.save_obj(write_file_hashes, wfhd)
54 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/femnist/preprocess/group_by_writer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | utils_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5 | utils_dir = os.path.join(utils_dir, 'utils')
6 | sys.path.append(utils_dir)
7 |
8 | import util
9 |
10 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
11 |
12 | wwcd = os.path.join(parent_path, 'data', 'intermediate', 'write_with_class')
13 | write_class = util.load_obj(wwcd)
14 |
15 | writers = [] # each entry is a (writer, [list of (file, class)]) tuple
16 | cimages = []
17 | (cw, _, _) = write_class[0]
18 | for (w, f, c) in write_class:
19 | if w != cw:
20 | writers.append((cw, cimages))
21 | cw = w
22 | cimages = [(f, c)]
23 | cimages.append((f, c))
24 | writers.append((cw, cimages))
25 |
26 | ibwd = os.path.join(parent_path, 'data', 'intermediate', 'images_by_writer')
27 | util.save_obj(writers, ibwd)
28 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/femnist/preprocess/match_hashes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | utils_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5 | utils_dir = os.path.join(utils_dir, 'utils')
6 | sys.path.append(utils_dir)
7 |
8 | import util
9 |
10 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
11 |
12 | cfhd = os.path.join(parent_path, 'data', 'intermediate', 'class_file_hashes')
13 | wfhd = os.path.join(parent_path, 'data', 'intermediate', 'write_file_hashes')
14 | class_file_hashes = util.load_obj(cfhd) # each elem is (class, file dir, hash)
15 | write_file_hashes = util.load_obj(wfhd) # each elem is (writer, file dir, hash)
16 |
17 | class_hash_dict = {}
18 | for i in range(len(class_file_hashes)):
19 | (c, f, h) = class_file_hashes[len(class_file_hashes)-i-1]
20 | class_hash_dict[h] = (c, f)
21 |
22 | write_classes = []
23 | for tup in write_file_hashes:
24 | (w, f, h) = tup
25 | write_classes.append((w, f, class_hash_dict[h][0]))
26 |
27 | wwcd = os.path.join(parent_path, 'data', 'intermediate', 'write_with_class')
28 | util.save_obj(write_classes, wwcd)
29 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/femnist/stats.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | NAME="femnist"
5 |
6 | cd ../utils
7 |
8 | python3 stats.py --name $NAME
9 |
10 | cd ../$NAME
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/adult_iid_10clients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/adult_iid_10clients.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/adult_noniid-label1_10clients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/adult_noniid-label1_10clients.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/adult_noniid_labeldir_10clients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/adult_noniid_labeldir_10clients.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/adult_unbalance_10clients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/adult_unbalance_10clients.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/cifar10_balance_dir_alpha_0.3_100clients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_balance_dir_alpha_0.3_100clients.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/cifar10_balance_iid_100clients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_balance_iid_100clients.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/cifar10_hetero_dir_0.3_100clients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_hetero_dir_0.3_100clients.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/cifar10_hetero_dir_0.3_100clients_dist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_hetero_dir_0.3_100clients_dist.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/cifar10_shards_200_100clients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_shards_200_100clients.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/cifar10_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/cifar10_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients_dist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_unbalance_dir_alpha_0.3_unbalance_sgm_0.3_100clients_dist.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/cifar10_unbalance_iid_unbalance_sgm_0.3_100clients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_unbalance_iid_unbalance_sgm_0.3_100clients.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/cifar10_unbalance_iid_unbalance_sgm_0.3_100clients_dist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/cifar10_unbalance_iid_unbalance_sgm_0.3_100clients_dist.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/fcube_class_dist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fcube_class_dist.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/fcube_iid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fcube_iid.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/fcube_iid_part.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fcube_iid_part.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/fcube_synthetic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fcube_synthetic.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/fcube_synthetic_original_paper.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fcube_synthetic_original_paper.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/fcube_synthetic_part.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fcube_synthetic_part.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/fcube_test_dist_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fcube_test_dist_vis.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/fcube_train_dist_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fcube_train_dist_vis.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/fmnist_feature_skew_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fmnist_feature_skew_vis.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/fmnist_iid_clients_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fmnist_iid_clients_10.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/fmnist_noniid-label_1_clients_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fmnist_noniid-label_1_clients_10.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/fmnist_noniid-label_2_clients_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fmnist_noniid-label_2_clients_10.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/fmnist_noniid-label_3_clients_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fmnist_noniid-label_3_clients_10.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/fmnist_noniid_labeldir_clients_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fmnist_noniid_labeldir_clients_10.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/fmnist_unbalance_clients_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fmnist_unbalance_clients_10.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/fmnist_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/fmnist_vis.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/svhn_feature_skew_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/svhn_feature_skew_vis.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/imgs/svhn_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/imgs/svhn_vis.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/mnist/download_mnist.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 |
3 | if __name__ == "__main__":
4 | root = './'
5 | trainset = torchvision.datasets.MNIST(
6 | root=root,
7 | train=True,
8 | download=True,
9 | )
10 |
11 | testset = torchvision.datasets.MNIST(
12 | root=root,
13 | train=False,
14 | download=True,
15 | )
16 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/mnist/mnist_iid.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/mnist/mnist_iid.pkl
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/mnist/mnist_noniid.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/mnist/mnist_noniid.pkl
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/mnist/mnist_partition.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from fedlab.utils.functional import save_dict
4 | from fedlab.utils.dataset.slicing import noniid_slicing, random_slicing
5 |
6 | import torchvision
7 |
8 | trainset = torchvision.datasets.MNIST(root="./", train=True, download=False)
9 |
10 | num_clients=100
11 | num_shards=200
12 |
13 | data_indices = noniid_slicing(trainset, num_clients=num_clients, num_shards=num_shards)
14 | save_dict(data_indices, "mnist_noniid_{}_{}.pkl".format(num_shards, num_clients))
15 |
16 | data_indices = random_slicing(trainset, num_clients=num_clients)
17 | save_dict(data_indices, "mnist_iid_{}.pkl".format(num_clients))
18 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/partition-reports/adult_iid_10clients.csv:
--------------------------------------------------------------------------------
1 | Class frequencies:
2 | client,class0,class1,Amount
3 | Client 0,0.751,0.249,3256
4 | Client 1,0.766,0.234,3256
5 | Client 2,0.764,0.236,3256
6 | Client 3,0.758,0.242,3256
7 | Client 4,0.757,0.243,3256
8 | Client 5,0.757,0.243,3256
9 | Client 6,0.756,0.244,3256
10 | Client 7,0.768,0.232,3256
11 | Client 8,0.763,0.237,3256
12 | Client 9,0.753,0.247,3256
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/partition-reports/adult_noniid-label1_10clients.csv:
--------------------------------------------------------------------------------
1 | Class frequencies:
2 | client,class0,class1,Amount
3 | Client 0,1.000,0.00,4944
4 | Client 1,0.00,1.000,1569
5 | Client 2,1.000,0.00,4944
6 | Client 3,0.00,1.000,1568
7 | Client 4,1.000,0.00,4944
8 | Client 5,0.00,1.000,1568
9 | Client 6,1.000,0.00,4944
10 | Client 7,0.00,1.000,1568
11 | Client 8,1.000,0.00,4944
12 | Client 9,0.00,1.000,1568
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/partition-reports/adult_noniid_labeldir_10clients.csv:
--------------------------------------------------------------------------------
1 | Class frequencies:
2 | client,class0,class1,Amount
3 | Client 0,1.000,0.00,3305
4 | Client 1,0.833,0.167,126
5 | Client 2,0.222,0.778,8300
6 | Client 3,0.768,0.232,777
7 | Client 4,0.552,0.448,747
8 | Client 5,0.709,0.291,629
9 | Client 6,0.750,0.250,1546
10 | Client 7,1.000,0.00,5688
11 | Client 8,1.000,0.00,9929
12 | Client 9,0.815,0.185,1514
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/partition-reports/adult_unbalance_10clients.csv:
--------------------------------------------------------------------------------
1 | Class frequencies:
2 | client,class0,class1,Amount
3 | Client 0,0.751,0.249,2449
4 | Client 1,0.683,0.317,123
5 | Client 2,0.725,0.275,204
6 | Client 3,0.764,0.236,3095
7 | Client 4,0.763,0.237,59
8 | Client 5,0.761,0.239,20562
9 | Client 6,0.723,0.277,47
10 | Client 7,0.757,0.243,2557
11 | Client 8,0.759,0.241,2153
12 | Client 9,0.747,0.253,1306
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/partition-reports/fcube_iid.csv:
--------------------------------------------------------------------------------
1 | Class frequencies:
2 | client,class0,class1,Amount
3 | Client 0,0.472,0.528,250
4 | Client 1,0.532,0.468,250
5 | Client 2,0.476,0.524,250
6 | Client 3,0.520,0.480,250
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/partition-reports/fcube_synthetic.csv:
--------------------------------------------------------------------------------
1 | Class frequencies:
2 | client,class0,class1,Amount
3 | Client 0,0.500,0.500,250
4 | Client 1,0.500,0.500,250
5 | Client 2,0.500,0.500,250
6 | Client 3,0.500,0.500,250
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/partition-reports/fmnist_iid_clients_10.csv:
--------------------------------------------------------------------------------
1 | Class frequencies:
2 | client,class0,class1,class2,class3,class4,class5,class6,class7,class8,class9,Amount
3 | Client 0,0.096,0.104,0.101,0.108,0.104,0.098,0.094,0.098,0.102,0.097,6000
4 | Client 1,0.105,0.100,0.101,0.097,0.090,0.099,0.103,0.100,0.097,0.108,6000
5 | Client 2,0.102,0.100,0.102,0.099,0.097,0.100,0.098,0.097,0.106,0.101,6000
6 | Client 3,0.096,0.105,0.101,0.096,0.100,0.102,0.102,0.102,0.099,0.099,6000
7 | Client 4,0.092,0.101,0.097,0.102,0.101,0.100,0.103,0.102,0.102,0.101,6000
8 | Client 5,0.103,0.098,0.106,0.100,0.098,0.097,0.098,0.105,0.096,0.101,6000
9 | Client 6,0.107,0.104,0.102,0.087,0.104,0.105,0.097,0.102,0.096,0.097,6000
10 | Client 7,0.101,0.097,0.096,0.100,0.102,0.099,0.101,0.104,0.102,0.098,6000
11 | Client 8,0.100,0.097,0.098,0.097,0.101,0.101,0.104,0.098,0.100,0.105,6000
12 | Client 9,0.099,0.096,0.098,0.114,0.103,0.101,0.101,0.093,0.100,0.095,6000
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/partition-reports/fmnist_noniid-label_1_clients_10.csv:
--------------------------------------------------------------------------------
1 | Class frequencies:
2 | client,class0,class1,class2,class3,class4,class5,class6,class7,class8,class9,Amount
3 | Client 0,1.000,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,6000
4 | Client 1,0.00,1.000,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,6000
5 | Client 2,0.00,0.00,1.000,0.00,0.00,0.00,0.00,0.00,0.00,0.00,6000
6 | Client 3,0.00,0.00,0.00,1.000,0.00,0.00,0.00,0.00,0.00,0.00,6000
7 | Client 4,0.00,0.00,0.00,0.00,1.000,0.00,0.00,0.00,0.00,0.00,6000
8 | Client 5,0.00,0.00,0.00,0.00,0.00,1.000,0.00,0.00,0.00,0.00,6000
9 | Client 6,0.00,0.00,0.00,0.00,0.00,0.00,1.000,0.00,0.00,0.00,6000
10 | Client 7,0.00,0.00,0.00,0.00,0.00,0.00,0.00,1.000,0.00,0.00,6000
11 | Client 8,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,1.000,0.00,6000
12 | Client 9,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,1.000,6000
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/partition-reports/fmnist_noniid-label_2_clients_10.csv:
--------------------------------------------------------------------------------
1 | Class frequencies:
2 | client,class0,class1,class2,class3,class4,class5,class6,class7,class8,class9,Amount
3 | Client 0,0.500,0.00,0.00,0.00,0.500,0.00,0.00,0.00,0.00,0.00,6000
4 | Client 1,0.00,0.500,0.00,0.00,0.00,0.500,0.00,0.00,0.00,0.00,4000
5 | Client 2,0.00,0.00,0.667,0.00,0.00,0.00,0.00,0.00,0.00,0.333,9000
6 | Client 3,0.333,0.00,0.00,0.667,0.00,0.00,0.00,0.00,0.00,0.00,9000
7 | Client 4,0.00,0.00,0.00,0.00,0.500,0.00,0.500,0.00,0.00,0.00,6000
8 | Client 5,0.00,0.00,0.00,0.00,0.00,0.400,0.00,0.00,0.600,0.00,5000
9 | Client 6,0.00,0.400,0.00,0.00,0.00,0.00,0.600,0.00,0.00,0.00,5000
10 | Client 7,0.00,0.00,0.00,0.00,0.00,0.400,0.00,0.600,0.00,0.00,5000
11 | Client 8,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.500,0.500,0.00,6000
12 | Client 9,0.00,0.400,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.600,5000
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/partition-reports/fmnist_noniid-label_3_clients_10.csv:
--------------------------------------------------------------------------------
1 | Class frequencies:
2 | client,class0,class1,class2,class3,class4,class5,class6,class7,class8,class9,Amount
3 | Client 0,0.250,0.00,0.00,0.00,0.500,0.250,0.00,0.00,0.00,0.00,6000
4 | Client 1,0.263,0.211,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.526,5700
5 | Client 2,0.00,0.00,0.400,0.00,0.00,0.300,0.300,0.00,0.00,0.00,5000
6 | Client 3,0.00,0.00,0.00,0.400,0.00,0.00,0.200,0.00,0.400,0.00,7500
7 | Client 4,0.00,0.211,0.00,0.00,0.526,0.00,0.263,0.00,0.00,0.00,5700
8 | Client 5,0.00,0.211,0.00,0.00,0.00,0.263,0.00,0.526,0.00,0.00,5700
9 | Client 6,0.00,0.286,0.00,0.00,0.00,0.357,0.357,0.00,0.00,0.00,4200
10 | Client 7,0.231,0.00,0.308,0.00,0.00,0.00,0.00,0.462,0.00,0.00,6500
11 | Client 8,0.00,0.167,0.00,0.417,0.00,0.00,0.00,0.00,0.417,0.00,7200
12 | Client 9,0.231,0.00,0.308,0.00,0.00,0.00,0.00,0.00,0.00,0.462,6500
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/partition-reports/fmnist_noniid_labeldir_clients_10.csv:
--------------------------------------------------------------------------------
1 | Class frequencies:
2 | client,class0,class1,class2,class3,class4,class5,class6,class7,class8,class9,Amount
3 | Client 0,0.000,0.040,0.00,0.007,0.036,0.092,0.008,0.454,0.362,0.00,6968
4 | Client 1,0.201,0.081,0.002,0.302,0.193,0.221,0.00,0.00,0.00,0.00,6300
5 | Client 2,0.002,0.090,0.172,0.178,0.080,0.048,0.430,0.00,0.00,0.00,6482
6 | Client 3,0.00,0.002,0.147,0.142,0.033,0.058,0.000,0.058,0.280,0.279,4687
7 | Client 4,0.040,0.177,0.014,0.00,0.200,0.022,0.154,0.204,0.189,0.00,6530
8 | Client 5,0.343,0.256,0.011,0.043,0.019,0.040,0.094,0.068,0.011,0.116,4562
9 | Client 6,0.065,0.001,0.078,0.146,0.301,0.033,0.000,0.098,0.003,0.274,6781
10 | Client 7,0.322,0.294,0.217,0.004,0.050,0.113,0.00,0.00,0.00,0.00,6637
11 | Client 8,0.039,0.016,0.120,0.060,0.002,0.283,0.161,0.015,0.051,0.252,5810
12 | Client 9,0.016,0.047,0.261,0.128,0.014,0.084,0.150,0.033,0.106,0.160,5243
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/partition-reports/fmnist_unbalance_clients_10.csv:
--------------------------------------------------------------------------------
1 | Class frequencies:
2 | client,class0,class1,class2,class3,class4,class5,class6,class7,class8,class9,Amount
3 | Client 0,0.104,0.097,0.096,0.100,0.102,0.104,0.099,0.099,0.103,0.097,4513
4 | Client 1,0.097,0.128,0.093,0.101,0.093,0.101,0.097,0.110,0.079,0.101,227
5 | Client 2,0.106,0.101,0.106,0.082,0.088,0.106,0.114,0.095,0.095,0.106,377
6 | Client 3,0.095,0.100,0.098,0.101,0.098,0.097,0.108,0.098,0.102,0.102,5704
7 | Client 4,0.064,0.109,0.055,0.100,0.109,0.055,0.118,0.118,0.100,0.173,110
8 | Client 5,0.101,0.101,0.102,0.098,0.099,0.100,0.098,0.101,0.099,0.101,37889
9 | Client 6,0.080,0.125,0.080,0.102,0.080,0.148,0.148,0.068,0.068,0.102,88
10 | Client 7,0.099,0.095,0.097,0.096,0.107,0.101,0.102,0.098,0.105,0.100,4713
11 | Client 8,0.098,0.095,0.097,0.113,0.106,0.101,0.098,0.097,0.098,0.098,3967
12 | Client 9,0.096,0.102,0.099,0.114,0.101,0.100,0.106,0.093,0.098,0.091,2407
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/rcv1/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .rcv1 import RCV1
16 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/reddit/README.md:
--------------------------------------------------------------------------------
1 | # Reddit Dataset
2 |
3 | We preprocess the Reddit data released by [pushshift.io](https://files.pushshift.io/reddit/) corresponding to December 2017. We perform the following operations:
4 |
5 | 1. Unescape html symbols.
6 | 2. Remove extraneous whitespaces.
7 | 3. Remove non-ascii symbols.
8 | 4. Replace URLS, reddit usernames and subreddit names with special tokens.
9 | 5. Lowercase the text.
10 | 6. Tokenize the text (using nltk's TweetTokenizer).
11 |
12 | We also remove users and comments that simple heuristics or preliminary inspections mark as bots; and remove users with less than 5 or more than 1000 comments (which account for less than 0.01% of users). We include the code for this preprocessing in the ```source``` folder for reference, but host the preprocessed dataset [here](https://drive.google.com/file/d/1CXufUKXNpR7Pn8gUbIerZ1-qHz1KatHH/view?usp=sharing). We further preprocess the data to make it ready for our reference model (by splitting it into train/val/test sets and by creating sequences of 10 tokens for the LSTM) [here](https://drive.google.com/file/d/1lT1Z0N1weG-oA2PgC1Jak_WQ6h3bu7V_/view?usp=sharing). The vocabulary of the 10 thousand most common tokens in the data can be found [here](https://drive.google.com/file/d/1I-CRlfAeiriLmAyICrmlpPE5zWJX4TOY/view?usp=sharing).
13 |
14 | ## Setup Instructions
15 |
16 | 1. To use our reference model, download the data [here](https://drive.google.com/file/d/1PwBpAEMYKNpnv64cQ2TIQfSc_vPbq3OQ/view?usp=sharing) into a ```data``` subfolder in this directory. This is a subsampled version of the complete data. Our reference implementation doesn't yet support training on the [complete dataset](https://drive.google.com/file/d/1lT1Z0N1weG-oA2PgC1Jak_WQ6h3bu7V_/view?usp=sharing), as it loads all given clients into memory.
17 | 2. With the data in the appropriate directory, run build the training vocabulary by running ```python build_vocab.py --data-dir ./data/train --target-dir vocab```.
18 | 3. With the data and the training vocabulary, you can now run our reference implementation in the ```models``` directory using a command as the following:
19 | - ```python3 main.py -dataset reddit -model stacked_lstm --eval-every 10 --num-rounds 100 --clients-per-round 10 --batch-size 5 -lr 5.65 --metrics-name reddit_experiment```
20 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/reddit/build_vocab.py:
--------------------------------------------------------------------------------
1 | """Builds vocabulary file from data."""
2 |
3 | import argparse
4 | import collections
5 | import json
6 | import os
7 | import pickle
8 | import re
9 |
10 |
11 | def build_counter(train_data, initial_counter=None):
12 | train_tokens = []
13 | for u in train_data:
14 | for c in train_data[u]['x']:
15 | train_tokens.extend([s for s in c])
16 |
17 | all_tokens = []
18 | for i in train_tokens:
19 | all_tokens.extend(i)
20 | train_tokens = []
21 |
22 | if initial_counter is None:
23 | counter = collections.Counter()
24 | else:
25 | counter = initial_counter
26 |
27 | counter.update(all_tokens)
28 | all_tokens = []
29 |
30 | return counter
31 |
32 |
33 | def build_vocab(counter, vocab_size=10000):
34 | pad_symbol, unk_symbol = 0, 1
35 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
36 | count_pairs = count_pairs[:(vocab_size - 2)] # -2 to account for the unknown and pad symbols
37 |
38 | words, _ = list(zip(*count_pairs))
39 |
40 | vocab = {}
41 | vocab[''] = pad_symbol
42 | vocab[''] = unk_symbol
43 |
44 | for i, w in enumerate(words):
45 | if w != '':
46 | vocab[w] = i + 1
47 |
48 | return {'vocab': vocab, 'size': vocab_size, 'unk_symbol': unk_symbol, 'pad_symbol': pad_symbol}
49 |
50 |
51 | def load_leaf_data(file_path):
52 | with open(file_path) as json_file:
53 | data = json.load(json_file)
54 | to_ret = data['user_data']
55 | data = None
56 | return to_ret
57 |
58 |
59 | def save_vocab(vocab, target_dir):
60 | os.makedirs(target_dir, exist_ok=True)
61 | pickle.dump(vocab, open(os.path.join(target_dir, 'reddit_vocab.pck'), 'wb'))
62 |
63 |
64 | def main():
65 | args = parse_args()
66 |
67 | json_files = [f for f in os.listdir(args.data_dir) if f.endswith('.json')]
68 | json_files.sort()
69 |
70 | counter = None
71 | train_data = {}
72 | for f in json_files:
73 | print('loading {}'.format(f))
74 | train_data = load_leaf_data(os.path.join(args.data_dir, f))
75 | print('counting {}'.format(f))
76 | counter = build_counter(train_data, initial_counter=counter)
77 | print()
78 | train_data = {}
79 |
80 | if counter is not None:
81 | vocab = build_vocab(counter, vocab_size=args.vocab_size)
82 | save_vocab(vocab, args.target_dir)
83 | else:
84 | print('No files to process.')
85 |
86 |
87 | def parse_args():
88 | parser = argparse.ArgumentParser()
89 |
90 | parser.add_argument('--data-dir',
91 | help='dir with training file;',
92 | type=str,
93 | required=True)
94 | parser.add_argument('--vocab-size',
95 | help='size of the vocabulary;',
96 | type=int,
97 | default=10000,
98 | required=False)
99 | parser.add_argument('--target-dir',
100 | help='dir with training file;',
101 | type=str,
102 | default='./',
103 | required=False)
104 |
105 | return parser.parse_args()
106 |
107 |
108 | if __name__ == '__main__':
109 | main()
110 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/reddit/source/clean_raw.py:
--------------------------------------------------------------------------------
1 | import json
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import os
5 | import pickle
6 | import re
7 | import string
8 | import html
9 |
10 | from nltk.tokenize import TweetTokenizer
11 |
12 |
13 | DIR = os.path.join('data', 'reddit_merged')
14 | FINAL_DIR = os.path.join('data', 'reddit_clean')
15 |
16 | PHRASES_TO_AVOID = [
17 | '[ deleted ]',
18 | '[ removed ]',
19 | '[deleted]',
20 | '[removed]',
21 | 'bot',
22 | 'thank you for participating',
23 | 'thank you for your submission',
24 | 'thanks for your submission',
25 | 'your submission has been removed',
26 | 'your comment has been removed',
27 | 'downvote this comment if this is',
28 | 'your post has been removed',
29 | ]
30 |
31 | def clean_file(f, tknzr):
32 | reddit = pickle.load(open(os.path.join(DIR, f), 'rb'))
33 |
34 | clean_reddit = {}
35 | for u, comments in reddit.items():
36 |
37 | clean_comments = []
38 | for c in comments:
39 | c.clean_body(tknzr)
40 | if len(c.body) > 0 and not any([p in c.body for p in PHRASES_TO_AVOID]):
41 | clean_comments.append(c)
42 |
43 | if len(clean_comments) > 0:
44 | clean_reddit[u] = clean_comments
45 |
46 | pickle.dump(
47 | clean_reddit,
48 | open(os.path.join(FINAL_DIR, f.replace('merged', 'cleaned')), 'wb'))
49 |
50 | def main():
51 | tknzr = TweetTokenizer()
52 |
53 | if not os.path.exists(FINAL_DIR):
54 | os.makedirs(FINAL_DIR)
55 |
56 | files = [f for f in os.listdir(DIR) if f.endswith('.pck')]
57 | files.sort()
58 |
59 | num_files = len(files)
60 | for i, f in enumerate(files):
61 | clean_file(f, tknzr)
62 | print('Done with {} of {}'.format(i, num_files))
63 |
64 | if __name__ == '__main__':
65 | main()
66 |
67 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/reddit/source/delete_small_users.py:
--------------------------------------------------------------------------------
1 | import json
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import os
5 | import pickle
6 |
7 |
8 | DIR = os.path.join('data', 'reddit_clean')
9 | FINAL_DIR = os.path.join('data', 'reddit_subsampled')
10 |
11 |
12 | def subsample_file(f):
13 | reddit = pickle.load(open(os.path.join(DIR, f), 'rb'))
14 |
15 | subsampled_reddit = {}
16 | for u, comments in reddit.items():
17 |
18 | subsampled_comments = [c for c in comments if len(c.body.split()) >= 5]
19 |
20 | if len(subsampled_comments) >= 5 and len(subsampled_comments) <= 1000:
21 | subsampled_reddit[u] = subsampled_comments
22 |
23 | pickle.dump(
24 | subsampled_reddit,
25 | open(os.path.join(FINAL_DIR, f.replace('cleaned', 'subsampled')), 'wb'))
26 |
27 |
28 | def main():
29 | if not os.path.exists(FINAL_DIR):
30 | os.makedirs(FINAL_DIR)
31 |
32 | files = [f for f in os.listdir(DIR) if f.endswith('.pck')]
33 | files.sort()
34 |
35 | num_files = len(files)
36 | for i, f in enumerate(files):
37 | subsample_file(f)
38 | print('Done with {} of {}'.format(i, num_files))
39 |
40 | if __name__ == '__main__':
41 | main()
42 |
43 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/reddit/source/get_json.py:
--------------------------------------------------------------------------------
1 | import math
2 | import json
3 | import os
4 | import pickle
5 |
6 | DIR = os.path.join('data', 'reddit_subsampled')
7 | FINAL_DIR = os.path.join('data', 'reddit_json')
8 | FILES_PER_JSON = 10
9 |
10 |
11 | def merge_dicts(x, y):
12 | z = x.copy()
13 | z.update(y)
14 | return z
15 |
16 |
17 | def to_leaf_format(some_json, start_idx=0):
18 | leaf_json = {'users': [], 'num_samples': [], 'user_data': {}}
19 | new_idx = start_idx
20 | for u, comments in some_json.items():
21 | new_idx += 1
22 | leaf_json['users'].append(str(new_idx))
23 | leaf_json['num_samples'].append(len(comments))
24 |
25 | x = []
26 | y = []
27 | for c in comments:
28 | assert c.author == u
29 |
30 | c_x = c.body
31 | c_y = {
32 | 'subreddit': c.subreddit,
33 | 'created_utc': c.created_utc,
34 | 'score': c.score,
35 | }
36 |
37 | x.append(c_x)
38 | y.append(c_y)
39 |
40 | user_data = {'x': x, 'y': y}
41 | leaf_json['user_data'][str(new_idx)] = user_data
42 |
43 | return leaf_json, new_idx
44 |
45 |
46 | def files_to_json(files, json_name, start_user_idx=0):
47 | all_users = {}
48 |
49 | for f in files:
50 | f_dir = os.path.join(DIR, f)
51 | f_users = pickle.load(open(f_dir, 'rb'))
52 |
53 | all_users = merge_dicts(all_users, f_users)
54 |
55 | all_users, last_user_idx = to_leaf_format(all_users, start_user_idx)
56 |
57 | with open(os.path.join(FINAL_DIR, json_name), 'w') as outfile:
58 | json.dump(all_users, outfile)
59 |
60 | return last_user_idx
61 |
62 |
63 | def main():
64 | if not os.path.exists(FINAL_DIR):
65 | os.makedirs(FINAL_DIR)
66 |
67 | files = [f for f in os.listdir(DIR) if f.endswith('.pck')]
68 | files.sort()
69 |
70 | num_files = len(files)
71 | num_json = math.ceil(num_files / FILES_PER_JSON)
72 |
73 | last_user_idx = 0
74 | for i in range(num_json):
75 | cur_files = files[i * FILES_PER_JSON : (i+1) * FILES_PER_JSON]
76 | print('processing until', (i+1) * FILES_PER_JSON)
77 | last_user_idx = files_to_json(
78 | cur_files,
79 | 'reddit_{}.json'.format(i),
80 | last_user_idx)
81 |
82 |
83 | if __name__ == '__main__':
84 | main()
85 |
86 |
87 |
88 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/reddit/source/merge_raw_users.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import pickle
4 | import random
5 |
6 | from reddit_utils import RedditComment
7 |
8 |
9 | DIR = os.path.join('data', 'reddit_raw')
10 | USERS_PER_REPEAT = 200000
11 | USERS_PER_FILE = 20000
12 | FINAL_DIR = os.path.join('data', 'reddit_merged')
13 |
14 | if not os.path.exists(FINAL_DIR):
15 | os.makedirs(FINAL_DIR)
16 |
17 | files = [f for f in os.listdir(DIR) if f.endswith('.pck')]
18 | files.sort()
19 |
20 | all_users = {}
21 |
22 | for f in files:
23 | f_path = os.path.join(DIR, f)
24 | users = pickle.load(open(f_path, 'rb'))
25 | users = list(users.keys())
26 |
27 | for u in users:
28 | if u not in all_users:
29 | all_users[u] = []
30 | all_users[u].append(f)
31 |
32 |
33 | user_keys = list(all_users.keys())
34 | random.seed(3760145)
35 | random.shuffle(user_keys)
36 |
37 | num_users = len(all_users)
38 |
39 | num_lots = (num_users // USERS_PER_REPEAT) + 1
40 | print('num_lots', num_lots)
41 |
42 | cur_file = 1
43 | for l in range(num_lots):
44 | min_idx, max_idx = l * USERS_PER_REPEAT, min((l + 1) * USERS_PER_REPEAT, num_users)
45 |
46 | cur_user_keys = user_keys[min_idx:max_idx]
47 | num_cur_users = len(cur_user_keys)
48 | cur_users = {u: [] for u in cur_user_keys}
49 |
50 | for f in files:
51 | f_path = os.path.join(DIR, f)
52 | users = pickle.load(open(f_path, 'rb'))
53 |
54 | for u in cur_users:
55 | if f not in all_users[u]:
56 | continue
57 |
58 | cur_users[u].extend([c for c in users[u] if len(c.body) > 0])
59 |
60 | written_users = 0
61 | while written_users < num_cur_users:
62 | low_bound, high_bound = written_users, min(written_users + USERS_PER_FILE, num_cur_users)
63 | file_keys = cur_user_keys[low_bound:high_bound]
64 | file_users = {u: cur_users[u] for u in file_keys if len(cur_users[u]) >= 5}
65 |
66 | pickle.dump(file_users, open(os.path.join(FINAL_DIR, 'reddit_users_merged_{}.pck'.format(cur_file)), 'wb'))
67 |
68 | written_users += USERS_PER_FILE
69 | cur_file += 1
70 |
71 | print(l + 1)
72 |
73 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/reddit/source/reddit_utils.py:
--------------------------------------------------------------------------------
1 | import html
2 | import re
3 | from nltk.tokenize import TweetTokenizer
4 |
5 |
6 | URL_TOKEN = ''
7 | USER_TOKEN = ''
8 | SUBREDDIT_TOKEN = ''
9 |
10 | URL_REGEX = r'http\S+'
11 | USER_REGEX = r'(?:\/?u\/\w+)'
12 | SUBREDDIT_REGEX = r'(?:\/?r\/\w+)'
13 |
14 |
15 | class RedditComment:
16 |
17 | def __init__(self, reddit_dict):
18 | self.body = reddit_dict['body']
19 | self.author = reddit_dict['author']
20 | self.subreddit = reddit_dict['subreddit']
21 | self.subreddit_id = reddit_dict['subreddit_id']
22 | self.created_utc = reddit_dict['created_utc']
23 | self.score = reddit_dict['score']
24 |
25 | def clean_body(self, tknzr=None):
26 | if tknzr is None:
27 | tknzr = TweetTokenizer()
28 |
29 | # unescape html symbols.
30 | new_body = html.unescape(self.body)
31 |
32 | # remove extraneous whitespace.
33 | new_body = new_body.replace('\n', ' ')
34 | new_body = new_body.replace('\t', ' ')
35 | new_body = re.sub(r'\s+', ' ', new_body).strip()
36 |
37 | # remove non-ascii symbols.
38 | new_body = new_body.encode('ascii', errors='ignore').decode()
39 |
40 | # replace URLS with a special token.
41 | new_body = re.sub(URL_REGEX, URL_TOKEN, new_body)
42 |
43 | # replace reddit user with a token
44 | new_body = re.sub(USER_REGEX, USER_TOKEN, new_body)
45 |
46 | # replace subreddit names with a token
47 | new_body = re.sub(SUBREDDIT_REGEX, SUBREDDIT_TOKEN, new_body)
48 |
49 | # lowercase the text
50 | new_body = new_body.casefold()
51 |
52 | # Could be done in addition:
53 | # get rid of comments with quotes
54 |
55 | # tokenize the text
56 | new_body = tknzr.tokenize(new_body)
57 |
58 | self.body = ' '.join(new_body)
59 |
60 | def __str__(self):
61 | return str(vars(self))
62 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/reddit/source/run_reddit.sh:
--------------------------------------------------------------------------------
1 | python get_raw_users.py
2 |
3 | echo 'Good job with raw'
4 |
5 | python merge_raw_users.py
6 |
7 | echo 'Good job with merging'
8 |
9 | python clean_raw.py
10 |
11 | echo 'Good job with cleaning'
12 |
13 | python delete_small_users.py
14 |
15 | echo 'Good job subsampling'
16 |
17 | python get_json.py
18 |
19 | echo 'Good job creating json'
20 |
21 | python preprocess.py
22 |
23 | echo 'Good job preprocessing'
24 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/sent140/preprocess.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # download data and convert to .json format
4 |
5 | if [ ! -d "data/all_data" ] || [ ! "$(ls -A data/all_data)" ]; then
6 | cd preprocess
7 | bash data_to_json.sh
8 | cd ..
9 | else
10 | echo "using existing data/all_data data folder to preprocess"
11 | fi
12 |
13 | NAME="sent140" # name of the dataset, equivalent to directory name
14 |
15 | cd ../utils
16 |
17 | bash preprocess.sh --name $NAME $@
18 |
19 | cd ../$NAME
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/sent140/preprocess/combine_data.py:
--------------------------------------------------------------------------------
1 | '''
2 | each row of created .csv file is of the form:
3 | polarity, id, date, query, user, comment, test_or_training
4 | '''
5 |
6 | import csv
7 | import os
8 |
9 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
10 |
11 | train_file_name = os.path.join(parent_path, 'data', 'raw_data', 'training.csv')
12 |
13 | training = []
14 | with open(train_file_name, 'rt', encoding='ISO-8859-1') as f:
15 | reader = csv.reader(f)
16 | training = list(reader)
17 |
18 | test_file_name = os.path.join(parent_path, 'data', 'raw_data', 'test.csv')
19 |
20 | test = []
21 | with open(test_file_name, 'rt', encoding='ISO-8859-1') as f:
22 | reader = csv.reader(f)
23 | test = list(reader)
24 |
25 | out_file_name = os.path.join(parent_path, 'data', 'intermediate', 'all_data.csv')
26 |
27 | with open(out_file_name, 'w') as f:
28 | writer = csv.writer(f)
29 |
30 | for row in training:
31 | row.append('training')
32 | writer.writerow(row)
33 |
34 | for row in test:
35 | row.append('test')
36 | writer.writerow(row)
37 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/sent140/preprocess/data_to_json.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import json
3 | import os
4 |
5 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
6 |
7 | data_dir = os.path.join(parent_path, 'data', 'intermediate', 'all_data.csv')
8 |
9 | data = []
10 | with open(data_dir, 'rt', encoding='ISO-8859-1') as f:
11 | reader = csv.reader(f)
12 | data = list(reader)
13 |
14 | data = sorted(data, key=lambda x: x[4])
15 |
16 | # ------------
17 | # get # of users in data, and list of users (note automatically sorted)
18 |
19 | num_users = 1
20 | cuser = data[0][4]
21 | users = [cuser]
22 |
23 | for i in range(len(data)):
24 | row = data[i]
25 | tuser = row[4]
26 | if tuser != cuser:
27 | num_users += 1
28 | cuser = tuser
29 | users.append(tuser)
30 |
31 | # ------------
32 | # get # of samples for each user
33 |
34 | num_samples = [0 for _ in range(num_users)]
35 | cuser = data[0][4]
36 | user_i = 0
37 |
38 | for i in range(len(data)):
39 | row = data[i]
40 | tuser = row[4]
41 | if tuser != cuser:
42 | cuser = tuser
43 | user_i += 1
44 | num_samples[user_i] += 1
45 |
46 | # ------------
47 | # create user_data
48 |
49 | user_data = {}
50 | row_i = 0
51 |
52 | for u in users:
53 | user_data[u] = {'x': [], 'y': []}
54 |
55 | while ((row_i < len(data)) and (data[row_i][4] == u)):
56 | row = data[row_i]
57 | y = 1 if row[0] == "4" else 0
58 | user_data[u]['x'].append(row[1:])
59 | user_data[u]['y'].append(y)
60 |
61 | row_i += 1
62 |
63 | # ------------
64 | # create .json file
65 |
66 | all_data = {}
67 | all_data['users'] = users
68 | all_data['num_samples'] = num_samples
69 | all_data['user_data'] = user_data
70 |
71 | file_path = os.path.join(parent_path, 'data', 'all_data', 'all_data.json')
72 |
73 | with open(file_path, 'w') as outfile:
74 | json.dump(all_data, outfile)
75 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/sent140/preprocess/data_to_json.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ ! -d "../data" ]; then
4 | mkdir ../data
5 | fi
6 | if [ ! -d "../data/raw_data" ]; then
7 | mkdir ../data/raw_data
8 | fi
9 | if [ ! -f ../data/raw_data/training.csv ] || [ ! -f ../data/raw_data/test.csv ]; then
10 | echo "------------------------------"
11 | echo "retrieving raw data"
12 |
13 | bash get_data.sh
14 | echo "finished retrieving raw data"
15 | else
16 | echo "using existing retrieved raw data"
17 | fi
18 |
19 | echo "generating intermediate data"
20 | if [ ! -d "../data/intermediate" ]; then
21 | mkdir ../data/intermediate
22 | fi
23 |
24 | if [ ! "$(ls -A ../data/intermediate)" ]; then
25 | echo "------------------------------"
26 | echo "combining raw_data .csv files"
27 | python3 combine_data.py
28 | echo "finished combining raw_data .csv files"
29 | else
30 | echo "using existing retrieved raw data"
31 | fi
32 |
33 | if [ ! -d "../data/all_data" ]; then
34 | mkdir ../data/all_data
35 | fi
36 |
37 | if [ ! "$(ls -A ../data/all_data)" ]; then
38 | echo "------------------------------"
39 | echo "converting data to .json format"
40 | python3 data_to_json.py
41 | echo "finished converting data to .json format"
42 | else
43 | echo "using existing data/all_data"
44 | fi
45 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/sent140/preprocess/get_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../data/raw_data
4 |
5 | if [ ! -f trainingandtestdata.zip ]; then
6 | echo "downloading trainingandtestdata.zip"
7 | wget --no-check-certificate http://cs.stanford.edu/people/alecmgo/trainingandtestdata.zip
8 | else
9 | echo "using existing trainingandtestdata.zip"
10 | fi
11 |
12 | if [ ! -f training.csv ] || [ ! -f test.csv ]; then
13 | echo "unzipping trainingandtestdata.zip"
14 | unzip trainingandtestdata.zip
15 | mv training.1600000.processed.noemoticon.csv training.csv
16 | mv testdata.manual.2009.06.14.csv test.csv
17 | else
18 | echo "using existing training.csv and test.csv in data/raw_data"
19 | fi
20 | #rm trainingandtestdata.zip
21 |
22 | cd ../../preprocess
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/sent140/stats.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | NAME="sent140"
4 |
5 | cd ../utils
6 |
7 | python3 stats.py --name $NAME
8 |
9 | cd ../$NAME
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/shakespeare/preprocess.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # download data and convert to .json format
4 |
5 | RAWTAG=""
6 | if [[ $@ = *"--raw"* ]]; then
7 | RAWTAG="--raw"
8 | fi
9 |
10 | if [ ! -d "data/all_data" ] || [ ! "$(ls -A data/all_data)" ]; then
11 | cd preprocess
12 | bash data_to_json.sh $RAWTAG
13 | cd ..
14 | else
15 | echo "using existing data/all_data data folder to preprocess"
16 | fi
17 |
18 | NAME="shakespeare"
19 |
20 | cd ../utils
21 |
22 | bash preprocess.sh --name $NAME $@
23 |
24 | cd ../$NAME
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/shakespeare/preprocess/data_to_json.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ ! -d "../data" ]; then
4 | mkdir ../data
5 | fi
6 |
7 | if [ ! -d "../data/raw_data" ]; then
8 | mkdir ../data/raw_data
9 | fi
10 |
11 | if [ ! -f ../data/raw_data/raw_data.txt ]; then
12 | bash get_data.sh
13 | else
14 | echo "using existing data/raw_data/raw_data.txt"
15 | fi
16 |
17 | if [ ! -d "../data/raw_data/by_play_and_character" ] || [ ! -f ../data/raw_data/users_and_plays.json ]; then
18 | echo "dividing txt data between users"
19 | python3 preprocess_shakespeare.py ../data/raw_data/raw_data.txt ../data/raw_data/
20 | else
21 | echo "using existing divided files of txt data"
22 | fi
23 |
24 | RAWTAG=""
25 | if [[ $@ = *"--raw"* ]]; then
26 | RAWTAG="--raw"
27 | fi
28 | if [ ! -d "../data/all_data" ]; then
29 | mkdir ../data/all_data
30 | fi
31 | if [ ! "$(ls -A ../data/all_data)" ]; then
32 | echo "generating all_data.json in data/all_data"
33 | python3 gen_all_data.py $RAWTAG
34 | fi
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/shakespeare/preprocess/gen_all_data.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 | import json
4 | import os
5 |
6 | from shake_utils import parse_data_in
7 |
8 | parser = argparse.ArgumentParser()
9 |
10 | parser.add_argument('--raw',
11 | help='include users\' raw .txt data in respective .json files',
12 | action="store_true")
13 |
14 | parser.set_defaults(raw=False)
15 |
16 | args = parser.parse_args()
17 |
18 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
19 |
20 | users_and_plays_path = os.path.join(parent_path, 'data', 'raw_data', 'users_and_plays.json')
21 | txt_dir = os.path.join(parent_path, 'data', 'raw_data', 'by_play_and_character')
22 | json_data = parse_data_in(txt_dir, users_and_plays_path, args.raw)
23 | json_path = os.path.join(parent_path, 'data', 'all_data', 'all_data.json')
24 | with open(json_path, 'w') as outfile:
25 | json.dump(json_data, outfile)
26 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/shakespeare/preprocess/get_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../data/raw_data
4 |
5 | if [ ! -f 1994-01-100.zip ]; then
6 | wget http://www.gutenberg.org/files/100/old/1994-01-100.zip
7 | else
8 | echo "using existing 1994-01-100.zip"
9 | fi
10 |
11 | echo "unzipping 1994-01-100.zip"
12 | unzip 1994-01-100.zip
13 | #rm 1994-01-100.zip
14 | mv 100.txt raw_data.txt
15 |
16 | #wget --adjust-extension http://www.gutenberg.org/files/100/100-0.txt
17 | #mv 100-0.txt raw_data.txt
18 |
19 | cd ../../preprocess
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/shakespeare/preprocess/shake_utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | helper functions for preprocessing shakespeare data
3 | '''
4 |
5 | import json
6 | import os
7 | import re
8 |
9 | def __txt_to_data(txt_dir, seq_length=80):
10 | """Parses text file in given directory into data for next-character model.
11 |
12 | Args:
13 | txt_dir: path to text file
14 | seq_length: length of strings in X
15 | """
16 | raw_text = ""
17 | with open(txt_dir,'r') as inf:
18 | raw_text = inf.read()
19 | raw_text = raw_text.replace('\n', ' ')
20 | raw_text = re.sub(r" *", r' ', raw_text)
21 | dataX = []
22 | dataY = []
23 | for i in range(0, len(raw_text) - seq_length, 1):
24 | seq_in = raw_text[i:i + seq_length]
25 | seq_out = raw_text[i + seq_length]
26 | dataX.append(seq_in)
27 | dataY.append(seq_out)
28 | return dataX, dataY
29 |
30 | def parse_data_in(data_dir, users_and_plays_path, raw=False):
31 | '''
32 | returns dictionary with keys: users, num_samples, user_data
33 | raw := bool representing whether to include raw text in all_data
34 | if raw is True, then user_data key
35 | removes users with no data
36 | '''
37 | with open(users_and_plays_path, 'r') as inf:
38 | users_and_plays = json.load(inf)
39 | files = os.listdir(data_dir)
40 | users = []
41 | hierarchies = []
42 | num_samples = []
43 | user_data = {}
44 | for f in files:
45 | user = f[:-4]
46 | passage = ''
47 | filename = os.path.join(data_dir, f)
48 | with open(filename, 'r') as inf:
49 | passage = inf.read()
50 | dataX, dataY = __txt_to_data(filename)
51 | if(len(dataX) > 0):
52 | users.append(user)
53 | if raw:
54 | user_data[user] = {'raw': passage}
55 | else:
56 | user_data[user] = {}
57 | user_data[user]['x'] = dataX
58 | user_data[user]['y'] = dataY
59 | hierarchies.append(users_and_plays[user])
60 | num_samples.append(len(dataY))
61 | all_data = {}
62 | all_data['users'] = users
63 | all_data['hierarchies'] = hierarchies
64 | all_data['num_samples'] = num_samples
65 | all_data['user_data'] = user_data
66 | return all_data
67 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/shakespeare/stats.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | NAME="shakespeare"
4 |
5 | cd ../utils
6 |
7 | python3 stats.py --name $NAME
8 |
9 | cd ../$NAME
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/synthetic/README.md:
--------------------------------------------------------------------------------
1 | # Synthetic Dataset
2 |
3 | We propose a process to generate synthetic federated datasets. The dataset is inspired by the one presented by [Li et al.](https://arxiv.org/abs/1905.10497), but has possible additional heterogeneity designed to make current meta-learning methods (such as [Reptile](https://openai.com/blog/reptile/)) struggle. The high-level goal is to create tasks whose true models are (1) task-dependant, and (2) clustered around more than just one center. To see a description of the whole generative process, please refer to the LEAF paper.
4 |
5 | We note that, at the moment, we default to one cluster of models in our code. This can be easily changed by modifying the PROB_CLUSTERS constant in ```main.py```.
6 |
7 | ## Setup Instructions
8 | - pip3 install numpy
9 | - pip3 install pillow
10 | - Run ```python main.py -num-tasks 1000 -num-classes 5 -num-dim 60``` to generate the initial data.
11 | - Run the ```./preprocess.sh``` (as with the other LEAF datasets) to produce the final data splits. We suggest using the following tags:
12 | - ```--sf``` := fraction of data to sample, written as a decimal; set it to 1.0 in order to keep the number of tasks/users specified earlier.
13 | - ```-k``` := minimum number of samples per user; set it to 5.
14 | - ```-t``` := 'user' to partition users into train-test groups, or 'sample' to partition each user's samples into train-test groups.
15 | - ```--tf``` := fraction of data in training set, written as a decimal; default is 0.9.
16 | - ```--smplseed``` := seed to be used before random sampling of data.
17 | - ```--spltseed``` := seed to be used before random split of data.
18 |
19 | i.e.
20 | - ```./preprocess.sh -s niid --sf 1.0 -k 5 -t sample --tf 0.6```
21 |
22 | Make sure to delete the rem_user_data, sampled_data, test, and train subfolders in the data directory before re-running preprocess.sh
23 |
24 | ## Notes
25 | - More details on ```preprocess.sh```:
26 | - The order in which ```preprocess.sh``` processes data is 1. generating all_data (done here by the ```main.py``` script), 2. sampling, 3. removing users, and 4. creating train-test split. The script will look at the data in the last generated directory and continue preprocessing from that point. For example, if the ```all_data``` directory has already been generated and the user decides to skip sampling and only remove users with the ```-k``` tag (i.e. running ```preprocess.sh -k 50```), the script will effectively apply a remove user filter to data in ```all_data``` and place the resulting data in the ```rem_user_data``` directory.
27 | - File names provide information about the preprocessing steps taken to generate them. For example, the ```all_data_niid_1_keep_64.json``` file was generated by first sampling 10 percent (.1) of the data ```all_data.json``` in a non-i.i.d. manner and then applying the ```-k 64``` argument to the resulting data.
28 | - Each .json file is an object with 3 keys:
29 | 1. 'users', a list of users
30 | 2. 'num_samples', a list of the number of samples for each user, and
31 | 3. 'user_data', an object with user names as keys and their respective data as values.
32 | - Run ```./stats.sh``` to get statistics of data (data/all_data/all_data.json must have been generated already)
33 | - In order to run reference implementations in ```../models``` directory, the ```-t sample``` tag must be used when running ```./preprocess.sh```
34 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/synthetic/data_generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from scipy.special import softmax
4 |
5 |
6 | NUM_DIM = 10
7 |
8 | class SyntheticDataset:
9 |
10 | def __init__(
11 | self,
12 | num_classes=2,
13 | seed=931231,
14 | num_dim=NUM_DIM,
15 | prob_clusters=[0.5, 0.5]):
16 |
17 | np.random.seed(seed)
18 |
19 | self.num_classes = num_classes
20 | self.num_dim = num_dim
21 | self.num_clusters = len(prob_clusters)
22 | self.prob_clusters = prob_clusters
23 |
24 | self.side_info_dim = self.num_clusters
25 |
26 | self.Q = np.random.normal(
27 | loc=0.0, scale=1.0, size=(self.num_dim + 1, self.num_classes, self.side_info_dim))
28 |
29 | self.Sigma = np.zeros((self.num_dim, self.num_dim))
30 | for i in range(self.num_dim):
31 | self.Sigma[i, i] = (i + 1)**(-1.2)
32 |
33 | self.means = self._generate_clusters()
34 |
35 | def get_task(self, num_samples):
36 | cluster_idx = np.random.choice(
37 | range(self.num_clusters), size=None, replace=True, p=self.prob_clusters)
38 | new_task = self._generate_task(self.means[cluster_idx], cluster_idx, num_samples)
39 | return new_task
40 |
41 | def _generate_clusters(self):
42 | means = []
43 | for i in range(self.num_clusters):
44 | loc = np.random.normal(loc=0, scale=1., size=None)
45 | mu = np.random.normal(loc=loc, scale=1., size=self.side_info_dim)
46 | means.append(mu)
47 | return means
48 |
49 | def _generate_x(self, num_samples):
50 | B = np.random.normal(loc=0.0, scale=1.0, size=None)
51 | loc = np.random.normal(loc=B, scale=1.0, size=self.num_dim)
52 |
53 | samples = np.ones((num_samples, self.num_dim + 1))
54 | samples[:, 1:] = np.random.multivariate_normal(
55 | mean=loc, cov=self.Sigma, size=num_samples)
56 |
57 | return samples
58 |
59 | def _generate_y(self, x, cluster_mean):
60 | model_info = np.random.normal(loc=cluster_mean, scale=0.1, size=cluster_mean.shape)
61 | w = np.matmul(self.Q, model_info)
62 |
63 | num_samples = x.shape[0]
64 | prob = softmax(np.matmul(x, w) + np.random.normal(loc=0., scale=0.1, size=(num_samples, self.num_classes)), axis=1)
65 |
66 | y = np.argmax(prob, axis=1)
67 | return y, w, model_info
68 |
69 | def _generate_task(self, cluster_mean, cluster_id, num_samples):
70 | x = self._generate_x(num_samples)
71 | y, w, model_info = self._generate_y(x, cluster_mean)
72 |
73 | # now that we have y, we can remove the bias coeff
74 | x = x[:, 1:]
75 |
76 | return {'x': x, 'y': y, 'w': w, 'model_info': model_info, 'cluster': cluster_id}
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/synthetic/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import numpy as np
4 | import os
5 |
6 | import data_generator as generator
7 |
8 |
9 | PROB_CLUSTERS = [1.0]
10 |
11 |
12 | def main():
13 | args = parse_args()
14 | np.random.seed(args.seed)
15 |
16 | print('Generating dataset')
17 | num_samples = get_num_samples(args.num_tasks)
18 | dataset = generator.SyntheticDataset(
19 | num_classes=args.num_classes, prob_clusters=PROB_CLUSTERS, num_dim=args.num_dim, seed=args.seed)
20 | tasks = [dataset.get_task(s) for s in num_samples]
21 | users, num_samples, user_data = to_leaf_format(tasks)
22 | save_json('data/all_data', 'data.json', users, num_samples, user_data)
23 | print('Done :D')
24 |
25 |
26 | def get_num_samples(num_tasks, min_num_samples=5, max_num_samples=1000):
27 | num_samples = np.random.lognormal(3, 2, (num_tasks)).astype(int)
28 | num_samples = [min(s + min_num_samples, max_num_samples) for s in num_samples]
29 | return num_samples
30 |
31 |
32 | def to_leaf_format(tasks):
33 | users, num_samples, user_data = [], [], {}
34 |
35 | for i, t in enumerate(tasks):
36 | x, y = t['x'].tolist(), t['y'].tolist()
37 | u_id = str(i)
38 |
39 | users.append(u_id)
40 | num_samples.append(len(y))
41 | user_data[u_id] = {'x': x, 'y': y}
42 |
43 | return users, num_samples, user_data
44 |
45 |
46 | def save_json(json_dir, json_name, users, num_samples, user_data):
47 | if not os.path.exists(json_dir):
48 | os.makedirs(json_dir)
49 |
50 | json_file = {
51 | 'users': users,
52 | 'num_samples': num_samples,
53 | 'user_data': user_data,
54 | }
55 |
56 | with open(os.path.join(json_dir, json_name), 'w') as outfile:
57 | json.dump(json_file, outfile)
58 |
59 |
60 | def parse_args():
61 | parser = argparse.ArgumentParser()
62 |
63 | parser.add_argument(
64 | '-num-tasks',
65 | help='number of devices;',
66 | type=int,
67 | required=True)
68 | parser.add_argument(
69 | '-num-classes',
70 | help='number of classes;',
71 | type=int,
72 | required=True)
73 | parser.add_argument(
74 | '-num-dim',
75 | help='number of dimensions;',
76 | type=int,
77 | required=True)
78 | parser.add_argument(
79 | '-seed',
80 | help='seed for the random processes;',
81 | type=int,
82 | default=931231,
83 | required=False)
84 | return parser.parse_args()
85 |
86 |
87 | if __name__ == '__main__':
88 | main()
89 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/synthetic/preprocess.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # download data and convert to .json format
4 |
5 | if [ ! -f "data/all_data/data.json" ]; then
6 | echo "Please run the main.py script to generate the initial data."
7 | exit 1
8 | fi
9 |
10 |
11 | NAME="synthetic" # name of the dataset, equivalent to directory name
12 |
13 |
14 | cd ../utils
15 |
16 | bash preprocess.sh --name $NAME $@
17 |
18 | cd ../$NAME
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/synthetic/stats.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | NAME="synthetic"
4 |
5 | cd ../utils
6 |
7 | python3 stats.py --name $NAME
8 |
9 | cd ../$NAME
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/utils/README.md:
--------------------------------------------------------------------------------
1 | # UTILS README
2 |
3 | This folder contains leaf preprocessed scripts
4 |
5 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/datasets/utils/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/utils/constants.py:
--------------------------------------------------------------------------------
1 | DATASETS = ['sent140', 'femnist', 'shakespeare', 'celeba', 'synthetic']
2 | SEED_FILES = { 'sampling': 'sampling_seed.txt', 'split': 'split_seed.txt' }
3 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/utils/remove_users.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | removes users with less than the given number of samples
4 | '''
5 |
6 | import argparse
7 | import json
8 | import os
9 |
10 | from constants import DATASETS
11 |
12 | parser = argparse.ArgumentParser()
13 |
14 | parser.add_argument('--name',
15 | help='name of dataset to parse; default: sent140;',
16 | type=str,
17 | choices=DATASETS,
18 | default='sent140')
19 |
20 | parser.add_argument('--min_samples',
21 | help='users with less than x samples are discarded; default: 10;',
22 | type=int,
23 | default=10)
24 |
25 | args = parser.parse_args()
26 |
27 | print('------------------------------')
28 | print('removing users with less than %d samples' % args.min_samples)
29 |
30 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
31 | dir = os.path.join(parent_path, args.name, 'data')
32 | subdir = os.path.join(dir, 'sampled_data')
33 | files = []
34 | if os.path.exists(subdir):
35 | files = os.listdir(subdir)
36 | if len(files) == 0:
37 | subdir = os.path.join(dir, 'all_data')
38 | files = os.listdir(subdir)
39 | files = [f for f in files if f.endswith('.json')]
40 |
41 | for f in files:
42 | users = []
43 | hierarchies = []
44 | num_samples = []
45 | user_data = {}
46 |
47 | file_dir = os.path.join(subdir, f)
48 | with open(file_dir, 'r') as inf:
49 | data = json.load(inf)
50 |
51 | num_users = len(data['users'])
52 | for i in range(num_users):
53 | curr_user = data['users'][i]
54 | curr_hierarchy = None
55 | if 'hierarchies' in data:
56 | curr_hierarchy = data['hierarchies'][i]
57 | curr_num_samples = data['num_samples'][i]
58 | if (curr_num_samples >= args.min_samples):
59 | user_data[curr_user] = data['user_data'][curr_user]
60 | users.append(curr_user)
61 | if curr_hierarchy is not None:
62 | hierarchies.append(curr_hierarchy)
63 | num_samples.append(data['num_samples'][i])
64 |
65 | all_data = {}
66 | all_data['users'] = users
67 | if len(hierarchies) == len(users):
68 | all_data['hierarchies'] = hierarchies
69 | all_data['num_samples'] = num_samples
70 | all_data['user_data'] = user_data
71 |
72 | file_name = '%s_keep_%d.json' % ((f[:-5]), args.min_samples)
73 | ouf_dir = os.path.join(dir, 'rem_user_data', file_name)
74 |
75 | print('writing %s' % file_name)
76 | with open(ouf_dir, 'w') as outfile:
77 | json.dump(all_data, outfile)
78 |
79 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/utils/stats.py:
--------------------------------------------------------------------------------
1 | '''
2 | assumes that the user has already generated .json file(s) containing data
3 | '''
4 |
5 | import argparse
6 | import json
7 | import matplotlib.pyplot as plt
8 | import math
9 | import numpy as np
10 | import os
11 |
12 | from scipy import io
13 | from scipy import stats
14 |
15 | from constants import DATASETS
16 |
17 | parser = argparse.ArgumentParser()
18 |
19 | parser.add_argument('--name',
20 | help='name of dataset to parse; default: sent140;',
21 | type=str,
22 | choices=DATASETS,
23 | default='sent140')
24 |
25 | args = parser.parse_args()
26 |
27 |
28 | def load_data(name):
29 |
30 | users = []
31 | num_samples = []
32 |
33 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
34 | data_dir = os.path.join(parent_path, name, 'data')
35 | subdir = os.path.join(data_dir, 'all_data')
36 |
37 | files = os.listdir(subdir)
38 | files = [f for f in files if f.endswith('.json')]
39 |
40 | for f in files:
41 | file_dir = os.path.join(subdir, f)
42 |
43 | with open(file_dir) as inf:
44 | data = json.load(inf)
45 |
46 | users.extend(data['users'])
47 | num_samples.extend(data['num_samples'])
48 |
49 | return users, num_samples
50 |
51 | def print_dataset_stats(name):
52 | users, num_samples = load_data(name)
53 | num_users = len(users)
54 |
55 | print('####################################')
56 | print('DATASET: %s' % name)
57 | print('%d users' % num_users)
58 | print('%d samples (total)' % np.sum(num_samples))
59 | print('%.2f samples per user (mean)' % np.mean(num_samples))
60 | print('num_samples (std): %.2f' % np.std(num_samples))
61 | print('num_samples (std/mean): %.2f' % (np.std(num_samples)/np.mean(num_samples)))
62 | print('num_samples (skewness): %.2f' % stats.skew(num_samples))
63 |
64 | bins = [0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200]
65 | if args.name == 'shakespeare':
66 | bins = [0, 2000, 4000, 6000, 8000, 10000, 12000, 14000, 16000, 18000, 20000]
67 | if args.name == 'femnist':
68 | bins = [0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200, 220, 240, 260, 280, 300, 320, 340, 360, 380, 400, 420, 440, 460, 480, 500]
69 | if args.name == 'celeba':
70 | bins = [2 * i for i in range(20)]
71 | if args.name == 'sent140':
72 | bins = [i for i in range(16)]
73 |
74 | hist, edges = np.histogram(num_samples, bins=bins)
75 | for e, h in zip(edges, hist):
76 | print(e, "\t", h)
77 |
78 | parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
79 | data_dir = os.path.join(parent_path, name, 'data')
80 |
81 | plt.hist(num_samples, bins = bins)
82 | fig_name = "%s_hist_nolabel.png" % name
83 | fig_dir = os.path.join(data_dir, fig_name)
84 | plt.savefig(fig_dir)
85 | plt.title(name)
86 | plt.xlabel('number of samples')
87 | plt.ylabel("number of users")
88 | fig_name = "%s_hist.png" % name
89 | fig_dir = os.path.join(data_dir, fig_name)
90 | plt.savefig(fig_dir)
91 |
92 | print_dataset_stats(args.name)
93 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/datasets/utils/util.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 |
4 | def save_obj(obj, name):
5 | with open(name + '.pkl', 'wb') as f:
6 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
7 |
8 |
9 | def load_obj(name):
10 | with open(name + '.pkl', 'rb') as f:
11 | return pickle.load(f)
12 |
13 |
14 | def iid_divide(l, g):
15 | '''
16 | divide list l among g groups
17 | each group has either int(len(l)/g) or int(len(l)/g)+1 elements
18 | returns a list of groups
19 | '''
20 | num_elems = len(l)
21 | group_size = int(len(l)/g)
22 | num_big_groups = num_elems - g * group_size
23 | num_small_groups = g - num_big_groups
24 | glist = []
25 | for i in range(num_small_groups):
26 | glist.append(l[group_size * i : group_size * (i + 1)])
27 | bi = group_size*num_small_groups
28 | group_size += 1
29 | for i in range(num_big_groups):
30 | glist.append(l[bi + group_size * i:bi + group_size * (i + 1)])
31 | return glist
32 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/feature-skew-fedavg/README.md:
--------------------------------------------------------------------------------
1 | # README
2 |
3 | demo for NIID-bench "feature distribution skew"-"noise-based feature imbalance" on FedAvg.
4 |
5 | __Setting:__
6 | - `noise=0.1` (Gaussian noise)
7 | - `partition='homo'` (IID partition)
8 | - `n_parties=10`
9 | - `lr=0.01`
10 | - `momoentum=0.9`
11 | - `weight_decay=1e-5` (L2-norm weight decay)
12 | - `comm_round=50`
13 | - `sample=1` (client sampling ratio)
14 | - `alg='fedavg'`
15 | - `epochs=10`
16 | - `dataset='fmnist'`
17 | - `model='simple-cnn'`
18 |
19 | Top-1 accuracy for FMNIST in paper: $89.1\% \pm 0.3\%$.
20 |
21 | Top-1 accuracy for FMNIST in this demo: $89.37\% \pm 0.14 \%$ (5 runs)
22 |
23 | ## Requirements
24 |
25 | fedlab==1.1.2
26 |
27 | ## How to Run?
28 |
29 | `start_server.sh` is for server process launch, and `start_clt.sh` is for client process launch.
30 |
31 | 1. run command in terminal window 1 to launch server:
32 |
33 | ```bash
34 | bash start_server.sh
35 | ```
36 |
37 | 2. run command in terminal window 2 to launch clients:
38 |
39 | ```bash
40 | bash start_client.sh
41 | ```
42 |
43 | > random seed for data partiiton over clients can be set using `--seed` in `start_clt.sh`:
44 | >
45 | > ```shell
46 | > python data_partition.py --num-clients 10 --seed 1
47 | > ```
48 | >
49 | > And the noise distribution can be send with `--noise`:
50 | >
51 | > ```bash
52 | > python client.py --world_size 2 --rank 1 --noise 0.1
53 | > ```
54 |
55 |
56 |
57 | We highly recommend to launch clients after server is launched to avoid some conficts.
58 |
59 |
60 |
61 | ## References
62 |
63 | - Li, Q., Diao, Y., Chen, Q., & He, B. (2021). Federated learning on non-iid data silos: An experimental study. *arXiv preprint arXiv:2102.02079*.
64 |
65 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/feature-skew-fedavg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/feature-skew-fedavg/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/feature-skew-fedavg/config.py:
--------------------------------------------------------------------------------
1 | fmnist_noise_baseline_config = {
2 | "partition": "homo",
3 | "round": 50,
4 | "network": "simple-cnn",
5 | "sample_ratio": 1,
6 | "dataset": "fmnist",
7 | "total_client_num": 10,
8 | "lr": 0.01,
9 | "momentum": 0.9,
10 | "weight_decay": 1e-5,
11 | "batch_size": 64,
12 | "test_batch_size": 32,
13 | "epochs": 10
14 | }
15 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/feature-skew-fedavg/data_partition.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import sys
4 |
5 | from fedlab.utils.dataset import FMNISTPartitioner
6 | from fedlab.utils.functional import save_dict
7 |
8 | from torchvision.datasets import FashionMNIST
9 |
10 |
11 | # python data_partition.py --num-clients 10
12 | if __name__ == '__main__':
13 | parser = argparse.ArgumentParser(description='Data partition')
14 |
15 | parser.add_argument('--num-clients', type=int, default=10)
16 | parser.add_argument('--seed', type=int, default=2021)
17 | args = parser.parse_args()
18 |
19 | root = "../../../datasets/FMNIST"
20 | trainset = FashionMNIST(root=root, train=True, download=True)
21 |
22 | # perform partition
23 | partition = FMNISTPartitioner(trainset.targets,
24 | num_clients=args.num_clients,
25 | partition="iid",
26 | seed=args.seed)
27 | save_dict(partition.client_dict, "fmnist_iid.pkl")
28 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/feature-skew-fedavg/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | """
6 | Code below is from NIID-bench official code:
7 | https://github.com/Xtra-Computing/NIID-Bench
8 | """
9 | class SimpleCNNMNIST(nn.Module):
10 | def __init__(self, input_dim, hidden_dims, output_dim=10):
11 | super(SimpleCNNMNIST, self).__init__()
12 | self.conv1 = nn.Conv2d(1, 6, 5)
13 | self.pool = nn.MaxPool2d(2, 2)
14 | self.conv2 = nn.Conv2d(6, 16, 5)
15 |
16 | # for now, we hard coded this network
17 | # i.e. we fix the number of hidden layers i.e. 2 layers
18 | self.fc1 = nn.Linear(input_dim, hidden_dims[0])
19 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
20 | self.fc3 = nn.Linear(hidden_dims[1], output_dim)
21 |
22 | def forward(self, x):
23 | x = self.pool(F.relu(self.conv1(x)))
24 | x = self.pool(F.relu(self.conv2(x)))
25 | x = x.view(-1, 16 * 4 * 4)
26 |
27 | x = F.relu(self.fc1(x))
28 | x = F.relu(self.fc2(x))
29 | x = self.fc3(x)
30 | return x
31 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/feature-skew-fedavg/start_clt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # perform data partition over clients
3 | python data_partition.py --num-clients 10 --seed 1
4 | echo -e "New data partition done \n\n"
5 | # launch 10 clients in single serial trainer
6 | python client.py --world_size 2 --rank 1 --noise 0.1
7 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/feature-skew-fedavg/start_server.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python server.py --world_size 2 --run 0
3 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedasync/README.md:
--------------------------------------------------------------------------------
1 | # Fedasync
2 |
3 | The demo of [asynchronous federated optimization](https://arxiv.org/abs/1903.03934).
4 |
5 | ## Requirements
6 |
7 | fedlab==1.1.2
8 |
9 | ## Run
10 |
11 | standalone:
12 |
13 | `$ cd standalone`
14 |
15 | `$ python standalone.py`
16 |
17 |
18 | cross_process:
19 |
20 |
21 | `$ cd cross_process`
22 |
23 | `$ bash quick_start.sh`
24 |
25 | ## Performance
26 |
27 | Null
28 |
29 | ## References
30 |
31 | Xie, Cong, Sanmi Koyejo, and Indranil Gupta. "Asynchronous federated optimization." arXiv preprint arXiv:1903.03934 (2019).
32 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedasync/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedasync/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedasync/cross_process/client.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import torchvision.transforms as transforms
3 | import torch
4 | import argparse
5 | import sys
6 | import os
7 |
8 | from torch import nn
9 |
10 | from fedlab.core.client.manager import ClientActiveManager
11 | from fedlab.core.client.trainer import ClientSGDTrainer
12 | from fedlab.utils.dataset.sampler import RawPartitionSampler
13 | from fedlab.core.network import DistNetwork
14 |
15 | sys.path.append("../../")
16 | from models.cnn import CNN_MNIST
17 |
18 | def get_dataset(args):
19 | """
20 | :param dataset_name:
21 | :param transform:
22 | :param batch_size:
23 | :return: iterators for the datasetaccuracy_score
24 | """
25 | train_transform = transforms.Compose([
26 | transforms.ToTensor(),
27 | ])
28 |
29 | test_transform = transforms.Compose([
30 | transforms.ToTensor(),
31 | ])
32 | trainset = torchvision.datasets.MNIST(root=args.root,
33 | train=True,
34 | download=True,
35 | transform=train_transform)
36 | testset = torchvision.datasets.MNIST(root=args.root,
37 | train=False,
38 | download=True,
39 | transform=test_transform)
40 |
41 | trainloader = torch.utils.data.DataLoader(
42 | trainset,
43 | sampler=RawPartitionSampler(trainset,
44 | client_id=args.rank,
45 | num_replicas=args.world_size - 1),
46 | batch_size=128,
47 | drop_last=True,
48 | num_workers=2)
49 | testloader = torch.utils.data.DataLoader(testset,
50 | batch_size=len(testset),
51 | drop_last=False,
52 | num_workers=2,
53 | shuffle=False)
54 | return trainloader, testloader
55 |
56 |
57 | if __name__ == "__main__":
58 | parser = argparse.ArgumentParser(description='Distbelief training example')
59 | parser.add_argument('--ip', type=str, default='127.0.0.1')
60 | parser.add_argument('--port', type=str, default='3002')
61 | parser.add_argument('--world_size', type=int)
62 | parser.add_argument('--rank', type=int)
63 |
64 | parser.add_argument("--epoch", type=int, default=2)
65 | parser.add_argument("--lr", type=float, default=0.1)
66 | parser.add_argument("--wd", type=float, default=0)
67 | parser.add_argument("--cuda", type=bool, default=True)
68 | args = parser.parse_args()
69 | args.root = '../../datasets/mnist/'
70 | args.cuda = True
71 |
72 | model = CNN_MNIST()
73 | trainloader, testloader = get_dataset(args)
74 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
75 | criterion = nn.CrossEntropyLoss()
76 | handler = ClientSGDTrainer(model,
77 | trainloader,
78 | epochs=args.epoch,
79 | optimizer=optimizer,
80 | criterion=criterion,
81 | cuda=args.cuda)
82 |
83 | network = DistNetwork(address=(args.ip, args.port),
84 | world_size=args.world_size,
85 | rank=args.rank)
86 | Manager = ClientActiveManager(trainer=handler, network=network)
87 | Manager.run()
88 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedasync/cross_process/quick_start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | python server.py --world_size 3 &
5 |
6 | python client.py --world_size 3 --rank 1 &
7 | python client.py --world_size 3 --rank 2 &
8 |
9 | wait
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedasync/cross_process/server.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | from fedlab.core.network import DistNetwork
5 | from fedlab.core.server.handler import AsyncParameterServerHandler
6 | from fedlab.core.server.manager import ServerAsynchronousManager
7 | sys.path.append("../../")
8 | from models.cnn import CNN_MNIST
9 | import argparse
10 |
11 | if __name__ == "__main__":
12 | parser = argparse.ArgumentParser(description='Distbelief training example')
13 |
14 | parser.add_argument('--ip', type=str, default='127.0.0.1')
15 | parser.add_argument('--port', type=str, default='3002')
16 | parser.add_argument('--world_size', type=int)
17 | args = parser.parse_args()
18 |
19 | model = CNN_MNIST().cpu()
20 | ps = AsyncParameterServerHandler(model)
21 |
22 | network = DistNetwork(address=(args.ip, args.port),
23 | world_size=args.world_size,
24 | rank=0)
25 | Manager = ServerAsynchronousManager(handler=ps, network=network)
26 |
27 | Manager.run()
28 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedasync/standalone/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedasync/standalone/cifar10_iid.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedasync/standalone/cifar10_iid.pkl
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedasync/standalone/cifar10_noniid.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedasync/standalone/cifar10_noniid.pkl
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/README.md:
--------------------------------------------------------------------------------
1 | # FedAvg
2 |
3 | [FedAvg](http://proceedings.mlr.press/v54/mcmahan17a.html) is the baseline of synchronous federated learning algorithm, and FedLab implements the algorithm flow of FedAvg, including standalone and Cross Process scenarios.
4 |
5 | ## Requirements
6 |
7 | fedlab==1.1.2
8 |
9 | ## Run
10 |
11 | ### Standalone
12 |
13 | The` SerialTrainer` module is for the FL system simulation on a single machine, and its source code can be found in `fedlab/core fedlab/core/client/trainer.py`.
14 |
15 | Executable scripts is in ` fedlab_benchmarks/algorithm/fedavg/standalone/`.
16 |
17 | ### Cross Process
18 |
19 | The federated simulation of **multi-machine** and **single-machine multi-process** scenarios is the core module of FedLab, which is composed of various modules in `core/client` and `core/server`, please refer to overview for details .
20 |
21 | The executable script is in `fedlab_benchmarks/algorithm/fedavg/cross_process/`
22 |
23 | ## Performance
24 |
25 | Null
26 |
27 | ## References
28 |
29 | McMahan, Brendan, et al. "Communication-efficient learning of deep networks from decentralized data." Artificial intelligence and statistics. PMLR, 2017.
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/cross_process/LEAF_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # get world size dicts for all datasets and run all client processes for each dataset
4 | declare -A dataset_world_size=(['femnist']=5 ['shakespeare']=3) # each process represents one client/server
5 |
6 | for key in ${!dataset_world_size[*]}
7 | do
8 | echo "${key} client_num is ${dataset_world_size[${key}]}"
9 |
10 | echo "server started"
11 | python server.py --ip 127.0.0.1 --port 3002 --world_size ${dataset_world_size[${key}]} --dataset ${key} &
12 |
13 | for ((i=1; i<${dataset_world_size[$key]}; i++))
14 | do
15 | {
16 | echo "client ${i} started"
17 | python client.py --ip 127.0.0.1 --port 3002 --world_size ${dataset_world_size[${key}]} --rank ${i} --dataset ${key} --epoch 2
18 | } &
19 | done
20 | wait
21 |
22 | echo "${key} experiment end"
23 | done
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/cross_process/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/cross_process/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/cross_process/client.py:
--------------------------------------------------------------------------------
1 | from logging import log
2 | import torch
3 | import argparse
4 | import sys
5 | import os
6 |
7 | from torch import nn
8 | from fedlab.core.client.manager import ClientPassiveManager
9 | from fedlab.core.client.trainer import ClientSGDTrainer
10 | from fedlab.core.network import DistNetwork
11 | from fedlab.utils.logger import Logger
12 |
13 | from setting import get_model, get_dataset
14 |
15 | if __name__ == "__main__":
16 |
17 | parser = argparse.ArgumentParser(description="Distbelief training example")
18 |
19 | parser.add_argument("--ip", type=str)
20 | parser.add_argument("--port", type=str)
21 | parser.add_argument("--world_size", type=int)
22 | parser.add_argument("--rank", type=int)
23 |
24 | parser.add_argument("--lr", type=float, default=0.01)
25 | parser.add_argument("--epoch", type=int, default=5)
26 | parser.add_argument("--dataset", type=str)
27 | parser.add_argument("--batch_size", type=int, default=100)
28 |
29 | parser.add_argument("--gpu", type=str, default="0,1,2,3")
30 | parser.add_argument("--ethernet", type=str, default=None)
31 | args = parser.parse_args()
32 |
33 | if args.gpu != "-1":
34 | args.cuda = True
35 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
36 | else:
37 | args.cuda = False
38 |
39 | model = get_model(args)
40 | trainloader, testloader = get_dataset(args)
41 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
42 | criterion = nn.CrossEntropyLoss()
43 |
44 | network = DistNetwork(
45 | address=(args.ip, args.port),
46 | world_size=args.world_size,
47 | rank=args.rank,
48 | ethernet=args.ethernet,
49 | )
50 |
51 | LOGGER = Logger(log_name="client " + str(args.rank))
52 |
53 | trainer = ClientSGDTrainer(
54 | model,
55 | trainloader,
56 | epochs=args.epoch,
57 | optimizer=optimizer,
58 | criterion=criterion,
59 | cuda=args.cuda,
60 | logger=LOGGER,
61 | )
62 |
63 | manager_ = ClientPassiveManager(trainer=trainer,
64 | network=network,
65 | logger=LOGGER)
66 | manager_.run()
67 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/cross_process/quick_start.sh:
--------------------------------------------------------------------------------
1 | #!bin/bash
2 |
3 | python server.py --ip 127.0.0.1 --port 3002 --world_size 3 --dataset mnist --round 3 &
4 |
5 | python client.py --ip 127.0.0.1 --port 3002 --world_size 3 --rank 1 --dataset mnist &
6 |
7 | python client.py --ip 127.0.0.1 --port 3002 --world_size 3 --rank 2 --dataset mnist &
8 |
9 | wait
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/cross_process/server.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from fedlab.utils.logger import Logger
4 | from fedlab.core.server.handler import SyncParameterServerHandler
5 | from fedlab.core.server.manager import ServerSynchronousManager
6 | from fedlab.core.network import DistNetwork
7 | from setting import get_model, get_dataset
8 |
9 | if __name__ == "__main__":
10 | parser = argparse.ArgumentParser(description='FL server example')
11 |
12 | parser.add_argument('--ip', type=str)
13 | parser.add_argument('--port', type=str)
14 | parser.add_argument('--world_size', type=int)
15 |
16 | parser.add_argument('--round', type=int, default=5)
17 | parser.add_argument('--dataset', type=str)
18 | parser.add_argument('--ethernet', type=str, default=None)
19 | parser.add_argument('--sample', type=float, default=1)
20 |
21 | args = parser.parse_args()
22 |
23 | model = get_model(args)
24 | LOGGER = Logger(log_name="server")
25 | handler = SyncParameterServerHandler(model,
26 | global_round=args.round,
27 | logger=LOGGER,
28 | sample_ratio=args.sample)
29 | network = DistNetwork(address=(args.ip, args.port),
30 | world_size=args.world_size,
31 | rank=0,
32 | ethernet=args.ethernet)
33 | manager_ = ServerSynchronousManager(handler=handler,
34 | network=network,
35 | logger=LOGGER)
36 | manager_.run()
37 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/cross_process/start_clients.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # start a group of client. continuous rank is required.
4 | # example: bash start_clients.sh ip port wolrd_size 1 5 dataset ethernet epoch
5 | # start client from rank 1-5
6 |
7 | echo "Connecting server:($1:$2), world_size $3, rank $4-$5, dataset $6, ethernet $7, local epoch $8"
8 |
9 | for ((i=$4; i<=$5; i++))
10 | do
11 | {
12 | echo "client ${i} started"
13 | python client.py --server_ip $1 --server_port $2 --world_size $3 --rank ${i} --dataset $6 --ethernet $7 --epoch $8 &
14 | sleep 2s # waiting gpu resource allocation
15 | }
16 | done
17 | wait
18 |
19 |
20 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/scale/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/cifar10-cnn/cifar10_iid.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/scale/cifar10-cnn/cifar10_iid.pkl
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/cifar10-cnn/cifar10_noniid.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/scale/cifar10-cnn/cifar10_noniid.pkl
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/cifar10-cnn/cifar10_partition.py:
--------------------------------------------------------------------------------
1 | from fedlab.utils.functional import save_dict
2 | from fedlab.utils.dataset.slicing import noniid_slicing, random_slicing
3 |
4 | import torchvision
5 |
6 | root = '../../../../datasets/data/cifar10/'
7 | trainset = torchvision.datasets.CIFAR10(root=root, train=True, download=False)
8 |
9 | data_indices = noniid_slicing(trainset, num_clients=100, num_shards=200)
10 | save_dict(data_indices, "cifar10_noniid.pkl")
11 |
12 | data_indices = random_slicing(trainset, num_clients=100)
13 | save_dict(data_indices, "cifar10_iid.pkl")
14 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/cifar10-cnn/client.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import os
4 |
5 | from torch import nn
6 | import torch.nn.functional as F
7 | import torchvision
8 | import torchvision.transforms as transforms
9 |
10 | torch.manual_seed(0)
11 |
12 | from fedlab.core.client.scale.trainer import SubsetSerialTrainer
13 | from fedlab.core.client.scale.manager import ScaleClientPassiveManager
14 | from fedlab.core.network import DistNetwork
15 |
16 | from fedlab.utils.serialization import SerializationTool
17 | from fedlab.utils.logger import Logger
18 | from fedlab.utils.aggregator import Aggregators
19 | from fedlab.utils.functional import load_dict
20 |
21 | import sys
22 |
23 | sys.path.append("../../../")
24 | from models.cnn import AlexNet_CIFAR10, CNN_CIFAR10
25 |
26 | from config import cifar10_noniid_baseline_config, cifar10_iid_baseline_config
27 |
28 | if __name__ == "__main__":
29 | parser = argparse.ArgumentParser(description="Distbelief training example")
30 |
31 | parser.add_argument("--ip", type=str, default="127.0.0.1")
32 | parser.add_argument("--port", type=str, default="3003")
33 | parser.add_argument("--world_size", type=int)
34 | parser.add_argument("--rank", type=int)
35 | parser.add_argument("--ethernet", type=str, default=None)
36 |
37 | parser.add_argument("--setting", type=str)
38 | args = parser.parse_args()
39 |
40 | if args.setting == 'iid':
41 | config = cifar10_iid_baseline_config
42 | else:
43 | config = cifar10_noniid_baseline_config
44 |
45 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
46 |
47 | transform_train = transforms.Compose([
48 | transforms.RandomCrop(32, padding=4),
49 | transforms.RandomHorizontalFlip(),
50 | transforms.ToTensor(),
51 | transforms.Normalize((0.4914, 0.4822, 0.4465),
52 | (0.2023, 0.1994, 0.2010))
53 | ])
54 |
55 | trainset = torchvision.datasets.CIFAR10(root='../../../datasets/cifar10/',
56 | train=True,
57 | download=True,
58 | transform=transform_train)
59 |
60 | if config['partition'] == "noniid":
61 | data_indices = load_dict("cifar10_noniid.pkl")
62 | if config['partition'] == "iid":
63 | data_indices = load_dict("cifar10_iid.pkl")
64 |
65 | # Process rank x represent client id from (x-1)*10 - (x-1)*10 +10
66 | # e.g. rank 5 <--> client 40-50
67 | client_id_list = [
68 | i for i in range((args.rank - 1) * 10, (args.rank - 1) * 10 + 10)
69 | ]
70 |
71 | # get corresponding data partition indices
72 | sub_data_indices = {
73 | idx: data_indices[cid]
74 | for idx, cid in enumerate(client_id_list)
75 | }
76 |
77 | #model = CNN_Cifar10()
78 | model = AlexNet_CIFAR10()
79 |
80 | aggregator = Aggregators.fedavg_aggregate
81 |
82 | network = DistNetwork(address=(args.ip, args.port),
83 | world_size=args.world_size,
84 | rank=args.rank,
85 | ethernet=args.ethernet)
86 |
87 | trainer = SubsetSerialTrainer(model=model,
88 | dataset=trainset,
89 | data_slices=sub_data_indices,
90 | aggregator=aggregator,
91 | args=config)
92 |
93 | manager_ = ScaleClientPassiveManager(trainer=trainer, network=network)
94 |
95 | manager_.run()
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/cifar10-cnn/config.py:
--------------------------------------------------------------------------------
1 | cifar10_iid_baseline_config = {
2 | "partition": "iid",
3 | "round": 4000,
4 | "network": "alexnet",
5 | "sample_ratio": 0.1,
6 | "dataset": "cifar10",
7 | "total_client_num": 100,
8 | "lr": 0.1,
9 | "batch_size": 100,
10 | "epochs": 5
11 | }
12 |
13 | cifar10_noniid_baseline_config = {
14 | "partition": "noniid",
15 | "round": 4000,
16 | "network": "alexnet",
17 | "sample_ratio": 0.1,
18 | "dataset": "cifar10",
19 | "total_client_num": 100,
20 | "lr": 0.1,
21 | "batch_size": 100,
22 | "epochs": 5
23 | }
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/cifar10-cnn/start_clt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # TODO: try to add auto client assignment script
3 | # TODO: bash start_clt.sh [world_size] [num_clients] [world_size]
4 | # TODO: where world_size = 1 + client_ranks_num
5 | # bash start_clt.sh 11 1 10 [iid/noniid]
6 | for ((i=$2; i<=$3; i++))
7 | do
8 | {
9 | echo "client ${i} started"
10 | python client.py --world_size $1 --rank ${i} --setting $4&
11 | sleep 2s
12 | }
13 | done
14 | wait
15 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/femnist-cnn/server.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | import torchvision
8 | from torchvision import transforms
9 |
10 | torch.manual_seed(0)
11 |
12 | from fedlab.core.server.handler import SyncParameterServerHandler
13 | from fedlab.core.server.scale.manager import ScaleSynchronousManager
14 | from fedlab.core.network import DistNetwork
15 | from fedlab.utils.functional import AverageMeter
16 |
17 | sys.path.append("../../../")
18 | from models.cnn import CNN_FEMNIST
19 |
20 | # python server.py --world_size 11
21 | if __name__ == "__main__":
22 | parser = argparse.ArgumentParser(description='FL server example')
23 |
24 | parser.add_argument('--ip', type=str, default="127.0.0.1")
25 | parser.add_argument('--port', type=str, default="3002")
26 | parser.add_argument('--world_size', type=int)
27 |
28 | parser.add_argument('--round', type=int, default=1000)
29 | parser.add_argument('--ethernet', type=str, default=None)
30 | parser.add_argument('--sample', type=float, default=0.01)
31 |
32 | args = parser.parse_args()
33 |
34 | model = CNN_FEMNIST()
35 |
36 | handler = SyncParameterServerHandler(model,
37 | global_round=args.round,
38 | sample_ratio=args.sample,
39 | cuda=True)
40 |
41 | network = DistNetwork(address=(args.ip, args.port),
42 | world_size=args.world_size,
43 | rank=0)
44 |
45 | manager_ = ScaleSynchronousManager(network=network, handler=handler)
46 | manager_.run()
47 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/femnist-cnn/start_clt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for ((i=$2; i<=$3; i++))
4 | do
5 | {
6 | echo "client ${i} started"
7 | python client.py --world_size $1 --rank ${i} &
8 | sleep 2s
9 | }
10 | done
11 | wait
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/client.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import sys
4 | import os
5 |
6 | import torchvision
7 | import torchvision.transforms as transforms
8 |
9 | from fedlab.core.client.scale.trainer import SubsetSerialTrainer
10 | from fedlab.core.client.scale.manager import ScaleClientPassiveManager
11 | from fedlab.core.network import DistNetwork
12 |
13 | from fedlab.utils.logger import Logger
14 | from fedlab.utils.aggregator import Aggregators
15 | from fedlab.utils.functional import load_dict
16 |
17 | sys.path.append("../../../")
18 | from models.cnn import CNN_MNIST
19 |
20 | if __name__ == "__main__":
21 |
22 | parser = argparse.ArgumentParser(description="Distbelief training example")
23 |
24 | parser.add_argument("--ip", type=str, default="127.0.0.1")
25 | parser.add_argument("--port", type=str, default="3002")
26 | parser.add_argument("--world_size", type=int)
27 | parser.add_argument("--rank", type=int)
28 |
29 | parser.add_argument("--partition", type=str, default="noniid")
30 |
31 | parser.add_argument("--gpu", type=str, default="0,1,2,3")
32 | parser.add_argument("--ethernet", type=str, default=None)
33 |
34 | args = parser.parse_args()
35 |
36 | if args.gpu != "-1":
37 | args.cuda = True
38 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
39 | else:
40 | args.cuda = False
41 |
42 | trainset = torchvision.datasets.MNIST(
43 | root='../../../datasets/mnist/',
44 | train=True,
45 | download=True,
46 | transform=transforms.ToTensor())
47 |
48 | if args.partition == "noniid":
49 | data_indices = load_dict("mnist_noniid.pkl")
50 | elif args.partition == "iid":
51 | data_indices = load_dict("mnist_iid.pkl")
52 | else:
53 | raise ValueError("invalid partition type ", args.partition)
54 |
55 | # Process rank x represent client id from (x-1)*10 - (x-1)*10 +10
56 | # e.g. rank 5 <--> client 40-50
57 | client_id_list = [
58 | i for i in range((args.rank - 1) * 10, (args.rank - 1) * 10 + 10)
59 | ]
60 |
61 | # get corresponding data partition indices
62 | sub_data_indices = {
63 | idx: data_indices[cid]
64 | for idx, cid in enumerate(client_id_list)
65 | }
66 |
67 | model = CNN_MNIST()
68 |
69 | aggregator = Aggregators.fedavg_aggregate
70 |
71 | network = DistNetwork(address=(args.ip, args.port),
72 | world_size=args.world_size,
73 | rank=args.rank,
74 | ethernet=args.ethernet)
75 |
76 | trainer = SubsetSerialTrainer(model=model,
77 | dataset=trainset,
78 | data_slices=sub_data_indices,
79 | aggregator=aggregator,
80 | args={
81 | "batch_size": 100,
82 | "lr": 0.02,
83 | "epochs": 5
84 | })
85 |
86 | manager_ = ScaleClientPassiveManager(trainer=trainer, network=network)
87 |
88 | manager_.run()
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/mnist_iid.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/mnist_iid.pkl
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/mnist_noniid.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/mnist_noniid.pkl
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/mnist_partition.py:
--------------------------------------------------------------------------------
1 | from fedlab.utils.functional import save_dict
2 | from fedlab.utils.dataset.slicing import noniid_slicing, random_slicing
3 |
4 | import torchvision
5 |
6 | root = '../../../../datasets/data/mnist/'
7 | trainset = torchvision.datasets.MNIST(root=root, train=True, download=True)
8 |
9 | data_indices = noniid_slicing(trainset, num_clients=100, num_shards=200)
10 | save_dict(data_indices, "mnist_noniid.pkl")
11 |
12 | data_indices = random_slicing(trainset, num_clients=100)
13 | save_dict(data_indices, "mnist_iid.pkl")
14 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/mnist-cnn/start_clt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for ((i=$2; i<=$3; i++))
4 | do
5 | {
6 | echo "client ${i} started"
7 | python client.py --ip 127.0.0.1 --port 3002 --world_size $1 --rank ${i} &
8 | sleep 2s
9 | }
10 | done
11 | wait
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/shakespeare-rnn/server.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 | from torchvision import transforms
7 |
8 | torch.manual_seed(0)
9 |
10 | from fedlab.core.server.handler import SyncParameterServerHandler
11 | from fedlab.core.server.scale.manager import ScaleSynchronousManager
12 | from fedlab.core.network import DistNetwork
13 | from fedlab.utils.functional import evaluate
14 |
15 | import sys
16 |
17 | sys.path.append('../../../')
18 | from models.rnn import RNN_Shakespeare
19 | from leaf.dataloader import get_LEAF_all_test_dataloader
20 |
21 | # python server.py --world_size 11
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser(description='FL server example')
24 |
25 | parser.add_argument('--ip', type=str, default="127.0.0.1")
26 | parser.add_argument('--port', type=str, default="3002")
27 | parser.add_argument('--world_size', type=int)
28 |
29 | parser.add_argument('--round', type=int, default=2)
30 | parser.add_argument('--ethernet', type=str, default=None)
31 | parser.add_argument('--sample', type=float, default=0.05)
32 |
33 | args = parser.parse_args()
34 |
35 | model = RNN_Shakespeare()
36 |
37 | handler = SyncParameterServerHandler(model,
38 | global_round=args.round,
39 | sample_ratio=args.sample,
40 | cuda=True)
41 |
42 | network = DistNetwork(address=(args.ip, args.port),
43 | world_size=args.world_size,
44 | rank=0)
45 |
46 | manager_ = ScaleSynchronousManager(network=network, handler=handler)
47 | manager_.run()
48 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/scale/shakespeare-rnn/start_clt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for ((i=$2; i<=$3; i++))
4 | do
5 | {
6 | echo "client ${i} started"
7 | python client.py --world_size $1 --rank ${i} &
8 | sleep 2s
9 | }
10 | done
11 | wait
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/standalone/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/standalone/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/standalone/mnist_iid.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/standalone/mnist_iid.pkl
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.1.2/standalone/mnist_noniid.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.1.2/standalone/mnist_noniid.pkl
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.2.0/client.py:
--------------------------------------------------------------------------------
1 | from logging import log
2 | import torch
3 | import argparse
4 | import sys
5 | import os
6 |
7 | from torch import nn
8 | from fedlab.core.client.manager import PassiveClientManager
9 | from fedlab.core.client.trainer import SGDClientTrainer
10 | from fedlab.core.network import DistNetwork
11 | from fedlab.utils.logger import Logger
12 |
13 | from setting import get_model, get_dataset
14 |
15 | if __name__ == "__main__":
16 |
17 | parser = argparse.ArgumentParser(description="Distbelief training example")
18 |
19 | parser.add_argument("--ip", type=str)
20 | parser.add_argument("--port", type=str)
21 | parser.add_argument("--world_size", type=int)
22 | parser.add_argument("--rank", type=int)
23 |
24 | parser.add_argument("--lr", type=float, default=0.01)
25 | parser.add_argument("--epoch", type=int, default=5)
26 | parser.add_argument("--dataset", type=str, default="mnist")
27 | parser.add_argument("--batch_size", type=int, default=100)
28 |
29 | parser.add_argument("--gpu", type=str, default="0,1,2,3")
30 | parser.add_argument("--ethernet", type=str, default=None)
31 | args = parser.parse_args()
32 |
33 | if args.gpu != "-1":
34 | args.cuda = True
35 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
36 | else:
37 | args.cuda = False
38 |
39 | model = get_model(args)
40 | trainloader, testloader = get_dataset(args)
41 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
42 | criterion = nn.CrossEntropyLoss()
43 |
44 | network = DistNetwork(
45 | address=(args.ip, args.port),
46 | world_size=args.world_size,
47 | rank=args.rank,
48 | ethernet=args.ethernet,
49 | )
50 |
51 | LOGGER = Logger(log_name="client " + str(args.rank))
52 |
53 | trainer = SGDClientTrainer(
54 | model,
55 | trainloader,
56 | epochs=args.epoch,
57 | optimizer=optimizer,
58 | criterion=criterion,
59 | cuda=args.cuda,
60 | logger=LOGGER,
61 | )
62 |
63 | manager_ = PassiveClientManager(trainer=trainer,
64 | network=network,
65 | logger=LOGGER)
66 | manager_.run()
67 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.2.0/mnist_partition.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedavg_v1.2.0/mnist_partition.pkl
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.2.0/run.sh:
--------------------------------------------------------------------------------
1 | #!bin/bash
2 |
3 | python server.py --ip 127.0.0.1 --port 3002 --world_size 11 --dataset mnist &
4 |
5 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 1 --dataset mnist &
6 |
7 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 2 --dataset mnist &
8 |
9 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 3 --dataset mnist &
10 |
11 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 4 --dataset mnist &
12 |
13 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 5 --dataset mnist &
14 |
15 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 6 --dataset mnist &
16 |
17 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 7 --dataset mnist &
18 |
19 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 8 --dataset mnist &
20 |
21 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 9 --dataset mnist &
22 |
23 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 10 --dataset mnist &
24 |
25 | wait
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedavg_v1.2.0/setting.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torchvision.transforms as transforms
4 | import sys
5 |
6 | from fedlab.utils.dataset.sampler import RawPartitionSampler
7 |
8 | sys.path.append('../')
9 |
10 | from models.cnn import CNN_CIFAR10, CNN_FEMNIST, CNN_MNIST
11 | from models.rnn import RNN_Shakespeare
12 | from models.mlp import MLP_CelebA
13 | from leaf.dataloader import get_LEAF_dataloader
14 |
15 | def get_dataset(args):
16 | if args.dataset == 'mnist':
17 | root = '../datasets/mnist/'
18 | train_transform = transforms.Compose([
19 | transforms.ToTensor(),
20 | ])
21 | test_transform = transforms.Compose([
22 | transforms.ToTensor(),
23 | ])
24 | trainset = torchvision.datasets.MNIST(root=root,
25 | train=True,
26 | download=True,
27 | transform=train_transform)
28 |
29 | testset = torchvision.datasets.MNIST(root=root,
30 | train=False,
31 | download=True,
32 | transform=test_transform)
33 |
34 | trainloader = torch.utils.data.DataLoader(
35 | trainset,
36 | sampler=RawPartitionSampler(trainset,
37 | client_id=args.rank,
38 | num_replicas=args.world_size - 1),
39 | batch_size=args.batch_size,
40 | drop_last=True,
41 | num_workers=args.world_size)
42 |
43 | testloader = torch.utils.data.DataLoader(testset,
44 | batch_size=int(
45 | len(testset) / 10),
46 | drop_last=False,
47 | shuffle=False)
48 | elif args.dataset == 'femnist':
49 | trainloader, testloader = get_LEAF_dataloader(dataset=args.dataset,
50 | client_id=args.rank)
51 | elif args.dataset == 'shakespeare':
52 | trainloader, testloader = get_LEAF_dataloader(dataset=args.dataset,
53 | client_id=args.rank)
54 | elif args.dataset == 'celeba':
55 | trainloader, testloader = get_LEAF_dataloader(dataset=args.dataset,
56 | client_id=args.rank)
57 | else:
58 | raise ValueError("Invalid dataset:", args.dataset)
59 |
60 | return trainloader, testloader
61 |
62 |
63 | def get_model(args):
64 | if args.dataset == "mnist":
65 | model = CNN_MNIST()
66 | elif args.dataset == 'femnist':
67 | model = CNN_FEMNIST()
68 | elif args.dataset == 'shakespeare':
69 | model = RNN_Shakespeare()
70 | elif args.dataset == 'celeba':
71 | model = MLP_CelebA()
72 | else:
73 | raise ValueError("Invalid dataset:", args.dataset)
74 |
75 | return model
--------------------------------------------------------------------------------
/fedlab_benchmarks/feddyn/Output/CIFAR10_100_iid_plots.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/feddyn/Output/CIFAR10_100_iid_plots.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/feddyn/README.md:
--------------------------------------------------------------------------------
1 | ## FedDyn
2 |
3 | demo for FedDyn using FedLab in scale mode.
4 |
5 | ### Setting
6 |
7 | - ``dataset``: CIFAR10
8 | - `partition`: iid
9 | - `balance`: `True`
10 | - `batch_size`: 50
11 | - ``num_clients``: 100
12 | - `round`: 1000
13 | - `epochs`: 5
14 | - `lr`: 0.1
15 | - `alpha_coef`: 1e-2
16 | - `weight_decay`: 1e-3
17 | - `max_norm`: 10
18 | - `sample_ratio`: 1.0
19 |
20 | ### Requirements
21 |
22 | fedlab=1.1.2
23 |
24 | ### How to run?
25 |
26 | `start_server.sh` is for server process launch, and `start_clt.sh` is for client process launch.
27 |
28 | 1. run command in terminal window 1 to launch server:
29 |
30 | ```bash
31 | bash start_server.sh
32 | ```
33 |
34 | 2. run command in terminal window 2 to launch clients:
35 |
36 | ```bash
37 | bash start_clt.sh
38 | ```
39 |
40 | > random seed for data partiiton over clients can be set using `--seed` in `start_server.sh`:
41 | >
42 | > ```bash
43 | > python data_partition.py --out-dir ./Output/FedDyn/run1 --partition iid --balance True --dataset cifar10 --num-clients ${ClientNum} --seed 1
44 | > ```
45 |
46 |
47 |
48 | We highly recommend to launch clients after server is launched to avoid some conficts.
49 |
50 |
51 |
52 | ### One-run Result
53 |
54 | | | FedDyn (Paper) | FedDyn (FedDyn code) | FedDyn (FedLab) | FedAvg (FedDyn code) | FedAvg (FedLab code) |
55 | | ------------------------ | :------------: | :------------------: | :-------------: | :------------------: | :------------------: |
56 | | Round for $acc>81.40\%$ | 67 | 64 | 65 | 491 | 423 |
57 | | Round for $acc>85.00\%$ | 198 | 185 | 195 | > 1000 | > 1000 |
58 |
59 |
60 |
61 | ### Duration
62 |
63 | | | FedAvg | FedDyn |
64 | | -------------- | :------------: | :------------: |
65 | | FedDyn code | 474.98 Min | 537.13 Min |
66 | | FedLab (scale) | __143.60 Min__ | __253.17 Min__ |
67 |
68 | ### Environment
69 |
70 | - CPU: 128G Memory, 32 cores, Intel(R) Core(TM) i9-9960X CPU @ 3.10GHz,
71 | - 4 * NVDIA GEFORCE RTX 2080 Ti
72 |
73 |
74 |
75 | ### Reference
76 |
77 | - Acar, D. A. E., Zhao, Y., Matas, R., Mattina, M., Whatmough, P., & Saligrama, V. (2020, September). Federated learning based on dynamic regularization. In *International Conference on Learning Representations*.
78 |
79 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/feddyn/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 9/27/21 12:12 AM
3 | # @Author : Siqi Liang
4 | # @Contact : zszxlsq@gmail.com
5 | # @File : __init__.py
6 | # @Software: PyCharm
7 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/feddyn/config.py:
--------------------------------------------------------------------------------
1 | cifar10_config = {
2 | 'num_clients': 100,
3 | 'model_name': 'Cifar10Net', # Model type
4 | 'round': 1000,
5 | 'save_period': 200,
6 | 'weight_decay': 1e-3,
7 | 'batch_size': 50,
8 | 'test_batch_size': 256, # no this param in official code
9 | 'lr_decay_per_round': 1,
10 | 'epochs': 5,
11 | 'lr': 0.1,
12 | 'print_freq': 5,
13 | 'alpha_coef': 1e-2,
14 | 'max_norm': 10,
15 | 'sample_ratio': 1,
16 | 'partition': 'iid',
17 | 'dataset': 'cifar10',
18 | }
19 |
20 |
21 | debug_config = {
22 | 'num_clients': 30,
23 | 'model_name': 'Cifar10Net', # Model type
24 | 'round': 5,
25 | 'save_period': 2,
26 | 'weight_decay': 1e-3,
27 | 'batch_size': 50,
28 | 'test_batch_size': 50,
29 | 'act_prob': 1,
30 | 'lr_decay_per_round': 1,
31 | 'epochs': 5,
32 | 'lr': 0.1,
33 | 'print_freq': 1,
34 | 'alpha_coef': 1e-2,
35 | 'max_norm': 10,
36 | 'sample_ratio': 1,
37 | 'partition': 'iid',
38 | 'dataset': 'cifar10'
39 | }
40 |
41 | # usage: local_params_file_pattern.format(cid=cid)
42 | local_grad_vector_file_pattern = "client_{cid:03d}_local_grad_vector.pt" # accumulated model gradient
43 | clnt_params_file_pattern = "client_{cid:03d}_clnt_params.pt" # latest model param
44 |
45 | local_grad_vector_list_file_pattern = "client_rank_{rank:02d}_local_grad_vector_list.pt" # accumulated model gradient for clients in one client process
46 | clnt_params_list_file_pattern = "client_rank_{rank:02d}_clnt_params_list.pt" # latest model param for clients in one client process
47 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/feddyn/data_partition.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from pathlib import Path
4 |
5 | from torchvision.datasets import CIFAR10, CIFAR100
6 |
7 | import sys
8 | sys.path.append("../../../FedLab/")
9 |
10 | from fedlab.utils.dataset import CIFAR10Partitioner, CIFAR100Partitioner
11 | from fedlab.utils.functional import partition_report, save_dict, load_dict
12 |
13 |
14 | def get_exp_name(args):
15 | exp_name = ""
16 | args_dict = vars(args)
17 | exclude_keys = ["out_dit", "data_dir"]
18 |
19 | for key in sorted(args_dict.keys()):
20 | exp_name += f"{key}_"
21 | if key not in exclude_keys:
22 | value = args_dict[key]
23 | if isinstance(value, float):
24 | exp_name += f"{value:.3f}_"
25 | else:
26 | exp_name += f"{value}_"
27 |
28 | return exp_name[:-1]
29 |
30 |
31 | if __name__ == "__main__":
32 | parser = argparse.ArgumentParser(description="FedDyn implementation: Client scale mode")
33 |
34 | parser.add_argument("--data-dir", type=str, default="../../../datasets/")
35 | parser.add_argument("--out-dir", type=str, default="./Output/")
36 | parser.add_argument("--dataset", type=str, default="cifar10",
37 | help="Currently only 'cifar10' and 'cifar100' are supported")
38 | parser.add_argument("--num-clients", type=int, default=100)
39 | parser.add_argument("--partition", type=str, default="iid")
40 | parser.add_argument("--balance", type=bool, default=None)
41 | parser.add_argument("--unbalance-sgm", type=float, default=0)
42 | parser.add_argument("--dir-alpha", type=float, default=None)
43 | parser.add_argument("--num-shards", type=int, default=None)
44 | parser.add_argument("--seed", type=int, default=0)
45 |
46 | args = parser.parse_args()
47 |
48 | Path(args.data_dir).mkdir(parents=True, exist_ok=True)
49 | Path(args.out_dir).mkdir(parents=True, exist_ok=True)
50 |
51 | if args.dataset == "cifar10":
52 | trainset = CIFAR10(root=os.path.join(args.data_dir, 'CIFAR10'), train=True, download=True)
53 | partitioner = CIFAR10Partitioner
54 | # elif args.dataset == "cifar100":
55 | # trainset = CIFAR100(root=os.path.join(args.data_dir, 'CIFAR100'), train=True, download=True)
56 | # partitioner = CIFAR100Partitoner
57 | else:
58 | raise ValueError(f"{args.dataset} is not supported yet.")
59 |
60 | partition = partitioner(targets=trainset.targets,
61 | num_clients=args.num_clients,
62 | balance=args.balance,
63 | partition=args.partition,
64 | unbalance_sgm=args.unbalance_sgm,
65 | num_shards=args.num_shards,
66 | dir_alpha=args.dir_alpha,
67 | seed=args.seed,
68 | verbose=True)
69 | file_name = f"{args.dataset}_{args.partition}.pkl" # get_exp_name(args) + ".pkl"
70 | save_dict(partition.client_dict,
71 | path=os.path.join(args.out_dir, file_name))
72 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/feddyn/start_clt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # ==========================================
3 | # =============== EXPERIMENT ===============
4 | # ==========================================
5 | ClientRankNum=10
6 | ClientNumPerRank=10
7 | ClientNum=$(($ClientNumPerRank * $ClientRankNum))
8 | WorldSize=$(($ClientRankNum + 1))
9 |
10 | # ------ FedAvg
11 | # for ((i = 1; i <= ${ClientRankNum}; i++)); do
12 | # {
13 | # echo "client ${i} started"
14 | # python client_starter.py --world_size ${WorldSize} --rank ${i} --client-num-per-rank ${ClientNumPerRank} --alg FedAvg --out-dir ./Output/FedAvg/run1 &
15 | # sleep 2s
16 | # }
17 | # done
18 |
19 | # wait
20 |
21 | # ------ FedDyn
22 | for ((i = 1; i <= ${ClientRankNum}; i++)); do
23 | {
24 | echo "client ${i} started"
25 | python client_starter.py --world_size ${WorldSize} --rank ${i} --client-num-per-rank ${ClientNumPerRank} --alg FedDyn --out-dir ./Output/FedDyn/run1 &
26 | sleep 2s
27 | }
28 | done
29 |
30 | wait
31 |
32 |
33 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/feddyn/start_server.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # ==========================================
3 | # =============== EXPERIMENT ===============
4 | # ==========================================
5 | ClientRankNum=10
6 | ClientNumPerRank=10
7 | ClientNum=$(($ClientNumPerRank * $ClientRankNum))
8 | WorldSize=$(($ClientRankNum + 1))
9 | # balance iid cifar10 for 100 clients, check config.py for other setting
10 | python data_partition.py --out-dir ./Output/FedDyn/run1 --partition iid --balance True --dataset cifar10 --num-clients ${ClientNum} --seed 1
11 | echo -e "Data partition DONE.\n\n"
12 | sleep 4s
13 |
14 | # # ----- FedAvg
15 | # SECONDS=0
16 |
17 | # python server_starter.py --world_size ${WorldSize} --partition iid --alg FedAvg --out-dir ./Output/FedAvg/run1
18 |
19 | # ELAPSED="Elapsed: $(($SECONDS / 3600))hrs $((($SECONDS / 60) % 60))min $(($SECONDS % 60))sec"
20 | # echo $ELAPSED
21 |
22 | # ------- FedDyn
23 | SECONDS=0
24 |
25 | python server_starter.py --world_size ${WorldSize} --partition iid --alg FedDyn --out-dir ./Output/FedDyn/run1
26 |
27 | ELAPSED="Elapsed: $(($SECONDS / 60))min $(($SECONDS % 60))sec"
28 | echo $ELAPSED
29 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedmgda+/README.md:
--------------------------------------------------------------------------------
1 | # FedMGDA+: Federated Learning meets Multi-objective Optimization
2 |
3 | An reproduction of paper ["Fedmgda+: Federated learning meets multi-objective optimization"](https://arxiv.org/abs/2006.11489) via fedlab.
4 |
5 | Thanks to GitHub repo https://github.com/WwZzz/easyFL.
6 |
7 | ## Requirements
8 |
9 | fedlab==1.2.0
10 |
11 | ## Run
12 |
13 | $ bash run.sh
14 |
15 | ## Performance
16 |
17 | Null.
18 |
19 | ## References
20 |
21 | Hu, Zeou, et al. "Fedmgda+: Federated learning meets multi-objective optimization." arXiv preprint arXiv:2006.11489 (2020).
22 |
23 |
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedmgda+/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedmgda+/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedmgda+/mnist_iid_100.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedmgda+/mnist_iid_100.pkl
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedmgda+/mnist_noniid.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedmgda+/mnist_noniid.pkl
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedmgda+/mnist_noniid_200_100.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedmgda+/mnist_noniid_200_100.pkl
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedmgda+/run.sh:
--------------------------------------------------------------------------------
1 | #!bin/bash
2 |
3 | python server.py --ip 127.0.0.1 --port 3002 --world_size 6 --dataset mnist --round 100 &
4 |
5 | python client.py --ip 127.0.0.1 --port 3002 --world_size 6 --rank 1 --dataset mnist &
6 |
7 | python client.py --ip 127.0.0.1 --port 3002 --world_size 6 --rank 2 --dataset mnist &
8 |
9 | python client.py --ip 127.0.0.1 --port 3002 --world_size 6 --rank 3 --dataset mnist &
10 |
11 | python client.py --ip 127.0.0.1 --port 3002 --world_size 6 --rank 4 --dataset mnist &
12 |
13 | python client.py --ip 127.0.0.1 --port 3002 --world_size 6 --rank 5 --dataset mnist &
14 |
15 |
16 | wait
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedmgda+/setting.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torchvision.transforms as transforms
4 | import sys
5 |
6 | from fedlab.utils.dataset.sampler import RawPartitionSampler
7 |
8 | sys.path.append('../')
9 |
10 | from models.cnn import CNN_CIFAR10, CNN_FEMNIST, CNN_MNIST
11 | from models.rnn import RNN_Shakespeare
12 | from models.mlp import MLP_CelebA
13 | from leaf.dataloader import get_LEAF_dataloader
14 |
15 | def get_dataloader(args):
16 | if args.dataset == 'mnist':
17 | root = '../datasets/mnist/'
18 | trainset = torchvision.datasets.MNIST(root=root,
19 | train=True,
20 | download=True,
21 | transform=transforms.ToTensor())
22 |
23 | testset = torchvision.datasets.MNIST(root=root,
24 | train=False,
25 | download=True,
26 | transform=transforms.ToTensor())
27 |
28 | trainloader = torch.utils.data.DataLoader(
29 | trainset,
30 | sampler=RawPartitionSampler(trainset,
31 | client_id=args.rank,
32 | num_replicas=args.world_size - 1),
33 | batch_size=args.batch_size,
34 | drop_last=True,
35 | num_workers=args.world_size)
36 |
37 | testloader = torch.utils.data.DataLoader(testset,
38 | batch_size=int(
39 | len(testset) / 10),
40 | drop_last=False,
41 | shuffle=False)
42 | elif args.dataset == 'femnist':
43 | trainloader, testloader = get_LEAF_dataloader(dataset=args.dataset,
44 | client_id=args.rank)
45 | elif args.dataset == 'shakespeare':
46 | trainloader, testloader = get_LEAF_dataloader(dataset=args.dataset,
47 | client_id=args.rank)
48 | elif args.dataset == 'celeba':
49 | trainloader, testloader = get_LEAF_dataloader(dataset=args.dataset,
50 | client_id=args.rank)
51 | else:
52 | raise ValueError("Invalid dataset:", args.dataset)
53 |
54 | return trainloader, testloader
55 |
56 |
57 | def get_model(args):
58 | if args.dataset == "mnist":
59 | model = CNN_MNIST()
60 | elif args.dataset == 'femnist':
61 | model = CNN_FEMNIST()
62 | elif args.dataset == 'shakespeare':
63 | model = RNN_Shakespeare()
64 | elif args.dataset == 'celeba':
65 | model = MLP_CelebA()
66 | else:
67 | raise ValueError("Invalid dataset:", args.dataset)
68 |
69 | return model
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedmgda+/start_clients.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for ((i=$2; i<=$3; i++))
4 | do
5 | {
6 | echo "client ${i} started"
7 | python client.py --ip 127.0.0.1 --port 3002 --world_size $1 --rank ${i} --scale True &
8 | sleep 1s
9 | }
10 | done
11 | wait
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedprox/README.md:
--------------------------------------------------------------------------------
1 | # Fedprox
2 |
3 | [Federated optimization in heterogeneous networks](https://proceedings.mlsys.org/papers/2020/176)
4 |
5 | ## Requirements
6 |
7 | fedlab==1.1.2
8 |
9 | ## Run
10 |
11 | You are wellcome to read sourcecodes. There are similar with the demos of FedAvg.
12 |
13 | ## Performance
14 |
15 | Null
16 |
17 | ## References
18 |
19 | Li, Tian, et al. "Federated optimization in heterogeneous networks." Proceedings of Machine Learning and Systems 2 (2020): 429-450.
20 |
21 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedprox/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedprox/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedprox/cross_process/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/fedprox/cross_process/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedprox/cross_process/client.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | import argparse
4 | import os
5 |
6 | from setting import get_model, get_dataset
7 | from torch import nn, optim
8 | from fedlab.core.client.manager import ClientPassiveManager
9 | from fedlab.core.network import DistNetwork
10 | from fedlab.utils.logger import Logger
11 |
12 | sys.path.append("../")
13 | from fedprox_trainer import FedProxTrainer
14 |
15 | if __name__ == "__main__":
16 |
17 | parser = argparse.ArgumentParser(description="Distbelief training example")
18 |
19 | parser.add_argument("--ip", type=str)
20 |
21 | parser.add_argument("--port", type=str)
22 |
23 | parser.add_argument("--world_size", type=int)
24 |
25 | parser.add_argument("--rank", type=int)
26 |
27 | parser.add_argument("--lr", type=float, default=0.01)
28 |
29 | parser.add_argument("--epoch", type=int, default=5)
30 |
31 | parser.add_argument("--dataset", type=str)
32 |
33 | parser.add_argument("--batch_size", type=int, default=100)
34 |
35 | parser.add_argument("--gpu", type=str, default="0,1,2,3")
36 |
37 | parser.add_argument("--ethernet", type=str, default=None)
38 |
39 | parser.add_argument(
40 | "--straggler", type=float, default=0.0
41 | ) # vaild value should be in range [0, 1] and mod 0.1 == 0
42 |
43 | parser.add_argument(
44 | "--optimizer", type=str, default="sgd"
45 | ) # valid value: {"sgd", "adam", "rmsprop"}
46 |
47 | parser.add_argument(
48 | "--mu", type=float, default=0.0
49 | ) # recommended value: {0.001, 0.01, 0.1, 1.0}
50 |
51 | args = parser.parse_args()
52 |
53 | if args.gpu != "-1":
54 | args.cuda = True
55 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
56 | device = torch.device(args.gpu)
57 | else:
58 | args.cuda = False
59 | device = torch.device("cpu")
60 |
61 | model = get_model(args).to(device)
62 | trainloader, testloader = get_dataset(args)
63 | optimizer_dict = dict(sgd=optim.SGD, adam=optim.Adam, rmsprop=optim.RMSprop)
64 | optimizer = optimizer_dict[args.optimizer](model.parameters(), lr=args.lr)
65 | criterion = nn.CrossEntropyLoss()
66 |
67 | network = DistNetwork(
68 | address=(args.ip, args.port),
69 | world_size=args.world_size,
70 | rank=args.rank,
71 | ethernet=args.ethernet,
72 | )
73 |
74 | LOGGER = Logger(log_name="client " + str(args.rank))
75 |
76 | trainer = FedProxTrainer(
77 | model=model,
78 | data_loader=trainloader,
79 | epochs=args.epoch,
80 | optimizer=optimizer,
81 | criterion=criterion,
82 | mu=args.mu,
83 | cuda=args.cuda,
84 | logger=LOGGER,
85 | )
86 |
87 | manager_ = ClientPassiveManager(trainer=trainer, network=network, logger=LOGGER)
88 | manager_.run()
89 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedprox/cross_process/server.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from fedlab.utils.logger import Logger
4 | from fedlab.core.server.handler import SyncParameterServerHandler
5 | from fedlab.core.server.manager import ServerSynchronousManager
6 | from fedlab.core.network import DistNetwork
7 | from setting import get_model, get_dataset
8 |
9 | if __name__ == "__main__":
10 | parser = argparse.ArgumentParser(description="FL server example")
11 |
12 | parser.add_argument("--ip", type=str)
13 |
14 | parser.add_argument("--port", type=str)
15 |
16 | parser.add_argument("--world_size", type=int)
17 |
18 | parser.add_argument("--round", type=int, default=5)
19 |
20 | parser.add_argument("--dataset", type=str)
21 |
22 | parser.add_argument("--ethernet", type=str, default=None)
23 |
24 | parser.add_argument("--sample", type=float, default=1)
25 |
26 | args = parser.parse_args()
27 |
28 | model = get_model(args)
29 | LOGGER = Logger(log_name="server")
30 | handler = SyncParameterServerHandler(model,
31 | global_round=args.round,
32 | logger=LOGGER,
33 | sample_ratio=args.sample)
34 | network = DistNetwork(
35 | address=(args.ip, args.port),
36 | world_size=args.world_size,
37 | rank=0,
38 | ethernet=args.ethernet,
39 | )
40 | manager_ = ServerSynchronousManager(handler=handler,
41 | network=network,
42 | logger=LOGGER)
43 | manager_.run()
44 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedprox/cross_process/setting.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("../../")
3 |
4 | import argparse
5 | from models.cnn import CNN_CIFAR10, CNN_FEMNIST, CNN_MNIST
6 | from models.rnn import RNN_Shakespeare
7 | from models.mlp import MLP_CelebA
8 | from leaf.dataloader import get_LEAF_dataloader
9 | from torchvision import transforms, datasets
10 | from torch.utils.data import DataLoader
11 | from fedlab.utils.dataset.sampler import RawPartitionSampler
12 |
13 |
14 | def get_model(args):
15 | if args.dataset == "mnist":
16 | model = CNN_MNIST()
17 | elif args.dataset == "femnist":
18 | model = CNN_FEMNIST()
19 | elif args.dataset == "shakespeare":
20 | model = RNN_Shakespeare()
21 | elif args.dataset == "celeba":
22 | model = MLP_CelebA()
23 | else:
24 | raise ValueError("Invalid dataset:", args.dataset)
25 |
26 | return model
27 |
28 |
29 | def get_dataset(args):
30 | if args.dataset == "mnist":
31 | root = "../../datasets/mnist/"
32 | train_transform = transforms.Compose([transforms.ToTensor(),])
33 | test_transform = transforms.Compose([transforms.ToTensor(),])
34 | trainset = datasets.MNIST(
35 | root=root, train=True, download=True, transform=train_transform
36 | )
37 |
38 | testset = datasets.MNIST(
39 | root=root, train=False, download=True, transform=test_transform
40 | )
41 |
42 | trainloader = DataLoader(
43 | trainset,
44 | sampler=RawPartitionSampler(
45 | trainset, client_id=args.rank, num_replicas=args.world_size - 1
46 | ),
47 | batch_size=args.batch_size,
48 | drop_last=True,
49 | num_workers=args.world_size,
50 | )
51 |
52 | testloader = DataLoader(
53 | testset, batch_size=int(len(testset) / 10), drop_last=False, shuffle=False
54 | )
55 | elif args.dataset == "femnist":
56 | trainloader, testloader = get_LEAF_dataloader(
57 | dataset=args.dataset, client_id=args.rank
58 | )
59 | elif args.dataset == "shakespeare":
60 | trainloader, testloader = get_LEAF_dataloader(
61 | dataset=args.dataset, client_id=args.rank
62 | )
63 | elif args.dataset == "celeba":
64 | trainloader, testloader = get_LEAF_dataloader(
65 | dataset=args.dataset, client_id=args.rank
66 | )
67 | else:
68 | raise ValueError("Invalid dataset:", args.dataset)
69 |
70 | return trainloader, testloader
71 |
72 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/fedprox/fedprox_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from copy import deepcopy
3 |
4 | from fedlab.core.client import ClientSGDTrainer
5 | from fedlab.utils.serialization import SerializationTool
6 | from fedlab.utils import Logger
7 | from tqdm import tqdm
8 |
9 |
10 | class FedProxTrainer(ClientSGDTrainer):
11 | """FedProxTrainer.
12 |
13 | Details of FedProx are available in paper: https://arxiv.org/abs/1812.06127
14 |
15 | Args:
16 | model (torch.nn.Module): PyTorch model.
17 | data_loader (torch.utils.data.DataLoader): :class:`torch.utils.data.DataLoader` for this client.
18 | epochs (int): the number of local epoch.
19 | optimizer (torch.optim.Optimizer, optional): optimizer for this client's model.
20 | criterion (torch.nn.Loss, optional): loss function used in local training process.
21 | cuda (bool, optional): use GPUs or not. Default: ``True``.
22 | logger (Logger, optional): :object of :class:`Logger`.
23 | mu (float): hyper-parameter of FedProx.
24 | """
25 |
26 | def __init__(
27 | self,
28 | model,
29 | data_loader,
30 | epochs,
31 | optimizer,
32 | criterion,
33 | mu,
34 | cuda=True,
35 | logger=Logger(),
36 | ):
37 | super().__init__(
38 | model, data_loader, epochs, optimizer, criterion, cuda=cuda, logger=logger
39 | )
40 |
41 | self.mu = mu
42 |
43 | def train(self, model_parameters):
44 | """Client trains its local model on local dataset.
45 |
46 | Args:
47 | model_parameters (torch.Tensor): Serialized model parameters.
48 | """
49 | frz_model = deepcopy(self._model)
50 | SerializationTool.deserialize_model(frz_model, model_parameters)
51 | SerializationTool.deserialize_model(
52 | self._model, model_parameters
53 | ) # load parameters
54 | self._LOGGER.info("Local train procedure is running")
55 | for ep in range(self.epochs):
56 | self._model.train()
57 | for inputs, labels in tqdm(
58 | self._data_loader, desc="{}, Epoch {}".format(self._LOGGER.name, ep)
59 | ):
60 | if self.cuda:
61 | inputs, labels = inputs.cuda(self.gpu), labels.cuda(self.gpu)
62 |
63 | outputs = self._model(inputs)
64 | l1 = self.criterion(outputs, labels)
65 | l2 = 0.0
66 |
67 | for w0, w in zip(frz_model.parameters(), self._model.parameters()):
68 | l2 += torch.sum(torch.pow(w - w0, 2))
69 |
70 | loss = l1 + 0.5 * self.mu * l2
71 |
72 | self.optimizer.zero_grad()
73 | loss.backward()
74 | self.optimizer.step()
75 | self._LOGGER.info("Local train procedure is finished")
76 |
77 | return self.model_parameters
78 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/leaf/README_tmp.md:
--------------------------------------------------------------------------------
1 | # PROCESS_DATA README
2 |
3 | This folders contains processed dataset pickle files for leaf datasets.
4 |
5 | You can run `create_pickle_dataset.py` script for processed leaf data
6 |
7 | Notice:
8 | 1. please make sure leaf dataset is downloaded and processed by leaf. (leaf code in `fedlab_benchmarks/datasets/data`)
9 | 2. please make sure `fedlab_benchmarks/datasets/data/{dataset_name}/{train,test}` path existing for train data and test data.
10 | 3. example script:
11 | `python create_pickle_dataset.py --data_root "../../datasets/data" --save_root "pickle_dataset" --dataset_name "shakespeare"`
12 | 3. usage example:
13 | ```
14 | pdataset = PickleDataset(pickle_root="pickle_datasets", dataset_name="shakespeare")
15 | pdataset.create_pickle_dataset(data_root="../datasets")
16 | dataset = pdataset.get_dataset_pickle(dataset_type="test", client_id="2")
17 | ```
18 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/leaf/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/leaf/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/leaf/dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | # from .femnist_dataset import FemnistDataset
3 | # from .shakespeare_dataset import ShakespeareDataset
4 | # from .celeba_dataset import CelebADataset
5 | # from sent140_dataset import Sent140Dataset
--------------------------------------------------------------------------------
/fedlab_benchmarks/leaf/dataset/celeba_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | from PIL import Image
17 | from pathlib import Path
18 | from torch.utils.data import Dataset
19 |
20 |
21 | class CelebADataset(Dataset):
22 | def __init__(self, client_id: int, client_str: str, data: list, targets: list, image_root: str, transform=None):
23 | """get `Dataset` for CelebA dataset
24 |
25 | Args:
26 | client_id (int): client id
27 | client_str (str): client name string
28 | data (list): input image name list data
29 | targets (list): output label list
30 | """
31 | self.client_id = client_id
32 | self.client_str = client_str
33 | self.image_root = Path(__file__).parent.resolve() / image_root
34 | self.transform = transform
35 | self.data = data
36 | self.targets = targets
37 | self._process_data_target()
38 |
39 | def _process_data_target(self):
40 | """process client's data and target
41 | """
42 | data = []
43 | targets = []
44 | for idx in range(len(self.data)):
45 | image_path = self.image_root / self.data[idx]
46 | image = Image.open(image_path).convert('RGB')
47 | data.append(image)
48 | targets.append(torch.tensor(self.targets[idx], dtype=torch.long))
49 | self.data = data
50 | self.targets = targets
51 |
52 | def __len__(self):
53 | return len(self.targets)
54 |
55 | def __getitem__(self, index):
56 | data = self.data[index]
57 | if self.transform:
58 | data = self.transform(data)
59 | target = self.targets[index]
60 | return data, target
61 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/leaf/dataset/femnist_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import torch
17 | from torch.utils.data import Dataset
18 |
19 |
20 | class FemnistDataset(Dataset):
21 |
22 | def __init__(self, client_id: int, client_str: str, data: list, targets: list):
23 | """get `Dataset` for femnist dataset
24 |
25 | Args:
26 | client_id (int): client id
27 | client_str (str): client name string
28 | data (list): image data list
29 | targets (list): image class target list
30 | """
31 | self.client_id = client_id
32 | self.client_str = client_str
33 | self.data = data
34 | self.targets = targets
35 | self._process_data_target()
36 |
37 | def _process_data_target(self):
38 | """process client's data and target
39 |
40 | """
41 | self.data = torch.tensor(self.data, dtype=torch.float32).reshape(-1, 1, 28, 28)
42 | self.targets = torch.tensor(self.targets, dtype=torch.long)
43 |
44 | def __len__(self):
45 | return len(self.targets)
46 |
47 | def __getitem__(self, index):
48 | return self.data[index], self.targets[index]
49 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/leaf/dataset/reddit_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/leaf/dataset/shakespeare_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import torch
17 | from torch.utils.data import Dataset
18 |
19 |
20 | class ShakespeareDataset(Dataset):
21 |
22 | def __init__(self, client_id: int, client_str: str, data: list, targets: list):
23 | """get `Dataset` for shakespeare dataset
24 |
25 | Args:
26 | client_id (int): client id
27 | client_str (str): client name string
28 | data (list): sentence list data
29 | targets (list): next-character target list
30 | """
31 | self.client_id = client_id
32 | self.client_str = client_str
33 | self.ALL_LETTERS, self.VOCAB_SIZE = self._build_vocab()
34 | self.data = data
35 | self.targets = targets
36 | self._process_data_target()
37 |
38 | def _build_vocab(self):
39 | """ according all letters to build vocab
40 |
41 | Vocabulary re-used from the Federated Learning for Text Generation tutorial.
42 | https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation
43 |
44 | Returns:
45 | all letters vocabulary list and length of vocab list
46 | """
47 | ALL_LETTERS = "\n !\"&'(),-.0123456789:;>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz}"
48 | VOCAB_SIZE = len(ALL_LETTERS)
49 | return ALL_LETTERS, VOCAB_SIZE
50 |
51 | def _process_data_target(self):
52 | """process client's data and target
53 | """
54 | self.data = torch.tensor(
55 | [self.__sentence_to_indices(sentence) for sentence in self.data])
56 | self.targets = torch.tensor(
57 | [self.__letter_to_index(letter) for letter in self.targets])
58 |
59 | def __sentence_to_indices(self, sentence: str):
60 | """Returns list of integer for character indices in ALL_LETTERS
61 |
62 | Args:
63 | sentence (str): input sentence
64 |
65 | Returns: a integer list of character indices
66 | """
67 | indices = []
68 | for c in sentence:
69 | indices.append(self.ALL_LETTERS.find(c))
70 | return indices
71 |
72 | def __letter_to_index(self, letter: str):
73 | """Returns index in ALL_LETTERS of given letter
74 |
75 | Args:
76 | letter (char/str[0]): input letter
77 |
78 | Returns: int index of input letter
79 | """
80 | index = self.ALL_LETTERS.find(letter)
81 | return index
82 |
83 | def __len__(self):
84 | return len(self.targets)
85 |
86 | def __getitem__(self, index):
87 | return self.data[index], self.targets[index]
88 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/leaf/gen_pickle_dataset.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Example: bash gen_pickle_dataset.sh "shakespeare" "../datasets" "./pickle_datasets"
4 | # Example: bash gen_pickle_dataset.sh "sent140" "../datasets" "./pickle_datasets" 1
5 | dataset=$1
6 | data_root=${2:-'../datasets'}
7 | pickle_root=${3:-'./pickle_datasets'}
8 | # for nlp datasets
9 | build_vocab=${4:-'0'}
10 | vocab_save_root=${5:-'./nlp_utils/dataset_vocab'}
11 | vector_save_root=${6:-'./nlp_utils/glove'}
12 | vocab_limit_size=${7:-'50000'}
13 |
14 |
15 | python pickle_dataset.py \
16 | --dataset ${dataset} \
17 | --data_root ${data_root} \
18 | --pickle_root ${pickle_root} \
19 | --build_vocab ${build_vocab} \
20 | --vocab_save_root ${vocab_save_root} \
21 | --vector_save_root ${vector_save_root} \
22 | --vocab_limit_size ${vocab_limit_size}
23 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/leaf/nlp_utils/README.md:
--------------------------------------------------------------------------------
1 | # NLP——UTILS
2 |
3 | This folder contains some lightweight utils for nlp process.
4 |
5 | - get_glove.sh: glove download script, from http://nlp.stanford.edu/data/glove.6B.zip to get glove.6B.300d.txt.
6 |
7 | - build_vocab.sh: provide a way to sample some clients' train data for building a vocabulary for federated nlp tasks,
8 | which is a simple alternative method compared with use all data directly in other implementation.
9 |
10 | - sample_build_vocab.py: provide a way to sample some clients' train data for building a vocabulary for federated nlp tasks.
11 |
12 | - tokenizer.py: provide `class Tokenizer`, splitting an entire text into smaller units called tokens, such as individual words or terms.
13 |
14 | - vocab.py: provide `class Vocab`, to encapsulate vocabulary operations in nlp,
15 | such as getting word2idx for tokenized input data, and get vector list from pretrained word_vec_file, such as glove.6B.300d.txt.
--------------------------------------------------------------------------------
/fedlab_benchmarks/leaf/nlp_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/leaf/nlp_utils/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/leaf/nlp_utils/download_glove.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # This is modified by [LEAF/models/sent140/get_embs.sh]
3 | # https://github.com/TalwalkarLab/leaf/blob/master/models/sent140/get_embs.sh
4 |
5 | if [ ! -f 'glove.6B.300d.txt' ]; then
6 | wget http://nlp.stanford.edu/data/glove.6B.zip
7 | unzip glove.6B.zip
8 | rm glove.6B.50d.txt glove.6B.100d.txt glove.6B.200d.txt glove.6B.zip
9 |
10 | if [ ! -d ./glove ];then
11 | mkdir glove
12 | fi
13 | mv glove.6B.300d.txt ./glove
14 | echo "download glove.6B.300d.txt successfully"
15 | fi
--------------------------------------------------------------------------------
/fedlab_benchmarks/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/models/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/models/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch.nn as nn
16 |
17 |
18 | class MLP_CelebA(nn.Module):
19 | """Used for celeba experiment"""
20 |
21 | def __init__(self):
22 | super(MLP_CelebA, self).__init__()
23 | self.fc1 = nn.Linear(12288, 2048) # image_size=64, 64*64*3
24 | self.relu1 = nn.ReLU()
25 | self.fc2 = nn.Linear(2048, 500)
26 | self.relu2 = nn.ReLU()
27 | self.fc3 = nn.Linear(500, 100)
28 | self.relu3 = nn.ReLU()
29 | self.fc4 = nn.Linear(100, 2)
30 |
31 | def forward(self, x):
32 | x = x.view(x.shape[0], -1)
33 | x = self.relu1(self.fc1(x))
34 | x = self.relu2(self.fc2(x))
35 | x = self.relu3(self.fc3(x))
36 | x = self.fc4(x)
37 | return x
38 |
39 |
40 | class MLP(nn.Module):
41 | def __init__(self, input_size, output_size):
42 | super(MLP, self).__init__()
43 | self.fc1 = nn.Linear(input_size, 200)
44 | self.fc2 = nn.Linear(200, 200)
45 | self.fc3 = nn.Linear(200, output_size)
46 | self.relu = nn.ReLU()
47 |
48 | def forward(self, x):
49 | x = x.view(x.shape[0], -1)
50 | x = self.relu(self.fc1(x))
51 | x = self.relu(self.fc2(x))
52 | x = self.fc3(x)
53 | return x
--------------------------------------------------------------------------------
/fedlab_benchmarks/perfedavg/README.md:
--------------------------------------------------------------------------------
1 | # Personalized-FedAvg
2 |
3 | Personalized-FedAvg: [Improving Federated Learning Personalization via Model Agnostic Meta Learning](https://arxiv.org/abs/1909.12488)
4 |
5 | ## Further reading
6 |
7 | - MAML: [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/abs/1703.03400)
8 | - Reptile: [On First-Order Meta-Learning Algorithms](https://arxiv.org/abs/1803.02999)
9 | - LEAF: [LEAF: A Benchmark for Federated Settings](https://arxiv.org/abs/1812.01097)
10 | ## Dataset
11 |
12 | - Experiments are run on FEMNIST(derived from [LEAF](https://github.com/TalwalkarLab/leaf))
13 |
14 | - Download and preprocess FEMNIST:
15 |
16 | - `cd leaf/datasets/femnist ; sh ./preprocess.sh -s niid --iu 1.0 --sf 1.0 -k 0 -t user --tf 0.8`
17 |
18 | You can get more details about this command in `/leaf/datasets/femnist/README.md`.
19 |
20 | ☝ Make sure you have download and preprocess FEMNIST before running the experiment.
21 |
22 |
23 |
24 | ## Requirements
25 |
26 | fedlab==1.1.2
27 |
28 | ## Run
29 |
30 | There're two way to run experiment in **Linux**. I have already set all hyper parameters well according to paper. Of course those can be modified. You can check `utils.get_args()` for more details about all hyper parameters.
31 |
32 | ### Single-process
33 |
34 | ```python
35 | python single_process.py
36 | ```
37 |
38 | ### Multi-process (needs more computational power)
39 |
40 | I have set 3 workers(process) to handle all training tasks.
41 |
42 | ```python
43 | cd multi_process/ ; sh quick_start.sh
44 | ```
45 |
46 | ## Performance
47 |
48 | Evaluation result after fine-tuned is shown below.
49 |
50 | Communication round: `500`
51 |
52 | Fine-tune: outer loop: `100`; inner loop: `10`
53 |
54 | Personalization round: `5`
55 |
56 | | FedAvg local training epochs (5 clients) | Initial loss | Initial Acc | Personalized loss | Personalized Acc |
57 | | ---------------------------------------- | ------------ | ----------- | ----------------- | ---------------- |
58 | | 20 | 2.3022 | 79.35% | 1.5766 | 84.86% |
59 | | 10 | 1.8387 | 80.53% | 1.1231 | 87.22% |
60 | | 5 | 1.4899 | **83.19%** | 0.9809 | **88.97%** |
61 | | 2 | 1.4613 | 81.70% | 0.9359 | 88.49% |
62 |
63 | | FedAvg local training epochs (20 clients) | Initial loss | Initial Acc | Personalized loss | Personalized Acc |
64 | | ----------------------------------------- | ------------ | ----------- | ----------------- | ---------------- |
65 | | 20 | 2.2398 | 82.40% | 0.9756 | 90.29% |
66 | | 10 | 1.6560 | **83.23**% | 0.8488 | 90.72% |
67 | | 5 | 1.5485 | 81.48% | 0.7452 | **90.77**% |
68 | | 2 | 1.2707 | 82.48% | 0.7139 | 90.48% |
69 |
70 | Experiment result from [paper](https://arxiv.org/abs/1909.12488) is shown below
71 |
72 | 
73 |
74 | I ascribe the gap between mine and paper's results to the difference of hyper parameters setting, and I think there is no big mistake in algorithm implementation. If it actually has, please open an issue at this repo or [PerFedAvg](https://github.com/KarhouTam/PerFedAvg). 🙏
75 |
76 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/perfedavg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/perfedavg/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/perfedavg/fine_tuner.py:
--------------------------------------------------------------------------------
1 | from sys import path
2 |
3 | path.append("../")
4 |
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from fedlab.utils.serialization import SerializationTool
8 | from fedlab.utils.logger import Logger
9 | from fedlab.utils.functional import get_best_gpu
10 | from fedlab.core.client.trainer import ClientTrainer
11 | from utils import get_optimizer
12 | from tqdm import trange
13 | from leaf.pickle_dataset import PickleDataset
14 |
15 |
16 | class LocalFineTuner(ClientTrainer):
17 | """
18 | Args:
19 | model (torch.nn.Module): Global model's architecture
20 | optimizer_type (str): Local optimizer.
21 | optimizer_args (dict): Provides necessary args for build local optimizer.
22 | criterion (torch.nn.CrossEntropyLoss / torch.nn.MSELoss()): Local loss function.
23 | epochs (int): Num of local training epoch. Personalization's local epochs may differ from others.
24 | batch_size (int): Batch size of local training.
25 | cuda (bool): True for using GPUs.
26 | logger (fedlab.utils.Logger): Object of Logger.
27 | """
28 |
29 | def __init__(
30 | self,
31 | model,
32 | optimizer_type,
33 | optimizer_args,
34 | criterion,
35 | epochs,
36 | batch_size,
37 | cuda,
38 | logger=Logger(),
39 | ):
40 | super(LocalFineTuner, self).__init__(model, cuda)
41 | if torch.cuda.is_available() and cuda:
42 | self.device = get_best_gpu()
43 | else:
44 | self.device = torch.device("cpu")
45 | self.epochs = epochs
46 | self._criterion = criterion
47 | self._optimizer = get_optimizer(self._model, optimizer_type, optimizer_args)
48 | self._logger = logger
49 | self.batch_size = batch_size
50 | self.dataset = PickleDataset("femnist")
51 |
52 | def train(self, client_id, model_parameters):
53 | trainloader = DataLoader(
54 | self.dataset.get_dataset_pickle("train", client_id), self.batch_size
55 | )
56 | SerializationTool.deserialize_model(self._model, model_parameters)
57 | gradients = []
58 | for param in self._model.parameters():
59 | gradients.append(
60 | torch.zeros(param.size(), requires_grad=True, device=param.device)
61 | )
62 | for _ in trange(self.epochs, desc="client [{}]".format(client_id)):
63 |
64 | for x, y in trainloader:
65 | x, y = x.to(self.device), y.to(self.device)
66 |
67 | logit = self._model(x)
68 | loss = self._criterion(logit, y)
69 |
70 | self._optimizer.zero_grad()
71 | loss.backward()
72 | self._optimizer.step()
73 |
74 | for idx, param in enumerate(self._model.parameters()):
75 | gradients[idx].data.add_(param.grad.data)
76 | return gradients
77 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/perfedavg/image/paper_exp_res.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/perfedavg/image/paper_exp_res.png
--------------------------------------------------------------------------------
/fedlab_benchmarks/perfedavg/models.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | # Model's architecture refers to https://github.com/tensorflow/federated/blob/main/tensorflow_federated/python/simulation/models/mnist.py
4 | class EmnistCNN(nn.Module):
5 | def __init__(self):
6 | super(EmnistCNN, self).__init__()
7 | self.net = nn.Sequential(
8 | nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2),
9 | nn.ReLU(inplace=True),
10 | nn.MaxPool2d(2, stride=2, padding=0),
11 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
12 | nn.ReLU(inplace=True),
13 | nn.MaxPool2d(2, stride=2, padding=0),
14 | nn.Flatten(),
15 | nn.Linear(7 * 7 * 64, out_features=1024),
16 | nn.ReLU(inplace=True),
17 | nn.Dropout(0.6),
18 | nn.Linear(in_features=1024, out_features=62),
19 | )
20 |
21 | def forward(self, x):
22 | return self.net(x)
23 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/perfedavg/multi_process/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/perfedavg/multi_process/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/perfedavg/multi_process/client.py:
--------------------------------------------------------------------------------
1 | from sys import path
2 |
3 | path.append("../")
4 | path.append("../../")
5 |
6 | import argparse
7 | from torch import nn
8 | from fedlab.core.network import DistNetwork
9 | from fedlab.utils.logger import Logger
10 | from client_manager import PerFedAvgClientManager
11 | from trainer import PerFedAvgTrainer
12 | from fine_tuner import LocalFineTuner
13 | from models import EmnistCNN
14 | from utils import get_args
15 |
16 | if __name__ == "__main__":
17 | parser = argparse.ArgumentParser()
18 | args = get_args(parser)
19 | model = EmnistCNN()
20 | criterion = nn.CrossEntropyLoss()
21 | network = DistNetwork(
22 | address=(args.ip, args.port),
23 | world_size=args.world_size,
24 | rank=args.rank,
25 | ethernet=args.ethernet,
26 | )
27 |
28 | LOGGER = Logger(log_name="client process " + str(args.rank))
29 |
30 | perfedavg_trainer = PerFedAvgTrainer(
31 | model=model,
32 | optimizer_type="sgd",
33 | optimizer_args=dict(lr=args.local_lr),
34 | criterion=criterion,
35 | epochs=args.inner_loops,
36 | batch_size=args.batch_size,
37 | pers_round=args.pers_round,
38 | cuda=args.cuda,
39 | logger=Logger(log_name="node {}".format(args.rank)),
40 | )
41 |
42 | finetuner = LocalFineTuner(
43 | model=model,
44 | optimizer_type="adam",
45 | optimizer_args=dict(lr=args.fine_tune_local_lr, betas=(0, 0.999)),
46 | criterion=criterion,
47 | epochs=args.fine_tune_inner_loops,
48 | batch_size=args.batch_size,
49 | cuda=args.cuda,
50 | logger=Logger(log_name="node {}".format(args.rank)),
51 | )
52 |
53 | manager_ = PerFedAvgClientManager(
54 | network=network, fedavg_trainer=perfedavg_trainer, fine_tuner=finetuner,
55 | )
56 | manager_.run()
57 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/perfedavg/multi_process/quick_start.sh:
--------------------------------------------------------------------------------
1 | #!bin/bash
2 | python server.py --ip 127.0.0.1 --port 3002 --world_size 4 --rank 0 &
3 |
4 | python client.py --ip 127.0.0.1 --port 3002 --world_size 4 --rank 1 &
5 |
6 | python client.py --ip 127.0.0.1 --port 3002 --world_size 4 --rank 2 &
7 |
8 | python client.py --ip 127.0.0.1 --port 3002 --world_size 4 --rank 3 &
9 |
10 | wait
--------------------------------------------------------------------------------
/fedlab_benchmarks/perfedavg/multi_process/server.py:
--------------------------------------------------------------------------------
1 | from sys import path
2 |
3 | path.append("../")
4 |
5 | import argparse
6 | from handler import PersonalizaitonHandler, FedAvgHandler, FineTuneHandler
7 | from fedlab.core.network import DistNetwork
8 | from fedlab.utils.logger import Logger
9 | from server_manager import PerFedAvgSyncServerManager
10 | from utils import get_args
11 | from models import EmnistCNN
12 |
13 | if __name__ == "__main__":
14 | parser = argparse.ArgumentParser()
15 | args = get_args(parser)
16 |
17 | model = EmnistCNN()
18 | fedavg_handler = FedAvgHandler(
19 | model=model,
20 | global_round=args.epochs,
21 | client_num_in_total=int(0.8 * args.client_num_in_total),
22 | client_num_per_round=args.client_num_per_round,
23 | optimizer_type="momentum_sgd",
24 | optimizer_args=dict(lr=args.server_lr, momentum=0.9),
25 | cuda=args.cuda,
26 | logger=Logger(log_name="fedavg"),
27 | )
28 |
29 | finetune_handler = (
30 | FineTuneHandler(
31 | model=model,
32 | global_round=args.fine_tune_outer_loops,
33 | client_num_in_total=int(0.8 * args.client_num_in_total),
34 | client_num_per_round=args.client_num_per_round,
35 | optimizer_type="sgd",
36 | optimizer_args=dict(lr=args.fine_tune_server_lr),
37 | cuda=args.cuda,
38 | logger=Logger(log_name="fine-tune"),
39 | )
40 | if args.fine_tune
41 | else None
42 | )
43 |
44 | personalization_handler = PersonalizaitonHandler(
45 | model=model,
46 | global_round=args.test_round,
47 | client_num_in_total=args.client_num_in_total
48 | - int(0.8 * args.client_num_in_total),
49 | client_num_per_round=args.client_num_per_round,
50 | cuda=args.cuda,
51 | logger=Logger(log_name="personalization"),
52 | )
53 | network = DistNetwork(
54 | address=(args.ip, args.port),
55 | world_size=args.world_size,
56 | rank=0,
57 | ethernet=args.ethernet,
58 | )
59 | manager_ = PerFedAvgSyncServerManager(
60 | network=network,
61 | fedavg_handler=fedavg_handler,
62 | finetune_handler=finetune_handler,
63 | personalization_handler=personalization_handler,
64 | logger=Logger(log_name="manager_server"),
65 | )
66 |
67 | manager_.run()
68 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/qfedavg/README.md:
--------------------------------------------------------------------------------
1 | # qFedAvg
2 |
3 | An reproduction of paper ["Fair resource allocation in federated learning"](https://arxiv.org/abs/1905.10497) via fedlab.
4 |
5 | ## Requirements
6 |
7 | fedlab==1.2.0
8 |
9 | ## Run
10 |
11 | $ bash run.sh
12 |
13 | ## Performance
14 |
15 | Null.
16 |
17 | ## References
18 |
19 | Li, Tian, et al. "Fair resource allocation in federated learning." arXiv preprint arXiv:1905.10497 (2019).
20 |
--------------------------------------------------------------------------------
/fedlab_benchmarks/qfedavg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/qfedavg/__init__.py
--------------------------------------------------------------------------------
/fedlab_benchmarks/qfedavg/mnist_iid_10.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/qfedavg/mnist_iid_10.pkl
--------------------------------------------------------------------------------
/fedlab_benchmarks/qfedavg/mnist_noniid_200_10.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SMILELab-FL/FedLab-benchmarks/9b7575a769712014865b1f99962cabce03642245/fedlab_benchmarks/qfedavg/mnist_noniid_200_10.pkl
--------------------------------------------------------------------------------
/fedlab_benchmarks/qfedavg/run.sh:
--------------------------------------------------------------------------------
1 | #!bin/bash
2 |
3 | python server.py --ip 127.0.0.1 --port 3002 --world_size 11 --dataset mnist &
4 |
5 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 1 --dataset mnist &
6 |
7 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 2 --dataset mnist &
8 |
9 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 3 --dataset mnist &
10 |
11 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 4 --dataset mnist &
12 |
13 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 5 --dataset mnist &
14 |
15 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 6 --dataset mnist &
16 |
17 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 7 --dataset mnist &
18 |
19 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 8 --dataset mnist &
20 |
21 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 9 --dataset mnist &
22 |
23 | python client.py --ip 127.0.0.1 --port 3002 --world_size 11 --rank 10 --dataset mnist &
24 |
25 | wait
--------------------------------------------------------------------------------
/fedlab_benchmarks/report_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | # Algorithm
2 |
3 | Algorithm and introductions.
4 |
5 | ## Requirements
6 |
7 | Clarify requirements, especially the version of fedlab (our version is updated frequently).
8 |
9 | ## Run
10 |
11 | Explain how to start your codes.
12 |
13 | ## Performance
14 |
15 | Please report the performance.
16 |
17 | ## References
18 |
19 | References here
20 |
21 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fedlab
2 | scikit-learn
3 | spacy
4 | matplotlib
5 | scipy
--------------------------------------------------------------------------------