├── .github ├── dependabot.yaml └── workflows │ └── test.yaml ├── .gitignore ├── .style.yapf ├── LICENSE ├── README.md ├── docs └── horovod_faq.md ├── examples ├── format.sh ├── ray_lightning ├── __init__.py ├── accelerators │ ├── __init__.py │ └── delayed_gpu_accelerator.py ├── examples │ ├── __init__.py │ ├── ray_ddp_example.py │ ├── ray_ddp_sharded_example.py │ ├── ray_ddp_tune.py │ └── ray_horovod_example.py ├── launchers │ ├── __init__.py │ ├── ray_horovod_launcher.py │ ├── ray_launcher.py │ └── utils.py ├── ray_ddp.py ├── ray_ddp_sharded.py ├── ray_horovod.py ├── session.py ├── tests │ ├── __init__.py │ ├── test_client.py │ ├── test_client_2.py │ ├── test_client_3.py │ ├── test_ddp.py │ ├── test_ddp_gpu.py │ ├── test_ddp_sharded.py │ ├── test_horovod.py │ ├── test_lightning_cli.py │ ├── test_tune.py │ └── utils.py ├── tune.py └── util.py ├── requirements-lint.txt ├── requirements-test.txt └── setup.py /.github/dependabot.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | reviewers: 8 | - "amogkam" 9 | ignore: 10 | - dependency-name: flake8 11 | - dependency-name: flake8-quotes 12 | - dependency-name: yapf 13 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: pytest on push 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | test_lint: 7 | runs-on: ubuntu-latest 8 | timeout-minutes: 3 9 | steps: 10 | - uses: actions/checkout@v2 11 | - name: Set up Python 3.7 12 | uses: actions/setup-python@v2 13 | with: 14 | python-version: 3.7 15 | - name: Install dependencies 16 | run: | 17 | python -m pip install --upgrade pip 18 | python -m pip install codecov 19 | python -m pip install -U -r requirements-lint.txt 20 | - name: Run format script 21 | run: | 22 | ./format.sh --all 23 | 24 | test_linux_ray_master_1: 25 | runs-on: ubuntu-latest 26 | timeout-minutes: 40 27 | steps: 28 | - uses: actions/checkout@v2 29 | - name: Set up Python 3.7 30 | uses: actions/setup-python@v2 31 | with: 32 | python-version: 3.7 33 | - name: Install dependencies 34 | run: | 35 | python -m pip install --upgrade pip 36 | python -m pip install --upgrade setuptools 37 | python -m pip install codecov 38 | python -m pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp37-cp37m-manylinux2014_x86_64.whl 39 | if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi 40 | - name: Install package 41 | run: | 42 | python -m pip install -e . 43 | - name: Test with Pytest 44 | run: | 45 | pushd ray_lightning/tests 46 | python -m pytest -v --durations=0 -x test_ddp.py 47 | python -m pytest -v --durations=0 -x test_ddp_sharded.py 48 | 49 | test_linux_ray_master_2: 50 | runs-on: ubuntu-latest 51 | timeout-minutes: 40 52 | steps: 53 | - uses: actions/checkout@v2 54 | - name: Set up Python 3.7 55 | uses: actions/setup-python@v2 56 | with: 57 | python-version: 3.7 58 | - name: Install dependencies 59 | run: | 60 | python -m pip install --upgrade pip 61 | python -m pip install --upgrade setuptools 62 | python -m pip install codecov 63 | python -m pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp37-cp37m-manylinux2014_x86_64.whl 64 | if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi 65 | HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_MXNET=1 pip install git+https://github.com/horovod/horovod.git 66 | - name: Install package 67 | run: | 68 | python -m pip install -e . 69 | - name: Test with Pytest 70 | run: | 71 | pushd ray_lightning/tests 72 | python -m pytest -v --durations=0 -x test_horovod.py 73 | python -m pytest -v --durations=0 -x test_tune.py 74 | 75 | test_linux_ray_master_examples: 76 | runs-on: ubuntu-latest 77 | timeout-minutes: 40 78 | steps: 79 | - uses: actions/checkout@v2 80 | - name: Set up Python 3.7 81 | uses: actions/setup-python@v2 82 | with: 83 | python-version: 3.7 84 | - name: Install dependencies 85 | run: | 86 | python -m pip install --upgrade pip 87 | python -m pip install --upgrade setuptools 88 | python -m pip install codecov 89 | python -m pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp37-cp37m-manylinux2014_x86_64.whl 90 | if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi 91 | HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_MXNET=1 pip install git+https://github.com/horovod/horovod.git 92 | - name: Install package 93 | run: | 94 | python -m pip install -e . 95 | - name: Run Examples 96 | run: | 97 | pushd ray_lightning/examples 98 | echo "running ray_ddp_example.py" && python ray_ddp_example.py --smoke-test 99 | echo "running ray_ddp_example.py with Tune" && python ray_ddp_example.py --smoke-test --tune 100 | echo "running ray_ddp_tune.py" && python ray_ddp_tune.py --smoke-test 101 | echo "running ray_horovod_example.py" && python ray_horovod_example.py --smoke-test 102 | echo "running ray_horovod_example.py with Tune" && python ray_horovod_example.py --smoke-test --tune 103 | popd 104 | pushd ray_lightning/tests 105 | echo "running examples with Ray Client 1" && python -m pytest -v --durations=0 -x test_client.py 106 | echo "running examples with Ray Client 2" && python -m pytest -v --durations=0 -x test_client_2.py 107 | echo "running examples with Ray Client 3" && python -m pytest -v --durations=0 -x test_client_3.py 108 | 109 | 110 | test_linux_ray_release_1: 111 | runs-on: ubuntu-latest 112 | timeout-minutes: 40 113 | steps: 114 | - uses: actions/checkout@v2 115 | - name: Set up Python 3.7 116 | uses: actions/setup-python@v2 117 | with: 118 | python-version: 3.7 119 | - name: Install dependencies 120 | run: | 121 | python -m pip install --upgrade pip 122 | python -m pip install --upgrade setuptools 123 | python -m pip install codecov 124 | python -m pip install -U ray 125 | if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi 126 | - name: Install package 127 | run: | 128 | python -m pip install -e . 129 | - name: Test with Pytest 130 | run: | 131 | pushd ray_lightning/tests 132 | python -m pytest -v --durations=0 -x test_ddp.py 133 | python -m pytest -v --durations=0 -x test_ddp_sharded.py 134 | python -m pytest -v --durations=0 -x test_lightning_cli.py 135 | 136 | test_linux_ray_release_2: 137 | runs-on: ubuntu-latest 138 | timeout-minutes: 40 139 | steps: 140 | - uses: actions/checkout@v2 141 | - name: Set up Python 3.7 142 | uses: actions/setup-python@v2 143 | with: 144 | python-version: 3.7 145 | - name: Install dependencies 146 | run: | 147 | python -m pip install --upgrade pip 148 | python -m pip install --upgrade setuptools 149 | python -m pip install codecov 150 | python -m pip install -U ray 151 | if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi 152 | HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_MXNET=1 pip install -U git+https://github.com/horovod/horovod.git 153 | - name: Install package 154 | run: | 155 | python -m pip install -e . 156 | - name: Test with Pytest 157 | run: | 158 | pushd ray_lightning/tests 159 | python -m pytest -v --durations=0 -x test_horovod.py 160 | python -m pytest -v --durations=0 -x test_tune.py 161 | 162 | 163 | test_linux_ray_release_examples: 164 | runs-on: ubuntu-latest 165 | timeout-minutes: 40 166 | steps: 167 | - uses: actions/checkout@v2 168 | - name: Set up Python 3.7 169 | uses: actions/setup-python@v2 170 | with: 171 | python-version: 3.7 172 | - name: Install dependencies 173 | run: | 174 | python -m pip install --upgrade pip 175 | python -m pip install --upgrade setuptools 176 | python -m pip install codecov 177 | python -m pip install -U ray 178 | if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi 179 | HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_MXNET=1 pip install -U git+https://github.com/horovod/horovod.git 180 | - name: Install package 181 | run: | 182 | python -m pip install -e . 183 | - name: Run Examples 184 | run: | 185 | pushd ray_lightning/examples 186 | echo "running ray_ddp_example.py" && python ray_ddp_example.py --smoke-test 187 | echo "running ray_ddp_example.py with Tune" && python ray_ddp_example.py --smoke-test --tune 188 | echo "running ray_ddp_tune.py" && python ray_ddp_tune.py --smoke-test 189 | echo "running ray_horovod_example.py" && python ray_horovod_example.py --smoke-test 190 | echo "running ray_horovod_example.py with Tune" && python ray_horovod_example.py --smoke-test --tune 191 | popd 192 | pushd ray_lightning/tests 193 | echo "running examples with Ray Client 1" && python -m pytest -v --durations=0 -x test_client.py 194 | echo "running examples with Ray Client 2" && python -m pytest -v --durations=0 -x test_client_2.py 195 | echo "running examples with Ray Client 3" && python -m pytest -v --durations=0 -x test_client_3.py 196 | 197 | test_linux_compat: 198 | # Test compatibility when optional libraries are not installed. 199 | runs-on: ubuntu-latest 200 | timeout-minutes: 40 201 | steps: 202 | - uses: actions/checkout@v2 203 | - name: Set up Python 3.7 204 | uses: actions/setup-python@v2 205 | with: 206 | python-version: 3.7 207 | - name: Install dependencies 208 | run: | 209 | python -m pip install --upgrade pip 210 | python -m pip install --upgrade setuptools 211 | python -m pip install codecov 212 | python -m pip install -U ray 213 | if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi 214 | HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_MXNET=1 pip install -U git+https://github.com/horovod/horovod.git 215 | - name: Uninstall unavailable dependencies 216 | run: | 217 | # Uninstall Tune 218 | pip uninstall -y tabulate 219 | - name: Install package 220 | run: | 221 | python -m pip install -e . 222 | - name: Test with Pytest 223 | run: | 224 | pushd ray_lightning/tests 225 | python -m pytest -v --durations=0 -x test_ddp.py 226 | python -m pytest -v --durations=0 -x test_horovod.py 227 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style=pep8 3 | allow_split_before_dict_value=False 4 | join_multiple_lines=False 5 | allow_multiline_lambdas=True -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Distributed PyTorch Lightning Training on Ray 4 | 5 | ## ⚠️ `ray_lightning` is no longer supported 6 | 7 | This project is no longer actively maintained and has been archived. For distributed PyTorch Lightning on Ray, visit [Ray Train](https://docs.ray.io/en/latest/train/train.html). 8 | 9 | For more details, see [this issue](https://github.com/ray-project/ray_lightning/issues/258). 10 | 11 | ## Overview 12 | 13 | This library adds new PyTorch Lightning strategies for distributed training using the Ray distributed computing framework. 14 | 15 | These PyTorch Lightning strategies on Ray enable quick and easy parallel training while still leveraging all the benefits of PyTorch Lightning and using your desired training protocol, either [PyTorch Distributed Data Parallel](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) or [Horovod](https://github.com/horovod/horovod). 16 | 17 | Once you add your strategy to the PyTorch Lightning Trainer, you can parallelize training to all the cores in your laptop, or across a massive multi-node, multi-GPU cluster with no additional code changes. 18 | 19 | This library also comes with an integration with [Ray Tune](https://tune.io) for distributed hyperparameter tuning experiments. 20 | 21 | 22 | ## Table of Contents 23 | 1. [Installation](#installation) 24 | 2. [PyTorch Lightning Compatibility](#pytorch-lightning-compatibility) 25 | 3. [PyTorch Distributed Data Parallel Plugin on Ray](#pytorch-distributed-data-parallel-plugin-on-ray) 26 | 4. [Multi-Node Distributed Training](#multinode-distributed-training) 27 | 5. [Multi-Node Training from your Laptop](#multinode-training-from-your-laptop) 28 | 5. [Horovod Plugin on Ray](#horovod-plugin-on-ray) 29 | 6. [Model Parallel Sharded Training on Ray](#model-parallel-sharded-training-on-ray) 30 | 7. [Hyperparameter Tuning with Ray Tune](#hyperparameter-tuning-with-ray-tune) 31 | 8. [FAQ](#faq) 32 | 33 | 34 | 35 | ## Installation 36 | You can install Ray Lightning via `pip`: 37 | 38 | `pip install ray_lightning` 39 | 40 | Or to install master: 41 | 42 | `pip install git+https://github.com/ray-project/ray_lightning#ray_lightning` 43 | 44 | ## PyTorch Lightning Compatibility 45 | Here are the supported PyTorch Lightning versions: 46 | 47 | | Ray Lightning | PyTorch Lightning | 48 | |---------------|-------------------| 49 | | 0.1 | 1.4 | 50 | | 0.2 | 1.5 | 51 | | 0.3 | 1.6 | 52 | | master | 1.6 | 53 | 54 | 55 | ## PyTorch Distributed Data Parallel Strategy on Ray 56 | The `RayStrategy` provides Distributed Data Parallel training on a Ray cluster. PyTorch DDP is used as the distributed training protocol, and Ray is used to launch and manage the training worker processes. 57 | 58 | Here is a simplified example: 59 | 60 | ```python 61 | import pytorch_lightning as pl 62 | from ray_lightning import RayStrategy 63 | 64 | # Create your PyTorch Lightning model here. 65 | ptl_model = MNISTClassifier(...) 66 | strategy = RayStrategy(num_workers=4, num_cpus_per_worker=1, use_gpu=True) 67 | 68 | # Don't set ``gpus`` in the ``Trainer``. 69 | # The actual number of GPUs is determined by ``num_workers``. 70 | trainer = pl.Trainer(..., strategy=strategy) 71 | trainer.fit(ptl_model) 72 | ``` 73 | 74 | Because Ray is used to launch processes, instead of the same script being called multiple times, you CAN use this strategy even in cases when you cannot use the standard `DDPStrategy` such as 75 | - Jupyter Notebooks, Google Colab, Kaggle 76 | - Calling `fit` or `test` multiple times in the same script 77 | 78 | ## Multi-node Distributed Training 79 | Using the same examples above, you can run distributed training on a multi-node cluster with just a couple simple steps. 80 | 81 | First, use Ray's [Cluster launcher](https://docs.ray.io/en/latest/cluster/quickstart.html) to start a Ray cluster: 82 | 83 | ```bash 84 | ray up my_cluster_config.yaml 85 | ``` 86 | 87 | Then, run your Ray script using one of the following options: 88 | 89 | 1. on the head node of the cluster (``python train_script.py``) 90 | 2. via ``ray job submit`` ([docs](https://docs.ray.io/en/latest/cluster/job-submission.html)) from your laptop (``ray job submit -- python train.py``) 91 | 92 | ## Multi-node Training from your Laptop 93 | Ray provides capabilities to run multi-node and GPU training all from your laptop through 94 | [Ray Client](https://docs.ray.io/en/master/cluster/ray-client.html) 95 | 96 | Ray's [Cluster launcher](https://docs.ray.io/en/latest/cluster/quickstart.html) to setup the cluster. 97 | Then, add this line to the beginning of your script to connect to the cluster: 98 | ```python 99 | import ray 100 | # replace with the appropriate host and port 101 | ray.init("ray://:10001") 102 | ``` 103 | Now you can run your training script on the laptop, but have it execute as if your laptop has all the resources of the cluster essentially providing you with an **infinite laptop**. 104 | 105 | **Note:** When using with Ray Client, you must disable checkpointing and logging for your Trainer by setting `checkpoint_callback` and `logger` to `False`. 106 | 107 | ## Horovod Strategy on Ray 108 | Or if you prefer to use Horovod as the distributed training protocol, use the `HorovodRayStrategy` instead. 109 | 110 | ```python 111 | import pytorch_lightning as pl 112 | from ray_lightning import HorovodRayStrategy 113 | 114 | # Create your PyTorch Lightning model here. 115 | ptl_model = MNISTClassifier(...) 116 | 117 | # 2 workers, 1 CPU and 1 GPU each. 118 | strategy = HorovodRayStrategy(num_workers=2, use_gpu=True) 119 | 120 | # Don't set ``gpus`` in the ``Trainer``. 121 | # The actual number of GPUs is determined by ``num_workers``. 122 | trainer = pl.Trainer(..., strategy=strategy) 123 | trainer.fit(ptl_model) 124 | ``` 125 | 126 | ## Model Parallel Sharded Training on Ray 127 | The `RayShardedStrategy` integrates with [FairScale](https://github.com/facebookresearch/fairscale) to provide sharded DDP training on a Ray cluster. 128 | With sharded training, leverage the scalability of data parallel training while drastically reducing memory usage when training large models. 129 | 130 | ```python 131 | import pytorch_lightning as pl 132 | from ray_lightning import RayShardedStrategy 133 | 134 | # Create your PyTorch Lightning model here. 135 | ptl_model = MNISTClassifier(...) 136 | strategy = RayShardedStrategy(num_workers=4, num_cpus_per_worker=1, use_gpu=True) 137 | 138 | # Don't set ``gpus`` in the ``Trainer``. 139 | # The actual number of GPUs is determined by ``num_workers``. 140 | trainer = pl.Trainer(..., strategy=strategy) 141 | trainer.fit(ptl_model) 142 | ``` 143 | See the [Pytorch Lightning docs](https://pytorch-lightning.readthedocs.io/en/stable/advanced/model_parallel.html#sharded-training) for more information on sharded training. 144 | 145 | ## Hyperparameter Tuning with Ray Tune 146 | `ray_lightning` also integrates with Ray Tune to provide distributed hyperparameter tuning for your distributed model training. You can run multiple PyTorch Lightning training runs in parallel, each with a different hyperparameter configuration, and each training run parallelized by itself. All you have to do is move your training code to a function, pass the function to tune.run, and make sure to add the appropriate callback (Either `TuneReportCallback` or `TuneReportCheckpointCallback`) to your PyTorch Lightning Trainer. 147 | 148 | Example using `ray_lightning` with Tune: 149 | 150 | ```python 151 | from ray import tune 152 | 153 | from ray_lightning import RayStrategy 154 | from ray_lightning.examples.ray_ddp_example import MNISTClassifier 155 | from ray_lightning.tune import TuneReportCallback, get_tune_resources 156 | 157 | import pytorch_lightning as pl 158 | 159 | 160 | def train_mnist(config): 161 | 162 | # Create your PTL model. 163 | model = MNISTClassifier(config) 164 | 165 | # Create the Tune Reporting Callback 166 | metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"} 167 | callbacks = [TuneReportCallback(metrics, on="validation_end")] 168 | 169 | trainer = pl.Trainer( 170 | max_epochs=4, 171 | callbacks=callbacks, 172 | strategy=RayStrategy(num_workers=4, use_gpu=False)) 173 | trainer.fit(model) 174 | 175 | config = { 176 | "layer_1": tune.choice([32, 64, 128]), 177 | "layer_2": tune.choice([64, 128, 256]), 178 | "lr": tune.loguniform(1e-4, 1e-1), 179 | "batch_size": tune.choice([32, 64, 128]), 180 | } 181 | 182 | # Make sure to pass in ``resources_per_trial`` using the ``get_tune_resources`` utility. 183 | analysis = tune.run( 184 | train_mnist, 185 | metric="loss", 186 | mode="min", 187 | config=config, 188 | num_samples=2, 189 | resources_per_trial=get_tune_resources(num_workers=4), 190 | name="tune_mnist") 191 | 192 | print("Best hyperparameters found were: ", analysis.best_config) 193 | ``` 194 | **Note:** Ray Tune requires 1 additional CPU per trial to use for the Trainable driver. So the actual number of resources each trial requires is `num_workers * num_cpus_per_worker + 1`. 195 | 196 | ## FAQ 197 | > I see that `RayStrategy` is based off of Pytorch Lightning's `DDPSpawnStrategy`. However, doesn't the PTL team discourage the use of spawn? 198 | 199 | As discussed [here](https://github.com/pytorch/pytorch/issues/51688#issuecomment-773539003), using a spawn approach instead of launch is not all that detrimental. The original factors for discouraging spawn were: 200 | 1. not being able to use 'spawn' in a Jupyter or Colab notebook, and 201 | 2. not being able to use multiple workers for data loading. 202 | 203 | Neither of these should be an issue with the `RayStrategy` due to Ray's serialization mechanisms. The only thing to keep in mind is that when using this strategy, your model does have to be serializable/pickleable. 204 | 205 | > Horovod installation issue 206 | please see [details](https://github.com/ray-project/ray_lightning/blob/main/docs/horovod_faq.md) 207 | 208 | 235 | -------------------------------------------------------------------------------- /docs/horovod_faq.md: -------------------------------------------------------------------------------- 1 | # Horovod installation issue 2 | 3 | > ``` 4 | > Extension horovod.torch has not been built: /home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/horovod/torch/mpi_lib/_mpi_lib.cpython-38-x86_64-linux-gnu.so not found 5 | > If this is not expected, reinstall Horovod with HOROVOD_WITH_PYTORCH=1 to debug the build error. 6 | >Warning! MPI libs are missing, but python applications are still avaiable. 7 | > ``` 8 | 9 | One might fix this issue by 10 | ```python 11 | $ pip uninstall -y horovod 12 | $ conda install gcc_linux-64 gxx_linux-64 13 | $ [flags] pip install --no-cache-dir horovod 14 | ``` 15 | 16 | from [here](https://github.com/horovod/horovod/issues/656), [here](https://github.com/tlkh/ai-lab/issues/27) and [here](https://horovod.readthedocs.io/en/stable/install_include.html) 17 | 18 | - install horovod from scratch with torch 19 | 20 | ```python 21 | conda create -n hd python=3.8 scipy numpy pandas -y 22 | conda activate hd 23 | conda install pytorch=1.11 torchvision torchaudio cudatoolkit=11.3 -c pytorch -y 24 | sudo rm -rf /usr/local/cuda 25 | sudo ln -s /usr/local/cuda-11.3 /usr/local/cuda 26 | conda install gxx_linux-64 -y 27 | conda install cxx-compiler=1.0 -y 28 | export TORCH_CUDA_ARCH_LIST="3.7;5.0;6.0;7.0;7.5;8.0" 29 | echo $TORCH_CUDA_ARCH_LIST 30 | sudo apt-get purge -y cmake 31 | wget -q https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz 32 | tar -zxvf cmake-3.20.2.tar.gz 33 | cd cmake-3.20.2 34 | ./bootstrap -- -DCMAKE_USE_OPENSSL=OFF 35 | make -j10 36 | sudo make install 37 | cmake --version 38 | export CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda 39 | export HOROVOD_NCCL_HOME=/usr/local/cuda/ 40 | export HOROVOD_NCCL_INCLUDE=/usr/local/cuda/include 41 | export TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST//";8.0"/} 42 | export HOROVOD_BUILD_CUDA_CC_LIST=${TORCH_CUDA_ARCH_LIST//";"/","} 43 | export HOROVOD_BUILD_CUDA_CC_LIST=${HOROVOD_BUILD_CUDA_CC_LIST//"."/""} 44 | export PATH=/usr/local/cuda/bin/:$PATH 45 | export HOROVOD_NCCL_LIB=/usr/local/cuda/lib/ 46 | HOROVOD_NCCL_HOME=/usr/local/cuda HOROVOD_GPU_OPERATIONS=NCCL HOROVOD_WITH_PYTORCH=1 HOROVOD_WITHOUT_TENSORFLOW=1 HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITHOUT_GLOO=1 pip install --no-cache-dir horovod 47 | ``` 48 | 49 | [reference 1](https://stackoverflow.com/questions/54948216/usr-lib-x86-64-linux-gnu-libstdc-so-6-version-glibcxx-3-4-21-not-found-req) and [reference 2](https://github.com/horovod/horovod/issues/401) and [reference 3](https://github.com/Lightning-AI/lightning/issues/4472) and [reference 4](https://github.com/horovod/horovod/issues/2276) and [reference 5](https://github.com/Lightning-AI/lightning/blob/master/dockers/base-cuda/Dockerfile#L105-L121) and [reference 6](https://horovod.readthedocs.io/en/stable/gpus_include.html) and [reference 7](https://horovod.readthedocs.io/en/stable/conda_include.html) and [reference 8](https://github.com/horovod/horovod/issues/3545) and [reference 9](https://github.com/KAUST-CTL/horovod-gpu-data-science-project) and [reference 10](https://kose-y.github.io/blog/2017/12/installing-cuda-aware-mpi/) -------------------------------------------------------------------------------- /examples: -------------------------------------------------------------------------------- 1 | ray_lightning/examples/ -------------------------------------------------------------------------------- /format.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. 3 | # You are encouraged to run this locally before pushing changes for review. 4 | 5 | # Cause the script to exit if a single command fails 6 | set -eo pipefail 7 | 8 | FLAKE8_VERSION_REQUIRED="3.7.7" 9 | YAPF_VERSION_REQUIRED="0.23.0" 10 | 11 | check_command_exist() { 12 | VERSION="" 13 | case "$1" in 14 | yapf) 15 | VERSION=$YAPF_VERSION_REQUIRED 16 | ;; 17 | flake8) 18 | VERSION=$FLAKE8_VERSION_REQUIRED 19 | ;; 20 | *) 21 | echo "$1 is not a required dependency" 22 | exit 1 23 | esac 24 | if ! [ -x "$(command -v $1)" ]; then 25 | echo "$1 not installed. pip install $1==$VERSION" 26 | exit 1 27 | fi 28 | } 29 | 30 | check_command_exist yapf 31 | check_command_exist flake8 32 | 33 | ver=$(yapf --version) 34 | if ! echo $ver | grep -q 0.23.0; then 35 | echo "Wrong YAPF version installed: 0.23.0 is required, not $ver. $YAPF_DOWNLOAD_COMMAND_MSG" 36 | exit 1 37 | fi 38 | 39 | # this stops git rev-parse from failing if we run this from the .git directory 40 | builtin cd "$(dirname "${BASH_SOURCE:-$0}")" 41 | 42 | ROOT="$(git rev-parse --show-toplevel)" 43 | builtin cd "$ROOT" || exit 1 44 | 45 | # Add the upstream remote if it doesn't exist 46 | if ! git remote -v | grep -q upstream; then 47 | git remote add 'upstream' 'https://github.com/ray-project/ray_lightning_accelerators.git' 48 | fi 49 | 50 | FLAKE8_VERSION=$(flake8 --version | awk '{print $1}') 51 | YAPF_VERSION=$(yapf --version | awk '{print $2}') 52 | 53 | # params: tool name, tool version, required version 54 | tool_version_check() { 55 | if [[ $2 != $3 ]]; then 56 | echo "WARNING: Ray uses $1 $3, You currently are using $2. This might generate different results." 57 | fi 58 | } 59 | 60 | tool_version_check "flake8" $FLAKE8_VERSION $FLAKE8_VERSION_REQUIRED 61 | tool_version_check "yapf" $YAPF_VERSION $YAPF_VERSION_REQUIRED 62 | 63 | if which clang-format >/dev/null; then 64 | CLANG_FORMAT_VERSION=$(clang-format --version | awk '{print $3}') 65 | tool_version_check "clang-format" $CLANG_FORMAT_VERSION "7.0.0" 66 | else 67 | echo "WARNING: clang-format is not installed!" 68 | fi 69 | 70 | # Only fetch main since that's the branch we're diffing against. 71 | git fetch upstream main || true 72 | 73 | YAPF_FLAGS=( 74 | '--style' "$ROOT/.style.yapf" 75 | '--recursive' 76 | '--parallel' 77 | ) 78 | 79 | YAPF_EXCLUDES=( 80 | # '--exclude' 'python/ray/cloudpickle/*' 81 | # '--exclude' 'python/build/*' 82 | # '--exclude' 'python/ray/core/src/ray/gcs/*' 83 | # '--exclude' 'python/ray/thirdparty_files/*' 84 | ) 85 | 86 | # Format specified files 87 | format() { 88 | yapf --in-place "${YAPF_FLAGS[@]}" -- "$@" 89 | } 90 | 91 | # Format files that differ from main branch. Ignores dirs that are not slated 92 | # for autoformat yet. 93 | format_changed() { 94 | # The `if` guard ensures that the list of filenames is not empty, which 95 | # could cause yapf to receive 0 positional arguments, making it hang 96 | # waiting for STDIN. 97 | # 98 | # `diff-filter=ACRM` and $MERGEBASE is to ensure we only format files that 99 | # exist on both branches. 100 | MERGEBASE="$(git merge-base upstream/main HEAD)" 101 | 102 | if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.py' &>/dev/null; then 103 | git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \ 104 | yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" 105 | if which flake8 >/dev/null; then 106 | git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \ 107 | flake8 --inline-quotes '"' --no-avoid-escape --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 108 | fi 109 | fi 110 | 111 | if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then 112 | if which flake8 >/dev/null; then 113 | git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \ 114 | flake8 --inline-quotes '"' --no-avoid-escape --ignore=C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 115 | fi 116 | fi 117 | } 118 | 119 | # Format all files, and print the diff to stdout for travis. 120 | format_all() { 121 | yapf --diff "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" ray_lightning 122 | flake8 --inline-quotes '"' --no-avoid-escape --ignore=C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 ray_lightning 123 | } 124 | 125 | # This flag formats individual files. --files *must* be the first command line 126 | # arg to use this option. 127 | if [[ "$1" == '--files' ]]; then 128 | format "${@:2}" 129 | # If `--all` is passed, then any further arguments are ignored and the 130 | # entire python directory is formatted. 131 | elif [[ "$1" == '--all' ]]; then 132 | format_all 133 | else 134 | # Format only the files that changed in last commit. 135 | format_changed 136 | fi 137 | 138 | if ! git diff --quiet &>/dev/null; then 139 | echo 'Reformatted changed files. Please review and stage the changes.' 140 | echo 'Files updated:' 141 | echo 142 | 143 | git --no-pager diff --name-only 144 | 145 | exit 1 146 | fi 147 | 148 | echo 'Linting check finished successfully.' -------------------------------------------------------------------------------- /ray_lightning/__init__.py: -------------------------------------------------------------------------------- 1 | from ray_lightning.ray_ddp import RayStrategy 2 | from ray_lightning.ray_horovod import HorovodRayStrategy 3 | from ray_lightning.ray_ddp_sharded import RayShardedStrategy 4 | 5 | __all__ = ["RayStrategy", "HorovodRayStrategy", "RayShardedStrategy"] 6 | -------------------------------------------------------------------------------- /ray_lightning/accelerators/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright The PyTorch Lightning team. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | from pytorch_lightning.accelerators.registry import \ 14 | call_register_accelerators # noqa: F401 15 | from ray_lightning.accelerators.delayed_gpu_accelerator import _GPUAccelerator 16 | 17 | # these lines are to register the delayed gpu accelerator as `_gpu` 18 | ACCELERATORS_BASE_MODULE = "ray_lightning.accelerators" 19 | call_register_accelerators(ACCELERATORS_BASE_MODULE) 20 | 21 | __all__ = ["_GPUAccelerator"] 22 | -------------------------------------------------------------------------------- /ray_lightning/accelerators/delayed_gpu_accelerator.py: -------------------------------------------------------------------------------- 1 | # Copyright The PyTorch Lightning team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Dict, List 15 | 16 | import torch 17 | 18 | from pytorch_lightning.accelerators import Accelerator,\ 19 | GPUAccelerator 20 | 21 | 22 | class _GPUAccelerator(GPUAccelerator): 23 | """Accelerator for GPU devices. 24 | 25 | adapted from: 26 | https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/accelerators/gpu.py#L43 27 | but remove `torch.cuda.set_device(root_device)` in `setup_environment` 28 | """ 29 | 30 | def setup_environment(self, root_device: torch.device) -> None: 31 | """ 32 | modified: remove `torch.cuda.set_device(root_device)` 33 | and call `torch.cuda.set_device(self.device)` at the later time 34 | inside the `ray_launcher` or `horovod_launcher` 35 | """ 36 | Accelerator.setup_environment(self, root_device) 37 | 38 | @staticmethod 39 | def get_parallel_devices(devices: List[int]) -> List[torch.device]: 40 | """Gets parallel devices for the Accelerator.""" 41 | # modified: return None when no devices are available 42 | if devices: 43 | return [torch.device("cuda", i) for i in devices] 44 | else: 45 | return [] 46 | 47 | @staticmethod 48 | def is_available() -> bool: 49 | # modified to always return True 50 | return True 51 | 52 | @classmethod 53 | def register_accelerators(cls, accelerator_registry: Dict) -> None: 54 | # the delayed gpu accelerator is registered as `_gpu` 55 | # in the accelerator registry 56 | accelerator_registry.register( 57 | "_gpu", 58 | cls, 59 | description=f"{cls.__class__.__name__}", 60 | ) 61 | -------------------------------------------------------------------------------- /ray_lightning/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray-project/ray_lightning/24f5922ee069f5c72087e300034c9be11b65016f/ray_lightning/examples/__init__.py -------------------------------------------------------------------------------- /ray_lightning/examples/ray_ddp_example.py: -------------------------------------------------------------------------------- 1 | """Example using Pytorch Lightning with Pytorch DDP on Ray Accelerator.""" 2 | import os 3 | import tempfile 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | from torch.utils.data import random_split, DataLoader 8 | from torchvision.datasets import MNIST 9 | from torchvision import transforms 10 | 11 | import ray 12 | from ray import tune 13 | from ray_lightning.tune import TuneReportCallback, get_tune_resources 14 | from ray_lightning import RayStrategy 15 | from ray_lightning.tests.utils import LightningMNISTClassifier 16 | 17 | 18 | class MNISTClassifier(LightningMNISTClassifier): 19 | def __init__(self, config, data_dir=None): 20 | super().__init__(config, data_dir) 21 | self.batch_size = config["batch_size"] 22 | 23 | def prepare_data(self): 24 | self.dataset = MNIST( 25 | self.data_dir, 26 | train=True, 27 | download=True, 28 | transform=transforms.ToTensor()) 29 | 30 | def train_dataloader(self): 31 | dataset = self.dataset 32 | train_length = len(dataset) 33 | dataset_train, _ = random_split( 34 | dataset, [train_length - 5000, 5000], 35 | generator=torch.Generator().manual_seed(0)) 36 | loader = DataLoader( 37 | dataset_train, 38 | batch_size=self.batch_size, 39 | num_workers=1, 40 | drop_last=True, 41 | pin_memory=True, 42 | ) 43 | return loader 44 | 45 | def val_dataloader(self): 46 | dataset = self.dataset 47 | train_length = len(dataset) 48 | _, dataset_val = random_split( 49 | dataset, [train_length - 5000, 5000], 50 | generator=torch.Generator().manual_seed(0)) 51 | loader = DataLoader( 52 | dataset_val, 53 | batch_size=self.batch_size, 54 | num_workers=1, 55 | drop_last=True, 56 | pin_memory=True, 57 | ) 58 | return loader 59 | 60 | 61 | def train_mnist(config, 62 | checkpoint_dir=None, 63 | data_dir=None, 64 | num_epochs=10, 65 | num_workers=1, 66 | use_gpu=False, 67 | callbacks=None, 68 | **trainer_kwargs): 69 | model = MNISTClassifier(config, data_dir) 70 | 71 | callbacks = callbacks or [] 72 | 73 | trainer = pl.Trainer( 74 | max_epochs=num_epochs, 75 | callbacks=callbacks, 76 | strategy=RayStrategy(num_workers=num_workers, use_gpu=use_gpu), 77 | **trainer_kwargs) 78 | trainer.fit(model) 79 | 80 | 81 | def tune_mnist(data_dir, 82 | num_samples=10, 83 | num_epochs=10, 84 | num_workers=1, 85 | use_gpu=False, 86 | **trainer_kwargs): 87 | config = { 88 | "layer_1": tune.choice([32, 64, 128]), 89 | "layer_2": tune.choice([64, 128, 256]), 90 | "lr": tune.loguniform(1e-4, 1e-1), 91 | "batch_size": tune.choice([32, 64, 128]), 92 | } 93 | 94 | # Add Tune callback. 95 | metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"} 96 | callbacks = [TuneReportCallback(metrics, on="validation_end")] 97 | trainable = tune.with_parameters( 98 | train_mnist, 99 | data_dir=data_dir, 100 | num_epochs=num_epochs, 101 | num_workers=num_workers, 102 | use_gpu=use_gpu, 103 | callbacks=callbacks, 104 | **trainer_kwargs) 105 | analysis = tune.run( 106 | trainable, 107 | metric="loss", 108 | mode="min", 109 | config=config, 110 | num_samples=num_samples, 111 | resources_per_trial=get_tune_resources( 112 | num_workers=num_workers, use_gpu=use_gpu), 113 | name="tune_mnist") 114 | 115 | print("Best hyperparameters found were: ", analysis.best_config) 116 | 117 | 118 | if __name__ == "__main__": 119 | import argparse 120 | 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument( 123 | "--num-workers", 124 | type=int, 125 | help="Number of training workers to use.", 126 | default=1) 127 | parser.add_argument( 128 | "--use-gpu", action="store_true", help="Use GPU for training.") 129 | parser.add_argument( 130 | "--tune", 131 | action="store_true", 132 | help="Use Ray Tune for hyperparameter tuning.") 133 | parser.add_argument( 134 | "--num-samples", 135 | type=int, 136 | default=10, 137 | help="Number of samples to tune.") 138 | parser.add_argument( 139 | "--num-epochs", 140 | type=int, 141 | default=10, 142 | help="Number of epochs to train for.") 143 | parser.add_argument( 144 | "--smoke-test", action="store_true", help="Finish quickly for testing") 145 | parser.add_argument( 146 | "--address", 147 | required=False, 148 | type=str, 149 | help="the address to use for Ray") 150 | args, _ = parser.parse_known_args() 151 | 152 | num_epochs = 1 if args.smoke_test else args.num_epochs 153 | num_workers = 1 if args.smoke_test else args.num_workers 154 | use_gpu = False if args.smoke_test else args.use_gpu 155 | num_samples = 1 if args.smoke_test else args.num_samples 156 | 157 | if args.smoke_test: 158 | ray.init(num_cpus=2) 159 | else: 160 | ray.init(address=args.address) 161 | 162 | data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_") 163 | 164 | if args.tune: 165 | tune_mnist(data_dir, num_samples, num_epochs, num_workers, use_gpu) 166 | else: 167 | config = {"layer_1": 32, "layer_2": 64, "lr": 1e-1, "batch_size": 32} 168 | train_mnist( 169 | config, 170 | data_dir=data_dir, 171 | num_epochs=num_epochs, 172 | num_workers=num_workers, 173 | use_gpu=use_gpu) 174 | -------------------------------------------------------------------------------- /ray_lightning/examples/ray_ddp_sharded_example.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import time 4 | 5 | import ray 6 | import torch 7 | from pl_bolts.datamodules import MNISTDataModule 8 | from pl_bolts.models.vision import ImageGPT 9 | 10 | import pytorch_lightning as pl 11 | from pytorch_lightning import Callback 12 | 13 | from ray_lightning import RayShardedStrategy 14 | 15 | 16 | class CUDACallback(Callback): 17 | def on_train_epoch_start(self, trainer, pl_module): 18 | # Reset the memory use counter 19 | torch.cuda.reset_peak_memory_stats(trainer.root_gpu) 20 | torch.cuda.synchronize(trainer.root_gpu) 21 | self.start_time = time.time() 22 | 23 | def on_train_epoch_end(self, trainer, pl_module): 24 | torch.cuda.synchronize(trainer.root_gpu) 25 | max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20 26 | epoch_time = time.time() - self.start_time 27 | 28 | max_memory = torch.tensor( 29 | max_memory, dtype=torch.int, device=trainer.root_gpu) 30 | epoch_time = torch.tensor( 31 | epoch_time, dtype=torch.int, device=trainer.root_gpu) 32 | 33 | torch.distributed.all_reduce( 34 | max_memory, op=torch.distributed.ReduceOp.SUM) 35 | torch.distributed.all_reduce( 36 | epoch_time, op=torch.distributed.ReduceOp.SUM) 37 | 38 | world_size = torch.distributed.get_world_size() 39 | 40 | print( 41 | f"Average Epoch time: {epoch_time.item() / float(world_size):.2f} " 42 | f"seconds") 43 | print( 44 | f"Average Peak memory {max_memory.item() / float(world_size):.2f}" 45 | f"MiB") 46 | 47 | 48 | def train(data_dir, num_workers, use_gpu, batch_size, embed_dim, max_epochs, 49 | max_steps): 50 | # Make sure data is downloaded on all nodes. 51 | def download_data(): 52 | from filelock import FileLock 53 | with FileLock(os.path.join(data_dir, ".lock")): 54 | MNISTDataModule(data_dir=data_dir).prepare_data() 55 | 56 | strategy = RayShardedStrategy( 57 | num_workers=num_workers, use_gpu=use_gpu, init_hook=download_data) 58 | 59 | dm = MNISTDataModule(data_dir, batch_size=batch_size) 60 | 61 | model = ImageGPT( 62 | embed_dim=embed_dim, layers=16, heads=4, vocab_size=32, num_pixels=28) 63 | 64 | trainer = pl.Trainer( 65 | max_epochs=max_epochs, 66 | precision=16 if use_gpu else 32, 67 | callbacks=[CUDACallback()] if use_gpu else [], 68 | strategy=strategy, 69 | max_steps=max_steps) 70 | 71 | trainer.fit(model, dm) 72 | 73 | 74 | if __name__ == "__main__": 75 | import argparse 76 | 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument( 79 | "--num-workers", 80 | type=int, 81 | help="Number of training workers to use.", 82 | default=1) 83 | parser.add_argument( 84 | "--use-gpu", action="store_true", help="Use GPU for training.") 85 | parser.add_argument( 86 | "--num-epochs", 87 | type=int, 88 | default=10, 89 | help="Number of epochs to train for.") 90 | parser.add_argument( 91 | "--batch-size", 92 | type=int, 93 | default=4, 94 | help="Batch size to use for training.") 95 | parser.add_argument( 96 | "--embed-dim", 97 | type=int, 98 | default=2048, 99 | help="Number of embedding dimensions for ImageGPT model.") 100 | parser.add_argument( 101 | "--smoke-test", action="store_true", help="Finish quickly for testing") 102 | parser.add_argument( 103 | "--address", 104 | required=False, 105 | type=str, 106 | help="the address to use for Ray") 107 | args, _ = parser.parse_known_args() 108 | 109 | if args.smoke_test: 110 | ray.init(num_cpus=2) 111 | else: 112 | ray.init(address=args.address) 113 | 114 | data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_") 115 | 116 | if args.smoke_test: 117 | train( 118 | data_dir=data_dir, 119 | num_workers=2, 120 | use_gpu=False, 121 | batch_size=32, 122 | embed_dim=16, 123 | max_epochs=1, 124 | max_steps=1) 125 | else: 126 | train( 127 | data_dir=data_dir, 128 | num_workers=args.num_workers, 129 | use_gpu=args.use_gpu, 130 | batch_size=args.batch_size, 131 | embed_dim=args.embed_dim, 132 | max_epochs=args.num_epochs, 133 | max_steps=None) 134 | -------------------------------------------------------------------------------- /ray_lightning/examples/ray_ddp_tune.py: -------------------------------------------------------------------------------- 1 | """Simple example using RayAccelerator and Ray Tune""" 2 | import os 3 | import tempfile 4 | 5 | from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule 6 | 7 | import pytorch_lightning as pl 8 | import ray 9 | from ray import tune 10 | from ray_lightning.tune import TuneReportCallback, get_tune_resources 11 | from ray_lightning import RayStrategy 12 | from ray_lightning.tests.utils import LightningMNISTClassifier 13 | 14 | 15 | def train_mnist(config, 16 | data_dir=None, 17 | num_epochs=10, 18 | num_workers=1, 19 | use_gpu=False, 20 | callbacks=None): 21 | # Make sure data is downloaded on all nodes. 22 | def download_data(): 23 | from filelock import FileLock 24 | with FileLock(os.path.join(data_dir, ".lock")): 25 | MNISTDataModule(data_dir=data_dir).prepare_data() 26 | 27 | model = LightningMNISTClassifier(config, data_dir) 28 | 29 | callbacks = callbacks or [] 30 | 31 | trainer = pl.Trainer( 32 | max_epochs=num_epochs, 33 | callbacks=callbacks, 34 | progress_bar_refresh_rate=0, 35 | strategy=RayStrategy( 36 | num_workers=num_workers, use_gpu=use_gpu, init_hook=download_data)) 37 | dm = MNISTDataModule( 38 | data_dir=data_dir, num_workers=1, batch_size=config["batch_size"]) 39 | trainer.fit(model, dm) 40 | 41 | 42 | def tune_mnist(data_dir, 43 | num_samples=10, 44 | num_epochs=10, 45 | num_workers=1, 46 | use_gpu=False): 47 | config = { 48 | "layer_1": tune.choice([32, 64, 128]), 49 | "layer_2": tune.choice([64, 128, 256]), 50 | "lr": tune.loguniform(1e-4, 1e-1), 51 | "batch_size": tune.choice([32, 64, 128]), 52 | } 53 | 54 | # Add Tune callback. 55 | metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"} 56 | callbacks = [TuneReportCallback(metrics, on="validation_end")] 57 | trainable = tune.with_parameters( 58 | train_mnist, 59 | data_dir=data_dir, 60 | num_epochs=num_epochs, 61 | num_workers=num_workers, 62 | use_gpu=use_gpu, 63 | callbacks=callbacks) 64 | analysis = tune.run( 65 | trainable, 66 | metric="loss", 67 | mode="min", 68 | config=config, 69 | num_samples=num_samples, 70 | resources_per_trial=get_tune_resources( 71 | num_workers=num_workers, use_gpu=use_gpu), 72 | name="tune_mnist") 73 | 74 | print("Best hyperparameters found were: ", analysis.best_config) 75 | 76 | 77 | if __name__ == "__main__": 78 | import argparse 79 | 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument( 82 | "--num-workers", 83 | type=int, 84 | help="Number of training workers to use.", 85 | default=1) 86 | parser.add_argument( 87 | "--use-gpu", action="store_true", help="Use GPU for training.") 88 | parser.add_argument( 89 | "--num-samples", 90 | type=int, 91 | default=10, 92 | help="Number of samples to tune.") 93 | parser.add_argument( 94 | "--num-epochs", 95 | type=int, 96 | default=10, 97 | help="Number of epochs to train for.") 98 | parser.add_argument( 99 | "--smoke-test", action="store_true", help="Finish quickly for testing") 100 | parser.add_argument( 101 | "--address", 102 | required=False, 103 | type=str, 104 | help="the address to use for Ray") 105 | args, _ = parser.parse_known_args() 106 | 107 | num_epochs = 1 if args.smoke_test else args.num_epochs 108 | num_workers = 1 if args.smoke_test else args.num_workers 109 | use_gpu = False if args.smoke_test else args.use_gpu 110 | num_samples = 1 if args.smoke_test else args.num_samples 111 | 112 | if args.smoke_test: 113 | ray.init(num_cpus=2) 114 | else: 115 | ray.init(address=args.address) 116 | 117 | data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_") 118 | tune_mnist(data_dir, num_samples, num_epochs, num_workers, use_gpu) 119 | -------------------------------------------------------------------------------- /ray_lightning/examples/ray_horovod_example.py: -------------------------------------------------------------------------------- 1 | """Example using Pytorch Lightning with a Horovod on Ray Accelerator.""" 2 | import os 3 | import tempfile 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | from torch.utils.data import random_split, DataLoader 8 | from torchvision.datasets import MNIST 9 | from torchvision import transforms 10 | 11 | import ray 12 | from ray import tune 13 | from ray_lightning.tune import TuneReportCallback, get_tune_resources 14 | from ray_lightning import HorovodRayStrategy 15 | from ray_lightning.tests.utils import LightningMNISTClassifier 16 | 17 | 18 | class MNISTClassifier(LightningMNISTClassifier): 19 | def __init__(self, config, data_dir=None): 20 | super().__init__(config, data_dir) 21 | self.batch_size = config["batch_size"] 22 | 23 | def prepare_data(self): 24 | self.dataset = MNIST( 25 | self.data_dir, 26 | train=True, 27 | download=True, 28 | transform=transforms.ToTensor()) 29 | 30 | def train_dataloader(self): 31 | dataset = self.dataset 32 | train_length = len(dataset) 33 | dataset_train, _ = random_split( 34 | dataset, [train_length - 5000, 5000], 35 | generator=torch.Generator().manual_seed(0)) 36 | loader = DataLoader( 37 | dataset_train, 38 | batch_size=self.batch_size, 39 | shuffle=True, 40 | num_workers=1, 41 | drop_last=True, 42 | pin_memory=True, 43 | ) 44 | return loader 45 | 46 | def val_dataloader(self): 47 | dataset = self.dataset 48 | train_length = len(dataset) 49 | _, dataset_val = random_split( 50 | dataset, [train_length - 5000, 5000], 51 | generator=torch.Generator().manual_seed(0)) 52 | loader = DataLoader( 53 | dataset_val, 54 | batch_size=self.batch_size, 55 | shuffle=False, 56 | num_workers=1, 57 | drop_last=True, 58 | pin_memory=True, 59 | ) 60 | return loader 61 | 62 | 63 | def train_mnist(config, 64 | data_dir=None, 65 | num_epochs=10, 66 | num_workers=4, 67 | use_gpu=False, 68 | callbacks=None): 69 | model = MNISTClassifier(config, data_dir) 70 | 71 | callbacks = callbacks or [] 72 | 73 | trainer = pl.Trainer( 74 | max_epochs=num_epochs, 75 | callbacks=callbacks, 76 | strategy=HorovodRayStrategy(num_workers=num_workers, use_gpu=use_gpu)) 77 | trainer.fit(model) 78 | 79 | 80 | def tune_mnist(data_dir, 81 | num_samples=10, 82 | num_epochs=10, 83 | num_workers=4, 84 | use_gpu=False): 85 | config = { 86 | "layer_1": tune.choice([32, 64, 128]), 87 | "layer_2": tune.choice([64, 128, 256]), 88 | "lr": tune.loguniform(1e-4, 1e-1), 89 | "batch_size": tune.choice([32, 64, 128]), 90 | } 91 | 92 | # Add Tune callback. 93 | metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"} 94 | callbacks = [TuneReportCallback(metrics, on="validation_end")] 95 | trainable = tune.with_parameters( 96 | train_mnist, 97 | data_dir=data_dir, 98 | num_epochs=num_epochs, 99 | num_workers=num_workers, 100 | use_gpu=use_gpu, 101 | callbacks=callbacks) 102 | analysis = tune.run( 103 | trainable, 104 | metric="loss", 105 | mode="min", 106 | config=config, 107 | num_samples=num_samples, 108 | resources_per_trial=get_tune_resources( 109 | num_workers=num_workers, use_gpu=use_gpu), 110 | name="tune_mnist") 111 | 112 | print("Best hyperparameters found were: ", analysis.best_config) 113 | 114 | 115 | if __name__ == "__main__": 116 | import argparse 117 | 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument( 120 | "--num-workers", 121 | type=int, 122 | help="Number of training workers to use.", 123 | default=1) 124 | parser.add_argument( 125 | "--use-gpu", action="store_true", help="Use GPU for " 126 | "training.") 127 | parser.add_argument( 128 | "--tune", 129 | action="store_true", 130 | help="Use Ray Tune " 131 | "for " 132 | "hyperparameter " 133 | "tuning.") 134 | parser.add_argument( 135 | "--num-samples", 136 | type=int, 137 | default=10, 138 | help="Number " 139 | "of " 140 | "samples to tune.") 141 | parser.add_argument( 142 | "--num-epochs", 143 | type=int, 144 | default=10, 145 | help="Number " 146 | "of " 147 | "epochs " 148 | "to train for.") 149 | parser.add_argument( 150 | "--smoke-test", action="store_true", help="Finish quickly for testing") 151 | parser.add_argument( 152 | "--address", 153 | required=False, 154 | type=str, 155 | help="the address to use for Ray") 156 | args, _ = parser.parse_known_args() 157 | 158 | num_epochs = 1 if args.smoke_test else args.num_epochs 159 | num_workers = 1 if args.smoke_test else args.num_workers 160 | use_gpu = False if args.smoke_test else args.use_gpu 161 | num_samples = 1 if args.smoke_test else args.num_samples 162 | 163 | if args.smoke_test: 164 | ray.init(num_cpus=2) 165 | else: 166 | ray.init(address=args.address) 167 | 168 | data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_") 169 | 170 | if args.tune: 171 | tune_mnist(data_dir, num_samples, num_epochs, num_workers, use_gpu) 172 | else: 173 | config = {"layer_1": 32, "layer_2": 64, "lr": 1e-1, "batch_size": 32} 174 | train_mnist(config, data_dir, num_epochs, num_workers, use_gpu) 175 | -------------------------------------------------------------------------------- /ray_lightning/launchers/__init__.py: -------------------------------------------------------------------------------- 1 | from ray_lightning.launchers.ray_launcher import RayLauncher 2 | from ray_lightning.launchers.ray_horovod_launcher import RayHorovodLauncher 3 | 4 | __all__ = ["RayLauncher", "RayHorovodLauncher"] 5 | -------------------------------------------------------------------------------- /ray_lightning/launchers/ray_horovod_launcher.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Any, Optional 2 | 3 | import pytorch_lightning as pl 4 | from pytorch_lightning.strategies.launchers import _Launcher 5 | from pytorch_lightning.utilities.apply_func import apply_to_collection, \ 6 | move_data_to_device 7 | import numpy as np 8 | import torch 9 | 10 | import ray 11 | from pytorch_lightning.utilities.rank_zero import rank_zero_debug 12 | from pytorch_lightning.strategies import Strategy 13 | from ray.util.queue import Queue 14 | 15 | from ray_lightning.session import init_session 16 | from ray_lightning.util import process_results, Unavailable, to_state_stream, \ 17 | load_state_stream, set_cuda_device_if_used 18 | 19 | from ray_lightning.tune import TUNE_INSTALLED, is_session_enabled 20 | 21 | try: 22 | import horovod.torch as hvd 23 | from horovod.ray import RayExecutor 24 | except (ModuleNotFoundError, ImportError): 25 | HOROVOD_AVAILABLE = False 26 | RayExecutor = Unavailable 27 | hvd = Unavailable 28 | else: 29 | HOROVOD_AVAILABLE = True 30 | 31 | from pytorch_lightning.utilities import rank_zero_only 32 | from ray_lightning.accelerators import \ 33 | _GPUAccelerator # noqa: F401 34 | from ray_lightning.launchers.utils import _RayOutput, get_executable_cls 35 | 36 | 37 | class RayHorovodLauncher(_Launcher): 38 | def __init__(self, strategy: "Strategy") -> None: 39 | """Initialize the Ray horovod launcher.""" 40 | self._strategy = strategy 41 | self._executor = strategy.executor 42 | 43 | if not ray.is_initialized(): 44 | ray.init() 45 | 46 | self.tune_queue = None 47 | 48 | @property 49 | def global_rank(self) -> int: 50 | """Return the global rank of the current process. 51 | 52 | This function is run on the worker process. 53 | """ 54 | if not hvd.is_initialized(): 55 | return 0 56 | return hvd.rank() 57 | 58 | @property 59 | def local_rank(self) -> int: 60 | """Return the local rank of the current process. 61 | 62 | This function is run on the worker process. 63 | """ 64 | if not hvd.is_initialized(): 65 | return 0 66 | return hvd.local_rank() 67 | 68 | @property 69 | def world_size(self) -> int: 70 | """Return the world size of the current process. 71 | 72 | This function is run on the worker process. 73 | """ 74 | if not hvd.is_initialized(): 75 | return self.num_workers 76 | return hvd.size() 77 | 78 | def is_interactive_compatible(self) -> bool: 79 | """Return whether the launcher is interactive compatible.""" 80 | return True 81 | 82 | def launch(self, 83 | function: Callable, 84 | *args: Any, 85 | trainer: Optional["pl.Trainer"] = None, 86 | **kwargs: Any) -> Any: 87 | """Launch the function on the workers and collect the results. 88 | 89 | This function is run on the driver process. 90 | """ 91 | ray_output = self.run_function_on_workers( 92 | function, *args, trainer=trainer, **kwargs) 93 | 94 | if trainer is None: 95 | raise NotImplementedError( 96 | "Ray launcher does not support trainer is None! " 97 | "Did you override the `trainer` variable? " 98 | "If not, please help file an issue on Github.") 99 | self._recover_results_in_main_process(ray_output, trainer) 100 | return_value = ray_output.trainer_results 101 | 102 | return return_value 103 | 104 | def run_function_on_workers(self, 105 | function: Callable, 106 | *args: Any, 107 | trainer: Optional["pl.Trainer"] = None, 108 | **kwargs: Any): 109 | """Run the function on the workers and collect the results. 110 | 111 | This function is run on the driver process. 112 | 113 | `executor.run_remote` is used to launch multiple ray remote tasks 114 | to distributed training the model using the horovod backend. 115 | """ 116 | 117 | # put the model as the ray object 118 | # this reduce the memory comsumption 119 | # and remove the model temporarily from the args 120 | model = trainer.model 121 | model_ref = ray.put(model) 122 | trainer.model = None 123 | # the model always be at the 0th position in the args 124 | new_args = tuple([None] + list(args[1:])) 125 | 126 | # remove the executor temporarily from the args 127 | # in order to avoid the ray.get() call in the function 128 | # because executor is not pickleable 129 | executor = self._executor 130 | self._executor = None 131 | self._strategy.executor = None 132 | 133 | executor.start(executable_cls=get_executable_cls()) 134 | 135 | if TUNE_INSTALLED and is_session_enabled(): 136 | # Create communication queue and send to all the workers. 137 | self.tune_queue = Queue(actor_options={"num_cpus": 0}) 138 | 139 | self._futures = executor.run_remote(lambda: self._wrapping_function( 140 | function, model_ref, new_args, kwargs, self.tune_queue)) 141 | 142 | # put back the executor and model 143 | self._executor = executor 144 | self._strategy.executor = executor 145 | trainer.model = model 146 | 147 | results = process_results(self._futures, self.tune_queue) 148 | executor.shutdown() 149 | self._strategy.teardown() 150 | 151 | return results[0] 152 | 153 | def _wrapping_function( 154 | self, 155 | function: Callable, 156 | model_ref: Any, 157 | args: Any, 158 | kwargs: Any, 159 | tune_queue: Queue, 160 | ) -> Any: 161 | """Wrapping function to run the function on the workers. 162 | 163 | This function is run on the worker process. 164 | 165 | `_wrapping_function` is run on each remote worker. 166 | `function(*args, **kwargs)` is where the actual training happens. 167 | """ 168 | 169 | self._strategy.set_remote(True) 170 | 171 | # `function` is a trainer's instance method 172 | # in the ray remote tasks, its bound instance `trainer` 173 | # will also be copied when the function is remoted. 174 | # 175 | # ALERT: passing the trainer as an argument of `_wrapping_function` 176 | # does not fulfill our purpose. Ray remote tasks will 177 | # create another copy of trainer so that 178 | # `function.__self__ != trainer`, in which the side effect only 179 | # happens to `function.__self__` when running 180 | # `function(*args, **kwargs)` (see SOLUTION below). 181 | # 182 | # SOLUTION: we find the trainer directly from `function` 183 | # by calling `function.__self__` so that we can restore 184 | # all the side effects happened to `function.__self__` 185 | trainer = function.__self__ 186 | model = ray.get(model_ref) 187 | trainer.model = model 188 | args = tuple([model] + list(args[1:])) 189 | 190 | trainer._data_connector.prepare_data() 191 | 192 | hvd.init() 193 | rank_zero_only.rank = self.global_rank 194 | set_cuda_device_if_used(trainer.strategy) 195 | 196 | # Move the model to the appropriate device. 197 | trainer.strategy.model_to_device() 198 | 199 | if tune_queue is not None: 200 | # Initialize session. 201 | init_session(rank=self.global_rank, queue=tune_queue) 202 | 203 | results = function(*args, **kwargs) 204 | 205 | if trainer is not None: 206 | results = self._collect_rank_zero_results(function.__self__, 207 | results) 208 | 209 | if self.local_rank == 0: 210 | return move_data_to_device(results, "cpu") 211 | 212 | return None 213 | 214 | def _collect_rank_zero_results(self, trainer: "pl.Trainer", 215 | results: Any) -> Optional["_RayOutput"]: 216 | """Collect the results from the rank zero process. 217 | 218 | This function is run on the worker process. 219 | """ 220 | rank_zero_debug("Finalizing the ray horovod launcher environment.") 221 | checkpoint_callback = trainer.checkpoint_callback 222 | best_model_path = checkpoint_callback.best_model_path \ 223 | if checkpoint_callback else None 224 | 225 | state_dict = trainer.lightning_module.state_dict() 226 | 227 | if self._strategy.global_rank != 0: 228 | return None 229 | 230 | # PyTorch Lightning saves the model weights in a temp file and 231 | # loads it back on the driver. 232 | # This won't work in a multi-node setup though, so we return the 233 | # model state stream directly. 234 | model_state_stream = to_state_stream(state_dict) 235 | 236 | # adds the `callback_metrics` 237 | callback_metrics: dict = apply_to_collection( 238 | trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy( 239 | )) # send as numpy to avoid issues with memory sharing 240 | 241 | # Same for logged_metrics 242 | logged_metrics: dict = apply_to_collection( 243 | trainer.logged_metrics, torch.Tensor, lambda x: x.cpu().numpy( 244 | )) # send as numpy to avoid issues with memory sharing 245 | 246 | return _RayOutput(best_model_path, model_state_stream, trainer.state, 247 | results, callback_metrics, logged_metrics) 248 | 249 | def _recover_results_in_main_process(self, ray_output: "_RayOutput", 250 | trainer: "pl.Trainer") -> None: 251 | """Recover the results in the main process. 252 | 253 | This function is run on the worker process. 254 | """ 255 | # transfer back the best path to the trainer 256 | if trainer.checkpoint_callback: 257 | trainer.checkpoint_callback.best_model_path = str( 258 | ray_output.best_model_path) 259 | 260 | if ray_output.weights_path is not None: 261 | state_stream = ray_output.weights_path 262 | # DDPSpawnPlugin.__recover_child_process_weights begin 263 | # Difference here is that instead of writing the model weights to a 264 | # file and loading it, we use the state dict of the model directly. 265 | state_dict = load_state_stream( 266 | state_stream, to_gpu=self._strategy.use_gpu) 267 | # Set the state for PTL using the output from remote training. 268 | trainer.lightning_module.load_state_dict(state_dict) 269 | 270 | trainer.state = ray_output.trainer_state 271 | 272 | trainer.callback_metrics.update( 273 | apply_to_collection(ray_output.callback_metrics, 274 | np.ndarray, lambda x: torch.tensor(x))) 275 | trainer.logged_metrics.update( 276 | apply_to_collection(ray_output.logged_metrics, 277 | np.ndarray, lambda x: torch.tensor(x))) 278 | -------------------------------------------------------------------------------- /ray_lightning/launchers/ray_launcher.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Any, Tuple, Optional 2 | 3 | from collections import defaultdict 4 | import os 5 | 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.strategies.launchers import _Launcher 8 | from pytorch_lightning.utilities.apply_func import apply_to_collection,\ 9 | move_data_to_device 10 | import numpy as np 11 | import torch 12 | 13 | import ray 14 | from ray import ObjectRef 15 | from pytorch_lightning.utilities.rank_zero import rank_zero_debug 16 | from ray.util.queue import Queue 17 | 18 | from ray_lightning.session import init_session 19 | from ray_lightning.util import process_results, to_state_stream, \ 20 | load_state_stream, set_cuda_device_if_used 21 | from ray_lightning.tune import TUNE_INSTALLED, is_session_enabled 22 | from pytorch_lightning.strategies import Strategy 23 | from ray_lightning.launchers.utils import _RayOutput, find_free_port,\ 24 | RayExecutor 25 | 26 | 27 | class RayLauncher(_Launcher): 28 | def __init__(self, strategy: "Strategy") -> None: 29 | """Initializes RayLauncher.""" 30 | self._strategy = strategy 31 | self._start_method = "ray" 32 | self._workers = [] 33 | self._futures = [] 34 | self._master_addr = None 35 | self._master_port = None 36 | 37 | self._global_to_local = None 38 | 39 | self.tune_queue = None 40 | 41 | if not ray.is_initialized(): 42 | ray.init() 43 | 44 | def is_interactive_compatible(self) -> bool: 45 | """Returns True if the launcher is interactive compatible.""" 46 | return True 47 | 48 | def launch(self, 49 | function: Callable, 50 | *args: Any, 51 | trainer: Optional["pl.Trainer"] = None, 52 | **kwargs: Any) -> Any: 53 | """Launches the function on the workers from the driver node. 54 | 55 | This function is run on the driver process. 56 | """ 57 | self.setup_workers() 58 | ray_output = self.run_function_on_workers( 59 | function, *args, trainer=trainer, **kwargs) 60 | 61 | if trainer is None: 62 | raise NotImplementedError( 63 | "Ray launcher does not support trainer is None!") 64 | self._recover_results_in_main_process(ray_output, trainer) 65 | return_value = ray_output.trainer_results 66 | 67 | self.teardown_workers() 68 | self._strategy.teardown() 69 | return return_value 70 | 71 | def setup_workers(self, tune_enabled: bool = True) -> None: 72 | """Creates the Ray actors and sets up PTL Trainer environment. 73 | 74 | This function is run on the driver process. 75 | """ 76 | self._workers = [ 77 | self._create_worker() for _ in range(self._strategy.num_workers) 78 | ] 79 | if self._strategy.init_hook: 80 | ray.get([ 81 | w.execute.remote(self._strategy.init_hook) 82 | for w in self._workers 83 | ]) 84 | 85 | self._master_addr = ray.get(self._workers[0].get_node_ip.remote()) 86 | self._master_port = str( 87 | ray.get(self._workers[0].execute.remote(find_free_port))) 88 | 89 | # Sets environment variables for all workers. 90 | # This will set the MASTER_ADDR and MASTER_PORT on each Ray actor. 91 | self._setup_env_vars() 92 | 93 | if self._strategy.use_gpu: 94 | # Set the CUDA_VISIBLE_DEVICES for all workers. 95 | self._share_cuda_visible_devices() 96 | 97 | # Get the mapping from global ranks to the respective local ranks. 98 | self._global_to_local = self.get_local_ranks() 99 | # Todo: put model into object store? 100 | 101 | if tune_enabled and TUNE_INSTALLED and is_session_enabled(): 102 | # Create communication queue and send to all the workers. 103 | self.tune_queue = Queue(actor_options={"num_cpus": 0}) 104 | 105 | def _create_worker(self) -> ray.actor.ActorHandle: 106 | """Creates Ray actor workers. 107 | 108 | This function is run on the driver process. 109 | """ 110 | worker = RayExecutor.options( 111 | num_cpus=self._strategy.num_cpus_per_worker, 112 | num_gpus=self._strategy.num_gpus_per_worker, 113 | resources=self._strategy.additional_resources_per_worker).remote() 114 | return worker 115 | 116 | def teardown_workers(self): 117 | """Tears down the Ray actors and PTL Trainer environment 118 | 119 | This function is run on the driver process. 120 | """ 121 | if self.tune_queue: 122 | # Shutdown the queue. 123 | self.tune_queue.shutdown() 124 | 125 | for w in self._workers: 126 | ray.kill(w, no_restart=True) 127 | del w 128 | self._workers = [] 129 | 130 | def get_local_ranks(self) -> List[Optional[Tuple[int, int]]]: 131 | """Creates a mapping of global ranks to local ranks/node ranks. 132 | 133 | This function is run on the driver process. 134 | """ 135 | # Get the local ranks for all the workers and store as a list. 136 | # First get the IP address of each remote worker. 137 | node_ips = ray.get([w.get_node_ip.remote() for w in self._workers]) 138 | 139 | node_rank_map = {} 140 | counter = 0 141 | for ip in node_ips: 142 | # If this is a new IP address, then increment counter. 143 | if ip not in node_rank_map: 144 | node_rank_map[ip] = counter 145 | counter += 1 146 | 147 | rank_counter_dict = defaultdict(int) 148 | global_to_local = [None] * self._strategy.num_workers 149 | 150 | for global_rank in range(self._strategy.num_workers): 151 | ip = node_ips[global_rank] 152 | global_to_local[global_rank] = ( 153 | rank_counter_dict[ip], # local rank 154 | node_rank_map[ip]) # node rank 155 | rank_counter_dict[ip] += 1 156 | 157 | return global_to_local 158 | 159 | def _setup_env_vars(self): 160 | """Sets environment variables for all workers. 161 | 162 | This function is run on the driver process. 163 | """ 164 | # Get rank 0 worker address and port for DDP connection. 165 | os.environ["MASTER_ADDR"] = self._master_addr 166 | os.environ["MASTER_PORT"] = self._master_port 167 | 168 | # Set environment variables for remote workers. 169 | keys = [ 170 | "PL_GLOBAL_SEED", "PL_TORCH_DISTRIBUTED_BACKEND", "MASTER_ADDR", 171 | "MASTER_PORT" 172 | ] 173 | values = [os.getenv(k) for k in keys] 174 | 175 | ray.get([w.set_env_vars.remote(keys, values) for w in self._workers]) 176 | 177 | def _share_cuda_visible_devices(self): 178 | """Sets CUDA_VISIBLE_DEVICES on all workers. 179 | 180 | This function is run on the driver process. 181 | 182 | For each worker, CUDA_VISIBLE_DEVICES will be set to the GPU IDs 183 | visible to all workers on that worker's node. 184 | This allows GPU workers on the same node to communicate with one 185 | another. 186 | Example: 187 | Setup: 188 | - Node1: 189 | - Worker1: {0, 1} 190 | - Worker2: {2, 3} 191 | - Node2: 192 | - Worker3: {0, 1} 193 | CUDA_VISIBLE_DEVICES: 194 | - Worker1: "0,1,2,3" 195 | - Worker2: "0,1,2,3" 196 | - Worker2: "0,1" 197 | """ 198 | node_ids_and_gpu_ids = ray.get( 199 | [w.get_node_and_gpu_ids.remote() for w in self._workers]) 200 | 201 | node_id_to_worker_id = defaultdict(set) 202 | node_id_to_gpu_ids = defaultdict(set) 203 | 204 | for worker_id, (node_id, gpu_ids) in enumerate(node_ids_and_gpu_ids): 205 | node_id_to_worker_id[node_id].add(worker_id) 206 | node_id_to_gpu_ids[node_id].update(gpu_ids) 207 | 208 | futures = [] 209 | for node_id, gpu_ids in node_id_to_gpu_ids.items(): 210 | all_gpu_ids = ",".join([str(gpu_id) for gpu_id in gpu_ids]) 211 | 212 | def set_gpu_ids(): 213 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 214 | os.environ["CUDA_VISIBLE_DEVICES"] = all_gpu_ids 215 | 216 | for worker_id in node_id_to_worker_id[node_id]: 217 | futures.append( 218 | self._workers[worker_id].execute.remote(set_gpu_ids)) 219 | ray.get(futures) 220 | 221 | def run_function_on_workers(self, 222 | function: Callable, 223 | *args: Any, 224 | trainer: Optional["pl.Trainer"] = None, 225 | **kwargs: Any): 226 | """launch a function on all workers. 227 | 228 | This function is run on the driver process. 229 | 230 | The actual training parts are run inside `_wrapping_function` 231 | """ 232 | # put the model as the ray object 233 | # and remove the model temporarily from the args 234 | model = trainer.model 235 | model_ref = ray.put(model) 236 | trainer.model = None 237 | new_args = tuple([None] + list(args[1:])) 238 | 239 | # train the model and get the result to rank 0 node 240 | self._futures = [ 241 | w.execute.remote(self._wrapping_function, i, self._global_to_local, 242 | function, model_ref, new_args, kwargs, 243 | self.tune_queue) 244 | for i, w in enumerate(self._workers) 245 | ] 246 | 247 | trainer.model = model 248 | 249 | results = process_results(self._futures, self.tune_queue) 250 | return results[0] 251 | 252 | def _wrapping_function( 253 | self, 254 | global_rank: int, 255 | global_to_local: List[Optional[Tuple[int, int]]], 256 | function: Callable, 257 | model_ref: ObjectRef, 258 | args: Any, 259 | kwargs: Any, 260 | tune_queue: Queue, 261 | ) -> Any: 262 | """Wraps the function to run on the workers. 263 | 264 | This function is run on the worker process. 265 | 266 | `results = function(*args, **kwargs)` is where the 267 | actual training parts are run. 268 | """ 269 | self._strategy.set_remote(True) 270 | self._strategy.set_global_to_local(global_to_local) 271 | 272 | # `function` is a trainer's instance method 273 | # in the ray remote tasks, its bound instance `trainer` 274 | # will also be copied when the function is remoted. 275 | # 276 | # ALERT: passing the trainer as an argument of `_wrapping_function` 277 | # does not fulfill our purpose. Ray remote tasks will 278 | # create another copy of trainer so that 279 | # `function.__self__ != trainer`, in which the side effect only 280 | # happens to `function.__self__` when running 281 | # `function(*args, **kwargs)` (see SOLUTION below). 282 | # 283 | # SOLUTION: we find the trainer directly from `function` 284 | # by calling `function.__self__` so that we can restore 285 | # all the side effects happened to `function.__self__` 286 | trainer = function.__self__ 287 | trainer.model = model_ref 288 | args = tuple([model_ref] + list(args[1:])) 289 | 290 | trainer._data_connector.prepare_data() 291 | if tune_queue is not None: 292 | # Initialize session. 293 | init_session(rank=global_rank, queue=tune_queue) 294 | 295 | self._strategy._worker_setup(process_idx=global_rank) 296 | trainer.strategy.root_device = self._strategy.root_device 297 | trainer.strategy.global_rank = self._strategy.global_rank 298 | trainer.strategy.local_rank = self._strategy.local_rank 299 | set_cuda_device_if_used(trainer.strategy) 300 | 301 | results = function(*args, **kwargs) 302 | 303 | if trainer is not None: 304 | return self._collect_rank_zero_results(trainer, results) 305 | else: 306 | return None 307 | 308 | trainer._teardown() 309 | trainer._call_teardown_hook() 310 | return None 311 | 312 | def _collect_rank_zero_results(self, trainer: "pl.Trainer", 313 | results: Any) -> Optional["_RayOutput"]: 314 | """Collects the results from the worker node 0. 315 | 316 | This function is run on the worker process. 317 | """ 318 | rank_zero_debug("Finalizing the Ray launcher environment.") 319 | checkpoint_callback = trainer.checkpoint_callback 320 | best_model_path = checkpoint_callback.best_model_path \ 321 | if checkpoint_callback else None 322 | 323 | state_dict = trainer.lightning_module.state_dict() 324 | 325 | if self._strategy.global_rank != 0: 326 | return None 327 | 328 | # Move state_dict to cpu before converting it to model state stream 329 | if trainer.strategy.local_rank == 0: 330 | state_dict = move_data_to_device(state_dict, "cpu") 331 | 332 | # PyTorch Lightning saves the model weights in a temp file and 333 | # loads it back on the driver. 334 | # This won't work in a multi-node setup though, so we return the 335 | # model state stream directly. 336 | model_state_stream = to_state_stream(state_dict) 337 | 338 | # adds the `callback_metrics` 339 | callback_metrics: dict = apply_to_collection( 340 | trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy( 341 | )) # send as numpy to avoid issues with memory sharing 342 | 343 | # Same for logged_metrics 344 | logged_metrics: dict = apply_to_collection( 345 | trainer.logged_metrics, torch.Tensor, lambda x: x.cpu().numpy( 346 | )) # send as numpy to avoid issues with memory sharing 347 | 348 | return _RayOutput(best_model_path, model_state_stream, trainer.state, 349 | results, callback_metrics, logged_metrics) 350 | 351 | def _recover_results_in_main_process(self, ray_output: "_RayOutput", 352 | trainer: "pl.Trainer") -> None: 353 | """Recovers the results in the main process. 354 | 355 | This function is run on the worker process. 356 | """ 357 | # transfer back the best path to the trainer 358 | if trainer.checkpoint_callback: 359 | trainer.checkpoint_callback.best_model_path = str( 360 | ray_output.best_model_path) 361 | 362 | if ray_output.weights_path is not None: 363 | state_stream = ray_output.weights_path 364 | # DDPSpawnPlugin.__recover_child_process_weights begin 365 | # Difference here is that instead of writing the model weights to a 366 | # file and loading it, we use the state dict of the model directly. 367 | state_dict = load_state_stream( 368 | state_stream, to_gpu=self._strategy.use_gpu) 369 | # Set the state for PTL using the output from remote training. 370 | trainer.lightning_module.load_state_dict(state_dict) 371 | 372 | trainer.state = ray_output.trainer_state 373 | 374 | trainer.callback_metrics.update( 375 | apply_to_collection(ray_output.callback_metrics, 376 | np.ndarray, lambda x: torch.tensor(x))) 377 | trainer.logged_metrics.update( 378 | apply_to_collection(ray_output.logged_metrics, 379 | np.ndarray, lambda x: torch.tensor(x))) 380 | -------------------------------------------------------------------------------- /ray_lightning/launchers/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, NamedTuple, Dict, List, Callable 2 | from pytorch_lightning.utilities.types import _PATH 3 | from pytorch_lightning.trainer.states import TrainerState 4 | 5 | from contextlib import closing 6 | import socket 7 | 8 | import ray 9 | import os 10 | 11 | 12 | def find_free_port(): 13 | """ Find a free port on the machines. """ 14 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: 15 | s.bind(("", 0)) 16 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 17 | return s.getsockname()[1] 18 | 19 | 20 | def get_executable_cls(): 21 | # Only used for testing purposes, currently. 22 | # Only used in `ray_horovod_launcher.py` 23 | # We need to override this in tests to ensure test path is set correctly. 24 | return None 25 | 26 | 27 | @ray.remote 28 | class RayExecutor: 29 | """A class to execute any arbitrary function remotely.""" 30 | 31 | def set_env_var(self, key: str, value: str): 32 | """Set an environment variable with the provided values.""" 33 | if value is not None: 34 | value = str(value) 35 | os.environ[key] = value 36 | 37 | def set_env_vars(self, keys: List[str], values: List[str]): 38 | """Sets multiple env vars with the provided values""" 39 | assert len(keys) == len(values) 40 | for key, value in zip(keys, values): 41 | self.set_env_var(key, value) 42 | 43 | def get_node_ip(self): 44 | """Returns the IP address of the node that this Ray actor is on.""" 45 | return ray.util.get_node_ip_address() 46 | 47 | def get_node_and_gpu_ids(self): 48 | return ray.get_runtime_context().node_id.hex(), ray.get_gpu_ids() 49 | 50 | def execute(self, fn: Callable, *args, **kwargs): 51 | """Execute the provided function and return the result.""" 52 | return fn(*args, **kwargs) 53 | 54 | 55 | class _RayOutput(NamedTuple): 56 | """Ray output tuple with the following fields: 57 | - `best_model_path`: path to the best model 58 | - `weights_path`: path to the weights 59 | - `trainer_state`: trainer state 60 | - `trainer_results`: trainer result 61 | - `callback_results`: callback result 62 | - `logged_metrics`: logged metrics 63 | """ 64 | best_model_path: Optional[_PATH] 65 | weights_path: Optional[_PATH] 66 | trainer_state: TrainerState 67 | trainer_results: Any 68 | callback_metrics: Dict[str, Any] 69 | logged_metrics: Dict[str, Any] 70 | -------------------------------------------------------------------------------- /ray_lightning/ray_ddp.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List, Union, Any, Tuple, Optional 2 | 3 | import warnings 4 | 5 | import torch 6 | 7 | from pytorch_lightning.strategies import DDPSpawnStrategy 8 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 9 | 10 | import ray 11 | from pytorch_lightning.utilities.rank_zero import rank_zero_info 12 | from pytorch_lightning.utilities.seed import reset_seed, log 13 | from ray.util import PublicAPI 14 | 15 | from ray_lightning.launchers import RayLauncher 16 | from ray_lightning.accelerators import \ 17 | _GPUAccelerator # noqa: F401 18 | 19 | import os 20 | 21 | 22 | @PublicAPI(stability="beta") 23 | class RayStrategy(DDPSpawnStrategy): 24 | """Pytorch Lightning strategy for DDP training on a Ray cluster. 25 | 26 | This strategy is used to manage distributed training using DDP and 27 | Ray for process launching. Internally, the specified number of 28 | Ray actors are launched in the cluster and are registered as part of a 29 | Pytorch DDP process group. The Pytorch Lightning trainer is instantiated 30 | on the driver and sent to each of these training workers where training is 31 | executed. The distributed training protocol is handled by Pytorch DDP. 32 | Each training worker is configured to reserve ``num_cpus_per_worker`` 33 | CPUS and 1 GPU if ``use_gpu`` is set to ``True``. 34 | If using this strategy, you should run your code like a normal Python 35 | script: ``python train.py``, and only on the head node if running in a 36 | distributed Ray cluster. There is no need to run this script on every 37 | single node. 38 | 39 | Args: 40 | num_workers (int): Number of training workers to use. 41 | num_cpus_per_worker (int): Number of CPUs per worker. 42 | use_gpu (bool): Whether to use GPU for allocation. For GPU to be 43 | used, you must also set the ``gpus`` arg in your Pytorch Lightning 44 | Trainer to a value > 0. 45 | init_hook (Callable): A function to run on each worker 46 | upon instantiation. 47 | resources_per_worker (Optional[Dict]): If specified, the resources 48 | defined in this Dict will be reserved for each worker. The 49 | ``CPU`` and ``GPU`` keys (case-sensitive) can be defined to 50 | override the number of CPU/GPUs used by each worker. 51 | **ddp_kwargs: Additional arguments to pass into 52 | ``DistributedDataParallel`` initialization 53 | Example: 54 | .. code-block:: python 55 | 56 | import pytorch_lightning as ptl 57 | from ray_lightning import RayAccelerator 58 | ptl_model = MNISTClassifier(...) 59 | strategy = RayStrategy(num_workers=4, cpus_per_worker=1, 60 | use_gpu=True) 61 | # Don't set ``gpus`` in ``Trainer``. 62 | # The actual number of GPUs is determined by ``num_workers``. 63 | trainer = pl.Trainer(..., strategy=strategy) 64 | trainer.fit(ptl_model) 65 | """ 66 | 67 | strategy_name = "ddp_ray" 68 | 69 | def __init__(self, 70 | num_workers: int = 1, 71 | num_cpus_per_worker: int = 1, 72 | use_gpu: bool = False, 73 | init_hook: Optional[Callable] = None, 74 | resources_per_worker: Optional[Dict] = None, 75 | **ddp_kwargs: Union[Any, Dict[str, Any]]): 76 | """Initialize the Ray strategy.""" 77 | resources_per_worker = resources_per_worker if resources_per_worker \ 78 | else {} 79 | self.nickname = "ddp_ray" 80 | self.num_workers = int(num_workers) 81 | self.num_cpus_per_worker = resources_per_worker.pop( 82 | "CPU", num_cpus_per_worker) 83 | 84 | if "GPU" in resources_per_worker: 85 | self.num_gpus_per_worker = resources_per_worker.pop("GPU") 86 | else: 87 | self.num_gpus_per_worker = int(use_gpu) 88 | 89 | self.use_gpu = self.num_gpus_per_worker > 0 90 | 91 | if self.use_gpu and self.num_gpus_per_worker < 1 and num_workers > 1: 92 | warnings.warn("Identified less than 1 GPU being set per worker. " 93 | "If using NCCL backend (which is the default for " 94 | "GPU training), GPU devices cannot be shared " 95 | "across processes/workers and training is likely " 96 | "to fail. It is recommended to use 1 GPU per " 97 | "worker for training, or if you must use " 98 | "fractional GPUs, then use the gloo backend by " 99 | "setting PL_TORCH_DISTRIBUTED_BACKEND=gloo " 100 | "environment variable.") 101 | 102 | self.additional_resources_per_worker = resources_per_worker 103 | self.init_hook = init_hook 104 | 105 | self._local_rank = 0 106 | self._global_rank = 0 107 | self._node_rank = 0 108 | 109 | self._is_remote = False 110 | self._device = None 111 | 112 | super().__init__( 113 | accelerator="_gpu" if use_gpu else "cpu", 114 | parallel_devices=[], 115 | cluster_environment=None, 116 | **ddp_kwargs) 117 | 118 | def _configure_launcher(self): 119 | """Configure the Ray launcher. 120 | 121 | This function is overriding ddp_spawn_strategy's method. 122 | It is run on the driver process. 123 | 124 | the distributed training logic is handled by the launcher. 125 | """ 126 | self._launcher = RayLauncher(self) 127 | 128 | def set_remote(self, remote: bool): 129 | """Set the remote flag. (this is useful for the remote workers) 130 | 131 | This function is a new RayStrategy method. 132 | It is run on the worker processes. 133 | """ 134 | self._is_remote = remote 135 | 136 | def set_global_to_local(self, 137 | global_to_local: List[Optional[Tuple[int, int]]]): 138 | """Set the global to local rank mapping. 139 | 140 | This function is a new RayStrategy method. 141 | It is run on the worker processes. 142 | """ 143 | self.global_to_local = global_to_local 144 | 145 | def set_world_ranks(self, process_idx: int = 0): 146 | """Set the appropriate rank attributes for the trainer. 147 | 148 | This function is overriding ddp_spawn_strategy's method. 149 | It is run on the worker processes. 150 | """ 151 | # Ranks should only be set once all the actors are created and 152 | # training has begun (otherwise self.global_to_local has not been 153 | # initialized). 154 | # If this method is called on the driver (i.e. self._is_remote is 155 | # False, then do a no-op). 156 | if self._is_remote: 157 | self._global_rank = process_idx 158 | self._local_rank, self._node_rank = self.global_to_local[ 159 | self.global_rank] 160 | 161 | def _worker_setup(self, process_idx: int): 162 | """Setup the workers and pytorch DDP connections. 163 | 164 | This function is overriding ddp_spawn_strategy's method. 165 | It is run on the worker processes. 166 | """ 167 | reset_seed() 168 | self.set_world_ranks(process_idx) 169 | rank_zero_only.rank = self.global_rank 170 | self._process_group_backend = self._get_process_group_backend() 171 | 172 | # Copied from 173 | # pytorch_lightning.utilities.distributed.init_dist_connection 174 | if not torch.distributed.is_available(): 175 | raise RuntimeError("torch.distributed is not available. " 176 | "Cannot initialize distributed process group") 177 | 178 | if torch.distributed.is_initialized(): 179 | log.debug( 180 | "torch.distributed is already initialized. Exiting early") 181 | return 182 | 183 | global_rank = self.global_rank 184 | world_size = self.world_size 185 | torch_distributed_backend = self.torch_distributed_backend 186 | 187 | # Taken from pytorch_lightning.utilities.distributed 188 | if torch.distributed.is_available( 189 | ) and not torch.distributed.is_initialized(): 190 | log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, " 191 | f"MEMBER: {global_rank + 1}/{world_size}") 192 | torch.distributed.init_process_group( 193 | torch_distributed_backend, 194 | rank=global_rank, 195 | world_size=world_size, 196 | init_method="env://") 197 | 198 | # on rank=0 let everyone know training is starting 199 | rank_zero_info(f"{'-' * 100}\n" 200 | f"distributed_backend={torch_distributed_backend}\n" 201 | f"All distributed processes registered. " 202 | f"Starting with {world_size} processes\n" 203 | f"{'-' * 100}\n") 204 | 205 | @property 206 | def world_size(self) -> int: 207 | """Return the world size. 208 | 209 | This function is a new RayStrategy method. 210 | It is run on the worker processes. 211 | """ 212 | return self.num_workers 213 | 214 | @property 215 | def local_rank(self) -> int: 216 | """Return the local rank. 217 | 218 | This function is a new RayStrategy method. 219 | It is run on the worker processes. 220 | """ 221 | return self._local_rank 222 | 223 | @local_rank.setter 224 | def local_rank(self, value: int): 225 | """Set the local rank. 226 | 227 | This function is a new RayStrategy method. 228 | It is run on the worker processes. 229 | """ 230 | self._local_rank = value 231 | 232 | @property 233 | def global_rank(self) -> int: 234 | """Return the global rank. 235 | 236 | This function is a new RayStrategy method. 237 | It is run on the worker processes. 238 | """ 239 | return self._global_rank 240 | 241 | @global_rank.setter 242 | def global_rank(self, value: int): 243 | """Set the global rank. 244 | 245 | This function is a new RayStrategy method. 246 | It is run on the worker processes. 247 | """ 248 | self._global_rank = value 249 | 250 | @property 251 | def node_rank(self) -> int: 252 | """Return the node rank. 253 | 254 | This function is a new RayStrategy method. 255 | It is run on the worker processes. 256 | """ 257 | return self._node_rank 258 | 259 | @property 260 | def root_device(self): 261 | """Return the root device. 262 | 263 | This function is overriding ddp_spawn_strategy's method. 264 | It is run on the worker processes. 265 | """ 266 | # get the root device 267 | # if the root device not set, figure it out 268 | # thru `get_gpu_ids` if `use_gpu` is True 269 | if self._device: 270 | return self._device 271 | if self.use_gpu and torch.cuda.is_available(): 272 | if self._is_remote: 273 | # GPU IDs are assigned by Ray after you specify "use_gpu" 274 | # GPU `ray.get_gpu_ids()` may return ints or may return 275 | # strings. We should always convert to strings. 276 | gpu_ids = [str(id) for id in ray.get_gpu_ids()] 277 | 278 | if len(gpu_ids) > 0: 279 | # By default, there should only be one GPU ID if 280 | # `use_gpu=True`. 281 | # If there are multiple GPUs, use the first one. 282 | # If using fractional GPUs, these IDs are not guaranteed 283 | # to be unique across different processes. 284 | gpu_id = gpu_ids[0] 285 | 286 | cuda_visible_str = os.environ.get("CUDA_VISIBLE_DEVICES", 287 | "") 288 | if cuda_visible_str and cuda_visible_str != "NoDevFiles": 289 | cuda_visible_list = cuda_visible_str.split(",") 290 | device_id = cuda_visible_list.index(gpu_id) 291 | else: 292 | raise RuntimeError( 293 | "CUDA_VISIBLE_DEVICES set incorrectly. " 294 | f"Got {cuda_visible_str}, expected to include " 295 | "{gpu_id}. Did you override the " 296 | "`CUDA_VISIBLE_DEVICES` environment variable? " 297 | "If not, please help file an issue on Github.") 298 | else: 299 | # If the root device is requested on the driver, just return 300 | # the 0th device. 301 | device_id = 0 302 | return torch.device(f"cuda:{device_id}") 303 | else: 304 | return torch.device("cpu") 305 | 306 | @root_device.setter 307 | def root_device(self, device): 308 | """Set the root device. 309 | 310 | This function is a new RayStrategy method. 311 | It is run on the worker processes. 312 | """ 313 | self._device = device 314 | 315 | @property 316 | def distributed_sampler_kwargs(self): 317 | """Returns the args to use for torch.data.DistributedSampler. 318 | 319 | This function is overriding ddp_spawn_strategy's method. 320 | It is run on the worker processes. 321 | """ 322 | distributed_sampler_kwargs = dict( 323 | num_replicas=self.num_workers, rank=self.global_rank) 324 | return distributed_sampler_kwargs 325 | 326 | def teardown(self) -> None: 327 | """Teardown the workers and pytorch DDP connections. 328 | 329 | This function is overriding ddp_spawn_strategy's method. 330 | It is run on the driver processes. 331 | """ 332 | self.accelerator = None 333 | super().teardown() 334 | -------------------------------------------------------------------------------- /ray_lightning/ray_ddp_sharded.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.strategies import DDPSpawnShardedStrategy 2 | 3 | from ray.util import PublicAPI 4 | 5 | from ray_lightning import RayStrategy 6 | 7 | 8 | # C3 linearization of parent classes will do breadth first since both 9 | # RayStrategy and DDPSpawnShardedStrategy share 10 | # a common parent of DDPSpawnStrategy 11 | @PublicAPI(stability="beta") 12 | class RayShardedStrategy(RayStrategy, DDPSpawnShardedStrategy): 13 | strategy_name = "ddp_sharded_ray" 14 | -------------------------------------------------------------------------------- /ray_lightning/ray_horovod.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ray.util import PublicAPI 4 | 5 | from pytorch_lightning.strategies import HorovodStrategy, ParallelStrategy 6 | import ray 7 | 8 | from ray_lightning.util import Unavailable 9 | 10 | try: 11 | import horovod.torch as hvd 12 | from horovod.ray import RayExecutor 13 | except (ModuleNotFoundError, ImportError): 14 | HOROVOD_AVAILABLE = False 15 | RayExecutor = Unavailable 16 | hvd = Unavailable 17 | else: 18 | HOROVOD_AVAILABLE = True 19 | 20 | from ray_lightning.launchers import RayHorovodLauncher 21 | from ray_lightning.accelerators import \ 22 | _GPUAccelerator # noqa: F401 23 | 24 | 25 | def get_executable_cls(): 26 | # Only used for testing purposes, currently. 27 | # We need to override this in tests to ensure test path is set correctly. 28 | return None 29 | 30 | 31 | @PublicAPI(stability="beta") 32 | class HorovodRayStrategy(HorovodStrategy): 33 | """Pytorch Lightning Strategy for Horovod training on a Ray cluster. 34 | 35 | This strategy is used to manage distributed training on a Ray cluster 36 | via the Horovod training framework. Internally, the specified number of 37 | Ray actors are launched in the cluster and are configured as part of the 38 | Horovod ring. The Pytorch Lightning trainer is instantiated on the 39 | driver and sent to each of these training workers where training is 40 | executed. The distributed training protocol is handled by Horovod. 41 | 42 | Each training worker is configured to reserve 1 CPU and if 1 GPU if 43 | ``use_gpu`` is set to ``True``. 44 | 45 | If using this strategy, you should run your code like a normal Python 46 | script: ``python train.py``, and not with ``horovodrun``. 47 | 48 | Args: 49 | num_workers (int): Number of training workers to use. 50 | num_cpus_per_worker (int): Number of CPUs per worker. 51 | use_gpu (bool): Whether to use GPU for allocation. For GPU to be 52 | used, you must also set the ``gpus`` arg in your Pytorch Lightning 53 | Trainer to a value > 0. 54 | 55 | Example: 56 | 57 | .. code-block:: python 58 | 59 | import pytorch_lightning as ptl 60 | from ray_lightning import HorovodRayPlugin 61 | 62 | ptl_model = MNISTClassifier(...) 63 | strategy = HorovodRayPlugin(num_workers=2, use_gpu=True) 64 | 65 | # Don't set ``gpus`` in ``Trainer``. 66 | # The actual number of GPUs is determined by ``num_workers``. 67 | trainer = pl.Trainer(..., strategy=strategy) 68 | trainer.fit(ptl_model) 69 | 70 | """ 71 | strategy_name = "horovod_ray" 72 | 73 | def __init__(self, 74 | num_workers: int, 75 | num_cpus_per_worker: int = 1, 76 | use_gpu: bool = False): 77 | """Initialize HorovodRayStrategy.""" 78 | if not HOROVOD_AVAILABLE: 79 | raise RuntimeError("Please intall Horovod to use this strategy.") 80 | if not ray.is_initialized(): 81 | ray.init() 82 | ParallelStrategy.__init__( 83 | self, accelerator="_gpu" if use_gpu else "cpu") 84 | self.num_workers = num_workers 85 | self.cpus_per_worker = num_cpus_per_worker 86 | self.use_gpu = use_gpu 87 | self.executor = None 88 | self._exit_stack = None 89 | self._local_rank = 0 90 | 91 | self._is_remote = False 92 | 93 | def _configure_launcher(self): 94 | """Configure the Ray launcher. 95 | 96 | This function is overriding horovod_strategy's method. 97 | It is run on the driver processes. 98 | 99 | The horovod launcher is used to launch the Ray actors. 100 | """ 101 | settings = RayExecutor.create_settings(timeout_s=30) 102 | self.executor = RayExecutor( 103 | settings, 104 | num_workers=self.num_workers, 105 | cpus_per_worker=self.cpus_per_worker, 106 | use_gpu=self.use_gpu) 107 | 108 | self._launcher = RayHorovodLauncher(self) 109 | 110 | @property 111 | def global_rank(self) -> int: 112 | """Return the global rank of the current process. 113 | 114 | This function is overriding horovod_strategy's method. 115 | It is run on the worker processes. 116 | """ 117 | if not hvd.is_initialized(): 118 | return 0 119 | return hvd.rank() 120 | 121 | @property 122 | def local_rank(self) -> int: 123 | """Return the local rank of the current process. 124 | 125 | This function is overriding horovod_strategy's method. 126 | It is run on the worker processes. 127 | """ 128 | if not hvd.is_initialized(): 129 | return 0 130 | return hvd.local_rank() 131 | 132 | @property 133 | def world_size(self) -> int: 134 | """Return the world size of the current process. 135 | 136 | This function is overriding horovod_strategy's method. 137 | It is run on the worker processes. 138 | """ 139 | if not hvd.is_initialized(): 140 | return self.num_workers 141 | return hvd.size() 142 | 143 | def teardown(self) -> None: 144 | """Teardown the strategy. 145 | 146 | This function is overriding horovod_strategy's method. 147 | It is run on the driver process. 148 | """ 149 | self.join() 150 | self.accelerator = None 151 | super().teardown() 152 | 153 | @property 154 | def is_distributed(self): 155 | """Return whether the strategy is distributed. 156 | 157 | This function is a new HorovodStrategy method. 158 | It is run on the worker processes. 159 | """ 160 | return True 161 | 162 | def set_remote(self, remote: bool): 163 | """Set the remote flag. 164 | 165 | This function is a new RayStrategy method. 166 | It is run on the worker processes. 167 | """ 168 | self._is_remote = remote 169 | 170 | @property 171 | def root_device(self): 172 | """Return the root device. 173 | 174 | This function is overriding horovod_strategy's method. 175 | It is run on the worker processes. 176 | """ 177 | if self.use_gpu and torch.cuda.is_available(): 178 | if hvd.is_initialized(): 179 | return torch.device("cuda", hvd.local_rank()) 180 | else: 181 | return torch.device("cuda", 0) 182 | else: 183 | return torch.device("cpu") 184 | -------------------------------------------------------------------------------- /ray_lightning/session.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from ray.util.queue import Queue 4 | 5 | 6 | class RayLightningSession: 7 | def __init__(self, rank: int, queue: Optional[Queue]): 8 | self._rank = rank 9 | self._queue = queue 10 | 11 | def get_actor_rank(self): 12 | return self._rank 13 | 14 | def set_queue(self, queue): 15 | self._queue = queue 16 | 17 | def put_queue(self, item): 18 | if self._queue is None: 19 | raise ValueError( 20 | "Trying to put something into session queue, but queue " 21 | "was not initialized. This is probably a bug, please raise " 22 | "an issue at " 23 | "https://github.com/ray-project/ray_lightning_accelerators") 24 | self._queue.put((self._rank, item)) 25 | 26 | 27 | _session = None 28 | 29 | 30 | def init_session(*args, **kwargs): 31 | global _session 32 | if _session: 33 | raise ValueError( 34 | "Trying to initialize RayLightningSession twice." 35 | "\nFIX THIS by not calling `init_session()` manually.") 36 | _session = RayLightningSession(*args, **kwargs) 37 | 38 | 39 | def get_session() -> RayLightningSession: 40 | global _session 41 | if not _session or not isinstance(_session, RayLightningSession): 42 | raise ValueError( 43 | "Trying to access RayLightningSession from outside an Pytorch " 44 | "Lightning run." 45 | "\nFIX THIS by calling function in `session.py` like " 46 | "`get_actor_rank()` only from within an Pytorch Lightning actor " 47 | "session.") 48 | return _session 49 | 50 | 51 | def set_session_queue(queue: Queue): 52 | session = get_session() 53 | session.set_queue(queue) 54 | 55 | 56 | def get_actor_rank() -> int: 57 | session = get_session() 58 | return session.get_actor_rank() 59 | 60 | 61 | def put_queue(*args, **kwargs): 62 | session = get_session() 63 | session.put_queue(*args, **kwargs) 64 | -------------------------------------------------------------------------------- /ray_lightning/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray-project/ray_lightning/24f5922ee069f5c72087e300034c9be11b65016f/ray_lightning/tests/__init__.py -------------------------------------------------------------------------------- /ray_lightning/tests/test_client.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import pytest 5 | 6 | import ray 7 | from ray.util.client.ray_client_helpers import ray_start_client_server 8 | 9 | 10 | @pytest.fixture 11 | def start_ray_client_server_2_cpus(): 12 | ray.init(num_cpus=2) 13 | with ray_start_client_server() as client: 14 | yield client 15 | 16 | 17 | def test_ddp_example(start_ray_client_server_2_cpus): 18 | assert ray.util.client.ray.is_connected() 19 | from ray_lightning.examples.ray_ddp_example import train_mnist 20 | data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_") 21 | config = {"layer_1": 32, "layer_2": 64, "lr": 1e-1, "batch_size": 32} 22 | train_mnist(config, data_dir, num_epochs=1, num_workers=1, use_gpu=False) 23 | 24 | 25 | def test_ddp_example_tune(start_ray_client_server_2_cpus): 26 | assert ray.util.client.ray.is_connected() 27 | from ray_lightning.examples.ray_ddp_example import tune_mnist 28 | data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_") 29 | tune_mnist( 30 | data_dir, num_samples=1, num_epochs=1, num_workers=1, use_gpu=False) 31 | -------------------------------------------------------------------------------- /ray_lightning/tests/test_client_2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import pytest 5 | 6 | import ray 7 | from ray.util.client.ray_client_helpers import ray_start_client_server 8 | 9 | 10 | @pytest.fixture 11 | def start_ray_client_server_2_cpus(): 12 | ray.init(num_cpus=2) 13 | with ray_start_client_server() as client: 14 | yield client 15 | 16 | 17 | def test_ddp_tune(start_ray_client_server_2_cpus): 18 | assert ray.util.client.ray.is_connected() 19 | from ray_lightning.examples.ray_ddp_tune import tune_mnist 20 | data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_") 21 | tune_mnist( 22 | data_dir, num_samples=1, num_epochs=1, num_workers=1, use_gpu=False) 23 | -------------------------------------------------------------------------------- /ray_lightning/tests/test_client_3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import pytest 5 | 6 | import ray 7 | from ray.util.client.ray_client_helpers import ray_start_client_server 8 | 9 | 10 | @pytest.fixture 11 | def start_ray_client_server_2_cpus(): 12 | ray.init(num_cpus=2) 13 | with ray_start_client_server() as client: 14 | yield client 15 | 16 | 17 | def test_horovod_example(start_ray_client_server_2_cpus): 18 | assert ray.util.client.ray.is_connected() 19 | from ray_lightning.examples.ray_horovod_example import train_mnist 20 | data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_") 21 | config = {"layer_1": 32, "layer_2": 64, "lr": 1e-1, "batch_size": 32} 22 | train_mnist(config, data_dir, num_epochs=1, num_workers=1, use_gpu=False) 23 | 24 | 25 | def test_horovod_example_tune(start_ray_client_server_2_cpus): 26 | assert ray.util.client.ray.is_connected() 27 | from ray_lightning.examples.ray_horovod_example import tune_mnist 28 | data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_") 29 | tune_mnist( 30 | data_dir, num_samples=1, num_epochs=1, num_workers=1, use_gpu=False) 31 | -------------------------------------------------------------------------------- /ray_lightning/tests/test_ddp.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from ray.util.client.ray_client_helpers import ray_start_client_server 3 | import torch 4 | from torch.utils.data import DistributedSampler 5 | 6 | from pl_bolts.datamodules import MNISTDataModule 7 | import pytorch_lightning as pl 8 | from pytorch_lightning import Callback 9 | from pytorch_lightning.callbacks import EarlyStopping 10 | 11 | import ray 12 | from ray.cluster_utils import Cluster 13 | 14 | from ray_lightning import RayStrategy 15 | from ray_lightning.tests.utils import get_trainer, train_test, \ 16 | load_test, predict_test, BoringModel, LightningMNISTClassifier, \ 17 | XORModel, XORDataModule 18 | 19 | 20 | @pytest.fixture 21 | def ray_start_2_cpus(): 22 | address_info = ray.init(num_cpus=2) 23 | yield address_info 24 | ray.shutdown() 25 | 26 | 27 | @pytest.fixture 28 | def ray_start_4_cpus(): 29 | address_info = ray.init(num_cpus=4) 30 | yield address_info 31 | ray.shutdown() 32 | 33 | 34 | @pytest.fixture 35 | def ray_start_4_cpus_4_extra(): 36 | address_info = ray.init(num_cpus=4, resources={"extra": 4}) 37 | yield address_info 38 | # The code after the yield will run as teardown code. 39 | ray.shutdown() 40 | 41 | 42 | @pytest.fixture 43 | def start_ray_client_server_2_cpus(): 44 | ray.init(num_cpus=2) 45 | with ray_start_client_server() as client: 46 | yield client 47 | 48 | 49 | @pytest.fixture 50 | def seed(): 51 | pl.seed_everything(0) 52 | 53 | 54 | @pytest.fixture 55 | def ray_start_cluster_2_node_2_cpu(): 56 | cluster = Cluster() 57 | cluster.add_node(num_cpus=2) 58 | cluster.add_node(num_cpus=2) 59 | address_info = ray.init(cluster.address) 60 | yield address_info 61 | ray.shutdown() 62 | cluster.shutdown() 63 | 64 | 65 | @pytest.mark.parametrize("num_workers", [1, 2]) 66 | def test_actor_creation(tmpdir, ray_start_2_cpus, num_workers): 67 | """Tests whether the appropriate number of training actors are created.""" 68 | model = BoringModel() 69 | 70 | def check_num_actor(): 71 | assert len(ray.state.actors()) == num_workers 72 | 73 | model.on_epoch_end = check_num_actor 74 | 75 | strategy = RayStrategy(num_workers=num_workers) 76 | trainer = get_trainer(tmpdir, strategy=strategy) 77 | trainer.fit(model) 78 | 79 | 80 | def test_global_local_ranks(ray_start_4_cpus): 81 | """Tests local rank and node rank map is correct.""" 82 | 83 | @ray.remote 84 | class Node1Actor: 85 | def get_node_ip(self): 86 | return "1" 87 | 88 | @ray.remote 89 | class Node2Actor: 90 | def get_node_ip(self): 91 | return "2" 92 | 93 | strategy = RayStrategy(num_workers=4, use_gpu=False) 94 | strategy._configure_launcher() 95 | 96 | # 2 workers on "Node 1", 2 workers on "Node 2" 97 | strategy._launcher._workers = [ 98 | Node1Actor.remote(), 99 | Node1Actor.remote(), 100 | Node2Actor.remote(), 101 | Node2Actor.remote() 102 | ] 103 | 104 | global_to_local = strategy._launcher.get_local_ranks() 105 | assert len(global_to_local) == 4 106 | local_ranks = {ranks[0] for ranks in global_to_local} 107 | node_ranks = {ranks[1] for ranks in global_to_local} 108 | 109 | assert local_ranks == set(range(2)) 110 | assert node_ranks == set(range(2)) 111 | 112 | # Make sure the rank 0 worker has local rank and node rank of 0. 113 | assert global_to_local[0][0] == 0 114 | assert global_to_local[0][1] == 0 115 | 116 | 117 | @pytest.mark.parametrize("num_workers", [1, 2]) 118 | @pytest.mark.parametrize("extra_resource_per_worker", [1, 2]) 119 | @pytest.mark.parametrize("num_cpus_per_worker", [1, 2]) 120 | def test_actor_creation_resources(tmpdir, ray_start_4_cpus_4_extra, 121 | num_workers, extra_resource_per_worker, 122 | num_cpus_per_worker): 123 | """Tests if training actors are created with custom resources.""" 124 | model = BoringModel() 125 | strategy = RayStrategy( 126 | num_workers=num_workers, 127 | num_cpus_per_worker=num_cpus_per_worker, 128 | resources_per_worker={"extra": 1}) 129 | 130 | def check_num_actor(): 131 | assert len(ray.state.actors()) == num_workers 132 | 133 | model.on_epoch_end = check_num_actor 134 | trainer = get_trainer(tmpdir, strategy=strategy) 135 | trainer.fit(model) 136 | 137 | 138 | def test_resource_override(ray_start_2_cpus): 139 | """Tests if CPU and GPU resources are overridden if manually passed in.""" 140 | 141 | strategy = RayStrategy(num_workers=1, num_cpus_per_worker=2, use_gpu=True) 142 | assert strategy.num_cpus_per_worker == 2 143 | assert strategy.use_gpu 144 | 145 | strategy = RayStrategy( 146 | num_workers=1, 147 | num_cpus_per_worker=2, 148 | use_gpu=True, 149 | resources_per_worker={"CPU": 3}) 150 | assert strategy.num_cpus_per_worker == 3 151 | assert strategy.use_gpu 152 | 153 | strategy = RayStrategy( 154 | num_workers=1, 155 | num_cpus_per_worker=2, 156 | use_gpu=True, 157 | resources_per_worker={"GPU": 0}) 158 | assert strategy.num_cpus_per_worker == 2 159 | assert not strategy.use_gpu 160 | 161 | strategy = RayStrategy( 162 | num_workers=1, 163 | num_cpus_per_worker=2, 164 | use_gpu=False, 165 | resources_per_worker={"GPU": 1}) 166 | assert strategy.num_cpus_per_worker == 2 167 | assert strategy.use_gpu 168 | 169 | strategy = RayStrategy( 170 | num_workers=1, 171 | num_cpus_per_worker=2, 172 | use_gpu=False, 173 | resources_per_worker={"GPU": 2}) 174 | assert strategy.num_cpus_per_worker == 2 175 | assert strategy.num_gpus_per_worker == 2 176 | assert strategy.use_gpu 177 | 178 | 179 | def test_distributed_sampler(tmpdir, ray_start_2_cpus): 180 | """Tests if distributed sampler is properly set.""" 181 | model = BoringModel() 182 | train_dataloader = model.train_dataloader() 183 | initial_sampler = train_dataloader.sampler 184 | assert not isinstance(initial_sampler, DistributedSampler) 185 | 186 | class DistributedSamplerCallback(Callback): 187 | def on_train_start(self, trainer, pl_module): 188 | train_sampler = trainer.train_dataloader.sampler 189 | assert isinstance(train_sampler, DistributedSampler) 190 | assert train_sampler.shuffle 191 | assert train_sampler.num_replicas == 2 192 | assert train_sampler.rank == trainer.global_rank 193 | 194 | def on_validation_start(self, trainer, pl_module): 195 | train_sampler = trainer.val_dataloaders[0].sampler 196 | assert isinstance(train_sampler, DistributedSampler) 197 | assert not train_sampler.shuffle 198 | assert train_sampler.num_replicas == 2 199 | assert train_sampler.rank == trainer.global_rank 200 | 201 | def on_test_start(self, trainer, pl_module): 202 | train_sampler = trainer.test_dataloaders[0].sampler 203 | assert isinstance(train_sampler, DistributedSampler) 204 | assert not train_sampler.shuffle 205 | assert train_sampler.num_replicas == 2 206 | assert train_sampler.rank == trainer.global_rank 207 | 208 | strategy = RayStrategy(num_workers=2) 209 | trainer = get_trainer( 210 | tmpdir, strategy=strategy, callbacks=[DistributedSamplerCallback()]) 211 | trainer.fit(model) 212 | 213 | 214 | @pytest.mark.parametrize("num_workers", [1, 2]) 215 | def test_train(tmpdir, ray_start_2_cpus, num_workers): 216 | """Tests if training modifies model weights.""" 217 | model = BoringModel() 218 | strategy = RayStrategy(num_workers=num_workers) 219 | trainer = get_trainer(tmpdir, strategy=strategy) 220 | train_test(trainer, model) 221 | 222 | 223 | @pytest.mark.parametrize("num_workers", [1, 2]) 224 | def test_train_client(tmpdir, start_ray_client_server_2_cpus, num_workers): 225 | assert ray.util.client.ray.is_connected() 226 | model = BoringModel() 227 | strategy = RayStrategy(num_workers=num_workers) 228 | trainer = get_trainer(tmpdir, strategy=strategy) 229 | train_test(trainer, model) 230 | 231 | 232 | def test_test_with_dataloader_workers(tmpdir, ray_start_2_cpus, seed): 233 | """Tests trainer.test with >0 workers for data loading.""" 234 | model = BoringModel() 235 | strategy = RayStrategy(num_workers=1, use_gpu=False) 236 | trainer = get_trainer( 237 | tmpdir, limit_train_batches=20, max_epochs=1, strategy=strategy) 238 | trainer.test(model) 239 | 240 | 241 | @pytest.mark.parametrize("num_workers", [1, 2]) 242 | def test_load(tmpdir, ray_start_2_cpus, num_workers): 243 | """Tests if model checkpoint can be loaded.""" 244 | model = BoringModel() 245 | strategy = RayStrategy(num_workers=num_workers, use_gpu=False) 246 | trainer = get_trainer(tmpdir, strategy=strategy) 247 | load_test(trainer, model) 248 | 249 | 250 | @pytest.mark.parametrize("num_workers", [1, 2]) 251 | def test_predict(tmpdir, ray_start_2_cpus, seed, num_workers): 252 | """Tests if trained model has high accuracy on test set.""" 253 | config = { 254 | "layer_1": 32, 255 | "layer_2": 32, 256 | "lr": 1e-2, 257 | "batch_size": 32, 258 | } 259 | 260 | model = LightningMNISTClassifier(config, tmpdir) 261 | dm = MNISTDataModule( 262 | data_dir=tmpdir, num_workers=1, batch_size=config["batch_size"]) 263 | strategy = RayStrategy(num_workers=num_workers, use_gpu=False) 264 | trainer = get_trainer( 265 | tmpdir, limit_train_batches=20, max_epochs=1, strategy=strategy) 266 | predict_test(trainer, model, dm) 267 | 268 | 269 | @pytest.mark.parametrize("num_workers", [1, 2]) 270 | def test_predict_client(tmpdir, start_ray_client_server_2_cpus, seed, 271 | num_workers): 272 | assert ray.util.client.ray.is_connected() 273 | config = { 274 | "layer_1": 32, 275 | "layer_2": 32, 276 | "lr": 1e-2, 277 | "batch_size": 32, 278 | } 279 | 280 | model = LightningMNISTClassifier(config, tmpdir) 281 | dm = MNISTDataModule( 282 | data_dir=tmpdir, num_workers=1, batch_size=config["batch_size"]) 283 | strategy = RayStrategy(num_workers=num_workers, use_gpu=False) 284 | trainer = get_trainer( 285 | tmpdir, limit_train_batches=20, max_epochs=1, strategy=strategy) 286 | predict_test(trainer, model, dm) 287 | 288 | 289 | def test_early_stop(tmpdir, ray_start_2_cpus): 290 | """Tests if early stopping callback works correctly.""" 291 | model = BoringModel() 292 | strategy = RayStrategy(num_workers=1, use_gpu=False) 293 | patience = 2 294 | early_stop = EarlyStopping( 295 | monitor="val_loss", patience=patience, verbose=True) 296 | trainer = get_trainer( 297 | tmpdir, 298 | max_epochs=500, 299 | strategy=strategy, 300 | callbacks=[early_stop], 301 | num_sanity_val_steps=0, 302 | limit_train_batches=1.0, 303 | limit_val_batches=1.0, 304 | progress_bar_refresh_rate=1) 305 | trainer.fit(model) 306 | trained_model = BoringModel.load_from_checkpoint( 307 | trainer.checkpoint_callback.best_model_path) 308 | assert trained_model.val_epoch == patience + 1, trained_model.val_epoch 309 | 310 | 311 | def test_unused_parameters(tmpdir, ray_start_2_cpus): 312 | """Tests if find_unused_parameters is properly passed to model.""" 313 | model = BoringModel() 314 | strategy = RayStrategy( 315 | num_workers=2, use_gpu=False, find_unused_parameters=False) 316 | 317 | class UnusedParameterCallback(Callback): 318 | def on_train_start(self, trainer, pl_module): 319 | assert trainer.model.find_unused_parameters is False 320 | 321 | trainer = get_trainer( 322 | tmpdir, strategy=strategy, callbacks=[UnusedParameterCallback()]) 323 | trainer.fit(model) 324 | 325 | 326 | def test_metrics(tmpdir, ray_start_2_cpus): 327 | """Tests if metrics are returned correctly""" 328 | model = XORModel() 329 | strategy = RayStrategy(num_workers=2, find_unused_parameters=False) 330 | trainer = get_trainer( 331 | tmpdir, 332 | strategy=strategy, 333 | max_epochs=1, 334 | num_sanity_val_steps=0, 335 | limit_train_batches=2, 336 | limit_val_batches=2, 337 | reload_dataloaders_every_n_epochs=1) 338 | dataset = XORDataModule() 339 | trainer.fit(model, dataset) 340 | callback_metrics = trainer.callback_metrics 341 | logged_metrics = trainer.logged_metrics 342 | assert callback_metrics["avg_val_loss"] == logged_metrics["avg_val_loss"] 343 | assert logged_metrics["val_foo"] == torch.tensor(1.234) 344 | assert callback_metrics["val_foo"] == torch.tensor(1.234) 345 | # forked name is used for on_step logged metrics 346 | forked_name_loss = "val_loss" + "_step" 347 | forked_name_bar = "val_bar" + "_step" 348 | assert forked_name_loss in logged_metrics.keys() 349 | assert logged_metrics[forked_name_bar] == torch.tensor(5.678) 350 | # callback_metrics doesn't record on_step metrics 351 | assert forked_name_loss not in callback_metrics.keys() 352 | assert forked_name_bar not in callback_metrics.keys() 353 | -------------------------------------------------------------------------------- /ray_lightning/tests/test_ddp_gpu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import torch 4 | 5 | from pl_bolts.datamodules import MNISTDataModule 6 | import pytorch_lightning as pl 7 | from pytorch_lightning import Callback 8 | 9 | import ray 10 | 11 | from ray_lightning import RayStrategy 12 | from ray_lightning.tests.utils import get_trainer, train_test, BoringModel, \ 13 | predict_test, LightningMNISTClassifier 14 | 15 | 16 | @pytest.fixture 17 | def ray_start_2_gpus(): 18 | address_info = ray.init(num_cpus=2, num_gpus=2) 19 | yield address_info 20 | ray.shutdown() 21 | 22 | 23 | @pytest.fixture 24 | def ray_start_4_gpus(): 25 | address_info = ray.init(num_cpus=4, num_gpus=4) 26 | yield address_info 27 | ray.shutdown() 28 | 29 | 30 | @pytest.fixture 31 | def seed(): 32 | pl.seed_everything(0) 33 | 34 | 35 | @pytest.mark.skipif( 36 | torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") 37 | @pytest.mark.parametrize("num_workers", [1, 2]) 38 | def test_train(tmpdir, ray_start_2_gpus, num_workers): 39 | """Tests if training modifies model weights.""" 40 | model = BoringModel() 41 | strategy = RayStrategy(num_workers=num_workers, use_gpu=True) 42 | trainer = get_trainer(tmpdir, strategy=strategy) 43 | train_test(trainer, model) 44 | 45 | 46 | @pytest.mark.skipif( 47 | torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") 48 | @pytest.mark.parametrize("num_workers", [1, 2]) 49 | def test_predict(tmpdir, ray_start_2_gpus, seed, num_workers): 50 | """Tests if trained model has high accuracy on test set.""" 51 | config = { 52 | "layer_1": 32, 53 | "layer_2": 32, 54 | "lr": 1e-2, 55 | "batch_size": 32, 56 | } 57 | model = LightningMNISTClassifier(config, tmpdir) 58 | dm = MNISTDataModule( 59 | data_dir=tmpdir, num_workers=1, batch_size=config["batch_size"]) 60 | strategy = RayStrategy(num_workers=num_workers, use_gpu=True) 61 | trainer = get_trainer( 62 | tmpdir, limit_train_batches=20, max_epochs=1, strategy=strategy) 63 | predict_test(trainer, model, dm) 64 | 65 | 66 | @pytest.mark.skipif( 67 | torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") 68 | def test_model_to_gpu(tmpdir, ray_start_2_gpus): 69 | """Tests if model is placed on CUDA device.""" 70 | model = BoringModel() 71 | 72 | class CheckGPUCallback(Callback): 73 | def on_epoch_end(self, trainer, pl_module): 74 | assert next(pl_module.parameters()).is_cuda 75 | 76 | strategy = RayStrategy(num_workers=2, use_gpu=True) 77 | trainer = get_trainer( 78 | tmpdir, strategy=strategy, callbacks=[CheckGPUCallback()]) 79 | trainer.fit(model) 80 | 81 | 82 | @pytest.mark.skipif( 83 | torch.cuda.device_count() < 4, reason="test requires multi-GPU machine") 84 | @pytest.mark.parametrize("num_gpus_per_worker", [0.4, 0.5, 1, 2]) 85 | def test_correct_devices(tmpdir, ray_start_4_gpus, num_gpus_per_worker, 86 | monkeypatch): 87 | """Tests if GPU devices are correctly set.""" 88 | model = BoringModel() 89 | 90 | if num_gpus_per_worker < 1: 91 | monkeypatch.setenv("PL_TORCH_DISTRIBUTED_BACKEND", "gloo") 92 | 93 | def get_gpu_placement(current_worker_index, num_gpus_per_worker): 94 | """Simulates GPU resource bin packing.""" 95 | next_gpu_index = 0 96 | starting_resource_count = num_gpus_per_worker 97 | for _ in range(current_worker_index + 1): 98 | current_gpu_index = next_gpu_index 99 | next_resources = starting_resource_count + \ 100 | num_gpus_per_worker - 0.0001 101 | # If the next worker cannot fit on the current GPU, then we move 102 | # onto the next GPU. 103 | if int(next_resources) != current_gpu_index: 104 | increment = max(1, int(num_gpus_per_worker)) 105 | next_gpu_index = current_gpu_index + increment 106 | 107 | return current_gpu_index 108 | 109 | class CheckDevicesCallback(Callback): 110 | def on_epoch_end(self, trainer, pl_module): 111 | assert trainer.strategy.root_device.index == get_gpu_placement( 112 | trainer.local_rank, num_gpus_per_worker) 113 | assert trainer.strategy.root_device.index == pl_module.device.index 114 | assert torch.cuda.current_device( 115 | ) == trainer.strategy.root_device.index 116 | 117 | strategy = RayStrategy( 118 | num_workers=2, 119 | use_gpu=True, 120 | resources_per_worker={"GPU": num_gpus_per_worker}) 121 | trainer = get_trainer( 122 | tmpdir, strategy=strategy, callbacks=[CheckDevicesCallback()]) 123 | trainer.fit(model) 124 | 125 | 126 | @pytest.mark.skipif( 127 | os.environ.get("CLUSTER", "0") != "1", 128 | reason="Should not be run in CI. Requires multi-node Ray " 129 | "cluster.") 130 | def test_multi_node(tmpdir): 131 | """Tests if multi-node GPU training works.""" 132 | ray.init("auto") 133 | num_gpus = ray.available_resources()["GPU"] 134 | model = BoringModel() 135 | strategy = RayStrategy(num_workers=num_gpus, use_gpu=True) 136 | trainer = get_trainer(tmpdir, strategy=strategy) 137 | train_test(trainer, model) 138 | -------------------------------------------------------------------------------- /ray_lightning/tests/test_ddp_sharded.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | 6 | import pytorch_lightning as pl 7 | from pytorch_lightning import Callback, Trainer 8 | 9 | import ray 10 | 11 | from ray_lightning import RayShardedStrategy 12 | from ray_lightning.tests.utils import BoringModel 13 | 14 | 15 | @pytest.fixture 16 | def ray_start_2_cpus(): 17 | address_info = ray.init(num_cpus=2) 18 | yield address_info 19 | ray.shutdown() 20 | 21 | 22 | @pytest.fixture 23 | def seed(): 24 | pl.seed_everything(0) 25 | 26 | 27 | def test_ddp_choice_sharded(tmpdir, ray_start_2_cpus, seed): 28 | """Tests if sharded strategy is properly recognized.""" 29 | 30 | class CB(Callback): 31 | def on_fit_start(self, trainer, pl_module): 32 | assert isinstance(trainer.strategy, RayShardedStrategy) 33 | raise ValueError() 34 | 35 | model = BoringModel() 36 | trainer = Trainer( 37 | fast_dev_run=True, 38 | strategy=RayShardedStrategy(num_workers=2), 39 | callbacks=[CB()], 40 | ) 41 | 42 | with pytest.raises(ValueError): 43 | trainer.fit(model) 44 | 45 | 46 | def test_ddp_sharded_plugin_checkpoint(tmpdir, ray_start_2_cpus, seed): 47 | """Tests if checkpoint is saved correctly.""" 48 | model = BoringModel() 49 | trainer = Trainer( 50 | strategy=RayShardedStrategy(num_workers=2), 51 | fast_dev_run=True, 52 | ) 53 | 54 | trainer.fit(model) 55 | 56 | checkpoint_path = os.path.join(tmpdir, "model.pt") 57 | trainer.save_checkpoint(checkpoint_path) 58 | saved_model = BoringModel.load_from_checkpoint(checkpoint_path) 59 | 60 | # Assert model parameters are identical after loading. 61 | for ddp_param, shard_param in zip(model.parameters(), 62 | saved_model.parameters()): 63 | assert torch.equal(ddp_param, shard_param) 64 | 65 | 66 | def test_ddp_sharded_plugin_finetune(tmpdir, ray_start_2_cpus, seed): 67 | """Tests if we can save and restart training.""" 68 | model = BoringModel() 69 | trainer = Trainer( 70 | strategy=RayShardedStrategy(num_workers=2), 71 | fast_dev_run=True, 72 | ) 73 | trainer.fit(model) 74 | 75 | checkpoint_path = os.path.join(tmpdir, "model.pt") 76 | trainer.save_checkpoint(checkpoint_path) 77 | saved_model = BoringModel.load_from_checkpoint(checkpoint_path) 78 | 79 | trainer = Trainer(fast_dev_run=True, ) 80 | trainer.fit(saved_model) 81 | 82 | 83 | def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir, ray_start_2_cpus, 84 | seed): 85 | """Tests if resuming from checkpoint works.""" 86 | model = BoringModel() 87 | trainer = Trainer( 88 | strategy=RayShardedStrategy(num_workers=2), 89 | fast_dev_run=True, 90 | ) 91 | 92 | trainer.fit(model) 93 | 94 | checkpoint_path = os.path.join(tmpdir, "model.pt") 95 | trainer.save_checkpoint(checkpoint_path) 96 | 97 | model = BoringModel() 98 | 99 | trainer = Trainer( 100 | strategy=RayShardedStrategy(num_workers=2), 101 | fast_dev_run=True, 102 | resume_from_checkpoint=checkpoint_path) 103 | 104 | trainer.fit(model) 105 | 106 | 107 | def test_ddp_sharded_plugin_test(tmpdir, ray_start_2_cpus, seed): 108 | """Tests if test works without fit.""" 109 | model = BoringModel() 110 | trainer = Trainer( 111 | strategy=RayShardedStrategy(num_workers=2), 112 | fast_dev_run=True, 113 | ) 114 | 115 | trainer.test(model) 116 | 117 | 118 | def test_ddp_sharded_plugin_resume_from_checkpoint_downsize( 119 | tmpdir, ray_start_2_cpus, seed): 120 | """Tests if we can save and resume training with less workers.""" 121 | model = BoringModel() 122 | trainer = Trainer( 123 | strategy=RayShardedStrategy(num_workers=2), fast_dev_run=True) 124 | 125 | trainer.fit(model) 126 | 127 | checkpoint_path = os.path.join(tmpdir, "model.pt") 128 | trainer.save_checkpoint(checkpoint_path) 129 | 130 | model = BoringModel() 131 | 132 | trainer = Trainer( 133 | strategy=RayShardedStrategy(num_workers=1), 134 | fast_dev_run=True, 135 | resume_from_checkpoint=checkpoint_path) 136 | 137 | trainer.fit(model) 138 | -------------------------------------------------------------------------------- /ray_lightning/tests/test_horovod.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule 5 | import pytorch_lightning as pl 6 | from ray.util.client.ray_client_helpers import ray_start_client_server 7 | 8 | try: 9 | import horovod # noqa: F401 10 | except ImportError: 11 | HOROVOD_AVAILABLE = False 12 | else: 13 | HOROVOD_AVAILABLE = True 14 | 15 | import ray 16 | 17 | from ray_lightning import HorovodRayStrategy 18 | from ray_lightning.tests.utils import get_trainer, BoringModel, \ 19 | train_test, load_test, predict_test, LightningMNISTClassifier 20 | 21 | 22 | @pytest.fixture 23 | def ray_start_2_cpus(): 24 | address_info = ray.init(num_cpus=2) 25 | yield address_info 26 | ray.shutdown() 27 | 28 | 29 | @pytest.fixture 30 | def ray_start_2_gpus(): 31 | address_info = ray.init(num_cpus=2, num_gpus=2) 32 | yield address_info 33 | ray.shutdown() 34 | 35 | 36 | @pytest.fixture 37 | def start_ray_client_server_2_cpus(): 38 | ray.init(num_cpus=2) 39 | with ray_start_client_server() as client: 40 | yield client 41 | 42 | 43 | @pytest.fixture 44 | def seed(): 45 | pl.seed_everything(0) 46 | 47 | 48 | @pytest.mark.parametrize("num_workers", [1, 2]) 49 | def test_train(tmpdir, ray_start_2_cpus, seed, num_workers): 50 | """Tests if training modifies model weights.""" 51 | model = BoringModel() 52 | strategy = HorovodRayStrategy(num_workers=num_workers, use_gpu=False) 53 | trainer = get_trainer(tmpdir, strategy=strategy) 54 | train_test(trainer, model) 55 | 56 | 57 | @pytest.mark.parametrize("num_workers", [1, 2]) 58 | def test_train_client(tmpdir, start_ray_client_server_2_cpus, seed, 59 | num_workers): 60 | """Tests if training modifies model weights.""" 61 | assert ray.util.client.ray.is_connected() 62 | model = BoringModel() 63 | strategy = HorovodRayStrategy(num_workers=num_workers, use_gpu=False) 64 | trainer = get_trainer(tmpdir, strategy=strategy) 65 | train_test(trainer, model) 66 | 67 | 68 | @pytest.mark.parametrize("num_workers", [1, 2]) 69 | def test_load(tmpdir, ray_start_2_cpus, seed, num_workers): 70 | """Tests if model checkpoint can be loaded.""" 71 | model = BoringModel() 72 | strategy = HorovodRayStrategy(num_workers=num_workers, use_gpu=False) 73 | trainer = get_trainer(tmpdir, strategy=strategy) 74 | load_test(trainer, model) 75 | 76 | 77 | @pytest.mark.parametrize("num_workers", [1, 2]) 78 | def test_predict(tmpdir, ray_start_2_cpus, seed, num_workers): 79 | """Tests if trained model has high accuracy on test set.""" 80 | config = { 81 | "layer_1": 32, 82 | "layer_2": 32, 83 | "lr": 1e-2, 84 | "batch_size": 32, 85 | } 86 | model = LightningMNISTClassifier(config, tmpdir) 87 | dm = MNISTDataModule( 88 | data_dir=tmpdir, num_workers=1, batch_size=config["batch_size"]) 89 | strategy = HorovodRayStrategy(num_workers=num_workers, use_gpu=False) 90 | trainer = get_trainer( 91 | tmpdir, limit_train_batches=20, max_epochs=1, strategy=strategy) 92 | predict_test(trainer, model, dm) 93 | 94 | 95 | @pytest.mark.parametrize("num_workers", [1, 2]) 96 | def test_predict_client(tmpdir, start_ray_client_server_2_cpus, seed, 97 | num_workers): 98 | assert ray.util.client.ray.is_connected() 99 | config = { 100 | "layer_1": 32, 101 | "layer_2": 32, 102 | "lr": 1e-2, 103 | "batch_size": 32, 104 | } 105 | model = LightningMNISTClassifier(config, tmpdir) 106 | dm = MNISTDataModule( 107 | data_dir=tmpdir, num_workers=1, batch_size=config["batch_size"]) 108 | strategy = HorovodRayStrategy(num_workers=num_workers, use_gpu=False) 109 | trainer = get_trainer( 110 | tmpdir, limit_train_batches=20, max_epochs=1, strategy=strategy) 111 | predict_test(trainer, model, dm) 112 | 113 | 114 | @pytest.mark.skipif( 115 | torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") 116 | @pytest.mark.parametrize("num_workers", [1, 2]) 117 | def test_train_gpu(tmpdir, ray_start_2_gpus, seed, num_workers): 118 | """Tests if training modifies model weights.""" 119 | model = BoringModel() 120 | strategy = HorovodRayStrategy(num_workers=num_workers, use_gpu=True) 121 | trainer = get_trainer(tmpdir, strategy=strategy) 122 | train_test(trainer, model) 123 | 124 | 125 | @pytest.mark.skipif( 126 | torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") 127 | @pytest.mark.parametrize("num_workers", [1, 2]) 128 | def test_load_gpu(tmpdir, ray_start_2_gpus, seed, num_workers): 129 | """Tests if model checkpoint can be loaded.""" 130 | model = BoringModel() 131 | strategy = HorovodRayStrategy(num_workers=num_workers, use_gpu=True) 132 | trainer = get_trainer(tmpdir, strategy=strategy) 133 | load_test(trainer, model) 134 | 135 | 136 | @pytest.mark.skipif( 137 | torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") 138 | @pytest.mark.parametrize("num_workers", [1, 2]) 139 | def test_predict_gpu(tmpdir, ray_start_2_gpus, seed, num_workers): 140 | """Tests if trained model has high accuracy on test set.""" 141 | config = { 142 | "layer_1": 32, 143 | "layer_2": 32, 144 | "lr": 1e-2, 145 | "batch_size": 32, 146 | } 147 | model = LightningMNISTClassifier(config, tmpdir) 148 | dm = MNISTDataModule( 149 | data_dir=tmpdir, num_workers=1, batch_size=config["batch_size"]) 150 | strategy = HorovodRayStrategy(num_workers=num_workers, use_gpu=True) 151 | trainer = get_trainer( 152 | tmpdir, limit_train_batches=20, max_epochs=1, strategy=strategy) 153 | predict_test(trainer, model, dm) 154 | -------------------------------------------------------------------------------- /ray_lightning/tests/test_lightning_cli.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from importlib.util import find_spec 3 | from pytorch_lightning.utilities.cli import LightningCLI 4 | from ray_lightning import RayStrategy 5 | from ray_lightning.tests.utils import BoringModel 6 | from unittest import mock 7 | 8 | 9 | @pytest.mark.skipif( 10 | not find_spec("jsonargparse"), reason="jsonargparse required") 11 | def test_lightning_cli_raystrategy_instantiation(): 12 | init_args = { 13 | "num_workers": 4, # Resolve from RayStrategy.__init__ 14 | "use_gpu": False, # Resolve from RayStrategy.__init__ 15 | "bucket_cap_mb": 50, # Resolve from DistributedDataParallel.__init__ 16 | } 17 | cli_args = ["--trainer.strategy=RayStrategy"] 18 | cli_args += [f"--trainer.strategy.{k}={v}" for k, v in init_args.items()] 19 | 20 | with mock.patch("sys.argv", ["any.py"] + cli_args): 21 | cli = LightningCLI(BoringModel, run=False) 22 | 23 | assert isinstance(cli.config_init["trainer"]["strategy"], RayStrategy) 24 | assert { 25 | k: cli.config["trainer"]["strategy"]["init_args"][k] 26 | for k in init_args 27 | } == init_args 28 | -------------------------------------------------------------------------------- /ray_lightning/tests/test_tune.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import ray 4 | import torch 5 | from ray import tune 6 | 7 | from ray_lightning import RayStrategy, HorovodRayStrategy 8 | from ray_lightning.tests.utils import BoringModel, get_trainer 9 | from ray_lightning.tune import TuneReportCallback, \ 10 | TuneReportCheckpointCallback, get_tune_resources 11 | 12 | 13 | @pytest.fixture 14 | def ray_start_4_cpus(): 15 | address_info = ray.init(num_cpus=4) 16 | yield address_info 17 | ray.shutdown() 18 | 19 | 20 | @pytest.fixture 21 | def ray_start_4_cpus_4_gpus(): 22 | address_info = ray.init(num_cpus=4, num_gpus=4) 23 | yield address_info 24 | ray.shutdown() 25 | 26 | 27 | def train_func(dir, strategy, callbacks=None): 28 | def _inner_train(config): 29 | model = BoringModel() 30 | trainer = get_trainer( 31 | dir, 32 | callbacks=callbacks, 33 | strategy=strategy, 34 | checkpoint_callback=False, 35 | **config) 36 | trainer.fit(model) 37 | 38 | return _inner_train 39 | 40 | 41 | def tune_test(dir, strategy): 42 | callbacks = [TuneReportCallback(on="validation_end")] 43 | analysis = tune.run( 44 | train_func(dir, strategy, callbacks=callbacks), 45 | config={"max_epochs": tune.choice([1, 2, 3])}, 46 | resources_per_trial=get_tune_resources( 47 | num_workers=strategy.num_workers, use_gpu=strategy.use_gpu), 48 | num_samples=2) 49 | # fix TUNE_RESULT_DELIM issue 50 | config_max_epochs = analysis.results_df.get("config.max_epochs", False) 51 | if config_max_epochs is False: 52 | config_max_epochs = analysis.results_df.get("config/max_epochs", False) 53 | assert all(analysis.results_df["training_iteration"] == config_max_epochs) 54 | 55 | 56 | def test_tune_iteration_ddp(tmpdir, ray_start_4_cpus): 57 | """Tests if each RayStrategy runs the correct number of iterations.""" 58 | strategy = RayStrategy(num_workers=2, use_gpu=False) 59 | tune_test(tmpdir, strategy) 60 | 61 | 62 | def test_tune_iteration_horovod(tmpdir, ray_start_4_cpus): 63 | """Tests if each HorovodRay trial runs the correct number of iterations.""" 64 | strategy = HorovodRayStrategy(num_workers=2, use_gpu=False) 65 | tune_test(tmpdir, strategy) 66 | 67 | 68 | def checkpoint_test(dir, strategy): 69 | callbacks = [TuneReportCheckpointCallback(on="validation_end")] 70 | analysis = tune.run( 71 | train_func(dir, strategy, callbacks=callbacks), 72 | config={"max_epochs": 2}, 73 | resources_per_trial=get_tune_resources( 74 | num_workers=strategy.num_workers, use_gpu=strategy.use_gpu), 75 | num_samples=1, 76 | local_dir=dir, 77 | log_to_file=True, 78 | metric="val_loss", 79 | mode="min") 80 | assert analysis.best_checkpoint is not None 81 | 82 | 83 | def test_checkpoint_ddp(tmpdir, ray_start_4_cpus): 84 | """Tests if Tune checkpointing works with RayAccelerator.""" 85 | strategy = RayStrategy(num_workers=2, use_gpu=False) 86 | checkpoint_test(tmpdir, strategy) 87 | 88 | 89 | def test_checkpoint_horovod(tmpdir, ray_start_4_cpus): 90 | """Tests if Tune checkpointing works with HorovodRayAccelerator.""" 91 | strategy = HorovodRayStrategy(num_workers=2, use_gpu=False) 92 | checkpoint_test(tmpdir, strategy) 93 | 94 | 95 | @pytest.mark.skipif( 96 | torch.cuda.device_count() < 4, reason="test requires multi-GPU machine") 97 | def test_checkpoint_ddp_gpu(tmpdir, ray_start_4_cpus_4_gpus): 98 | """Tests if Tune checkpointing works with RayAccelerator.""" 99 | strategy = RayStrategy(num_workers=2, use_gpu=True) 100 | checkpoint_test(tmpdir, strategy) 101 | 102 | 103 | @pytest.mark.skipif( 104 | torch.cuda.device_count() < 4, reason="test requires multi-GPU machine") 105 | def test_checkpoint_horovod_gpu(tmpdir, ray_start_4_cpus_4_gpus): 106 | """Tests if Tune checkpointing works with HorovodRayAccelerator.""" 107 | strategy = HorovodRayStrategy(num_workers=2, use_gpu=True) 108 | checkpoint_test(tmpdir, strategy) 109 | 110 | 111 | @pytest.mark.skipif( 112 | torch.cuda.device_count() < 4, reason="test requires multi-GPU machine") 113 | def test_tune_iteration_ddp_gpu(tmpdir, ray_start_4_cpus_4_gpus): 114 | """Tests if each RayStrategy runs the correct number of iterations.""" 115 | strategy = RayStrategy(num_workers=2, use_gpu=True) 116 | tune_test(tmpdir, strategy) 117 | -------------------------------------------------------------------------------- /ray_lightning/tests/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import Dataset 7 | 8 | import pytorch_lightning as pl 9 | from pytorch_lightning.strategies import Strategy 10 | from pytorch_lightning import LightningModule, Callback, Trainer, \ 11 | LightningDataModule 12 | 13 | import torchmetrics 14 | 15 | 16 | class RandomDataset(Dataset): 17 | def __init__(self, size: int, length: int): 18 | self.len = length 19 | self.data = torch.randn(length, size) 20 | 21 | def __getitem__(self, index: int): 22 | return self.data[index] 23 | 24 | def __len__(self): 25 | return self.len 26 | 27 | 28 | class BoringModel(LightningModule): 29 | def __init__(self): 30 | super().__init__() 31 | self.layer = torch.nn.Linear(32, 2) 32 | self.val_epoch = 0 33 | 34 | def forward(self, x): 35 | return self.layer(x) 36 | 37 | def loss(self, batch, prediction): 38 | # Arbitrary loss to have a loss that updates the model weights 39 | # during `Trainer.fit` calls 40 | return torch.nn.functional.mse_loss(prediction, 41 | torch.ones_like(prediction)) 42 | 43 | def step(self, x): 44 | x = self(x) 45 | out = torch.nn.functional.mse_loss(x, torch.ones_like(x)) 46 | return out 47 | 48 | def training_step(self, batch, batch_idx): 49 | output = self.layer(batch) 50 | loss = self.loss(batch, output) 51 | return {"loss": loss} 52 | 53 | def training_step_end(self, training_step_outputs): 54 | return training_step_outputs 55 | 56 | def training_epoch_end(self, outputs) -> None: 57 | torch.stack([x["loss"] for x in outputs]).mean() 58 | 59 | def validation_step(self, batch, batch_idx): 60 | self.layer(batch) 61 | loss = torch.tensor(1.0) 62 | self.log("val_loss", loss) 63 | return {"x": loss} 64 | 65 | def validation_epoch_end(self, outputs) -> None: 66 | torch.stack([x["x"] for x in outputs]).mean() 67 | self.val_epoch += 1 68 | 69 | def test_step(self, batch, batch_idx): 70 | output = self.layer(batch) 71 | loss = self.loss(batch, output) 72 | return {"y": loss} 73 | 74 | def test_epoch_end(self, outputs) -> None: 75 | torch.stack([x["y"] for x in outputs]).mean() 76 | 77 | def configure_optimizers(self): 78 | optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) 79 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) 80 | return [optimizer], [lr_scheduler] 81 | 82 | def train_dataloader(self): 83 | return torch.utils.data.DataLoader(RandomDataset(32, 64)) 84 | 85 | def val_dataloader(self): 86 | return torch.utils.data.DataLoader(RandomDataset(32, 64)) 87 | 88 | def test_dataloader(self): 89 | return torch.utils.data.DataLoader( 90 | RandomDataset(32, 64), num_workers=1) 91 | 92 | def on_save_checkpoint(self, checkpoint): 93 | checkpoint["val_epoch"] = self.val_epoch 94 | 95 | def on_load_checkpoint(self, checkpoint) -> None: 96 | self.val_epoch = checkpoint["val_epoch"] 97 | 98 | 99 | class LightningMNISTClassifier(pl.LightningModule): 100 | def __init__(self, config, data_dir=None): 101 | super(LightningMNISTClassifier, self).__init__() 102 | 103 | self.data_dir = data_dir or os.getcwd() 104 | self.lr = config["lr"] 105 | layer_1, layer_2 = config["layer_1"], config["layer_2"] 106 | self.batch_size = config["batch_size"] 107 | 108 | # mnist images are (1, 28, 28) (channels, width, height) 109 | self.layer_1 = torch.nn.Linear(28 * 28, layer_1) 110 | self.layer_2 = torch.nn.Linear(layer_1, layer_2) 111 | self.layer_3 = torch.nn.Linear(layer_2, 10) 112 | self.accuracy = torchmetrics.Accuracy() 113 | 114 | def forward(self, x): 115 | batch_size, channels, width, height = x.size() 116 | x = x.view(batch_size, -1) 117 | x = self.layer_1(x) 118 | x = torch.relu(x) 119 | x = self.layer_2(x) 120 | x = torch.relu(x) 121 | x = self.layer_3(x) 122 | x = F.log_softmax(x, dim=1) 123 | return x 124 | 125 | def configure_optimizers(self): 126 | return torch.optim.Adam(self.parameters(), lr=self.lr) 127 | 128 | def training_step(self, train_batch, batch_idx): 129 | x, y = train_batch 130 | logits = self.forward(x) 131 | loss = F.nll_loss(logits, y.long()) 132 | acc = self.accuracy(logits, y) 133 | self.log("ptl/train_loss", loss) 134 | self.log("ptl/train_accuracy", acc) 135 | return loss 136 | 137 | def validation_step(self, val_batch, batch_idx): 138 | x, y = val_batch 139 | logits = self.forward(x) 140 | loss = F.nll_loss(logits, y.long()) 141 | acc = self.accuracy(logits, y) 142 | return {"val_loss": loss, "val_accuracy": acc} 143 | 144 | def validation_epoch_end(self, outputs): 145 | avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() 146 | avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean() 147 | self.log("ptl/val_loss", avg_loss) 148 | self.log("ptl/val_accuracy", avg_acc) 149 | 150 | 151 | class XORModel(LightningModule): 152 | def __init__(self, input_dim=2, output_dim=1): 153 | super(XORModel, self).__init__() 154 | self.save_hyperparameters() 155 | self.lin1 = torch.nn.Linear(input_dim, 8) 156 | self.lin2 = torch.nn.Linear(8, output_dim) 157 | 158 | def forward(self, features): 159 | x = features.float() 160 | x = self.lin1(x) 161 | x = torch.tanh(x) 162 | x = self.lin2(x) 163 | x = torch.sigmoid(x) 164 | return x 165 | 166 | def configure_optimizers(self): 167 | return torch.optim.Adam(self.parameters(), lr=0.02) 168 | 169 | def training_step(self, batch, batch_nb): 170 | x, y = batch["x"], batch["y"].unsqueeze(1) 171 | y_hat = self(x) 172 | loss = F.binary_cross_entropy(y_hat, y.float()) 173 | return loss 174 | 175 | def validation_step(self, batch, batch_nb): 176 | x, y = batch["x"], batch["y"].unsqueeze(1) 177 | y_hat = self(x) 178 | loss = F.binary_cross_entropy(y_hat, y.float()) 179 | self.log("val_loss", loss, on_step=True) 180 | # Log a constant for test purpose 181 | self.log("val_bar", torch.tensor(5.678), on_step=True) 182 | return loss 183 | 184 | def validation_epoch_end(self, outputs): 185 | avg_loss = torch.stack(outputs).mean() 186 | self.log("avg_val_loss", avg_loss) 187 | # Log a constant for test purpose 188 | self.log("val_foo", torch.tensor(1.234)) 189 | 190 | 191 | class XORDataModule(LightningDataModule): 192 | def train_dataloader(self): 193 | input_train = [{ 194 | "x": torch.tensor([[0.0, 0.0]]), 195 | "y": torch.tensor([0]) 196 | }, { 197 | "x": torch.tensor([[1.0, 1.0]]), 198 | "y": torch.tensor([0]) 199 | }] 200 | return iter(input_train) 201 | 202 | def val_dataloader(self): 203 | input_val = [{ 204 | "x": torch.tensor([[0.0, 1.0]]), 205 | "y": torch.tensor([1]) 206 | }, { 207 | "x": torch.tensor([[1.0, 0.0]]), 208 | "y": torch.tensor([1]) 209 | }] 210 | return iter(input_val) 211 | 212 | 213 | def get_trainer(dir, 214 | strategy: Strategy, 215 | max_epochs: int = 1, 216 | limit_train_batches: int = 10, 217 | limit_val_batches: int = 10, 218 | callbacks: Optional[List[Callback]] = None, 219 | checkpoint_callback: bool = True, 220 | **trainer_kwargs) -> Trainer: 221 | """Returns a Pytorch Lightning Trainer with the provided arguments.""" 222 | callbacks = [] if not callbacks else callbacks 223 | trainer = pl.Trainer( 224 | default_root_dir=dir, 225 | callbacks=callbacks, 226 | strategy=strategy, 227 | max_epochs=max_epochs, 228 | limit_train_batches=limit_train_batches, 229 | limit_val_batches=limit_val_batches, 230 | enable_progress_bar=False, 231 | checkpoint_callback=checkpoint_callback, 232 | **trainer_kwargs) 233 | return trainer 234 | 235 | 236 | def train_test(trainer: Trainer, model: LightningModule): 237 | """Checks if training the provided model updates its weights.""" 238 | initial_values = torch.tensor( 239 | [torch.sum(torch.abs(x)) for x in model.parameters()]) 240 | trainer.fit(model) 241 | post_train_values = torch.tensor( 242 | [torch.sum(torch.abs(x)) for x in model.parameters()]) 243 | assert trainer.state.finished, f"Trainer failed with {trainer.state}" 244 | # Check that the model is actually changed post-training. 245 | assert torch.norm(initial_values - post_train_values) > 0.1 246 | 247 | 248 | def load_test(trainer: Trainer, model: LightningModule): 249 | """Checks if the model checkpoint can be loaded.""" 250 | trainer.fit(model) 251 | trained_model = BoringModel.load_from_checkpoint( 252 | trainer.checkpoint_callback.best_model_path) 253 | assert trained_model is not None, "loading model failed" 254 | 255 | 256 | def predict_test(trainer: Trainer, model: LightningModule, 257 | dm: LightningDataModule): 258 | """Checks if the trained model has high accuracy on the test set.""" 259 | trainer.fit(model, datamodule=dm) 260 | model = trainer.lightning_module 261 | dm.setup(stage="test") 262 | test_loader = dm.test_dataloader() 263 | acc = torchmetrics.Accuracy() 264 | for batch in test_loader: 265 | x, y = batch 266 | with torch.no_grad(): 267 | y_hat = model(x) 268 | y_hat = y_hat.cpu() 269 | acc.update(y_hat, y) 270 | average_acc = acc.compute() 271 | assert average_acc >= 0.5, f"This model is expected to get > {0.5} in " \ 272 | f"test set (it got {average_acc})" 273 | -------------------------------------------------------------------------------- /ray_lightning/tune.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Union, Optional 2 | import warnings 3 | 4 | import fsspec 5 | import os 6 | 7 | from pytorch_lightning import Trainer, LightningModule 8 | from ray.util import PublicAPI 9 | 10 | from ray_lightning.session import put_queue, get_actor_rank 11 | from ray_lightning.util import to_state_stream, Unavailable 12 | 13 | try: 14 | from ray import tune 15 | from ray.tune.integration.pytorch_lightning import TuneCallback 16 | from ray.tune import is_session_enabled 17 | TUNE_INSTALLED = True 18 | except ImportError: 19 | tune = None 20 | TuneCallback = Unavailable 21 | 22 | def is_session_enabled(): 23 | return False 24 | 25 | get_tune_resources = Unavailable 26 | 27 | TUNE_INSTALLED = False 28 | 29 | if TUNE_INSTALLED: 30 | 31 | @PublicAPI(stability="beta") 32 | def get_tune_resources( 33 | num_workers: int = 1, 34 | num_cpus_per_worker: int = 1, 35 | use_gpu: bool = False, 36 | # Deprecated args. 37 | cpus_per_worker: Optional[int] = None, 38 | ) -> Dict[str, int]: 39 | """Returns the PlacementGroupFactory to use for Ray Tune.""" 40 | from ray.tune import PlacementGroupFactory 41 | 42 | if cpus_per_worker is not None: 43 | # TODO(amogkam): Remove `cpus_per_worker` on next major release. 44 | num_cpus_per_worker = cpus_per_worker 45 | warnings.warn( 46 | "`cpus_per_worker` will be deprecated in the " 47 | "future. Use " 48 | "`num_cpus_per_worker` instead.", PendingDeprecationWarning) 49 | 50 | head_bundle = {"CPU": 1} 51 | child_bundle = {"CPU": num_cpus_per_worker, "GPU": int(use_gpu)} 52 | child_bundles = [child_bundle.copy() for _ in range(num_workers)] 53 | bundles = [head_bundle] + child_bundles 54 | placement_group_factory = PlacementGroupFactory( 55 | bundles, strategy="PACK") 56 | return placement_group_factory 57 | 58 | @PublicAPI(stability="beta") 59 | class TuneReportCallback(TuneCallback): 60 | """Distributed PyTorch Lightning to Ray Tune reporting callback 61 | 62 | Reports metrics to Ray Tune, specifically when training is done 63 | remotely with Ray via the accelerators in this library. 64 | 65 | Args: 66 | metrics (str|list|dict): Metrics to report to Tune. 67 | If this is a list, each item describes the metric key 68 | reported to PyTorch Lightning, and it will reported 69 | under the same name to Tune. If this is a dict, each key 70 | will be the name reported to Tune and the respective 71 | value will be the metric key reported to PyTorch Lightning. 72 | on (str|list): When to trigger checkpoint creations. 73 | Must be one of the PyTorch Lightning event hooks (less 74 | the ``on_``), e.g. "batch_start", or "train_end". 75 | Defaults to "validation_end". 76 | 77 | Example: 78 | 79 | .. code-block:: python 80 | 81 | import pytorch_lightning as pl 82 | from ray_lightning import RayStrategy 83 | from ray_lightning.tune import TuneReportCallback 84 | 85 | # Create strategy. 86 | ray_plugin = RayStrategy(num_workers=4, use_gpu=True) 87 | 88 | # Report loss and accuracy to Tune after each validation epoch: 89 | trainer = pl.Trainer(strategy=[ray_plugin], callbacks=[ 90 | TuneReportCallback(["val_loss", "val_acc"], 91 | on="validation_end")]) 92 | 93 | # Same as above, but report as `loss` and `mean_accuracy`: 94 | trainer = pl.Trainer(strategy=[ray_plugin], callbacks=[ 95 | TuneReportCallback( 96 | {"loss": "val_loss", "mean_accuracy": "val_acc"}, 97 | on="validation_end")]) 98 | 99 | """ 100 | 101 | def __init__( 102 | self, 103 | metrics: Union[None, str, List[str], Dict[str, str]] = None, 104 | on: Union[str, List[str]] = "validation_end"): 105 | super(TuneReportCallback, self).__init__(on) 106 | if isinstance(metrics, str): 107 | metrics = [metrics] 108 | self._metrics = metrics 109 | 110 | def _get_report_dict(self, trainer: Trainer, 111 | pl_module: LightningModule): 112 | # Don't report if just doing initial validation sanity checks. 113 | if trainer.sanity_checking: 114 | return 115 | if not self._metrics: 116 | report_dict = { 117 | k: v.item() 118 | for k, v in trainer.callback_metrics.items() 119 | } 120 | else: 121 | report_dict = {} 122 | for key in self._metrics: 123 | if isinstance(self._metrics, dict): 124 | metric = self._metrics[key] 125 | else: 126 | metric = key 127 | report_dict[key] = trainer.callback_metrics[metric].item() 128 | return report_dict 129 | 130 | def _handle(self, trainer: Trainer, pl_module: LightningModule): 131 | if get_actor_rank() == 0: 132 | report_dict = self._get_report_dict(trainer, pl_module) 133 | if report_dict is not None: 134 | put_queue(lambda: tune.report(**report_dict)) 135 | 136 | class _TuneCheckpointCallback(TuneCallback): 137 | """Distributed PyTorch Lightning to Ray Tune checkpoint callback 138 | 139 | Saves checkpoints after each validation step. To be used 140 | specifically with the strategies in this library. 141 | 142 | Checkpoint are currently not registered if no ``tune.report()`` 143 | call is made afterwards. Consider using 144 | ``TuneReportCheckpointCallback`` instead. 145 | 146 | Args: 147 | filename (str): Filename of the checkpoint within the 148 | checkpoint directory. Defaults to "checkpoint". 149 | on (str|list): When to trigger checkpoint creations. 150 | Must be one of the PyTorch Lightning event hooks (less 151 | the ``on_``), e.g. "batch_start", or "train_end". 152 | Defaults to "validation_end". 153 | """ 154 | 155 | def __init__(self, 156 | filename: str = "checkpoint", 157 | on: Union[str, List[str]] = "validation_end"): 158 | super(_TuneCheckpointCallback, self).__init__(on) 159 | self._filename = filename 160 | 161 | @staticmethod 162 | def _create_checkpoint(checkpoint_stream, global_step: int, 163 | filename: str): 164 | with tune.checkpoint_dir(step=global_step) as checkpoint_dir: 165 | file_path = os.path.join(checkpoint_dir, filename) 166 | with fsspec.open(file_path, "wb") as f: 167 | f.write(checkpoint_stream) 168 | 169 | def _handle(self, trainer: Trainer, pl_module: LightningModule): 170 | if trainer.sanity_checking: 171 | return 172 | checkpoint_dict = trainer._checkpoint_connector.dump_checkpoint() 173 | # Convert to a state stream first. 174 | checkpoint_stream = to_state_stream(checkpoint_dict) 175 | global_step = trainer.global_step 176 | if get_actor_rank() == 0: 177 | put_queue(lambda: self._create_checkpoint( 178 | checkpoint_stream, global_step, self._filename)) 179 | 180 | @PublicAPI(stability="beta") 181 | class TuneReportCheckpointCallback(TuneCallback): 182 | """PyTorch Lightning to Tune reporting and checkpointing callback. 183 | 184 | Saves checkpoints after each validation step. Also reports metrics 185 | to Tune, which is needed for checkpoint registration. To be used 186 | specifically with the strategies in this library. 187 | 188 | Args: 189 | metrics (str|list|dict): Metrics to report to Tune. 190 | If this is a list, each item describes the metric key 191 | reported to PyTorch Lightning, and it will reported 192 | under the same name to Tune. If this is a dict, each key 193 | will be the name reported to Tune and the respective 194 | value will be the metric key reported to PyTorch Lightning. 195 | filename (str): Filename of the checkpoint within the 196 | checkpoint directory. Defaults to "checkpoint". 197 | on (str|list): When to trigger checkpoint creations. Must be 198 | one of the PyTorch Lightning event hooks (less the 199 | ``on_``), e.g. "batch_start", or "train_end". Defaults 200 | to "validation_end". 201 | 202 | 203 | Example: 204 | 205 | .. code-block:: python 206 | 207 | import pytorch_lightning as pl 208 | from ray_lightning import RayStrategy 209 | from ray_lightning.tune import TuneReportCheckpointCallback. 210 | 211 | # Create the Ray strategy. 212 | ray_plugin = RayStrategy() 213 | 214 | # Save checkpoint after each training batch and after each 215 | # validation epoch. 216 | trainer = pl.Trainer(strategy=[ray_plugin], callbacks=[ 217 | TuneReportCheckpointCallback( 218 | metrics={"loss": "val_loss", 219 | "mean_accuracy": "val_acc"}, 220 | filename="trainer.ckpt", on="validation_end")]) 221 | 222 | 223 | """ 224 | 225 | def __init__( 226 | self, 227 | metrics: Union[None, str, List[str], Dict[str, str]] = None, 228 | filename: str = "checkpoint", 229 | on: Union[str, List[str]] = "validation_end"): 230 | super(TuneReportCheckpointCallback, self).__init__(on) 231 | self._checkpoint = _TuneCheckpointCallback(filename, on) 232 | self._report = TuneReportCallback(metrics, on) 233 | 234 | def _handle(self, trainer: Trainer, pl_module: LightningModule): 235 | self._checkpoint._handle(trainer, pl_module) 236 | self._report._handle(trainer, pl_module) 237 | 238 | else: 239 | # If Tune is not installed. 240 | TuneReportCallback = Unavailable 241 | TuneReportCheckpointCallback = Unavailable 242 | -------------------------------------------------------------------------------- /ray_lightning/util.py: -------------------------------------------------------------------------------- 1 | import io 2 | from typing import Callable 3 | 4 | import torch 5 | from pytorch_lightning.accelerators import GPUAccelerator 6 | from pytorch_lightning import Trainer 7 | from pytorch_lightning.strategies import Strategy 8 | from pytorch_lightning.utilities.rank_zero import rank_zero_info 9 | 10 | import ray 11 | 12 | 13 | class DelayedGPUAccelerator(GPUAccelerator): 14 | """Same as GPUAccelerator, but doesn't do any CUDA setup. 15 | 16 | This allows the driver script to be launched from CPU-only machines ( 17 | like the laptop) but have training still execute on GPU. 18 | """ 19 | 20 | def setup_environment(self) -> None: 21 | # Don't do any CUDA setup. 22 | # Directly call the setup_environment method of the superclass of 23 | # GPUAccelerator. 24 | super(GPUAccelerator, self).setup_environment() 25 | 26 | def setup( 27 | self, 28 | trainer: Trainer, 29 | ) -> None: 30 | # Don't do any CUDA setup. 31 | # Directly call the setup_environment method of the superclass of 32 | # GPUAccelerator. 33 | return super(GPUAccelerator, self).setup(trainer) 34 | 35 | def on_train_start(self) -> None: 36 | if "cuda" not in str(self.root_device): 37 | raise RuntimeError("GPUs were requested but are not available.") 38 | torch.cuda.set_device(self.root_device) 39 | super(DelayedGPUAccelerator, self).on_train_start() 40 | 41 | 42 | class Unavailable: 43 | """No object should be instance of this class""" 44 | 45 | def __init__(self, *args, **kwargs): 46 | raise RuntimeError("This class should never be instantiated.") 47 | 48 | 49 | def _handle_queue(queue): 50 | """Process results from the queue.""" 51 | while not queue.empty(): 52 | (actor_rank, item) = queue.get() 53 | if isinstance(item, Callable): 54 | item() 55 | 56 | 57 | def process_results(training_result_futures, queue=None): 58 | """Process results from the queue, and return results from the futures.""" 59 | not_ready = training_result_futures 60 | while not_ready: 61 | if queue: 62 | _handle_queue(queue) 63 | ready, not_ready = ray.wait(not_ready, timeout=0) 64 | ray.get(ready) 65 | ray.get(ready) 66 | 67 | if queue: 68 | # Process any remaining items in queue. 69 | _handle_queue(queue) 70 | return ray.get(training_result_futures) 71 | 72 | 73 | def to_state_stream(model_state_dict): 74 | """Converts the given state dict to a stream of bytes.""" 75 | _buffer = io.BytesIO() 76 | torch.save(model_state_dict, _buffer) 77 | return _buffer.getvalue() 78 | 79 | 80 | def load_state_stream(state_stream, to_gpu): 81 | """Converts the state stream to a state dict on the appropriate device. 82 | 83 | Converts to GPU if ``to_gpu`` is True and CUDA is available. 84 | 85 | """ 86 | _buffer = io.BytesIO(state_stream) 87 | to_gpu = to_gpu and torch.cuda.is_available() 88 | state_dict = torch.load( 89 | _buffer, 90 | map_location=("cpu" 91 | if not to_gpu else lambda storage, loc: storage.cuda())) 92 | return state_dict 93 | 94 | 95 | def set_cuda_device_if_used(strategy: "Strategy") -> None: 96 | """Set the CUDA device to use for the root node.""" 97 | if strategy.use_gpu: 98 | # overwrite the logger 99 | rank_zero_info("GPU available: True (cuda), used: True " 100 | "(Please ignore the previous info [GPU used: False]).") 101 | 102 | torch.cuda.set_device(strategy.root_device) 103 | -------------------------------------------------------------------------------- /requirements-lint.txt: -------------------------------------------------------------------------------- 1 | flake8==3.9.1 2 | flake8-comprehensions 3 | flake8-quotes 4 | importlib_metadata==4.13.0 5 | yapf==0.23.0 6 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | fairscale 2 | pytest 3 | pytorch-lightning==1.6.4 4 | lightning-bolts==0.3.3 5 | ray[tune] 6 | torch==1.12.0 7 | torchmetrics==0.9.3 8 | torchvision 9 | protobuf<=3.20.1 10 | jsonargparse>=4.13.2 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="ray_lightning", 5 | packages=find_packages(where=".", include="ray_lightning*"), 6 | version="0.3.0", 7 | author="Ray Team", 8 | description="Ray distributed strategies for Pytorch Lightning.", 9 | long_description="Custom Pytorch Lightning distributed strategies " 10 | "built on top of distributed computing framework Ray.", 11 | url="https://github.com/ray-project/ray_lightning_accelerators", 12 | install_requires=["pytorch-lightning==1.6.*", "ray"]) 13 | --------------------------------------------------------------------------------