├── .gitignore
├── LICENSE
├── README.md
├── check_local.py
├── config.py
├── config.yml
├── data.py
├── datasets
├── __init__.py
├── cifar.py
├── folder.py
├── gld.py
├── imagenet.py
├── lm.py
├── mnist.py
├── omniglot.py
├── transforms.py
└── utils.py
├── figures
├── fedrolex_overview.png
├── table_overview.png
└── video_placeholder.png
├── logger.py
├── main_resnet.py
├── main_transformer.py
├── metrics
├── __init__.py
└── metrics.py
├── models
├── conv.py
├── resnet.py
├── transformer.py
├── transformer_nwp.py
└── utils.py
├── modules
├── __init__.py
└── modules.py
├── requirements.txt
├── resnet_client.py
├── resnet_server.py
├── resnet_test.py
├── test_classifier.py
├── train_classifier.py
├── transformer_client.py
├── transformer_server.py
├── utils.py
└── weighted_server.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | .idea/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | pip-wheel-metadata/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # pipenv
89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
92 | # install all needed dependencies.
93 | #Pipfile.lock
94 |
95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96 | __pypackages__/
97 |
98 | # Celery stuff
99 | celerybeat-schedule
100 | celerybeat.pid
101 |
102 | # SageMath parsed files
103 | *.sage.py
104 |
105 | # Environments
106 | .env
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FedRolex: Model-Heterogeneous Federated Learning with Rolling Sub-Model Extraction
2 |
3 | Code for paper:
4 | > [FedRolex: Model-Heterogeneous Federated Learning with Rolling Sub-Model Extraction](https://openreview.net/forum?id=OtxyysUdBE)\
5 | > Samiul Alam, Luyang Liu, Ming Yan, and Mi Zhang.\
6 | > _NeurIPS 2022_.
7 |
8 | The repository is built upon [HeteroFL](https://github.com/dem123456789/HeteroFL-Computation-and-Communication-Efficient-Federated-Learning-for-Heterogeneous-Clients).
9 |
10 | # Overview
11 |
12 | Most cross-device federated learning studies focus on the model-homogeneous setting where the global server model and local client models are identical. However, such constraint not only excludes low-end clients who would otherwise make unique contributions to model training but also restrains clients from training large models due to on-device resource bottlenecks. We propose `FedRolex`, a partial training-based approach that enables model-heterogeneous FL and can train a global server model larger than the largest client model.
13 |
14 |
15 |
16 |
17 |
18 | The key difference between `FedRolex` and existing partial training-based methods is how the sub-models are extracted for each client over communication rounds in the federated training process. Specifically, instead of extracting sub-models in either random or static manner, `FedRolex` proposes a rolling sub-model extraction scheme, where the sub-model is extracted from the global server model using a rolling window that advances in each communication round. Since the window is rolling, sub-models from different parts of the global model are extracted in sequence in different rounds. As a result, all the parameters of the global server model are evenly trained over the local data of client devices.
19 |
20 |
21 |
22 |
23 |
24 | # Video Brief
25 | Click the figure to watch this short video explaining our work.
26 |
27 | [](https://recorder-v3.slideslive.com/#/share?share=74286&s=8264b1ae-a2a0-459d-aa99-88d7baf2d51f)
28 |
29 | # Usage
30 | ## Setup
31 | ```commandline
32 | pip install -r requirements.txt
33 | ```
34 |
35 | ## Training
36 | Train RESNET-18 model on CIFAR-10 dataset.
37 | ```commandline
38 | python main_resnet.py --data_name CIFAR10 \
39 | --model_name resnet18 \
40 | --control_name 1_100_0.1_non-iid-2_dynamic_a1-b1-c1-d1-e1_bn_1_1 \
41 | --exp_name roll_test \
42 | --algo roll \
43 | --g_epoch 3200 \
44 | --l_epoch 1 \
45 | --lr 2e-4 \
46 | --schedule 1200 \
47 | --seed 31 \
48 | --num_experiments 3 \
49 | --devices 0 1 2
50 | ```
51 | `data_name`: CIFAR10 or CIFAR100 \
52 | `model_name`: resnet18 or vgg
53 | `control_name`: 1_{num users}_{num participating users}_{iid or non-iid-{num classes}}_{dynamic or fix}
54 | _{heterogeneity distribution}_{batch norm(bn), {group norm(gn)}}_{scalar 1 or 0}_{masked cross entropy, 1 or 0} \
55 | `exp_name`: string value \
56 | `algo`: roll, random or static \
57 | `g_epoch`: num global epochs \
58 | `l_epoch`: num local epochs \
59 | `lr`: learning rate \
60 | `schedule`: lr schedule, space seperated list of integers less than g_epoch \
61 | `seed`: integer number \
62 | `num_experiments`: integer number, will run `num_experiments` trials with `seed` incrementing each time \
63 | `devices`: Index of GPUs to use \
64 |
65 | To train Transformer model on StackOverflow dataset, use main_transformer.py instead.
66 | ```commandline
67 | python main_transformer.py --data_name Stackoverflow \
68 | --model_name transformer \
69 | --control_name 1_100_0.1_iid_dynamic_a6-b10-c11-d18-e55_bn_1_1 \
70 | --exp_name roll_so_test \
71 | --algo roll \
72 | --g_epoch 1500 \
73 | --l_epoch 1 \
74 | --lr 2e-4 \
75 | --schedule 600 1000 \
76 | --seed 31 \
77 | --num_experiments 3 \
78 | --devices 0 1 2 3 4 5 6 7
79 | ```
80 | To train a data and model homogeneous the command would look like this.
81 | ```commandline
82 | python main_resnet.py --data_name CIFAR10 \
83 | --model_name resnet18 \
84 | --control_name 1_100_0.1_iid_dynamic_a1_bn_1_1 \
85 | --exp_name homogeneous_largest_low_heterogeneity \
86 | --algo static \
87 | --g_epoch 3200 \
88 | --l_epoch 1 \
89 | --lr 2e-4 \
90 | --schedule 800 1200 \
91 | --seed 31 \
92 | --num_experiments 3 \
93 | --devices 0 1 2
94 | ```
95 | To reproduce the results of on Table 3 in the paper please run the following commands:
96 |
97 | CIFAR-10
98 | ``` commandline
99 | python main_resnet.py --data_name CIFAR10 \
100 | --model_name resnet18 \
101 | --control_name 1_100_0.1_iid_dynamic_a1-b1-c1-d1-e1_bn_1_1 \
102 | --exp_name homogeneous_largest_low_heterogeneity \
103 | --algo static \
104 | --g_epoch 3200 \
105 | --l_epoch 1 \
106 | --lr 2e-4 \
107 | --schedule 800 1200 \
108 | --seed 31 \
109 | --num_experiments 5 \
110 | --devices 0 1 2
111 | ```
112 | CIFAR-100
113 | ``` commandline
114 | python main_resnet.py --data_name CIFAR100 \
115 | --model_name resnet18 \
116 | --control_name 1_100_0.1_iid_dynamic_a1-b1-c1-d1-e1_bn_1_1 \
117 | --exp_name homogeneous_largest_low_heterogeneity \
118 | --algo static \
119 | --g_epoch 2500 \
120 | --l_epoch 1 \
121 | --lr 2e-4 \
122 | --schedule 800 1200 \
123 | --seed 31 \
124 | --num_experiments 5 \
125 | --devices 0 1 2
126 | ```
127 | StackOverflow
128 | ```commandline
129 | python main_transformer.py --data_name Stackoverflow \
130 | --model_name transformer \
131 | --control_name 1_100_0.1_iid_dynamic_a1-b1-c1-d1-e1_bn_1_1 \
132 | --exp_name roll_so_test \
133 | --algo roll \
134 | --g_epoch 1500 \
135 | --l_epoch 1 \
136 | --lr 2e-4 \
137 | --schedule 600 1000 \
138 | --seed 31 \
139 | --num_experiments 5 \
140 | --devices 0 1 2 3 4 5 6 7
141 | ```
142 |
143 | Note: To get the results based on the real world distribution as in Table 4, use `a6-b10-c11-d18-e55` as the
144 | distribution.
145 |
146 | ## Citation
147 | If you find this useful for your work, please consider citing:
148 |
149 | ```
150 | @InProceedings{alam2022fedrolex,
151 | title = {FedRolex: Model-Heterogeneous Federated Learning with Rolling Sub-Model Extraction},
152 | author = {Alam, Samiul and Liu, Luyang and Yan, Ming and Zhang, Mi},
153 | booktitle = {Conference on Neural Information Processing Systems (NeurIPS)},
154 | year = {2022}
155 | }
156 | ```
157 |
158 |
159 |
--------------------------------------------------------------------------------
/check_local.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import torch
4 |
5 | title = 'CIFAR100 \nLow Heterogeneity'
6 | axisfont_minor = 12
7 | axisfont_major = 20
8 | label_font = 18
9 | title_font = 24
10 | b = 12
11 | tag = 'output/runs/31_CIFAR100_label_resnet18_1_100_0.1_non-iid-50_dynamic_'
12 | base = f'{tag}e1_bn_1_1_real_world.pt'
13 | compare = f'{tag}a6-b10-c11-d18-e55_bn_1_1_real_world.pt'
14 | acc = torch.load(compare)
15 |
16 | x = np.array([i[0][0]['Local-Accuracy'] for i in acc])
17 | x += np.random.normal(0, 1, x.shape)
18 | x = np.clip(x, 0, 100)
19 |
20 | fig, ax = plt.subplots()
21 | ax.tick_params(axis='both', which='major', labelsize=axisfont_major)
22 | ax.tick_params(axis='both', which='minor', labelsize=axisfont_minor)
23 |
24 | n, bins, patches = plt.hist(x, bins=b, facecolor='#2ab0ff', edgecolor='#e0e0e0', linewidth=0.5, alpha=0.75)
25 |
26 | n = n.astype('int') # it MUST be integer
27 | # Good old loop. Choose colormap of your taste
28 | for i in range(len(patches)):
29 | patches[i].set_facecolor(plt.cm.get_cmap('Oranges')(n[i] / max(n)))
30 |
31 | acc = torch.load(base)
32 |
33 | x = np.array([i[0][0]['Local-Accuracy'] for i in acc[0]])
34 | x += np.random.normal(0, 1, x.shape)
35 | x = np.clip(x, 0, 100)
36 |
37 | n, bins, patches = plt.hist(x, bins=b, facecolor='#2ab0ff', edgecolor='#e0e0e0', linewidth=0.5, alpha=0.7)
38 |
39 | n = n.astype('int') # it MUST be integer
40 | # Good old loop. Choose colormap of your taste
41 | for i in range(len(patches)):
42 | patches[i].set_facecolor(plt.cm.get_cmap('Blues')(n[i] / max(n)))
43 |
44 | plt.title(f'{title}', fontsize=title_font)
45 | plt.xlabel('Accuracy', fontsize=label_font)
46 | plt.ylabel('Counts', fontsize=label_font)
47 |
48 | # plt.rcParams['font.weight'] = 'bold'
49 | # plt.rcParams['axes.labelweight'] = 'bold'
50 | plt.rcParams['font.size'] = label_font
51 | # plt.rcParams['axes.linewidth'] = 1.0
52 | # plt.show()
53 |
54 | plt.savefig(f'{tag}a6-b10-c11-d18-e55_bn_1_1_real_world_accuracy_distribution.pdf', bbox_inches="tight")
55 | # 'Accent', 'Accent_r', 'Blues', 'Blues_r',
56 | # 'BrBG', 'BrBG_r', 'BuGn', 'BuGn_r', 'BuPu', 'BuPu_r', 'CMRmap', 'CMRmap_r', 'Dark2', 'Dark2_r', 'GnBu', 'GnBu_r', 'Greens', 'Greens_r', 'Greys', 'Greys_r', 'OrRd', 'OrRd_r', 'Oranges', 'Oranges_r', 'PRGn', 'PRGn_r', 'Paired', 'Paired_r', 'Pastel1', 'Pastel1_r', 'Pastel2', 'Pastel2_r', 'PiYG', 'PiYG_r', 'PuBu', 'PuBuGn', 'PuBuGn_r', 'PuBu_r', 'PuOr', 'PuOr_r', 'PuRd', 'PuRd_r', 'Purples', 'Purples_r', 'RdBu', 'RdBu_r', 'RdGy', 'RdGy_r', 'RdPu', 'RdPu_r', 'RdYlBu', 'RdYlBu_r', 'RdYlGn', 'RdYlGn_r', 'Reds', 'Reds_r', 'Set1', 'Set1_r', 'Set2', 'Set2_r', 'Set3', 'Set3_r', 'Spectral', 'Spectral_r', 'Wistia', 'Wistia_r', 'YlGn', 'YlGnBu', 'YlGnBu_r', 'YlGn_r', 'YlOrBr', 'YlOrBr_r', 'YlOrRd', 'YlOrRd_r', 'afmhot', 'afmhot_r', 'autumn', 'autumn_r', 'binary', 'binary_r', 'bone', 'bone_r', 'brg', 'brg_r', 'bwr', 'bwr_r', 'cividis', 'cividis_r', 'cool', 'cool_r', 'coolwarm', 'coolwarm_r', 'copper', 'copper_r', 'crest', 'crest_r', 'cubehelix', 'cubehelix_r', 'flag', 'flag_r', 'flare', 'flare_r', 'gist_earth', 'gist_earth_r', 'gist_gray', 'gist_gray_r', 'gist_heat', 'gist_heat_r', 'gist_ncar', 'gist_ncar_r', 'gist_rainbow', 'gist_rainbow_r', 'gist_stern', 'gist_stern_r', 'gist_yarg', 'gist_yarg_r', 'gnuplot', 'gnuplot2', 'gnuplot2_r', 'gnuplot_r', 'gray', 'gray_r', 'hot', 'hot_r', 'hsv', 'hsv_r', 'icefire', 'icefire_r', 'inferno', 'inferno_r', 'jet', 'jet_r', 'magma', 'magma_r', 'mako', 'mako_r', 'nipy_spectral', 'nipy_spectral_r', 'ocean', 'ocean_r', 'pink', 'pink_r', 'plasma', 'plasma_r', 'prism', 'prism_r', 'rainbow', 'rainbow_r', 'rocket', 'rocket_r', 'seismic', 'seismic_r', 'spring', 'spring_r', 'summer', 'summer_r', 'tab10', 'tab10_r', 'tab20', 'tab20_r', 'tab20b', 'tab20b_r', 'tab20c', 'tab20c_r', 'terrain', 'terrain_r', 'turbo', 'turbo_r', 'twilight', 'twilight_r', 'twilight_shifted', 'twilight_shifted_r', 'viridis', 'viridis_r', 'vlag', 'vlag_r', 'winter', 'winter_r'
57 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 |
3 | global cfg
4 | if 'cfg' not in globals():
5 | with open('config.yml', 'r') as f:
6 | cfg = yaml.load(f, Loader=yaml.FullLoader)
7 |
--------------------------------------------------------------------------------
/config.yml:
--------------------------------------------------------------------------------
1 | ---
2 | # control
3 | exp_name: hetero_fl_roll_50_100
4 | control:
5 | fed: '1'
6 | num_users: '100'
7 | frac: '0.1'
8 | data_split_mode: 'iid'
9 | model_split_mode: 'fix'
10 | model_mode: 'a1'
11 | norm: 'bn'
12 | scale: '1'
13 | mask: '1'
14 | # data
15 | data_name: CIFAR10
16 | subset: label
17 | batch_size:
18 | train: 128
19 | test: 128
20 | shuffle:
21 | train: True
22 | test: False
23 | num_workers: 0
24 | model_name: resnet18
25 | metric_name:
26 | train:
27 | - Loss
28 | - Accuracy
29 | test:
30 | - Loss
31 | - Accuracy
32 | # optimizer
33 | optimizer_name: Adam
34 | lr: 3.0e-4
35 | momentum: 0.9
36 | weight_decay: 5.0e-4
37 | # scheduler
38 | scheduler_name: None
39 | step_size: 1
40 | milestones:
41 | - 100
42 | - 150
43 | patience: 10
44 | threshold: 1.0e-3
45 | factor: 0.5
46 | min_lr: 1.0e-4
47 | # experiment
48 | init_seed: 31
49 | num_experiments: 1
50 | num_epochs: 200
51 | log_interval: 0.25
52 | device: cuda
53 | world_size: 1
54 | resume_mode: 0
55 | # other
56 | save_format: pdf
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import Dataset
6 | from torch.utils.data.dataloader import default_collate
7 | from torchvision import transforms
8 |
9 | import datasets
10 | from config import cfg
11 | from datasets.gld import GLD160
12 |
13 |
14 | def fetch_dataset(data_name, subset):
15 | dataset = {}
16 | print('fetching data {}...'.format(data_name))
17 | root = './data/{}'.format(data_name)
18 | if data_name == 'MNIST':
19 | dataset['train'] = datasets.MNIST(root=root, split='train', subset=subset, transform=datasets.Compose(
20 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
21 | dataset['test'] = datasets.MNIST(root=root, split='test', subset=subset, transform=datasets.Compose(
22 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
23 | elif data_name == 'CIFAR10':
24 | dataset['train'] = datasets.CIFAR10(root=root, split='train', subset=subset, transform=datasets.Compose(
25 | [transforms.RandomCrop(32, padding=4),
26 | transforms.RandomHorizontalFlip(),
27 | transforms.ToTensor(),
28 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]))
29 | dataset['test'] = datasets.CIFAR10(root=root, split='test', subset=subset, transform=datasets.Compose(
30 | [transforms.ToTensor(),
31 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]))
32 | elif data_name == 'CIFAR100':
33 | dataset['train'] = datasets.CIFAR100(root=root, split='train', subset=subset, transform=datasets.Compose(
34 | [transforms.RandomCrop(32, padding=4),
35 | transforms.RandomHorizontalFlip(),
36 | transforms.ToTensor(),
37 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]))
38 | dataset['test'] = datasets.CIFAR100(root=root, split='test', subset=subset, transform=datasets.Compose(
39 | [transforms.ToTensor(),
40 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]))
41 | elif data_name in ['PennTreebank', 'WikiText2', 'WikiText103']:
42 | dataset['train'] = eval('datasets.{}(root=root, split=\'train\')'.format(data_name))
43 | dataset['test'] = eval('datasets.{}(root=root, split=\'test\')'.format(data_name))
44 | elif data_name in ['Stackoverflow']:
45 | dataset['train'] = torch.load(os.path.join('/egr/research-zhanglambda/samiul/stackoverflow/',
46 | 'stackoverflow_{}.pt'.format('train')))
47 | dataset['test'] = torch.load(os.path.join('/egr/research-zhanglambda/samiul/stackoverflow/',
48 | 'stackoverflow_{}.pt'.format('val')))
49 | dataset['vocab'] = torch.load(os.path.join('/egr/research-zhanglambda/samiul/stackoverflow/',
50 | 'meta.pt'))
51 | elif data_name in ['gld']:
52 | dataset['train'] = torch.load(os.path.join('gld_160k/',
53 | '{}.pt'.format('train')))
54 | dataset['test'] = torch.load(os.path.join('gld_160k/',
55 | '{}.pt'.format('test')))
56 | else:
57 | raise ValueError('Not valid dataset name')
58 | print('data ready')
59 | return dataset
60 |
61 |
62 | def input_collate(batch):
63 | if isinstance(batch[0], dict):
64 | output = {key: [] for key in batch[0].keys()}
65 | for b in batch:
66 | for key in b:
67 | output[key].append(b[key])
68 | return output
69 | else:
70 | return default_collate(batch)
71 |
72 |
73 | def split_dataset(dataset, num_users, data_split_mode):
74 | data_split = {}
75 | if cfg['data_name'] in ['gld']:
76 | data_split['train'] = [GLD160(usr_data, usr_labels) for usr_data, usr_labels, _ in dataset['train'].values()]
77 | data_split['test'] = GLD160(*dataset['test'])
78 | label_split = [list(usr_lbl_split.keys()) for _, _, usr_lbl_split in dataset['train'].values()]
79 | return data_split, label_split
80 | if data_split_mode == 'iid':
81 | data_split['train'], label_split = iid(dataset['train'], num_users)
82 | data_split['test'], _ = iid(dataset['test'], num_users)
83 | elif 'non-iid' in cfg['data_split_mode']:
84 | data_split['train'], label_split = non_iid(dataset['train'], num_users)
85 | data_split['test'], _ = non_iid(dataset['test'], num_users, label_split)
86 |
87 | else:
88 | raise ValueError('Not valid data split mode')
89 | return data_split, label_split
90 |
91 |
92 | def iid(dataset, num_users):
93 | if cfg['data_name'] in ['MNIST', 'CIFAR10', 'CIFAR100']:
94 | label = torch.tensor(dataset.target)
95 | elif cfg['data_name'] in ['WikiText2', 'WikiText103', 'PennTreebank']:
96 | label = dataset.token
97 | else:
98 | raise ValueError('Not valid data name')
99 | num_items = int(len(dataset) / num_users)
100 | data_split, idx = {}, list(range(len(dataset)))
101 | label_split = {}
102 | for i in range(num_users):
103 | num_items_i = min(len(idx), num_items)
104 | data_split[i] = torch.tensor(idx)[torch.randperm(len(idx))[:num_items_i]].tolist()
105 | label_split[i] = torch.unique(label[data_split[i]]).tolist()
106 | idx = list(set(idx) - set(data_split[i]))
107 | return data_split, label_split
108 |
109 |
110 | def non_iid(dataset, num_users, label_split=None):
111 | label = np.array(dataset.target)
112 | cfg['non-iid-n'] = int(cfg['data_split_mode'].split('-')[-1])
113 | shard_per_user = cfg['non-iid-n']
114 | data_split = {i: [] for i in range(num_users)}
115 | label_idx_split = {}
116 | for i in range(len(label)):
117 | label_i = label[i].item()
118 | if label_i not in label_idx_split:
119 | label_idx_split[label_i] = []
120 | label_idx_split[label_i].append(i)
121 | shard_per_class = int(shard_per_user * num_users / cfg['classes_size'])
122 | for label_i in label_idx_split:
123 | label_idx = label_idx_split[label_i]
124 | num_leftover = len(label_idx) % shard_per_class
125 | leftover = label_idx[-num_leftover:] if num_leftover > 0 else []
126 | new_label_idx = np.array(label_idx[:-num_leftover]) if num_leftover > 0 else np.array(label_idx)
127 | new_label_idx = new_label_idx.reshape((shard_per_class, -1)).tolist()
128 | for i, leftover_label_idx in enumerate(leftover):
129 | new_label_idx[i] = np.concatenate([new_label_idx[i], [leftover_label_idx]])
130 | label_idx_split[label_i] = new_label_idx
131 | if label_split is None:
132 | label_split = list(range(cfg['classes_size'])) * shard_per_class
133 | label_split = torch.tensor(label_split)[torch.randperm(len(label_split))].tolist()
134 | label_split = np.array(label_split).reshape((num_users, -1)).tolist()
135 | for i in range(len(label_split)):
136 | label_split[i] = np.unique(label_split[i]).tolist()
137 | for i in range(num_users):
138 | for label_i in label_split[i]:
139 | idx = torch.arange(len(label_idx_split[label_i]))[torch.randperm(len(label_idx_split[label_i]))[0]].item()
140 | data_split[i].extend(label_idx_split[label_i].pop(idx))
141 | return data_split, label_split
142 |
143 |
144 | def make_data_loader(dataset):
145 | data_loader = {}
146 | for k in dataset:
147 | data_loader[k] = torch.utils.data.DataLoader(dataset=dataset[k], shuffle=cfg['shuffle'][k],
148 | batch_size=cfg['batch_size'][k], pin_memory=True,
149 | num_workers=cfg['num_workers'], collate_fn=input_collate)
150 | return data_loader
151 |
152 |
153 | class SplitDataset(Dataset):
154 | def __init__(self, dataset, idx):
155 | super().__init__()
156 | self.dataset = dataset
157 | self.idx = idx
158 |
159 | def __len__(self):
160 | return len(self.idx)
161 |
162 | def __getitem__(self, index):
163 | return self.dataset[self.idx[index]]
164 |
165 |
166 | class GenericDataset(Dataset):
167 | def __init__(self, dataset):
168 | super().__init__()
169 | self.dataset = dataset
170 |
171 | def __len__(self):
172 | return len(self.dataset)
173 |
174 | def __getitem__(self, index):
175 | input = self.dataset[index]
176 | return input
177 |
178 |
179 | class BatchDataset(Dataset):
180 | def __init__(self, dataset, seq_length):
181 | super().__init__()
182 | self.dataset = dataset
183 | self.seq_length = seq_length
184 | self.S = dataset[0]['label'].size(0)
185 | self.idx = list(range(0, self.S, seq_length))
186 |
187 | def __len__(self):
188 | return len(self.idx)
189 |
190 | def __getitem__(self, index):
191 | seq_length = min(self.seq_length, self.S - index)
192 | return {'label': self.dataset[:]['label'][:, self.idx[index]:self.idx[index] + seq_length]}
193 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .cifar import CIFAR10, CIFAR100
2 | from .folder import ImageFolder
3 | from .imagenet import ImageNet
4 | from .lm import PennTreebank, WikiText2, WikiText103
5 | from .mnist import MNIST, EMNIST, FashionMNIST
6 | from .transforms import *
7 | from .utils import *
8 |
9 | __all__ = ('MNIST', 'EMNIST', 'FashionMNIST',
10 | 'CIFAR10', 'CIFAR100',
11 | 'ImageNet',
12 | 'PennTreebank', 'WikiText2', 'WikiText103',
13 | 'ImageFolder')
14 |
--------------------------------------------------------------------------------
/datasets/cifar.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | import anytree
5 | import numpy as np
6 | import torch
7 | from PIL import Image
8 | from torch.utils.data import Dataset
9 |
10 | from utils import check_exists, makedir_exist_ok, save, load
11 | from .utils import download_url, extract_file, make_classes_counts, make_tree, make_flat_index
12 |
13 |
14 | class CIFAR10(Dataset):
15 | data_name = 'CIFAR10'
16 | file = [('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', 'c58f30108f718f92721af3b95e74349a')]
17 |
18 | def __init__(self, root, split, subset, transform=None):
19 | self.root = os.path.expanduser(root)
20 | self.split = split
21 | self.subset = subset
22 | self.transform = transform
23 | if not check_exists(self.processed_folder):
24 | self.process()
25 | self.img, self.target = load(os.path.join(self.processed_folder, '{}.pt'.format(self.split)))
26 | self.target = self.target[self.subset]
27 | self.classes_counts = make_classes_counts(self.target)
28 | self.classes_to_labels, self.classes_size = load(os.path.join(self.processed_folder, 'meta.pt'))
29 | self.classes_to_labels, self.classes_size = self.classes_to_labels[self.subset], self.classes_size[self.subset]
30 |
31 | def __getitem__(self, index):
32 | img, target = Image.fromarray(self.img[index]), torch.tensor(self.target[index])
33 | input = {'img': img, self.subset: target}
34 | if self.transform is not None:
35 | input = self.transform(input)
36 | return input
37 |
38 | def __len__(self):
39 | return len(self.img)
40 |
41 | @property
42 | def processed_folder(self):
43 | return os.path.join(self.root, 'processed')
44 |
45 | @property
46 | def raw_folder(self):
47 | return os.path.join(self.root, 'raw')
48 |
49 | def process(self):
50 | if not check_exists(self.raw_folder):
51 | self.download()
52 | train_set, test_set, meta = self.make_data()
53 | save(train_set, os.path.join(self.processed_folder, 'train.pt'))
54 | save(test_set, os.path.join(self.processed_folder, 'test.pt'))
55 | save(meta, os.path.join(self.processed_folder, 'meta.pt'))
56 | return
57 |
58 | def download(self):
59 | makedir_exist_ok(self.raw_folder)
60 | for (url, md5) in self.file:
61 | filename = os.path.basename(url)
62 | download_url(url, self.raw_folder, filename, md5)
63 | extract_file(os.path.join(self.raw_folder, filename))
64 | return
65 |
66 | def __repr__(self):
67 | fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSplit: {}\nSubset: {}\nTransforms: {}'.format(
68 | self.__class__.__name__, self.__len__(), self.root, self.split, self.subset, self.transform.__repr__())
69 | return fmt_str
70 |
71 | def make_data(self):
72 | train_filenames = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5']
73 | test_filenames = ['test_batch']
74 | train_img, train_label = read_pickle_file(os.path.join(self.raw_folder, 'cifar-10-batches-py'), train_filenames)
75 | test_img, test_label = read_pickle_file(os.path.join(self.raw_folder, 'cifar-10-batches-py'), test_filenames)
76 | train_target, test_target = {'label': train_label}, {'label': test_label}
77 | with open(os.path.join(self.raw_folder, 'cifar-10-batches-py', 'batches.meta'), 'rb') as f:
78 | data = pickle.load(f, encoding='latin1')
79 | classes = data['label_names']
80 | classes_to_labels = {'label': anytree.Node('U', index=[])}
81 | for c in classes:
82 | make_tree(classes_to_labels['label'], [c])
83 | classes_size = {'label': make_flat_index(classes_to_labels['label'])}
84 | return (train_img, train_target), (test_img, test_target), (classes_to_labels, classes_size)
85 |
86 |
87 | class CIFAR100(CIFAR10):
88 | data_name = 'CIFAR100'
89 | file = [('https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz', 'eb9058c3a382ffc7106e4002c42a8d85')]
90 |
91 | def make_data(self):
92 | train_filenames = ['train']
93 | test_filenames = ['test']
94 | train_img, train_label = read_pickle_file(os.path.join(self.raw_folder, 'cifar-100-python'), train_filenames)
95 | test_img, test_label = read_pickle_file(os.path.join(self.raw_folder, 'cifar-100-python'), test_filenames)
96 | train_target, test_target = {'label': train_label}, {'label': test_label}
97 | with open(os.path.join(self.raw_folder, 'cifar-100-python', 'meta'), 'rb') as f:
98 | data = pickle.load(f, encoding='latin1')
99 | classes = data['fine_label_names']
100 | classes_to_labels = {'label': anytree.Node('U', index=[])}
101 | for c in classes:
102 | for k in CIFAR100_classes:
103 | if c in CIFAR100_classes[k]:
104 | c = [k, c]
105 | break
106 | make_tree(classes_to_labels['label'], c)
107 | classes_size = {'label': make_flat_index(classes_to_labels['label'], classes)}
108 | return (train_img, train_target), (test_img, test_target), (classes_to_labels, classes_size)
109 |
110 |
111 | def read_pickle_file(path, filenames):
112 | img, label = [], []
113 | for filename in filenames:
114 | file_path = os.path.join(path, filename)
115 | with open(file_path, 'rb') as f:
116 | entry = pickle.load(f, encoding='latin1')
117 | img.append(entry['data'])
118 | label.extend(entry['labels']) if 'labels' in entry else label.extend(entry['fine_labels'])
119 | img = np.vstack(img).reshape(-1, 3, 32, 32)
120 | img = img.transpose((0, 2, 3, 1))
121 | return img, label
122 |
123 |
124 | CIFAR100_classes = {
125 | 'aquatic mammals': ['beaver', 'dolphin', 'otter', 'seal', 'whale'],
126 | 'fish': ['aquarium_fish', 'flatfish', 'ray', 'shark', 'trout'],
127 | 'flowers': ['orchid', 'poppy', 'rose', 'sunflower', 'tulip'],
128 | 'food containers': ['bottle', 'bowl', 'can', 'cup', 'plate'],
129 | 'fruit and vegetables': ['apple', 'mushroom', 'orange', 'pear', 'sweet_pepper'],
130 | 'household electrical devices': ['clock', 'keyboard', 'lamp', 'telephone', 'television'],
131 | 'household furniture': ['bed', 'chair', 'couch', 'table', 'wardrobe'],
132 | 'insects': ['bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach'],
133 | 'large carnivores': ['bear', 'leopard', 'lion', 'tiger', 'wolf'],
134 | 'large man-made outdoor things': ['bridge', 'castle', 'house', 'road', 'skyscraper'],
135 | 'large natural outdoor scenes': ['cloud', 'forest', 'mountain', 'plain', 'sea'],
136 | 'large omnivores and herbivores': ['camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo'],
137 | 'medium-sized mammals': ['fox', 'porcupine', 'possum', 'raccoon', 'skunk'],
138 | 'non-insect invertebrates': ['crab', 'lobster', 'snail', 'spider', 'worm'],
139 | 'people': ['baby', 'boy', 'girl', 'man', 'woman'],
140 | 'reptiles': ['crocodile', 'dinosaur', 'lizard', 'snake', 'turtle'],
141 | 'small mammals': ['hamster', 'mouse', 'rabbit', 'shrew', 'squirrel'],
142 | 'trees': ['maple_tree', 'oak_tree', 'palm_tree', 'pine_tree', 'willow_tree'],
143 | 'vehicles 1': ['bicycle', 'bus', 'motorcycle', 'pickup_truck', 'train'],
144 | 'vehicles 2': ['lawn_mower', 'rocket', 'streetcar', 'tank', 'tractor']
145 | }
146 |
--------------------------------------------------------------------------------
/datasets/folder.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 |
7 | from utils import check_exists, save, load
8 | from .utils import find_classes, make_img, make_classes_counts
9 |
10 |
11 | class ImageFolder(Dataset):
12 |
13 | def __init__(self, root, split, subset, transform=None):
14 | self.data_name = os.path.basename(root)
15 | self.root = os.path.expanduser(root)
16 | self.split = split
17 | self.subset = subset
18 | self.transform = transform
19 | if not check_exists(self.processed_folder):
20 | self.process()
21 | self.img, self.target = load(os.path.join(self.processed_folder, '{}.pt'.format(self.split)))
22 | self.target = self.target[self.subset]
23 | self.classes_counts = make_classes_counts(self.target)
24 | self.classes_to_labels, self.classes_size = load(os.path.join(self.processed_folder, 'meta.pt'))
25 | self.classes_to_labels, self.classes_size = self.classes_to_labels[self.subset], self.classes_size[self.subset]
26 |
27 | def __getitem__(self, index):
28 | img, target = Image.open(self.img[index], mode='r').convert('RGB'), torch.tensor(self.target[index])
29 | input = {'img': img, self.subset: target}
30 | if self.transform is not None:
31 | input = self.transform(input)
32 | return input
33 |
34 | def __len__(self):
35 | return len(self.img)
36 |
37 | @property
38 | def processed_folder(self):
39 | return os.path.join(self.root, 'processed')
40 |
41 | @property
42 | def raw_folder(self):
43 | return os.path.join(self.root, 'raw')
44 |
45 | def process(self):
46 | if not check_exists(self.raw_folder):
47 | raise RuntimeError('Dataset not found')
48 | train_set, test_set, meta = self.make_data()
49 | save(train_set, os.path.join(self.processed_folder, 'train.pt'))
50 | save(test_set, os.path.join(self.processed_folder, 'test.pt'))
51 | save(meta, os.path.join(self.processed_folder, 'meta.pt'))
52 | return
53 |
54 | def __repr__(self):
55 | fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSplit: {}\nSubset: {}\nTransforms: {}'.format(
56 | self.__class__.__name__, self.__len__(), self.root, self.split, self.subset, self.transform.__repr__())
57 | return fmt_str
58 |
59 | def make_data(self):
60 | classes_to_labels, classes_size = find_classes(os.path.join(self.raw_folder, 'train'))
61 | train_img, train_label = make_img(os.path.join(self.raw_folder, 'train'), classes_to_labels['label'])
62 | test_img, test_label = make_img(os.path.join(self.raw_folder, 'test'), classes_to_labels['label'])
63 | train_target, test_target = {'label': train_label}, {'label': test_label}
64 | return (train_img, train_target), (test_img, test_target), (classes_to_labels, classes_size)
65 |
--------------------------------------------------------------------------------
/datasets/gld.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as T
2 | from torch.utils.data import Dataset
3 |
4 |
5 | class GLD160(Dataset):
6 | data_name = 'GLD'
7 |
8 | def __init__(self, images, targets, transform=T.RandomCrop(92)):
9 | self.transform = transform
10 | self.img, self.target = images, targets
11 | self.classes_counts = 2028
12 |
13 | def __getitem__(self, index):
14 | img = self.img[index]
15 | target = self.target[index]
16 | inp = {'img': img, 'label': target}
17 | if self.transform is not None:
18 | inp['img'] = self.transform(inp['img'])
19 | return inp
20 |
21 | def __len__(self):
22 | return len(self.img)
23 |
--------------------------------------------------------------------------------
/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import anytree
5 | import numpy as np
6 | import torch
7 | from PIL import Image, ImageFile
8 | from torch.utils.data import Dataset
9 |
10 | from utils import check_exists, save, load
11 | from .utils import extract_file, make_classes_counts, make_img, make_tree, make_flat_index
12 |
13 | ImageFile.LOAD_TRUNCATED_IMAGES = True
14 |
15 |
16 | class ImageNet(Dataset):
17 | data_name = 'ImageNet'
18 |
19 | def __init__(self, root, split, subset, size, transform=None):
20 | self.root = os.path.expanduser(root)
21 | self.split = split
22 | self.subset = subset
23 | self.transform = transform
24 | self.size = size
25 | if not check_exists(os.path.join(self.processed_folder, str(self.size))):
26 | self.process()
27 | self.img, self.target = load(os.path.join(self.processed_folder, str(self.size), '{}.pt'.format(self.split)))
28 | self.target = self.target[self.subset]
29 | self.classes_counts = make_classes_counts(self.target)
30 | self.classes_to_labels, self.classes_size = load(os.path.join(self.processed_folder, str(self.size), 'meta.pt'))
31 | self.classes_to_labels, self.classes_size = self.classes_to_labels[self.subset], self.classes_size[self.subset]
32 |
33 | def __getitem__(self, index):
34 | img, target = Image.open(self.img[index], mode='r').convert('RGB'), torch.tensor(self.target[index])
35 | input = {'img': img, self.subset: target}
36 | if self.transform is not None:
37 | input = self.transform(input)
38 | return input
39 |
40 | def __len__(self):
41 | return len(self.img)
42 |
43 | @property
44 | def processed_folder(self):
45 | return os.path.join(self.root, 'processed')
46 |
47 | @property
48 | def raw_folder(self):
49 | return os.path.join(self.root, 'raw')
50 |
51 | def process(self):
52 | if not check_exists(self.raw_folder):
53 | raise RuntimeError('Dataset not found')
54 | train_set, test_set, meta = self.make_data()
55 | save(train_set, os.path.join(self.processed_folder, str(self.size), 'train.pt'))
56 | save(test_set, os.path.join(self.processed_folder, str(self.size), 'test.pt'))
57 | save(meta, os.path.join(self.processed_folder, str(self.size), 'meta.pt'))
58 | return
59 |
60 | def __repr__(self):
61 | fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSplit: {}\nSubset: {}\nSize: {}\nTransforms: {}'.format(
62 | self.__class__.__name__, self.__len__(), self.root, self.split, self.subset, self.size,
63 | self.transform.__repr__())
64 | return fmt_str
65 |
66 | def make_data(self):
67 | if not check_exists(os.path.join(self.raw_folder, 'base')):
68 | train_path = os.path.join(self.raw_folder, 'ILSVRC2012_img_train')
69 | test_path = os.path.join(self.raw_folder, 'ILSVRC2012_img_val')
70 | meta_path = os.path.join(self.raw_folder, 'ILSVRC2012_devkit_t12')
71 | extract_file(os.path.join(self.raw_folder, 'ILSVRC2012_img_train.tar'), train_path)
72 | extract_file(os.path.join(self.raw_folder, 'ILSVRC2012_img_val.tar'), test_path)
73 | extract_file(os.path.join(self.raw_folder, 'ILSVRC2012_devkit_t12.tar'), meta_path)
74 | for archive in [os.path.join(train_path, archive) for archive in os.listdir(train_path)]:
75 | extract_file(archive, os.path.splitext(archive)[0], delete=True)
76 | classes_to_labels, classes_size = make_meta(meta_path)
77 | with open(os.path.join(meta_path, 'data', 'ILSVRC2012_validation_ground_truth.txt'), 'r') as f:
78 | test_id = f.readlines()
79 | test_id = [int(i) for i in test_id]
80 | test_img = sorted([os.path.join(test_path, file) for file in os.listdir(test_path)])
81 | test_wnid = []
82 | for test_id_i in test_id:
83 | test_node_i = anytree.find_by_attr(classes_to_labels['label'], name='id', value=test_id_i)
84 | test_wnid.append(test_node_i.name)
85 | for test_wnid_i in set(test_wnid):
86 | os.mkdir(os.path.join(test_path, test_wnid_i))
87 | for test_wnid_i, test_img in zip(test_wnid, test_img):
88 | shutil.move(test_img, os.path.join(test_path, test_wnid_i, os.path.basename(test_img)))
89 | shutil.move(train_path, os.path.join(self.raw_folder, 'base', 'ILSVRC2012_img_train'))
90 | shutil.move(test_path, os.path.join(self.raw_folder, 'base', 'ILSVRC2012_img_val'))
91 | shutil.move(meta_path, os.path.join(self.raw_folder, 'base', 'ILSVRC2012_devkit_t12'))
92 | if not check_exists(os.path.join(self.raw_folder, str(self.size))):
93 | raise ValueError('Need to run resizer')
94 | classes_to_labels, classes_size = make_meta(os.path.join(self.raw_folder, 'base', 'ILSVRC2012_devkit_t12'))
95 | train_img, train_label = make_img(os.path.join(self.raw_folder, str(self.size), 'ILSVRC2012_img_train'),
96 | classes_to_labels['label'])
97 | test_img, test_label = make_img(os.path.join(self.raw_folder, str(self.size), 'ILSVRC2012_img_val'),
98 | classes_to_labels['label'])
99 | train_target = {'label': train_label}
100 | test_target = {'label': test_label}
101 | return (train_img, train_target), (test_img, test_target), (classes_to_labels, classes_size)
102 |
103 |
104 | def make_meta(path):
105 | import scipy.io as sio
106 | meta = sio.loadmat(os.path.join(path, 'data', 'meta.mat'), squeeze_me=True)['synsets']
107 | num_children = list(zip(*meta))[4]
108 | leaf_meta = [meta[i] for (i, n) in enumerate(num_children) if n == 0]
109 | branch_meta = [meta[i] for (i, n) in enumerate(num_children) if n > 0]
110 | names, attributes = [], []
111 | for i in range(len(leaf_meta)):
112 | name, attribute = make_node(leaf_meta[i], branch_meta)
113 | names.append(name)
114 | attributes.append(attribute)
115 | classes_to_labels = {'label': anytree.Node('U', index=[])}
116 | classes = []
117 | for (name, attribute) in zip(names, attributes):
118 | make_tree(classes_to_labels['label'], name, attribute)
119 | classes.append(name[-1])
120 | classes_size = {'label': make_flat_index(classes_to_labels['label'], classes)}
121 | return classes_to_labels, classes_size
122 |
123 |
124 | def make_node(node, branch):
125 | id, wnid, classes = node.item()[:3]
126 | if classes == 'entity':
127 | name = []
128 | attribute = {'id': [], 'class': []}
129 | for i in range(len(branch)):
130 | branch_children = branch[i].item()[5]
131 | if (isinstance(branch_children, int) and id == branch_children) or (
132 | isinstance(branch_children, np.ndarray) and id in branch_children):
133 | parent_name, parent_attribute = make_node(branch[i], branch)
134 | name = parent_name + [wnid]
135 | attribute = {'id': parent_attribute['id'] + [id], 'class': parent_attribute['class'] + [classes]}
136 | break
137 | return name, attribute
138 |
--------------------------------------------------------------------------------
/datasets/lm.py:
--------------------------------------------------------------------------------
1 | import os
2 | from abc import abstractmethod
3 |
4 | import torch
5 | from torch.utils.data import Dataset
6 |
7 | from utils import check_exists, makedir_exist_ok, save, load
8 | from .utils import download_url, extract_file
9 |
10 |
11 | class Vocab:
12 | def __init__(self):
13 | self.symbol_to_index = {u'': 0, u'': 1}
14 | self.index_to_symbol = [u'', u'']
15 |
16 | def add(self, symbol):
17 | if symbol not in self.symbol_to_index:
18 | self.index_to_symbol.append(symbol)
19 | self.symbol_to_index[symbol] = len(self.index_to_symbol) - 1
20 | return
21 |
22 | def delete(self, symbol):
23 | if symbol in self.symbol_to_index:
24 | self.index_to_symbol.remove(symbol)
25 | self.symbol_to_index.pop(symbol, None)
26 | return
27 |
28 | def __len__(self):
29 | return len(self.index_to_symbol)
30 |
31 | def __getitem__(self, input):
32 | if isinstance(input, int):
33 | if len(self.index_to_symbol) > input >= 0:
34 | output = self.index_to_symbol[input]
35 | else:
36 | output = u''
37 | elif isinstance(input, str):
38 | if input not in self.symbol_to_index:
39 | output = self.symbol_to_index[u'']
40 | else:
41 | output = self.symbol_to_index[input]
42 | else:
43 | raise ValueError('Not valid data type')
44 | return output
45 |
46 | def __contains__(self, input):
47 | if isinstance(input, int):
48 | exist = len(self.index_to_symbol) > input >= 0
49 | elif isinstance(input, str):
50 | exist = input in self.symbol_to_index
51 | else:
52 | raise ValueError('Not valid data type')
53 | return exist
54 |
55 |
56 | class LanguageModeling(Dataset):
57 | def __init__(self, root, split):
58 | self.root = os.path.expanduser(root)
59 | self.split = split
60 | if not check_exists(self.processed_folder):
61 | self.process()
62 | self.token = load(os.path.join(self.processed_folder, '{}.pt'.format(split)))
63 | self.vocab = load(os.path.join(self.processed_folder, 'meta.pt'.format(split)))
64 |
65 | def __getitem__(self, index):
66 | input = {'label': self.token[index]}
67 | return input
68 |
69 | def __len__(self):
70 | return len(self.token)
71 |
72 | @property
73 | def processed_folder(self):
74 | return os.path.join(self.root, 'processed')
75 |
76 | @property
77 | def raw_folder(self):
78 | return os.path.join(self.root, 'raw')
79 |
80 | def _check_exists(self):
81 | return os.path.exists(self.processed_folder)
82 |
83 | @abstractmethod
84 | def process(self):
85 | raise NotImplementedError
86 |
87 | @abstractmethod
88 | def download(self):
89 | raise NotImplementedError
90 |
91 | def __repr__(self):
92 | fmt_str = 'Dataset {}\nRoot: {}\nSplit: {}'.format(
93 | self.__class__.__name__, self.root, self.split)
94 | return fmt_str
95 |
96 |
97 | class PennTreebank(LanguageModeling):
98 | data_name = 'PennTreebank'
99 | file = [('https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt', None),
100 | ('https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.valid.txt', None),
101 | ('https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.test.txt', None)]
102 |
103 | def __init__(self, root, split):
104 | super().__init__(root, split)
105 |
106 | def process(self):
107 | if not check_exists(self.raw_folder):
108 | self.download()
109 | train_set, valid_set, test_set, meta = self.make_data()
110 | save(train_set, os.path.join(self.processed_folder, 'train.pt'))
111 | save(valid_set, os.path.join(self.processed_folder, 'valid.pt'))
112 | save(test_set, os.path.join(self.processed_folder, 'test.pt'))
113 | save(meta, os.path.join(self.processed_folder, 'meta.pt'))
114 | return
115 |
116 | def download(self):
117 | makedir_exist_ok(self.raw_folder)
118 | for (url, md5) in self.file:
119 | filename = os.path.basename(url)
120 | download_url(url, self.raw_folder, filename, md5)
121 | extract_file(os.path.join(self.raw_folder, filename))
122 | return
123 |
124 | def make_data(self):
125 | vocab = Vocab()
126 | read_token(vocab, os.path.join(self.raw_folder, 'ptb.train.txt'))
127 | read_token(vocab, os.path.join(self.raw_folder, 'ptb.valid.txt'))
128 | train_token = make_token(vocab, os.path.join(self.raw_folder, 'ptb.train.txt'))
129 | valid_token = make_token(vocab, os.path.join(self.raw_folder, 'ptb.valid.txt'))
130 | test_token = make_token(vocab, os.path.join(self.raw_folder, 'ptb.test.txt'))
131 | return train_token, valid_token, test_token, vocab
132 |
133 |
134 | class WikiText2(LanguageModeling):
135 | data_name = 'WikiText2'
136 | file = [('https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip', None)]
137 |
138 | def __init__(self, root, split):
139 | super().__init__(root, split)
140 |
141 | def process(self):
142 | if not check_exists(self.raw_folder):
143 | self.download()
144 | train_set, valid_set, test_set, meta = self.make_data()
145 | save(train_set, os.path.join(self.processed_folder, 'train.pt'))
146 | save(valid_set, os.path.join(self.processed_folder, 'valid.pt'))
147 | save(test_set, os.path.join(self.processed_folder, 'test.pt'))
148 | save(meta, os.path.join(self.processed_folder, 'meta.pt'))
149 | return
150 |
151 | def download(self):
152 | makedir_exist_ok(self.raw_folder)
153 | for (url, md5) in self.file:
154 | filename = os.path.basename(url)
155 | download_url(url, self.raw_folder, filename, md5)
156 | extract_file(os.path.join(self.raw_folder, filename))
157 | return
158 |
159 | def make_data(self):
160 | vocab = Vocab()
161 | read_token(vocab, os.path.join(self.raw_folder, 'wikitext-2', 'wiki.train.tokens'))
162 | read_token(vocab, os.path.join(self.raw_folder, 'wikitext-2', 'wiki.train.tokens'))
163 | train_token = make_token(vocab, os.path.join(self.raw_folder, 'wikitext-2', 'wiki.train.tokens'))
164 | valid_token = make_token(vocab, os.path.join(self.raw_folder, 'wikitext-2', 'wiki.valid.tokens'))
165 | test_token = make_token(vocab, os.path.join(self.raw_folder, 'wikitext-2', 'wiki.test.tokens'))
166 | return train_token, valid_token, test_token, vocab
167 |
168 |
169 | class WikiText103(LanguageModeling):
170 | data_name = 'WikiText103'
171 | file = [('https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip', None)]
172 |
173 | def __init__(self, root, split):
174 | super().__init__(root, split)
175 |
176 | def process(self):
177 | if not check_exists(self.raw_folder):
178 | self.download()
179 | train_set, valid_set, test_set, meta = self.make_data()
180 | save(train_set, os.path.join(self.processed_folder, 'train.pt'))
181 | save(valid_set, os.path.join(self.processed_folder, 'valid.pt'))
182 | save(test_set, os.path.join(self.processed_folder, 'test.pt'))
183 | save(meta, os.path.join(self.processed_folder, 'meta.pt'))
184 | return
185 |
186 | def download(self):
187 | makedir_exist_ok(self.raw_folder)
188 | for (url, md5) in self.file:
189 | filename = os.path.basename(url)
190 | download_url(url, self.raw_folder, filename, md5)
191 | extract_file(os.path.join(self.raw_folder, filename))
192 | return
193 |
194 | def make_data(self):
195 | vocab = Vocab()
196 | read_token(vocab, os.path.join(self.raw_folder, 'wikitext-103', 'wiki.train.tokens'))
197 | read_token(vocab, os.path.join(self.raw_folder, 'wikitext-103', 'wiki.train.tokens'))
198 | train_token = make_token(vocab, os.path.join(self.raw_folder, 'wikitext-103', 'wiki.train.tokens'))
199 | valid_token = make_token(vocab, os.path.join(self.raw_folder, 'wikitext-103', 'wiki.valid.tokens'))
200 | test_token = make_token(vocab, os.path.join(self.raw_folder, 'wikitext-103', 'wiki.test.tokens'))
201 | return train_token, valid_token, test_token, vocab
202 |
203 |
204 | class StackOverflowClientDataset(Dataset):
205 | def __init__(self, token, seq_length, batch_size):
206 | self.seq_length = seq_length
207 | self.token = token
208 | num_batch = len(token) // (batch_size * seq_length)
209 | self.token = self.token.narrow(0, 0, num_batch * batch_size * seq_length)
210 | self.token = self.token.reshape(-1, batch_size, seq_length)
211 |
212 | def __getitem__(self, index):
213 | return {'label': self.token[index, :, :].reshape(-1, self.seq_length)}
214 |
215 | def __len__(self):
216 | return len(self.token)
217 |
218 |
219 | def read_token(vocab, token_path):
220 | with open(token_path, 'r', encoding='utf-8') as f:
221 | for line in f:
222 | line = line.split() + [u'']
223 | for symbol in line:
224 | vocab.add(symbol)
225 | return
226 |
227 |
228 | def make_token(vocab, token_path):
229 | token = []
230 | with open(token_path, 'r', encoding='utf-8') as f:
231 | for line in f:
232 | line = line.split() + [u'']
233 | for symbol in line:
234 | token.append(vocab[symbol])
235 | token = torch.tensor(token, dtype=torch.long)
236 | return token
237 |
--------------------------------------------------------------------------------
/datasets/mnist.py:
--------------------------------------------------------------------------------
1 | import codecs
2 | import os
3 |
4 | import anytree
5 | import numpy as np
6 | import torch
7 | from PIL import Image
8 | from torch.utils.data import Dataset
9 |
10 | from utils import check_exists, makedir_exist_ok, save, load
11 | from .utils import download_url, extract_file, make_classes_counts, make_tree, make_flat_index
12 |
13 |
14 | class MNIST(Dataset):
15 | data_name = 'MNIST'
16 | file = [('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'),
17 | ('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'),
18 | ('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'),
19 | ('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c')]
20 |
21 | def __init__(self, root, split, subset, transform=None):
22 | self.root = os.path.expanduser(root)
23 | self.split = split
24 | self.subset = subset
25 | self.transform = transform
26 | if not check_exists(self.processed_folder):
27 | self.process()
28 | self.img, self.target = load(os.path.join(self.processed_folder, '{}.pt'.format(self.split)))
29 | self.target = self.target[self.subset]
30 | self.classes_counts = make_classes_counts(self.target)
31 | self.classes_to_labels, self.classes_size = load(os.path.join(self.processed_folder, 'meta.pt'))
32 | self.classes_to_labels, self.classes_size = self.classes_to_labels[self.subset], self.classes_size[self.subset]
33 |
34 | def __getitem__(self, index):
35 | img, target = Image.fromarray(self.img[index], mode='L'), torch.tensor(self.target[index])
36 | input = {'img': img, self.subset: target}
37 | if self.transform is not None:
38 | input = self.transform(input)
39 | return input
40 |
41 | def __len__(self):
42 | return len(self.img)
43 |
44 | @property
45 | def processed_folder(self):
46 | return os.path.join(self.root, 'processed')
47 |
48 | @property
49 | def raw_folder(self):
50 | return os.path.join(self.root, 'raw')
51 |
52 | def process(self):
53 | if not check_exists(self.raw_folder):
54 | self.download()
55 | train_set, test_set, meta = self.make_data()
56 | save(train_set, os.path.join(self.processed_folder, 'train.pt'))
57 | save(test_set, os.path.join(self.processed_folder, 'test.pt'))
58 | save(meta, os.path.join(self.processed_folder, 'meta.pt'))
59 | return
60 |
61 | def download(self):
62 | makedir_exist_ok(self.raw_folder)
63 | for (url, md5) in self.file:
64 | filename = os.path.basename(url)
65 | download_url(url, self.raw_folder, filename, md5)
66 | extract_file(os.path.join(self.raw_folder, filename))
67 | return
68 |
69 | def __repr__(self):
70 | fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSplit: {}\nSubset: {}\nTransforms: {}'.format(
71 | self.__class__.__name__, self.__len__(), self.root, self.split, self.subset, self.transform.__repr__())
72 | return fmt_str
73 |
74 | def make_data(self):
75 | train_img = read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte'))
76 | test_img = read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte'))
77 | train_label = read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
78 | test_label = read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
79 | train_target, test_target = {'label': train_label}, {'label': test_label}
80 | classes_to_labels = {'label': anytree.Node('U', index=[])}
81 | classes = list(map(str, list(range(10))))
82 | for c in classes:
83 | make_tree(classes_to_labels['label'], [c])
84 | classes_size = {'label': make_flat_index(classes_to_labels['label'])}
85 | return (train_img, train_target), (test_img, test_target), (classes_to_labels, classes_size)
86 |
87 |
88 | class EMNIST(MNIST):
89 | data_name = 'EMNIST'
90 | file = [('http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip', '58c8d27c78d21e728a6bc7b3cc06412e')]
91 |
92 | def __init__(self, root, split, subset, transform=None):
93 | super().__init__(root, split, subset, transform)
94 | self.img = self.img[self.subset]
95 |
96 | def make_data(self):
97 | gzip_folder = os.path.join(self.raw_folder, 'gzip')
98 | for gzip_file in os.listdir(gzip_folder):
99 | if gzip_file.endswith('.gz'):
100 | extract_file(os.path.join(gzip_folder, gzip_file))
101 | subsets = ['byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist']
102 | train_img, test_img, train_target, test_target = {}, {}, {}, {}
103 | digits_classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
104 | upper_letters_classes = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q',
105 | 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
106 | lower_letters_classes = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
107 | 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
108 | merged_classes = ['c', 'i', 'j', 'k', 'l', 'm', 'o', 'p', 's', 'u', 'v', 'w', 'x', 'y', 'z']
109 | unmerged_classes = list(set(lower_letters_classes) - set(merged_classes))
110 | classes = {'byclass': digits_classes + upper_letters_classes + lower_letters_classes,
111 | 'bymerge': digits_classes + upper_letters_classes + unmerged_classes,
112 | 'balanced': digits_classes + upper_letters_classes + unmerged_classes,
113 | 'letters': upper_letters_classes + unmerged_classes, 'digits': digits_classes,
114 | 'mnist': digits_classes}
115 | classes_to_labels = {s: anytree.Node('U', index=[]) for s in subsets}
116 | classes_size = {}
117 | for subset in subsets:
118 | train_img[subset] = read_image_file(
119 | os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(subset)))
120 | train_img[subset] = np.transpose(train_img[subset], [0, 2, 1])
121 | test_img[subset] = read_image_file(
122 | os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(subset)))
123 | test_img[subset] = np.transpose(test_img[subset], [0, 2, 1])
124 | train_target[subset] = read_label_file(
125 | os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(subset)))
126 | test_target[subset] = read_label_file(
127 | os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(subset)))
128 | for c in classes[subset]:
129 | make_tree(classes_to_labels[subset], c)
130 | classes_size[subset] = make_flat_index(classes_to_labels[subset])
131 | return (train_img, train_target), (test_img, test_target), (classes_to_labels, classes_size)
132 |
133 |
134 | class FashionMNIST(MNIST):
135 | data_name = 'FashionMNIST'
136 | file = [('http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
137 | '8d4fb7e6c68d591d4c3dfef9ec88bf0d'),
138 | ('http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
139 | 'bef4ecab320f06d8554ea6380940ec79'),
140 | ('http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
141 | '25c81989df183df01b3e8a0aad5dffbe'),
142 | ('http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
143 | 'bb300cfdad3c16e7a12a480ee83cd310')]
144 |
145 | def make_data(self):
146 | train_img = read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte'))
147 | test_img = read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte'))
148 | train_label = read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
149 | test_label = read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
150 | train_target = {'label': train_label}
151 | test_target = {'label': test_label}
152 | classes_to_labels = {'label': anytree.Node('U', index=[])}
153 | classes = ['T-shirt_top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag',
154 | 'Ankle boot']
155 | for c in classes:
156 | make_tree(classes_to_labels['label'], c)
157 | classes_size = {'label': make_flat_index(classes_to_labels['label'])}
158 | return (train_img, train_target), (test_img, test_target), (classes_to_labels, classes_size)
159 |
160 |
161 | def get_int(b):
162 | return int(codecs.encode(b, 'hex'), 16)
163 |
164 |
165 | def read_image_file(path):
166 | with open(path, 'rb') as f:
167 | data = f.read()
168 | assert get_int(data[:4]) == 2051
169 | length = get_int(data[4:8])
170 | num_rows = get_int(data[8:12])
171 | num_cols = get_int(data[12:16])
172 | parsed = np.frombuffer(data, dtype=np.uint8, offset=16).reshape((length, num_rows, num_cols))
173 | return parsed
174 |
175 |
176 | def read_label_file(path):
177 | with open(path, 'rb') as f:
178 | data = f.read()
179 | assert get_int(data[:4]) == 2049
180 | length = get_int(data[4:8])
181 | parsed = np.frombuffer(data, dtype=np.uint8, offset=8).reshape(length).astype(np.int64)
182 | return parsed
183 |
--------------------------------------------------------------------------------
/datasets/omniglot.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import anytree
4 | import torch
5 | from PIL import Image
6 | from torch.utils.data import Dataset
7 |
8 | from utils import check_exists, makedir_exist_ok, save, load
9 | from .utils import IMG_EXTENSIONS
10 | from .utils import download_url, extract_file, make_classes_counts, make_data, make_tree, make_flat_index
11 |
12 |
13 | class Omniglot(Dataset):
14 | data_name = 'Omniglot'
15 | file = [
16 | ('https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
17 | '68d2efa1b9178cc56df9314c21c6e718'),
18 | ('https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip',
19 | '6b91aef0f799c5bb55b94e3f2daec811')
20 | ]
21 |
22 | def __init__(self, root, split, subset, transform=None):
23 | self.root = os.path.expanduser(root)
24 | self.split = split
25 | self.subset = subset
26 | self.transform = transform
27 | if not check_exists(self.processed_folder):
28 | self.process()
29 | self.img, self.target = load(os.path.join(self.processed_folder, '{}.pt'.format(self.split)))
30 | self.target = self.target[self.subset]
31 | self.classes_counts = make_classes_counts(self.target)
32 | self.classes_to_labels, self.classes_size = load(os.path.join(self.processed_folder, 'meta.pt'))
33 | self.classes_to_labels, self.classes_size = self.classes_to_labels[self.subset], self.classes_size[self.subset]
34 |
35 | def __getitem__(self, index):
36 | img, target = Image.open(self.img[index], mode='r').convert('L'), torch.tensor(self.target[index])
37 | input = {'img': img, self.subset: target}
38 | if self.transform is not None:
39 | input = self.transform(input)
40 | return input
41 |
42 | def __len__(self):
43 | return len(self.img)
44 |
45 | @property
46 | def processed_folder(self):
47 | return os.path.join(self.root, 'processed')
48 |
49 | @property
50 | def raw_folder(self):
51 | return os.path.join(self.root, 'raw')
52 |
53 | def process(self):
54 | if not check_exists(self.raw_folder):
55 | self.download()
56 | train_set, test_set, meta = self.make_data()
57 | save(train_set, os.path.join(self.processed_folder, 'train.pt'))
58 | save(test_set, os.path.join(self.processed_folder, 'test.pt'))
59 | save(meta, os.path.join(self.processed_folder, 'meta.pt'))
60 | return
61 |
62 | def download(self):
63 | makedir_exist_ok(self.raw_folder)
64 | for (url, md5) in self.file:
65 | filename = os.path.basename(url)
66 | download_url(url, self.raw_folder, filename, md5)
67 | extract_file(os.path.join(self.raw_folder, filename))
68 | return
69 |
70 | def __repr__(self):
71 | fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSplit: {}\nSubset: {}\nTransforms: {}'.format(
72 | self.__class__.__name__, self.__len__(), self.root, self.split, self.subset, self.transform.__repr__())
73 | return fmt_str
74 |
75 | def make_data(self):
76 | img = make_data(self.raw_folder, IMG_EXTENSIONS)
77 | classes = set()
78 | train_img = []
79 | test_img = []
80 | train_label = []
81 | test_label = []
82 | for i in range(len(img)):
83 | img_i = img[i]
84 | class_i = '/'.join(os.path.normpath(img_i).split(os.path.sep)[-3:-1])
85 | classes.add(class_i)
86 | idx_i = int(os.path.splitext(os.path.basename(img_i))[0].split('_')[1])
87 | if idx_i <= 10:
88 | train_img.append(img_i)
89 | else:
90 | test_img.append(img_i)
91 | classes = sorted(list(classes))
92 | classes_to_labels = {'label': anytree.Node('U', index=[])}
93 | for c in classes:
94 | make_tree(classes_to_labels['label'], c.split('/'))
95 | classes_size = {'label': make_flat_index(classes_to_labels['label'])}
96 | r = anytree.resolver.Resolver()
97 | for i in range(len(train_img)):
98 | train_img_i = train_img[i]
99 | train_class_i = '/'.join(os.path.normpath(train_img_i).split(os.path.sep)[-3:-1])
100 | node = r.get(classes_to_labels['label'], train_class_i)
101 | train_label.append(node.flat_index)
102 | for i in range(len(test_img)):
103 | test_img_i = test_img[i]
104 | test_class_i = '/'.join(os.path.normpath(test_img_i).split(os.path.sep)[-3:-1])
105 | node = r.get(classes_to_labels['label'], test_class_i)
106 | test_label.append(node.flat_index)
107 | train_target = {'label': train_label}
108 | test_target = {'label': test_label}
109 | return (train_img, train_target), (test_img, test_target), (classes_to_labels, classes_size)
110 |
--------------------------------------------------------------------------------
/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | class CustomTransform(object):
2 | def __call__(self, input):
3 | return input['img']
4 |
5 | def __repr__(self):
6 | return self.__class__.__name__
7 |
8 |
9 | class BoundingBoxCrop(CustomTransform):
10 | def __init__(self):
11 | pass
12 |
13 | def __call__(self, input):
14 | x, y, width, height = input['bbox'].long().tolist()
15 | left, top, right, bottom = x, y, x + width, y + height
16 | bboxc_img = input['img'].crop((left, top, right, bottom))
17 | return bboxc_img
18 |
--------------------------------------------------------------------------------
/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import gzip
3 | import hashlib
4 | import os
5 | import tarfile
6 | import zipfile
7 | from collections import Counter
8 |
9 | import anytree
10 | import numpy as np
11 | from PIL import Image
12 | from tqdm import tqdm
13 |
14 | from utils import makedir_exist_ok
15 | from .transforms import *
16 |
17 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
18 |
19 |
20 | def find_classes(dir):
21 | classes = [d.name for d in os.scandir(dir) if d.is_dir()]
22 | classes.sort()
23 | classes_to_labels = {classes[i]: i for i in range(len(classes))}
24 | return classes_to_labels
25 |
26 |
27 | def pil_loader(path):
28 | with open(path, 'rb') as f:
29 | img = Image.open(f)
30 | return img.convert('RGB')
31 |
32 |
33 | def accimage_loader(path):
34 | import accimage
35 | try:
36 | return accimage.Image(path)
37 | except IOError:
38 | return pil_loader(path)
39 |
40 |
41 | def default_loader(path):
42 | from torchvision import get_image_backend
43 | if get_image_backend() == 'accimage':
44 | return accimage_loader(path)
45 | else:
46 | return pil_loader(path)
47 |
48 |
49 | def has_file_allowed_extension(filename, extensions):
50 | filename_lower = filename.lower()
51 | return any(filename_lower.endswith(ext) for ext in extensions)
52 |
53 |
54 | def make_classes_counts(label):
55 | label = np.array(label)
56 | if label.ndim > 1:
57 | label = label.sum(axis=tuple([i for i in range(1, label.ndim)]))
58 | classes_counts = Counter(label)
59 | return classes_counts
60 |
61 |
62 | def make_bar_updater(pbar):
63 | def bar_update(count, block_size, total_size):
64 | if pbar.total is None and total_size:
65 | pbar.total = total_size
66 | progress_bytes = count * block_size
67 | pbar.update(progress_bytes - pbar.n)
68 |
69 | return bar_update
70 |
71 |
72 | def calculate_md5(path, chunk_size=1024 * 1024):
73 | md5 = hashlib.md5()
74 | with open(path, 'rb') as f:
75 | for chunk in iter(lambda: f.read(chunk_size), b''):
76 | md5.update(chunk)
77 | return md5.hexdigest()
78 |
79 |
80 | def check_md5(path, md5, **kwargs):
81 | return md5 == calculate_md5(path, **kwargs)
82 |
83 |
84 | def check_integrity(path, md5=None):
85 | if not os.path.isfile(path):
86 | return False
87 | if md5 is None:
88 | return True
89 | return check_md5(path, md5)
90 |
91 |
92 | def download_url(url, root, filename, md5):
93 | from six.moves import urllib
94 | path = os.path.join(root, filename)
95 | makedir_exist_ok(root)
96 | if os.path.isfile(path) and check_integrity(path, md5):
97 | print('Using downloaded and verified file: ' + path)
98 | else:
99 | try:
100 | print('Downloading ' + url + ' to ' + path)
101 | urllib.request.urlretrieve(url, path, reporthook=make_bar_updater(tqdm(unit='B', unit_scale=True)))
102 | except OSError:
103 | if url[:5] == 'https':
104 | url = url.replace('https:', 'http:')
105 | print('Failed download. Trying https -> http instead.'
106 | ' Downloading ' + url + ' to ' + path)
107 | urllib.request.urlretrieve(url, path, reporthook=make_bar_updater(tqdm(unit='B', unit_scale=True)))
108 | if not check_integrity(path, md5):
109 | raise RuntimeError('Not valid downloaded file')
110 | return
111 |
112 |
113 | def extract_file(src, dest=None, delete=False):
114 | print('Extracting {}'.format(src))
115 | dest = os.path.dirname(src) if dest is None else dest
116 | filename = os.path.basename(src)
117 | if filename.endswith('.zip'):
118 | with zipfile.ZipFile(src, "r") as zip_f:
119 | zip_f.extractall(dest)
120 | elif filename.endswith('.tar'):
121 | with tarfile.open(src) as tar_f:
122 | tar_f.extractall(dest)
123 | elif filename.endswith('.tar.gz') or filename.endswith('.tgz'):
124 | with tarfile.open(src, 'r:gz') as tar_f:
125 | tar_f.extractall(dest)
126 | elif filename.endswith('.gz'):
127 | with open(src.replace('.gz', ''), 'wb') as out_f, gzip.GzipFile(src) as zip_f:
128 | out_f.write(zip_f.read())
129 | if delete:
130 | os.remove(src)
131 | return
132 |
133 |
134 | def make_data(root, extensions):
135 | path = []
136 | files = glob.glob('{}/**/*'.format(root), recursive=True)
137 | for file in files:
138 | if has_file_allowed_extension(file, extensions):
139 | path.append(os.path.normpath(file))
140 | return path
141 |
142 |
143 | def make_img(path, classes_to_labels, extensions=IMG_EXTENSIONS):
144 | img, label = [], []
145 | classes = []
146 | leaf_nodes = classes_to_labels.leaves
147 | for node in leaf_nodes:
148 | classes.append(node.name)
149 | for c in sorted(classes):
150 | d = os.path.join(path, c)
151 | if not os.path.isdir(d):
152 | continue
153 | for root, _, filenames in sorted(os.walk(d)):
154 | for filename in sorted(filenames):
155 | if has_file_allowed_extension(filename, extensions):
156 | cur_path = os.path.join(root, filename)
157 | img.append(cur_path)
158 | label.append(anytree.find_by_attr(classes_to_labels, c).flat_index)
159 | return img, label
160 |
161 |
162 | def make_tree(root, name, attribute=None):
163 | if len(name) == 0:
164 | return
165 | if attribute is None:
166 | attribute = {}
167 | this_name = name[0]
168 | next_name = name[1:]
169 | this_attribute = {k: attribute[k][0] for k in attribute}
170 | next_attribute = {k: attribute[k][1:] for k in attribute}
171 | this_node = anytree.find_by_attr(root, this_name)
172 | this_index = root.index + [len(root.children)]
173 | if this_node is None:
174 | this_node = anytree.Node(this_name, parent=root, index=this_index, **this_attribute)
175 | make_tree(this_node, next_name, next_attribute)
176 | return
177 |
178 |
179 | def make_flat_index(root, given=None):
180 | if given:
181 | classes_size = 0
182 | for node in anytree.PreOrderIter(root):
183 | if len(node.children) == 0:
184 | node.flat_index = given.index(node.name)
185 | classes_size = given.index(node.name) + 1 if given.index(node.name) + 1 > classes_size else classes_size
186 | else:
187 | classes_size = 0
188 | for node in anytree.PreOrderIter(root):
189 | if len(node.children) == 0:
190 | node.flat_index = classes_size
191 | classes_size += 1
192 | return classes_size
193 |
194 |
195 | class Compose(object):
196 | def __init__(self, transforms):
197 | self.transforms = transforms
198 |
199 | def __call__(self, input):
200 | for t in self.transforms:
201 | if isinstance(t, CustomTransform):
202 | input['img'] = t(input)
203 | else:
204 | input['img'] = t(input['img'])
205 | return input
206 |
207 | def __repr__(self):
208 | format_string = self.__class__.__name__ + '('
209 | for t in self.transforms:
210 | format_string += '\n'
211 | format_string += ' {0}'.format(t)
212 | format_string += '\n)'
213 | return format_string
214 |
--------------------------------------------------------------------------------
/figures/fedrolex_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/FedRolex/813510997e1802eb756d53baa4229c90ce5ac008/figures/fedrolex_overview.png
--------------------------------------------------------------------------------
/figures/table_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/FedRolex/813510997e1802eb756d53baa4229c90ce5ac008/figures/table_overview.png
--------------------------------------------------------------------------------
/figures/video_placeholder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/FedRolex/813510997e1802eb756d53baa4229c90ce5ac008/figures/video_placeholder.png
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from collections.abc import Iterable
3 | from numbers import Number
4 |
5 | from torch.utils.tensorboard import SummaryWriter
6 |
7 | from utils import ntuple
8 |
9 |
10 | class Logger():
11 | def __init__(self, log_path):
12 | self.log_path = log_path
13 | self.writer = None
14 | self.tracker = defaultdict(int)
15 | self.counter = defaultdict(int)
16 | self.mean = defaultdict(int)
17 | self.history = defaultdict(list)
18 | self.iterator = defaultdict(int)
19 | self.hist = defaultdict(list)
20 |
21 | def safe(self, write):
22 | if write:
23 | self.writer = SummaryWriter(self.log_path)
24 | else:
25 | if self.writer is not None:
26 | self.writer.close()
27 | self.writer = None
28 | for name in self.mean:
29 | self.history[name].append(self.mean[name])
30 | return
31 |
32 | def reset(self):
33 | self.tracker = defaultdict(int)
34 | self.counter = defaultdict(int)
35 | self.mean = defaultdict(int)
36 | self.hist = defaultdict(list)
37 | return
38 |
39 | def append(self, result, tag, n=1, mean=True):
40 | for k in result:
41 | name = '{}/{}'.format(tag, k)
42 | self.tracker[name] = result[k]
43 | if mean:
44 | if isinstance(result[k], Number):
45 | self.counter[name] += n
46 | if 'local' in name.lower():
47 | self.hist[name].append(result[k])
48 | self.mean[name] = ((self.counter[name] - n) * self.mean[name] + n * result[k]) / self.counter[name]
49 | elif isinstance(result[k], Iterable):
50 | if name not in self.mean:
51 | self.counter[name] = [0 for _ in range(len(result[k]))]
52 | self.mean[name] = [0 for _ in range(len(result[k]))]
53 | _ntuple = ntuple(len(result[k]))
54 | n = _ntuple(n)
55 | for i in range(len(result[k])):
56 | self.counter[name][i] += n[i]
57 | if 'local' in name.lower():
58 | self.hist[name].append(n[i])
59 | self.mean[name][i] = ((self.counter[name][i] - n[i]) * self.mean[name][i] + n[i] *
60 | result[k][i]) / self.counter[name][i]
61 | else:
62 | raise ValueError('Not valid data type')
63 | return
64 |
65 | def write(self, tag, metric_names):
66 | names = ['{}/{}'.format(tag, k) for k in metric_names]
67 | evaluation_info = []
68 | for name in names:
69 | tag, k = name.split('/')
70 | if isinstance(self.mean[name], Number):
71 | s = self.mean[name]
72 | evaluation_info.append('{}: {:.4f}'.format(k, s))
73 | if self.writer is not None:
74 | self.iterator[name] += 1
75 | self.writer.add_scalar(name, s, self.iterator[name])
76 | elif isinstance(self.mean[name], Iterable):
77 | s = tuple(self.mean[name])
78 | evaluation_info.append('{}: {}'.format(k, s))
79 | if self.writer is not None:
80 | self.iterator[name] += 1
81 | self.writer.add_scalar(name, s[0], self.iterator[name])
82 | if 'local' in name.lower():
83 | self.writer.add_histogram(f'{name}_hist', self.hist[name], self.iterator[name])
84 | else:
85 | raise ValueError('Not valid data type')
86 | info_name = '{}/info'.format(tag)
87 | info = self.tracker[info_name]
88 | info[2:2] = evaluation_info
89 | info = ' '.join(info)
90 | print(info)
91 | if self.writer is not None:
92 | self.iterator[info_name] += 1
93 | self.writer.add_text(info_name, info, self.iterator[info_name])
94 | return
95 |
96 | def flush(self):
97 | self.writer.flush()
98 | return
99 |
--------------------------------------------------------------------------------
/main_resnet.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import datetime
4 | import os
5 | import random
6 | import shutil
7 | import time
8 |
9 | import numpy as np
10 | import ray
11 | import torch
12 | import torch.backends.cudnn as cudnn
13 |
14 | from config import cfg
15 | from data import fetch_dataset, make_data_loader, split_dataset
16 | from logger import Logger
17 | from metrics import Metric
18 | from models import resnet
19 | from resnet_client import ResnetClient
20 | from utils import save, to_device, process_control, process_dataset, make_optimizer, make_scheduler, collate
21 |
22 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
23 |
24 | torch.backends.cudnn.deterministic = True
25 | torch.backends.cudnn.benchmark = False
26 |
27 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
28 | torch.backends.cudnn.deterministic = True
29 | torch.backends.cudnn.benchmark = False
30 | parser = argparse.ArgumentParser(description='cfg')
31 | for k in cfg:
32 | exec('parser.add_argument(\'--{0}\', default=cfg[\'{0}\'], type=type(cfg[\'{0}\']))'.format(k))
33 | parser.add_argument('--control_name', default=None, type=str)
34 | parser.add_argument('--seed', default=None, type=int)
35 | parser.add_argument('--devices', default=None, nargs='+', type=int)
36 | parser.add_argument('--algo', default='roll', type=str)
37 | parser.add_argument('--weighting', default='avg', type=str)
38 | # parser.add_argument('--lr', default=None, type=int)
39 | parser.add_argument('--g_epochs', default=None, type=int)
40 | parser.add_argument('--l_epochs', default=None, type=int)
41 | parser.add_argument('--overlap', default=None, type=float)
42 |
43 | parser.add_argument('--schedule', default=None, nargs='+', type=int)
44 | # parser.add_argument('--exp_name', default=None, type=str)
45 | args = vars(parser.parse_args())
46 |
47 | cfg['overlap'] = args['overlap']
48 | cfg['weighting'] = args['weighting']
49 | cfg['init_seed'] = int(args['seed'])
50 | if args['algo'] == 'roll':
51 | from resnet_server import ResnetServerRoll as Server
52 | elif args['algo'] == 'random':
53 | from resnet_server import ResnetServerRandom as Server
54 | elif args['algo'] == 'static':
55 | from resnet_server import ResnetServerStatic as Server
56 | if args['devices'] is not None:
57 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in args['devices']])
58 | for k in cfg:
59 | cfg[k] = args[k]
60 | if args['control_name']:
61 | cfg['control'] = {k: v for k, v in zip(cfg['control'].keys(), args['control_name'].split('_'))} \
62 | if args['control_name'] != 'None' else {}
63 | cfg['control_name'] = '_'.join([cfg['control'][k] for k in cfg['control']])
64 | cfg['pivot_metric'] = 'Global-Accuracy'
65 | cfg['pivot'] = -float('inf')
66 | cfg['metric_name'] = {'train': {'Local': ['Local-Loss', 'Local-Accuracy']},
67 | 'test': {'Local': ['Local-Loss', 'Local-Accuracy'], 'Global': ['Global-Loss', 'Global-Accuracy']}}
68 |
69 | ray.init()
70 | rates = None
71 |
72 |
73 | def main():
74 | process_control()
75 |
76 | if args['schedule'] is not None:
77 | cfg['milestones'] = args['schedule']
78 |
79 | if args['g_epochs'] is not None and args['l_epochs'] is not None:
80 | cfg['num_epochs'] = {'global': args['g_epochs'], 'local': args['l_epochs']}
81 | cfg['init_seed'] = int(args['seed'])
82 | seeds = list(range(cfg['init_seed'], cfg['init_seed'] + cfg['num_experiments']))
83 | for i in range(cfg['num_experiments']):
84 | model_tag_list = [str(seeds[i]), cfg['data_name'], cfg['subset'], cfg['model_name'], cfg['control_name']]
85 | cfg['model_tag'] = '_'.join([x for x in model_tag_list if x])
86 | print('Experiment: {}'.format(cfg['model_tag']))
87 | print('Seed: {}'.format(cfg['init_seed']))
88 | run_experiment()
89 | return
90 |
91 |
92 | def run_experiment():
93 | seed = int(cfg['model_tag'].split('_')[0])
94 | torch.manual_seed(seed)
95 | torch.cuda.manual_seed(seed)
96 | random.seed(seed)
97 | np.random.seed(seed)
98 | torch.cuda.manual_seed_all(seed)
99 | torch.set_deterministic_debug_mode('default')
100 | os.environ['PYTHONHASHSEED'] = str(seed)
101 | dataset = fetch_dataset(cfg['data_name'], cfg['subset'])
102 | process_dataset(dataset)
103 | global_model = resnet.resnet18(model_rate=cfg["global_model_rate"], cfg=cfg).to(cfg['device'])
104 | optimizer = make_optimizer(global_model, cfg['lr'])
105 | scheduler = make_scheduler(optimizer)
106 | last_epoch = 1
107 | data_split, label_split = split_dataset(dataset, cfg['num_users'], cfg['data_split_mode'])
108 | logger_path = os.path.join('output', 'runs', 'train_{}'.format(f'{cfg["model_tag"]}_{cfg["exp_name"]}'))
109 | logger = Logger(logger_path)
110 |
111 | num_active_users = int(np.ceil(cfg['frac'] * cfg['num_users']))
112 | cfg['active_user'] = num_active_users
113 | cfg_id = ray.put(cfg)
114 | dataset_ref = {
115 | 'dataset': ray.put(dataset['train']),
116 | 'split': ray.put(data_split['train']),
117 | 'label_split': ray.put(label_split)}
118 |
119 | server = Server(global_model, cfg['model_rate'], dataset_ref, cfg_id)
120 | local = [ResnetClient.remote(logger.log_path, [cfg_id]) for _ in range(num_active_users)]
121 | rates = server.model_rate
122 | for epoch in range(last_epoch, cfg['num_epochs']['global'] + 1):
123 | t0 = time.time()
124 | logger.safe(True)
125 | scheduler.step()
126 | lr = optimizer.param_groups[0]['lr']
127 | local, param_idx, user_idx = server.broadcast(local, lr)
128 | t1 = time.time()
129 |
130 | num_active_users = len(local)
131 | start_time = time.time()
132 | dt = ray.get([client.step.remote(m, num_active_users, start_time)
133 | for m, client in enumerate(local)])
134 |
135 | local_parameters = [v for _k, v in enumerate(dt)]
136 |
137 | # for i, p in enumerate(local_parameters):
138 | # with open(f'local_param_pulled_{i}.pickle', 'w') as f:
139 | # pickle.dump(p, f)
140 | # local_parameters = [{k: torch.tensor(v, device=cfg['device']) for k, v in p.items()} for p in local_parameters]
141 | # for lp in local_parameters:
142 | # for k, p in lp.items():
143 | # print(k, torch.var_mean(p, unbiased=False))
144 |
145 | # local_parameters = [None for _ in range(num_active_users)]
146 | # for m in range(num_active_users):
147 | # local[m].step(m, num_active_users, start_time)
148 | # local_parameters[m] = local[m].pull()
149 | t2 = time.time()
150 | server.step(local_parameters, param_idx, user_idx)
151 | t3 = time.time()
152 |
153 | global_model = server.global_model
154 | test_model = global_model
155 | t4 = time.time()
156 |
157 | test(dataset['test'], data_split['test'], label_split, test_model, logger, epoch, local)
158 | t5 = time.time()
159 | logger.safe(False)
160 | model_state_dict = global_model.state_dict()
161 | save_result = {
162 | 'cfg': cfg, 'epoch': epoch + 1, 'data_split': data_split, 'label_split': label_split,
163 | 'model_dict': model_state_dict, 'optimizer_dict': optimizer.state_dict(),
164 | 'scheduler_dict': scheduler.state_dict(), 'logger': logger}
165 | save(save_result, './output/model/{}_checkpoint.pt'.format(cfg['model_tag']))
166 | if cfg['pivot'] < logger.mean['test/{}'.format(cfg['pivot_metric'])]:
167 | cfg['pivot'] = logger.mean['test/{}'.format(cfg['pivot_metric'])]
168 | shutil.copy('./output/model/{}_checkpoint.pt'.format(cfg['model_tag']),
169 | './output/model/{}_best.pt'.format(cfg['model_tag']))
170 | logger.reset()
171 | t6 = time.time()
172 | print(f'Broadcast Time : {datetime.timedelta(seconds=t1 - t0)}')
173 | print(f'Client Step Time : {datetime.timedelta(seconds=t2 - t1)}')
174 | print(f'Server Step Time : {datetime.timedelta(seconds=t3 - t2)}')
175 | print(f'Stats Time : {datetime.timedelta(seconds=t4 - t3)}')
176 | print(f'Test Time : {datetime.timedelta(seconds=t5 - t4)}')
177 | print(f'Output Copy Time : {datetime.timedelta(seconds=t6 - t5)}')
178 | print(f'<>: {datetime.timedelta(seconds=t6 - t0)}')
179 | logger.safe(False)
180 | [ray.kill(client) for client in local]
181 | return
182 |
183 |
184 | def test(dataset, data_split, label_split, model, logger, epoch, local):
185 | with torch.no_grad():
186 | model.train(False)
187 | dataset_id = ray.put(dataset)
188 | data_split_id = ray.put(data_split)
189 | model_id = ray.put(copy.deepcopy(model))
190 | label_split_id = ray.put(label_split)
191 | all_res = []
192 | for m in range(0, cfg['num_users'], len(local)):
193 | processes = []
194 | for k in range(m, min(m + len(local), cfg['num_users'])):
195 | processes.append(local[k % len(local)]
196 | .test_model_for_user.remote(k,
197 | [dataset_id, data_split_id, model_id, label_split_id]))
198 | results = ray.get(processes)
199 | for result in results:
200 | all_res.append(result)
201 | for r in result:
202 | evaluation, input_size = r
203 | logger.append(evaluation, 'test', input_size)
204 | # Save all_res for plotting
205 | # torch.save((all_res, rates), f'./output/runs/{cfg["model_tag"]}_real_world.pt')
206 | data_loader = make_data_loader({'test': dataset})['test']
207 | metric = Metric()
208 | model.cuda()
209 | for i, data_input in enumerate(data_loader):
210 | data_input = collate(data_input)
211 | input_size = data_input['img'].size(0)
212 | data_input = to_device(data_input, 'cuda')
213 | output = model(data_input)
214 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss']
215 | evaluation = metric.evaluate(cfg['metric_name']['test']['Global'], data_input, output)
216 | logger.append(evaluation, 'test', input_size)
217 | info = {'info': ['Model: {}'.format(cfg['model_tag']),
218 | 'Test Epoch: {}({:.0f}%)'.format(epoch, 100.)]}
219 | logger.append(info, 'test', mean=False)
220 | logger.write('test', cfg['metric_name']['test']['Local'] + cfg['metric_name']['test']['Global'])
221 | return
222 |
223 |
224 | if __name__ == "__main__":
225 | main()
226 |
--------------------------------------------------------------------------------
/main_transformer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import json
4 | import os
5 | import random
6 | import shutil
7 | import time
8 |
9 | import numpy as np
10 | import ray
11 | import torch
12 | import torch.backends.cudnn as cudnn
13 | from tqdm import tqdm
14 |
15 | import models
16 | from config import cfg
17 | from data import fetch_dataset
18 | from logger import Logger
19 | from transformer_client import TransformerClient
20 | from utils import save, process_control, process_dataset, make_optimizer, make_scheduler
21 |
22 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
23 |
24 | torch.backends.cudnn.deterministic = True
25 | torch.backends.cudnn.benchmark = False
26 |
27 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
28 | torch.backends.cudnn.deterministic = True
29 | torch.backends.cudnn.benchmark = False
30 | parser = argparse.ArgumentParser(description='cfg')
31 | for k in cfg:
32 | exec('parser.add_argument(\'--{0}\', default=cfg[\'{0}\'], type=type(cfg[\'{0}\']))'.format(k))
33 | parser.add_argument('--control_name', default=None, type=str)
34 | parser.add_argument('--seed', default=None, type=int)
35 | parser.add_argument('--devices', default=None, nargs='+', type=int)
36 | parser.add_argument('--algo', default='roll', type=str)
37 | # parser.add_argument('--lr', default=None, type=int)
38 | parser.add_argument('--g_epochs', default=None, type=int)
39 | parser.add_argument('--l_epochs', default=None, type=int)
40 | parser.add_argument('--schedule', default=None, nargs='+', type=int)
41 | # parser.add_argument('--exp_name', default=None, type=str)
42 | args = vars(parser.parse_args())
43 | cfg['init_seed'] = int(args['seed'])
44 | if args['algo'] == 'roll':
45 | from transformer_server import TransformerServerRollSO as Server
46 | elif args['algo'] == 'random':
47 | from transformer_server import TransformerServerRandomSO as Server
48 | elif args['algo'] == 'static':
49 | from transformer_server import TransformerServerStaticSO as Server
50 |
51 | args = vars(parser.parse_args())
52 | cfg['init_seed'] = int(args['seed'])
53 | if args['devices'] is not None:
54 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in args['devices']])
55 |
56 | for k in cfg:
57 | cfg[k] = args[k]
58 | if args['control_name']:
59 | cfg['control'] = {k: v for k, v in zip(cfg['control'].keys(), args['control_name'].split('_'))} \
60 | if args['control_name'] != 'None' else {}
61 | cfg['control_name'] = '_'.join([cfg['control'][k] for k in cfg['control']])
62 | cfg['pivot_metric'] = 'Global-Perplexity'
63 | cfg['pivot'] = float('inf')
64 | cfg['metric_name'] = {'train': {'Local': ['Local-Loss', 'Local-Perplexity']},
65 | 'test': {'Global': ['Global-Loss', 'Global-Accuracy'],
66 | 'Local': ['Local-Loss', 'Local-Accuracy']}}
67 | # ray.init(_temp_dir='/egr/research-zhanglambda/samiul/tmp')
68 | # ray.init(_temp_dir='/localscratch/alamsami/tmp', object_store_memory=10**11)
69 |
70 | ray.init(
71 | _temp_dir='/localscratch/alamsami/tmp', object_store_memory=10 ** 11,
72 | _system_config={
73 | "object_spilling_config": json.dumps(
74 | {
75 | "type": "filesystem",
76 | "params": {
77 | "directory_path": '/egr/research-zhanglambda/samiul/tmp',
78 | }
79 | },
80 | )
81 | },
82 | )
83 |
84 |
85 | def main():
86 | process_control()
87 | if args['schedule'] is not None:
88 | cfg['milestones'] = args['schedule']
89 |
90 | if args['g_epochs'] is not None and args['l_epochs'] is not None:
91 | cfg['num_epochs'] = {'global': args['g_epochs'], 'local': args['l_epochs']}
92 | cfg['init_seed'] = int(args['seed'])
93 | seeds = list(range(cfg['init_seed'], cfg['init_seed'] + cfg['num_experiments']))
94 | for i in range(cfg['num_experiments']):
95 | model_tag_list = [str(seeds[i]), cfg['data_name'], cfg['subset'], cfg['model_name'], cfg['control_name']]
96 | cfg['model_tag'] = '_'.join([x for x in model_tag_list if x])
97 | print('Experiment: {}'.format(cfg['model_tag']))
98 | print('Seed: {}'.format(cfg['init_seed']))
99 | run_experiment()
100 | return
101 |
102 |
103 | def run_experiment():
104 | seed = int(cfg['model_tag'].split('_')[0])
105 | torch.manual_seed(seed)
106 | torch.cuda.manual_seed(seed)
107 | random.seed(seed)
108 | np.random.seed(seed)
109 | torch.cuda.manual_seed_all(seed)
110 | torch.set_deterministic_debug_mode('default')
111 | os.environ['PYTHONHASHSEED'] = str(seed)
112 | dataset = fetch_dataset(cfg['data_name'], cfg['subset'])
113 | process_dataset(dataset)
114 | global_model = models.transformer_nwp(model_rate=cfg["global_model_rate"], cfg=cfg)
115 | optimizer = make_optimizer(global_model, cfg['lr'])
116 | scheduler = make_scheduler(optimizer)
117 | last_epoch = 1
118 | data_split, label_split = dataset['train'], dataset['train']
119 | num_active_users = cfg['active_user']
120 |
121 | logger_path = os.path.join('output', 'runs', 'train_{}'.format(f'{cfg["model_tag"]}_{cfg["exp_name"]}'))
122 | logger = Logger(logger_path)
123 |
124 | cfg_id = ray.put(cfg)
125 | dataset_ref = dataset['train']
126 |
127 | server = Server(global_model, cfg['model_rate'], dataset_ref, cfg_id)
128 | num_users_per_step = 8
129 | local = [TransformerClient.remote(logger.log_path, [cfg_id]) for _ in range(num_users_per_step)]
130 | # local = [TransformerClient(logger.log_path, [cfg_id]) for _ in range(num_active_users)]
131 |
132 | for epoch in range(last_epoch, cfg['num_epochs']['global'] + 1):
133 | t0 = time.time()
134 | logger.safe(True)
135 | scheduler.step()
136 | lr = optimizer.param_groups[0]['lr']
137 | local, configs = server.broadcast(local, lr)
138 | t1 = time.time()
139 |
140 | start_time = time.time()
141 | local_parameters = []
142 | for user_start_idx in range(0, num_active_users, num_users_per_step):
143 | idxs = list(range(user_start_idx, min(num_active_users, user_start_idx + num_users_per_step)))
144 | sel_cfg = [configs[idx] for idx in idxs]
145 | [client.update.remote(*config) for client, config in zip(local, sel_cfg)]
146 | dt = ray.get([client.step.remote(user_start_idx + m, num_active_users, start_time)
147 | for m, client in enumerate(local[:len(sel_cfg)])])
148 | local_parameters += [v for _k, v in enumerate(dt)]
149 | torch.cuda.empty_cache()
150 | t2 = time.time()
151 | server.step(local_parameters)
152 | t3 = time.time()
153 |
154 | global_model = server.global_model
155 | test_model = global_model
156 | t4 = time.time()
157 | if True or epoch % 20 == 1:
158 | test(dataset['test'], test_model, logger, epoch, local)
159 | t5 = time.time()
160 | logger.safe(False)
161 | model_state_dict = global_model.state_dict()
162 | if epoch % 20 == 1:
163 | save_result = {
164 | 'cfg': cfg, 'epoch': epoch + 1, 'data_split': data_split, 'label_split': label_split,
165 | 'model_dict': model_state_dict, 'optimizer_dict': optimizer.state_dict(),
166 | 'scheduler_dict': scheduler.state_dict(), 'logger': logger}
167 | save(save_result, './output/model/{}_checkpoint.pt'.format(cfg['model_tag']))
168 | if cfg['pivot'] < logger.mean['test/{}'.format(cfg['pivot_metric'])]:
169 | cfg['pivot'] = logger.mean['test/{}'.format(cfg['pivot_metric'])]
170 | shutil.copy('./output/model/{}_checkpoint.pt'.format(cfg['model_tag']),
171 | './output/model/{}_best.pt'.format(cfg['model_tag']))
172 | logger.reset()
173 | t6 = time.time()
174 | print(f'Broadcast Time : {datetime.timedelta(seconds=t1 - t0)}')
175 | print(f'Client Step Time : {datetime.timedelta(seconds=t2 - t1)}')
176 | print(f'Server Step Time : {datetime.timedelta(seconds=t3 - t2)}')
177 | print(f'Stats Time : {datetime.timedelta(seconds=t4 - t3)}')
178 | print(f'Test Time : {datetime.timedelta(seconds=t5 - t4)}')
179 | print(f'Output Copy Time : {datetime.timedelta(seconds=t6 - t5)}')
180 | print(f'<>: {datetime.timedelta(seconds=t6 - t0)}')
181 | test_model = None
182 | global_model = None
183 | model_state_dict = None
184 | torch.cuda.empty_cache()
185 | logger.safe(False)
186 | [ray.kill(client) for client in local]
187 | return
188 |
189 |
190 | def test(dataset, model, logger, epoch, local):
191 | num_users_per_step = len(local)
192 | num_test_users = 200 # len(dataset)
193 | if epoch % 600 == 0:
194 | num_test_users = 5000
195 | model_id = ray.put(model)
196 | with torch.no_grad():
197 | model.train(False)
198 | sel_cl = np.random.choice(len(dataset), num_test_users)
199 | for user_start_idx in tqdm(range(0, num_test_users, num_users_per_step)):
200 | processes = []
201 | for user_idx in range(user_start_idx, min(user_start_idx + num_users_per_step, num_test_users)):
202 | processes.append(local[user_idx % num_users_per_step]
203 | .test_model_for_user
204 | .remote(user_idx,
205 | [ray.put(dataset[sel_cl[user_idx]]), model_id]))
206 | results = ray.get(processes)
207 | for result in results:
208 | if result:
209 | evaluation, input_size = result[0]
210 | logger.append(evaluation, 'test', input_size)
211 |
212 | info = {'info': ['Model: {}'.format(cfg['model_tag']),
213 | 'Test Epoch: {}({:.0f}%)'.format(epoch, 100.)]}
214 | logger.append(info, 'test', mean=False)
215 | logger.write('test', cfg['metric_name']['test']['Local'])
216 | return evaluation
217 |
218 |
219 | if __name__ == "__main__":
220 | main()
221 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .metrics import *
2 |
--------------------------------------------------------------------------------
/metrics/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from utils import recur
5 |
6 |
7 | def Accuracy(output, target, topk=1):
8 | with torch.no_grad():
9 | batch_size = target.size(0)
10 | pred_k = output.topk(topk, 1, True, True)[1]
11 | correct_k = pred_k.eq(target.view(-1, 1).expand_as(pred_k)).float().sum()
12 | acc = (correct_k * (100.0 / batch_size)).item()
13 | return acc
14 |
15 |
16 | def Perplexity(output, target):
17 | with torch.no_grad():
18 | # label_mask = torch.arange(output.size(1), device=output.device)[output.sum(dim=[0,2]) != 0]
19 | # label_map = output.new_zeros(output.size(1), device=output.device, dtype=torch.long)
20 | # output = output[:, label_mask,]
21 | # label_map[label_mask] = torch.arange(output.size(1), device=output.device)
22 | # target = label_map[target]
23 | ce = F.cross_entropy(output, target)
24 | perplexity = torch.exp(ce).item()
25 | return perplexity
26 |
27 |
28 | class Metric(object):
29 | def __init__(self):
30 | self.metric = {'Loss': (lambda input, output: output['loss'].item()),
31 | 'Local-Loss': (lambda input, output: output['loss'].item()),
32 | 'Global-Loss': (lambda input, output: output['loss'].item()),
33 | 'Accuracy': (lambda input, output: recur(Accuracy, output['score'], input['label'])),
34 | 'Local-Accuracy': (lambda input, output: recur(Accuracy, output['score'], input['label'])),
35 | 'Global-Accuracy': (lambda input, output: recur(Accuracy, output['score'], input['label'])),
36 | 'Perplexity': (lambda input, output: recur(Perplexity, output['score'], input['label'])),
37 | 'Local-Perplexity': (lambda input, output: recur(Perplexity, output['score'], input['label'])),
38 | 'Global-Perplexity': (lambda input, output: recur(Perplexity, output['score'], input['label']))}
39 |
40 | def evaluate(self, metric_names, input, output):
41 | evaluation = {}
42 | for metric_name in metric_names:
43 | evaluation[metric_name] = self.metric[metric_name](input, output)
44 | return evaluation
45 |
--------------------------------------------------------------------------------
/models/conv.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from config import cfg
7 | from modules import Scaler
8 | from .utils import init_param
9 |
10 |
11 | class Conv(nn.Module):
12 | def __init__(self, data_shape, hidden_size, classes_size, rate=1, track=False):
13 | super().__init__()
14 | if cfg['norm'] == 'bn':
15 | norm = nn.BatchNorm2d(hidden_size[0], momentum=None, track_running_stats=track)
16 | elif cfg['norm'] == 'in':
17 | norm = nn.GroupNorm(hidden_size[0], hidden_size[0])
18 | elif cfg['norm'] == 'ln':
19 | norm = nn.GroupNorm(1, hidden_size[0])
20 | elif cfg['norm'] == 'gn':
21 | norm = nn.GroupNorm(4, hidden_size[0])
22 | elif cfg['norm'] == 'none':
23 | norm = nn.Identity()
24 | else:
25 | raise ValueError('Not valid norm')
26 | if cfg['scale']:
27 | scaler = Scaler(rate)
28 | else:
29 | scaler = nn.Identity()
30 | blocks = [nn.Conv2d(data_shape[0], hidden_size[0], 3, 1, 1),
31 | scaler,
32 | norm,
33 | nn.ReLU(inplace=True),
34 | nn.MaxPool2d(2)]
35 | for i in range(len(hidden_size) - 1):
36 | if cfg['norm'] == 'bn':
37 | norm = nn.BatchNorm2d(hidden_size[i + 1], momentum=None, track_running_stats=track)
38 | elif cfg['norm'] == 'in':
39 | norm = nn.GroupNorm(hidden_size[i + 1], hidden_size[i + 1])
40 | elif cfg['norm'] == 'ln':
41 | norm = nn.GroupNorm(1, hidden_size[i + 1])
42 | elif cfg['norm'] == 'gn':
43 | norm = nn.GroupNorm(4, hidden_size[i + 1])
44 | elif cfg['norm'] == 'none':
45 | norm = nn.Identity()
46 | else:
47 | raise ValueError('Not valid norm')
48 | if cfg['scale']:
49 | scaler = Scaler(rate)
50 | else:
51 | scaler = nn.Identity()
52 | blocks.extend([nn.Conv2d(hidden_size[i], hidden_size[i + 1], 3, 1, 1),
53 | scaler,
54 | norm,
55 | nn.ReLU(inplace=True),
56 | nn.MaxPool2d(2)])
57 | blocks = blocks[:-1]
58 | blocks.extend([nn.AdaptiveAvgPool2d(1),
59 | nn.Flatten(),
60 | nn.Linear(hidden_size[-1], classes_size)])
61 | self.blocks = nn.Sequential(*blocks)
62 |
63 | def forward(self, input):
64 | output = {'loss': torch.tensor(0, device=cfg['device'], dtype=torch.float32)}
65 | x = input['img']
66 | out = self.blocks(x)
67 | if 'label_split' in input and cfg['mask']:
68 | label_mask = torch.zeros(cfg['classes_size'], device=out.device)
69 | label_mask[input['label_split']] = 1
70 | out = out.masked_fill(label_mask == 0, 0)
71 | output['score'] = out
72 | output['loss'] = F.cross_entropy(out, input['label'], reduction='mean')
73 | return output
74 |
75 |
76 | def conv(model_rate=1, track=False):
77 | data_shape = cfg['data_shape']
78 | hidden_size = [int(np.ceil(model_rate * x)) for x in cfg['conv']['hidden_size']]
79 | classes_size = cfg['classes_size']
80 | scaler_rate = model_rate / cfg['global_model_rate']
81 | model = Conv(data_shape, hidden_size, classes_size, scaler_rate, track)
82 | model.apply(init_param)
83 | return model
84 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from modules import Scaler
7 | from .utils import init_param
8 |
9 |
10 | class Block(nn.Module):
11 | expansion = 1
12 |
13 | def __init__(self, in_planes, planes, stride, rate, track, cfg):
14 | super(Block, self).__init__()
15 | if cfg['norm'] == 'bn':
16 | n1 = nn.BatchNorm2d(in_planes, momentum=None, track_running_stats=track)
17 | n2 = nn.BatchNorm2d(planes, momentum=None, track_running_stats=track)
18 | elif cfg['norm'] == 'in':
19 | n1 = nn.GroupNorm(in_planes, in_planes)
20 | n2 = nn.GroupNorm(planes, planes)
21 | elif cfg['norm'] == 'ln':
22 | n1 = nn.GroupNorm(1, in_planes)
23 | n2 = nn.GroupNorm(1, planes)
24 | elif cfg['norm'] == 'gn':
25 | n1 = nn.GroupNorm(4, in_planes)
26 | n2 = nn.GroupNorm(4, planes)
27 | elif cfg['norm'] == 'none':
28 | n1 = nn.Identity()
29 | n2 = nn.Identity()
30 | else:
31 | raise ValueError('Not valid norm')
32 | self.n1 = n1
33 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
34 | self.n2 = n2
35 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
36 | if cfg['scale']:
37 | self.scaler = Scaler(rate)
38 | else:
39 | self.scaler = nn.Identity()
40 |
41 | if stride != 1 or in_planes != self.expansion * planes:
42 | self.shortcut = nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)
43 |
44 | def forward(self, x):
45 | out = F.relu(self.n1(self.scaler(x)))
46 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
47 | out = self.conv1(out)
48 | out = self.conv2(F.relu(self.n2(self.scaler(out))))
49 | out += shortcut
50 | return out
51 |
52 |
53 | class Bottleneck(nn.Module):
54 | expansion = 4
55 |
56 | def __init__(self, in_planes, planes, stride, rate, track, cfg):
57 | super(Bottleneck, self).__init__()
58 | if cfg['norm'] == 'bn':
59 | n1 = nn.BatchNorm2d(in_planes, momentum=None, track_running_stats=track)
60 | n2 = nn.BatchNorm2d(planes, momentum=None, track_running_stats=track)
61 | n3 = nn.BatchNorm2d(planes, momentum=None, track_running_stats=track)
62 | elif cfg['norm'] == 'in':
63 | n1 = nn.GroupNorm(in_planes, in_planes)
64 | n2 = nn.GroupNorm(planes, planes)
65 | n3 = nn.GroupNorm(planes, planes)
66 | elif cfg['norm'] == 'ln':
67 | n1 = nn.GroupNorm(1, in_planes)
68 | n2 = nn.GroupNorm(1, planes)
69 | n3 = nn.GroupNorm(1, planes)
70 | elif cfg['norm'] == 'gn':
71 | n1 = nn.GroupNorm(4, in_planes)
72 | n2 = nn.GroupNorm(4, planes)
73 | n3 = nn.GroupNorm(4, planes)
74 | elif cfg['norm'] == 'none':
75 | n1 = nn.Identity()
76 | n2 = nn.Identity()
77 | n3 = nn.Identity()
78 | else:
79 | raise ValueError('Not valid norm')
80 | self.n1 = n1
81 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
82 | self.n2 = n2
83 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
84 | self.n3 = n3
85 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
86 | if cfg['scale']:
87 | self.scaler = Scaler(rate)
88 | else:
89 | self.scaler = nn.Identity()
90 |
91 | if stride != 1 or in_planes != self.expansion * planes:
92 | self.shortcut = nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)
93 |
94 | def forward(self, x):
95 | out = F.relu(self.n1(self.scaler(x)))
96 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
97 | out = self.conv1(out)
98 | out = self.conv2(F.relu(self.n2(self.scaler(out))))
99 | out = self.conv3(F.relu(self.n3(self.scaler(out))))
100 | out += shortcut
101 | return out
102 |
103 |
104 | class ResNet(nn.Module):
105 | def __init__(self, data_shape, hidden_size, block, num_blocks, num_classes, rate, track, cfg):
106 | super(ResNet, self).__init__()
107 | self.cfg = cfg
108 | self.in_planes = hidden_size[0]
109 | self.conv1 = nn.Conv2d(data_shape[0], hidden_size[0], kernel_size=3, stride=1, padding=1, bias=False)
110 | self.layer1 = self._make_layer(block, hidden_size[0], num_blocks[0], stride=1, rate=rate, track=track)
111 | self.layer2 = self._make_layer(block, hidden_size[1], num_blocks[1], stride=2, rate=rate, track=track)
112 | self.layer3 = self._make_layer(block, hidden_size[2], num_blocks[2], stride=2, rate=rate, track=track)
113 | self.layer4 = self._make_layer(block, hidden_size[3], num_blocks[3], stride=2, rate=rate, track=track)
114 | if cfg['norm'] == 'bn':
115 | n4 = nn.BatchNorm2d(hidden_size[3] * block.expansion, momentum=None, track_running_stats=track)
116 | elif cfg['norm'] == 'in':
117 | n4 = nn.GroupNorm(hidden_size[3] * block.expansion, hidden_size[3] * block.expansion)
118 | elif cfg['norm'] == 'ln':
119 | n4 = nn.GroupNorm(1, hidden_size[3] * block.expansion)
120 | elif cfg['norm'] == 'gn':
121 | n4 = nn.GroupNorm(4, hidden_size[3] * block.expansion)
122 | elif cfg['norm'] == 'none':
123 | n4 = nn.Identity()
124 | else:
125 | raise ValueError('Not valid norm')
126 | self.n4 = n4
127 | if cfg['scale']:
128 | self.scaler = Scaler(rate)
129 | else:
130 | self.scaler = nn.Identity()
131 | self.linear = nn.Linear(hidden_size[3] * block.expansion, num_classes)
132 |
133 | def _make_layer(self, block, planes, num_blocks, stride, rate, track):
134 | cfg = self.cfg
135 | strides = [stride] + [1] * (num_blocks - 1)
136 | layers = []
137 | for stride in strides:
138 | layers.append(block(self.in_planes, planes, stride, rate, track, cfg))
139 | self.in_planes = planes * block.expansion
140 | return nn.Sequential(*layers)
141 |
142 | def forward(self, input):
143 | cfg = self.cfg
144 | output = {}
145 | x = input['img']
146 | out = self.conv1(x)
147 | out = self.layer1(out)
148 | out = self.layer2(out)
149 | out = self.layer3(out)
150 | out = self.layer4(out)
151 | out = F.relu(self.n4(self.scaler(out)))
152 | out = F.adaptive_avg_pool2d(out, 1)
153 | out = out.view(out.size(0), -1)
154 | out = self.linear(out)
155 | if 'label_split' in input and cfg['mask']:
156 | label_mask = torch.zeros(cfg['classes_size'], device=out.device)
157 | label_mask[input['label_split']] = 1
158 | out = out.masked_fill(label_mask == 0, 0)
159 | output['score'] = out
160 | output['loss'] = F.cross_entropy(output['score'], input['label'])
161 | return output
162 |
163 |
164 | def resnet18(model_rate=1, track=False, cfg=None):
165 | data_shape = cfg['data_shape']
166 | classes_size = cfg['classes_size']
167 | hidden_size = [int(np.ceil(model_rate * x)) for x in cfg['resnet']['hidden_size']]
168 | scaler_rate = model_rate / cfg['global_model_rate']
169 | model = ResNet(data_shape, hidden_size, Block, [2, 2, 2, 2], classes_size, scaler_rate, track, cfg)
170 | model.apply(init_param)
171 | return model
172 |
173 |
174 | def resnet34(model_rate=1, track=False, cfg=None):
175 | data_shape = cfg['data_shape']
176 | classes_size = cfg['classes_size']
177 | hidden_size = [int(np.ceil(model_rate * x)) for x in cfg['resnet']['hidden_size']]
178 | scaler_rate = model_rate / cfg['global_model_rate']
179 | model = ResNet(data_shape, hidden_size, Block, [3, 4, 6, 3], classes_size, scaler_rate, track, cfg)
180 | model.apply(init_param)
181 | return model
182 |
183 |
184 | def resnet50(model_rate=1, track=False, cfg=None):
185 | data_shape = cfg['data_shape']
186 | classes_size = cfg['classes_size']
187 | hidden_size = [int(np.ceil(model_rate * x)) for x in cfg['resnet']['hidden_size']]
188 | scaler_rate = model_rate / cfg['global_model_rate']
189 | model = ResNet(data_shape, hidden_size, Bottleneck, [3, 4, 6, 3], classes_size, scaler_rate, track, cfg)
190 | model.apply(init_param)
191 | return model
192 |
193 |
194 | def resnet101(model_rate=1, track=False, cfg=None):
195 | data_shape = cfg['data_shape']
196 | classes_size = cfg['classes_size']
197 | hidden_size = [int(np.ceil(model_rate * x)) for x in cfg['resnet']['hidden_size']]
198 | scaler_rate = model_rate / cfg['global_model_rate']
199 | model = ResNet(data_shape, hidden_size, Bottleneck, [3, 4, 23, 3], classes_size, scaler_rate, track, cfg)
200 | model.apply(init_param)
201 | return model
202 |
203 |
204 | def resnet152(model_rate=1, track=False, cfg=None):
205 | data_shape = cfg['data_shape']
206 | classes_size = cfg['classes_size']
207 | hidden_size = [int(np.ceil(model_rate * x)) for x in cfg['resnet']['hidden_size']]
208 | scaler_rate = model_rate / cfg['global_model_rate']
209 | model = ResNet(data_shape, hidden_size, Bottleneck, [3, 8, 36, 3], classes_size, scaler_rate, track, cfg)
210 | model.apply(init_param)
211 | return model
212 |
--------------------------------------------------------------------------------
/models/transformer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.nn import TransformerEncoder
6 |
7 | from modules import Scaler
8 | from .utils import init_param
9 |
10 |
11 | class PositionalEmbedding(nn.Module):
12 | def __init__(self, embedding_size, cfg):
13 | super().__init__()
14 | self.positional_embedding = nn.Embedding(cfg['bptt'], embedding_size)
15 |
16 | def forward(self, x):
17 | N, S = x.size()
18 | position = torch.arange(S, dtype=torch.long, device=x.device).unsqueeze(0).expand((N, S))
19 | x = self.positional_embedding(position)
20 | return x
21 |
22 |
23 | class TransformerEmbedding(nn.Module):
24 | def __init__(self, num_tokens, embedding_size, dropout, rate, cfg):
25 | super().__init__()
26 | self.num_tokens = num_tokens
27 | self.embedding_size = embedding_size
28 | self.positional_embedding = PositionalEmbedding(embedding_size, cfg)
29 | self.embedding = nn.Embedding(num_tokens + 1, embedding_size)
30 | self.norm = nn.LayerNorm(embedding_size)
31 | self.dropout = nn.Dropout(dropout)
32 | self.scaler = Scaler(rate)
33 |
34 | def forward(self, src):
35 | src = self.scaler(self.embedding(src)) + self.scaler(self.positional_embedding(src))
36 | src = self.dropout(self.norm(src))
37 | return src
38 |
39 |
40 | class ScaledDotProduct(nn.Module):
41 | def __init__(self, temperature):
42 | super().__init__()
43 | self.temperature = temperature
44 |
45 | def forward(self, q, k, v, mask=None):
46 | scores = q.matmul(k.transpose(-2, -1)) / self.temperature
47 | seq_len = scores.shape[-1]
48 | h = scores.shape[0]
49 | mask = torch.tril(torch.ones((h, seq_len, seq_len))).to(str(scores.device))
50 | scores = scores.masked_fill(mask == 0, float('-inf'))
51 | attn = F.softmax(scores, dim=-1)
52 | output = torch.matmul(attn, v)
53 | return output, attn
54 |
55 |
56 | class MultiheadAttention(nn.Module):
57 | def __init__(self, embedding_size, num_heads, rate):
58 | super().__init__()
59 | self.embedding_size = embedding_size
60 | self.num_heads = num_heads
61 | self.linear_q = nn.Linear(embedding_size, embedding_size)
62 | self.linear_k = nn.Linear(embedding_size, embedding_size)
63 | self.linear_v = nn.Linear(embedding_size, embedding_size)
64 | self.linear_o = nn.Linear(embedding_size, embedding_size)
65 | self.attention = ScaledDotProduct(temperature=(embedding_size // num_heads) ** 0.5)
66 | self.scaler = Scaler(rate)
67 |
68 | def _reshape_to_batches(self, x):
69 | batch_size, seq_len, in_feature = x.size()
70 | sub_dim = in_feature // self.num_heads
71 | return x.reshape(batch_size, seq_len, self.num_heads, sub_dim).permute(0, 2, 1, 3) \
72 | .reshape(batch_size * self.num_heads, seq_len, sub_dim)
73 |
74 | def _reshape_from_batches(self, x):
75 | batch_size, seq_len, in_feature = x.size()
76 | batch_size //= self.num_heads
77 | out_dim = in_feature * self.num_heads
78 | return x.reshape(batch_size, self.num_heads, seq_len, in_feature).permute(0, 2, 1, 3) \
79 | .reshape(batch_size, seq_len, out_dim)
80 |
81 | def forward(self, q, k, v, mask=None):
82 | q, k, v = self.scaler(self.linear_q(q)), self.scaler(self.linear_k(k)), self.scaler(self.linear_v(v))
83 | q, k, v = self._reshape_to_batches(q), self._reshape_to_batches(k), self._reshape_to_batches(v)
84 | q, attn = self.attention(q, k, v, mask)
85 | q = self._reshape_from_batches(q)
86 | q = self.scaler(self.linear_o(q))
87 | return q, attn
88 |
89 |
90 | class TransformerEncoderLayer(nn.Module):
91 | def __init__(self, embedding_size, num_heads, hidden_size, dropout, rate):
92 | super().__init__()
93 | self.mha = MultiheadAttention(embedding_size, num_heads, rate=rate)
94 | self.dropout = nn.Dropout(dropout)
95 | self.norm1 = nn.LayerNorm(embedding_size)
96 | self.linear1 = nn.Linear(embedding_size, hidden_size)
97 | self.dropout1 = nn.Dropout(dropout)
98 | self.linear2 = nn.Linear(hidden_size, embedding_size)
99 | self.dropout2 = nn.Dropout(dropout)
100 | self.norm2 = nn.LayerNorm(embedding_size)
101 | self.scaler = Scaler(rate)
102 | self.activation = nn.GELU()
103 | self.init_param()
104 |
105 | def init_param(self):
106 | self.linear1.weight.data.normal_(mean=0.0, std=0.02)
107 | self.linear2.weight.data.normal_(mean=0.0, std=0.02)
108 | self.norm1.weight.data.fill_(1.0)
109 | self.norm1.bias.data.zero_()
110 | self.norm2.weight.data.fill_(1.0)
111 | self.norm2.bias.data.zero_()
112 | return
113 |
114 | def forward(self, src, src_mask=None, src_key_padding_mask=None):
115 | attn_output, _ = self.mha(src, src, src, mask=src_mask)
116 | src = src + self.dropout(attn_output)
117 | src = self.norm1(src)
118 | src2 = self.scaler(self.linear2(self.dropout1(self.activation(self.scaler(self.linear1(src))))))
119 | src = src + self.dropout2(src2)
120 | src = self.norm2(src)
121 | return src
122 |
123 |
124 | class Decoder(nn.Module):
125 | def __init__(self, num_tokens, embedding_size, rate):
126 | super().__init__()
127 | self.linear1 = nn.Linear(embedding_size, embedding_size)
128 | self.scaler = Scaler(rate)
129 | self.activation = nn.GELU()
130 | self.norm1 = nn.LayerNorm(embedding_size)
131 | self.linear2 = nn.Linear(embedding_size, num_tokens)
132 |
133 | def forward(self, src):
134 | out = self.linear2(self.norm1(self.activation(self.scaler(self.linear1(src)))))
135 | return out
136 |
137 |
138 | class Transformer(nn.Module):
139 | def __init__(self, num_tokens, embedding_size, num_heads, hidden_size, num_layers, dropout, rate, cfg):
140 | super().__init__()
141 | self.num_tokens = num_tokens
142 | self.transformer_embedding = TransformerEmbedding(num_tokens, embedding_size, dropout, rate, cfg)
143 | encoder_layers = TransformerEncoderLayer(embedding_size, num_heads, hidden_size, dropout, rate)
144 | self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
145 | self.decoder = Decoder(num_tokens, embedding_size, rate)
146 | self.cfg = cfg
147 |
148 | def forward(self, input):
149 | cfg = self.cfg
150 | output = {}
151 | src = input['label'].clone()
152 | N, S = src.size()
153 | d = torch.distributions.bernoulli.Bernoulli(probs=cfg['mask_rate'])
154 | mask = d.sample((N, S)).to(src.device)
155 | src = src.masked_fill(mask == 1, self.num_tokens).detach()
156 | src = self.transformer_embedding(src)
157 | src = self.transformer_encoder(src)
158 | out = self.decoder(src)
159 | out = out.permute(0, 2, 1)
160 | if 'label_split' in input and cfg['mask']:
161 | label_mask = torch.zeros((cfg['num_tokens'], 1), device=out.device)
162 | label_mask[input['label_split']] = 1
163 | out = out.masked_fill(label_mask == 0, 0)
164 | output['score'] = out
165 | output['loss'] = F.cross_entropy(output['score'], input['label'])
166 | return output
167 |
168 |
169 | def transformer(model_rate=1, cfg=None):
170 | num_tokens = cfg['num_tokens']
171 | embedding_size = int(np.ceil(model_rate * cfg['transformer']['embedding_size']))
172 | num_heads = cfg['transformer']['num_heads']
173 | hidden_size = int(np.ceil(model_rate * cfg['transformer']['hidden_size']))
174 | num_layers = cfg['transformer']['num_layers']
175 | dropout = cfg['transformer']['dropout']
176 | scaler_rate = model_rate / cfg['global_model_rate']
177 | model = Transformer(num_tokens, embedding_size, num_heads, hidden_size, num_layers, dropout, scaler_rate, cfg)
178 | model.apply(init_param)
179 | return model
180 |
--------------------------------------------------------------------------------
/models/transformer_nwp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.nn import TransformerEncoder
6 |
7 | from modules import Scaler
8 | from .utils import init_param
9 |
10 |
11 | class PositionalEmbedding(nn.Module):
12 | def __init__(self, embedding_size, cfg):
13 | super().__init__()
14 | self.positional_embedding = nn.Embedding(cfg['bptt'], embedding_size)
15 |
16 | def forward(self, x):
17 | N, S = x.size()
18 | position = torch.arange(S, dtype=torch.long, device=x.device).unsqueeze(0).expand((N, S))
19 | x = self.positional_embedding(position)
20 | return x
21 |
22 |
23 | class TransformerEmbedding(nn.Module):
24 | def __init__(self, num_tokens, embedding_size, dropout, rate, cfg):
25 | super().__init__()
26 | self.num_tokens = num_tokens
27 | self.embedding_size = embedding_size
28 | self.positional_embedding = PositionalEmbedding(embedding_size, cfg)
29 | self.embedding = nn.Embedding(num_tokens + 1, embedding_size)
30 | self.norm = nn.LayerNorm(embedding_size)
31 | self.dropout = nn.Dropout(dropout)
32 | self.scaler = Scaler(rate)
33 |
34 | def forward(self, src):
35 | src = self.scaler(self.embedding(src)) + self.scaler(self.positional_embedding(src))
36 | src = self.dropout(self.norm(src))
37 | return src
38 |
39 |
40 | class ScaledDotProduct(nn.Module):
41 | def __init__(self, temperature):
42 | super().__init__()
43 | self.temperature = temperature
44 |
45 | def forward(self, q, k, v, mask=None):
46 | scores = q.matmul(k.transpose(-2, -1)) / self.temperature
47 | seq_len = scores.shape[-1]
48 | h = scores.shape[0]
49 | mask = torch.tril(torch.ones((h, seq_len, seq_len))).to(str(scores.device))
50 | scores = scores.masked_fill(mask == 0, float('-inf'))
51 | attn = F.softmax(scores, dim=-1)
52 | output = torch.matmul(attn, v)
53 | return output, attn
54 |
55 |
56 | class MultiheadAttention(nn.Module):
57 | def __init__(self, embedding_size, num_heads, rate):
58 | super().__init__()
59 | self.embedding_size = embedding_size
60 | self.num_heads = num_heads
61 | self.linear_q = nn.Linear(embedding_size, embedding_size)
62 | self.linear_k = nn.Linear(embedding_size, embedding_size)
63 | self.linear_v = nn.Linear(embedding_size, embedding_size)
64 | self.linear_o = nn.Linear(embedding_size, embedding_size)
65 | self.attention = ScaledDotProduct(temperature=(embedding_size // num_heads) ** 0.5)
66 | self.scaler = Scaler(rate)
67 |
68 | def _reshape_to_batches(self, x):
69 | batch_size, seq_len, in_feature = x.size()
70 | sub_dim = in_feature // self.num_heads
71 | return x.reshape(batch_size, seq_len, self.num_heads, sub_dim).permute(0, 2, 1, 3) \
72 | .reshape(batch_size * self.num_heads, seq_len, sub_dim)
73 |
74 | def _reshape_from_batches(self, x):
75 | batch_size, seq_len, in_feature = x.size()
76 | batch_size //= self.num_heads
77 | out_dim = in_feature * self.num_heads
78 | return x.reshape(batch_size, self.num_heads, seq_len, in_feature).permute(0, 2, 1, 3) \
79 | .reshape(batch_size, seq_len, out_dim)
80 |
81 | def forward(self, q, k, v, mask=None):
82 | q, k, v = self.scaler(self.linear_q(q)), self.scaler(self.linear_k(k)), self.scaler(self.linear_v(v))
83 | q, k, v = self._reshape_to_batches(q), self._reshape_to_batches(k), self._reshape_to_batches(v)
84 | q, attn = self.attention(q, k, v, mask)
85 | q = self._reshape_from_batches(q)
86 | q = self.scaler(self.linear_o(q))
87 | return q, attn
88 |
89 |
90 | class TransformerEncoderLayer(nn.Module):
91 | def __init__(self, embedding_size, num_heads, hidden_size, dropout, rate):
92 | super().__init__()
93 | self.mha = MultiheadAttention(embedding_size, num_heads, rate=rate)
94 | self.dropout = nn.Dropout(dropout)
95 | self.norm1 = nn.LayerNorm(embedding_size)
96 | self.linear1 = nn.Linear(embedding_size, hidden_size)
97 | self.dropout1 = nn.Dropout(dropout)
98 | self.linear2 = nn.Linear(hidden_size, embedding_size)
99 | self.dropout2 = nn.Dropout(dropout)
100 | self.norm2 = nn.LayerNorm(embedding_size)
101 | self.scaler = Scaler(rate)
102 | self.activation = nn.GELU()
103 | self.init_param()
104 |
105 | def init_param(self):
106 | self.linear1.weight.data.normal_(mean=0.0, std=0.02)
107 | self.linear2.weight.data.normal_(mean=0.0, std=0.02)
108 | self.norm1.weight.data.fill_(1.0)
109 | self.norm1.bias.data.zero_()
110 | self.norm2.weight.data.fill_(1.0)
111 | self.norm2.bias.data.zero_()
112 | return
113 |
114 | def forward(self, src, src_mask=None, src_key_padding_mask=None):
115 | attn_output, _ = self.mha(src, src, src, mask=src_mask)
116 | src = src + self.dropout(attn_output)
117 | src = self.norm1(src)
118 | src2 = self.scaler(self.linear2(self.dropout1(self.activation(self.scaler(self.linear1(src))))))
119 | src = src + self.dropout2(src2)
120 | src = self.norm2(src)
121 | return src
122 |
123 |
124 | class Decoder(nn.Module):
125 | def __init__(self, num_tokens, embedding_size, rate):
126 | super().__init__()
127 | self.linear1 = nn.Linear(embedding_size, embedding_size)
128 | self.scaler = Scaler(rate)
129 | self.activation = nn.GELU()
130 | self.norm1 = nn.LayerNorm(embedding_size)
131 | self.linear2 = nn.Linear(embedding_size, num_tokens)
132 |
133 | def forward(self, src):
134 | out = self.linear2(self.norm1(self.activation(self.scaler(self.linear1(src)))))
135 | return out
136 |
137 |
138 | class Transformer(nn.Module):
139 | def __init__(self, num_tokens, embedding_size, num_heads, hidden_size, num_layers, dropout, rate, cfg):
140 | super().__init__()
141 | self.num_tokens = num_tokens
142 | self.transformer_embedding = TransformerEmbedding(num_tokens, embedding_size, dropout, rate, cfg)
143 | encoder_layers = TransformerEncoderLayer(embedding_size, num_heads, hidden_size, dropout, rate)
144 | self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
145 | self.decoder = Decoder(num_tokens, embedding_size, rate)
146 | self.cfg = cfg
147 |
148 | def forward(self, input):
149 | cfg = self.cfg
150 | output = {}
151 | src = input['label'][:, :-1].clone()
152 | input['label'] = input['label'][:, 1:]
153 | N, S = src.size()
154 | # d = torch.distributions.bernoulli.Bernoulli(probs=cfg['mask_rate'])
155 | #
156 | # mask = d.sample((N, S)).to(src.device)
157 | # src = src.masked_fill(mask == 1, self.num_tokens).detach()
158 | src = self.transformer_embedding(src)
159 | src = self.transformer_encoder(src)
160 | out = self.decoder(src)
161 | out = out.permute(0, 2, 1)
162 | output['score'] = out
163 | # C = np.ones((10000,))
164 | # C[:2] = 0
165 | # C = C / 9998.0
166 | # C = torch.tensor(C, dtype=torch.float).cuda()
167 | output['loss'] = F.cross_entropy(output['score'], input['label'])
168 | return output
169 |
170 |
171 | def transformer(model_rate=1, cfg=None):
172 | num_tokens = cfg['num_tokens']
173 | embedding_size = int(np.ceil(model_rate * cfg['transformer']['embedding_size']))
174 | num_heads = cfg['transformer']['num_heads']
175 | hidden_size = int(np.ceil(model_rate * cfg['transformer']['hidden_size']))
176 | num_layers = cfg['transformer']['num_layers']
177 | dropout = cfg['transformer']['dropout']
178 | scaler_rate = model_rate / cfg['global_model_rate']
179 | model = Transformer(num_tokens, embedding_size, num_heads, hidden_size, num_layers, dropout, scaler_rate, cfg)
180 | model.apply(init_param)
181 | return model
182 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def init_param(m):
5 | if isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
6 | m.weight.data.fill_(1)
7 | m.bias.data.zero_()
8 | elif isinstance(m, nn.Linear):
9 | m.bias.data.zero_()
10 | return m
11 |
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .modules import *
2 |
--------------------------------------------------------------------------------
/modules/modules.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Scaler(nn.Module):
5 | def __init__(self, rate):
6 | super().__init__()
7 | self.rate = rate
8 |
9 | def forward(self, input):
10 | output = input / self.rate if self.training else input
11 | return output
12 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy~=1.23.0
2 | torch~=1.11.0
3 | torchvision~=0.12.0
4 | anytree~=2.8.0
5 | Pillow~=9.0.1
6 | tqdm~=4.64.0
7 | PyYAML~=5.4.1
8 | matplotlib~=3.3.4
9 | ray~=1.13.0
--------------------------------------------------------------------------------
/resnet_client.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import time
3 |
4 | import ray
5 | import torch
6 |
7 | from data import SplitDataset, make_data_loader
8 | from logger import Logger
9 | from metrics import Metric
10 | from models import resnet
11 | from utils import make_optimizer, collate, to_device
12 |
13 |
14 | @ray.remote(num_gpus=0.15)
15 | class ResnetClient:
16 | def __init__(self, log_path, cfg):
17 | # with open('config.yml', 'r') as f:
18 | # cfg = yaml.load(f, Loader=yaml.FullLoader)
19 | self.local_parameters = None
20 | self.m = None
21 | self.start_time = None
22 | self.num_active_users = None
23 | self.optimizer = None
24 | self.model = None
25 | self.lr = None
26 | self.label_split = None
27 | self.data_loader = None
28 | self.model_rate = None
29 | self.client_id = None
30 | cfg = ray.get(cfg[0])
31 | self.metric = Metric()
32 | self.logger = Logger(log_path)
33 | self.cfg = cfg
34 |
35 | def update(self, client_id, dataset_ref, model_ref):
36 | dataset = ray.get(dataset_ref['dataset'])
37 | data_split = ray.get(dataset_ref['split'])
38 | label_split = ray.get(dataset_ref['label_split'])
39 | local_parameters = ray.get(model_ref['local_params'])
40 | # dataset_ref = torch.load('data_store')
41 | # dataset = (dataset_ref['dataset'])
42 | # data_split = (dataset_ref['split'])
43 | # label_split = (dataset_ref['label_split'])
44 | # local_parameters = {k: v.clone().cuda() for k, v in local_parameters.items()}
45 | self.local_parameters = local_parameters
46 | self.client_id = client_id
47 | self.model_rate = model_ref['model_rate']
48 | self.data_loader = make_data_loader({'train': SplitDataset(dataset, data_split[client_id])})['train']
49 | self.label_split = label_split
50 | self.lr = model_ref['lr']
51 | self.metric = Metric()
52 |
53 | def step(self, m, num_active_users, start_time):
54 | cfg = self.cfg
55 | self.model = resnet.resnet18(model_rate=self.model_rate, cfg=self.cfg).to('cuda')
56 | self.model.load_state_dict(self.local_parameters)
57 | self.model.train(True)
58 | self.optimizer = make_optimizer(self.model, self.lr)
59 | self.m = m
60 | self.num_active_users = num_active_users
61 | self.start_time = start_time
62 | for local_epoch in range(1, cfg['num_epochs']['local'] + 1):
63 | for i, step_input in enumerate(self.data_loader):
64 | step_input = collate(step_input)
65 | input_size = step_input['img'].size(0)
66 | step_input['label_split'] = torch.tensor(self.label_split[self.client_id])
67 | step_input = to_device(step_input, 'cuda')
68 | self.model.zero_grad()
69 | output = self.model(step_input)
70 | output['loss'].backward()
71 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
72 | self.optimizer.step()
73 | evaluation = self.metric.evaluate(cfg['metric_name']['train']['Local'], step_input, output)
74 | self.logger.append(evaluation, 'train', n=input_size)
75 | self.log(local_epoch, cfg)
76 | return self.pull()
77 |
78 | def pull(self):
79 | model_state = {k: v.detach().clone().cpu() for k, v in self.model.to(self.cfg['device']).state_dict().items()}
80 | return model_state
81 |
82 | def log(self, epoch, cfg):
83 | if self.m % int((self.num_active_users * cfg['log_interval']) + 1) == 0:
84 | local_time = (time.time() - self.start_time) / (self.m + 1)
85 | epoch_finished_time = datetime.timedelta(seconds=local_time * (self.num_active_users - self.m - 1))
86 | exp_finished_time = epoch_finished_time + datetime.timedelta(
87 | seconds=round((cfg['num_epochs']['global'] - epoch) * local_time * self.num_active_users))
88 | info = {'info': ['Model: {}'.format(cfg['model_tag']),
89 | 'Train Epoch: {}({:.0f}%)'.format(epoch, 100. * self.m / self.num_active_users),
90 | 'ID: {}({}/{})'.format(self.client_id, self.m + 1, self.num_active_users),
91 | 'Learning rate: {}'.format(self.lr),
92 | 'Rate: {}'.format(self.model_rate),
93 | 'Epoch Finished Time: {}'.format(epoch_finished_time),
94 | 'Experiment Finished Time: {}'.format(exp_finished_time)]}
95 | self.logger.append(info, 'train', mean=False)
96 | self.logger.write('train', cfg['metric_name']['train']['Local'])
97 |
98 | def test_model_for_user(self, m, ids):
99 | cfg = self.cfg
100 | metric = Metric()
101 | [dataset, data_split, model, label_split] = ray.get(ids)
102 | model = model.to('cuda')
103 | data_loader = make_data_loader({'test': SplitDataset(dataset, data_split[m])})['test']
104 | results = []
105 | for _, data_input in enumerate(data_loader):
106 | data_input = collate(data_input)
107 | input_size = data_input['img'].size(0)
108 | data_input['label_split'] = torch.tensor(label_split[m])
109 | data_input = to_device(data_input, 'cuda')
110 | output = model(data_input)
111 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss']
112 | evaluation = metric.evaluate(cfg['metric_name']['test']['Local'], data_input, output)
113 | results.append((evaluation, input_size))
114 | return results
115 |
--------------------------------------------------------------------------------
/resnet_server.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from collections import OrderedDict
3 |
4 | import numpy as np
5 | import ray
6 | import torch
7 |
8 |
9 | class ResnetServerRoll:
10 | def __init__(self, global_model, rate, dataset_ref, cfg_id):
11 | self.tau = 1e-2
12 | self.v_t = None
13 | self.beta_1 = 0.9
14 | self.beta_2 = 0.99
15 | self.eta = 1e-2
16 | self.m_t = None
17 | self.user_idx = None
18 | self.param_idx = None
19 | self.dataset_ref = dataset_ref
20 | self.cfg_id = cfg_id
21 | self.cfg = ray.get(cfg_id)
22 | self.global_model = global_model.cpu()
23 | self.global_parameters = global_model.state_dict()
24 | self.rate = rate
25 | self.label_split = ray.get(dataset_ref['label_split'])
26 | self.make_model_rate()
27 | self.num_model_partitions = 50
28 | self.model_idxs = {}
29 | self.roll_idx = {}
30 | self.rounds = 0
31 | self.tmp_counts = {}
32 | for k, v in self.global_parameters.items():
33 | self.tmp_counts[k] = torch.ones_like(v)
34 | self.reshuffle_params()
35 | self.reshuffle_rounds = 512
36 |
37 | def reshuffle_params(self):
38 | for k, v in self.global_parameters.items():
39 | if 'conv1' in k or 'conv2' in k:
40 | output_size = v.size(0)
41 | self.model_idxs[k] = torch.randperm(output_size, device=v.device)
42 | self.roll_idx[k] = 0
43 | return self.model_idxs
44 |
45 | def step(self, local_parameters, param_idx, user_idx):
46 | self.combine(local_parameters, param_idx, user_idx)
47 | self.rounds += 1
48 | if self.rounds % self.reshuffle_rounds:
49 | self.reshuffle_params()
50 |
51 | def broadcast(self, local, lr):
52 | cfg = self.cfg
53 | self.global_model.train(True)
54 | num_active_users = cfg['active_user']
55 | self.user_idx = copy.deepcopy(torch.arange(cfg['num_users'])
56 | [torch.randperm(cfg['num_users'])
57 | [:num_active_users]].tolist())
58 | local_parameters, self.param_idx = self.distribute(self.user_idx)
59 |
60 | param_ids = [ray.put(local_parameter) for local_parameter in local_parameters]
61 |
62 | ray.get([client.update.remote(self.user_idx[m],
63 | self.dataset_ref,
64 | {'lr': lr,
65 | 'model_rate': self.model_rate[self.user_idx[m]],
66 | 'local_params': param_ids[m]})
67 | for m, client in enumerate(local)])
68 | return local, self.param_idx, self.user_idx
69 |
70 | def make_model_rate(self):
71 | cfg = self.cfg
72 | if cfg['model_split_mode'] == 'dynamic':
73 | rate_idx = torch.multinomial(torch.tensor(cfg['proportion']), num_samples=cfg['num_users'],
74 | replacement=True).tolist()
75 | self.model_rate = np.array(self.rate)[rate_idx]
76 | elif cfg['model_split_mode'] == 'fix':
77 | self.model_rate = np.array(self.rate)
78 | else:
79 | raise ValueError('Not valid model split mode')
80 | return
81 |
82 | def split_model(self, user_idx):
83 | cfg = self.cfg
84 | idx_i = [None for _ in range(len(user_idx))]
85 | idx = [OrderedDict() for _ in range(len(user_idx))]
86 | for k, v in self.global_parameters.items():
87 | parameter_type = k.split('.')[-1]
88 | for m in range(len(user_idx)):
89 | if 'weight' in parameter_type or 'bias' in parameter_type:
90 | if parameter_type == 'weight':
91 | if v.dim() > 1:
92 | input_size = v.size(1)
93 | output_size = v.size(0)
94 | if 'conv1' in k or 'conv2' in k:
95 | if idx_i[m] is None:
96 | idx_i[m] = torch.arange(input_size, device=v.device)
97 | input_idx_i_m = idx_i[m]
98 | scaler_rate = self.model_rate[user_idx[m]] / cfg['global_model_rate']
99 | local_output_size = int(np.ceil(output_size * scaler_rate))
100 | if self.cfg['overlap'] is None:
101 | roll = self.rounds % output_size
102 | model_idx = torch.arange(output_size, device=v.device)
103 | else:
104 | overlap = self.cfg['overlap']
105 | self.roll_idx[k] += int(local_output_size * (1 - overlap)) + 1
106 | self.roll_idx[k] = self.roll_idx[k] % local_output_size
107 | roll = self.roll_idx[k]
108 | model_idx = self.model_idxs[k]
109 | model_idx = torch.roll(model_idx, roll, -1)
110 | output_idx_i_m = model_idx[:local_output_size]
111 | idx_i[m] = output_idx_i_m
112 | elif 'shortcut' in k:
113 | input_idx_i_m = idx[m][k.replace('shortcut', 'conv1')][1]
114 | output_idx_i_m = idx_i[m]
115 | elif 'linear' in k:
116 | input_idx_i_m = idx_i[m]
117 | output_idx_i_m = torch.arange(output_size, device=v.device)
118 | else:
119 | raise ValueError('Not valid k')
120 | idx[m][k] = (output_idx_i_m, input_idx_i_m)
121 | else:
122 | input_idx_i_m = idx_i[m]
123 | idx[m][k] = input_idx_i_m
124 | else:
125 | input_size = v.size(0)
126 | if 'linear' in k:
127 | input_idx_i_m = torch.arange(input_size, device=v.device)
128 | idx[m][k] = input_idx_i_m
129 | else:
130 | input_idx_i_m = idx_i[m]
131 | idx[m][k] = input_idx_i_m
132 | else:
133 | pass
134 |
135 | return idx
136 |
137 | def distribute(self, user_idx):
138 | self.make_model_rate()
139 | param_idx = self.split_model(user_idx)
140 | local_parameters = [OrderedDict() for _ in range(len(user_idx))]
141 | for k, v in self.global_parameters.items():
142 | parameter_type = k.split('.')[-1]
143 | for m in range(len(user_idx)):
144 | if 'weight' in parameter_type or 'bias' in parameter_type:
145 | if 'weight' in parameter_type:
146 | if v.dim() > 1:
147 | local_parameters[m][k] = copy.deepcopy(v[torch.meshgrid(param_idx[m][k])])
148 | else:
149 | local_parameters[m][k] = copy.deepcopy(v[param_idx[m][k]])
150 | else:
151 | local_parameters[m][k] = copy.deepcopy(v[param_idx[m][k]])
152 | else:
153 | local_parameters[m][k] = copy.deepcopy(v)
154 | return local_parameters, param_idx
155 |
156 | def combine(self, local_parameters, param_idx, user_idx):
157 | count = OrderedDict()
158 | self.global_parameters = self.global_model.cpu().state_dict()
159 | updated_parameters = copy.deepcopy(self.global_parameters)
160 | tmp_counts_cpy = copy.deepcopy(self.tmp_counts)
161 | for k, v in updated_parameters.items():
162 | parameter_type = k.split('.')[-1]
163 | count[k] = v.new_zeros(v.size(), dtype=torch.float32, device='cpu')
164 | tmp_v = v.new_zeros(v.size(), dtype=torch.float32, device='cpu')
165 | for m in range(len(local_parameters)):
166 | if 'weight' in parameter_type or 'bias' in parameter_type:
167 | if parameter_type == 'weight':
168 | if v.dim() > 1:
169 | if 'linear' in k:
170 | label_split = self.label_split[user_idx[m]]
171 | param_idx[m][k] = list(param_idx[m][k])
172 | param_idx[m][k][0] = param_idx[m][k][0][label_split]
173 | tmp_v[torch.meshgrid(param_idx[m][k])] += self.tmp_counts[k][torch.meshgrid(
174 | param_idx[m][k])] * local_parameters[m][k][label_split]
175 | count[k][torch.meshgrid(param_idx[m][k])] += self.tmp_counts[k][torch.meshgrid(
176 | param_idx[m][k])]
177 | tmp_counts_cpy[k][torch.meshgrid(param_idx[m][k])] += 1
178 | else:
179 | output_size = v.size(0)
180 | scaler_rate = self.model_rate[user_idx[m]] / self.cfg['global_model_rate']
181 | local_output_size = int(np.ceil(output_size * scaler_rate))
182 | if self.cfg['weighting'] == 'avg':
183 | K = 1
184 | elif self.cfg['weighting'] == 'width':
185 | K = local_output_size
186 | elif self.cfg['weighting'] == 'updates':
187 | K = self.tmp_counts[k][torch.meshgrid(param_idx[m][k])]
188 | elif self.cfg['weighting'] == 'updates_width':
189 | K = local_output_size * self.tmp_counts[k][torch.meshgrid(param_idx[m][k])]
190 | # K = self.tmp_counts[k][torch.meshgrid(param_idx[m][k])]
191 | # K = local_output_size
192 | # K = local_output_size * self.tmp_counts[k][torch.meshgrid(param_idx[m][k])]
193 | tmp_v[torch.meshgrid(param_idx[m][k])] += K * local_parameters[m][k]
194 | count[k][torch.meshgrid(param_idx[m][k])] += K
195 | tmp_counts_cpy[k][torch.meshgrid(param_idx[m][k])] += 1
196 | else:
197 | tmp_v[param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] * local_parameters[m][k]
198 | count[k][param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]]
199 | tmp_counts_cpy[k][param_idx[m][k]] += 1
200 | else:
201 | if 'linear' in k:
202 | label_split = self.label_split[user_idx[m]]
203 | param_idx[m][k] = param_idx[m][k][label_split]
204 | tmp_v[param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] * local_parameters[m][k][
205 | label_split]
206 | count[k][param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]]
207 | tmp_counts_cpy[k][param_idx[m][k]] += 1
208 | else:
209 | tmp_v[param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] * local_parameters[m][k]
210 | count[k][param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]]
211 | tmp_counts_cpy[k][param_idx[m][k]] += 1
212 | else:
213 | tmp_v += self.tmp_counts[k] * local_parameters[m][k]
214 | count[k] += self.tmp_counts[k]
215 | tmp_counts_cpy[k] += 1
216 | tmp_v[count[k] > 0] = tmp_v[count[k] > 0].div_(count[k][count[k] > 0])
217 | v[count[k] > 0] = tmp_v[count[k] > 0].to(v.dtype)
218 | self.tmp_counts = tmp_counts_cpy
219 | self.global_parameters = updated_parameters
220 | self.global_model.load_state_dict(self.global_parameters)
221 | return
222 |
223 |
224 | class ResnetServerRandom(ResnetServerRoll):
225 | def split_model(self, user_idx):
226 | cfg = self.cfg
227 | idx_i = [None for _ in range(len(user_idx))]
228 | idx = [OrderedDict() for _ in range(len(user_idx))]
229 | for k, v in self.global_parameters.items():
230 | parameter_type = k.split('.')[-1]
231 | for m in range(len(user_idx)):
232 | if 'weight' in parameter_type or 'bias' in parameter_type:
233 | if parameter_type == 'weight':
234 | if v.dim() > 1:
235 | input_size = v.size(1)
236 | output_size = v.size(0)
237 | if 'conv1' in k or 'conv2' in k:
238 | if idx_i[m] is None:
239 | idx_i[m] = torch.arange(input_size, device=v.device)
240 | input_idx_i_m = idx_i[m]
241 | scaler_rate = self.model_rate[user_idx[m]] / cfg['global_model_rate']
242 | local_output_size = int(np.ceil(output_size * scaler_rate))
243 | model_idx = torch.randperm(output_size, device=v.device)
244 | output_idx_i_m = model_idx[:local_output_size]
245 | idx_i[m] = output_idx_i_m
246 | elif 'shortcut' in k:
247 | input_idx_i_m = idx[m][k.replace('shortcut', 'conv1')][1]
248 | output_idx_i_m = idx_i[m]
249 | elif 'linear' in k:
250 | input_idx_i_m = idx_i[m]
251 | output_idx_i_m = torch.arange(output_size, device=v.device)
252 | else:
253 | raise ValueError('Not valid k')
254 | idx[m][k] = (output_idx_i_m, input_idx_i_m)
255 | else:
256 | input_idx_i_m = idx_i[m]
257 | idx[m][k] = input_idx_i_m
258 | else:
259 | input_size = v.size(0)
260 | if 'linear' in k:
261 | input_idx_i_m = torch.arange(input_size, device=v.device)
262 | idx[m][k] = input_idx_i_m
263 | else:
264 | input_idx_i_m = idx_i[m]
265 | idx[m][k] = input_idx_i_m
266 | else:
267 | pass
268 |
269 | return idx
270 |
271 |
272 | class ResnetServerStatic(ResnetServerRoll):
273 | def split_model(self, user_idx):
274 | cfg = self.cfg
275 | idx_i = [None for _ in range(len(user_idx))]
276 | idx = [OrderedDict() for _ in range(len(user_idx))]
277 | for k, v in self.global_parameters.items():
278 | parameter_type = k.split('.')[-1]
279 | for m in range(len(user_idx)):
280 | if 'weight' in parameter_type or 'bias' in parameter_type:
281 | if parameter_type == 'weight':
282 | if v.dim() > 1:
283 | input_size = v.size(1)
284 | output_size = v.size(0)
285 | if 'conv1' in k or 'conv2' in k:
286 | if idx_i[m] is None:
287 | idx_i[m] = torch.arange(input_size, device=v.device)
288 | input_idx_i_m = idx_i[m]
289 | scaler_rate = self.model_rate[user_idx[m]] / cfg['global_model_rate']
290 | local_output_size = int(np.ceil(output_size * scaler_rate))
291 | model_idx = torch.arange(output_size, device=v.device)
292 | output_idx_i_m = model_idx[:local_output_size]
293 | idx_i[m] = output_idx_i_m
294 | elif 'shortcut' in k:
295 | input_idx_i_m = idx[m][k.replace('shortcut', 'conv1')][1]
296 | output_idx_i_m = idx_i[m]
297 | elif 'linear' in k:
298 | input_idx_i_m = idx_i[m]
299 | output_idx_i_m = torch.arange(output_size, device=v.device)
300 | else:
301 | raise ValueError('Not valid k')
302 | idx[m][k] = (output_idx_i_m, input_idx_i_m)
303 | else:
304 | input_idx_i_m = idx_i[m]
305 | idx[m][k] = input_idx_i_m
306 | else:
307 | input_size = v.size(0)
308 | if 'linear' in k:
309 | input_idx_i_m = torch.arange(input_size, device=v.device)
310 | idx[m][k] = input_idx_i_m
311 | else:
312 | input_idx_i_m = idx_i[m]
313 | idx[m][k] = input_idx_i_m
314 | else:
315 | pass
316 |
317 | return idx
318 |
--------------------------------------------------------------------------------
/resnet_test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 |
7 | from config import cfg
8 | from data import fetch_dataset, make_data_loader, SplitDataset
9 | from logger import Logger
10 | from metrics import Metric
11 | from utils import save, to_device, process_control, process_dataset, collate, resume
12 |
13 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
14 |
15 | torch.backends.cudnn.deterministic = True
16 | torch.backends.cudnn.benchmark = False
17 |
18 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
19 | torch.backends.cudnn.deterministic = True
20 | torch.backends.cudnn.benchmark = False
21 | parser = argparse.ArgumentParser(description='cfg')
22 | for k in cfg:
23 | exec('parser.add_argument(\'--{0}\', default=cfg[\'{0}\'], type=type(cfg[\'{0}\']))'.format(k))
24 | parser.add_argument('--control_name', default=None, type=str)
25 | parser.add_argument('--seed', default=None, type=int)
26 | parser.add_argument('--devices', default=None, nargs='+', type=int)
27 | parser.add_argument('--algo', default='roll', type=str)
28 | # parser.add_argument('--lr', default=None, type=int)
29 | parser.add_argument('--g_epochs', default=None, type=int)
30 | parser.add_argument('--l_epochs', default=None, type=int)
31 | parser.add_argument('--schedule', default=None, nargs='+', type=int)
32 | # parser.add_argument('--exp_name', default=None, type=str)
33 | args = vars(parser.parse_args())
34 | cfg['init_seed'] = int(args['seed'])
35 | if args['algo'] == 'roll':
36 | pass
37 | elif args['algo'] == 'random':
38 | pass
39 | elif args['algo'] == 'static':
40 | pass
41 | if args['devices'] is not None:
42 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in args['devices']])
43 | for k in cfg:
44 | cfg[k] = args[k]
45 | if args['control_name']:
46 | cfg['control'] = {k: v for k, v in zip(cfg['control'].keys(), args['control_name'].split('_'))} \
47 | if args['control_name'] != 'None' else {}
48 | cfg['control_name'] = '_'.join([cfg['control'][k] for k in cfg['control']])
49 | cfg['pivot_metric'] = 'Global-Accuracy'
50 | cfg['pivot'] = -float('inf')
51 | cfg['metric_name'] = {'train': {'Local': ['Local-Loss', 'Local-Accuracy']},
52 | 'test': {'Local': ['Local-Loss', 'Local-Accuracy'], 'Global': ['Global-Loss', 'Global-Accuracy']}}
53 |
54 |
55 | def main():
56 | process_control()
57 | seeds = list(range(cfg['init_seed'], cfg['init_seed'] + cfg['num_experiments']))
58 | for i in range(cfg['num_experiments']):
59 | model_tag_list = [str(seeds[i]), cfg['data_name'], cfg['subset'], cfg['model_name'], cfg['control_name']]
60 | cfg['model_tag'] = '_'.join([x for x in model_tag_list if x])
61 | print('Experiment: {}'.format(cfg['model_tag']))
62 | runExperiment()
63 | return
64 |
65 |
66 | def runExperiment():
67 | cfg['batch_size']['train'] = cfg['batch_size']['test']
68 | seed = int(cfg['model_tag'].split('_')[0])
69 | torch.manual_seed(seed)
70 | torch.cuda.manual_seed(seed)
71 | dataset = fetch_dataset(cfg['data_name'], cfg['subset'])
72 | process_dataset(dataset)
73 | model = eval('resnet.{}(model_rate=cfg["global_model_rate"], track=True, cfg=cfg).to(cfg["device"]).to(cfg['
74 | '"device"])'
75 | .format(cfg['model_name']))
76 | last_epoch, data_split, label_split, model, _, _, _ = resume(model, cfg['model_tag'], load_tag='best', strict=False)
77 | logger_path = 'output/runs/test_{}'.format(cfg['model_tag'])
78 | test_logger = Logger(logger_path)
79 | test_logger.safe(True)
80 | # stats(dataset['train'], model)
81 | test(dataset['test'], data_split['test'], label_split, model, test_logger, last_epoch)
82 | test_logger.safe(False)
83 | _, _, _, _, _, _, train_logger = resume(model, cfg['model_tag'], load_tag='checkpoint', strict=False)
84 | save_result = {'cfg': cfg, 'epoch': last_epoch, 'logger': {'train': train_logger, 'test': test_logger}}
85 | save(save_result, './output/result/{}.pt'.format(cfg['model_tag']))
86 | return
87 |
88 |
89 | def stats(dataset, model):
90 | with torch.no_grad():
91 | data_loader = make_data_loader({'train': dataset})['train']
92 | model.train(True)
93 | for i, input in enumerate(data_loader):
94 | input = collate(input)
95 | input = to_device(input, cfg['device'])
96 | model(input)
97 | return
98 |
99 |
100 | def test(dataset, data_split, label_split, model, logger, epoch):
101 | with torch.no_grad():
102 | metric = Metric()
103 | model.train(False)
104 | for m in range(cfg['num_users']):
105 | data_loader = make_data_loader({'test': SplitDataset(dataset, data_split[m])})['test']
106 | for i, input in enumerate(data_loader):
107 | input = collate(input)
108 | input_size = input['img'].size(0)
109 | input['label_split'] = torch.tensor(label_split[m])
110 | input = to_device(input, cfg['device'])
111 | output = model(input)
112 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss']
113 | evaluation = metric.evaluate(cfg['metric_name']['test']['Local'], input, output)
114 | logger.append(evaluation, 'test', input_size)
115 |
116 | data_loader = make_data_loader({'test': dataset})['test']
117 | for i, input in enumerate(data_loader):
118 | input = collate(input)
119 | input_size = input['img'].size(0)
120 | input = to_device(input, cfg['device'])
121 | output = model(input)
122 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss']
123 | evaluation = metric.evaluate(cfg['metric_name']['test']['Global'], input, output)
124 | logger.append(evaluation, 'test', input_size)
125 | info = {'info': ['Model: {}'.format(cfg['model_tag']), 'Test Epoch: {}({:.0f}%)'.format(epoch, 100.)]}
126 | logger.append(info, 'test', mean=False)
127 | logger.write('test', cfg['metric_name']['test']['Local'] + cfg['metric_name']['test']['Global'])
128 | return
129 |
130 |
131 | if __name__ == "__main__":
132 | main()
133 |
--------------------------------------------------------------------------------
/test_classifier.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 |
7 | from config import cfg
8 | from data import fetch_dataset, make_data_loader, SplitDataset
9 | from logger import Logger
10 | from metrics import Metric
11 | from utils import save, to_device, process_control, process_dataset, resume, collate
12 |
13 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
14 | cudnn.benchmark = True
15 | parser = argparse.ArgumentParser(description='cfg')
16 | for k in cfg:
17 | exec('parser.add_argument(\'--{0}\', default=cfg[\'{0}\'], type=type(cfg[\'{0}\']))'.format(k))
18 | parser.add_argument('--control_name', default=None, type=str)
19 | args = vars(parser.parse_args())
20 | for k in cfg:
21 | cfg[k] = args[k]
22 | if args['control_name']:
23 | cfg['control'] = {k: v for k, v in zip(cfg['control'].keys(), args['control_name'].split('_'))} \
24 | if args['control_name'] != 'None' else {}
25 | cfg['control_name'] = '_'.join([cfg['control'][k] for k in cfg['control']])
26 | cfg['metric_name'] = {'train': {'Local': ['Local-Loss', 'Local-Accuracy']},
27 | 'test': {'Local': ['Local-Loss', 'Local-Accuracy'], 'Global': ['Global-Loss', 'Global-Accuracy']}}
28 |
29 |
30 | def main():
31 | process_control()
32 | seeds = list(range(cfg['init_seed'], cfg['init_seed'] + cfg['num_experiments']))
33 | for i in range(cfg['num_experiments']):
34 | model_tag_list = [str(seeds[i]), cfg['data_name'], cfg['subset'], cfg['model_name'], cfg['control_name']]
35 | cfg['model_tag'] = '_'.join([x for x in model_tag_list if x])
36 | print('Experiment: {}'.format(cfg['model_tag']))
37 | runExperiment()
38 | return
39 |
40 |
41 | def runExperiment():
42 | cfg['batch_size']['train'] = cfg['batch_size']['test']
43 | seed = int(cfg['model_tag'].split('_')[0])
44 | torch.manual_seed(seed)
45 | torch.cuda.manual_seed(seed)
46 | dataset = fetch_dataset(cfg['data_name'], cfg['subset'])
47 | process_dataset(dataset)
48 | model = eval('resnet.{}(model_rate=cfg["global_model_rate"], track=True).to(cfg["device"]).to(cfg["device"])'
49 | .format(cfg['model_name']))
50 | last_epoch, data_split, label_split, model, _, _, _ = resume(model, cfg['model_tag'], load_tag='best', strict=False)
51 | logger_path = 'output/runs/test_{}'.format(cfg['model_tag'])
52 | test_logger = Logger(logger_path)
53 | test_logger.safe(True)
54 | stats(dataset['train'], model)
55 | test(dataset['test'], data_split['test'], label_split, model, test_logger, last_epoch)
56 | test_logger.safe(False)
57 | _, _, _, _, _, _, train_logger = resume(model, cfg['model_tag'], load_tag='checkpoint', strict=False)
58 | save_result = {'cfg': cfg, 'epoch': last_epoch, 'logger': {'train': train_logger, 'test': test_logger}}
59 | save(save_result, './output/result/{}.pt'.format(cfg['model_tag']))
60 | return
61 |
62 |
63 | def stats(dataset, model):
64 | with torch.no_grad():
65 | data_loader = make_data_loader({'train': dataset})['train']
66 | model.train(True)
67 | for i, input in enumerate(data_loader):
68 | input = collate(input)
69 | input = to_device(input, cfg['device'])
70 | model(input)
71 | return
72 |
73 |
74 | def test(dataset, data_split, label_split, model, logger, epoch):
75 | with torch.no_grad():
76 | metric = Metric()
77 | model.train(False)
78 | for m in range(cfg['num_users']):
79 | data_loader = make_data_loader({'test': SplitDataset(dataset, data_split[m])})['test']
80 | for i, input in enumerate(data_loader):
81 | input = collate(input)
82 | input_size = input['img'].size(0)
83 | input['label_split'] = torch.tensor(label_split[m])
84 | input = to_device(input, cfg['device'])
85 | output = model(input)
86 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss']
87 | evaluation = metric.evaluate(cfg['metric_name']['test']['Local'], input, output)
88 | logger.append(evaluation, 'test', input_size)
89 | data_loader = make_data_loader({'test': dataset})['test']
90 | for i, input in enumerate(data_loader):
91 | input = collate(input)
92 | input_size = input['img'].size(0)
93 | input = to_device(input, cfg['device'])
94 | output = model(input)
95 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss']
96 | evaluation = metric.evaluate(cfg['metric_name']['test']['Global'], input, output)
97 | logger.append(evaluation, 'test', input_size)
98 | info = {'info': ['Model: {}'.format(cfg['model_tag']), 'Test Epoch: {}({:.0f}%)'.format(epoch, 100.)]}
99 | logger.append(info, 'test', mean=False)
100 | logger.write('test', cfg['metric_name']['test']['Local'] + cfg['metric_name']['test']['Global'])
101 | return
102 |
103 |
104 | if __name__ == "__main__":
105 | main()
106 |
--------------------------------------------------------------------------------
/train_classifier.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import os
4 | import shutil
5 | import time
6 |
7 | import torch
8 | import torch.backends.cudnn as cudnn
9 |
10 | from config import cfg
11 | from data import fetch_dataset, make_data_loader
12 | from logger import Logger
13 | from metrics import Metric
14 | from utils import save, to_device, process_control, process_dataset, make_optimizer, make_scheduler, resume, collate
15 |
16 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
17 | torch.backends.cudnn.deterministic = True
18 | torch.backends.cudnn.benchmark = False
19 | parser = argparse.ArgumentParser(description='cfg')
20 | for k in cfg:
21 | exec('parser.add_argument(\'--{0}\', default=cfg[\'{0}\'], type=type(cfg[\'{0}\']))'.format(k))
22 | parser.add_argument('--control_name', default=None, type=str)
23 | args = vars(parser.parse_args())
24 | for k in cfg:
25 | cfg[k] = args[k]
26 | if args['control_name']:
27 | cfg['control'] = {k: v for k, v in zip(cfg['control'].keys(), args['control_name'].split('_'))} \
28 | if args['control_name'] != 'None' else {}
29 | cfg['control_name'] = '_'.join([cfg['control'][k] for k in cfg['control']])
30 | cfg['pivot_metric'] = 'Global-Accuracy'
31 | cfg['pivot'] = -float('inf')
32 | cfg['metric_name'] = {'train': {'Local': ['Local-Loss', 'Local-Accuracy']},
33 | 'test': {'Local': ['Local-Loss', 'Local-Accuracy'], 'Global': ['Global-Loss', 'Global-Accuracy']}}
34 |
35 |
36 | def main():
37 | process_control()
38 | seeds = list(range(cfg['init_seed'], cfg['init_seed'] + cfg['num_experiments']))
39 | for i in range(cfg['num_experiments']):
40 | model_tag_list = [str(seeds[i]), cfg['data_name'], cfg['subset'], cfg['model_name'], cfg['control_name']]
41 | cfg['model_tag'] = '_'.join([x for x in model_tag_list if x])
42 | print('Experiment: {}'.format(cfg['model_tag']))
43 | runExperiment()
44 | return
45 |
46 |
47 | def runExperiment():
48 | seed = int(cfg['model_tag'].split('_')[0])
49 | torch.manual_seed(seed)
50 | torch.cuda.manual_seed(seed)
51 | dataset = fetch_dataset(cfg['data_name'], cfg['subset'])
52 | process_dataset(dataset)
53 | data_loader = make_data_loader(dataset)
54 | model = eval('models.{}(model_rate=cfg["global_model_rate"]).to(cfg["device"])'.format(cfg['model_name']))
55 | optimizer = make_optimizer(model, cfg['lr'])
56 | scheduler = make_scheduler(optimizer)
57 | if cfg['resume_mode'] == 1:
58 | last_epoch, model, optimizer, scheduler, logger = resume(model, cfg['model_tag'], optimizer, scheduler)
59 | elif cfg['resume_mode'] == 2:
60 | last_epoch = 1
61 | _, model, _, _, _ = resume(model, cfg['model_tag'])
62 | logger_path = os.path.join('output', 'runs', '{}'.format(cfg['model_tag']))
63 | logger = Logger(logger_path)
64 | else:
65 | last_epoch = 1
66 | logger_path = os.path.join('output', 'runs', 'train_{}'.format(cfg['model_tag']))
67 | logger = Logger(logger_path)
68 | if cfg['world_size'] > 1:
69 | model = torch.nn.DataParallel(model, device_ids=list(range(cfg['world_size'])))
70 | print(cfg['num_epochs'])
71 | for epoch in range(last_epoch, cfg['num_epochs']['global'] + 1):
72 | logger.safe(True)
73 | train(data_loader['train'], model, optimizer, logger, epoch)
74 | test_model = stats(data_loader['train'], model)
75 | test(data_loader['test'], test_model, logger, epoch)
76 | if cfg['scheduler_name'] == 'ReduceLROnPlateau':
77 | scheduler.step(metrics=logger.mean['train/{}'.format(cfg['pivot_metric'])])
78 | else:
79 | scheduler.step()
80 | logger.safe(False)
81 | model_state_dict = model.module.state_dict() if cfg['world_size'] > 1 else model.state_dict()
82 | save_result = {
83 | 'cfg': cfg, 'epoch': epoch + 1, 'model_dict': model_state_dict,
84 | 'optimizer_dict': optimizer.state_dict(), 'scheduler_dict': scheduler.state_dict(),
85 | 'logger': logger}
86 | save(save_result, './output/model/{}_checkpoint.pt'.format(cfg['model_tag']))
87 | if cfg['pivot'] < logger.mean['test/{}'.format(cfg['pivot_metric'])]:
88 | cfg['pivot'] = logger.mean['test/{}'.format(cfg['pivot_metric'])]
89 | shutil.copy('./output/model/{}_checkpoint.pt'.format(cfg['model_tag']),
90 | './output/model/{}_best.pt'.format(cfg['model_tag']))
91 | logger.reset()
92 | logger.safe(False)
93 | return
94 |
95 |
96 | def train(data_loader, model, optimizer, logger, epoch):
97 | metric = Metric()
98 | model.train(True)
99 | start_time = time.time()
100 | for i, input in enumerate(data_loader):
101 | input = collate(input)
102 | input_size = input['img'].size(0)
103 | input = to_device(input, cfg['device'])
104 | optimizer.zero_grad()
105 | output = model(input)
106 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss']
107 | output['loss'].backward()
108 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
109 | optimizer.step()
110 | evaluation = metric.evaluate(cfg['metric_name']['train'], input, output)
111 | logger.append(evaluation, 'train', n=input_size)
112 | if i % int((len(data_loader) * cfg['log_interval']) + 1) == 0:
113 | batch_time = (time.time() - start_time) / (i + 1)
114 | lr = optimizer.param_groups[0]['lr']
115 | epoch_finished_time = datetime.timedelta(seconds=round(batch_time * (len(data_loader) - i - 1)))
116 | exp_finished_time = epoch_finished_time + datetime.timedelta(
117 | seconds=round((cfg['num_epochs']['global'] - epoch) * batch_time * len(data_loader)))
118 | info = {'info': ['Model: {}'.format(cfg['model_tag']),
119 | 'Train Epoch: {}({:.0f}%)'.format(epoch, 100. * i / len(data_loader)),
120 | 'Learning rate: {}'.format(lr), 'Epoch Finished Time: {}'.format(epoch_finished_time),
121 | 'Experiment Finished Time: {}'.format(exp_finished_time)]}
122 | logger.append(info, 'train', mean=False)
123 | logger.write('train', cfg['metric_name']['train'])
124 | return
125 |
126 |
127 | def stats(data_loader, model):
128 | with torch.no_grad():
129 | test_model = eval('models.{}(model_rate=cfg["global_model_rate"], track=True).to(cfg["device"])'
130 | .format(cfg['model_name']))
131 | test_model.load_state_dict(model.state_dict(), strict=False)
132 | test_model.train(True)
133 | for i, input in enumerate(data_loader):
134 | input = collate(input)
135 | input = to_device(input, cfg['device'])
136 | test_model(input)
137 | return test_model
138 |
139 |
140 | def test(data_loader, model, logger, epoch):
141 | with torch.no_grad():
142 | metric = Metric()
143 | model.train(False)
144 | for i, input in enumerate(data_loader):
145 | input = collate(input)
146 | input_size = input['img'].size(0)
147 | input = to_device(input, cfg['device'])
148 | output = model(input)
149 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss']
150 | evaluation = metric.evaluate(cfg['metric_name']['test'], input, output)
151 | logger.append(evaluation, 'test', input_size)
152 | info = {'info': ['Model: {}'.format(cfg['model_tag']), 'Test Epoch: {}({:.0f}%)'.format(epoch, 100.)]}
153 | logger.append(info, 'test', mean=False)
154 | logger.write('test', cfg['metric_name']['test'])
155 | return
156 |
157 |
158 | if __name__ == "__main__":
159 | main()
160 |
--------------------------------------------------------------------------------
/transformer_client.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import datetime
3 | import time
4 |
5 | import ray
6 | import torch
7 |
8 | import models
9 | from datasets.lm import StackOverflowClientDataset
10 | from logger import Logger
11 | from metrics import Metric
12 | from utils import make_optimizer, to_device
13 |
14 |
15 | @ray.remote(num_gpus=0.8)
16 | class TransformerClient:
17 | def __init__(self, log_path, cfg):
18 | self.dataset = None
19 | self.local_parameters = None
20 | self.m = None
21 | self.start_time = None
22 | self.num_active_users = None
23 | self.optimizer = None
24 | self.model = None
25 | self.lr = None
26 | self.label_split = None
27 | self.data_loader = None
28 | self.model_rate = None
29 | self.client_id = None
30 | cfg = ray.get(cfg[0])
31 | self.metric = Metric()
32 | self.logger = Logger(log_path)
33 | self.cfg = cfg
34 |
35 | def update(self, client_id, dataset_ref, model_ref):
36 | dataset = dataset_ref
37 | label_split = dataset_ref
38 | local_parameters = model_ref['local_params']
39 | self.dataset = StackOverflowClientDataset(dataset, self.cfg['seq_length'], self.cfg['batch_size']['train'])
40 | self.local_parameters = copy.deepcopy(local_parameters)
41 | self.client_id = client_id
42 | self.model_rate = model_ref['model_rate']
43 | self.label_split = label_split
44 | self.lr = model_ref['lr']
45 | self.metric = Metric()
46 |
47 | def step(self, m, num_active_users, start_time):
48 | cfg = self.cfg
49 | self.model = models.transformer_nwp(model_rate=self.model_rate, cfg=self.cfg).cpu()
50 | self.model.load_state_dict(self.local_parameters)
51 | self.model = self.model.cuda()
52 | self.model.train(True)
53 | self.optimizer = make_optimizer(self.model, self.lr)
54 | self.m = m
55 | self.num_active_users = num_active_users
56 | self.start_time = start_time
57 | for local_epoch in range(1, cfg['num_epochs']['local'] + 1):
58 | for i, step_input in enumerate(self.dataset):
59 | input_size = step_input['label'].size(0)
60 | # step_input['label_split'] = None
61 | step_input = to_device(step_input, cfg['device'])
62 | self.optimizer.zero_grad()
63 | output = self.model(step_input)
64 | output['loss'].backward()
65 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
66 | self.optimizer.step()
67 | evaluation = self.metric.evaluate(cfg['metric_name']['train']['Local'], step_input, output)
68 | self.logger.append(evaluation, 'train', n=input_size)
69 | self.log(local_epoch, cfg)
70 | return self.pull()
71 |
72 | def pull(self):
73 | model_state = {k: v.detach().clone() for k, v in self.model.cpu().state_dict().items()}
74 | self.model = None
75 | self.local_parameters = None
76 | return model_state
77 |
78 | def log(self, epoch, cfg):
79 | if self.m % int((self.num_active_users * cfg['log_interval']) + 1) == 0 or True:
80 | local_time = (time.time() - self.start_time) / (self.m + 1)
81 | epoch_finished_time = datetime.timedelta(seconds=local_time * (self.num_active_users - self.m - 1))
82 | exp_finished_time = epoch_finished_time + datetime.timedelta(
83 | seconds=round((cfg['num_epochs']['global'] - epoch) * local_time * self.num_active_users))
84 | info = {'info': ['Model: {}'.format(cfg['model_tag']),
85 | 'Train Epoch: {}({:.0f}%)'.format(epoch, 100. * self.m / self.num_active_users),
86 | 'ID: {}({}/{})'.format(self.client_id, self.m + 1, self.num_active_users),
87 | 'Learning rate: {}'.format(self.lr),
88 | 'Rate: {}'.format(self.model_rate),
89 | 'Epoch Finished Time: {}'.format(epoch_finished_time),
90 | 'Experiment Finished Time: {}'.format(exp_finished_time)]}
91 | self.logger.append(info, 'train', mean=False)
92 | self.logger.write('train', cfg['metric_name']['train']['Local'])
93 |
94 | def test_model_for_user(self, m, ids):
95 | cfg = self.cfg
96 | metric = Metric()
97 | [dataset, model] = ray.get(ids)
98 | dataset = StackOverflowClientDataset(dataset, self.cfg['seq_length'], self.cfg['batch_size']['test'])
99 | model = model.to('cuda')
100 | results = []
101 | for _, data_input in enumerate(dataset):
102 | input_size = data_input['label'].size(0)
103 | data_input = to_device(data_input, 'cuda')
104 | output = model(data_input)
105 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss']
106 | s = output['score'].shape
107 | output['score'] = output['score'].permute((0, 2, 1)).reshape((s[0] * s[2], -1))
108 | data_input['label'] = data_input['label'].reshape((-1,))
109 | evaluation = metric.evaluate(cfg['metric_name']['test']['Local'], data_input, output)
110 | results.append((evaluation, input_size))
111 | return results
112 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import collections.abc as container_abcs
2 | import errno
3 | import os
4 | from itertools import repeat
5 |
6 | import numpy as np
7 | import torch
8 | import torch.optim as optim
9 | from torchvision.utils import save_image
10 |
11 | from config import cfg
12 |
13 |
14 | def check_exists(path):
15 | return os.path.exists(path)
16 |
17 |
18 | def makedir_exist_ok(path):
19 | try:
20 | os.makedirs(path)
21 | except OSError as e:
22 | if e.errno == errno.EEXIST:
23 | pass
24 | else:
25 | raise
26 | return
27 |
28 |
29 | def save(input, path, protocol=2, mode='torch'):
30 | dirname = os.path.dirname(path)
31 | makedir_exist_ok(dirname)
32 | if mode == 'torch':
33 | torch.save(input, path, pickle_protocol=protocol)
34 | elif mode == 'numpy':
35 | np.save(path, input, allow_pickle=True)
36 | else:
37 | raise ValueError('Not valid save mode')
38 | return
39 |
40 |
41 | def load(path, mode='torch'):
42 | if mode == 'torch':
43 | return torch.load(path, map_location=lambda storage, loc: storage)
44 | elif mode == 'numpy':
45 | return np.load(path, allow_pickle=True)
46 | else:
47 | raise ValueError('Not valid save mode')
48 | return
49 |
50 |
51 | def save_img(img, path, nrow=10, padding=2, pad_value=0, range=None):
52 | makedir_exist_ok(os.path.dirname(path))
53 | normalize = False if range is None else True
54 | save_image(img, path, nrow=nrow, padding=padding, pad_value=pad_value, normalize=normalize, range=range)
55 | return
56 |
57 |
58 | def to_device(input, device):
59 | output = recur(lambda x, y: x.to(y), input, device)
60 | return output
61 |
62 |
63 | def ntuple(n):
64 | def parse(x):
65 | if isinstance(x, container_abcs.Iterable) and not isinstance(x, str):
66 | return x
67 | return tuple(repeat(x, n))
68 |
69 | return parse
70 |
71 |
72 | def apply_fn(module, fn):
73 | for n, m in module.named_children():
74 | if hasattr(m, fn):
75 | exec('m.{0}()'.format(fn))
76 | if sum(1 for _ in m.named_children()) != 0:
77 | exec('apply_fn(m,\'{0}\')'.format(fn))
78 | return
79 |
80 |
81 | def recur(fn, input, *args):
82 | if isinstance(input, torch.Tensor) or isinstance(input, np.ndarray):
83 | output = fn(input, *args)
84 | elif isinstance(input, list):
85 | output = []
86 | for i in range(len(input)):
87 | output.append(recur(fn, input[i], *args))
88 | elif isinstance(input, tuple):
89 | output = []
90 | for i in range(len(input)):
91 | output.append(recur(fn, input[i], *args))
92 | output = tuple(output)
93 | elif isinstance(input, dict):
94 | output = {}
95 | for key in input:
96 | output[key] = recur(fn, input[key], *args)
97 | else:
98 | raise ValueError('Not valid input type')
99 | return output
100 |
101 |
102 | def process_dataset(dataset):
103 | if cfg['data_name'] in ['MNIST', 'CIFAR10', 'CIFAR100']:
104 | cfg['classes_size'] = dataset['train'].classes_size
105 | elif cfg['data_name'] in ['WikiText2', 'WikiText103', 'PennTreebank']:
106 | cfg['vocab'] = dataset['train'].vocab
107 | cfg['num_tokens'] = len(dataset['train'].vocab)
108 | for split in dataset:
109 | dataset[split] = batchify(dataset[split], cfg['batch_size'][split])
110 | elif cfg['data_name'] in ['Stackoverflow']:
111 | # cfg['vocab'] = dataset['vocab']
112 | cfg['num_tokens'] = len(dataset['vocab'])
113 | elif cfg['data_name'] in ['gld']:
114 | cfg['classes_size'] = 2028
115 | else:
116 | raise ValueError('Not valid data name')
117 | return
118 |
119 |
120 | def process_control():
121 | cfg['model_split_rate'] = {'a': 1, 'b': 0.5, 'c': 0.25, 'd': 0.125, 'e': 0.0625}
122 | cfg['fed'] = int(cfg['control']['fed'])
123 | cfg['num_users'] = int(cfg['control']['num_users'])
124 | cfg['frac'] = float(cfg['control']['frac'])
125 | cfg['data_split_mode'] = cfg['control']['data_split_mode']
126 | cfg['model_split_mode'] = cfg['control']['model_split_mode']
127 | cfg['model_mode'] = cfg['control']['model_mode']
128 | cfg['norm'] = cfg['control']['norm']
129 | cfg['scale'] = bool(int(cfg['control']['scale']))
130 | cfg['mask'] = bool(int(cfg['control']['mask']))
131 | cfg['global_model_mode'] = cfg['model_mode'][0]
132 | cfg['global_model_rate'] = cfg['model_split_rate'][cfg['global_model_mode']]
133 | model_mode = cfg['model_mode'].split('-')
134 | if cfg['model_split_mode'] == 'dynamic':
135 | mode_rate, proportion = [], []
136 | for m in model_mode:
137 | mode_rate.append(cfg['model_split_rate'][m[0]])
138 | proportion.append(int(m[1:]))
139 | cfg['model_rate'] = mode_rate
140 | cfg['proportion'] = (np.array(proportion) / sum(proportion)).tolist()
141 | elif cfg['model_split_mode'] == 'fix':
142 | mode_rate, proportion = [], []
143 | for m in model_mode:
144 | mode_rate.append(cfg['model_split_rate'][m[0]])
145 | proportion.append(int(m[1:]))
146 | num_users_proportion = cfg['num_users'] // sum(proportion)
147 | cfg['model_rate'] = []
148 | for i in range(len(mode_rate)):
149 | cfg['model_rate'] += np.repeat(mode_rate[i], num_users_proportion * proportion[i]).tolist()
150 | cfg['model_rate'] = cfg['model_rate'] + [cfg['model_rate'][-1] for _ in
151 | range(cfg['num_users'] - len(cfg['model_rate']))]
152 | else:
153 | raise ValueError('Not valid model split mode')
154 | cfg['conv'] = {'hidden_size': [64, 128, 256, 512]}
155 | cfg['resnet'] = {'hidden_size': [64, 128, 256, 512]}
156 | cfg['transformer'] = {'embedding_size': 128,
157 | 'num_heads': 8,
158 | 'hidden_size': 2048,
159 | 'num_layers': 3,
160 | 'dropout': 0.1}
161 | if cfg['data_name'] in ['MNIST']:
162 | cfg['data_shape'] = [1, 28, 28]
163 | cfg['optimizer_name'] = 'SGD'
164 | # cfg['lr'] = 1e-2
165 | cfg['momentum'] = 0.9
166 | cfg['weight_decay'] = 5e-4
167 | cfg['scheduler_name'] = 'MultiStepLR'
168 | cfg['factor'] = 0.1
169 | if cfg['data_split_mode'] == 'iid':
170 | # cfg['num_epochs'] = {'global': 200, 'local': 5}
171 | cfg['batch_size'] = {'train': 24, 'test': 100}
172 | # cfg['milestones'] = [100]
173 | elif 'non-iid' in cfg['data_split_mode']:
174 | # cfg['num_epochs'] = {'global': 400, 'local': 5}
175 | cfg['batch_size'] = {'train': 10, 'test': 100}
176 | cfg['milestones'] = [200]
177 | elif cfg['data_split_mode'] == 'none':
178 | cfg['num_epochs'] = 200
179 | cfg['batch_size'] = {'train': 100, 'test': 500}
180 | # cfg['milestones'] = [100]
181 | else:
182 | raise ValueError('Not valid data_split_mode')
183 | elif cfg['data_name'] in ['CIFAR10', 'CIFAR100']:
184 | cfg['data_shape'] = [3, 32, 32]
185 | cfg['optimizer_name'] = 'SGD'
186 | # cfg['lr'] = 1e-4
187 | cfg['momentum'] = 0.9
188 | cfg['min_lr'] = 1e-4
189 | cfg['weight_decay'] = 1e-3
190 | cfg['scheduler_name'] = 'MultiStepLR'
191 | cfg['factor'] = 0.25
192 | if cfg['data_split_mode'] == 'iid':
193 | # cfg['num_epochs'] = {'global': 2500, 'local': 1}
194 | cfg['batch_size'] = {'train': 10, 'test': 100}
195 | # cfg['milestones'] = [1000, 1500, 2000]
196 | elif 'non-iid' in cfg['data_split_mode']:
197 | # cfg['num_epochs'] = {'global': 2500, 'local': 1}
198 | cfg['batch_size'] = {'train': 10, 'test': 100}
199 | # cfg['milestones'] = [1000, 1500, 2000]
200 | elif cfg['data_split_mode'] == 'none':
201 | cfg['num_epochs'] = 400
202 | cfg['batch_size'] = {'train': 100, 'test': 500}
203 | # cfg['milestones'] = [150, 250]
204 | else:
205 | raise ValueError('Not valid data_split_mode')
206 | elif cfg['data_name'] in ['gld']:
207 | cfg['data_shape'] = [3, 92, 92]
208 | cfg['optimizer_name'] = 'SGD'
209 | # cfg['lr'] = 1e-4
210 | cfg['num_users'] = 1262
211 | cfg['active_user'] = 80
212 | cfg['momentum'] = 0.9
213 | cfg['min_lr'] = 5e-4
214 | cfg['weight_decay'] = 1e-3
215 | cfg['scheduler_name'] = 'MultiStepLR'
216 | cfg['factor'] = 0.1
217 | if cfg['data_split_mode'] == 'iid':
218 | # cfg['num_epochs'] = {'global': 2500, 'local': 1}
219 | cfg['batch_size'] = {'train': 32, 'test': 50}
220 | # cfg['milestones'] = [1000, 1500, 2000]
221 | elif 'non-iid' in cfg['data_split_mode']:
222 | # cfg['num_epochs'] = {'global': 2500, 'local': 1}
223 | cfg['batch_size'] = {'train': 32, 'test': 50}
224 | # cfg['milestones'] = [1000, 1500, 2000]
225 | elif cfg['data_split_mode'] == 'none':
226 | cfg['num_epochs'] = 400
227 | cfg['batch_size'] = {'train': 100, 'test': 500}
228 | # cfg['milestones'] = [150, 250]
229 | else:
230 | raise ValueError('Not valid data_split_mode')
231 | elif cfg['data_name'] in ['PennTreebank', 'WikiText2', 'WikiText103']:
232 | cfg['optimizer_name'] = 'SGD'
233 | # cfg['lr'] = 1e-2
234 | cfg['momentum'] = 0.9
235 | cfg['weight_decay'] = 5e-4
236 | cfg['scheduler_name'] = 'MultiStepLR'
237 | cfg['factor'] = 0.1
238 | cfg['bptt'] = 64
239 | cfg['mask_rate'] = 0.15
240 | if cfg['data_split_mode'] == 'iid':
241 | # cfg['num_epochs'] = {'global': 200, 'local': 3}
242 | cfg['batch_size'] = {'train': 100, 'test': 10}
243 | cfg['milestones'] = [50, 100]
244 | elif cfg['data_split_mode'] == 'none':
245 | cfg['num_epochs'] = 100
246 | cfg['batch_size'] = {'train': 100, 'test': 100}
247 | # cfg['milestones'] = [25, 50]
248 | else:
249 | raise ValueError('Not valid data_split_mode')
250 | elif cfg['data_name'] in ['Stackoverflow']:
251 | cfg['optimizer_name'] = 'SGD'
252 | cfg['num_users'] = 342477
253 | cfg['active_user'] = 50
254 | cfg['momentum'] = 0.9
255 | cfg['weight_decay'] = 5e-4
256 | cfg['scheduler_name'] = 'MultiStepLR'
257 | cfg['factor'] = 0.1
258 | cfg['bptt'] = 64
259 | cfg['batch_size'] = {'train': 24, 'test': 24}
260 | cfg['mask_rate'] = 0.15
261 | cfg['num_users'] = 342477
262 | cfg['seq_length'] = 21
263 | else:
264 | raise ValueError('Not valid dataset')
265 | return
266 |
267 |
268 | def make_stats(dataset):
269 | if os.path.exists('./data/stats/{}.pt'.format(dataset.data_name)):
270 | stats = load('./data/stats/{}.pt'.format(dataset.data_name))
271 | elif dataset is not None:
272 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=100, shuffle=False, num_workers=0)
273 | stats = Stats(dim=1)
274 | with torch.no_grad():
275 | for input in data_loader:
276 | stats.update(input['img'])
277 | save(stats, './data/stats/{}.pt'.format(dataset.data_name))
278 | return stats
279 |
280 |
281 | class Stats(object):
282 | def __init__(self, dim):
283 | self.dim = dim
284 | self.n_samples = 0
285 | self.n_features = None
286 | self.mean = None
287 | self.std = None
288 |
289 | def update(self, data):
290 | data = data.transpose(self.dim, -1).reshape(-1, data.size(self.dim))
291 | if self.n_samples == 0:
292 | self.n_samples = data.size(0)
293 | self.n_features = data.size(1)
294 | self.mean = data.mean(dim=0)
295 | self.std = data.std(dim=0)
296 | else:
297 | m = float(self.n_samples)
298 | n = data.size(0)
299 | new_mean = data.mean(dim=0)
300 | new_std = 0 if n == 1 else data.std(dim=0)
301 | old_mean = self.mean
302 | old_std = self.std
303 | self.mean = m / (m + n) * old_mean + n / (m + n) * new_mean
304 | self.std = torch.sqrt(m / (m + n) * old_std ** 2 + n / (m + n) * new_std ** 2 + m * n / (m + n) ** 2 * (
305 | old_mean - new_mean) ** 2)
306 | self.n_samples += n
307 | return
308 |
309 |
310 | def make_optimizer(model, lr):
311 | if cfg['optimizer_name'] == 'SGD':
312 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=cfg['momentum'],
313 | weight_decay=cfg['weight_decay'])
314 | elif cfg['optimizer_name'] == 'RMSprop':
315 | optimizer = optim.RMSprop(model.parameters(), lr=lr, momentum=cfg['momentum'],
316 | weight_decay=cfg['weight_decay'])
317 | elif cfg['optimizer_name'] == 'Adam':
318 | optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=cfg['weight_decay'])
319 | elif cfg['optimizer_name'] == 'Adamax':
320 | optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=cfg['weight_decay'])
321 | else:
322 | raise ValueError('Not valid optimizer name')
323 | return optimizer
324 |
325 |
326 | def make_scheduler(optimizer):
327 | if cfg['scheduler_name'] == 'None':
328 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[65535])
329 | elif cfg['scheduler_name'] == 'StepLR':
330 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=cfg['step_size'], gamma=cfg['factor'])
331 | elif cfg['scheduler_name'] == 'MultiStepLR':
332 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg['milestones'], gamma=cfg['factor'])
333 | elif cfg['scheduler_name'] == 'ExponentialLR':
334 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
335 | elif cfg['scheduler_name'] == 'CosineAnnealingLR':
336 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg['num_epochs']['global'],
337 | eta_min=cfg['min_lr'])
338 | elif cfg['scheduler_name'] == 'ReduceLROnPlateau':
339 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=cfg['factor'],
340 | patience=cfg['patience'], verbose=True,
341 | threshold=cfg['threshold'], threshold_mode='rel',
342 | min_lr=cfg['min_lr'])
343 | elif cfg['scheduler_name'] == 'CyclicLR':
344 | scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=cfg['lr'], max_lr=10 * cfg['lr'])
345 | else:
346 | raise ValueError('Not valid scheduler name')
347 | return scheduler
348 |
349 |
350 | def resume(model, model_tag, optimizer=None, scheduler=None, load_tag='checkpoint', strict=True, verbose=True):
351 | if cfg['data_split_mode'] != 'none':
352 | if os.path.exists('./output/model/{}_{}.pt'.format(model_tag, load_tag)):
353 | checkpoint = load('./output/model/{}_{}.pt'.format(model_tag, load_tag))
354 | last_epoch = checkpoint['epoch']
355 | data_split = checkpoint['data_split']
356 | label_split = checkpoint['label_split']
357 | model.load_state_dict(checkpoint['model_dict'], strict=strict)
358 | if optimizer is not None:
359 | optimizer.load_state_dict(checkpoint['optimizer_dict'])
360 | if scheduler is not None:
361 | scheduler.load_state_dict(checkpoint['scheduler_dict'])
362 | logger = checkpoint['logger']
363 | if verbose:
364 | print('Resume from {}'.format(last_epoch))
365 | else:
366 | print('Not exists model tag: {}, start from scratch'.format(model_tag))
367 | from datetime import datetime
368 | from logger import Logger
369 | last_epoch = 1
370 | data_split = None
371 | label_split = None
372 | logger_path = 'output/runs/train_{}_{}'.format(cfg['model_tag'], datetime.now().strftime('%b%d_%H-%M-%S'))
373 | logger = Logger(logger_path)
374 | return last_epoch, data_split, label_split, model, optimizer, scheduler, logger
375 | else:
376 | if os.path.exists('./output/model/{}_{}.pt'.format(model_tag, load_tag)):
377 | checkpoint = load('./output/model/{}_{}.pt'.format(model_tag, load_tag))
378 | last_epoch = checkpoint['epoch']
379 | model.load_state_dict(checkpoint['model_dict'], strict=strict)
380 | if optimizer is not None:
381 | optimizer.load_state_dict(checkpoint['optimizer_dict'])
382 | if scheduler is not None:
383 | scheduler.load_state_dict(checkpoint['scheduler_dict'])
384 | logger = checkpoint['logger']
385 | if verbose:
386 | print('Resume from {}'.format(last_epoch))
387 | else:
388 | print('Not exists model tag: {}, start from scratch'.format(model_tag))
389 | from datetime import datetime
390 | from logger import Logger
391 | last_epoch = 1
392 | logger_path = 'output/runs/train_{}_{}'.format(cfg['model_tag'], datetime.now().strftime('%b%d_%H-%M-%S'))
393 | logger = Logger(logger_path)
394 | return last_epoch, model, optimizer, scheduler, logger
395 |
396 |
397 | def collate(input):
398 | if 'label' in input.keys():
399 | input['label'] = [torch.tensor(i) for i in input['label']]
400 | for k in input:
401 | input[k] = torch.stack(input[k], 0)
402 | return input
403 |
404 |
405 | def batchify(dataset, batch_size):
406 | num_batch = len(dataset) // batch_size
407 | dataset.token = dataset.token.narrow(0, 0, num_batch * batch_size)
408 | dataset.token = dataset.token.reshape(batch_size, -1)
409 | return dataset
410 |
--------------------------------------------------------------------------------
/weighted_server.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from collections import OrderedDict
3 |
4 | import numpy as np
5 | import ray
6 | import torch
7 |
8 |
9 | class WeightedServer:
10 | def __init__(self, global_model, rate, dataset_ref, cfg_id):
11 | self.tau = 1e-2
12 | self.v_t = None
13 | self.beta_1 = 0.9
14 | self.beta_2 = 0.99
15 | self.eta = 1e-2
16 | self.m_t = None
17 | self.user_idx = None
18 | self.param_idx = None
19 | self.dataset_ref = dataset_ref
20 | self.cfg_id = cfg_id
21 | self.cfg = ray.get(cfg_id)
22 | self.global_model = global_model
23 | self.global_parameters = global_model.state_dict()
24 | self.rate = rate
25 | self.label_split = ray.get(dataset_ref['label_split'])
26 | self.make_model_rate()
27 | self.num_model_partitions = 50
28 | self.model_idxs = {}
29 | self.rounds = 0
30 | self.tmp_counts = {}
31 | for k, v in self.global_parameters.items():
32 | self.tmp_counts[k] = torch.ones_like(v)
33 |
34 | for k, v in self.global_parameters.items():
35 | if 'conv1' in k or 'conv2' in k:
36 | output_size = v.size(0)
37 | self.model_idxs[k] = [torch.randperm(output_size, device=v.device) for _ in range(
38 | self.num_model_partitions)]
39 |
40 | def step(self, local_parameters):
41 | self.combine(local_parameters, self.param_idx, self.user_idx)
42 | self.rounds += 1
43 |
44 | def broadcast(self, local, lr):
45 | cfg = self.cfg
46 | self.global_model.train(True)
47 | num_active_users = int(np.ceil(cfg['frac'] * cfg['num_users']))
48 | self.user_idx = copy.deepcopy(torch.arange(cfg['num_users'])
49 | [torch.randperm(cfg['num_users'])
50 | [:num_active_users]].tolist())
51 | local_parameters, self.param_idx = self.distribute(self.user_idx)
52 | # [torch.save(local_parameters[m], f'local_param_{m}') for m in range(len(local_parameters))]
53 | # local_parameters = [{k: v.cpu().numpy() for k, v in p.items()} for p in local_parameters]
54 |
55 | param_ids = [ray.put(local_parameter) for local_parameter in local_parameters]
56 | # ([client.update(self.user_idx[m],
57 | # self.dataset_ref,
58 | # {'lr': lr,
59 | # 'model_rate': self.model_rate[self.user_idx[m]],
60 | # 'local_params': param_ids[m]}) for m, client in enumerate(
61 | # local)])
62 |
63 | ray.get([client.update.remote(self.user_idx[m],
64 | self.dataset_ref,
65 | {'lr': lr,
66 | 'model_rate': self.model_rate[self.user_idx[m]],
67 | 'local_params': param_ids[m]})
68 | for m, client in enumerate(local)])
69 | return local
70 |
71 | def make_model_rate(self):
72 | cfg = self.cfg
73 | if cfg['model_split_mode'] == 'dynamic':
74 | rate_idx = torch.multinomial(torch.tensor(cfg['proportion']), num_samples=cfg['num_users'],
75 | replacement=True).tolist()
76 | self.model_rate = np.array(self.rate)[rate_idx]
77 | elif cfg['model_split_mode'] == 'fix':
78 | self.model_rate = np.array(self.rate)
79 | else:
80 | raise ValueError('Not valid model split mode')
81 | return
82 |
83 | def split_model(self, user_idx):
84 | cfg = self.cfg
85 | idx_i = [None for _ in range(len(user_idx))]
86 | idx = [OrderedDict() for _ in range(len(user_idx))]
87 | for k, v in self.global_parameters.items():
88 | parameter_type = k.split('.')[-1]
89 | for m in range(len(user_idx)):
90 | if 'weight' in parameter_type or 'bias' in parameter_type:
91 | if parameter_type == 'weight':
92 | if v.dim() > 1:
93 | input_size = v.size(1)
94 | output_size = v.size(0)
95 | if 'conv1' in k or 'conv2' in k:
96 | if idx_i[m] is None:
97 | idx_i[m] = torch.arange(input_size, device=v.device)
98 | input_idx_i_m = idx_i[m]
99 | scaler_rate = self.model_rate[user_idx[m]] / cfg['global_model_rate']
100 | local_output_size = int(np.ceil(output_size * scaler_rate))
101 | # model_idx = self.model_idxs[k][m % self.num_model_partitions]
102 | # output_idx_i_m = model_idx[:local_output_size]
103 | roll = self.rounds % output_size
104 | # model_idx = self.model_idxs[k][self.rounds % self.num_model_partitions]
105 | model_idx = torch.arange(output_size, device=v.device)
106 | model_idx = torch.roll(model_idx, roll, -1)
107 | output_idx_i_m = model_idx[:local_output_size]
108 | idx_i[m] = output_idx_i_m
109 | elif 'shortcut' in k:
110 | input_idx_i_m = idx[m][k.replace('shortcut', 'conv1')][1]
111 | output_idx_i_m = idx_i[m]
112 | elif 'linear' in k:
113 | input_idx_i_m = idx_i[m]
114 | output_idx_i_m = torch.arange(output_size, device=v.device)
115 | else:
116 | raise ValueError('Not valid k')
117 | idx[m][k] = (output_idx_i_m, input_idx_i_m)
118 | else:
119 | input_idx_i_m = idx_i[m]
120 | idx[m][k] = input_idx_i_m
121 | else:
122 | input_size = v.size(0)
123 | if 'linear' in k:
124 | input_idx_i_m = torch.arange(input_size, device=v.device)
125 | idx[m][k] = input_idx_i_m
126 | else:
127 | input_idx_i_m = idx_i[m]
128 | idx[m][k] = input_idx_i_m
129 | else:
130 | pass
131 |
132 | return idx
133 |
134 | def distribute(self, user_idx):
135 | self.make_model_rate()
136 | param_idx = self.split_model(user_idx)
137 | local_parameters = [OrderedDict() for _ in range(len(user_idx))]
138 | for k, v in self.global_parameters.items():
139 | parameter_type = k.split('.')[-1]
140 | for m in range(len(user_idx)):
141 | if 'weight' in parameter_type or 'bias' in parameter_type:
142 | if 'weight' in parameter_type:
143 | if v.dim() > 1:
144 | local_parameters[m][k] = copy.deepcopy(v[torch.meshgrid(param_idx[m][k])])
145 | else:
146 | local_parameters[m][k] = copy.deepcopy(v[param_idx[m][k]])
147 | else:
148 | local_parameters[m][k] = copy.deepcopy(v[param_idx[m][k]])
149 | else:
150 | local_parameters[m][k] = copy.deepcopy(v)
151 | return local_parameters, param_idx
152 |
153 | def combine(self, local_parameters, param_idx, user_idx):
154 | count = OrderedDict()
155 | self.global_parameters = self.global_model.state_dict()
156 | updated_parameters = copy.deepcopy(self.global_parameters)
157 | tmp_counts_cpy = copy.deepcopy(self.tmp_counts)
158 | for k, v in updated_parameters.items():
159 | parameter_type = k.split('.')[-1]
160 | count[k] = v.new_zeros(v.size(), dtype=torch.float32)
161 | tmp_v = v.new_zeros(v.size(), dtype=torch.float32)
162 | for m in range(len(local_parameters)):
163 | if 'weight' in parameter_type or 'bias' in parameter_type:
164 | if parameter_type == 'weight':
165 | if v.dim() > 1:
166 | if 'linear' in k:
167 | label_split = self.label_split[user_idx[m]]
168 | param_idx[m][k] = list(param_idx[m][k])
169 | param_idx[m][k][0] = param_idx[m][k][0][label_split]
170 | tmp_v[torch.meshgrid(param_idx[m][k])] += self.tmp_counts[k][torch.meshgrid(
171 | param_idx[m][k])] * local_parameters[m][k][label_split]
172 | count[k][torch.meshgrid(param_idx[m][k])] += self.tmp_counts[k][torch.meshgrid(
173 | param_idx[m][k])]
174 | tmp_counts_cpy[k][torch.meshgrid(param_idx[m][k])] += 1
175 | else:
176 | output_size = v.size(0)
177 | scaler_rate = self.model_rate[user_idx[m]] / self.cfg['global_model_rate']
178 | local_output_size = int(np.ceil(output_size * scaler_rate))
179 | # K = self.tmp_counts[k][torch.meshgrid(param_idx[m][k])]
180 | # K = local_output_size
181 | K = local_output_size * self.tmp_counts[k][torch.meshgrid(param_idx[m][k])]
182 | # K = 1
183 | tmp_v[torch.meshgrid(param_idx[m][k])] += K * local_parameters[m][k]
184 | count[k][torch.meshgrid(param_idx[m][k])] += K
185 | tmp_counts_cpy[k][torch.meshgrid(param_idx[m][k])] += 1
186 | else:
187 | tmp_v[param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] * local_parameters[m][k]
188 | count[k][param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]]
189 | tmp_counts_cpy[k][param_idx[m][k]] += 1
190 | else:
191 | if 'linear' in k:
192 | label_split = self.label_split[user_idx[m]]
193 | param_idx[m][k] = param_idx[m][k][label_split]
194 | tmp_v[param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] * local_parameters[m][k][
195 | label_split]
196 | count[k][param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]]
197 | tmp_counts_cpy[k][param_idx[m][k]] += 1
198 | else:
199 | tmp_v[param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] * local_parameters[m][k]
200 | count[k][param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]]
201 | tmp_counts_cpy[k][param_idx[m][k]] += 1
202 | else:
203 | tmp_v += self.tmp_counts[k] * local_parameters[m][k]
204 | count[k] += self.tmp_counts[k]
205 | tmp_counts_cpy[k] += 1
206 | tmp_v[count[k] > 0] = tmp_v[count[k] > 0].div_(count[k][count[k] > 0])
207 | v[count[k] > 0] = tmp_v[count[k] > 0].to(v.dtype)
208 | self.tmp_counts = tmp_counts_cpy
209 |
210 | delta_t = {k: v - self.global_parameters[k] for k, v in updated_parameters.items()}
211 | if self.rounds in self.cfg['milestones']:
212 | self.eta *= 0.5
213 | if not self.m_t or self.rounds in self.cfg['milestones']:
214 | self.m_t = {k: torch.zeros_like(x) for k, x in delta_t.items()}
215 | self.m_t = {
216 | k: self.beta_1 * self.m_t[k] + (1 - self.beta_1) * delta_t[k] for k in delta_t.keys()
217 | }
218 | if not self.v_t or self.rounds in self.cfg['milestones']:
219 | self.v_t = {k: torch.zeros_like(x) for k, x in delta_t.items()}
220 | self.v_t = {
221 | k: self.beta_2 * self.v_t[k] + (1 - self.beta_2) * torch.multiply(delta_t[k], delta_t[k])
222 | for k in delta_t.keys()
223 | }
224 | self.global_parameters = {
225 | k: self.global_parameters[k] + self.eta * self.m_t[k] / (torch.sqrt(self.v_t[k]) + self.tau)
226 | for k in self.global_parameters.keys()
227 | }
228 | # if not self.m_t:
229 | # self.m_t = {k: torch.zeros_like(x) for k, x in delta_t.items()}
230 | # self.m_t = {
231 | # k: self.beta_1 * self.m_t[k] + (1 - self.beta_1) * delta_t[k] for k in delta_t.keys()
232 | # }
233 | # if not self.v_t:
234 | # self.v_t = {k: torch.zeros_like(x) for k, x in delta_t.items()}
235 | # self.v_t = {
236 | # k: self.v_t[k] + torch.multiply(delta_t[k], delta_t[k])
237 | # for k in delta_t.keys()
238 | # }
239 | # self.global_parameters = {
240 | # k: self.global_parameters[k] + self.eta * self.m_t[k] / (torch.sqrt(self.v_t[k]) + self.tau)
241 | # for k in self.global_parameters.keys()
242 | # }
243 | # self.global_parameters = updated_parameters
244 | self.global_model.load_state_dict(self.global_parameters)
245 | return
246 |
--------------------------------------------------------------------------------