├── .gitignore ├── FedCP.pdf ├── LICENSE ├── README.md ├── dataset └── mnist-0.1-npz │ ├── config.json │ ├── test │ ├── test0_.npz │ ├── test10_.npz │ ├── test11_.npz │ ├── test12_.npz │ ├── test13_.npz │ ├── test14_.npz │ ├── test15_.npz │ ├── test16_.npz │ ├── test17_.npz │ ├── test18_.npz │ ├── test19_.npz │ ├── test1_.npz │ ├── test2_.npz │ ├── test3_.npz │ ├── test4_.npz │ ├── test5_.npz │ ├── test6_.npz │ ├── test7_.npz │ ├── test8_.npz │ └── test9_.npz │ └── train │ ├── train0_.npz │ ├── train10_.npz │ ├── train11_.npz │ ├── train12_.npz │ ├── train13_.npz │ ├── train14_.npz │ ├── train15_.npz │ ├── train16_.npz │ ├── train17_.npz │ ├── train18_.npz │ ├── train19_.npz │ ├── train1_.npz │ ├── train2_.npz │ ├── train3_.npz │ ├── train4_.npz │ ├── train5_.npz │ ├── train6_.npz │ ├── train7_.npz │ ├── train8_.npz │ └── train9_.npz ├── figs ├── CPN.png ├── example.png └── feature_separation.png └── system ├── env_linux.yaml ├── flcore ├── clients │ └── clientcp.py ├── servers │ └── servercp.py └── trainmodel │ └── models.py ├── main.py ├── run_me.sh └── utils ├── data_utils.py └── mem_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /FedCP.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/FedCP.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | This is the implementation of our paper [*FedCP: Separating Feature Information for Personalized Federated Learning via Conditional Policy*](https://arxiv.org/pdf/2307.01217v2.pdf) (accepted by KDD 2023). 4 | 5 | - [Oral PPT](./FedCP.pdf) 6 | 7 | 8 | **Citation** 9 | 10 | ``` 11 | @inproceedings{Zhang2023fedcp, 12 | author = {Zhang, Jianqing and Hua, Yang and Wang, Hao and Song, Tao and Xue, Zhengui and Ma, Ruhui and Guan, Haibing}, 13 | title = {FedCP: Separating Feature Information for Personalized Federated Learning via Conditional Policy}, 14 | year = {2023}, 15 | booktitle = {Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining} 16 | } 17 | ``` 18 | 19 | **TL;DR**: Existing personalized federated learning (pFL) mehtods focus on exploiting global and personalized information in model parameters rather than the source of information: ***data***, so we propose FedCP to automatically separate global and personalized information from data (i.e., feature representations) in iterative federated learning procedure, as shown in the following figures. 20 | 21 | ![](./figs/example.png) 22 | 23 | ![](./figs/feature_separation.png) 24 | 25 | 26 | # Datasets and Environments 27 | 28 | Due to the file size limitation of GitHub repository, we only upload the mnist dataset with the default practical setting ($\beta=0.1$). You can generate other datasets and environment settings based on my other repository [PFLlib](https://github.com/TsingZ0/PFLlib). 29 | 30 | 31 | # System 32 | 33 | - `main.py`: configurations of **FedCP**. 34 | - `run_me.sh`: start **FedCP**. 35 | - `env_linux.yaml`: python environment to run **FedCP** on Linux. 36 | - `./flcore`: 37 | - `./clients/clientcp.py`: the code on the client. 38 | - `./servers/servercp.py`: the code on the server. 39 | - `./trainmodel/models.py`: the code for backbones. 40 | - `./utils`: 41 | - `data_utils.py`: the code to read the dataset. 42 | 43 | 44 | # Federated Conditional Policy (FedCP) 45 | 46 | ![](./figs/CPN.png) 47 | 48 | 49 | # Training and Evaluation 50 | 51 | All codes corresponding to **FedCP** are stored in `./system`. Just run the following commands. 52 | 53 | ``` 54 | cd ./system 55 | sh run_me.sh # for Linux 56 | ``` 57 | -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/config.json: -------------------------------------------------------------------------------- 1 | {"num_clients":20,"num_classes":10,"non_iid":true,"balance":false,"partition":"dir","Size of samples for labels in clients":[[[0,140],[1,890],[4,1],[5,319],[7,29],[8,1067],[9,184]],[[0,5],[2,27],[5,19],[6,335],[8,6],[9,107]],[[0,3],[3,143],[6,1461],[9,23]],[[0,155],[4,1],[7,2381],[8,4]],[[0,71],[1,13],[3,207],[5,1129],[6,6],[8,40],[9,451]],[[1,38],[3,1],[4,39],[8,25],[9,6086]],[[1,873],[2,176],[3,46],[6,42],[8,13],[9,106]],[[1,21],[2,5],[3,11],[5,787],[7,4],[8,441]],[[0,1],[1,3599]],[[0,633],[1,1997],[2,89],[4,519],[6,768]],[[0,920],[1,2],[2,1450],[3,513],[4,134],[5,97]],[[2,159],[3,3055],[5,558]],[[0,8],[1,180],[2,3277],[5,148]],[[1,237],[2,343],[4,6],[5,453],[7,1095]],[[5,2719],[7,3011]],[[0,31],[3,1785],[5,16],[6,4],[7,756],[8,2856]],[[0,3628]],[[1,26],[2,1463],[3,1379],[4,335],[5,60],[7,17],[8,2373]],[[0,998],[5,8],[6,4260]],[[0,310],[1,1],[2,1],[3,1],[4,5789],[9,1]]],"alpha":0.1,"batch_size":10} -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test0_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test0_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test10_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test10_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test11_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test11_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test12_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test12_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test13_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test13_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test14_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test14_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test15_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test15_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test16_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test16_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test17_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test17_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test18_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test18_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test19_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test19_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test1_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test1_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test2_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test2_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test3_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test3_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test4_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test4_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test5_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test5_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test6_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test6_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test7_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test7_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test8_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test8_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/test/test9_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/test/test9_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train0_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train0_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train10_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train10_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train11_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train11_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train12_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train12_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train13_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train13_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train14_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train14_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train15_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train15_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train16_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train16_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train17_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train17_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train18_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train18_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train19_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train19_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train1_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train1_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train2_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train2_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train3_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train3_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train4_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train4_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train5_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train5_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train6_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train6_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train7_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train7_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train8_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train8_.npz -------------------------------------------------------------------------------- /dataset/mnist-0.1-npz/train/train9_.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/dataset/mnist-0.1-npz/train/train9_.npz -------------------------------------------------------------------------------- /figs/CPN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/figs/CPN.png -------------------------------------------------------------------------------- /figs/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/figs/example.png -------------------------------------------------------------------------------- /figs/feature_separation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsingZ0/FedCP/7110e973a4069db7b05904cde6e76ca6795d7ab8/figs/feature_separation.png -------------------------------------------------------------------------------- /system/env_linux.yaml: -------------------------------------------------------------------------------- 1 | name: fl_torch 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1 7 | - blas=1.0 8 | - brotlipy=0.7.0 9 | - bzip2=1.0.8 10 | - ca-certificates=2021.7.5 11 | - certifi=2021.5.30 12 | - cffi=1.14.5 13 | - chardet=4.0.0 14 | - cryptography=3.4.7 15 | - cudatoolkit=10.2.89 16 | - cycler=0.10.0 17 | - dbus=1.13.18 18 | - et_xmlfile=1.1.0 19 | - expat=2.2.10 20 | - ffmpeg=4.3 21 | - fontconfig=2.13.1 22 | - freetype=2.10.4 23 | - glib=2.67.4 24 | - gmp=6.2.1 25 | - gnutls=3.6.15 26 | - gst-plugins-base=1.14.0 27 | - gstreamer=1.14.0 28 | - h5py=2.10.0 29 | - hdf5=1.10.6 30 | - icu=58.2 31 | - idna=2.10 32 | - intel-openmp=2020.2 33 | - jdcal=1.4.1 34 | - joblib=1.0.1 35 | - jpeg=9b 36 | - kiwisolver=1.3.1 37 | - lame=3.100 38 | - lcms2=2.11 39 | - ld_impl_linux-64=2.33.1 40 | - libedit=3.1.20191231 41 | - libffi=3.3 42 | - libgcc-ng=9.1.0 43 | - libgfortran-ng=7.3.0 44 | - libiconv=1.15 45 | - libidn2=2.3.1 46 | - libpng=1.6.37 47 | - libstdcxx-ng=9.1.0 48 | - libtasn1=4.16.0 49 | - libtiff=4.1.0 50 | - libunistring=0.9.10 51 | - libuuid=1.0.3 52 | - libuv=1.40.0 53 | - libxcb=1.14 54 | - libxml2=2.9.10 55 | - lz4-c=1.9.3 56 | - matplotlib=3.3.4 57 | - matplotlib-base=3.3.4 58 | - mkl=2020.2 59 | - mkl-service=2.3.0 60 | - mkl_fft=1.3.0 61 | - mkl_random=1.1.1 62 | - ncurses=6.2 63 | - nettle=3.7.3 64 | - ninja=1.10.2 65 | - numpy=1.19.2 66 | - numpy-base=1.19.2 67 | - olefile=0.46 68 | - openh264=2.1.0 69 | - openpyxl=3.0.7 70 | - openssl=1.1.1k 71 | - pcre=8.44 72 | - pillow=8.1.1 73 | - pip=21.0.1 74 | - pycparser=2.20 75 | - pyopenssl=20.0.1 76 | - pyparsing=2.4.7 77 | - pyqt=5.9.2 78 | - pysocks=1.7.1 79 | - python=3.8.8 80 | - python-dateutil=2.8.1 81 | - qt=5.9.7 82 | - readline=8.1 83 | - requests=2.25.1 84 | - scikit-learn=0.24.1 85 | - scipy=1.6.1 86 | - setuptools=52.0.0 87 | - sip=4.19.13 88 | - six=1.15.0 89 | - sqlite=3.33.0 90 | - threadpoolctl=2.1.0 91 | - tk=8.6.10 92 | - torchaudio=0.9.0 93 | - torchtext=0.10.0 94 | - tornado=6.1 95 | - tqdm=4.56.0 96 | - typing_extensions=3.7.4.3 97 | - ujson=4.0.2 98 | - wheel=0.36.2 99 | - xz=5.2.5 100 | - zlib=1.2.11 101 | - zstd=1.4.5 102 | - pip: 103 | - absl-py==0.12.0 104 | - cachetools==4.2.2 105 | - calmsize==0.1.3 106 | - google-auth==1.31.0 107 | - google-auth-oauthlib==0.4.4 108 | - grpcio==1.38.0 109 | - markdown==3.3.4 110 | - memory-profiler==0.58.0 111 | - oauthlib==3.1.1 112 | - opacus==0.15.0 113 | - opencv-python==4.5.5.64 114 | - pandas==1.2.4 115 | - protobuf==3.17.3 116 | - psutil==5.8.0 117 | - pyasn1==0.4.8 118 | - pyasn1-modules==0.2.8 119 | - pytorch-memlab==0.2.3 120 | - pytz==2021.1 121 | - requests-oauthlib==1.3.0 122 | - rsa==4.7.2 123 | - seaborn==0.11.2 124 | - tensorboard==2.5.0 125 | - tensorboard-data-server==0.6.1 126 | - tensorboard-plugin-wit==1.8.0 127 | - torch==1.8.0 128 | - torch-tb-profiler==0.1.0 129 | - torchvision==0.9.0 130 | - ttach==0.0.3 131 | - urllib3==1.26.5 132 | - werkzeug==2.0.1 133 | prefix: /slstore/tsing/miniconda3/envs/fl_torch 134 | -------------------------------------------------------------------------------- /system/flcore/clients/clientcp.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | from sklearn.preprocessing import label_binarize 8 | from sklearn import metrics 9 | from utils.data_utils import read_client_data 10 | 11 | 12 | class clientCP: 13 | def __init__(self, args, id, train_samples, test_samples, **kwargs): 14 | self.model = copy.deepcopy(args.model) 15 | self.dataset = args.dataset 16 | self.device = args.device 17 | self.id = id 18 | 19 | self.num_classes = args.num_classes 20 | self.train_samples = train_samples 21 | self.test_samples = test_samples 22 | self.batch_size = args.batch_size 23 | self.learning_rate = args.local_learning_rate 24 | self.local_steps = args.local_steps 25 | 26 | self.loss = nn.CrossEntropyLoss() 27 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate) 28 | 29 | self.lamda = args.lamda 30 | 31 | in_dim = list(args.model.head.parameters())[0].shape[1] 32 | self.context = torch.rand(1, in_dim).to(self.device) 33 | 34 | self.model = Ensemble( 35 | model=self.model, 36 | cs=copy.deepcopy(kwargs['ConditionalSelection']), 37 | head_g=copy.deepcopy(self.model.head), 38 | feature_extractor=copy.deepcopy(self.model.feature_extractor) 39 | ) 40 | self.opt= torch.optim.SGD(self.model.parameters(), lr=self.learning_rate) 41 | 42 | self.pm_train = [] 43 | self.pm_test = [] 44 | 45 | def load_train_data(self, batch_size=None): 46 | if batch_size == None: 47 | batch_size = self.batch_size 48 | train_data = read_client_data(self.dataset, self.id, is_train=True) 49 | return DataLoader(train_data, batch_size, drop_last=True, shuffle=False) 50 | 51 | def load_test_data(self, batch_size=None): 52 | if batch_size == None: 53 | batch_size = self.batch_size 54 | test_data = read_client_data(self.dataset, self.id, is_train=False) 55 | return DataLoader(test_data, batch_size, drop_last=True, shuffle=False) 56 | 57 | def set_parameters(self, feature_extractor): 58 | for new_param, old_param in zip(feature_extractor.parameters(), self.model.model.feature_extractor.parameters()): 59 | old_param.data = new_param.data.clone() 60 | 61 | for new_param, old_param in zip(feature_extractor.parameters(), self.model.feature_extractor.parameters()): 62 | old_param.data = new_param.data.clone() 63 | 64 | 65 | def set_head_g(self, head): 66 | headw_ps = [] 67 | for name, mat in self.model.model.head.named_parameters(): 68 | if 'weight' in name: 69 | headw_ps.append(mat.data) 70 | headw_p = headw_ps[-1] 71 | for mat in headw_ps[-2::-1]: 72 | headw_p = torch.matmul(headw_p, mat) 73 | headw_p.detach_() 74 | self.context = torch.sum(headw_p, dim=0, keepdim=True) 75 | 76 | for new_param, old_param in zip(head.parameters(), self.model.head_g.parameters()): 77 | old_param.data = new_param.data.clone() 78 | 79 | def set_cs(self, cs): 80 | for new_param, old_param in zip(cs.parameters(), self.model.gate.cs.parameters()): 81 | old_param.data = new_param.data.clone() 82 | 83 | def save_con_items(self, items, tag='', item_path=None): 84 | self.save_item(self.pm_train, 'pm_train' + '_' + tag, item_path) 85 | self.save_item(self.pm_test, 'pm_test' + '_' + tag, item_path) 86 | for idx, it in enumerate(items): 87 | self.save_item(it, 'item_' + str(idx) + '_' + tag, item_path) 88 | 89 | def generate_upload_head(self): 90 | for (np, pp), (ng, pg) in zip(self.model.model.head.named_parameters(), self.model.head_g.named_parameters()): 91 | pg.data = pp * 0.5 + pg * 0.5 92 | 93 | def test_metrics(self): 94 | testloader = self.load_test_data() 95 | self.model.eval() 96 | 97 | test_acc = 0 98 | test_num = 0 99 | y_prob = [] 100 | y_true = [] 101 | self.model.gate.pm_ = [] 102 | self.model.gate.gm_ = [] 103 | self.pm_test = [] 104 | 105 | with torch.no_grad(): 106 | for x, y in testloader: 107 | if type(x) == type([]): 108 | x[0] = x[0].to(self.device) 109 | else: 110 | x = x.to(self.device) 111 | y = y.to(self.device) 112 | output = self.model(x, is_rep=False, context=self.context) 113 | 114 | test_acc += (torch.sum(torch.argmax(output, dim=1) == y)).item() 115 | test_num += y.shape[0] 116 | 117 | y_prob.append(F.softmax(output).detach().cpu().numpy()) 118 | nc = self.num_classes 119 | if self.num_classes == 2: 120 | nc += 1 121 | lb = label_binarize(y.detach().cpu().numpy(), classes=np.arange(nc)) 122 | if self.num_classes == 2: 123 | lb = lb[:, :2] 124 | y_true.append(lb) 125 | 126 | y_prob = np.concatenate(y_prob, axis=0) 127 | y_true = np.concatenate(y_true, axis=0) 128 | 129 | auc = metrics.roc_auc_score(y_true, y_prob, average='micro') 130 | 131 | self.pm_test.extend(self.model.gate.pm_) 132 | 133 | return test_acc, test_num, auc 134 | 135 | 136 | def train_cs_model(self): 137 | trainloader = self.load_train_data() 138 | self.model.train() 139 | 140 | for _ in range(self.local_steps): 141 | self.model.gate.pm = [] 142 | self.model.gate.gm = [] 143 | self.pm_train = [] 144 | for i, (x, y) in enumerate(trainloader): 145 | if type(x) == type([]): 146 | x[0] = x[0].to(self.device) 147 | else: 148 | x = x.to(self.device) 149 | y = y.to(self.device) 150 | output, rep, rep_base = self.model(x, is_rep=True, context=self.context) 151 | loss = self.loss(output, y) 152 | loss += MMD(rep, rep_base, 'rbf', self.device) * self.lamda 153 | self.opt.zero_grad() 154 | loss.backward() 155 | self.opt.step() 156 | 157 | self.pm_train.extend(self.model.gate.pm) 158 | scores = [torch.mean(pm).item() for pm in self.pm_train] 159 | print(np.mean(scores), np.std(scores)) 160 | 161 | 162 | def MMD(x, y, kernel, device='cpu'): 163 | """Emprical maximum mean discrepancy. The lower the result 164 | the more evidence that distributions are the same. 165 | 166 | Args: 167 | x: first sample, distribution P 168 | y: second sample, distribution Q 169 | kernel: kernel type such as "multiscale" or "rbf" 170 | """ 171 | xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t()) 172 | rx = (xx.diag().unsqueeze(0).expand_as(xx)) 173 | ry = (yy.diag().unsqueeze(0).expand_as(yy)) 174 | 175 | dxx = rx.t() + rx - 2. * xx # Used for A in (1) 176 | dyy = ry.t() + ry - 2. * yy # Used for B in (1) 177 | dxy = rx.t() + ry - 2. * zz # Used for C in (1) 178 | 179 | XX, YY, XY = (torch.zeros(xx.shape).to(device), 180 | torch.zeros(xx.shape).to(device), 181 | torch.zeros(xx.shape).to(device)) 182 | 183 | if kernel == "multiscale": 184 | 185 | bandwidth_range = [0.2, 0.5, 0.9, 1.3] 186 | for a in bandwidth_range: 187 | XX += a**2 * (a**2 + dxx)**-1 188 | YY += a**2 * (a**2 + dyy)**-1 189 | XY += a**2 * (a**2 + dxy)**-1 190 | 191 | if kernel == "rbf": 192 | 193 | bandwidth_range = [10, 15, 20, 50] 194 | for a in bandwidth_range: 195 | XX += torch.exp(-0.5*dxx/a) 196 | YY += torch.exp(-0.5*dyy/a) 197 | XY += torch.exp(-0.5*dxy/a) 198 | 199 | return torch.mean(XX + YY - 2. * XY) 200 | 201 | 202 | class Ensemble(nn.Module): 203 | def __init__(self, model, cs, head_g, feature_extractor) -> None: 204 | super().__init__() 205 | 206 | self.model = model 207 | self.head_g = head_g 208 | self.feature_extractor = feature_extractor 209 | 210 | for param in self.head_g.parameters(): 211 | param.requires_grad = False 212 | for param in self.feature_extractor.parameters(): 213 | param.requires_grad = False 214 | 215 | self.flag = 0 216 | self.tau = 1 217 | self.hard = False 218 | self.context = None 219 | 220 | self.gate = Gate(cs) 221 | 222 | def forward(self, x, is_rep=False, context=None): 223 | rep = self.model.feature_extractor(x) 224 | 225 | gate_in = rep 226 | 227 | if context != None: 228 | context = F.normalize(context, p=2, dim=1) 229 | if type(x) == type([]): 230 | self.context = torch.tile(context, (x[0].shape[0], 1)) 231 | else: 232 | self.context = torch.tile(context, (x.shape[0], 1)) 233 | 234 | if self.context != None: 235 | gate_in = rep * self.context 236 | 237 | if self.flag == 0: 238 | rep_p, rep_g = self.gate(rep, self.tau, self.hard, gate_in, self.flag) 239 | output = self.model.head(rep_p) + self.head_g(rep_g) 240 | elif self.flag == 1: 241 | rep_p = self.gate(rep, self.tau, self.hard, gate_in, self.flag) 242 | output = self.model.head(rep_p) 243 | else: 244 | rep_g = self.gate(rep, self.tau, self.hard, gate_in, self.flag) 245 | output = self.head_g(rep_g) 246 | 247 | if is_rep: 248 | return output, rep, self.feature_extractor(x) 249 | else: 250 | return output 251 | 252 | 253 | class Gate(nn.Module): 254 | def __init__(self, cs) -> None: 255 | super().__init__() 256 | 257 | self.cs = cs 258 | self.pm = [] 259 | self.gm = [] 260 | self.pm_ = [] 261 | self.gm_ = [] 262 | 263 | def forward(self, rep, tau=1, hard=False, context=None, flag=0): 264 | pm, gm = self.cs(context, tau=tau, hard=hard) 265 | if self.training: 266 | self.pm.extend(pm) 267 | self.gm.extend(gm) 268 | else: 269 | self.pm_.extend(pm) 270 | self.gm_.extend(gm) 271 | 272 | if flag == 0: 273 | rep_p = rep * pm 274 | rep_g = rep * gm 275 | return rep_p, rep_g 276 | elif flag == 1: 277 | return rep * pm 278 | else: 279 | return rep * gm -------------------------------------------------------------------------------- /system/flcore/servers/servercp.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import time 5 | from flcore.clients.clientcp import * 6 | from utils.data_utils import read_client_data 7 | from threading import Thread 8 | 9 | 10 | class FedCP: 11 | def __init__(self, args, times): 12 | self.device = args.device 13 | self.dataset = args.dataset 14 | self.global_rounds = args.global_rounds 15 | self.global_modules = copy.deepcopy(args.model) 16 | self.num_clients = args.num_clients 17 | self.join_ratio = args.join_ratio 18 | self.random_join_ratio = args.random_join_ratio 19 | self.join_clients = int(self.num_clients * self.join_ratio) 20 | 21 | self.clients = [] 22 | self.selected_clients = [] 23 | 24 | self.uploaded_weights = [] 25 | self.uploaded_ids = [] 26 | self.uploaded_models = [] 27 | 28 | self.rs_test_acc = [] 29 | self.rs_train_loss = [] 30 | 31 | self.times = times 32 | self.eval_gap = args.eval_gap 33 | 34 | in_dim = list(args.model.head.parameters())[0].shape[1] 35 | cs = ConditionalSelection(in_dim, in_dim).to(args.device) 36 | 37 | for i in range(self.num_clients): 38 | train_data = read_client_data(self.dataset, i, is_train=True) 39 | test_data = read_client_data(self.dataset, i, is_train=False) 40 | client = clientCP(args, 41 | id=i, 42 | train_samples=len(train_data), 43 | test_samples=len(test_data), 44 | ConditionalSelection=cs) 45 | self.clients.append(client) 46 | 47 | print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}") 48 | print("Finished creating server and clients.") 49 | 50 | # self.load_model() 51 | self.Budget = [] 52 | self.head = None 53 | self.cs = None 54 | 55 | 56 | def select_clients(self): 57 | if self.random_join_ratio: 58 | join_clients = np.random.choice(range(self.join_clients, self.num_clients+1), 1, replace=False)[0] 59 | else: 60 | join_clients = self.join_clients 61 | selected_clients = list(np.random.choice(self.clients, join_clients, replace=False)) 62 | 63 | return selected_clients 64 | 65 | def send_models(self): 66 | assert (len(self.clients) > 0) 67 | 68 | for client in self.clients: 69 | client.set_parameters(self.global_modules) 70 | 71 | def add_parameters(self, w, client_model): 72 | for server_param, client_param in zip(self.global_modules.parameters(), client_model.parameters()): 73 | server_param.data += client_param.data.clone() * w 74 | 75 | def aggregate_parameters(self): 76 | assert (len(self.uploaded_models) > 0) 77 | 78 | self.global_modules = copy.deepcopy(self.uploaded_models[0]) 79 | for param in self.global_modules.parameters(): 80 | param.data = torch.zeros_like(param.data) 81 | 82 | for w, client_model in zip(self.uploaded_weights, self.uploaded_models): 83 | self.add_parameters(w, client_model) 84 | 85 | def test_metrics(self): 86 | num_samples = [] 87 | tot_correct = [] 88 | tot_auc = [] 89 | for c in self.clients: 90 | ct, ns, auc = c.test_metrics() 91 | print(f'Client {c.id}: Acc: {ct*1.0/ns}, AUC: {auc}') 92 | tot_correct.append(ct*1.0) 93 | tot_auc.append(auc*ns) 94 | num_samples.append(ns) 95 | 96 | ids = [c.id for c in self.clients] 97 | 98 | return ids, num_samples, tot_correct, tot_auc 99 | 100 | def evaluate(self, acc=None): 101 | stats = self.test_metrics() 102 | 103 | test_acc = sum(stats[2])*1.0 / sum(stats[1]) 104 | test_auc = sum(stats[3])*1.0 / sum(stats[1]) 105 | 106 | if acc == None: 107 | self.rs_test_acc.append(test_acc) 108 | else: 109 | acc.append(test_acc) 110 | 111 | print("Averaged Test Accurancy: {:.4f}".format(test_acc)) 112 | print("Averaged Test AUC: {:.4f}".format(test_auc)) 113 | 114 | 115 | def train(self): 116 | for i in range(self.global_rounds+1): 117 | s_t = time.time() 118 | self.selected_clients = self.select_clients() 119 | 120 | if i%self.eval_gap == 0: 121 | print(f"\n-------------Round number: {i}-------------") 122 | print("\nEvaluate before local training") 123 | self.evaluate() 124 | 125 | for client in self.selected_clients: 126 | client.train_cs_model() 127 | client.generate_upload_head() 128 | 129 | self.receive_models() 130 | self.aggregate_parameters() 131 | self.send_models() 132 | self.global_head() 133 | self.global_cs() 134 | 135 | self.Budget.append(time.time() - s_t) 136 | print('-'*50, self.Budget[-1]) 137 | 138 | print("\nBest accuracy.") 139 | print(max(self.rs_test_acc)) 140 | print(sum(self.Budget[1:])/len(self.Budget[1:])) 141 | 142 | 143 | def receive_models(self): 144 | assert (len(self.selected_clients) > 0) 145 | 146 | active_train_samples = 0 147 | for client in self.selected_clients: 148 | active_train_samples += client.train_samples 149 | 150 | self.uploaded_weights = [] 151 | self.uploaded_ids = [] 152 | self.uploaded_models = [] 153 | for client in self.selected_clients: 154 | self.uploaded_weights.append(client.train_samples / active_train_samples) 155 | self.uploaded_ids.append(client.id) 156 | self.uploaded_models.append(client.model.model.feature_extractor) 157 | 158 | def global_head(self): 159 | self.uploaded_model_gs = [] 160 | for client in self.selected_clients: 161 | self.uploaded_model_gs.append(client.model.head_g) 162 | 163 | self.head = copy.deepcopy(self.uploaded_model_gs[0]) 164 | for param in self.head.parameters(): 165 | param.data = torch.zeros_like(param.data) 166 | 167 | for w, client_model in zip(self.uploaded_weights, self.uploaded_model_gs): 168 | self.add_head(w, client_model) 169 | 170 | for client in self.selected_clients: 171 | client.set_head_g(self.head) 172 | 173 | def add_head(self, w, head): 174 | for server_param, client_param in zip(self.head.parameters(), head.parameters()): 175 | server_param.data += client_param.data.clone() * w 176 | 177 | def global_cs(self): 178 | self.uploaded_model_gs = [] 179 | for client in self.selected_clients: 180 | self.uploaded_model_gs.append(client.model.gate.cs) 181 | 182 | self.cs = copy.deepcopy(self.uploaded_model_gs[0]) 183 | for param in self.cs.parameters(): 184 | param.data = torch.zeros_like(param.data) 185 | 186 | for w, client_model in zip(self.uploaded_weights, self.uploaded_model_gs): 187 | self.add_cs(w, client_model) 188 | 189 | for client in self.selected_clients: 190 | client.set_cs(self.cs) 191 | 192 | def add_cs(self, w, cs): 193 | for server_param, client_param in zip(self.cs.parameters(), cs.parameters()): 194 | server_param.data += client_param.data.clone() * w 195 | 196 | 197 | class ConditionalSelection(nn.Module): 198 | def __init__(self, in_dim, h_dim): 199 | super(ConditionalSelection, self).__init__() 200 | 201 | self.fc = nn.Sequential( 202 | nn.Linear(in_dim, h_dim*2), 203 | nn.LayerNorm([h_dim*2]), 204 | nn.ReLU(), 205 | ) 206 | 207 | def forward(self, x, tau=1, hard=False): 208 | shape = x.shape 209 | x = self.fc(x) 210 | x = x.view(shape[0], 2, -1) 211 | x = F.gumbel_softmax(x, dim=1, tau=tau, hard=hard) 212 | return x[:, 0, :], x[:, 1, :] 213 | -------------------------------------------------------------------------------- /system/flcore/trainmodel/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | batch_size = 16 6 | 7 | class LocalModel(nn.Module): 8 | def __init__(self, feature_extractor, head): 9 | super(LocalModel, self).__init__() 10 | 11 | self.feature_extractor = feature_extractor 12 | self.head = head 13 | 14 | def forward(self, x, feat=False): 15 | out = self.feature_extractor(x) 16 | if feat: 17 | return out 18 | else: 19 | out = self.head(out) 20 | return out 21 | 22 | 23 | # https://github.com/FengHZ/KD3A/blob/master/model/amazon.py 24 | class AmazonMLP(nn.Module): 25 | def __init__(self): 26 | super(AmazonMLP, self).__init__() 27 | self.encoder = nn.Sequential( 28 | nn.Linear(5000, 1000), 29 | nn.ReLU(), 30 | nn.Linear(1000, 500), 31 | nn.ReLU(), 32 | nn.Linear(500, 100), 33 | nn.ReLU() 34 | ) 35 | self.fc = nn.Linear(100, 2) 36 | 37 | def forward(self, x): 38 | out = self.encoder(x) 39 | out = self.fc(out) 40 | return out 41 | 42 | 43 | class FedAvgCNN(nn.Module): 44 | def __init__(self, in_features=1, num_classes=10, dim=1024, dim1=512): 45 | super().__init__() 46 | self.conv1 = nn.Sequential( 47 | nn.Conv2d(in_features, 48 | 32, 49 | kernel_size=5, 50 | padding=0, 51 | stride=1, 52 | bias=True), 53 | nn.ReLU(inplace=True), 54 | nn.MaxPool2d(kernel_size=(2, 2)) 55 | ) 56 | self.conv2 = nn.Sequential( 57 | nn.Conv2d(32, 58 | 64, 59 | kernel_size=5, 60 | padding=0, 61 | stride=1, 62 | bias=True), 63 | nn.ReLU(inplace=True), 64 | nn.MaxPool2d(kernel_size=(2, 2)) 65 | ) 66 | self.fc1 = nn.Sequential( 67 | nn.Linear(dim, dim1), 68 | nn.ReLU(inplace=True) 69 | ) 70 | self.fc = nn.Linear(dim1, num_classes) 71 | 72 | def forward(self, x): 73 | out = self.conv1(x) 74 | out = self.conv2(out) 75 | out = torch.flatten(out, 1) 76 | out = self.fc1(out) 77 | out = self.fc(out) 78 | return out 79 | 80 | 81 | class fastText(nn.Module): 82 | def __init__(self, hidden_dim, padding_idx=0, vocab_size=98635, num_classes=10): 83 | super(fastText, self).__init__() 84 | 85 | # Embedding Layer 86 | self.embedding = nn.Embedding(vocab_size, hidden_dim, padding_idx) 87 | 88 | # Hidden Layer 89 | self.fc1 = nn.Linear(hidden_dim, hidden_dim) 90 | 91 | # Output Layer 92 | self.fc = nn.Linear(hidden_dim, num_classes) 93 | 94 | def forward(self, x): 95 | text, text_lengths = x 96 | 97 | embedded_sent = self.embedding(text) 98 | h = self.fc1(embedded_sent.mean(1)) 99 | z = self.fc(h) 100 | out = z 101 | 102 | return out 103 | -------------------------------------------------------------------------------- /system/main.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import argparse 4 | import os 5 | import time 6 | import warnings 7 | import numpy as np 8 | import torchvision 9 | 10 | from flcore.servers.servercp import FedCP 11 | from flcore.trainmodel.models import * 12 | 13 | from utils.mem_utils import MemReporter 14 | 15 | warnings.simplefilter("ignore") 16 | torch.manual_seed(0) 17 | 18 | # hyper-params for AG News 19 | vocab_size = 98635 20 | max_len=200 21 | 22 | hidden_dim=32 23 | 24 | def run(args): 25 | 26 | time_list = [] 27 | reporter = MemReporter() 28 | model_str = args.model 29 | 30 | for i in range(args.prev, args.times): 31 | print(f"\n============= Running time: {i}th =============") 32 | print("Creating server and clients ...") 33 | start = time.time() 34 | 35 | # Generate args.model 36 | if model_str == "cnn": 37 | if args.dataset[:5] == "mnist": 38 | args.model = FedAvgCNN(in_features=1, num_classes=args.num_classes, dim=1024).to(args.device) 39 | elif args.dataset[:5] == "Cifar": 40 | args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=1600).to(args.device) 41 | else: 42 | args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=10816).to(args.device) 43 | 44 | elif model_str == "resnet": 45 | args.model = torchvision.models.resnet18(pretrained=False, num_classes=args.num_classes).to(args.device) 46 | 47 | elif model_str == "fastText": 48 | args.model = fastText(hidden_dim=hidden_dim, vocab_size=vocab_size, num_classes=args.num_classes).to(args.device) 49 | 50 | else: 51 | raise NotImplementedError 52 | 53 | print(args.model) 54 | 55 | if args.algorithm == "FedCP": 56 | args.head = copy.deepcopy(args.model.fc) 57 | args.model.fc = nn.Identity() 58 | args.model = LocalModel(args.model, args.head) 59 | server = FedCP(args, i) 60 | else: 61 | raise NotImplementedError 62 | 63 | server.train() 64 | 65 | # torch.cuda.empty_cache() 66 | 67 | time_list.append(time.time()-start) 68 | 69 | reporter.report() 70 | 71 | print(f"\nAverage time cost: {round(np.average(time_list), 2)}s.") 72 | 73 | print("All done!") 74 | 75 | 76 | if __name__ == "__main__": 77 | total_start = time.time() 78 | 79 | parser = argparse.ArgumentParser() 80 | # general 81 | parser.add_argument('-dev', "--device", type=str, default="cuda", 82 | choices=["cpu", "cuda"]) 83 | parser.add_argument('-did', "--device_id", type=str, default="0") 84 | parser.add_argument('-data', "--dataset", type=str, default="mnist") 85 | parser.add_argument('-nb', "--num_classes", type=int, default=10) 86 | parser.add_argument('-m', "--model", type=str, default="cnn") 87 | parser.add_argument('-lbs', "--batch_size", type=int, default=10) 88 | parser.add_argument('-lr', "--local_learning_rate", type=float, default=0.005, 89 | help="Local learning rate") 90 | parser.add_argument('-gr', "--global_rounds", type=int, default=1000) 91 | parser.add_argument('-ls', "--local_steps", type=int, default=1) 92 | parser.add_argument('-algo', "--algorithm", type=str, default="FedGP") 93 | parser.add_argument('-jr', "--join_ratio", type=float, default=1.0, 94 | help="Ratio of clients per round") 95 | parser.add_argument('-rjr', "--random_join_ratio", type=bool, default=False, 96 | help="Random ratio of clients per round") 97 | parser.add_argument('-nc', "--num_clients", type=int, default=20, 98 | help="Total number of clients") 99 | parser.add_argument('-pv', "--prev", type=int, default=0, 100 | help="Previous Running times") 101 | parser.add_argument('-t', "--times", type=int, default=1, 102 | help="Running times") 103 | parser.add_argument('-eg', "--eval_gap", type=int, default=1, 104 | help="Rounds gap for evaluation") 105 | 106 | parser.add_argument('-lam', "--lamda", type=float, default=0.0) 107 | 108 | args = parser.parse_args() 109 | 110 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id 111 | # torch.cuda.set_device(int(args.device_id)) 112 | 113 | if args.device == "cuda" and not torch.cuda.is_available(): 114 | print("\ncuda is not avaiable.\n") 115 | args.device = "cpu" 116 | 117 | run(args) -------------------------------------------------------------------------------- /system/run_me.sh: -------------------------------------------------------------------------------- 1 | # Due to the file size limitation of the supplementary material (250MB), we only upload the mnist dataset. 2 | 3 | nohup python -u main.py -t 1 -jr 1 -nc 20 -nb 10 -data mnist-0.1-npz -m cnn -algo FedCP -did 6 -lam 5 > result-mnist-0.1-npz.out 2>&1 & -------------------------------------------------------------------------------- /system/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | 5 | 6 | def read_data(dataset, idx, is_train=True): 7 | if is_train: 8 | train_data_dir = os.path.join('../dataset', dataset, 'train/') 9 | 10 | train_file = train_data_dir + str(idx) + '.npz' 11 | with open(train_file, 'rb') as f: 12 | train_data = np.load(f, allow_pickle=True)['data'].tolist() 13 | 14 | return train_data 15 | 16 | else: 17 | test_data_dir = os.path.join('../dataset', dataset, 'test/') 18 | 19 | test_file = test_data_dir + str(idx) + '.npz' 20 | with open(test_file, 'rb') as f: 21 | test_data = np.load(f, allow_pickle=True)['data'].tolist() 22 | 23 | return test_data 24 | 25 | 26 | def read_client_data(dataset, idx, is_train=True): 27 | if "News" in dataset: 28 | return read_client_data_text(dataset, idx, is_train) 29 | elif "Shakespeare" in dataset: 30 | return read_client_data_Shakespeare(dataset, idx) 31 | 32 | if is_train: 33 | train_data = read_data(dataset, idx, is_train) 34 | X_train = torch.Tensor(train_data['x']).type(torch.float32) 35 | y_train = torch.Tensor(train_data['y']).type(torch.int64) 36 | 37 | train_data = [(x, y) for x, y in zip(X_train, y_train)] 38 | return train_data 39 | else: 40 | test_data = read_data(dataset, idx, is_train) 41 | X_test = torch.Tensor(test_data['x']).type(torch.float32) 42 | y_test = torch.Tensor(test_data['y']).type(torch.int64) 43 | test_data = [(x, y) for x, y in zip(X_test, y_test)] 44 | return test_data 45 | 46 | 47 | def read_client_data_text(dataset, idx, is_train=True): 48 | if is_train: 49 | train_data = read_data(dataset, idx, is_train) 50 | X_train, X_train_lens = list(zip(*train_data['x'])) 51 | y_train = train_data['y'] 52 | 53 | X_train = torch.Tensor(X_train).type(torch.int64) 54 | X_train_lens = torch.Tensor(X_train_lens).type(torch.int64) 55 | y_train = torch.Tensor(train_data['y']).type(torch.int64) 56 | 57 | train_data = [((x, lens), y) for x, lens, y in zip(X_train, X_train_lens, y_train)] 58 | return train_data 59 | else: 60 | test_data = read_data(dataset, idx, is_train) 61 | X_test, X_test_lens = list(zip(*test_data['x'])) 62 | y_test = test_data['y'] 63 | 64 | X_test = torch.Tensor(X_test).type(torch.int64) 65 | X_test_lens = torch.Tensor(X_test_lens).type(torch.int64) 66 | y_test = torch.Tensor(test_data['y']).type(torch.int64) 67 | 68 | test_data = [((x, lens), y) for x, lens, y in zip(X_test, X_test_lens, y_test)] 69 | return test_data 70 | 71 | 72 | def read_client_data_Shakespeare(dataset, idx, is_train=True): 73 | if is_train: 74 | train_data = read_data(dataset, idx, is_train) 75 | X_train = torch.Tensor(train_data['x']).type(torch.int64) 76 | y_train = torch.Tensor(train_data['y']).type(torch.int64) 77 | 78 | train_data = [(x, y) for x, y in zip(X_train, y_train)] 79 | return train_data 80 | else: 81 | test_data = read_data(dataset, idx, is_train) 82 | X_test = torch.Tensor(test_data['x']).type(torch.int64) 83 | y_test = torch.Tensor(test_data['y']).type(torch.int64) 84 | test_data = [(x, y) for x, y in zip(X_test, y_test)] 85 | return test_data 86 | 87 | -------------------------------------------------------------------------------- /system/utils/mem_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import gc 3 | from collections import defaultdict 4 | from typing import Optional, Tuple, List 5 | 6 | import torch 7 | 8 | from math import isnan 9 | from calmsize import size as calmsize 10 | 11 | def readable_size(num_bytes: int) -> str: 12 | return '' if isnan(num_bytes) else '{:.2f}'.format(calmsize(num_bytes)) 13 | 14 | LEN = 79 15 | 16 | # some pytorch low-level memory management constant 17 | # the minimal allocate memory size (Byte) 18 | PYTORCH_MIN_ALLOCATE = 2 ** 9 19 | # the minimal cache memory size (Byte) 20 | PYTORCH_MIN_CACHE = 2 ** 20 21 | 22 | class MemReporter(): 23 | """A memory reporter that collects tensors and memory usages 24 | 25 | Parameters: 26 | - model: an extra nn.Module can be passed to infer the name 27 | of Tensors 28 | 29 | """ 30 | def __init__(self, model: Optional[torch.nn.Module] = None): 31 | self.tensor_name = {} 32 | self.device_mapping = defaultdict(list) 33 | self.device_tensor_stat = {} 34 | # to numbering the unknown tensors 35 | self.name_idx = 0 36 | 37 | tensor_names = defaultdict(list) 38 | if model is not None: 39 | assert isinstance(model, torch.nn.Module) 40 | # for model with tying weight, multiple parameters may share 41 | # the same underlying tensor 42 | for name, param in model.named_parameters(): 43 | tensor_names[param].append(name) 44 | 45 | for param, name in tensor_names.items(): 46 | self.tensor_name[id(param)] = '+'.join(name) 47 | 48 | def _get_tensor_name(self, tensor: torch.Tensor) -> str: 49 | tensor_id = id(tensor) 50 | if tensor_id in self.tensor_name: 51 | name = self.tensor_name[tensor_id] 52 | # use numbering if no name can be inferred 53 | else: 54 | name = type(tensor).__name__ + str(self.name_idx) 55 | self.tensor_name[tensor_id] = name 56 | self.name_idx += 1 57 | return name 58 | 59 | def collect_tensor(self): 60 | """Collect all tensor objects tracked by python 61 | 62 | NOTICE: 63 | - the buffers for backward which is implemented in C++ are 64 | not tracked by python's reference counting. 65 | - the gradients(.grad) of Parameters is not collected, and 66 | I don't know why. 67 | """ 68 | #FIXME: make the grad tensor collected by gc 69 | objects = gc.get_objects() 70 | tensors = [obj for obj in objects if isinstance(obj, torch.Tensor)] 71 | for t in tensors: 72 | self.device_mapping[t.device].append(t) 73 | 74 | def get_stats(self): 75 | """Get the memory stat of tensors and then release them 76 | 77 | As a memory profiler, we cannot hold the reference to any tensors, which 78 | causes possibly inaccurate memory usage stats, so we delete the tensors after 79 | getting required stats""" 80 | visited_data = {} 81 | self.device_tensor_stat.clear() 82 | 83 | def get_tensor_stat(tensor: torch.Tensor) -> List[Tuple[str, int, int, int]]: 84 | """Get the stat of a single tensor 85 | 86 | Returns: 87 | - stat: a tuple containing (tensor_name, tensor_size, 88 | tensor_numel, tensor_memory) 89 | """ 90 | assert isinstance(tensor, torch.Tensor) 91 | 92 | name = self._get_tensor_name(tensor) 93 | if tensor.is_sparse: 94 | indices_stat = get_tensor_stat(tensor._indices()) 95 | values_stat = get_tensor_stat(tensor._values()) 96 | return indices_stat + values_stat 97 | 98 | numel = tensor.numel() 99 | element_size = tensor.element_size() 100 | fact_numel = tensor.storage().size() 101 | fact_memory_size = fact_numel * element_size 102 | # since pytorch allocate at least 512 Bytes for any tensor, round 103 | # up to a multiple of 512 104 | memory_size = math.ceil(fact_memory_size / PYTORCH_MIN_ALLOCATE) \ 105 | * PYTORCH_MIN_ALLOCATE 106 | 107 | # tensor.storage should be the actual object related to memory 108 | # allocation 109 | data_ptr = tensor.storage().data_ptr() 110 | if data_ptr in visited_data: 111 | name = '{}(->{})'.format( 112 | name, 113 | visited_data[data_ptr], 114 | ) 115 | # don't count the memory for reusing same underlying storage 116 | memory_size = 0 117 | else: 118 | visited_data[data_ptr] = name 119 | 120 | size = tuple(tensor.size()) 121 | # torch scalar has empty size 122 | if not size: 123 | size = (1,) 124 | 125 | return [(name, size, numel, memory_size)] 126 | 127 | for device, tensors in self.device_mapping.items(): 128 | tensor_stats = [] 129 | for tensor in tensors: 130 | 131 | if tensor.numel() == 0: 132 | continue 133 | stat = get_tensor_stat(tensor) # (name, shape, numel, memory_size) 134 | tensor_stats += stat 135 | if isinstance(tensor, torch.nn.Parameter): 136 | if tensor.grad is not None: 137 | # manually specify the name of gradient tensor 138 | self.tensor_name[id(tensor.grad)] = '{}.grad'.format( 139 | self._get_tensor_name(tensor) 140 | ) 141 | stat = get_tensor_stat(tensor.grad) 142 | tensor_stats += stat 143 | 144 | self.device_tensor_stat[device] = tensor_stats 145 | 146 | self.device_mapping.clear() 147 | 148 | def print_stats(self, verbose: bool = False, target_device: Optional[torch.device] = None) -> None: 149 | # header 150 | # show_reuse = verbose 151 | # template_format = '{:<40s}{:>20s}{:>10s}' 152 | # print(template_format.format('Element type', 'Size', 'Used MEM') ) 153 | for device, tensor_stats in self.device_tensor_stat.items(): 154 | # By default, if the target_device is not specified, 155 | # print tensors on all devices 156 | if target_device is not None and device != target_device: 157 | continue 158 | # print('-' * LEN) 159 | print('\nStorage on {}'.format(device)) 160 | total_mem = 0 161 | total_numel = 0 162 | for stat in tensor_stats: 163 | name, size, numel, mem = stat 164 | # if not show_reuse: 165 | # name = name.split('(')[0] 166 | # print(template_format.format( 167 | # str(name), 168 | # str(size), 169 | # readable_size(mem), 170 | # )) 171 | total_mem += mem 172 | total_numel += numel 173 | 174 | print('-'*LEN) 175 | print('Total Tensors: {} \tUsed Memory: {}'.format( 176 | total_numel, readable_size(total_mem), 177 | )) 178 | 179 | if device != torch.device('cpu'): 180 | with torch.cuda.device(device): 181 | memory_allocated = torch.cuda.memory_allocated() 182 | print('The allocated memory on {}: {}'.format( 183 | device, readable_size(memory_allocated), 184 | )) 185 | if memory_allocated != total_mem: 186 | print('Memory differs due to the matrix alignment or' 187 | ' invisible gradient buffer tensors') 188 | print('-'*LEN) 189 | 190 | def report(self, verbose: bool = False, device: Optional[torch.device] = None) -> None: 191 | """Interface for end-users to directly print the memory usage 192 | 193 | args: 194 | - verbose: flag to show tensor.storage reuse information 195 | - device: `torch.device` object, specify the target device 196 | to report detailed memory usage. It will print memory usage 197 | on all devices if not specified. Usually we only want to 198 | print the memory usage on CUDA devices. 199 | 200 | """ 201 | self.collect_tensor() 202 | self.get_stats() 203 | self.print_stats(verbose, target_device=device) --------------------------------------------------------------------------------