├── .gitignore ├── LICENSE ├── README.md ├── check_local.py ├── config.py ├── config.yml ├── data.py ├── datasets ├── __init__.py ├── cifar.py ├── folder.py ├── gld.py ├── imagenet.py ├── lm.py ├── mnist.py ├── omniglot.py ├── transforms.py └── utils.py ├── figures ├── fedrolex_overview.png ├── table_overview.png └── video_placeholder.png ├── logger.py ├── main_resnet.py ├── main_transformer.py ├── metrics ├── __init__.py └── metrics.py ├── models ├── conv.py ├── resnet.py ├── transformer.py ├── transformer_nwp.py └── utils.py ├── modules ├── __init__.py └── modules.py ├── requirements.txt ├── resnet_client.py ├── resnet_server.py ├── resnet_test.py ├── test_classifier.py ├── train_classifier.py ├── transformer_client.py ├── transformer_server.py ├── utils.py └── weighted_server.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | .idea/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedRolex: Model-Heterogeneous Federated Learning with Rolling Sub-Model Extraction 2 | 3 | Code for paper: 4 | > [FedRolex: Model-Heterogeneous Federated Learning with Rolling Sub-Model Extraction](https://openreview.net/forum?id=OtxyysUdBE)\ 5 | > Samiul Alam, Luyang Liu, Ming Yan, and Mi Zhang.\ 6 | > _NeurIPS 2022_. 7 | 8 | The repository is built upon [HeteroFL](https://github.com/dem123456789/HeteroFL-Computation-and-Communication-Efficient-Federated-Learning-for-Heterogeneous-Clients). 9 | 10 | # Overview 11 | 12 | Most cross-device federated learning studies focus on the model-homogeneous setting where the global server model and local client models are identical. However, such constraint not only excludes low-end clients who would otherwise make unique contributions to model training but also restrains clients from training large models due to on-device resource bottlenecks. We propose `FedRolex`, a partial training-based approach that enables model-heterogeneous FL and can train a global server model larger than the largest client model. 13 | 14 |

15 | comparison 16 |

17 | 18 | The key difference between `FedRolex` and existing partial training-based methods is how the sub-models are extracted for each client over communication rounds in the federated training process. Specifically, instead of extracting sub-models in either random or static manner, `FedRolex` proposes a rolling sub-model extraction scheme, where the sub-model is extracted from the global server model using a rolling window that advances in each communication round. Since the window is rolling, sub-models from different parts of the global model are extracted in sequence in different rounds. As a result, all the parameters of the global server model are evenly trained over the local data of client devices. 19 | 20 |

21 | fedrolex 22 |

23 | 24 | # Video Brief 25 | Click the figure to watch this short video explaining our work. 26 | 27 | [![slideslive_link](figures/video_placeholder.png)](https://recorder-v3.slideslive.com/#/share?share=74286&s=8264b1ae-a2a0-459d-aa99-88d7baf2d51f) 28 | 29 | # Usage 30 | ## Setup 31 | ```commandline 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ## Training 36 | Train RESNET-18 model on CIFAR-10 dataset. 37 | ```commandline 38 | python main_resnet.py --data_name CIFAR10 \ 39 | --model_name resnet18 \ 40 | --control_name 1_100_0.1_non-iid-2_dynamic_a1-b1-c1-d1-e1_bn_1_1 \ 41 | --exp_name roll_test \ 42 | --algo roll \ 43 | --g_epoch 3200 \ 44 | --l_epoch 1 \ 45 | --lr 2e-4 \ 46 | --schedule 1200 \ 47 | --seed 31 \ 48 | --num_experiments 3 \ 49 | --devices 0 1 2 50 | ``` 51 | `data_name`: CIFAR10 or CIFAR100 \ 52 | `model_name`: resnet18 or vgg 53 | `control_name`: 1_{num users}_{num participating users}_{iid or non-iid-{num classes}}_{dynamic or fix} 54 | _{heterogeneity distribution}_{batch norm(bn), {group norm(gn)}}_{scalar 1 or 0}_{masked cross entropy, 1 or 0} \ 55 | `exp_name`: string value \ 56 | `algo`: roll, random or static \ 57 | `g_epoch`: num global epochs \ 58 | `l_epoch`: num local epochs \ 59 | `lr`: learning rate \ 60 | `schedule`: lr schedule, space seperated list of integers less than g_epoch \ 61 | `seed`: integer number \ 62 | `num_experiments`: integer number, will run `num_experiments` trials with `seed` incrementing each time \ 63 | `devices`: Index of GPUs to use \ 64 | 65 | To train Transformer model on StackOverflow dataset, use main_transformer.py instead. 66 | ```commandline 67 | python main_transformer.py --data_name Stackoverflow \ 68 | --model_name transformer \ 69 | --control_name 1_100_0.1_iid_dynamic_a6-b10-c11-d18-e55_bn_1_1 \ 70 | --exp_name roll_so_test \ 71 | --algo roll \ 72 | --g_epoch 1500 \ 73 | --l_epoch 1 \ 74 | --lr 2e-4 \ 75 | --schedule 600 1000 \ 76 | --seed 31 \ 77 | --num_experiments 3 \ 78 | --devices 0 1 2 3 4 5 6 7 79 | ``` 80 | To train a data and model homogeneous the command would look like this. 81 | ```commandline 82 | python main_resnet.py --data_name CIFAR10 \ 83 | --model_name resnet18 \ 84 | --control_name 1_100_0.1_iid_dynamic_a1_bn_1_1 \ 85 | --exp_name homogeneous_largest_low_heterogeneity \ 86 | --algo static \ 87 | --g_epoch 3200 \ 88 | --l_epoch 1 \ 89 | --lr 2e-4 \ 90 | --schedule 800 1200 \ 91 | --seed 31 \ 92 | --num_experiments 3 \ 93 | --devices 0 1 2 94 | ``` 95 | To reproduce the results of on Table 3 in the paper please run the following commands: 96 | 97 | CIFAR-10 98 | ``` commandline 99 | python main_resnet.py --data_name CIFAR10 \ 100 | --model_name resnet18 \ 101 | --control_name 1_100_0.1_iid_dynamic_a1-b1-c1-d1-e1_bn_1_1 \ 102 | --exp_name homogeneous_largest_low_heterogeneity \ 103 | --algo static \ 104 | --g_epoch 3200 \ 105 | --l_epoch 1 \ 106 | --lr 2e-4 \ 107 | --schedule 800 1200 \ 108 | --seed 31 \ 109 | --num_experiments 5 \ 110 | --devices 0 1 2 111 | ``` 112 | CIFAR-100 113 | ``` commandline 114 | python main_resnet.py --data_name CIFAR100 \ 115 | --model_name resnet18 \ 116 | --control_name 1_100_0.1_iid_dynamic_a1-b1-c1-d1-e1_bn_1_1 \ 117 | --exp_name homogeneous_largest_low_heterogeneity \ 118 | --algo static \ 119 | --g_epoch 2500 \ 120 | --l_epoch 1 \ 121 | --lr 2e-4 \ 122 | --schedule 800 1200 \ 123 | --seed 31 \ 124 | --num_experiments 5 \ 125 | --devices 0 1 2 126 | ``` 127 | StackOverflow 128 | ```commandline 129 | python main_transformer.py --data_name Stackoverflow \ 130 | --model_name transformer \ 131 | --control_name 1_100_0.1_iid_dynamic_a1-b1-c1-d1-e1_bn_1_1 \ 132 | --exp_name roll_so_test \ 133 | --algo roll \ 134 | --g_epoch 1500 \ 135 | --l_epoch 1 \ 136 | --lr 2e-4 \ 137 | --schedule 600 1000 \ 138 | --seed 31 \ 139 | --num_experiments 5 \ 140 | --devices 0 1 2 3 4 5 6 7 141 | ``` 142 | 143 | Note: To get the results based on the real world distribution as in Table 4, use `a6-b10-c11-d18-e55` as the 144 | distribution. 145 | 146 | ## Citation 147 | If you find this useful for your work, please consider citing: 148 | 149 | ``` 150 | @InProceedings{alam2022fedrolex, 151 | title = {FedRolex: Model-Heterogeneous Federated Learning with Rolling Sub-Model Extraction}, 152 | author = {Alam, Samiul and Liu, Luyang and Yan, Ming and Zhang, Mi}, 153 | booktitle = {Conference on Neural Information Processing Systems (NeurIPS)}, 154 | year = {2022} 155 | } 156 | ``` 157 | 158 | 159 | -------------------------------------------------------------------------------- /check_local.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | 5 | title = 'CIFAR100 \nLow Heterogeneity' 6 | axisfont_minor = 12 7 | axisfont_major = 20 8 | label_font = 18 9 | title_font = 24 10 | b = 12 11 | tag = 'output/runs/31_CIFAR100_label_resnet18_1_100_0.1_non-iid-50_dynamic_' 12 | base = f'{tag}e1_bn_1_1_real_world.pt' 13 | compare = f'{tag}a6-b10-c11-d18-e55_bn_1_1_real_world.pt' 14 | acc = torch.load(compare) 15 | 16 | x = np.array([i[0][0]['Local-Accuracy'] for i in acc]) 17 | x += np.random.normal(0, 1, x.shape) 18 | x = np.clip(x, 0, 100) 19 | 20 | fig, ax = plt.subplots() 21 | ax.tick_params(axis='both', which='major', labelsize=axisfont_major) 22 | ax.tick_params(axis='both', which='minor', labelsize=axisfont_minor) 23 | 24 | n, bins, patches = plt.hist(x, bins=b, facecolor='#2ab0ff', edgecolor='#e0e0e0', linewidth=0.5, alpha=0.75) 25 | 26 | n = n.astype('int') # it MUST be integer 27 | # Good old loop. Choose colormap of your taste 28 | for i in range(len(patches)): 29 | patches[i].set_facecolor(plt.cm.get_cmap('Oranges')(n[i] / max(n))) 30 | 31 | acc = torch.load(base) 32 | 33 | x = np.array([i[0][0]['Local-Accuracy'] for i in acc[0]]) 34 | x += np.random.normal(0, 1, x.shape) 35 | x = np.clip(x, 0, 100) 36 | 37 | n, bins, patches = plt.hist(x, bins=b, facecolor='#2ab0ff', edgecolor='#e0e0e0', linewidth=0.5, alpha=0.7) 38 | 39 | n = n.astype('int') # it MUST be integer 40 | # Good old loop. Choose colormap of your taste 41 | for i in range(len(patches)): 42 | patches[i].set_facecolor(plt.cm.get_cmap('Blues')(n[i] / max(n))) 43 | 44 | plt.title(f'{title}', fontsize=title_font) 45 | plt.xlabel('Accuracy', fontsize=label_font) 46 | plt.ylabel('Counts', fontsize=label_font) 47 | 48 | # plt.rcParams['font.weight'] = 'bold' 49 | # plt.rcParams['axes.labelweight'] = 'bold' 50 | plt.rcParams['font.size'] = label_font 51 | # plt.rcParams['axes.linewidth'] = 1.0 52 | # plt.show() 53 | 54 | plt.savefig(f'{tag}a6-b10-c11-d18-e55_bn_1_1_real_world_accuracy_distribution.pdf', bbox_inches="tight") 55 | # 'Accent', 'Accent_r', 'Blues', 'Blues_r', 56 | # 'BrBG', 'BrBG_r', 'BuGn', 'BuGn_r', 'BuPu', 'BuPu_r', 'CMRmap', 'CMRmap_r', 'Dark2', 'Dark2_r', 'GnBu', 'GnBu_r', 'Greens', 'Greens_r', 'Greys', 'Greys_r', 'OrRd', 'OrRd_r', 'Oranges', 'Oranges_r', 'PRGn', 'PRGn_r', 'Paired', 'Paired_r', 'Pastel1', 'Pastel1_r', 'Pastel2', 'Pastel2_r', 'PiYG', 'PiYG_r', 'PuBu', 'PuBuGn', 'PuBuGn_r', 'PuBu_r', 'PuOr', 'PuOr_r', 'PuRd', 'PuRd_r', 'Purples', 'Purples_r', 'RdBu', 'RdBu_r', 'RdGy', 'RdGy_r', 'RdPu', 'RdPu_r', 'RdYlBu', 'RdYlBu_r', 'RdYlGn', 'RdYlGn_r', 'Reds', 'Reds_r', 'Set1', 'Set1_r', 'Set2', 'Set2_r', 'Set3', 'Set3_r', 'Spectral', 'Spectral_r', 'Wistia', 'Wistia_r', 'YlGn', 'YlGnBu', 'YlGnBu_r', 'YlGn_r', 'YlOrBr', 'YlOrBr_r', 'YlOrRd', 'YlOrRd_r', 'afmhot', 'afmhot_r', 'autumn', 'autumn_r', 'binary', 'binary_r', 'bone', 'bone_r', 'brg', 'brg_r', 'bwr', 'bwr_r', 'cividis', 'cividis_r', 'cool', 'cool_r', 'coolwarm', 'coolwarm_r', 'copper', 'copper_r', 'crest', 'crest_r', 'cubehelix', 'cubehelix_r', 'flag', 'flag_r', 'flare', 'flare_r', 'gist_earth', 'gist_earth_r', 'gist_gray', 'gist_gray_r', 'gist_heat', 'gist_heat_r', 'gist_ncar', 'gist_ncar_r', 'gist_rainbow', 'gist_rainbow_r', 'gist_stern', 'gist_stern_r', 'gist_yarg', 'gist_yarg_r', 'gnuplot', 'gnuplot2', 'gnuplot2_r', 'gnuplot_r', 'gray', 'gray_r', 'hot', 'hot_r', 'hsv', 'hsv_r', 'icefire', 'icefire_r', 'inferno', 'inferno_r', 'jet', 'jet_r', 'magma', 'magma_r', 'mako', 'mako_r', 'nipy_spectral', 'nipy_spectral_r', 'ocean', 'ocean_r', 'pink', 'pink_r', 'plasma', 'plasma_r', 'prism', 'prism_r', 'rainbow', 'rainbow_r', 'rocket', 'rocket_r', 'seismic', 'seismic_r', 'spring', 'spring_r', 'summer', 'summer_r', 'tab10', 'tab10_r', 'tab20', 'tab20_r', 'tab20b', 'tab20b_r', 'tab20c', 'tab20c_r', 'terrain', 'terrain_r', 'turbo', 'turbo_r', 'twilight', 'twilight_r', 'twilight_shifted', 'twilight_shifted_r', 'viridis', 'viridis_r', 'vlag', 'vlag_r', 'winter', 'winter_r' 57 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | global cfg 4 | if 'cfg' not in globals(): 5 | with open('config.yml', 'r') as f: 6 | cfg = yaml.load(f, Loader=yaml.FullLoader) 7 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | --- 2 | # control 3 | exp_name: hetero_fl_roll_50_100 4 | control: 5 | fed: '1' 6 | num_users: '100' 7 | frac: '0.1' 8 | data_split_mode: 'iid' 9 | model_split_mode: 'fix' 10 | model_mode: 'a1' 11 | norm: 'bn' 12 | scale: '1' 13 | mask: '1' 14 | # data 15 | data_name: CIFAR10 16 | subset: label 17 | batch_size: 18 | train: 128 19 | test: 128 20 | shuffle: 21 | train: True 22 | test: False 23 | num_workers: 0 24 | model_name: resnet18 25 | metric_name: 26 | train: 27 | - Loss 28 | - Accuracy 29 | test: 30 | - Loss 31 | - Accuracy 32 | # optimizer 33 | optimizer_name: Adam 34 | lr: 3.0e-4 35 | momentum: 0.9 36 | weight_decay: 5.0e-4 37 | # scheduler 38 | scheduler_name: None 39 | step_size: 1 40 | milestones: 41 | - 100 42 | - 150 43 | patience: 10 44 | threshold: 1.0e-3 45 | factor: 0.5 46 | min_lr: 1.0e-4 47 | # experiment 48 | init_seed: 31 49 | num_experiments: 1 50 | num_epochs: 200 51 | log_interval: 0.25 52 | device: cuda 53 | world_size: 1 54 | resume_mode: 0 55 | # other 56 | save_format: pdf -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torch.utils.data.dataloader import default_collate 7 | from torchvision import transforms 8 | 9 | import datasets 10 | from config import cfg 11 | from datasets.gld import GLD160 12 | 13 | 14 | def fetch_dataset(data_name, subset): 15 | dataset = {} 16 | print('fetching data {}...'.format(data_name)) 17 | root = './data/{}'.format(data_name) 18 | if data_name == 'MNIST': 19 | dataset['train'] = datasets.MNIST(root=root, split='train', subset=subset, transform=datasets.Compose( 20 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) 21 | dataset['test'] = datasets.MNIST(root=root, split='test', subset=subset, transform=datasets.Compose( 22 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) 23 | elif data_name == 'CIFAR10': 24 | dataset['train'] = datasets.CIFAR10(root=root, split='train', subset=subset, transform=datasets.Compose( 25 | [transforms.RandomCrop(32, padding=4), 26 | transforms.RandomHorizontalFlip(), 27 | transforms.ToTensor(), 28 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])) 29 | dataset['test'] = datasets.CIFAR10(root=root, split='test', subset=subset, transform=datasets.Compose( 30 | [transforms.ToTensor(), 31 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])) 32 | elif data_name == 'CIFAR100': 33 | dataset['train'] = datasets.CIFAR100(root=root, split='train', subset=subset, transform=datasets.Compose( 34 | [transforms.RandomCrop(32, padding=4), 35 | transforms.RandomHorizontalFlip(), 36 | transforms.ToTensor(), 37 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])) 38 | dataset['test'] = datasets.CIFAR100(root=root, split='test', subset=subset, transform=datasets.Compose( 39 | [transforms.ToTensor(), 40 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])) 41 | elif data_name in ['PennTreebank', 'WikiText2', 'WikiText103']: 42 | dataset['train'] = eval('datasets.{}(root=root, split=\'train\')'.format(data_name)) 43 | dataset['test'] = eval('datasets.{}(root=root, split=\'test\')'.format(data_name)) 44 | elif data_name in ['Stackoverflow']: 45 | dataset['train'] = torch.load(os.path.join('/egr/research-zhanglambda/samiul/stackoverflow/', 46 | 'stackoverflow_{}.pt'.format('train'))) 47 | dataset['test'] = torch.load(os.path.join('/egr/research-zhanglambda/samiul/stackoverflow/', 48 | 'stackoverflow_{}.pt'.format('val'))) 49 | dataset['vocab'] = torch.load(os.path.join('/egr/research-zhanglambda/samiul/stackoverflow/', 50 | 'meta.pt')) 51 | elif data_name in ['gld']: 52 | dataset['train'] = torch.load(os.path.join('gld_160k/', 53 | '{}.pt'.format('train'))) 54 | dataset['test'] = torch.load(os.path.join('gld_160k/', 55 | '{}.pt'.format('test'))) 56 | else: 57 | raise ValueError('Not valid dataset name') 58 | print('data ready') 59 | return dataset 60 | 61 | 62 | def input_collate(batch): 63 | if isinstance(batch[0], dict): 64 | output = {key: [] for key in batch[0].keys()} 65 | for b in batch: 66 | for key in b: 67 | output[key].append(b[key]) 68 | return output 69 | else: 70 | return default_collate(batch) 71 | 72 | 73 | def split_dataset(dataset, num_users, data_split_mode): 74 | data_split = {} 75 | if cfg['data_name'] in ['gld']: 76 | data_split['train'] = [GLD160(usr_data, usr_labels) for usr_data, usr_labels, _ in dataset['train'].values()] 77 | data_split['test'] = GLD160(*dataset['test']) 78 | label_split = [list(usr_lbl_split.keys()) for _, _, usr_lbl_split in dataset['train'].values()] 79 | return data_split, label_split 80 | if data_split_mode == 'iid': 81 | data_split['train'], label_split = iid(dataset['train'], num_users) 82 | data_split['test'], _ = iid(dataset['test'], num_users) 83 | elif 'non-iid' in cfg['data_split_mode']: 84 | data_split['train'], label_split = non_iid(dataset['train'], num_users) 85 | data_split['test'], _ = non_iid(dataset['test'], num_users, label_split) 86 | 87 | else: 88 | raise ValueError('Not valid data split mode') 89 | return data_split, label_split 90 | 91 | 92 | def iid(dataset, num_users): 93 | if cfg['data_name'] in ['MNIST', 'CIFAR10', 'CIFAR100']: 94 | label = torch.tensor(dataset.target) 95 | elif cfg['data_name'] in ['WikiText2', 'WikiText103', 'PennTreebank']: 96 | label = dataset.token 97 | else: 98 | raise ValueError('Not valid data name') 99 | num_items = int(len(dataset) / num_users) 100 | data_split, idx = {}, list(range(len(dataset))) 101 | label_split = {} 102 | for i in range(num_users): 103 | num_items_i = min(len(idx), num_items) 104 | data_split[i] = torch.tensor(idx)[torch.randperm(len(idx))[:num_items_i]].tolist() 105 | label_split[i] = torch.unique(label[data_split[i]]).tolist() 106 | idx = list(set(idx) - set(data_split[i])) 107 | return data_split, label_split 108 | 109 | 110 | def non_iid(dataset, num_users, label_split=None): 111 | label = np.array(dataset.target) 112 | cfg['non-iid-n'] = int(cfg['data_split_mode'].split('-')[-1]) 113 | shard_per_user = cfg['non-iid-n'] 114 | data_split = {i: [] for i in range(num_users)} 115 | label_idx_split = {} 116 | for i in range(len(label)): 117 | label_i = label[i].item() 118 | if label_i not in label_idx_split: 119 | label_idx_split[label_i] = [] 120 | label_idx_split[label_i].append(i) 121 | shard_per_class = int(shard_per_user * num_users / cfg['classes_size']) 122 | for label_i in label_idx_split: 123 | label_idx = label_idx_split[label_i] 124 | num_leftover = len(label_idx) % shard_per_class 125 | leftover = label_idx[-num_leftover:] if num_leftover > 0 else [] 126 | new_label_idx = np.array(label_idx[:-num_leftover]) if num_leftover > 0 else np.array(label_idx) 127 | new_label_idx = new_label_idx.reshape((shard_per_class, -1)).tolist() 128 | for i, leftover_label_idx in enumerate(leftover): 129 | new_label_idx[i] = np.concatenate([new_label_idx[i], [leftover_label_idx]]) 130 | label_idx_split[label_i] = new_label_idx 131 | if label_split is None: 132 | label_split = list(range(cfg['classes_size'])) * shard_per_class 133 | label_split = torch.tensor(label_split)[torch.randperm(len(label_split))].tolist() 134 | label_split = np.array(label_split).reshape((num_users, -1)).tolist() 135 | for i in range(len(label_split)): 136 | label_split[i] = np.unique(label_split[i]).tolist() 137 | for i in range(num_users): 138 | for label_i in label_split[i]: 139 | idx = torch.arange(len(label_idx_split[label_i]))[torch.randperm(len(label_idx_split[label_i]))[0]].item() 140 | data_split[i].extend(label_idx_split[label_i].pop(idx)) 141 | return data_split, label_split 142 | 143 | 144 | def make_data_loader(dataset): 145 | data_loader = {} 146 | for k in dataset: 147 | data_loader[k] = torch.utils.data.DataLoader(dataset=dataset[k], shuffle=cfg['shuffle'][k], 148 | batch_size=cfg['batch_size'][k], pin_memory=True, 149 | num_workers=cfg['num_workers'], collate_fn=input_collate) 150 | return data_loader 151 | 152 | 153 | class SplitDataset(Dataset): 154 | def __init__(self, dataset, idx): 155 | super().__init__() 156 | self.dataset = dataset 157 | self.idx = idx 158 | 159 | def __len__(self): 160 | return len(self.idx) 161 | 162 | def __getitem__(self, index): 163 | return self.dataset[self.idx[index]] 164 | 165 | 166 | class GenericDataset(Dataset): 167 | def __init__(self, dataset): 168 | super().__init__() 169 | self.dataset = dataset 170 | 171 | def __len__(self): 172 | return len(self.dataset) 173 | 174 | def __getitem__(self, index): 175 | input = self.dataset[index] 176 | return input 177 | 178 | 179 | class BatchDataset(Dataset): 180 | def __init__(self, dataset, seq_length): 181 | super().__init__() 182 | self.dataset = dataset 183 | self.seq_length = seq_length 184 | self.S = dataset[0]['label'].size(0) 185 | self.idx = list(range(0, self.S, seq_length)) 186 | 187 | def __len__(self): 188 | return len(self.idx) 189 | 190 | def __getitem__(self, index): 191 | seq_length = min(self.seq_length, self.S - index) 192 | return {'label': self.dataset[:]['label'][:, self.idx[index]:self.idx[index] + seq_length]} 193 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar import CIFAR10, CIFAR100 2 | from .folder import ImageFolder 3 | from .imagenet import ImageNet 4 | from .lm import PennTreebank, WikiText2, WikiText103 5 | from .mnist import MNIST, EMNIST, FashionMNIST 6 | from .transforms import * 7 | from .utils import * 8 | 9 | __all__ = ('MNIST', 'EMNIST', 'FashionMNIST', 10 | 'CIFAR10', 'CIFAR100', 11 | 'ImageNet', 12 | 'PennTreebank', 'WikiText2', 'WikiText103', 13 | 'ImageFolder') 14 | -------------------------------------------------------------------------------- /datasets/cifar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import anytree 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | 10 | from utils import check_exists, makedir_exist_ok, save, load 11 | from .utils import download_url, extract_file, make_classes_counts, make_tree, make_flat_index 12 | 13 | 14 | class CIFAR10(Dataset): 15 | data_name = 'CIFAR10' 16 | file = [('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', 'c58f30108f718f92721af3b95e74349a')] 17 | 18 | def __init__(self, root, split, subset, transform=None): 19 | self.root = os.path.expanduser(root) 20 | self.split = split 21 | self.subset = subset 22 | self.transform = transform 23 | if not check_exists(self.processed_folder): 24 | self.process() 25 | self.img, self.target = load(os.path.join(self.processed_folder, '{}.pt'.format(self.split))) 26 | self.target = self.target[self.subset] 27 | self.classes_counts = make_classes_counts(self.target) 28 | self.classes_to_labels, self.classes_size = load(os.path.join(self.processed_folder, 'meta.pt')) 29 | self.classes_to_labels, self.classes_size = self.classes_to_labels[self.subset], self.classes_size[self.subset] 30 | 31 | def __getitem__(self, index): 32 | img, target = Image.fromarray(self.img[index]), torch.tensor(self.target[index]) 33 | input = {'img': img, self.subset: target} 34 | if self.transform is not None: 35 | input = self.transform(input) 36 | return input 37 | 38 | def __len__(self): 39 | return len(self.img) 40 | 41 | @property 42 | def processed_folder(self): 43 | return os.path.join(self.root, 'processed') 44 | 45 | @property 46 | def raw_folder(self): 47 | return os.path.join(self.root, 'raw') 48 | 49 | def process(self): 50 | if not check_exists(self.raw_folder): 51 | self.download() 52 | train_set, test_set, meta = self.make_data() 53 | save(train_set, os.path.join(self.processed_folder, 'train.pt')) 54 | save(test_set, os.path.join(self.processed_folder, 'test.pt')) 55 | save(meta, os.path.join(self.processed_folder, 'meta.pt')) 56 | return 57 | 58 | def download(self): 59 | makedir_exist_ok(self.raw_folder) 60 | for (url, md5) in self.file: 61 | filename = os.path.basename(url) 62 | download_url(url, self.raw_folder, filename, md5) 63 | extract_file(os.path.join(self.raw_folder, filename)) 64 | return 65 | 66 | def __repr__(self): 67 | fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSplit: {}\nSubset: {}\nTransforms: {}'.format( 68 | self.__class__.__name__, self.__len__(), self.root, self.split, self.subset, self.transform.__repr__()) 69 | return fmt_str 70 | 71 | def make_data(self): 72 | train_filenames = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5'] 73 | test_filenames = ['test_batch'] 74 | train_img, train_label = read_pickle_file(os.path.join(self.raw_folder, 'cifar-10-batches-py'), train_filenames) 75 | test_img, test_label = read_pickle_file(os.path.join(self.raw_folder, 'cifar-10-batches-py'), test_filenames) 76 | train_target, test_target = {'label': train_label}, {'label': test_label} 77 | with open(os.path.join(self.raw_folder, 'cifar-10-batches-py', 'batches.meta'), 'rb') as f: 78 | data = pickle.load(f, encoding='latin1') 79 | classes = data['label_names'] 80 | classes_to_labels = {'label': anytree.Node('U', index=[])} 81 | for c in classes: 82 | make_tree(classes_to_labels['label'], [c]) 83 | classes_size = {'label': make_flat_index(classes_to_labels['label'])} 84 | return (train_img, train_target), (test_img, test_target), (classes_to_labels, classes_size) 85 | 86 | 87 | class CIFAR100(CIFAR10): 88 | data_name = 'CIFAR100' 89 | file = [('https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz', 'eb9058c3a382ffc7106e4002c42a8d85')] 90 | 91 | def make_data(self): 92 | train_filenames = ['train'] 93 | test_filenames = ['test'] 94 | train_img, train_label = read_pickle_file(os.path.join(self.raw_folder, 'cifar-100-python'), train_filenames) 95 | test_img, test_label = read_pickle_file(os.path.join(self.raw_folder, 'cifar-100-python'), test_filenames) 96 | train_target, test_target = {'label': train_label}, {'label': test_label} 97 | with open(os.path.join(self.raw_folder, 'cifar-100-python', 'meta'), 'rb') as f: 98 | data = pickle.load(f, encoding='latin1') 99 | classes = data['fine_label_names'] 100 | classes_to_labels = {'label': anytree.Node('U', index=[])} 101 | for c in classes: 102 | for k in CIFAR100_classes: 103 | if c in CIFAR100_classes[k]: 104 | c = [k, c] 105 | break 106 | make_tree(classes_to_labels['label'], c) 107 | classes_size = {'label': make_flat_index(classes_to_labels['label'], classes)} 108 | return (train_img, train_target), (test_img, test_target), (classes_to_labels, classes_size) 109 | 110 | 111 | def read_pickle_file(path, filenames): 112 | img, label = [], [] 113 | for filename in filenames: 114 | file_path = os.path.join(path, filename) 115 | with open(file_path, 'rb') as f: 116 | entry = pickle.load(f, encoding='latin1') 117 | img.append(entry['data']) 118 | label.extend(entry['labels']) if 'labels' in entry else label.extend(entry['fine_labels']) 119 | img = np.vstack(img).reshape(-1, 3, 32, 32) 120 | img = img.transpose((0, 2, 3, 1)) 121 | return img, label 122 | 123 | 124 | CIFAR100_classes = { 125 | 'aquatic mammals': ['beaver', 'dolphin', 'otter', 'seal', 'whale'], 126 | 'fish': ['aquarium_fish', 'flatfish', 'ray', 'shark', 'trout'], 127 | 'flowers': ['orchid', 'poppy', 'rose', 'sunflower', 'tulip'], 128 | 'food containers': ['bottle', 'bowl', 'can', 'cup', 'plate'], 129 | 'fruit and vegetables': ['apple', 'mushroom', 'orange', 'pear', 'sweet_pepper'], 130 | 'household electrical devices': ['clock', 'keyboard', 'lamp', 'telephone', 'television'], 131 | 'household furniture': ['bed', 'chair', 'couch', 'table', 'wardrobe'], 132 | 'insects': ['bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach'], 133 | 'large carnivores': ['bear', 'leopard', 'lion', 'tiger', 'wolf'], 134 | 'large man-made outdoor things': ['bridge', 'castle', 'house', 'road', 'skyscraper'], 135 | 'large natural outdoor scenes': ['cloud', 'forest', 'mountain', 'plain', 'sea'], 136 | 'large omnivores and herbivores': ['camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo'], 137 | 'medium-sized mammals': ['fox', 'porcupine', 'possum', 'raccoon', 'skunk'], 138 | 'non-insect invertebrates': ['crab', 'lobster', 'snail', 'spider', 'worm'], 139 | 'people': ['baby', 'boy', 'girl', 'man', 'woman'], 140 | 'reptiles': ['crocodile', 'dinosaur', 'lizard', 'snake', 'turtle'], 141 | 'small mammals': ['hamster', 'mouse', 'rabbit', 'shrew', 'squirrel'], 142 | 'trees': ['maple_tree', 'oak_tree', 'palm_tree', 'pine_tree', 'willow_tree'], 143 | 'vehicles 1': ['bicycle', 'bus', 'motorcycle', 'pickup_truck', 'train'], 144 | 'vehicles 2': ['lawn_mower', 'rocket', 'streetcar', 'tank', 'tractor'] 145 | } 146 | -------------------------------------------------------------------------------- /datasets/folder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | from utils import check_exists, save, load 8 | from .utils import find_classes, make_img, make_classes_counts 9 | 10 | 11 | class ImageFolder(Dataset): 12 | 13 | def __init__(self, root, split, subset, transform=None): 14 | self.data_name = os.path.basename(root) 15 | self.root = os.path.expanduser(root) 16 | self.split = split 17 | self.subset = subset 18 | self.transform = transform 19 | if not check_exists(self.processed_folder): 20 | self.process() 21 | self.img, self.target = load(os.path.join(self.processed_folder, '{}.pt'.format(self.split))) 22 | self.target = self.target[self.subset] 23 | self.classes_counts = make_classes_counts(self.target) 24 | self.classes_to_labels, self.classes_size = load(os.path.join(self.processed_folder, 'meta.pt')) 25 | self.classes_to_labels, self.classes_size = self.classes_to_labels[self.subset], self.classes_size[self.subset] 26 | 27 | def __getitem__(self, index): 28 | img, target = Image.open(self.img[index], mode='r').convert('RGB'), torch.tensor(self.target[index]) 29 | input = {'img': img, self.subset: target} 30 | if self.transform is not None: 31 | input = self.transform(input) 32 | return input 33 | 34 | def __len__(self): 35 | return len(self.img) 36 | 37 | @property 38 | def processed_folder(self): 39 | return os.path.join(self.root, 'processed') 40 | 41 | @property 42 | def raw_folder(self): 43 | return os.path.join(self.root, 'raw') 44 | 45 | def process(self): 46 | if not check_exists(self.raw_folder): 47 | raise RuntimeError('Dataset not found') 48 | train_set, test_set, meta = self.make_data() 49 | save(train_set, os.path.join(self.processed_folder, 'train.pt')) 50 | save(test_set, os.path.join(self.processed_folder, 'test.pt')) 51 | save(meta, os.path.join(self.processed_folder, 'meta.pt')) 52 | return 53 | 54 | def __repr__(self): 55 | fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSplit: {}\nSubset: {}\nTransforms: {}'.format( 56 | self.__class__.__name__, self.__len__(), self.root, self.split, self.subset, self.transform.__repr__()) 57 | return fmt_str 58 | 59 | def make_data(self): 60 | classes_to_labels, classes_size = find_classes(os.path.join(self.raw_folder, 'train')) 61 | train_img, train_label = make_img(os.path.join(self.raw_folder, 'train'), classes_to_labels['label']) 62 | test_img, test_label = make_img(os.path.join(self.raw_folder, 'test'), classes_to_labels['label']) 63 | train_target, test_target = {'label': train_label}, {'label': test_label} 64 | return (train_img, train_target), (test_img, test_target), (classes_to_labels, classes_size) 65 | -------------------------------------------------------------------------------- /datasets/gld.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class GLD160(Dataset): 6 | data_name = 'GLD' 7 | 8 | def __init__(self, images, targets, transform=T.RandomCrop(92)): 9 | self.transform = transform 10 | self.img, self.target = images, targets 11 | self.classes_counts = 2028 12 | 13 | def __getitem__(self, index): 14 | img = self.img[index] 15 | target = self.target[index] 16 | inp = {'img': img, 'label': target} 17 | if self.transform is not None: 18 | inp['img'] = self.transform(inp['img']) 19 | return inp 20 | 21 | def __len__(self): 22 | return len(self.img) 23 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import anytree 5 | import numpy as np 6 | import torch 7 | from PIL import Image, ImageFile 8 | from torch.utils.data import Dataset 9 | 10 | from utils import check_exists, save, load 11 | from .utils import extract_file, make_classes_counts, make_img, make_tree, make_flat_index 12 | 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | 15 | 16 | class ImageNet(Dataset): 17 | data_name = 'ImageNet' 18 | 19 | def __init__(self, root, split, subset, size, transform=None): 20 | self.root = os.path.expanduser(root) 21 | self.split = split 22 | self.subset = subset 23 | self.transform = transform 24 | self.size = size 25 | if not check_exists(os.path.join(self.processed_folder, str(self.size))): 26 | self.process() 27 | self.img, self.target = load(os.path.join(self.processed_folder, str(self.size), '{}.pt'.format(self.split))) 28 | self.target = self.target[self.subset] 29 | self.classes_counts = make_classes_counts(self.target) 30 | self.classes_to_labels, self.classes_size = load(os.path.join(self.processed_folder, str(self.size), 'meta.pt')) 31 | self.classes_to_labels, self.classes_size = self.classes_to_labels[self.subset], self.classes_size[self.subset] 32 | 33 | def __getitem__(self, index): 34 | img, target = Image.open(self.img[index], mode='r').convert('RGB'), torch.tensor(self.target[index]) 35 | input = {'img': img, self.subset: target} 36 | if self.transform is not None: 37 | input = self.transform(input) 38 | return input 39 | 40 | def __len__(self): 41 | return len(self.img) 42 | 43 | @property 44 | def processed_folder(self): 45 | return os.path.join(self.root, 'processed') 46 | 47 | @property 48 | def raw_folder(self): 49 | return os.path.join(self.root, 'raw') 50 | 51 | def process(self): 52 | if not check_exists(self.raw_folder): 53 | raise RuntimeError('Dataset not found') 54 | train_set, test_set, meta = self.make_data() 55 | save(train_set, os.path.join(self.processed_folder, str(self.size), 'train.pt')) 56 | save(test_set, os.path.join(self.processed_folder, str(self.size), 'test.pt')) 57 | save(meta, os.path.join(self.processed_folder, str(self.size), 'meta.pt')) 58 | return 59 | 60 | def __repr__(self): 61 | fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSplit: {}\nSubset: {}\nSize: {}\nTransforms: {}'.format( 62 | self.__class__.__name__, self.__len__(), self.root, self.split, self.subset, self.size, 63 | self.transform.__repr__()) 64 | return fmt_str 65 | 66 | def make_data(self): 67 | if not check_exists(os.path.join(self.raw_folder, 'base')): 68 | train_path = os.path.join(self.raw_folder, 'ILSVRC2012_img_train') 69 | test_path = os.path.join(self.raw_folder, 'ILSVRC2012_img_val') 70 | meta_path = os.path.join(self.raw_folder, 'ILSVRC2012_devkit_t12') 71 | extract_file(os.path.join(self.raw_folder, 'ILSVRC2012_img_train.tar'), train_path) 72 | extract_file(os.path.join(self.raw_folder, 'ILSVRC2012_img_val.tar'), test_path) 73 | extract_file(os.path.join(self.raw_folder, 'ILSVRC2012_devkit_t12.tar'), meta_path) 74 | for archive in [os.path.join(train_path, archive) for archive in os.listdir(train_path)]: 75 | extract_file(archive, os.path.splitext(archive)[0], delete=True) 76 | classes_to_labels, classes_size = make_meta(meta_path) 77 | with open(os.path.join(meta_path, 'data', 'ILSVRC2012_validation_ground_truth.txt'), 'r') as f: 78 | test_id = f.readlines() 79 | test_id = [int(i) for i in test_id] 80 | test_img = sorted([os.path.join(test_path, file) for file in os.listdir(test_path)]) 81 | test_wnid = [] 82 | for test_id_i in test_id: 83 | test_node_i = anytree.find_by_attr(classes_to_labels['label'], name='id', value=test_id_i) 84 | test_wnid.append(test_node_i.name) 85 | for test_wnid_i in set(test_wnid): 86 | os.mkdir(os.path.join(test_path, test_wnid_i)) 87 | for test_wnid_i, test_img in zip(test_wnid, test_img): 88 | shutil.move(test_img, os.path.join(test_path, test_wnid_i, os.path.basename(test_img))) 89 | shutil.move(train_path, os.path.join(self.raw_folder, 'base', 'ILSVRC2012_img_train')) 90 | shutil.move(test_path, os.path.join(self.raw_folder, 'base', 'ILSVRC2012_img_val')) 91 | shutil.move(meta_path, os.path.join(self.raw_folder, 'base', 'ILSVRC2012_devkit_t12')) 92 | if not check_exists(os.path.join(self.raw_folder, str(self.size))): 93 | raise ValueError('Need to run resizer') 94 | classes_to_labels, classes_size = make_meta(os.path.join(self.raw_folder, 'base', 'ILSVRC2012_devkit_t12')) 95 | train_img, train_label = make_img(os.path.join(self.raw_folder, str(self.size), 'ILSVRC2012_img_train'), 96 | classes_to_labels['label']) 97 | test_img, test_label = make_img(os.path.join(self.raw_folder, str(self.size), 'ILSVRC2012_img_val'), 98 | classes_to_labels['label']) 99 | train_target = {'label': train_label} 100 | test_target = {'label': test_label} 101 | return (train_img, train_target), (test_img, test_target), (classes_to_labels, classes_size) 102 | 103 | 104 | def make_meta(path): 105 | import scipy.io as sio 106 | meta = sio.loadmat(os.path.join(path, 'data', 'meta.mat'), squeeze_me=True)['synsets'] 107 | num_children = list(zip(*meta))[4] 108 | leaf_meta = [meta[i] for (i, n) in enumerate(num_children) if n == 0] 109 | branch_meta = [meta[i] for (i, n) in enumerate(num_children) if n > 0] 110 | names, attributes = [], [] 111 | for i in range(len(leaf_meta)): 112 | name, attribute = make_node(leaf_meta[i], branch_meta) 113 | names.append(name) 114 | attributes.append(attribute) 115 | classes_to_labels = {'label': anytree.Node('U', index=[])} 116 | classes = [] 117 | for (name, attribute) in zip(names, attributes): 118 | make_tree(classes_to_labels['label'], name, attribute) 119 | classes.append(name[-1]) 120 | classes_size = {'label': make_flat_index(classes_to_labels['label'], classes)} 121 | return classes_to_labels, classes_size 122 | 123 | 124 | def make_node(node, branch): 125 | id, wnid, classes = node.item()[:3] 126 | if classes == 'entity': 127 | name = [] 128 | attribute = {'id': [], 'class': []} 129 | for i in range(len(branch)): 130 | branch_children = branch[i].item()[5] 131 | if (isinstance(branch_children, int) and id == branch_children) or ( 132 | isinstance(branch_children, np.ndarray) and id in branch_children): 133 | parent_name, parent_attribute = make_node(branch[i], branch) 134 | name = parent_name + [wnid] 135 | attribute = {'id': parent_attribute['id'] + [id], 'class': parent_attribute['class'] + [classes]} 136 | break 137 | return name, attribute 138 | -------------------------------------------------------------------------------- /datasets/lm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import abstractmethod 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | from utils import check_exists, makedir_exist_ok, save, load 8 | from .utils import download_url, extract_file 9 | 10 | 11 | class Vocab: 12 | def __init__(self): 13 | self.symbol_to_index = {u'': 0, u'': 1} 14 | self.index_to_symbol = [u'', u''] 15 | 16 | def add(self, symbol): 17 | if symbol not in self.symbol_to_index: 18 | self.index_to_symbol.append(symbol) 19 | self.symbol_to_index[symbol] = len(self.index_to_symbol) - 1 20 | return 21 | 22 | def delete(self, symbol): 23 | if symbol in self.symbol_to_index: 24 | self.index_to_symbol.remove(symbol) 25 | self.symbol_to_index.pop(symbol, None) 26 | return 27 | 28 | def __len__(self): 29 | return len(self.index_to_symbol) 30 | 31 | def __getitem__(self, input): 32 | if isinstance(input, int): 33 | if len(self.index_to_symbol) > input >= 0: 34 | output = self.index_to_symbol[input] 35 | else: 36 | output = u'' 37 | elif isinstance(input, str): 38 | if input not in self.symbol_to_index: 39 | output = self.symbol_to_index[u''] 40 | else: 41 | output = self.symbol_to_index[input] 42 | else: 43 | raise ValueError('Not valid data type') 44 | return output 45 | 46 | def __contains__(self, input): 47 | if isinstance(input, int): 48 | exist = len(self.index_to_symbol) > input >= 0 49 | elif isinstance(input, str): 50 | exist = input in self.symbol_to_index 51 | else: 52 | raise ValueError('Not valid data type') 53 | return exist 54 | 55 | 56 | class LanguageModeling(Dataset): 57 | def __init__(self, root, split): 58 | self.root = os.path.expanduser(root) 59 | self.split = split 60 | if not check_exists(self.processed_folder): 61 | self.process() 62 | self.token = load(os.path.join(self.processed_folder, '{}.pt'.format(split))) 63 | self.vocab = load(os.path.join(self.processed_folder, 'meta.pt'.format(split))) 64 | 65 | def __getitem__(self, index): 66 | input = {'label': self.token[index]} 67 | return input 68 | 69 | def __len__(self): 70 | return len(self.token) 71 | 72 | @property 73 | def processed_folder(self): 74 | return os.path.join(self.root, 'processed') 75 | 76 | @property 77 | def raw_folder(self): 78 | return os.path.join(self.root, 'raw') 79 | 80 | def _check_exists(self): 81 | return os.path.exists(self.processed_folder) 82 | 83 | @abstractmethod 84 | def process(self): 85 | raise NotImplementedError 86 | 87 | @abstractmethod 88 | def download(self): 89 | raise NotImplementedError 90 | 91 | def __repr__(self): 92 | fmt_str = 'Dataset {}\nRoot: {}\nSplit: {}'.format( 93 | self.__class__.__name__, self.root, self.split) 94 | return fmt_str 95 | 96 | 97 | class PennTreebank(LanguageModeling): 98 | data_name = 'PennTreebank' 99 | file = [('https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt', None), 100 | ('https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.valid.txt', None), 101 | ('https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.test.txt', None)] 102 | 103 | def __init__(self, root, split): 104 | super().__init__(root, split) 105 | 106 | def process(self): 107 | if not check_exists(self.raw_folder): 108 | self.download() 109 | train_set, valid_set, test_set, meta = self.make_data() 110 | save(train_set, os.path.join(self.processed_folder, 'train.pt')) 111 | save(valid_set, os.path.join(self.processed_folder, 'valid.pt')) 112 | save(test_set, os.path.join(self.processed_folder, 'test.pt')) 113 | save(meta, os.path.join(self.processed_folder, 'meta.pt')) 114 | return 115 | 116 | def download(self): 117 | makedir_exist_ok(self.raw_folder) 118 | for (url, md5) in self.file: 119 | filename = os.path.basename(url) 120 | download_url(url, self.raw_folder, filename, md5) 121 | extract_file(os.path.join(self.raw_folder, filename)) 122 | return 123 | 124 | def make_data(self): 125 | vocab = Vocab() 126 | read_token(vocab, os.path.join(self.raw_folder, 'ptb.train.txt')) 127 | read_token(vocab, os.path.join(self.raw_folder, 'ptb.valid.txt')) 128 | train_token = make_token(vocab, os.path.join(self.raw_folder, 'ptb.train.txt')) 129 | valid_token = make_token(vocab, os.path.join(self.raw_folder, 'ptb.valid.txt')) 130 | test_token = make_token(vocab, os.path.join(self.raw_folder, 'ptb.test.txt')) 131 | return train_token, valid_token, test_token, vocab 132 | 133 | 134 | class WikiText2(LanguageModeling): 135 | data_name = 'WikiText2' 136 | file = [('https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip', None)] 137 | 138 | def __init__(self, root, split): 139 | super().__init__(root, split) 140 | 141 | def process(self): 142 | if not check_exists(self.raw_folder): 143 | self.download() 144 | train_set, valid_set, test_set, meta = self.make_data() 145 | save(train_set, os.path.join(self.processed_folder, 'train.pt')) 146 | save(valid_set, os.path.join(self.processed_folder, 'valid.pt')) 147 | save(test_set, os.path.join(self.processed_folder, 'test.pt')) 148 | save(meta, os.path.join(self.processed_folder, 'meta.pt')) 149 | return 150 | 151 | def download(self): 152 | makedir_exist_ok(self.raw_folder) 153 | for (url, md5) in self.file: 154 | filename = os.path.basename(url) 155 | download_url(url, self.raw_folder, filename, md5) 156 | extract_file(os.path.join(self.raw_folder, filename)) 157 | return 158 | 159 | def make_data(self): 160 | vocab = Vocab() 161 | read_token(vocab, os.path.join(self.raw_folder, 'wikitext-2', 'wiki.train.tokens')) 162 | read_token(vocab, os.path.join(self.raw_folder, 'wikitext-2', 'wiki.train.tokens')) 163 | train_token = make_token(vocab, os.path.join(self.raw_folder, 'wikitext-2', 'wiki.train.tokens')) 164 | valid_token = make_token(vocab, os.path.join(self.raw_folder, 'wikitext-2', 'wiki.valid.tokens')) 165 | test_token = make_token(vocab, os.path.join(self.raw_folder, 'wikitext-2', 'wiki.test.tokens')) 166 | return train_token, valid_token, test_token, vocab 167 | 168 | 169 | class WikiText103(LanguageModeling): 170 | data_name = 'WikiText103' 171 | file = [('https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip', None)] 172 | 173 | def __init__(self, root, split): 174 | super().__init__(root, split) 175 | 176 | def process(self): 177 | if not check_exists(self.raw_folder): 178 | self.download() 179 | train_set, valid_set, test_set, meta = self.make_data() 180 | save(train_set, os.path.join(self.processed_folder, 'train.pt')) 181 | save(valid_set, os.path.join(self.processed_folder, 'valid.pt')) 182 | save(test_set, os.path.join(self.processed_folder, 'test.pt')) 183 | save(meta, os.path.join(self.processed_folder, 'meta.pt')) 184 | return 185 | 186 | def download(self): 187 | makedir_exist_ok(self.raw_folder) 188 | for (url, md5) in self.file: 189 | filename = os.path.basename(url) 190 | download_url(url, self.raw_folder, filename, md5) 191 | extract_file(os.path.join(self.raw_folder, filename)) 192 | return 193 | 194 | def make_data(self): 195 | vocab = Vocab() 196 | read_token(vocab, os.path.join(self.raw_folder, 'wikitext-103', 'wiki.train.tokens')) 197 | read_token(vocab, os.path.join(self.raw_folder, 'wikitext-103', 'wiki.train.tokens')) 198 | train_token = make_token(vocab, os.path.join(self.raw_folder, 'wikitext-103', 'wiki.train.tokens')) 199 | valid_token = make_token(vocab, os.path.join(self.raw_folder, 'wikitext-103', 'wiki.valid.tokens')) 200 | test_token = make_token(vocab, os.path.join(self.raw_folder, 'wikitext-103', 'wiki.test.tokens')) 201 | return train_token, valid_token, test_token, vocab 202 | 203 | 204 | class StackOverflowClientDataset(Dataset): 205 | def __init__(self, token, seq_length, batch_size): 206 | self.seq_length = seq_length 207 | self.token = token 208 | num_batch = len(token) // (batch_size * seq_length) 209 | self.token = self.token.narrow(0, 0, num_batch * batch_size * seq_length) 210 | self.token = self.token.reshape(-1, batch_size, seq_length) 211 | 212 | def __getitem__(self, index): 213 | return {'label': self.token[index, :, :].reshape(-1, self.seq_length)} 214 | 215 | def __len__(self): 216 | return len(self.token) 217 | 218 | 219 | def read_token(vocab, token_path): 220 | with open(token_path, 'r', encoding='utf-8') as f: 221 | for line in f: 222 | line = line.split() + [u''] 223 | for symbol in line: 224 | vocab.add(symbol) 225 | return 226 | 227 | 228 | def make_token(vocab, token_path): 229 | token = [] 230 | with open(token_path, 'r', encoding='utf-8') as f: 231 | for line in f: 232 | line = line.split() + [u''] 233 | for symbol in line: 234 | token.append(vocab[symbol]) 235 | token = torch.tensor(token, dtype=torch.long) 236 | return token 237 | -------------------------------------------------------------------------------- /datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import os 3 | 4 | import anytree 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | 10 | from utils import check_exists, makedir_exist_ok, save, load 11 | from .utils import download_url, extract_file, make_classes_counts, make_tree, make_flat_index 12 | 13 | 14 | class MNIST(Dataset): 15 | data_name = 'MNIST' 16 | file = [('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'), 17 | ('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'), 18 | ('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'), 19 | ('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c')] 20 | 21 | def __init__(self, root, split, subset, transform=None): 22 | self.root = os.path.expanduser(root) 23 | self.split = split 24 | self.subset = subset 25 | self.transform = transform 26 | if not check_exists(self.processed_folder): 27 | self.process() 28 | self.img, self.target = load(os.path.join(self.processed_folder, '{}.pt'.format(self.split))) 29 | self.target = self.target[self.subset] 30 | self.classes_counts = make_classes_counts(self.target) 31 | self.classes_to_labels, self.classes_size = load(os.path.join(self.processed_folder, 'meta.pt')) 32 | self.classes_to_labels, self.classes_size = self.classes_to_labels[self.subset], self.classes_size[self.subset] 33 | 34 | def __getitem__(self, index): 35 | img, target = Image.fromarray(self.img[index], mode='L'), torch.tensor(self.target[index]) 36 | input = {'img': img, self.subset: target} 37 | if self.transform is not None: 38 | input = self.transform(input) 39 | return input 40 | 41 | def __len__(self): 42 | return len(self.img) 43 | 44 | @property 45 | def processed_folder(self): 46 | return os.path.join(self.root, 'processed') 47 | 48 | @property 49 | def raw_folder(self): 50 | return os.path.join(self.root, 'raw') 51 | 52 | def process(self): 53 | if not check_exists(self.raw_folder): 54 | self.download() 55 | train_set, test_set, meta = self.make_data() 56 | save(train_set, os.path.join(self.processed_folder, 'train.pt')) 57 | save(test_set, os.path.join(self.processed_folder, 'test.pt')) 58 | save(meta, os.path.join(self.processed_folder, 'meta.pt')) 59 | return 60 | 61 | def download(self): 62 | makedir_exist_ok(self.raw_folder) 63 | for (url, md5) in self.file: 64 | filename = os.path.basename(url) 65 | download_url(url, self.raw_folder, filename, md5) 66 | extract_file(os.path.join(self.raw_folder, filename)) 67 | return 68 | 69 | def __repr__(self): 70 | fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSplit: {}\nSubset: {}\nTransforms: {}'.format( 71 | self.__class__.__name__, self.__len__(), self.root, self.split, self.subset, self.transform.__repr__()) 72 | return fmt_str 73 | 74 | def make_data(self): 75 | train_img = read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')) 76 | test_img = read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')) 77 | train_label = read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte')) 78 | test_label = read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte')) 79 | train_target, test_target = {'label': train_label}, {'label': test_label} 80 | classes_to_labels = {'label': anytree.Node('U', index=[])} 81 | classes = list(map(str, list(range(10)))) 82 | for c in classes: 83 | make_tree(classes_to_labels['label'], [c]) 84 | classes_size = {'label': make_flat_index(classes_to_labels['label'])} 85 | return (train_img, train_target), (test_img, test_target), (classes_to_labels, classes_size) 86 | 87 | 88 | class EMNIST(MNIST): 89 | data_name = 'EMNIST' 90 | file = [('http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip', '58c8d27c78d21e728a6bc7b3cc06412e')] 91 | 92 | def __init__(self, root, split, subset, transform=None): 93 | super().__init__(root, split, subset, transform) 94 | self.img = self.img[self.subset] 95 | 96 | def make_data(self): 97 | gzip_folder = os.path.join(self.raw_folder, 'gzip') 98 | for gzip_file in os.listdir(gzip_folder): 99 | if gzip_file.endswith('.gz'): 100 | extract_file(os.path.join(gzip_folder, gzip_file)) 101 | subsets = ['byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist'] 102 | train_img, test_img, train_target, test_target = {}, {}, {}, {} 103 | digits_classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 104 | upper_letters_classes = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 105 | 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] 106 | lower_letters_classes = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 107 | 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 108 | merged_classes = ['c', 'i', 'j', 'k', 'l', 'm', 'o', 'p', 's', 'u', 'v', 'w', 'x', 'y', 'z'] 109 | unmerged_classes = list(set(lower_letters_classes) - set(merged_classes)) 110 | classes = {'byclass': digits_classes + upper_letters_classes + lower_letters_classes, 111 | 'bymerge': digits_classes + upper_letters_classes + unmerged_classes, 112 | 'balanced': digits_classes + upper_letters_classes + unmerged_classes, 113 | 'letters': upper_letters_classes + unmerged_classes, 'digits': digits_classes, 114 | 'mnist': digits_classes} 115 | classes_to_labels = {s: anytree.Node('U', index=[]) for s in subsets} 116 | classes_size = {} 117 | for subset in subsets: 118 | train_img[subset] = read_image_file( 119 | os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(subset))) 120 | train_img[subset] = np.transpose(train_img[subset], [0, 2, 1]) 121 | test_img[subset] = read_image_file( 122 | os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(subset))) 123 | test_img[subset] = np.transpose(test_img[subset], [0, 2, 1]) 124 | train_target[subset] = read_label_file( 125 | os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(subset))) 126 | test_target[subset] = read_label_file( 127 | os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(subset))) 128 | for c in classes[subset]: 129 | make_tree(classes_to_labels[subset], c) 130 | classes_size[subset] = make_flat_index(classes_to_labels[subset]) 131 | return (train_img, train_target), (test_img, test_target), (classes_to_labels, classes_size) 132 | 133 | 134 | class FashionMNIST(MNIST): 135 | data_name = 'FashionMNIST' 136 | file = [('http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz', 137 | '8d4fb7e6c68d591d4c3dfef9ec88bf0d'), 138 | ('http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz', 139 | 'bef4ecab320f06d8554ea6380940ec79'), 140 | ('http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz', 141 | '25c81989df183df01b3e8a0aad5dffbe'), 142 | ('http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz', 143 | 'bb300cfdad3c16e7a12a480ee83cd310')] 144 | 145 | def make_data(self): 146 | train_img = read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')) 147 | test_img = read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')) 148 | train_label = read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte')) 149 | test_label = read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte')) 150 | train_target = {'label': train_label} 151 | test_target = {'label': test_label} 152 | classes_to_labels = {'label': anytree.Node('U', index=[])} 153 | classes = ['T-shirt_top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 154 | 'Ankle boot'] 155 | for c in classes: 156 | make_tree(classes_to_labels['label'], c) 157 | classes_size = {'label': make_flat_index(classes_to_labels['label'])} 158 | return (train_img, train_target), (test_img, test_target), (classes_to_labels, classes_size) 159 | 160 | 161 | def get_int(b): 162 | return int(codecs.encode(b, 'hex'), 16) 163 | 164 | 165 | def read_image_file(path): 166 | with open(path, 'rb') as f: 167 | data = f.read() 168 | assert get_int(data[:4]) == 2051 169 | length = get_int(data[4:8]) 170 | num_rows = get_int(data[8:12]) 171 | num_cols = get_int(data[12:16]) 172 | parsed = np.frombuffer(data, dtype=np.uint8, offset=16).reshape((length, num_rows, num_cols)) 173 | return parsed 174 | 175 | 176 | def read_label_file(path): 177 | with open(path, 'rb') as f: 178 | data = f.read() 179 | assert get_int(data[:4]) == 2049 180 | length = get_int(data[4:8]) 181 | parsed = np.frombuffer(data, dtype=np.uint8, offset=8).reshape(length).astype(np.int64) 182 | return parsed 183 | -------------------------------------------------------------------------------- /datasets/omniglot.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import anytree 4 | import torch 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | from utils import check_exists, makedir_exist_ok, save, load 9 | from .utils import IMG_EXTENSIONS 10 | from .utils import download_url, extract_file, make_classes_counts, make_data, make_tree, make_flat_index 11 | 12 | 13 | class Omniglot(Dataset): 14 | data_name = 'Omniglot' 15 | file = [ 16 | ('https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip', 17 | '68d2efa1b9178cc56df9314c21c6e718'), 18 | ('https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip', 19 | '6b91aef0f799c5bb55b94e3f2daec811') 20 | ] 21 | 22 | def __init__(self, root, split, subset, transform=None): 23 | self.root = os.path.expanduser(root) 24 | self.split = split 25 | self.subset = subset 26 | self.transform = transform 27 | if not check_exists(self.processed_folder): 28 | self.process() 29 | self.img, self.target = load(os.path.join(self.processed_folder, '{}.pt'.format(self.split))) 30 | self.target = self.target[self.subset] 31 | self.classes_counts = make_classes_counts(self.target) 32 | self.classes_to_labels, self.classes_size = load(os.path.join(self.processed_folder, 'meta.pt')) 33 | self.classes_to_labels, self.classes_size = self.classes_to_labels[self.subset], self.classes_size[self.subset] 34 | 35 | def __getitem__(self, index): 36 | img, target = Image.open(self.img[index], mode='r').convert('L'), torch.tensor(self.target[index]) 37 | input = {'img': img, self.subset: target} 38 | if self.transform is not None: 39 | input = self.transform(input) 40 | return input 41 | 42 | def __len__(self): 43 | return len(self.img) 44 | 45 | @property 46 | def processed_folder(self): 47 | return os.path.join(self.root, 'processed') 48 | 49 | @property 50 | def raw_folder(self): 51 | return os.path.join(self.root, 'raw') 52 | 53 | def process(self): 54 | if not check_exists(self.raw_folder): 55 | self.download() 56 | train_set, test_set, meta = self.make_data() 57 | save(train_set, os.path.join(self.processed_folder, 'train.pt')) 58 | save(test_set, os.path.join(self.processed_folder, 'test.pt')) 59 | save(meta, os.path.join(self.processed_folder, 'meta.pt')) 60 | return 61 | 62 | def download(self): 63 | makedir_exist_ok(self.raw_folder) 64 | for (url, md5) in self.file: 65 | filename = os.path.basename(url) 66 | download_url(url, self.raw_folder, filename, md5) 67 | extract_file(os.path.join(self.raw_folder, filename)) 68 | return 69 | 70 | def __repr__(self): 71 | fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSplit: {}\nSubset: {}\nTransforms: {}'.format( 72 | self.__class__.__name__, self.__len__(), self.root, self.split, self.subset, self.transform.__repr__()) 73 | return fmt_str 74 | 75 | def make_data(self): 76 | img = make_data(self.raw_folder, IMG_EXTENSIONS) 77 | classes = set() 78 | train_img = [] 79 | test_img = [] 80 | train_label = [] 81 | test_label = [] 82 | for i in range(len(img)): 83 | img_i = img[i] 84 | class_i = '/'.join(os.path.normpath(img_i).split(os.path.sep)[-3:-1]) 85 | classes.add(class_i) 86 | idx_i = int(os.path.splitext(os.path.basename(img_i))[0].split('_')[1]) 87 | if idx_i <= 10: 88 | train_img.append(img_i) 89 | else: 90 | test_img.append(img_i) 91 | classes = sorted(list(classes)) 92 | classes_to_labels = {'label': anytree.Node('U', index=[])} 93 | for c in classes: 94 | make_tree(classes_to_labels['label'], c.split('/')) 95 | classes_size = {'label': make_flat_index(classes_to_labels['label'])} 96 | r = anytree.resolver.Resolver() 97 | for i in range(len(train_img)): 98 | train_img_i = train_img[i] 99 | train_class_i = '/'.join(os.path.normpath(train_img_i).split(os.path.sep)[-3:-1]) 100 | node = r.get(classes_to_labels['label'], train_class_i) 101 | train_label.append(node.flat_index) 102 | for i in range(len(test_img)): 103 | test_img_i = test_img[i] 104 | test_class_i = '/'.join(os.path.normpath(test_img_i).split(os.path.sep)[-3:-1]) 105 | node = r.get(classes_to_labels['label'], test_class_i) 106 | test_label.append(node.flat_index) 107 | train_target = {'label': train_label} 108 | test_target = {'label': test_label} 109 | return (train_img, train_target), (test_img, test_target), (classes_to_labels, classes_size) 110 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | class CustomTransform(object): 2 | def __call__(self, input): 3 | return input['img'] 4 | 5 | def __repr__(self): 6 | return self.__class__.__name__ 7 | 8 | 9 | class BoundingBoxCrop(CustomTransform): 10 | def __init__(self): 11 | pass 12 | 13 | def __call__(self, input): 14 | x, y, width, height = input['bbox'].long().tolist() 15 | left, top, right, bottom = x, y, x + width, y + height 16 | bboxc_img = input['img'].crop((left, top, right, bottom)) 17 | return bboxc_img 18 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import gzip 3 | import hashlib 4 | import os 5 | import tarfile 6 | import zipfile 7 | from collections import Counter 8 | 9 | import anytree 10 | import numpy as np 11 | from PIL import Image 12 | from tqdm import tqdm 13 | 14 | from utils import makedir_exist_ok 15 | from .transforms import * 16 | 17 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 18 | 19 | 20 | def find_classes(dir): 21 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 22 | classes.sort() 23 | classes_to_labels = {classes[i]: i for i in range(len(classes))} 24 | return classes_to_labels 25 | 26 | 27 | def pil_loader(path): 28 | with open(path, 'rb') as f: 29 | img = Image.open(f) 30 | return img.convert('RGB') 31 | 32 | 33 | def accimage_loader(path): 34 | import accimage 35 | try: 36 | return accimage.Image(path) 37 | except IOError: 38 | return pil_loader(path) 39 | 40 | 41 | def default_loader(path): 42 | from torchvision import get_image_backend 43 | if get_image_backend() == 'accimage': 44 | return accimage_loader(path) 45 | else: 46 | return pil_loader(path) 47 | 48 | 49 | def has_file_allowed_extension(filename, extensions): 50 | filename_lower = filename.lower() 51 | return any(filename_lower.endswith(ext) for ext in extensions) 52 | 53 | 54 | def make_classes_counts(label): 55 | label = np.array(label) 56 | if label.ndim > 1: 57 | label = label.sum(axis=tuple([i for i in range(1, label.ndim)])) 58 | classes_counts = Counter(label) 59 | return classes_counts 60 | 61 | 62 | def make_bar_updater(pbar): 63 | def bar_update(count, block_size, total_size): 64 | if pbar.total is None and total_size: 65 | pbar.total = total_size 66 | progress_bytes = count * block_size 67 | pbar.update(progress_bytes - pbar.n) 68 | 69 | return bar_update 70 | 71 | 72 | def calculate_md5(path, chunk_size=1024 * 1024): 73 | md5 = hashlib.md5() 74 | with open(path, 'rb') as f: 75 | for chunk in iter(lambda: f.read(chunk_size), b''): 76 | md5.update(chunk) 77 | return md5.hexdigest() 78 | 79 | 80 | def check_md5(path, md5, **kwargs): 81 | return md5 == calculate_md5(path, **kwargs) 82 | 83 | 84 | def check_integrity(path, md5=None): 85 | if not os.path.isfile(path): 86 | return False 87 | if md5 is None: 88 | return True 89 | return check_md5(path, md5) 90 | 91 | 92 | def download_url(url, root, filename, md5): 93 | from six.moves import urllib 94 | path = os.path.join(root, filename) 95 | makedir_exist_ok(root) 96 | if os.path.isfile(path) and check_integrity(path, md5): 97 | print('Using downloaded and verified file: ' + path) 98 | else: 99 | try: 100 | print('Downloading ' + url + ' to ' + path) 101 | urllib.request.urlretrieve(url, path, reporthook=make_bar_updater(tqdm(unit='B', unit_scale=True))) 102 | except OSError: 103 | if url[:5] == 'https': 104 | url = url.replace('https:', 'http:') 105 | print('Failed download. Trying https -> http instead.' 106 | ' Downloading ' + url + ' to ' + path) 107 | urllib.request.urlretrieve(url, path, reporthook=make_bar_updater(tqdm(unit='B', unit_scale=True))) 108 | if not check_integrity(path, md5): 109 | raise RuntimeError('Not valid downloaded file') 110 | return 111 | 112 | 113 | def extract_file(src, dest=None, delete=False): 114 | print('Extracting {}'.format(src)) 115 | dest = os.path.dirname(src) if dest is None else dest 116 | filename = os.path.basename(src) 117 | if filename.endswith('.zip'): 118 | with zipfile.ZipFile(src, "r") as zip_f: 119 | zip_f.extractall(dest) 120 | elif filename.endswith('.tar'): 121 | with tarfile.open(src) as tar_f: 122 | tar_f.extractall(dest) 123 | elif filename.endswith('.tar.gz') or filename.endswith('.tgz'): 124 | with tarfile.open(src, 'r:gz') as tar_f: 125 | tar_f.extractall(dest) 126 | elif filename.endswith('.gz'): 127 | with open(src.replace('.gz', ''), 'wb') as out_f, gzip.GzipFile(src) as zip_f: 128 | out_f.write(zip_f.read()) 129 | if delete: 130 | os.remove(src) 131 | return 132 | 133 | 134 | def make_data(root, extensions): 135 | path = [] 136 | files = glob.glob('{}/**/*'.format(root), recursive=True) 137 | for file in files: 138 | if has_file_allowed_extension(file, extensions): 139 | path.append(os.path.normpath(file)) 140 | return path 141 | 142 | 143 | def make_img(path, classes_to_labels, extensions=IMG_EXTENSIONS): 144 | img, label = [], [] 145 | classes = [] 146 | leaf_nodes = classes_to_labels.leaves 147 | for node in leaf_nodes: 148 | classes.append(node.name) 149 | for c in sorted(classes): 150 | d = os.path.join(path, c) 151 | if not os.path.isdir(d): 152 | continue 153 | for root, _, filenames in sorted(os.walk(d)): 154 | for filename in sorted(filenames): 155 | if has_file_allowed_extension(filename, extensions): 156 | cur_path = os.path.join(root, filename) 157 | img.append(cur_path) 158 | label.append(anytree.find_by_attr(classes_to_labels, c).flat_index) 159 | return img, label 160 | 161 | 162 | def make_tree(root, name, attribute=None): 163 | if len(name) == 0: 164 | return 165 | if attribute is None: 166 | attribute = {} 167 | this_name = name[0] 168 | next_name = name[1:] 169 | this_attribute = {k: attribute[k][0] for k in attribute} 170 | next_attribute = {k: attribute[k][1:] for k in attribute} 171 | this_node = anytree.find_by_attr(root, this_name) 172 | this_index = root.index + [len(root.children)] 173 | if this_node is None: 174 | this_node = anytree.Node(this_name, parent=root, index=this_index, **this_attribute) 175 | make_tree(this_node, next_name, next_attribute) 176 | return 177 | 178 | 179 | def make_flat_index(root, given=None): 180 | if given: 181 | classes_size = 0 182 | for node in anytree.PreOrderIter(root): 183 | if len(node.children) == 0: 184 | node.flat_index = given.index(node.name) 185 | classes_size = given.index(node.name) + 1 if given.index(node.name) + 1 > classes_size else classes_size 186 | else: 187 | classes_size = 0 188 | for node in anytree.PreOrderIter(root): 189 | if len(node.children) == 0: 190 | node.flat_index = classes_size 191 | classes_size += 1 192 | return classes_size 193 | 194 | 195 | class Compose(object): 196 | def __init__(self, transforms): 197 | self.transforms = transforms 198 | 199 | def __call__(self, input): 200 | for t in self.transforms: 201 | if isinstance(t, CustomTransform): 202 | input['img'] = t(input) 203 | else: 204 | input['img'] = t(input['img']) 205 | return input 206 | 207 | def __repr__(self): 208 | format_string = self.__class__.__name__ + '(' 209 | for t in self.transforms: 210 | format_string += '\n' 211 | format_string += ' {0}'.format(t) 212 | format_string += '\n)' 213 | return format_string 214 | -------------------------------------------------------------------------------- /figures/fedrolex_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/FedRolex/813510997e1802eb756d53baa4229c90ce5ac008/figures/fedrolex_overview.png -------------------------------------------------------------------------------- /figures/table_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/FedRolex/813510997e1802eb756d53baa4229c90ce5ac008/figures/table_overview.png -------------------------------------------------------------------------------- /figures/video_placeholder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/FedRolex/813510997e1802eb756d53baa4229c90ce5ac008/figures/video_placeholder.png -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from collections.abc import Iterable 3 | from numbers import Number 4 | 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | from utils import ntuple 8 | 9 | 10 | class Logger(): 11 | def __init__(self, log_path): 12 | self.log_path = log_path 13 | self.writer = None 14 | self.tracker = defaultdict(int) 15 | self.counter = defaultdict(int) 16 | self.mean = defaultdict(int) 17 | self.history = defaultdict(list) 18 | self.iterator = defaultdict(int) 19 | self.hist = defaultdict(list) 20 | 21 | def safe(self, write): 22 | if write: 23 | self.writer = SummaryWriter(self.log_path) 24 | else: 25 | if self.writer is not None: 26 | self.writer.close() 27 | self.writer = None 28 | for name in self.mean: 29 | self.history[name].append(self.mean[name]) 30 | return 31 | 32 | def reset(self): 33 | self.tracker = defaultdict(int) 34 | self.counter = defaultdict(int) 35 | self.mean = defaultdict(int) 36 | self.hist = defaultdict(list) 37 | return 38 | 39 | def append(self, result, tag, n=1, mean=True): 40 | for k in result: 41 | name = '{}/{}'.format(tag, k) 42 | self.tracker[name] = result[k] 43 | if mean: 44 | if isinstance(result[k], Number): 45 | self.counter[name] += n 46 | if 'local' in name.lower(): 47 | self.hist[name].append(result[k]) 48 | self.mean[name] = ((self.counter[name] - n) * self.mean[name] + n * result[k]) / self.counter[name] 49 | elif isinstance(result[k], Iterable): 50 | if name not in self.mean: 51 | self.counter[name] = [0 for _ in range(len(result[k]))] 52 | self.mean[name] = [0 for _ in range(len(result[k]))] 53 | _ntuple = ntuple(len(result[k])) 54 | n = _ntuple(n) 55 | for i in range(len(result[k])): 56 | self.counter[name][i] += n[i] 57 | if 'local' in name.lower(): 58 | self.hist[name].append(n[i]) 59 | self.mean[name][i] = ((self.counter[name][i] - n[i]) * self.mean[name][i] + n[i] * 60 | result[k][i]) / self.counter[name][i] 61 | else: 62 | raise ValueError('Not valid data type') 63 | return 64 | 65 | def write(self, tag, metric_names): 66 | names = ['{}/{}'.format(tag, k) for k in metric_names] 67 | evaluation_info = [] 68 | for name in names: 69 | tag, k = name.split('/') 70 | if isinstance(self.mean[name], Number): 71 | s = self.mean[name] 72 | evaluation_info.append('{}: {:.4f}'.format(k, s)) 73 | if self.writer is not None: 74 | self.iterator[name] += 1 75 | self.writer.add_scalar(name, s, self.iterator[name]) 76 | elif isinstance(self.mean[name], Iterable): 77 | s = tuple(self.mean[name]) 78 | evaluation_info.append('{}: {}'.format(k, s)) 79 | if self.writer is not None: 80 | self.iterator[name] += 1 81 | self.writer.add_scalar(name, s[0], self.iterator[name]) 82 | if 'local' in name.lower(): 83 | self.writer.add_histogram(f'{name}_hist', self.hist[name], self.iterator[name]) 84 | else: 85 | raise ValueError('Not valid data type') 86 | info_name = '{}/info'.format(tag) 87 | info = self.tracker[info_name] 88 | info[2:2] = evaluation_info 89 | info = ' '.join(info) 90 | print(info) 91 | if self.writer is not None: 92 | self.iterator[info_name] += 1 93 | self.writer.add_text(info_name, info, self.iterator[info_name]) 94 | return 95 | 96 | def flush(self): 97 | self.writer.flush() 98 | return 99 | -------------------------------------------------------------------------------- /main_resnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import datetime 4 | import os 5 | import random 6 | import shutil 7 | import time 8 | 9 | import numpy as np 10 | import ray 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | 14 | from config import cfg 15 | from data import fetch_dataset, make_data_loader, split_dataset 16 | from logger import Logger 17 | from metrics import Metric 18 | from models import resnet 19 | from resnet_client import ResnetClient 20 | from utils import save, to_device, process_control, process_dataset, make_optimizer, make_scheduler, collate 21 | 22 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 23 | 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = False 26 | 27 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = False 30 | parser = argparse.ArgumentParser(description='cfg') 31 | for k in cfg: 32 | exec('parser.add_argument(\'--{0}\', default=cfg[\'{0}\'], type=type(cfg[\'{0}\']))'.format(k)) 33 | parser.add_argument('--control_name', default=None, type=str) 34 | parser.add_argument('--seed', default=None, type=int) 35 | parser.add_argument('--devices', default=None, nargs='+', type=int) 36 | parser.add_argument('--algo', default='roll', type=str) 37 | parser.add_argument('--weighting', default='avg', type=str) 38 | # parser.add_argument('--lr', default=None, type=int) 39 | parser.add_argument('--g_epochs', default=None, type=int) 40 | parser.add_argument('--l_epochs', default=None, type=int) 41 | parser.add_argument('--overlap', default=None, type=float) 42 | 43 | parser.add_argument('--schedule', default=None, nargs='+', type=int) 44 | # parser.add_argument('--exp_name', default=None, type=str) 45 | args = vars(parser.parse_args()) 46 | 47 | cfg['overlap'] = args['overlap'] 48 | cfg['weighting'] = args['weighting'] 49 | cfg['init_seed'] = int(args['seed']) 50 | if args['algo'] == 'roll': 51 | from resnet_server import ResnetServerRoll as Server 52 | elif args['algo'] == 'random': 53 | from resnet_server import ResnetServerRandom as Server 54 | elif args['algo'] == 'static': 55 | from resnet_server import ResnetServerStatic as Server 56 | if args['devices'] is not None: 57 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in args['devices']]) 58 | for k in cfg: 59 | cfg[k] = args[k] 60 | if args['control_name']: 61 | cfg['control'] = {k: v for k, v in zip(cfg['control'].keys(), args['control_name'].split('_'))} \ 62 | if args['control_name'] != 'None' else {} 63 | cfg['control_name'] = '_'.join([cfg['control'][k] for k in cfg['control']]) 64 | cfg['pivot_metric'] = 'Global-Accuracy' 65 | cfg['pivot'] = -float('inf') 66 | cfg['metric_name'] = {'train': {'Local': ['Local-Loss', 'Local-Accuracy']}, 67 | 'test': {'Local': ['Local-Loss', 'Local-Accuracy'], 'Global': ['Global-Loss', 'Global-Accuracy']}} 68 | 69 | ray.init() 70 | rates = None 71 | 72 | 73 | def main(): 74 | process_control() 75 | 76 | if args['schedule'] is not None: 77 | cfg['milestones'] = args['schedule'] 78 | 79 | if args['g_epochs'] is not None and args['l_epochs'] is not None: 80 | cfg['num_epochs'] = {'global': args['g_epochs'], 'local': args['l_epochs']} 81 | cfg['init_seed'] = int(args['seed']) 82 | seeds = list(range(cfg['init_seed'], cfg['init_seed'] + cfg['num_experiments'])) 83 | for i in range(cfg['num_experiments']): 84 | model_tag_list = [str(seeds[i]), cfg['data_name'], cfg['subset'], cfg['model_name'], cfg['control_name']] 85 | cfg['model_tag'] = '_'.join([x for x in model_tag_list if x]) 86 | print('Experiment: {}'.format(cfg['model_tag'])) 87 | print('Seed: {}'.format(cfg['init_seed'])) 88 | run_experiment() 89 | return 90 | 91 | 92 | def run_experiment(): 93 | seed = int(cfg['model_tag'].split('_')[0]) 94 | torch.manual_seed(seed) 95 | torch.cuda.manual_seed(seed) 96 | random.seed(seed) 97 | np.random.seed(seed) 98 | torch.cuda.manual_seed_all(seed) 99 | torch.set_deterministic_debug_mode('default') 100 | os.environ['PYTHONHASHSEED'] = str(seed) 101 | dataset = fetch_dataset(cfg['data_name'], cfg['subset']) 102 | process_dataset(dataset) 103 | global_model = resnet.resnet18(model_rate=cfg["global_model_rate"], cfg=cfg).to(cfg['device']) 104 | optimizer = make_optimizer(global_model, cfg['lr']) 105 | scheduler = make_scheduler(optimizer) 106 | last_epoch = 1 107 | data_split, label_split = split_dataset(dataset, cfg['num_users'], cfg['data_split_mode']) 108 | logger_path = os.path.join('output', 'runs', 'train_{}'.format(f'{cfg["model_tag"]}_{cfg["exp_name"]}')) 109 | logger = Logger(logger_path) 110 | 111 | num_active_users = int(np.ceil(cfg['frac'] * cfg['num_users'])) 112 | cfg['active_user'] = num_active_users 113 | cfg_id = ray.put(cfg) 114 | dataset_ref = { 115 | 'dataset': ray.put(dataset['train']), 116 | 'split': ray.put(data_split['train']), 117 | 'label_split': ray.put(label_split)} 118 | 119 | server = Server(global_model, cfg['model_rate'], dataset_ref, cfg_id) 120 | local = [ResnetClient.remote(logger.log_path, [cfg_id]) for _ in range(num_active_users)] 121 | rates = server.model_rate 122 | for epoch in range(last_epoch, cfg['num_epochs']['global'] + 1): 123 | t0 = time.time() 124 | logger.safe(True) 125 | scheduler.step() 126 | lr = optimizer.param_groups[0]['lr'] 127 | local, param_idx, user_idx = server.broadcast(local, lr) 128 | t1 = time.time() 129 | 130 | num_active_users = len(local) 131 | start_time = time.time() 132 | dt = ray.get([client.step.remote(m, num_active_users, start_time) 133 | for m, client in enumerate(local)]) 134 | 135 | local_parameters = [v for _k, v in enumerate(dt)] 136 | 137 | # for i, p in enumerate(local_parameters): 138 | # with open(f'local_param_pulled_{i}.pickle', 'w') as f: 139 | # pickle.dump(p, f) 140 | # local_parameters = [{k: torch.tensor(v, device=cfg['device']) for k, v in p.items()} for p in local_parameters] 141 | # for lp in local_parameters: 142 | # for k, p in lp.items(): 143 | # print(k, torch.var_mean(p, unbiased=False)) 144 | 145 | # local_parameters = [None for _ in range(num_active_users)] 146 | # for m in range(num_active_users): 147 | # local[m].step(m, num_active_users, start_time) 148 | # local_parameters[m] = local[m].pull() 149 | t2 = time.time() 150 | server.step(local_parameters, param_idx, user_idx) 151 | t3 = time.time() 152 | 153 | global_model = server.global_model 154 | test_model = global_model 155 | t4 = time.time() 156 | 157 | test(dataset['test'], data_split['test'], label_split, test_model, logger, epoch, local) 158 | t5 = time.time() 159 | logger.safe(False) 160 | model_state_dict = global_model.state_dict() 161 | save_result = { 162 | 'cfg': cfg, 'epoch': epoch + 1, 'data_split': data_split, 'label_split': label_split, 163 | 'model_dict': model_state_dict, 'optimizer_dict': optimizer.state_dict(), 164 | 'scheduler_dict': scheduler.state_dict(), 'logger': logger} 165 | save(save_result, './output/model/{}_checkpoint.pt'.format(cfg['model_tag'])) 166 | if cfg['pivot'] < logger.mean['test/{}'.format(cfg['pivot_metric'])]: 167 | cfg['pivot'] = logger.mean['test/{}'.format(cfg['pivot_metric'])] 168 | shutil.copy('./output/model/{}_checkpoint.pt'.format(cfg['model_tag']), 169 | './output/model/{}_best.pt'.format(cfg['model_tag'])) 170 | logger.reset() 171 | t6 = time.time() 172 | print(f'Broadcast Time : {datetime.timedelta(seconds=t1 - t0)}') 173 | print(f'Client Step Time : {datetime.timedelta(seconds=t2 - t1)}') 174 | print(f'Server Step Time : {datetime.timedelta(seconds=t3 - t2)}') 175 | print(f'Stats Time : {datetime.timedelta(seconds=t4 - t3)}') 176 | print(f'Test Time : {datetime.timedelta(seconds=t5 - t4)}') 177 | print(f'Output Copy Time : {datetime.timedelta(seconds=t6 - t5)}') 178 | print(f'<>: {datetime.timedelta(seconds=t6 - t0)}') 179 | logger.safe(False) 180 | [ray.kill(client) for client in local] 181 | return 182 | 183 | 184 | def test(dataset, data_split, label_split, model, logger, epoch, local): 185 | with torch.no_grad(): 186 | model.train(False) 187 | dataset_id = ray.put(dataset) 188 | data_split_id = ray.put(data_split) 189 | model_id = ray.put(copy.deepcopy(model)) 190 | label_split_id = ray.put(label_split) 191 | all_res = [] 192 | for m in range(0, cfg['num_users'], len(local)): 193 | processes = [] 194 | for k in range(m, min(m + len(local), cfg['num_users'])): 195 | processes.append(local[k % len(local)] 196 | .test_model_for_user.remote(k, 197 | [dataset_id, data_split_id, model_id, label_split_id])) 198 | results = ray.get(processes) 199 | for result in results: 200 | all_res.append(result) 201 | for r in result: 202 | evaluation, input_size = r 203 | logger.append(evaluation, 'test', input_size) 204 | # Save all_res for plotting 205 | # torch.save((all_res, rates), f'./output/runs/{cfg["model_tag"]}_real_world.pt') 206 | data_loader = make_data_loader({'test': dataset})['test'] 207 | metric = Metric() 208 | model.cuda() 209 | for i, data_input in enumerate(data_loader): 210 | data_input = collate(data_input) 211 | input_size = data_input['img'].size(0) 212 | data_input = to_device(data_input, 'cuda') 213 | output = model(data_input) 214 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss'] 215 | evaluation = metric.evaluate(cfg['metric_name']['test']['Global'], data_input, output) 216 | logger.append(evaluation, 'test', input_size) 217 | info = {'info': ['Model: {}'.format(cfg['model_tag']), 218 | 'Test Epoch: {}({:.0f}%)'.format(epoch, 100.)]} 219 | logger.append(info, 'test', mean=False) 220 | logger.write('test', cfg['metric_name']['test']['Local'] + cfg['metric_name']['test']['Global']) 221 | return 222 | 223 | 224 | if __name__ == "__main__": 225 | main() 226 | -------------------------------------------------------------------------------- /main_transformer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import os 5 | import random 6 | import shutil 7 | import time 8 | 9 | import numpy as np 10 | import ray 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | from tqdm import tqdm 14 | 15 | import models 16 | from config import cfg 17 | from data import fetch_dataset 18 | from logger import Logger 19 | from transformer_client import TransformerClient 20 | from utils import save, process_control, process_dataset, make_optimizer, make_scheduler 21 | 22 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 23 | 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = False 26 | 27 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = False 30 | parser = argparse.ArgumentParser(description='cfg') 31 | for k in cfg: 32 | exec('parser.add_argument(\'--{0}\', default=cfg[\'{0}\'], type=type(cfg[\'{0}\']))'.format(k)) 33 | parser.add_argument('--control_name', default=None, type=str) 34 | parser.add_argument('--seed', default=None, type=int) 35 | parser.add_argument('--devices', default=None, nargs='+', type=int) 36 | parser.add_argument('--algo', default='roll', type=str) 37 | # parser.add_argument('--lr', default=None, type=int) 38 | parser.add_argument('--g_epochs', default=None, type=int) 39 | parser.add_argument('--l_epochs', default=None, type=int) 40 | parser.add_argument('--schedule', default=None, nargs='+', type=int) 41 | # parser.add_argument('--exp_name', default=None, type=str) 42 | args = vars(parser.parse_args()) 43 | cfg['init_seed'] = int(args['seed']) 44 | if args['algo'] == 'roll': 45 | from transformer_server import TransformerServerRollSO as Server 46 | elif args['algo'] == 'random': 47 | from transformer_server import TransformerServerRandomSO as Server 48 | elif args['algo'] == 'static': 49 | from transformer_server import TransformerServerStaticSO as Server 50 | 51 | args = vars(parser.parse_args()) 52 | cfg['init_seed'] = int(args['seed']) 53 | if args['devices'] is not None: 54 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in args['devices']]) 55 | 56 | for k in cfg: 57 | cfg[k] = args[k] 58 | if args['control_name']: 59 | cfg['control'] = {k: v for k, v in zip(cfg['control'].keys(), args['control_name'].split('_'))} \ 60 | if args['control_name'] != 'None' else {} 61 | cfg['control_name'] = '_'.join([cfg['control'][k] for k in cfg['control']]) 62 | cfg['pivot_metric'] = 'Global-Perplexity' 63 | cfg['pivot'] = float('inf') 64 | cfg['metric_name'] = {'train': {'Local': ['Local-Loss', 'Local-Perplexity']}, 65 | 'test': {'Global': ['Global-Loss', 'Global-Accuracy'], 66 | 'Local': ['Local-Loss', 'Local-Accuracy']}} 67 | # ray.init(_temp_dir='/egr/research-zhanglambda/samiul/tmp') 68 | # ray.init(_temp_dir='/localscratch/alamsami/tmp', object_store_memory=10**11) 69 | 70 | ray.init( 71 | _temp_dir='/localscratch/alamsami/tmp', object_store_memory=10 ** 11, 72 | _system_config={ 73 | "object_spilling_config": json.dumps( 74 | { 75 | "type": "filesystem", 76 | "params": { 77 | "directory_path": '/egr/research-zhanglambda/samiul/tmp', 78 | } 79 | }, 80 | ) 81 | }, 82 | ) 83 | 84 | 85 | def main(): 86 | process_control() 87 | if args['schedule'] is not None: 88 | cfg['milestones'] = args['schedule'] 89 | 90 | if args['g_epochs'] is not None and args['l_epochs'] is not None: 91 | cfg['num_epochs'] = {'global': args['g_epochs'], 'local': args['l_epochs']} 92 | cfg['init_seed'] = int(args['seed']) 93 | seeds = list(range(cfg['init_seed'], cfg['init_seed'] + cfg['num_experiments'])) 94 | for i in range(cfg['num_experiments']): 95 | model_tag_list = [str(seeds[i]), cfg['data_name'], cfg['subset'], cfg['model_name'], cfg['control_name']] 96 | cfg['model_tag'] = '_'.join([x for x in model_tag_list if x]) 97 | print('Experiment: {}'.format(cfg['model_tag'])) 98 | print('Seed: {}'.format(cfg['init_seed'])) 99 | run_experiment() 100 | return 101 | 102 | 103 | def run_experiment(): 104 | seed = int(cfg['model_tag'].split('_')[0]) 105 | torch.manual_seed(seed) 106 | torch.cuda.manual_seed(seed) 107 | random.seed(seed) 108 | np.random.seed(seed) 109 | torch.cuda.manual_seed_all(seed) 110 | torch.set_deterministic_debug_mode('default') 111 | os.environ['PYTHONHASHSEED'] = str(seed) 112 | dataset = fetch_dataset(cfg['data_name'], cfg['subset']) 113 | process_dataset(dataset) 114 | global_model = models.transformer_nwp(model_rate=cfg["global_model_rate"], cfg=cfg) 115 | optimizer = make_optimizer(global_model, cfg['lr']) 116 | scheduler = make_scheduler(optimizer) 117 | last_epoch = 1 118 | data_split, label_split = dataset['train'], dataset['train'] 119 | num_active_users = cfg['active_user'] 120 | 121 | logger_path = os.path.join('output', 'runs', 'train_{}'.format(f'{cfg["model_tag"]}_{cfg["exp_name"]}')) 122 | logger = Logger(logger_path) 123 | 124 | cfg_id = ray.put(cfg) 125 | dataset_ref = dataset['train'] 126 | 127 | server = Server(global_model, cfg['model_rate'], dataset_ref, cfg_id) 128 | num_users_per_step = 8 129 | local = [TransformerClient.remote(logger.log_path, [cfg_id]) for _ in range(num_users_per_step)] 130 | # local = [TransformerClient(logger.log_path, [cfg_id]) for _ in range(num_active_users)] 131 | 132 | for epoch in range(last_epoch, cfg['num_epochs']['global'] + 1): 133 | t0 = time.time() 134 | logger.safe(True) 135 | scheduler.step() 136 | lr = optimizer.param_groups[0]['lr'] 137 | local, configs = server.broadcast(local, lr) 138 | t1 = time.time() 139 | 140 | start_time = time.time() 141 | local_parameters = [] 142 | for user_start_idx in range(0, num_active_users, num_users_per_step): 143 | idxs = list(range(user_start_idx, min(num_active_users, user_start_idx + num_users_per_step))) 144 | sel_cfg = [configs[idx] for idx in idxs] 145 | [client.update.remote(*config) for client, config in zip(local, sel_cfg)] 146 | dt = ray.get([client.step.remote(user_start_idx + m, num_active_users, start_time) 147 | for m, client in enumerate(local[:len(sel_cfg)])]) 148 | local_parameters += [v for _k, v in enumerate(dt)] 149 | torch.cuda.empty_cache() 150 | t2 = time.time() 151 | server.step(local_parameters) 152 | t3 = time.time() 153 | 154 | global_model = server.global_model 155 | test_model = global_model 156 | t4 = time.time() 157 | if True or epoch % 20 == 1: 158 | test(dataset['test'], test_model, logger, epoch, local) 159 | t5 = time.time() 160 | logger.safe(False) 161 | model_state_dict = global_model.state_dict() 162 | if epoch % 20 == 1: 163 | save_result = { 164 | 'cfg': cfg, 'epoch': epoch + 1, 'data_split': data_split, 'label_split': label_split, 165 | 'model_dict': model_state_dict, 'optimizer_dict': optimizer.state_dict(), 166 | 'scheduler_dict': scheduler.state_dict(), 'logger': logger} 167 | save(save_result, './output/model/{}_checkpoint.pt'.format(cfg['model_tag'])) 168 | if cfg['pivot'] < logger.mean['test/{}'.format(cfg['pivot_metric'])]: 169 | cfg['pivot'] = logger.mean['test/{}'.format(cfg['pivot_metric'])] 170 | shutil.copy('./output/model/{}_checkpoint.pt'.format(cfg['model_tag']), 171 | './output/model/{}_best.pt'.format(cfg['model_tag'])) 172 | logger.reset() 173 | t6 = time.time() 174 | print(f'Broadcast Time : {datetime.timedelta(seconds=t1 - t0)}') 175 | print(f'Client Step Time : {datetime.timedelta(seconds=t2 - t1)}') 176 | print(f'Server Step Time : {datetime.timedelta(seconds=t3 - t2)}') 177 | print(f'Stats Time : {datetime.timedelta(seconds=t4 - t3)}') 178 | print(f'Test Time : {datetime.timedelta(seconds=t5 - t4)}') 179 | print(f'Output Copy Time : {datetime.timedelta(seconds=t6 - t5)}') 180 | print(f'<>: {datetime.timedelta(seconds=t6 - t0)}') 181 | test_model = None 182 | global_model = None 183 | model_state_dict = None 184 | torch.cuda.empty_cache() 185 | logger.safe(False) 186 | [ray.kill(client) for client in local] 187 | return 188 | 189 | 190 | def test(dataset, model, logger, epoch, local): 191 | num_users_per_step = len(local) 192 | num_test_users = 200 # len(dataset) 193 | if epoch % 600 == 0: 194 | num_test_users = 5000 195 | model_id = ray.put(model) 196 | with torch.no_grad(): 197 | model.train(False) 198 | sel_cl = np.random.choice(len(dataset), num_test_users) 199 | for user_start_idx in tqdm(range(0, num_test_users, num_users_per_step)): 200 | processes = [] 201 | for user_idx in range(user_start_idx, min(user_start_idx + num_users_per_step, num_test_users)): 202 | processes.append(local[user_idx % num_users_per_step] 203 | .test_model_for_user 204 | .remote(user_idx, 205 | [ray.put(dataset[sel_cl[user_idx]]), model_id])) 206 | results = ray.get(processes) 207 | for result in results: 208 | if result: 209 | evaluation, input_size = result[0] 210 | logger.append(evaluation, 'test', input_size) 211 | 212 | info = {'info': ['Model: {}'.format(cfg['model_tag']), 213 | 'Test Epoch: {}({:.0f}%)'.format(epoch, 100.)]} 214 | logger.append(info, 'test', mean=False) 215 | logger.write('test', cfg['metric_name']['test']['Local']) 216 | return evaluation 217 | 218 | 219 | if __name__ == "__main__": 220 | main() 221 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import * 2 | -------------------------------------------------------------------------------- /metrics/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from utils import recur 5 | 6 | 7 | def Accuracy(output, target, topk=1): 8 | with torch.no_grad(): 9 | batch_size = target.size(0) 10 | pred_k = output.topk(topk, 1, True, True)[1] 11 | correct_k = pred_k.eq(target.view(-1, 1).expand_as(pred_k)).float().sum() 12 | acc = (correct_k * (100.0 / batch_size)).item() 13 | return acc 14 | 15 | 16 | def Perplexity(output, target): 17 | with torch.no_grad(): 18 | # label_mask = torch.arange(output.size(1), device=output.device)[output.sum(dim=[0,2]) != 0] 19 | # label_map = output.new_zeros(output.size(1), device=output.device, dtype=torch.long) 20 | # output = output[:, label_mask,] 21 | # label_map[label_mask] = torch.arange(output.size(1), device=output.device) 22 | # target = label_map[target] 23 | ce = F.cross_entropy(output, target) 24 | perplexity = torch.exp(ce).item() 25 | return perplexity 26 | 27 | 28 | class Metric(object): 29 | def __init__(self): 30 | self.metric = {'Loss': (lambda input, output: output['loss'].item()), 31 | 'Local-Loss': (lambda input, output: output['loss'].item()), 32 | 'Global-Loss': (lambda input, output: output['loss'].item()), 33 | 'Accuracy': (lambda input, output: recur(Accuracy, output['score'], input['label'])), 34 | 'Local-Accuracy': (lambda input, output: recur(Accuracy, output['score'], input['label'])), 35 | 'Global-Accuracy': (lambda input, output: recur(Accuracy, output['score'], input['label'])), 36 | 'Perplexity': (lambda input, output: recur(Perplexity, output['score'], input['label'])), 37 | 'Local-Perplexity': (lambda input, output: recur(Perplexity, output['score'], input['label'])), 38 | 'Global-Perplexity': (lambda input, output: recur(Perplexity, output['score'], input['label']))} 39 | 40 | def evaluate(self, metric_names, input, output): 41 | evaluation = {} 42 | for metric_name in metric_names: 43 | evaluation[metric_name] = self.metric[metric_name](input, output) 44 | return evaluation 45 | -------------------------------------------------------------------------------- /models/conv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from config import cfg 7 | from modules import Scaler 8 | from .utils import init_param 9 | 10 | 11 | class Conv(nn.Module): 12 | def __init__(self, data_shape, hidden_size, classes_size, rate=1, track=False): 13 | super().__init__() 14 | if cfg['norm'] == 'bn': 15 | norm = nn.BatchNorm2d(hidden_size[0], momentum=None, track_running_stats=track) 16 | elif cfg['norm'] == 'in': 17 | norm = nn.GroupNorm(hidden_size[0], hidden_size[0]) 18 | elif cfg['norm'] == 'ln': 19 | norm = nn.GroupNorm(1, hidden_size[0]) 20 | elif cfg['norm'] == 'gn': 21 | norm = nn.GroupNorm(4, hidden_size[0]) 22 | elif cfg['norm'] == 'none': 23 | norm = nn.Identity() 24 | else: 25 | raise ValueError('Not valid norm') 26 | if cfg['scale']: 27 | scaler = Scaler(rate) 28 | else: 29 | scaler = nn.Identity() 30 | blocks = [nn.Conv2d(data_shape[0], hidden_size[0], 3, 1, 1), 31 | scaler, 32 | norm, 33 | nn.ReLU(inplace=True), 34 | nn.MaxPool2d(2)] 35 | for i in range(len(hidden_size) - 1): 36 | if cfg['norm'] == 'bn': 37 | norm = nn.BatchNorm2d(hidden_size[i + 1], momentum=None, track_running_stats=track) 38 | elif cfg['norm'] == 'in': 39 | norm = nn.GroupNorm(hidden_size[i + 1], hidden_size[i + 1]) 40 | elif cfg['norm'] == 'ln': 41 | norm = nn.GroupNorm(1, hidden_size[i + 1]) 42 | elif cfg['norm'] == 'gn': 43 | norm = nn.GroupNorm(4, hidden_size[i + 1]) 44 | elif cfg['norm'] == 'none': 45 | norm = nn.Identity() 46 | else: 47 | raise ValueError('Not valid norm') 48 | if cfg['scale']: 49 | scaler = Scaler(rate) 50 | else: 51 | scaler = nn.Identity() 52 | blocks.extend([nn.Conv2d(hidden_size[i], hidden_size[i + 1], 3, 1, 1), 53 | scaler, 54 | norm, 55 | nn.ReLU(inplace=True), 56 | nn.MaxPool2d(2)]) 57 | blocks = blocks[:-1] 58 | blocks.extend([nn.AdaptiveAvgPool2d(1), 59 | nn.Flatten(), 60 | nn.Linear(hidden_size[-1], classes_size)]) 61 | self.blocks = nn.Sequential(*blocks) 62 | 63 | def forward(self, input): 64 | output = {'loss': torch.tensor(0, device=cfg['device'], dtype=torch.float32)} 65 | x = input['img'] 66 | out = self.blocks(x) 67 | if 'label_split' in input and cfg['mask']: 68 | label_mask = torch.zeros(cfg['classes_size'], device=out.device) 69 | label_mask[input['label_split']] = 1 70 | out = out.masked_fill(label_mask == 0, 0) 71 | output['score'] = out 72 | output['loss'] = F.cross_entropy(out, input['label'], reduction='mean') 73 | return output 74 | 75 | 76 | def conv(model_rate=1, track=False): 77 | data_shape = cfg['data_shape'] 78 | hidden_size = [int(np.ceil(model_rate * x)) for x in cfg['conv']['hidden_size']] 79 | classes_size = cfg['classes_size'] 80 | scaler_rate = model_rate / cfg['global_model_rate'] 81 | model = Conv(data_shape, hidden_size, classes_size, scaler_rate, track) 82 | model.apply(init_param) 83 | return model 84 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from modules import Scaler 7 | from .utils import init_param 8 | 9 | 10 | class Block(nn.Module): 11 | expansion = 1 12 | 13 | def __init__(self, in_planes, planes, stride, rate, track, cfg): 14 | super(Block, self).__init__() 15 | if cfg['norm'] == 'bn': 16 | n1 = nn.BatchNorm2d(in_planes, momentum=None, track_running_stats=track) 17 | n2 = nn.BatchNorm2d(planes, momentum=None, track_running_stats=track) 18 | elif cfg['norm'] == 'in': 19 | n1 = nn.GroupNorm(in_planes, in_planes) 20 | n2 = nn.GroupNorm(planes, planes) 21 | elif cfg['norm'] == 'ln': 22 | n1 = nn.GroupNorm(1, in_planes) 23 | n2 = nn.GroupNorm(1, planes) 24 | elif cfg['norm'] == 'gn': 25 | n1 = nn.GroupNorm(4, in_planes) 26 | n2 = nn.GroupNorm(4, planes) 27 | elif cfg['norm'] == 'none': 28 | n1 = nn.Identity() 29 | n2 = nn.Identity() 30 | else: 31 | raise ValueError('Not valid norm') 32 | self.n1 = n1 33 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 34 | self.n2 = n2 35 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 36 | if cfg['scale']: 37 | self.scaler = Scaler(rate) 38 | else: 39 | self.scaler = nn.Identity() 40 | 41 | if stride != 1 or in_planes != self.expansion * planes: 42 | self.shortcut = nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False) 43 | 44 | def forward(self, x): 45 | out = F.relu(self.n1(self.scaler(x))) 46 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 47 | out = self.conv1(out) 48 | out = self.conv2(F.relu(self.n2(self.scaler(out)))) 49 | out += shortcut 50 | return out 51 | 52 | 53 | class Bottleneck(nn.Module): 54 | expansion = 4 55 | 56 | def __init__(self, in_planes, planes, stride, rate, track, cfg): 57 | super(Bottleneck, self).__init__() 58 | if cfg['norm'] == 'bn': 59 | n1 = nn.BatchNorm2d(in_planes, momentum=None, track_running_stats=track) 60 | n2 = nn.BatchNorm2d(planes, momentum=None, track_running_stats=track) 61 | n3 = nn.BatchNorm2d(planes, momentum=None, track_running_stats=track) 62 | elif cfg['norm'] == 'in': 63 | n1 = nn.GroupNorm(in_planes, in_planes) 64 | n2 = nn.GroupNorm(planes, planes) 65 | n3 = nn.GroupNorm(planes, planes) 66 | elif cfg['norm'] == 'ln': 67 | n1 = nn.GroupNorm(1, in_planes) 68 | n2 = nn.GroupNorm(1, planes) 69 | n3 = nn.GroupNorm(1, planes) 70 | elif cfg['norm'] == 'gn': 71 | n1 = nn.GroupNorm(4, in_planes) 72 | n2 = nn.GroupNorm(4, planes) 73 | n3 = nn.GroupNorm(4, planes) 74 | elif cfg['norm'] == 'none': 75 | n1 = nn.Identity() 76 | n2 = nn.Identity() 77 | n3 = nn.Identity() 78 | else: 79 | raise ValueError('Not valid norm') 80 | self.n1 = n1 81 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 82 | self.n2 = n2 83 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 84 | self.n3 = n3 85 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 86 | if cfg['scale']: 87 | self.scaler = Scaler(rate) 88 | else: 89 | self.scaler = nn.Identity() 90 | 91 | if stride != 1 or in_planes != self.expansion * planes: 92 | self.shortcut = nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False) 93 | 94 | def forward(self, x): 95 | out = F.relu(self.n1(self.scaler(x))) 96 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 97 | out = self.conv1(out) 98 | out = self.conv2(F.relu(self.n2(self.scaler(out)))) 99 | out = self.conv3(F.relu(self.n3(self.scaler(out)))) 100 | out += shortcut 101 | return out 102 | 103 | 104 | class ResNet(nn.Module): 105 | def __init__(self, data_shape, hidden_size, block, num_blocks, num_classes, rate, track, cfg): 106 | super(ResNet, self).__init__() 107 | self.cfg = cfg 108 | self.in_planes = hidden_size[0] 109 | self.conv1 = nn.Conv2d(data_shape[0], hidden_size[0], kernel_size=3, stride=1, padding=1, bias=False) 110 | self.layer1 = self._make_layer(block, hidden_size[0], num_blocks[0], stride=1, rate=rate, track=track) 111 | self.layer2 = self._make_layer(block, hidden_size[1], num_blocks[1], stride=2, rate=rate, track=track) 112 | self.layer3 = self._make_layer(block, hidden_size[2], num_blocks[2], stride=2, rate=rate, track=track) 113 | self.layer4 = self._make_layer(block, hidden_size[3], num_blocks[3], stride=2, rate=rate, track=track) 114 | if cfg['norm'] == 'bn': 115 | n4 = nn.BatchNorm2d(hidden_size[3] * block.expansion, momentum=None, track_running_stats=track) 116 | elif cfg['norm'] == 'in': 117 | n4 = nn.GroupNorm(hidden_size[3] * block.expansion, hidden_size[3] * block.expansion) 118 | elif cfg['norm'] == 'ln': 119 | n4 = nn.GroupNorm(1, hidden_size[3] * block.expansion) 120 | elif cfg['norm'] == 'gn': 121 | n4 = nn.GroupNorm(4, hidden_size[3] * block.expansion) 122 | elif cfg['norm'] == 'none': 123 | n4 = nn.Identity() 124 | else: 125 | raise ValueError('Not valid norm') 126 | self.n4 = n4 127 | if cfg['scale']: 128 | self.scaler = Scaler(rate) 129 | else: 130 | self.scaler = nn.Identity() 131 | self.linear = nn.Linear(hidden_size[3] * block.expansion, num_classes) 132 | 133 | def _make_layer(self, block, planes, num_blocks, stride, rate, track): 134 | cfg = self.cfg 135 | strides = [stride] + [1] * (num_blocks - 1) 136 | layers = [] 137 | for stride in strides: 138 | layers.append(block(self.in_planes, planes, stride, rate, track, cfg)) 139 | self.in_planes = planes * block.expansion 140 | return nn.Sequential(*layers) 141 | 142 | def forward(self, input): 143 | cfg = self.cfg 144 | output = {} 145 | x = input['img'] 146 | out = self.conv1(x) 147 | out = self.layer1(out) 148 | out = self.layer2(out) 149 | out = self.layer3(out) 150 | out = self.layer4(out) 151 | out = F.relu(self.n4(self.scaler(out))) 152 | out = F.adaptive_avg_pool2d(out, 1) 153 | out = out.view(out.size(0), -1) 154 | out = self.linear(out) 155 | if 'label_split' in input and cfg['mask']: 156 | label_mask = torch.zeros(cfg['classes_size'], device=out.device) 157 | label_mask[input['label_split']] = 1 158 | out = out.masked_fill(label_mask == 0, 0) 159 | output['score'] = out 160 | output['loss'] = F.cross_entropy(output['score'], input['label']) 161 | return output 162 | 163 | 164 | def resnet18(model_rate=1, track=False, cfg=None): 165 | data_shape = cfg['data_shape'] 166 | classes_size = cfg['classes_size'] 167 | hidden_size = [int(np.ceil(model_rate * x)) for x in cfg['resnet']['hidden_size']] 168 | scaler_rate = model_rate / cfg['global_model_rate'] 169 | model = ResNet(data_shape, hidden_size, Block, [2, 2, 2, 2], classes_size, scaler_rate, track, cfg) 170 | model.apply(init_param) 171 | return model 172 | 173 | 174 | def resnet34(model_rate=1, track=False, cfg=None): 175 | data_shape = cfg['data_shape'] 176 | classes_size = cfg['classes_size'] 177 | hidden_size = [int(np.ceil(model_rate * x)) for x in cfg['resnet']['hidden_size']] 178 | scaler_rate = model_rate / cfg['global_model_rate'] 179 | model = ResNet(data_shape, hidden_size, Block, [3, 4, 6, 3], classes_size, scaler_rate, track, cfg) 180 | model.apply(init_param) 181 | return model 182 | 183 | 184 | def resnet50(model_rate=1, track=False, cfg=None): 185 | data_shape = cfg['data_shape'] 186 | classes_size = cfg['classes_size'] 187 | hidden_size = [int(np.ceil(model_rate * x)) for x in cfg['resnet']['hidden_size']] 188 | scaler_rate = model_rate / cfg['global_model_rate'] 189 | model = ResNet(data_shape, hidden_size, Bottleneck, [3, 4, 6, 3], classes_size, scaler_rate, track, cfg) 190 | model.apply(init_param) 191 | return model 192 | 193 | 194 | def resnet101(model_rate=1, track=False, cfg=None): 195 | data_shape = cfg['data_shape'] 196 | classes_size = cfg['classes_size'] 197 | hidden_size = [int(np.ceil(model_rate * x)) for x in cfg['resnet']['hidden_size']] 198 | scaler_rate = model_rate / cfg['global_model_rate'] 199 | model = ResNet(data_shape, hidden_size, Bottleneck, [3, 4, 23, 3], classes_size, scaler_rate, track, cfg) 200 | model.apply(init_param) 201 | return model 202 | 203 | 204 | def resnet152(model_rate=1, track=False, cfg=None): 205 | data_shape = cfg['data_shape'] 206 | classes_size = cfg['classes_size'] 207 | hidden_size = [int(np.ceil(model_rate * x)) for x in cfg['resnet']['hidden_size']] 208 | scaler_rate = model_rate / cfg['global_model_rate'] 209 | model = ResNet(data_shape, hidden_size, Bottleneck, [3, 8, 36, 3], classes_size, scaler_rate, track, cfg) 210 | model.apply(init_param) 211 | return model 212 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import TransformerEncoder 6 | 7 | from modules import Scaler 8 | from .utils import init_param 9 | 10 | 11 | class PositionalEmbedding(nn.Module): 12 | def __init__(self, embedding_size, cfg): 13 | super().__init__() 14 | self.positional_embedding = nn.Embedding(cfg['bptt'], embedding_size) 15 | 16 | def forward(self, x): 17 | N, S = x.size() 18 | position = torch.arange(S, dtype=torch.long, device=x.device).unsqueeze(0).expand((N, S)) 19 | x = self.positional_embedding(position) 20 | return x 21 | 22 | 23 | class TransformerEmbedding(nn.Module): 24 | def __init__(self, num_tokens, embedding_size, dropout, rate, cfg): 25 | super().__init__() 26 | self.num_tokens = num_tokens 27 | self.embedding_size = embedding_size 28 | self.positional_embedding = PositionalEmbedding(embedding_size, cfg) 29 | self.embedding = nn.Embedding(num_tokens + 1, embedding_size) 30 | self.norm = nn.LayerNorm(embedding_size) 31 | self.dropout = nn.Dropout(dropout) 32 | self.scaler = Scaler(rate) 33 | 34 | def forward(self, src): 35 | src = self.scaler(self.embedding(src)) + self.scaler(self.positional_embedding(src)) 36 | src = self.dropout(self.norm(src)) 37 | return src 38 | 39 | 40 | class ScaledDotProduct(nn.Module): 41 | def __init__(self, temperature): 42 | super().__init__() 43 | self.temperature = temperature 44 | 45 | def forward(self, q, k, v, mask=None): 46 | scores = q.matmul(k.transpose(-2, -1)) / self.temperature 47 | seq_len = scores.shape[-1] 48 | h = scores.shape[0] 49 | mask = torch.tril(torch.ones((h, seq_len, seq_len))).to(str(scores.device)) 50 | scores = scores.masked_fill(mask == 0, float('-inf')) 51 | attn = F.softmax(scores, dim=-1) 52 | output = torch.matmul(attn, v) 53 | return output, attn 54 | 55 | 56 | class MultiheadAttention(nn.Module): 57 | def __init__(self, embedding_size, num_heads, rate): 58 | super().__init__() 59 | self.embedding_size = embedding_size 60 | self.num_heads = num_heads 61 | self.linear_q = nn.Linear(embedding_size, embedding_size) 62 | self.linear_k = nn.Linear(embedding_size, embedding_size) 63 | self.linear_v = nn.Linear(embedding_size, embedding_size) 64 | self.linear_o = nn.Linear(embedding_size, embedding_size) 65 | self.attention = ScaledDotProduct(temperature=(embedding_size // num_heads) ** 0.5) 66 | self.scaler = Scaler(rate) 67 | 68 | def _reshape_to_batches(self, x): 69 | batch_size, seq_len, in_feature = x.size() 70 | sub_dim = in_feature // self.num_heads 71 | return x.reshape(batch_size, seq_len, self.num_heads, sub_dim).permute(0, 2, 1, 3) \ 72 | .reshape(batch_size * self.num_heads, seq_len, sub_dim) 73 | 74 | def _reshape_from_batches(self, x): 75 | batch_size, seq_len, in_feature = x.size() 76 | batch_size //= self.num_heads 77 | out_dim = in_feature * self.num_heads 78 | return x.reshape(batch_size, self.num_heads, seq_len, in_feature).permute(0, 2, 1, 3) \ 79 | .reshape(batch_size, seq_len, out_dim) 80 | 81 | def forward(self, q, k, v, mask=None): 82 | q, k, v = self.scaler(self.linear_q(q)), self.scaler(self.linear_k(k)), self.scaler(self.linear_v(v)) 83 | q, k, v = self._reshape_to_batches(q), self._reshape_to_batches(k), self._reshape_to_batches(v) 84 | q, attn = self.attention(q, k, v, mask) 85 | q = self._reshape_from_batches(q) 86 | q = self.scaler(self.linear_o(q)) 87 | return q, attn 88 | 89 | 90 | class TransformerEncoderLayer(nn.Module): 91 | def __init__(self, embedding_size, num_heads, hidden_size, dropout, rate): 92 | super().__init__() 93 | self.mha = MultiheadAttention(embedding_size, num_heads, rate=rate) 94 | self.dropout = nn.Dropout(dropout) 95 | self.norm1 = nn.LayerNorm(embedding_size) 96 | self.linear1 = nn.Linear(embedding_size, hidden_size) 97 | self.dropout1 = nn.Dropout(dropout) 98 | self.linear2 = nn.Linear(hidden_size, embedding_size) 99 | self.dropout2 = nn.Dropout(dropout) 100 | self.norm2 = nn.LayerNorm(embedding_size) 101 | self.scaler = Scaler(rate) 102 | self.activation = nn.GELU() 103 | self.init_param() 104 | 105 | def init_param(self): 106 | self.linear1.weight.data.normal_(mean=0.0, std=0.02) 107 | self.linear2.weight.data.normal_(mean=0.0, std=0.02) 108 | self.norm1.weight.data.fill_(1.0) 109 | self.norm1.bias.data.zero_() 110 | self.norm2.weight.data.fill_(1.0) 111 | self.norm2.bias.data.zero_() 112 | return 113 | 114 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 115 | attn_output, _ = self.mha(src, src, src, mask=src_mask) 116 | src = src + self.dropout(attn_output) 117 | src = self.norm1(src) 118 | src2 = self.scaler(self.linear2(self.dropout1(self.activation(self.scaler(self.linear1(src)))))) 119 | src = src + self.dropout2(src2) 120 | src = self.norm2(src) 121 | return src 122 | 123 | 124 | class Decoder(nn.Module): 125 | def __init__(self, num_tokens, embedding_size, rate): 126 | super().__init__() 127 | self.linear1 = nn.Linear(embedding_size, embedding_size) 128 | self.scaler = Scaler(rate) 129 | self.activation = nn.GELU() 130 | self.norm1 = nn.LayerNorm(embedding_size) 131 | self.linear2 = nn.Linear(embedding_size, num_tokens) 132 | 133 | def forward(self, src): 134 | out = self.linear2(self.norm1(self.activation(self.scaler(self.linear1(src))))) 135 | return out 136 | 137 | 138 | class Transformer(nn.Module): 139 | def __init__(self, num_tokens, embedding_size, num_heads, hidden_size, num_layers, dropout, rate, cfg): 140 | super().__init__() 141 | self.num_tokens = num_tokens 142 | self.transformer_embedding = TransformerEmbedding(num_tokens, embedding_size, dropout, rate, cfg) 143 | encoder_layers = TransformerEncoderLayer(embedding_size, num_heads, hidden_size, dropout, rate) 144 | self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers) 145 | self.decoder = Decoder(num_tokens, embedding_size, rate) 146 | self.cfg = cfg 147 | 148 | def forward(self, input): 149 | cfg = self.cfg 150 | output = {} 151 | src = input['label'].clone() 152 | N, S = src.size() 153 | d = torch.distributions.bernoulli.Bernoulli(probs=cfg['mask_rate']) 154 | mask = d.sample((N, S)).to(src.device) 155 | src = src.masked_fill(mask == 1, self.num_tokens).detach() 156 | src = self.transformer_embedding(src) 157 | src = self.transformer_encoder(src) 158 | out = self.decoder(src) 159 | out = out.permute(0, 2, 1) 160 | if 'label_split' in input and cfg['mask']: 161 | label_mask = torch.zeros((cfg['num_tokens'], 1), device=out.device) 162 | label_mask[input['label_split']] = 1 163 | out = out.masked_fill(label_mask == 0, 0) 164 | output['score'] = out 165 | output['loss'] = F.cross_entropy(output['score'], input['label']) 166 | return output 167 | 168 | 169 | def transformer(model_rate=1, cfg=None): 170 | num_tokens = cfg['num_tokens'] 171 | embedding_size = int(np.ceil(model_rate * cfg['transformer']['embedding_size'])) 172 | num_heads = cfg['transformer']['num_heads'] 173 | hidden_size = int(np.ceil(model_rate * cfg['transformer']['hidden_size'])) 174 | num_layers = cfg['transformer']['num_layers'] 175 | dropout = cfg['transformer']['dropout'] 176 | scaler_rate = model_rate / cfg['global_model_rate'] 177 | model = Transformer(num_tokens, embedding_size, num_heads, hidden_size, num_layers, dropout, scaler_rate, cfg) 178 | model.apply(init_param) 179 | return model 180 | -------------------------------------------------------------------------------- /models/transformer_nwp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import TransformerEncoder 6 | 7 | from modules import Scaler 8 | from .utils import init_param 9 | 10 | 11 | class PositionalEmbedding(nn.Module): 12 | def __init__(self, embedding_size, cfg): 13 | super().__init__() 14 | self.positional_embedding = nn.Embedding(cfg['bptt'], embedding_size) 15 | 16 | def forward(self, x): 17 | N, S = x.size() 18 | position = torch.arange(S, dtype=torch.long, device=x.device).unsqueeze(0).expand((N, S)) 19 | x = self.positional_embedding(position) 20 | return x 21 | 22 | 23 | class TransformerEmbedding(nn.Module): 24 | def __init__(self, num_tokens, embedding_size, dropout, rate, cfg): 25 | super().__init__() 26 | self.num_tokens = num_tokens 27 | self.embedding_size = embedding_size 28 | self.positional_embedding = PositionalEmbedding(embedding_size, cfg) 29 | self.embedding = nn.Embedding(num_tokens + 1, embedding_size) 30 | self.norm = nn.LayerNorm(embedding_size) 31 | self.dropout = nn.Dropout(dropout) 32 | self.scaler = Scaler(rate) 33 | 34 | def forward(self, src): 35 | src = self.scaler(self.embedding(src)) + self.scaler(self.positional_embedding(src)) 36 | src = self.dropout(self.norm(src)) 37 | return src 38 | 39 | 40 | class ScaledDotProduct(nn.Module): 41 | def __init__(self, temperature): 42 | super().__init__() 43 | self.temperature = temperature 44 | 45 | def forward(self, q, k, v, mask=None): 46 | scores = q.matmul(k.transpose(-2, -1)) / self.temperature 47 | seq_len = scores.shape[-1] 48 | h = scores.shape[0] 49 | mask = torch.tril(torch.ones((h, seq_len, seq_len))).to(str(scores.device)) 50 | scores = scores.masked_fill(mask == 0, float('-inf')) 51 | attn = F.softmax(scores, dim=-1) 52 | output = torch.matmul(attn, v) 53 | return output, attn 54 | 55 | 56 | class MultiheadAttention(nn.Module): 57 | def __init__(self, embedding_size, num_heads, rate): 58 | super().__init__() 59 | self.embedding_size = embedding_size 60 | self.num_heads = num_heads 61 | self.linear_q = nn.Linear(embedding_size, embedding_size) 62 | self.linear_k = nn.Linear(embedding_size, embedding_size) 63 | self.linear_v = nn.Linear(embedding_size, embedding_size) 64 | self.linear_o = nn.Linear(embedding_size, embedding_size) 65 | self.attention = ScaledDotProduct(temperature=(embedding_size // num_heads) ** 0.5) 66 | self.scaler = Scaler(rate) 67 | 68 | def _reshape_to_batches(self, x): 69 | batch_size, seq_len, in_feature = x.size() 70 | sub_dim = in_feature // self.num_heads 71 | return x.reshape(batch_size, seq_len, self.num_heads, sub_dim).permute(0, 2, 1, 3) \ 72 | .reshape(batch_size * self.num_heads, seq_len, sub_dim) 73 | 74 | def _reshape_from_batches(self, x): 75 | batch_size, seq_len, in_feature = x.size() 76 | batch_size //= self.num_heads 77 | out_dim = in_feature * self.num_heads 78 | return x.reshape(batch_size, self.num_heads, seq_len, in_feature).permute(0, 2, 1, 3) \ 79 | .reshape(batch_size, seq_len, out_dim) 80 | 81 | def forward(self, q, k, v, mask=None): 82 | q, k, v = self.scaler(self.linear_q(q)), self.scaler(self.linear_k(k)), self.scaler(self.linear_v(v)) 83 | q, k, v = self._reshape_to_batches(q), self._reshape_to_batches(k), self._reshape_to_batches(v) 84 | q, attn = self.attention(q, k, v, mask) 85 | q = self._reshape_from_batches(q) 86 | q = self.scaler(self.linear_o(q)) 87 | return q, attn 88 | 89 | 90 | class TransformerEncoderLayer(nn.Module): 91 | def __init__(self, embedding_size, num_heads, hidden_size, dropout, rate): 92 | super().__init__() 93 | self.mha = MultiheadAttention(embedding_size, num_heads, rate=rate) 94 | self.dropout = nn.Dropout(dropout) 95 | self.norm1 = nn.LayerNorm(embedding_size) 96 | self.linear1 = nn.Linear(embedding_size, hidden_size) 97 | self.dropout1 = nn.Dropout(dropout) 98 | self.linear2 = nn.Linear(hidden_size, embedding_size) 99 | self.dropout2 = nn.Dropout(dropout) 100 | self.norm2 = nn.LayerNorm(embedding_size) 101 | self.scaler = Scaler(rate) 102 | self.activation = nn.GELU() 103 | self.init_param() 104 | 105 | def init_param(self): 106 | self.linear1.weight.data.normal_(mean=0.0, std=0.02) 107 | self.linear2.weight.data.normal_(mean=0.0, std=0.02) 108 | self.norm1.weight.data.fill_(1.0) 109 | self.norm1.bias.data.zero_() 110 | self.norm2.weight.data.fill_(1.0) 111 | self.norm2.bias.data.zero_() 112 | return 113 | 114 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 115 | attn_output, _ = self.mha(src, src, src, mask=src_mask) 116 | src = src + self.dropout(attn_output) 117 | src = self.norm1(src) 118 | src2 = self.scaler(self.linear2(self.dropout1(self.activation(self.scaler(self.linear1(src)))))) 119 | src = src + self.dropout2(src2) 120 | src = self.norm2(src) 121 | return src 122 | 123 | 124 | class Decoder(nn.Module): 125 | def __init__(self, num_tokens, embedding_size, rate): 126 | super().__init__() 127 | self.linear1 = nn.Linear(embedding_size, embedding_size) 128 | self.scaler = Scaler(rate) 129 | self.activation = nn.GELU() 130 | self.norm1 = nn.LayerNorm(embedding_size) 131 | self.linear2 = nn.Linear(embedding_size, num_tokens) 132 | 133 | def forward(self, src): 134 | out = self.linear2(self.norm1(self.activation(self.scaler(self.linear1(src))))) 135 | return out 136 | 137 | 138 | class Transformer(nn.Module): 139 | def __init__(self, num_tokens, embedding_size, num_heads, hidden_size, num_layers, dropout, rate, cfg): 140 | super().__init__() 141 | self.num_tokens = num_tokens 142 | self.transformer_embedding = TransformerEmbedding(num_tokens, embedding_size, dropout, rate, cfg) 143 | encoder_layers = TransformerEncoderLayer(embedding_size, num_heads, hidden_size, dropout, rate) 144 | self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers) 145 | self.decoder = Decoder(num_tokens, embedding_size, rate) 146 | self.cfg = cfg 147 | 148 | def forward(self, input): 149 | cfg = self.cfg 150 | output = {} 151 | src = input['label'][:, :-1].clone() 152 | input['label'] = input['label'][:, 1:] 153 | N, S = src.size() 154 | # d = torch.distributions.bernoulli.Bernoulli(probs=cfg['mask_rate']) 155 | # 156 | # mask = d.sample((N, S)).to(src.device) 157 | # src = src.masked_fill(mask == 1, self.num_tokens).detach() 158 | src = self.transformer_embedding(src) 159 | src = self.transformer_encoder(src) 160 | out = self.decoder(src) 161 | out = out.permute(0, 2, 1) 162 | output['score'] = out 163 | # C = np.ones((10000,)) 164 | # C[:2] = 0 165 | # C = C / 9998.0 166 | # C = torch.tensor(C, dtype=torch.float).cuda() 167 | output['loss'] = F.cross_entropy(output['score'], input['label']) 168 | return output 169 | 170 | 171 | def transformer(model_rate=1, cfg=None): 172 | num_tokens = cfg['num_tokens'] 173 | embedding_size = int(np.ceil(model_rate * cfg['transformer']['embedding_size'])) 174 | num_heads = cfg['transformer']['num_heads'] 175 | hidden_size = int(np.ceil(model_rate * cfg['transformer']['hidden_size'])) 176 | num_layers = cfg['transformer']['num_layers'] 177 | dropout = cfg['transformer']['dropout'] 178 | scaler_rate = model_rate / cfg['global_model_rate'] 179 | model = Transformer(num_tokens, embedding_size, num_heads, hidden_size, num_layers, dropout, scaler_rate, cfg) 180 | model.apply(init_param) 181 | return model 182 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def init_param(m): 5 | if isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)): 6 | m.weight.data.fill_(1) 7 | m.bias.data.zero_() 8 | elif isinstance(m, nn.Linear): 9 | m.bias.data.zero_() 10 | return m 11 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | -------------------------------------------------------------------------------- /modules/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Scaler(nn.Module): 5 | def __init__(self, rate): 6 | super().__init__() 7 | self.rate = rate 8 | 9 | def forward(self, input): 10 | output = input / self.rate if self.training else input 11 | return output 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy~=1.23.0 2 | torch~=1.11.0 3 | torchvision~=0.12.0 4 | anytree~=2.8.0 5 | Pillow~=9.0.1 6 | tqdm~=4.64.0 7 | PyYAML~=5.4.1 8 | matplotlib~=3.3.4 9 | ray~=1.13.0 -------------------------------------------------------------------------------- /resnet_client.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import time 3 | 4 | import ray 5 | import torch 6 | 7 | from data import SplitDataset, make_data_loader 8 | from logger import Logger 9 | from metrics import Metric 10 | from models import resnet 11 | from utils import make_optimizer, collate, to_device 12 | 13 | 14 | @ray.remote(num_gpus=0.15) 15 | class ResnetClient: 16 | def __init__(self, log_path, cfg): 17 | # with open('config.yml', 'r') as f: 18 | # cfg = yaml.load(f, Loader=yaml.FullLoader) 19 | self.local_parameters = None 20 | self.m = None 21 | self.start_time = None 22 | self.num_active_users = None 23 | self.optimizer = None 24 | self.model = None 25 | self.lr = None 26 | self.label_split = None 27 | self.data_loader = None 28 | self.model_rate = None 29 | self.client_id = None 30 | cfg = ray.get(cfg[0]) 31 | self.metric = Metric() 32 | self.logger = Logger(log_path) 33 | self.cfg = cfg 34 | 35 | def update(self, client_id, dataset_ref, model_ref): 36 | dataset = ray.get(dataset_ref['dataset']) 37 | data_split = ray.get(dataset_ref['split']) 38 | label_split = ray.get(dataset_ref['label_split']) 39 | local_parameters = ray.get(model_ref['local_params']) 40 | # dataset_ref = torch.load('data_store') 41 | # dataset = (dataset_ref['dataset']) 42 | # data_split = (dataset_ref['split']) 43 | # label_split = (dataset_ref['label_split']) 44 | # local_parameters = {k: v.clone().cuda() for k, v in local_parameters.items()} 45 | self.local_parameters = local_parameters 46 | self.client_id = client_id 47 | self.model_rate = model_ref['model_rate'] 48 | self.data_loader = make_data_loader({'train': SplitDataset(dataset, data_split[client_id])})['train'] 49 | self.label_split = label_split 50 | self.lr = model_ref['lr'] 51 | self.metric = Metric() 52 | 53 | def step(self, m, num_active_users, start_time): 54 | cfg = self.cfg 55 | self.model = resnet.resnet18(model_rate=self.model_rate, cfg=self.cfg).to('cuda') 56 | self.model.load_state_dict(self.local_parameters) 57 | self.model.train(True) 58 | self.optimizer = make_optimizer(self.model, self.lr) 59 | self.m = m 60 | self.num_active_users = num_active_users 61 | self.start_time = start_time 62 | for local_epoch in range(1, cfg['num_epochs']['local'] + 1): 63 | for i, step_input in enumerate(self.data_loader): 64 | step_input = collate(step_input) 65 | input_size = step_input['img'].size(0) 66 | step_input['label_split'] = torch.tensor(self.label_split[self.client_id]) 67 | step_input = to_device(step_input, 'cuda') 68 | self.model.zero_grad() 69 | output = self.model(step_input) 70 | output['loss'].backward() 71 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) 72 | self.optimizer.step() 73 | evaluation = self.metric.evaluate(cfg['metric_name']['train']['Local'], step_input, output) 74 | self.logger.append(evaluation, 'train', n=input_size) 75 | self.log(local_epoch, cfg) 76 | return self.pull() 77 | 78 | def pull(self): 79 | model_state = {k: v.detach().clone().cpu() for k, v in self.model.to(self.cfg['device']).state_dict().items()} 80 | return model_state 81 | 82 | def log(self, epoch, cfg): 83 | if self.m % int((self.num_active_users * cfg['log_interval']) + 1) == 0: 84 | local_time = (time.time() - self.start_time) / (self.m + 1) 85 | epoch_finished_time = datetime.timedelta(seconds=local_time * (self.num_active_users - self.m - 1)) 86 | exp_finished_time = epoch_finished_time + datetime.timedelta( 87 | seconds=round((cfg['num_epochs']['global'] - epoch) * local_time * self.num_active_users)) 88 | info = {'info': ['Model: {}'.format(cfg['model_tag']), 89 | 'Train Epoch: {}({:.0f}%)'.format(epoch, 100. * self.m / self.num_active_users), 90 | 'ID: {}({}/{})'.format(self.client_id, self.m + 1, self.num_active_users), 91 | 'Learning rate: {}'.format(self.lr), 92 | 'Rate: {}'.format(self.model_rate), 93 | 'Epoch Finished Time: {}'.format(epoch_finished_time), 94 | 'Experiment Finished Time: {}'.format(exp_finished_time)]} 95 | self.logger.append(info, 'train', mean=False) 96 | self.logger.write('train', cfg['metric_name']['train']['Local']) 97 | 98 | def test_model_for_user(self, m, ids): 99 | cfg = self.cfg 100 | metric = Metric() 101 | [dataset, data_split, model, label_split] = ray.get(ids) 102 | model = model.to('cuda') 103 | data_loader = make_data_loader({'test': SplitDataset(dataset, data_split[m])})['test'] 104 | results = [] 105 | for _, data_input in enumerate(data_loader): 106 | data_input = collate(data_input) 107 | input_size = data_input['img'].size(0) 108 | data_input['label_split'] = torch.tensor(label_split[m]) 109 | data_input = to_device(data_input, 'cuda') 110 | output = model(data_input) 111 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss'] 112 | evaluation = metric.evaluate(cfg['metric_name']['test']['Local'], data_input, output) 113 | results.append((evaluation, input_size)) 114 | return results 115 | -------------------------------------------------------------------------------- /resnet_server.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | import ray 6 | import torch 7 | 8 | 9 | class ResnetServerRoll: 10 | def __init__(self, global_model, rate, dataset_ref, cfg_id): 11 | self.tau = 1e-2 12 | self.v_t = None 13 | self.beta_1 = 0.9 14 | self.beta_2 = 0.99 15 | self.eta = 1e-2 16 | self.m_t = None 17 | self.user_idx = None 18 | self.param_idx = None 19 | self.dataset_ref = dataset_ref 20 | self.cfg_id = cfg_id 21 | self.cfg = ray.get(cfg_id) 22 | self.global_model = global_model.cpu() 23 | self.global_parameters = global_model.state_dict() 24 | self.rate = rate 25 | self.label_split = ray.get(dataset_ref['label_split']) 26 | self.make_model_rate() 27 | self.num_model_partitions = 50 28 | self.model_idxs = {} 29 | self.roll_idx = {} 30 | self.rounds = 0 31 | self.tmp_counts = {} 32 | for k, v in self.global_parameters.items(): 33 | self.tmp_counts[k] = torch.ones_like(v) 34 | self.reshuffle_params() 35 | self.reshuffle_rounds = 512 36 | 37 | def reshuffle_params(self): 38 | for k, v in self.global_parameters.items(): 39 | if 'conv1' in k or 'conv2' in k: 40 | output_size = v.size(0) 41 | self.model_idxs[k] = torch.randperm(output_size, device=v.device) 42 | self.roll_idx[k] = 0 43 | return self.model_idxs 44 | 45 | def step(self, local_parameters, param_idx, user_idx): 46 | self.combine(local_parameters, param_idx, user_idx) 47 | self.rounds += 1 48 | if self.rounds % self.reshuffle_rounds: 49 | self.reshuffle_params() 50 | 51 | def broadcast(self, local, lr): 52 | cfg = self.cfg 53 | self.global_model.train(True) 54 | num_active_users = cfg['active_user'] 55 | self.user_idx = copy.deepcopy(torch.arange(cfg['num_users']) 56 | [torch.randperm(cfg['num_users']) 57 | [:num_active_users]].tolist()) 58 | local_parameters, self.param_idx = self.distribute(self.user_idx) 59 | 60 | param_ids = [ray.put(local_parameter) for local_parameter in local_parameters] 61 | 62 | ray.get([client.update.remote(self.user_idx[m], 63 | self.dataset_ref, 64 | {'lr': lr, 65 | 'model_rate': self.model_rate[self.user_idx[m]], 66 | 'local_params': param_ids[m]}) 67 | for m, client in enumerate(local)]) 68 | return local, self.param_idx, self.user_idx 69 | 70 | def make_model_rate(self): 71 | cfg = self.cfg 72 | if cfg['model_split_mode'] == 'dynamic': 73 | rate_idx = torch.multinomial(torch.tensor(cfg['proportion']), num_samples=cfg['num_users'], 74 | replacement=True).tolist() 75 | self.model_rate = np.array(self.rate)[rate_idx] 76 | elif cfg['model_split_mode'] == 'fix': 77 | self.model_rate = np.array(self.rate) 78 | else: 79 | raise ValueError('Not valid model split mode') 80 | return 81 | 82 | def split_model(self, user_idx): 83 | cfg = self.cfg 84 | idx_i = [None for _ in range(len(user_idx))] 85 | idx = [OrderedDict() for _ in range(len(user_idx))] 86 | for k, v in self.global_parameters.items(): 87 | parameter_type = k.split('.')[-1] 88 | for m in range(len(user_idx)): 89 | if 'weight' in parameter_type or 'bias' in parameter_type: 90 | if parameter_type == 'weight': 91 | if v.dim() > 1: 92 | input_size = v.size(1) 93 | output_size = v.size(0) 94 | if 'conv1' in k or 'conv2' in k: 95 | if idx_i[m] is None: 96 | idx_i[m] = torch.arange(input_size, device=v.device) 97 | input_idx_i_m = idx_i[m] 98 | scaler_rate = self.model_rate[user_idx[m]] / cfg['global_model_rate'] 99 | local_output_size = int(np.ceil(output_size * scaler_rate)) 100 | if self.cfg['overlap'] is None: 101 | roll = self.rounds % output_size 102 | model_idx = torch.arange(output_size, device=v.device) 103 | else: 104 | overlap = self.cfg['overlap'] 105 | self.roll_idx[k] += int(local_output_size * (1 - overlap)) + 1 106 | self.roll_idx[k] = self.roll_idx[k] % local_output_size 107 | roll = self.roll_idx[k] 108 | model_idx = self.model_idxs[k] 109 | model_idx = torch.roll(model_idx, roll, -1) 110 | output_idx_i_m = model_idx[:local_output_size] 111 | idx_i[m] = output_idx_i_m 112 | elif 'shortcut' in k: 113 | input_idx_i_m = idx[m][k.replace('shortcut', 'conv1')][1] 114 | output_idx_i_m = idx_i[m] 115 | elif 'linear' in k: 116 | input_idx_i_m = idx_i[m] 117 | output_idx_i_m = torch.arange(output_size, device=v.device) 118 | else: 119 | raise ValueError('Not valid k') 120 | idx[m][k] = (output_idx_i_m, input_idx_i_m) 121 | else: 122 | input_idx_i_m = idx_i[m] 123 | idx[m][k] = input_idx_i_m 124 | else: 125 | input_size = v.size(0) 126 | if 'linear' in k: 127 | input_idx_i_m = torch.arange(input_size, device=v.device) 128 | idx[m][k] = input_idx_i_m 129 | else: 130 | input_idx_i_m = idx_i[m] 131 | idx[m][k] = input_idx_i_m 132 | else: 133 | pass 134 | 135 | return idx 136 | 137 | def distribute(self, user_idx): 138 | self.make_model_rate() 139 | param_idx = self.split_model(user_idx) 140 | local_parameters = [OrderedDict() for _ in range(len(user_idx))] 141 | for k, v in self.global_parameters.items(): 142 | parameter_type = k.split('.')[-1] 143 | for m in range(len(user_idx)): 144 | if 'weight' in parameter_type or 'bias' in parameter_type: 145 | if 'weight' in parameter_type: 146 | if v.dim() > 1: 147 | local_parameters[m][k] = copy.deepcopy(v[torch.meshgrid(param_idx[m][k])]) 148 | else: 149 | local_parameters[m][k] = copy.deepcopy(v[param_idx[m][k]]) 150 | else: 151 | local_parameters[m][k] = copy.deepcopy(v[param_idx[m][k]]) 152 | else: 153 | local_parameters[m][k] = copy.deepcopy(v) 154 | return local_parameters, param_idx 155 | 156 | def combine(self, local_parameters, param_idx, user_idx): 157 | count = OrderedDict() 158 | self.global_parameters = self.global_model.cpu().state_dict() 159 | updated_parameters = copy.deepcopy(self.global_parameters) 160 | tmp_counts_cpy = copy.deepcopy(self.tmp_counts) 161 | for k, v in updated_parameters.items(): 162 | parameter_type = k.split('.')[-1] 163 | count[k] = v.new_zeros(v.size(), dtype=torch.float32, device='cpu') 164 | tmp_v = v.new_zeros(v.size(), dtype=torch.float32, device='cpu') 165 | for m in range(len(local_parameters)): 166 | if 'weight' in parameter_type or 'bias' in parameter_type: 167 | if parameter_type == 'weight': 168 | if v.dim() > 1: 169 | if 'linear' in k: 170 | label_split = self.label_split[user_idx[m]] 171 | param_idx[m][k] = list(param_idx[m][k]) 172 | param_idx[m][k][0] = param_idx[m][k][0][label_split] 173 | tmp_v[torch.meshgrid(param_idx[m][k])] += self.tmp_counts[k][torch.meshgrid( 174 | param_idx[m][k])] * local_parameters[m][k][label_split] 175 | count[k][torch.meshgrid(param_idx[m][k])] += self.tmp_counts[k][torch.meshgrid( 176 | param_idx[m][k])] 177 | tmp_counts_cpy[k][torch.meshgrid(param_idx[m][k])] += 1 178 | else: 179 | output_size = v.size(0) 180 | scaler_rate = self.model_rate[user_idx[m]] / self.cfg['global_model_rate'] 181 | local_output_size = int(np.ceil(output_size * scaler_rate)) 182 | if self.cfg['weighting'] == 'avg': 183 | K = 1 184 | elif self.cfg['weighting'] == 'width': 185 | K = local_output_size 186 | elif self.cfg['weighting'] == 'updates': 187 | K = self.tmp_counts[k][torch.meshgrid(param_idx[m][k])] 188 | elif self.cfg['weighting'] == 'updates_width': 189 | K = local_output_size * self.tmp_counts[k][torch.meshgrid(param_idx[m][k])] 190 | # K = self.tmp_counts[k][torch.meshgrid(param_idx[m][k])] 191 | # K = local_output_size 192 | # K = local_output_size * self.tmp_counts[k][torch.meshgrid(param_idx[m][k])] 193 | tmp_v[torch.meshgrid(param_idx[m][k])] += K * local_parameters[m][k] 194 | count[k][torch.meshgrid(param_idx[m][k])] += K 195 | tmp_counts_cpy[k][torch.meshgrid(param_idx[m][k])] += 1 196 | else: 197 | tmp_v[param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] * local_parameters[m][k] 198 | count[k][param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] 199 | tmp_counts_cpy[k][param_idx[m][k]] += 1 200 | else: 201 | if 'linear' in k: 202 | label_split = self.label_split[user_idx[m]] 203 | param_idx[m][k] = param_idx[m][k][label_split] 204 | tmp_v[param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] * local_parameters[m][k][ 205 | label_split] 206 | count[k][param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] 207 | tmp_counts_cpy[k][param_idx[m][k]] += 1 208 | else: 209 | tmp_v[param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] * local_parameters[m][k] 210 | count[k][param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] 211 | tmp_counts_cpy[k][param_idx[m][k]] += 1 212 | else: 213 | tmp_v += self.tmp_counts[k] * local_parameters[m][k] 214 | count[k] += self.tmp_counts[k] 215 | tmp_counts_cpy[k] += 1 216 | tmp_v[count[k] > 0] = tmp_v[count[k] > 0].div_(count[k][count[k] > 0]) 217 | v[count[k] > 0] = tmp_v[count[k] > 0].to(v.dtype) 218 | self.tmp_counts = tmp_counts_cpy 219 | self.global_parameters = updated_parameters 220 | self.global_model.load_state_dict(self.global_parameters) 221 | return 222 | 223 | 224 | class ResnetServerRandom(ResnetServerRoll): 225 | def split_model(self, user_idx): 226 | cfg = self.cfg 227 | idx_i = [None for _ in range(len(user_idx))] 228 | idx = [OrderedDict() for _ in range(len(user_idx))] 229 | for k, v in self.global_parameters.items(): 230 | parameter_type = k.split('.')[-1] 231 | for m in range(len(user_idx)): 232 | if 'weight' in parameter_type or 'bias' in parameter_type: 233 | if parameter_type == 'weight': 234 | if v.dim() > 1: 235 | input_size = v.size(1) 236 | output_size = v.size(0) 237 | if 'conv1' in k or 'conv2' in k: 238 | if idx_i[m] is None: 239 | idx_i[m] = torch.arange(input_size, device=v.device) 240 | input_idx_i_m = idx_i[m] 241 | scaler_rate = self.model_rate[user_idx[m]] / cfg['global_model_rate'] 242 | local_output_size = int(np.ceil(output_size * scaler_rate)) 243 | model_idx = torch.randperm(output_size, device=v.device) 244 | output_idx_i_m = model_idx[:local_output_size] 245 | idx_i[m] = output_idx_i_m 246 | elif 'shortcut' in k: 247 | input_idx_i_m = idx[m][k.replace('shortcut', 'conv1')][1] 248 | output_idx_i_m = idx_i[m] 249 | elif 'linear' in k: 250 | input_idx_i_m = idx_i[m] 251 | output_idx_i_m = torch.arange(output_size, device=v.device) 252 | else: 253 | raise ValueError('Not valid k') 254 | idx[m][k] = (output_idx_i_m, input_idx_i_m) 255 | else: 256 | input_idx_i_m = idx_i[m] 257 | idx[m][k] = input_idx_i_m 258 | else: 259 | input_size = v.size(0) 260 | if 'linear' in k: 261 | input_idx_i_m = torch.arange(input_size, device=v.device) 262 | idx[m][k] = input_idx_i_m 263 | else: 264 | input_idx_i_m = idx_i[m] 265 | idx[m][k] = input_idx_i_m 266 | else: 267 | pass 268 | 269 | return idx 270 | 271 | 272 | class ResnetServerStatic(ResnetServerRoll): 273 | def split_model(self, user_idx): 274 | cfg = self.cfg 275 | idx_i = [None for _ in range(len(user_idx))] 276 | idx = [OrderedDict() for _ in range(len(user_idx))] 277 | for k, v in self.global_parameters.items(): 278 | parameter_type = k.split('.')[-1] 279 | for m in range(len(user_idx)): 280 | if 'weight' in parameter_type or 'bias' in parameter_type: 281 | if parameter_type == 'weight': 282 | if v.dim() > 1: 283 | input_size = v.size(1) 284 | output_size = v.size(0) 285 | if 'conv1' in k or 'conv2' in k: 286 | if idx_i[m] is None: 287 | idx_i[m] = torch.arange(input_size, device=v.device) 288 | input_idx_i_m = idx_i[m] 289 | scaler_rate = self.model_rate[user_idx[m]] / cfg['global_model_rate'] 290 | local_output_size = int(np.ceil(output_size * scaler_rate)) 291 | model_idx = torch.arange(output_size, device=v.device) 292 | output_idx_i_m = model_idx[:local_output_size] 293 | idx_i[m] = output_idx_i_m 294 | elif 'shortcut' in k: 295 | input_idx_i_m = idx[m][k.replace('shortcut', 'conv1')][1] 296 | output_idx_i_m = idx_i[m] 297 | elif 'linear' in k: 298 | input_idx_i_m = idx_i[m] 299 | output_idx_i_m = torch.arange(output_size, device=v.device) 300 | else: 301 | raise ValueError('Not valid k') 302 | idx[m][k] = (output_idx_i_m, input_idx_i_m) 303 | else: 304 | input_idx_i_m = idx_i[m] 305 | idx[m][k] = input_idx_i_m 306 | else: 307 | input_size = v.size(0) 308 | if 'linear' in k: 309 | input_idx_i_m = torch.arange(input_size, device=v.device) 310 | idx[m][k] = input_idx_i_m 311 | else: 312 | input_idx_i_m = idx_i[m] 313 | idx[m][k] = input_idx_i_m 314 | else: 315 | pass 316 | 317 | return idx 318 | -------------------------------------------------------------------------------- /resnet_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | 7 | from config import cfg 8 | from data import fetch_dataset, make_data_loader, SplitDataset 9 | from logger import Logger 10 | from metrics import Metric 11 | from utils import save, to_device, process_control, process_dataset, collate, resume 12 | 13 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 14 | 15 | torch.backends.cudnn.deterministic = True 16 | torch.backends.cudnn.benchmark = False 17 | 18 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | parser = argparse.ArgumentParser(description='cfg') 22 | for k in cfg: 23 | exec('parser.add_argument(\'--{0}\', default=cfg[\'{0}\'], type=type(cfg[\'{0}\']))'.format(k)) 24 | parser.add_argument('--control_name', default=None, type=str) 25 | parser.add_argument('--seed', default=None, type=int) 26 | parser.add_argument('--devices', default=None, nargs='+', type=int) 27 | parser.add_argument('--algo', default='roll', type=str) 28 | # parser.add_argument('--lr', default=None, type=int) 29 | parser.add_argument('--g_epochs', default=None, type=int) 30 | parser.add_argument('--l_epochs', default=None, type=int) 31 | parser.add_argument('--schedule', default=None, nargs='+', type=int) 32 | # parser.add_argument('--exp_name', default=None, type=str) 33 | args = vars(parser.parse_args()) 34 | cfg['init_seed'] = int(args['seed']) 35 | if args['algo'] == 'roll': 36 | pass 37 | elif args['algo'] == 'random': 38 | pass 39 | elif args['algo'] == 'static': 40 | pass 41 | if args['devices'] is not None: 42 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in args['devices']]) 43 | for k in cfg: 44 | cfg[k] = args[k] 45 | if args['control_name']: 46 | cfg['control'] = {k: v for k, v in zip(cfg['control'].keys(), args['control_name'].split('_'))} \ 47 | if args['control_name'] != 'None' else {} 48 | cfg['control_name'] = '_'.join([cfg['control'][k] for k in cfg['control']]) 49 | cfg['pivot_metric'] = 'Global-Accuracy' 50 | cfg['pivot'] = -float('inf') 51 | cfg['metric_name'] = {'train': {'Local': ['Local-Loss', 'Local-Accuracy']}, 52 | 'test': {'Local': ['Local-Loss', 'Local-Accuracy'], 'Global': ['Global-Loss', 'Global-Accuracy']}} 53 | 54 | 55 | def main(): 56 | process_control() 57 | seeds = list(range(cfg['init_seed'], cfg['init_seed'] + cfg['num_experiments'])) 58 | for i in range(cfg['num_experiments']): 59 | model_tag_list = [str(seeds[i]), cfg['data_name'], cfg['subset'], cfg['model_name'], cfg['control_name']] 60 | cfg['model_tag'] = '_'.join([x for x in model_tag_list if x]) 61 | print('Experiment: {}'.format(cfg['model_tag'])) 62 | runExperiment() 63 | return 64 | 65 | 66 | def runExperiment(): 67 | cfg['batch_size']['train'] = cfg['batch_size']['test'] 68 | seed = int(cfg['model_tag'].split('_')[0]) 69 | torch.manual_seed(seed) 70 | torch.cuda.manual_seed(seed) 71 | dataset = fetch_dataset(cfg['data_name'], cfg['subset']) 72 | process_dataset(dataset) 73 | model = eval('resnet.{}(model_rate=cfg["global_model_rate"], track=True, cfg=cfg).to(cfg["device"]).to(cfg[' 74 | '"device"])' 75 | .format(cfg['model_name'])) 76 | last_epoch, data_split, label_split, model, _, _, _ = resume(model, cfg['model_tag'], load_tag='best', strict=False) 77 | logger_path = 'output/runs/test_{}'.format(cfg['model_tag']) 78 | test_logger = Logger(logger_path) 79 | test_logger.safe(True) 80 | # stats(dataset['train'], model) 81 | test(dataset['test'], data_split['test'], label_split, model, test_logger, last_epoch) 82 | test_logger.safe(False) 83 | _, _, _, _, _, _, train_logger = resume(model, cfg['model_tag'], load_tag='checkpoint', strict=False) 84 | save_result = {'cfg': cfg, 'epoch': last_epoch, 'logger': {'train': train_logger, 'test': test_logger}} 85 | save(save_result, './output/result/{}.pt'.format(cfg['model_tag'])) 86 | return 87 | 88 | 89 | def stats(dataset, model): 90 | with torch.no_grad(): 91 | data_loader = make_data_loader({'train': dataset})['train'] 92 | model.train(True) 93 | for i, input in enumerate(data_loader): 94 | input = collate(input) 95 | input = to_device(input, cfg['device']) 96 | model(input) 97 | return 98 | 99 | 100 | def test(dataset, data_split, label_split, model, logger, epoch): 101 | with torch.no_grad(): 102 | metric = Metric() 103 | model.train(False) 104 | for m in range(cfg['num_users']): 105 | data_loader = make_data_loader({'test': SplitDataset(dataset, data_split[m])})['test'] 106 | for i, input in enumerate(data_loader): 107 | input = collate(input) 108 | input_size = input['img'].size(0) 109 | input['label_split'] = torch.tensor(label_split[m]) 110 | input = to_device(input, cfg['device']) 111 | output = model(input) 112 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss'] 113 | evaluation = metric.evaluate(cfg['metric_name']['test']['Local'], input, output) 114 | logger.append(evaluation, 'test', input_size) 115 | 116 | data_loader = make_data_loader({'test': dataset})['test'] 117 | for i, input in enumerate(data_loader): 118 | input = collate(input) 119 | input_size = input['img'].size(0) 120 | input = to_device(input, cfg['device']) 121 | output = model(input) 122 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss'] 123 | evaluation = metric.evaluate(cfg['metric_name']['test']['Global'], input, output) 124 | logger.append(evaluation, 'test', input_size) 125 | info = {'info': ['Model: {}'.format(cfg['model_tag']), 'Test Epoch: {}({:.0f}%)'.format(epoch, 100.)]} 126 | logger.append(info, 'test', mean=False) 127 | logger.write('test', cfg['metric_name']['test']['Local'] + cfg['metric_name']['test']['Global']) 128 | return 129 | 130 | 131 | if __name__ == "__main__": 132 | main() 133 | -------------------------------------------------------------------------------- /test_classifier.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | 7 | from config import cfg 8 | from data import fetch_dataset, make_data_loader, SplitDataset 9 | from logger import Logger 10 | from metrics import Metric 11 | from utils import save, to_device, process_control, process_dataset, resume, collate 12 | 13 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 14 | cudnn.benchmark = True 15 | parser = argparse.ArgumentParser(description='cfg') 16 | for k in cfg: 17 | exec('parser.add_argument(\'--{0}\', default=cfg[\'{0}\'], type=type(cfg[\'{0}\']))'.format(k)) 18 | parser.add_argument('--control_name', default=None, type=str) 19 | args = vars(parser.parse_args()) 20 | for k in cfg: 21 | cfg[k] = args[k] 22 | if args['control_name']: 23 | cfg['control'] = {k: v for k, v in zip(cfg['control'].keys(), args['control_name'].split('_'))} \ 24 | if args['control_name'] != 'None' else {} 25 | cfg['control_name'] = '_'.join([cfg['control'][k] for k in cfg['control']]) 26 | cfg['metric_name'] = {'train': {'Local': ['Local-Loss', 'Local-Accuracy']}, 27 | 'test': {'Local': ['Local-Loss', 'Local-Accuracy'], 'Global': ['Global-Loss', 'Global-Accuracy']}} 28 | 29 | 30 | def main(): 31 | process_control() 32 | seeds = list(range(cfg['init_seed'], cfg['init_seed'] + cfg['num_experiments'])) 33 | for i in range(cfg['num_experiments']): 34 | model_tag_list = [str(seeds[i]), cfg['data_name'], cfg['subset'], cfg['model_name'], cfg['control_name']] 35 | cfg['model_tag'] = '_'.join([x for x in model_tag_list if x]) 36 | print('Experiment: {}'.format(cfg['model_tag'])) 37 | runExperiment() 38 | return 39 | 40 | 41 | def runExperiment(): 42 | cfg['batch_size']['train'] = cfg['batch_size']['test'] 43 | seed = int(cfg['model_tag'].split('_')[0]) 44 | torch.manual_seed(seed) 45 | torch.cuda.manual_seed(seed) 46 | dataset = fetch_dataset(cfg['data_name'], cfg['subset']) 47 | process_dataset(dataset) 48 | model = eval('resnet.{}(model_rate=cfg["global_model_rate"], track=True).to(cfg["device"]).to(cfg["device"])' 49 | .format(cfg['model_name'])) 50 | last_epoch, data_split, label_split, model, _, _, _ = resume(model, cfg['model_tag'], load_tag='best', strict=False) 51 | logger_path = 'output/runs/test_{}'.format(cfg['model_tag']) 52 | test_logger = Logger(logger_path) 53 | test_logger.safe(True) 54 | stats(dataset['train'], model) 55 | test(dataset['test'], data_split['test'], label_split, model, test_logger, last_epoch) 56 | test_logger.safe(False) 57 | _, _, _, _, _, _, train_logger = resume(model, cfg['model_tag'], load_tag='checkpoint', strict=False) 58 | save_result = {'cfg': cfg, 'epoch': last_epoch, 'logger': {'train': train_logger, 'test': test_logger}} 59 | save(save_result, './output/result/{}.pt'.format(cfg['model_tag'])) 60 | return 61 | 62 | 63 | def stats(dataset, model): 64 | with torch.no_grad(): 65 | data_loader = make_data_loader({'train': dataset})['train'] 66 | model.train(True) 67 | for i, input in enumerate(data_loader): 68 | input = collate(input) 69 | input = to_device(input, cfg['device']) 70 | model(input) 71 | return 72 | 73 | 74 | def test(dataset, data_split, label_split, model, logger, epoch): 75 | with torch.no_grad(): 76 | metric = Metric() 77 | model.train(False) 78 | for m in range(cfg['num_users']): 79 | data_loader = make_data_loader({'test': SplitDataset(dataset, data_split[m])})['test'] 80 | for i, input in enumerate(data_loader): 81 | input = collate(input) 82 | input_size = input['img'].size(0) 83 | input['label_split'] = torch.tensor(label_split[m]) 84 | input = to_device(input, cfg['device']) 85 | output = model(input) 86 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss'] 87 | evaluation = metric.evaluate(cfg['metric_name']['test']['Local'], input, output) 88 | logger.append(evaluation, 'test', input_size) 89 | data_loader = make_data_loader({'test': dataset})['test'] 90 | for i, input in enumerate(data_loader): 91 | input = collate(input) 92 | input_size = input['img'].size(0) 93 | input = to_device(input, cfg['device']) 94 | output = model(input) 95 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss'] 96 | evaluation = metric.evaluate(cfg['metric_name']['test']['Global'], input, output) 97 | logger.append(evaluation, 'test', input_size) 98 | info = {'info': ['Model: {}'.format(cfg['model_tag']), 'Test Epoch: {}({:.0f}%)'.format(epoch, 100.)]} 99 | logger.append(info, 'test', mean=False) 100 | logger.write('test', cfg['metric_name']['test']['Local'] + cfg['metric_name']['test']['Global']) 101 | return 102 | 103 | 104 | if __name__ == "__main__": 105 | main() 106 | -------------------------------------------------------------------------------- /train_classifier.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import shutil 5 | import time 6 | 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | 10 | from config import cfg 11 | from data import fetch_dataset, make_data_loader 12 | from logger import Logger 13 | from metrics import Metric 14 | from utils import save, to_device, process_control, process_dataset, make_optimizer, make_scheduler, resume, collate 15 | 16 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 17 | torch.backends.cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = False 19 | parser = argparse.ArgumentParser(description='cfg') 20 | for k in cfg: 21 | exec('parser.add_argument(\'--{0}\', default=cfg[\'{0}\'], type=type(cfg[\'{0}\']))'.format(k)) 22 | parser.add_argument('--control_name', default=None, type=str) 23 | args = vars(parser.parse_args()) 24 | for k in cfg: 25 | cfg[k] = args[k] 26 | if args['control_name']: 27 | cfg['control'] = {k: v for k, v in zip(cfg['control'].keys(), args['control_name'].split('_'))} \ 28 | if args['control_name'] != 'None' else {} 29 | cfg['control_name'] = '_'.join([cfg['control'][k] for k in cfg['control']]) 30 | cfg['pivot_metric'] = 'Global-Accuracy' 31 | cfg['pivot'] = -float('inf') 32 | cfg['metric_name'] = {'train': {'Local': ['Local-Loss', 'Local-Accuracy']}, 33 | 'test': {'Local': ['Local-Loss', 'Local-Accuracy'], 'Global': ['Global-Loss', 'Global-Accuracy']}} 34 | 35 | 36 | def main(): 37 | process_control() 38 | seeds = list(range(cfg['init_seed'], cfg['init_seed'] + cfg['num_experiments'])) 39 | for i in range(cfg['num_experiments']): 40 | model_tag_list = [str(seeds[i]), cfg['data_name'], cfg['subset'], cfg['model_name'], cfg['control_name']] 41 | cfg['model_tag'] = '_'.join([x for x in model_tag_list if x]) 42 | print('Experiment: {}'.format(cfg['model_tag'])) 43 | runExperiment() 44 | return 45 | 46 | 47 | def runExperiment(): 48 | seed = int(cfg['model_tag'].split('_')[0]) 49 | torch.manual_seed(seed) 50 | torch.cuda.manual_seed(seed) 51 | dataset = fetch_dataset(cfg['data_name'], cfg['subset']) 52 | process_dataset(dataset) 53 | data_loader = make_data_loader(dataset) 54 | model = eval('models.{}(model_rate=cfg["global_model_rate"]).to(cfg["device"])'.format(cfg['model_name'])) 55 | optimizer = make_optimizer(model, cfg['lr']) 56 | scheduler = make_scheduler(optimizer) 57 | if cfg['resume_mode'] == 1: 58 | last_epoch, model, optimizer, scheduler, logger = resume(model, cfg['model_tag'], optimizer, scheduler) 59 | elif cfg['resume_mode'] == 2: 60 | last_epoch = 1 61 | _, model, _, _, _ = resume(model, cfg['model_tag']) 62 | logger_path = os.path.join('output', 'runs', '{}'.format(cfg['model_tag'])) 63 | logger = Logger(logger_path) 64 | else: 65 | last_epoch = 1 66 | logger_path = os.path.join('output', 'runs', 'train_{}'.format(cfg['model_tag'])) 67 | logger = Logger(logger_path) 68 | if cfg['world_size'] > 1: 69 | model = torch.nn.DataParallel(model, device_ids=list(range(cfg['world_size']))) 70 | print(cfg['num_epochs']) 71 | for epoch in range(last_epoch, cfg['num_epochs']['global'] + 1): 72 | logger.safe(True) 73 | train(data_loader['train'], model, optimizer, logger, epoch) 74 | test_model = stats(data_loader['train'], model) 75 | test(data_loader['test'], test_model, logger, epoch) 76 | if cfg['scheduler_name'] == 'ReduceLROnPlateau': 77 | scheduler.step(metrics=logger.mean['train/{}'.format(cfg['pivot_metric'])]) 78 | else: 79 | scheduler.step() 80 | logger.safe(False) 81 | model_state_dict = model.module.state_dict() if cfg['world_size'] > 1 else model.state_dict() 82 | save_result = { 83 | 'cfg': cfg, 'epoch': epoch + 1, 'model_dict': model_state_dict, 84 | 'optimizer_dict': optimizer.state_dict(), 'scheduler_dict': scheduler.state_dict(), 85 | 'logger': logger} 86 | save(save_result, './output/model/{}_checkpoint.pt'.format(cfg['model_tag'])) 87 | if cfg['pivot'] < logger.mean['test/{}'.format(cfg['pivot_metric'])]: 88 | cfg['pivot'] = logger.mean['test/{}'.format(cfg['pivot_metric'])] 89 | shutil.copy('./output/model/{}_checkpoint.pt'.format(cfg['model_tag']), 90 | './output/model/{}_best.pt'.format(cfg['model_tag'])) 91 | logger.reset() 92 | logger.safe(False) 93 | return 94 | 95 | 96 | def train(data_loader, model, optimizer, logger, epoch): 97 | metric = Metric() 98 | model.train(True) 99 | start_time = time.time() 100 | for i, input in enumerate(data_loader): 101 | input = collate(input) 102 | input_size = input['img'].size(0) 103 | input = to_device(input, cfg['device']) 104 | optimizer.zero_grad() 105 | output = model(input) 106 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss'] 107 | output['loss'].backward() 108 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 109 | optimizer.step() 110 | evaluation = metric.evaluate(cfg['metric_name']['train'], input, output) 111 | logger.append(evaluation, 'train', n=input_size) 112 | if i % int((len(data_loader) * cfg['log_interval']) + 1) == 0: 113 | batch_time = (time.time() - start_time) / (i + 1) 114 | lr = optimizer.param_groups[0]['lr'] 115 | epoch_finished_time = datetime.timedelta(seconds=round(batch_time * (len(data_loader) - i - 1))) 116 | exp_finished_time = epoch_finished_time + datetime.timedelta( 117 | seconds=round((cfg['num_epochs']['global'] - epoch) * batch_time * len(data_loader))) 118 | info = {'info': ['Model: {}'.format(cfg['model_tag']), 119 | 'Train Epoch: {}({:.0f}%)'.format(epoch, 100. * i / len(data_loader)), 120 | 'Learning rate: {}'.format(lr), 'Epoch Finished Time: {}'.format(epoch_finished_time), 121 | 'Experiment Finished Time: {}'.format(exp_finished_time)]} 122 | logger.append(info, 'train', mean=False) 123 | logger.write('train', cfg['metric_name']['train']) 124 | return 125 | 126 | 127 | def stats(data_loader, model): 128 | with torch.no_grad(): 129 | test_model = eval('models.{}(model_rate=cfg["global_model_rate"], track=True).to(cfg["device"])' 130 | .format(cfg['model_name'])) 131 | test_model.load_state_dict(model.state_dict(), strict=False) 132 | test_model.train(True) 133 | for i, input in enumerate(data_loader): 134 | input = collate(input) 135 | input = to_device(input, cfg['device']) 136 | test_model(input) 137 | return test_model 138 | 139 | 140 | def test(data_loader, model, logger, epoch): 141 | with torch.no_grad(): 142 | metric = Metric() 143 | model.train(False) 144 | for i, input in enumerate(data_loader): 145 | input = collate(input) 146 | input_size = input['img'].size(0) 147 | input = to_device(input, cfg['device']) 148 | output = model(input) 149 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss'] 150 | evaluation = metric.evaluate(cfg['metric_name']['test'], input, output) 151 | logger.append(evaluation, 'test', input_size) 152 | info = {'info': ['Model: {}'.format(cfg['model_tag']), 'Test Epoch: {}({:.0f}%)'.format(epoch, 100.)]} 153 | logger.append(info, 'test', mean=False) 154 | logger.write('test', cfg['metric_name']['test']) 155 | return 156 | 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /transformer_client.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import datetime 3 | import time 4 | 5 | import ray 6 | import torch 7 | 8 | import models 9 | from datasets.lm import StackOverflowClientDataset 10 | from logger import Logger 11 | from metrics import Metric 12 | from utils import make_optimizer, to_device 13 | 14 | 15 | @ray.remote(num_gpus=0.8) 16 | class TransformerClient: 17 | def __init__(self, log_path, cfg): 18 | self.dataset = None 19 | self.local_parameters = None 20 | self.m = None 21 | self.start_time = None 22 | self.num_active_users = None 23 | self.optimizer = None 24 | self.model = None 25 | self.lr = None 26 | self.label_split = None 27 | self.data_loader = None 28 | self.model_rate = None 29 | self.client_id = None 30 | cfg = ray.get(cfg[0]) 31 | self.metric = Metric() 32 | self.logger = Logger(log_path) 33 | self.cfg = cfg 34 | 35 | def update(self, client_id, dataset_ref, model_ref): 36 | dataset = dataset_ref 37 | label_split = dataset_ref 38 | local_parameters = model_ref['local_params'] 39 | self.dataset = StackOverflowClientDataset(dataset, self.cfg['seq_length'], self.cfg['batch_size']['train']) 40 | self.local_parameters = copy.deepcopy(local_parameters) 41 | self.client_id = client_id 42 | self.model_rate = model_ref['model_rate'] 43 | self.label_split = label_split 44 | self.lr = model_ref['lr'] 45 | self.metric = Metric() 46 | 47 | def step(self, m, num_active_users, start_time): 48 | cfg = self.cfg 49 | self.model = models.transformer_nwp(model_rate=self.model_rate, cfg=self.cfg).cpu() 50 | self.model.load_state_dict(self.local_parameters) 51 | self.model = self.model.cuda() 52 | self.model.train(True) 53 | self.optimizer = make_optimizer(self.model, self.lr) 54 | self.m = m 55 | self.num_active_users = num_active_users 56 | self.start_time = start_time 57 | for local_epoch in range(1, cfg['num_epochs']['local'] + 1): 58 | for i, step_input in enumerate(self.dataset): 59 | input_size = step_input['label'].size(0) 60 | # step_input['label_split'] = None 61 | step_input = to_device(step_input, cfg['device']) 62 | self.optimizer.zero_grad() 63 | output = self.model(step_input) 64 | output['loss'].backward() 65 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) 66 | self.optimizer.step() 67 | evaluation = self.metric.evaluate(cfg['metric_name']['train']['Local'], step_input, output) 68 | self.logger.append(evaluation, 'train', n=input_size) 69 | self.log(local_epoch, cfg) 70 | return self.pull() 71 | 72 | def pull(self): 73 | model_state = {k: v.detach().clone() for k, v in self.model.cpu().state_dict().items()} 74 | self.model = None 75 | self.local_parameters = None 76 | return model_state 77 | 78 | def log(self, epoch, cfg): 79 | if self.m % int((self.num_active_users * cfg['log_interval']) + 1) == 0 or True: 80 | local_time = (time.time() - self.start_time) / (self.m + 1) 81 | epoch_finished_time = datetime.timedelta(seconds=local_time * (self.num_active_users - self.m - 1)) 82 | exp_finished_time = epoch_finished_time + datetime.timedelta( 83 | seconds=round((cfg['num_epochs']['global'] - epoch) * local_time * self.num_active_users)) 84 | info = {'info': ['Model: {}'.format(cfg['model_tag']), 85 | 'Train Epoch: {}({:.0f}%)'.format(epoch, 100. * self.m / self.num_active_users), 86 | 'ID: {}({}/{})'.format(self.client_id, self.m + 1, self.num_active_users), 87 | 'Learning rate: {}'.format(self.lr), 88 | 'Rate: {}'.format(self.model_rate), 89 | 'Epoch Finished Time: {}'.format(epoch_finished_time), 90 | 'Experiment Finished Time: {}'.format(exp_finished_time)]} 91 | self.logger.append(info, 'train', mean=False) 92 | self.logger.write('train', cfg['metric_name']['train']['Local']) 93 | 94 | def test_model_for_user(self, m, ids): 95 | cfg = self.cfg 96 | metric = Metric() 97 | [dataset, model] = ray.get(ids) 98 | dataset = StackOverflowClientDataset(dataset, self.cfg['seq_length'], self.cfg['batch_size']['test']) 99 | model = model.to('cuda') 100 | results = [] 101 | for _, data_input in enumerate(dataset): 102 | input_size = data_input['label'].size(0) 103 | data_input = to_device(data_input, 'cuda') 104 | output = model(data_input) 105 | output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss'] 106 | s = output['score'].shape 107 | output['score'] = output['score'].permute((0, 2, 1)).reshape((s[0] * s[2], -1)) 108 | data_input['label'] = data_input['label'].reshape((-1,)) 109 | evaluation = metric.evaluate(cfg['metric_name']['test']['Local'], data_input, output) 110 | results.append((evaluation, input_size)) 111 | return results 112 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import collections.abc as container_abcs 2 | import errno 3 | import os 4 | from itertools import repeat 5 | 6 | import numpy as np 7 | import torch 8 | import torch.optim as optim 9 | from torchvision.utils import save_image 10 | 11 | from config import cfg 12 | 13 | 14 | def check_exists(path): 15 | return os.path.exists(path) 16 | 17 | 18 | def makedir_exist_ok(path): 19 | try: 20 | os.makedirs(path) 21 | except OSError as e: 22 | if e.errno == errno.EEXIST: 23 | pass 24 | else: 25 | raise 26 | return 27 | 28 | 29 | def save(input, path, protocol=2, mode='torch'): 30 | dirname = os.path.dirname(path) 31 | makedir_exist_ok(dirname) 32 | if mode == 'torch': 33 | torch.save(input, path, pickle_protocol=protocol) 34 | elif mode == 'numpy': 35 | np.save(path, input, allow_pickle=True) 36 | else: 37 | raise ValueError('Not valid save mode') 38 | return 39 | 40 | 41 | def load(path, mode='torch'): 42 | if mode == 'torch': 43 | return torch.load(path, map_location=lambda storage, loc: storage) 44 | elif mode == 'numpy': 45 | return np.load(path, allow_pickle=True) 46 | else: 47 | raise ValueError('Not valid save mode') 48 | return 49 | 50 | 51 | def save_img(img, path, nrow=10, padding=2, pad_value=0, range=None): 52 | makedir_exist_ok(os.path.dirname(path)) 53 | normalize = False if range is None else True 54 | save_image(img, path, nrow=nrow, padding=padding, pad_value=pad_value, normalize=normalize, range=range) 55 | return 56 | 57 | 58 | def to_device(input, device): 59 | output = recur(lambda x, y: x.to(y), input, device) 60 | return output 61 | 62 | 63 | def ntuple(n): 64 | def parse(x): 65 | if isinstance(x, container_abcs.Iterable) and not isinstance(x, str): 66 | return x 67 | return tuple(repeat(x, n)) 68 | 69 | return parse 70 | 71 | 72 | def apply_fn(module, fn): 73 | for n, m in module.named_children(): 74 | if hasattr(m, fn): 75 | exec('m.{0}()'.format(fn)) 76 | if sum(1 for _ in m.named_children()) != 0: 77 | exec('apply_fn(m,\'{0}\')'.format(fn)) 78 | return 79 | 80 | 81 | def recur(fn, input, *args): 82 | if isinstance(input, torch.Tensor) or isinstance(input, np.ndarray): 83 | output = fn(input, *args) 84 | elif isinstance(input, list): 85 | output = [] 86 | for i in range(len(input)): 87 | output.append(recur(fn, input[i], *args)) 88 | elif isinstance(input, tuple): 89 | output = [] 90 | for i in range(len(input)): 91 | output.append(recur(fn, input[i], *args)) 92 | output = tuple(output) 93 | elif isinstance(input, dict): 94 | output = {} 95 | for key in input: 96 | output[key] = recur(fn, input[key], *args) 97 | else: 98 | raise ValueError('Not valid input type') 99 | return output 100 | 101 | 102 | def process_dataset(dataset): 103 | if cfg['data_name'] in ['MNIST', 'CIFAR10', 'CIFAR100']: 104 | cfg['classes_size'] = dataset['train'].classes_size 105 | elif cfg['data_name'] in ['WikiText2', 'WikiText103', 'PennTreebank']: 106 | cfg['vocab'] = dataset['train'].vocab 107 | cfg['num_tokens'] = len(dataset['train'].vocab) 108 | for split in dataset: 109 | dataset[split] = batchify(dataset[split], cfg['batch_size'][split]) 110 | elif cfg['data_name'] in ['Stackoverflow']: 111 | # cfg['vocab'] = dataset['vocab'] 112 | cfg['num_tokens'] = len(dataset['vocab']) 113 | elif cfg['data_name'] in ['gld']: 114 | cfg['classes_size'] = 2028 115 | else: 116 | raise ValueError('Not valid data name') 117 | return 118 | 119 | 120 | def process_control(): 121 | cfg['model_split_rate'] = {'a': 1, 'b': 0.5, 'c': 0.25, 'd': 0.125, 'e': 0.0625} 122 | cfg['fed'] = int(cfg['control']['fed']) 123 | cfg['num_users'] = int(cfg['control']['num_users']) 124 | cfg['frac'] = float(cfg['control']['frac']) 125 | cfg['data_split_mode'] = cfg['control']['data_split_mode'] 126 | cfg['model_split_mode'] = cfg['control']['model_split_mode'] 127 | cfg['model_mode'] = cfg['control']['model_mode'] 128 | cfg['norm'] = cfg['control']['norm'] 129 | cfg['scale'] = bool(int(cfg['control']['scale'])) 130 | cfg['mask'] = bool(int(cfg['control']['mask'])) 131 | cfg['global_model_mode'] = cfg['model_mode'][0] 132 | cfg['global_model_rate'] = cfg['model_split_rate'][cfg['global_model_mode']] 133 | model_mode = cfg['model_mode'].split('-') 134 | if cfg['model_split_mode'] == 'dynamic': 135 | mode_rate, proportion = [], [] 136 | for m in model_mode: 137 | mode_rate.append(cfg['model_split_rate'][m[0]]) 138 | proportion.append(int(m[1:])) 139 | cfg['model_rate'] = mode_rate 140 | cfg['proportion'] = (np.array(proportion) / sum(proportion)).tolist() 141 | elif cfg['model_split_mode'] == 'fix': 142 | mode_rate, proportion = [], [] 143 | for m in model_mode: 144 | mode_rate.append(cfg['model_split_rate'][m[0]]) 145 | proportion.append(int(m[1:])) 146 | num_users_proportion = cfg['num_users'] // sum(proportion) 147 | cfg['model_rate'] = [] 148 | for i in range(len(mode_rate)): 149 | cfg['model_rate'] += np.repeat(mode_rate[i], num_users_proportion * proportion[i]).tolist() 150 | cfg['model_rate'] = cfg['model_rate'] + [cfg['model_rate'][-1] for _ in 151 | range(cfg['num_users'] - len(cfg['model_rate']))] 152 | else: 153 | raise ValueError('Not valid model split mode') 154 | cfg['conv'] = {'hidden_size': [64, 128, 256, 512]} 155 | cfg['resnet'] = {'hidden_size': [64, 128, 256, 512]} 156 | cfg['transformer'] = {'embedding_size': 128, 157 | 'num_heads': 8, 158 | 'hidden_size': 2048, 159 | 'num_layers': 3, 160 | 'dropout': 0.1} 161 | if cfg['data_name'] in ['MNIST']: 162 | cfg['data_shape'] = [1, 28, 28] 163 | cfg['optimizer_name'] = 'SGD' 164 | # cfg['lr'] = 1e-2 165 | cfg['momentum'] = 0.9 166 | cfg['weight_decay'] = 5e-4 167 | cfg['scheduler_name'] = 'MultiStepLR' 168 | cfg['factor'] = 0.1 169 | if cfg['data_split_mode'] == 'iid': 170 | # cfg['num_epochs'] = {'global': 200, 'local': 5} 171 | cfg['batch_size'] = {'train': 24, 'test': 100} 172 | # cfg['milestones'] = [100] 173 | elif 'non-iid' in cfg['data_split_mode']: 174 | # cfg['num_epochs'] = {'global': 400, 'local': 5} 175 | cfg['batch_size'] = {'train': 10, 'test': 100} 176 | cfg['milestones'] = [200] 177 | elif cfg['data_split_mode'] == 'none': 178 | cfg['num_epochs'] = 200 179 | cfg['batch_size'] = {'train': 100, 'test': 500} 180 | # cfg['milestones'] = [100] 181 | else: 182 | raise ValueError('Not valid data_split_mode') 183 | elif cfg['data_name'] in ['CIFAR10', 'CIFAR100']: 184 | cfg['data_shape'] = [3, 32, 32] 185 | cfg['optimizer_name'] = 'SGD' 186 | # cfg['lr'] = 1e-4 187 | cfg['momentum'] = 0.9 188 | cfg['min_lr'] = 1e-4 189 | cfg['weight_decay'] = 1e-3 190 | cfg['scheduler_name'] = 'MultiStepLR' 191 | cfg['factor'] = 0.25 192 | if cfg['data_split_mode'] == 'iid': 193 | # cfg['num_epochs'] = {'global': 2500, 'local': 1} 194 | cfg['batch_size'] = {'train': 10, 'test': 100} 195 | # cfg['milestones'] = [1000, 1500, 2000] 196 | elif 'non-iid' in cfg['data_split_mode']: 197 | # cfg['num_epochs'] = {'global': 2500, 'local': 1} 198 | cfg['batch_size'] = {'train': 10, 'test': 100} 199 | # cfg['milestones'] = [1000, 1500, 2000] 200 | elif cfg['data_split_mode'] == 'none': 201 | cfg['num_epochs'] = 400 202 | cfg['batch_size'] = {'train': 100, 'test': 500} 203 | # cfg['milestones'] = [150, 250] 204 | else: 205 | raise ValueError('Not valid data_split_mode') 206 | elif cfg['data_name'] in ['gld']: 207 | cfg['data_shape'] = [3, 92, 92] 208 | cfg['optimizer_name'] = 'SGD' 209 | # cfg['lr'] = 1e-4 210 | cfg['num_users'] = 1262 211 | cfg['active_user'] = 80 212 | cfg['momentum'] = 0.9 213 | cfg['min_lr'] = 5e-4 214 | cfg['weight_decay'] = 1e-3 215 | cfg['scheduler_name'] = 'MultiStepLR' 216 | cfg['factor'] = 0.1 217 | if cfg['data_split_mode'] == 'iid': 218 | # cfg['num_epochs'] = {'global': 2500, 'local': 1} 219 | cfg['batch_size'] = {'train': 32, 'test': 50} 220 | # cfg['milestones'] = [1000, 1500, 2000] 221 | elif 'non-iid' in cfg['data_split_mode']: 222 | # cfg['num_epochs'] = {'global': 2500, 'local': 1} 223 | cfg['batch_size'] = {'train': 32, 'test': 50} 224 | # cfg['milestones'] = [1000, 1500, 2000] 225 | elif cfg['data_split_mode'] == 'none': 226 | cfg['num_epochs'] = 400 227 | cfg['batch_size'] = {'train': 100, 'test': 500} 228 | # cfg['milestones'] = [150, 250] 229 | else: 230 | raise ValueError('Not valid data_split_mode') 231 | elif cfg['data_name'] in ['PennTreebank', 'WikiText2', 'WikiText103']: 232 | cfg['optimizer_name'] = 'SGD' 233 | # cfg['lr'] = 1e-2 234 | cfg['momentum'] = 0.9 235 | cfg['weight_decay'] = 5e-4 236 | cfg['scheduler_name'] = 'MultiStepLR' 237 | cfg['factor'] = 0.1 238 | cfg['bptt'] = 64 239 | cfg['mask_rate'] = 0.15 240 | if cfg['data_split_mode'] == 'iid': 241 | # cfg['num_epochs'] = {'global': 200, 'local': 3} 242 | cfg['batch_size'] = {'train': 100, 'test': 10} 243 | cfg['milestones'] = [50, 100] 244 | elif cfg['data_split_mode'] == 'none': 245 | cfg['num_epochs'] = 100 246 | cfg['batch_size'] = {'train': 100, 'test': 100} 247 | # cfg['milestones'] = [25, 50] 248 | else: 249 | raise ValueError('Not valid data_split_mode') 250 | elif cfg['data_name'] in ['Stackoverflow']: 251 | cfg['optimizer_name'] = 'SGD' 252 | cfg['num_users'] = 342477 253 | cfg['active_user'] = 50 254 | cfg['momentum'] = 0.9 255 | cfg['weight_decay'] = 5e-4 256 | cfg['scheduler_name'] = 'MultiStepLR' 257 | cfg['factor'] = 0.1 258 | cfg['bptt'] = 64 259 | cfg['batch_size'] = {'train': 24, 'test': 24} 260 | cfg['mask_rate'] = 0.15 261 | cfg['num_users'] = 342477 262 | cfg['seq_length'] = 21 263 | else: 264 | raise ValueError('Not valid dataset') 265 | return 266 | 267 | 268 | def make_stats(dataset): 269 | if os.path.exists('./data/stats/{}.pt'.format(dataset.data_name)): 270 | stats = load('./data/stats/{}.pt'.format(dataset.data_name)) 271 | elif dataset is not None: 272 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=100, shuffle=False, num_workers=0) 273 | stats = Stats(dim=1) 274 | with torch.no_grad(): 275 | for input in data_loader: 276 | stats.update(input['img']) 277 | save(stats, './data/stats/{}.pt'.format(dataset.data_name)) 278 | return stats 279 | 280 | 281 | class Stats(object): 282 | def __init__(self, dim): 283 | self.dim = dim 284 | self.n_samples = 0 285 | self.n_features = None 286 | self.mean = None 287 | self.std = None 288 | 289 | def update(self, data): 290 | data = data.transpose(self.dim, -1).reshape(-1, data.size(self.dim)) 291 | if self.n_samples == 0: 292 | self.n_samples = data.size(0) 293 | self.n_features = data.size(1) 294 | self.mean = data.mean(dim=0) 295 | self.std = data.std(dim=0) 296 | else: 297 | m = float(self.n_samples) 298 | n = data.size(0) 299 | new_mean = data.mean(dim=0) 300 | new_std = 0 if n == 1 else data.std(dim=0) 301 | old_mean = self.mean 302 | old_std = self.std 303 | self.mean = m / (m + n) * old_mean + n / (m + n) * new_mean 304 | self.std = torch.sqrt(m / (m + n) * old_std ** 2 + n / (m + n) * new_std ** 2 + m * n / (m + n) ** 2 * ( 305 | old_mean - new_mean) ** 2) 306 | self.n_samples += n 307 | return 308 | 309 | 310 | def make_optimizer(model, lr): 311 | if cfg['optimizer_name'] == 'SGD': 312 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=cfg['momentum'], 313 | weight_decay=cfg['weight_decay']) 314 | elif cfg['optimizer_name'] == 'RMSprop': 315 | optimizer = optim.RMSprop(model.parameters(), lr=lr, momentum=cfg['momentum'], 316 | weight_decay=cfg['weight_decay']) 317 | elif cfg['optimizer_name'] == 'Adam': 318 | optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=cfg['weight_decay']) 319 | elif cfg['optimizer_name'] == 'Adamax': 320 | optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=cfg['weight_decay']) 321 | else: 322 | raise ValueError('Not valid optimizer name') 323 | return optimizer 324 | 325 | 326 | def make_scheduler(optimizer): 327 | if cfg['scheduler_name'] == 'None': 328 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[65535]) 329 | elif cfg['scheduler_name'] == 'StepLR': 330 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=cfg['step_size'], gamma=cfg['factor']) 331 | elif cfg['scheduler_name'] == 'MultiStepLR': 332 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg['milestones'], gamma=cfg['factor']) 333 | elif cfg['scheduler_name'] == 'ExponentialLR': 334 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) 335 | elif cfg['scheduler_name'] == 'CosineAnnealingLR': 336 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg['num_epochs']['global'], 337 | eta_min=cfg['min_lr']) 338 | elif cfg['scheduler_name'] == 'ReduceLROnPlateau': 339 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=cfg['factor'], 340 | patience=cfg['patience'], verbose=True, 341 | threshold=cfg['threshold'], threshold_mode='rel', 342 | min_lr=cfg['min_lr']) 343 | elif cfg['scheduler_name'] == 'CyclicLR': 344 | scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=cfg['lr'], max_lr=10 * cfg['lr']) 345 | else: 346 | raise ValueError('Not valid scheduler name') 347 | return scheduler 348 | 349 | 350 | def resume(model, model_tag, optimizer=None, scheduler=None, load_tag='checkpoint', strict=True, verbose=True): 351 | if cfg['data_split_mode'] != 'none': 352 | if os.path.exists('./output/model/{}_{}.pt'.format(model_tag, load_tag)): 353 | checkpoint = load('./output/model/{}_{}.pt'.format(model_tag, load_tag)) 354 | last_epoch = checkpoint['epoch'] 355 | data_split = checkpoint['data_split'] 356 | label_split = checkpoint['label_split'] 357 | model.load_state_dict(checkpoint['model_dict'], strict=strict) 358 | if optimizer is not None: 359 | optimizer.load_state_dict(checkpoint['optimizer_dict']) 360 | if scheduler is not None: 361 | scheduler.load_state_dict(checkpoint['scheduler_dict']) 362 | logger = checkpoint['logger'] 363 | if verbose: 364 | print('Resume from {}'.format(last_epoch)) 365 | else: 366 | print('Not exists model tag: {}, start from scratch'.format(model_tag)) 367 | from datetime import datetime 368 | from logger import Logger 369 | last_epoch = 1 370 | data_split = None 371 | label_split = None 372 | logger_path = 'output/runs/train_{}_{}'.format(cfg['model_tag'], datetime.now().strftime('%b%d_%H-%M-%S')) 373 | logger = Logger(logger_path) 374 | return last_epoch, data_split, label_split, model, optimizer, scheduler, logger 375 | else: 376 | if os.path.exists('./output/model/{}_{}.pt'.format(model_tag, load_tag)): 377 | checkpoint = load('./output/model/{}_{}.pt'.format(model_tag, load_tag)) 378 | last_epoch = checkpoint['epoch'] 379 | model.load_state_dict(checkpoint['model_dict'], strict=strict) 380 | if optimizer is not None: 381 | optimizer.load_state_dict(checkpoint['optimizer_dict']) 382 | if scheduler is not None: 383 | scheduler.load_state_dict(checkpoint['scheduler_dict']) 384 | logger = checkpoint['logger'] 385 | if verbose: 386 | print('Resume from {}'.format(last_epoch)) 387 | else: 388 | print('Not exists model tag: {}, start from scratch'.format(model_tag)) 389 | from datetime import datetime 390 | from logger import Logger 391 | last_epoch = 1 392 | logger_path = 'output/runs/train_{}_{}'.format(cfg['model_tag'], datetime.now().strftime('%b%d_%H-%M-%S')) 393 | logger = Logger(logger_path) 394 | return last_epoch, model, optimizer, scheduler, logger 395 | 396 | 397 | def collate(input): 398 | if 'label' in input.keys(): 399 | input['label'] = [torch.tensor(i) for i in input['label']] 400 | for k in input: 401 | input[k] = torch.stack(input[k], 0) 402 | return input 403 | 404 | 405 | def batchify(dataset, batch_size): 406 | num_batch = len(dataset) // batch_size 407 | dataset.token = dataset.token.narrow(0, 0, num_batch * batch_size) 408 | dataset.token = dataset.token.reshape(batch_size, -1) 409 | return dataset 410 | -------------------------------------------------------------------------------- /weighted_server.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | import ray 6 | import torch 7 | 8 | 9 | class WeightedServer: 10 | def __init__(self, global_model, rate, dataset_ref, cfg_id): 11 | self.tau = 1e-2 12 | self.v_t = None 13 | self.beta_1 = 0.9 14 | self.beta_2 = 0.99 15 | self.eta = 1e-2 16 | self.m_t = None 17 | self.user_idx = None 18 | self.param_idx = None 19 | self.dataset_ref = dataset_ref 20 | self.cfg_id = cfg_id 21 | self.cfg = ray.get(cfg_id) 22 | self.global_model = global_model 23 | self.global_parameters = global_model.state_dict() 24 | self.rate = rate 25 | self.label_split = ray.get(dataset_ref['label_split']) 26 | self.make_model_rate() 27 | self.num_model_partitions = 50 28 | self.model_idxs = {} 29 | self.rounds = 0 30 | self.tmp_counts = {} 31 | for k, v in self.global_parameters.items(): 32 | self.tmp_counts[k] = torch.ones_like(v) 33 | 34 | for k, v in self.global_parameters.items(): 35 | if 'conv1' in k or 'conv2' in k: 36 | output_size = v.size(0) 37 | self.model_idxs[k] = [torch.randperm(output_size, device=v.device) for _ in range( 38 | self.num_model_partitions)] 39 | 40 | def step(self, local_parameters): 41 | self.combine(local_parameters, self.param_idx, self.user_idx) 42 | self.rounds += 1 43 | 44 | def broadcast(self, local, lr): 45 | cfg = self.cfg 46 | self.global_model.train(True) 47 | num_active_users = int(np.ceil(cfg['frac'] * cfg['num_users'])) 48 | self.user_idx = copy.deepcopy(torch.arange(cfg['num_users']) 49 | [torch.randperm(cfg['num_users']) 50 | [:num_active_users]].tolist()) 51 | local_parameters, self.param_idx = self.distribute(self.user_idx) 52 | # [torch.save(local_parameters[m], f'local_param_{m}') for m in range(len(local_parameters))] 53 | # local_parameters = [{k: v.cpu().numpy() for k, v in p.items()} for p in local_parameters] 54 | 55 | param_ids = [ray.put(local_parameter) for local_parameter in local_parameters] 56 | # ([client.update(self.user_idx[m], 57 | # self.dataset_ref, 58 | # {'lr': lr, 59 | # 'model_rate': self.model_rate[self.user_idx[m]], 60 | # 'local_params': param_ids[m]}) for m, client in enumerate( 61 | # local)]) 62 | 63 | ray.get([client.update.remote(self.user_idx[m], 64 | self.dataset_ref, 65 | {'lr': lr, 66 | 'model_rate': self.model_rate[self.user_idx[m]], 67 | 'local_params': param_ids[m]}) 68 | for m, client in enumerate(local)]) 69 | return local 70 | 71 | def make_model_rate(self): 72 | cfg = self.cfg 73 | if cfg['model_split_mode'] == 'dynamic': 74 | rate_idx = torch.multinomial(torch.tensor(cfg['proportion']), num_samples=cfg['num_users'], 75 | replacement=True).tolist() 76 | self.model_rate = np.array(self.rate)[rate_idx] 77 | elif cfg['model_split_mode'] == 'fix': 78 | self.model_rate = np.array(self.rate) 79 | else: 80 | raise ValueError('Not valid model split mode') 81 | return 82 | 83 | def split_model(self, user_idx): 84 | cfg = self.cfg 85 | idx_i = [None for _ in range(len(user_idx))] 86 | idx = [OrderedDict() for _ in range(len(user_idx))] 87 | for k, v in self.global_parameters.items(): 88 | parameter_type = k.split('.')[-1] 89 | for m in range(len(user_idx)): 90 | if 'weight' in parameter_type or 'bias' in parameter_type: 91 | if parameter_type == 'weight': 92 | if v.dim() > 1: 93 | input_size = v.size(1) 94 | output_size = v.size(0) 95 | if 'conv1' in k or 'conv2' in k: 96 | if idx_i[m] is None: 97 | idx_i[m] = torch.arange(input_size, device=v.device) 98 | input_idx_i_m = idx_i[m] 99 | scaler_rate = self.model_rate[user_idx[m]] / cfg['global_model_rate'] 100 | local_output_size = int(np.ceil(output_size * scaler_rate)) 101 | # model_idx = self.model_idxs[k][m % self.num_model_partitions] 102 | # output_idx_i_m = model_idx[:local_output_size] 103 | roll = self.rounds % output_size 104 | # model_idx = self.model_idxs[k][self.rounds % self.num_model_partitions] 105 | model_idx = torch.arange(output_size, device=v.device) 106 | model_idx = torch.roll(model_idx, roll, -1) 107 | output_idx_i_m = model_idx[:local_output_size] 108 | idx_i[m] = output_idx_i_m 109 | elif 'shortcut' in k: 110 | input_idx_i_m = idx[m][k.replace('shortcut', 'conv1')][1] 111 | output_idx_i_m = idx_i[m] 112 | elif 'linear' in k: 113 | input_idx_i_m = idx_i[m] 114 | output_idx_i_m = torch.arange(output_size, device=v.device) 115 | else: 116 | raise ValueError('Not valid k') 117 | idx[m][k] = (output_idx_i_m, input_idx_i_m) 118 | else: 119 | input_idx_i_m = idx_i[m] 120 | idx[m][k] = input_idx_i_m 121 | else: 122 | input_size = v.size(0) 123 | if 'linear' in k: 124 | input_idx_i_m = torch.arange(input_size, device=v.device) 125 | idx[m][k] = input_idx_i_m 126 | else: 127 | input_idx_i_m = idx_i[m] 128 | idx[m][k] = input_idx_i_m 129 | else: 130 | pass 131 | 132 | return idx 133 | 134 | def distribute(self, user_idx): 135 | self.make_model_rate() 136 | param_idx = self.split_model(user_idx) 137 | local_parameters = [OrderedDict() for _ in range(len(user_idx))] 138 | for k, v in self.global_parameters.items(): 139 | parameter_type = k.split('.')[-1] 140 | for m in range(len(user_idx)): 141 | if 'weight' in parameter_type or 'bias' in parameter_type: 142 | if 'weight' in parameter_type: 143 | if v.dim() > 1: 144 | local_parameters[m][k] = copy.deepcopy(v[torch.meshgrid(param_idx[m][k])]) 145 | else: 146 | local_parameters[m][k] = copy.deepcopy(v[param_idx[m][k]]) 147 | else: 148 | local_parameters[m][k] = copy.deepcopy(v[param_idx[m][k]]) 149 | else: 150 | local_parameters[m][k] = copy.deepcopy(v) 151 | return local_parameters, param_idx 152 | 153 | def combine(self, local_parameters, param_idx, user_idx): 154 | count = OrderedDict() 155 | self.global_parameters = self.global_model.state_dict() 156 | updated_parameters = copy.deepcopy(self.global_parameters) 157 | tmp_counts_cpy = copy.deepcopy(self.tmp_counts) 158 | for k, v in updated_parameters.items(): 159 | parameter_type = k.split('.')[-1] 160 | count[k] = v.new_zeros(v.size(), dtype=torch.float32) 161 | tmp_v = v.new_zeros(v.size(), dtype=torch.float32) 162 | for m in range(len(local_parameters)): 163 | if 'weight' in parameter_type or 'bias' in parameter_type: 164 | if parameter_type == 'weight': 165 | if v.dim() > 1: 166 | if 'linear' in k: 167 | label_split = self.label_split[user_idx[m]] 168 | param_idx[m][k] = list(param_idx[m][k]) 169 | param_idx[m][k][0] = param_idx[m][k][0][label_split] 170 | tmp_v[torch.meshgrid(param_idx[m][k])] += self.tmp_counts[k][torch.meshgrid( 171 | param_idx[m][k])] * local_parameters[m][k][label_split] 172 | count[k][torch.meshgrid(param_idx[m][k])] += self.tmp_counts[k][torch.meshgrid( 173 | param_idx[m][k])] 174 | tmp_counts_cpy[k][torch.meshgrid(param_idx[m][k])] += 1 175 | else: 176 | output_size = v.size(0) 177 | scaler_rate = self.model_rate[user_idx[m]] / self.cfg['global_model_rate'] 178 | local_output_size = int(np.ceil(output_size * scaler_rate)) 179 | # K = self.tmp_counts[k][torch.meshgrid(param_idx[m][k])] 180 | # K = local_output_size 181 | K = local_output_size * self.tmp_counts[k][torch.meshgrid(param_idx[m][k])] 182 | # K = 1 183 | tmp_v[torch.meshgrid(param_idx[m][k])] += K * local_parameters[m][k] 184 | count[k][torch.meshgrid(param_idx[m][k])] += K 185 | tmp_counts_cpy[k][torch.meshgrid(param_idx[m][k])] += 1 186 | else: 187 | tmp_v[param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] * local_parameters[m][k] 188 | count[k][param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] 189 | tmp_counts_cpy[k][param_idx[m][k]] += 1 190 | else: 191 | if 'linear' in k: 192 | label_split = self.label_split[user_idx[m]] 193 | param_idx[m][k] = param_idx[m][k][label_split] 194 | tmp_v[param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] * local_parameters[m][k][ 195 | label_split] 196 | count[k][param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] 197 | tmp_counts_cpy[k][param_idx[m][k]] += 1 198 | else: 199 | tmp_v[param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] * local_parameters[m][k] 200 | count[k][param_idx[m][k]] += self.tmp_counts[k][param_idx[m][k]] 201 | tmp_counts_cpy[k][param_idx[m][k]] += 1 202 | else: 203 | tmp_v += self.tmp_counts[k] * local_parameters[m][k] 204 | count[k] += self.tmp_counts[k] 205 | tmp_counts_cpy[k] += 1 206 | tmp_v[count[k] > 0] = tmp_v[count[k] > 0].div_(count[k][count[k] > 0]) 207 | v[count[k] > 0] = tmp_v[count[k] > 0].to(v.dtype) 208 | self.tmp_counts = tmp_counts_cpy 209 | 210 | delta_t = {k: v - self.global_parameters[k] for k, v in updated_parameters.items()} 211 | if self.rounds in self.cfg['milestones']: 212 | self.eta *= 0.5 213 | if not self.m_t or self.rounds in self.cfg['milestones']: 214 | self.m_t = {k: torch.zeros_like(x) for k, x in delta_t.items()} 215 | self.m_t = { 216 | k: self.beta_1 * self.m_t[k] + (1 - self.beta_1) * delta_t[k] for k in delta_t.keys() 217 | } 218 | if not self.v_t or self.rounds in self.cfg['milestones']: 219 | self.v_t = {k: torch.zeros_like(x) for k, x in delta_t.items()} 220 | self.v_t = { 221 | k: self.beta_2 * self.v_t[k] + (1 - self.beta_2) * torch.multiply(delta_t[k], delta_t[k]) 222 | for k in delta_t.keys() 223 | } 224 | self.global_parameters = { 225 | k: self.global_parameters[k] + self.eta * self.m_t[k] / (torch.sqrt(self.v_t[k]) + self.tau) 226 | for k in self.global_parameters.keys() 227 | } 228 | # if not self.m_t: 229 | # self.m_t = {k: torch.zeros_like(x) for k, x in delta_t.items()} 230 | # self.m_t = { 231 | # k: self.beta_1 * self.m_t[k] + (1 - self.beta_1) * delta_t[k] for k in delta_t.keys() 232 | # } 233 | # if not self.v_t: 234 | # self.v_t = {k: torch.zeros_like(x) for k, x in delta_t.items()} 235 | # self.v_t = { 236 | # k: self.v_t[k] + torch.multiply(delta_t[k], delta_t[k]) 237 | # for k in delta_t.keys() 238 | # } 239 | # self.global_parameters = { 240 | # k: self.global_parameters[k] + self.eta * self.m_t[k] / (torch.sqrt(self.v_t[k]) + self.tau) 241 | # for k in self.global_parameters.keys() 242 | # } 243 | # self.global_parameters = updated_parameters 244 | self.global_model.load_state_dict(self.global_parameters) 245 | return 246 | --------------------------------------------------------------------------------