├── .dockerignore ├── .github └── workflows │ ├── check-style.yml │ └── push-docker-image.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── README.md ├── open_diloco ├── .gitignore ├── __init__.py ├── ckpt_utils.py ├── configs │ ├── config_14m.json │ ├── config_150m.json │ ├── config_1b.json │ ├── config_2m.json │ └── config_60m.json ├── fixed_key.pem ├── hivemind_diloco.py ├── init_weights.py ├── run_training.sh ├── train_diloco_torch.py ├── train_fsdp.py └── utils.py ├── pyproject.toml ├── requirements-dev.txt ├── requirements.txt ├── scripts ├── pull-c4.sh └── pull-model.py └── tests ├── models └── llama-2m-fresh │ ├── config.json │ ├── generation_config.json │ └── model.safetensors ├── test_diloco_hivemind.py └── test_training └── test_train.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | */wandb/* 4 | */data/* 5 | */outputs/* 6 | 7 | # Git 8 | .git 9 | .gitignore 10 | 11 | # Docker 12 | docker-compose.yml 13 | .docker 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | */__pycache__/ 18 | */*/__pycache__/ 19 | */*/*/__pycache__/ 20 | *.py[cod] 21 | */*.py[cod] 22 | */*/*.py[cod] 23 | */*/*/*.py[cod] 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | env/ 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | 45 | # PyInstaller 46 | # Usually these files are written by a python script from a template 47 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 48 | *.manifest 49 | *.spec 50 | 51 | # Installer logs 52 | pip-log.txt 53 | pip-delete-this-directory.txt 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .coverage 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Virtual environment 77 | .env/ 78 | .venv/ 79 | venv/ 80 | 81 | # PyCharm 82 | .idea 83 | 84 | # Python mode for VIM 85 | .ropeproject 86 | */.ropeproject 87 | */*/.ropeproject 88 | */*/*/.ropeproject 89 | 90 | # Vim swap files 91 | *.swp 92 | */*.swp 93 | */*/*.swp 94 | */*/*/*.swp -------------------------------------------------------------------------------- /.github/workflows/check-style.yml: -------------------------------------------------------------------------------- 1 | name: Check style 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | # This will trigger the workflow for pull requests to any branch 9 | types: [opened, synchronize, reopened] 10 | 11 | jobs: 12 | ruff: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - uses: chartboost/ruff-action@v1 17 | 18 | codespell: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v3 22 | - uses: codespell-project/actions-codespell@v1 23 | with: 24 | only_warn: 1 25 | ignore_words_list: ibrary,nd 26 | -------------------------------------------------------------------------------- /.github/workflows/push-docker-image.yml: -------------------------------------------------------------------------------- 1 | name: Push to Docker Hub 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | jobs: 7 | build: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Remove unnecessary packages 11 | run: | 12 | echo "=== Before pruning ===" 13 | df -h 14 | sudo rm -rf /usr/share/dotnet 15 | sudo rm -rf /usr/local/lib/android 16 | sudo rm -rf /opt/ghc 17 | echo "=== After pruning ===" 18 | df -h 19 | 20 | # Link to discussion: https://github.com/orgs/community/discussions/25678 21 | 22 | - name: Checkout 23 | uses: actions/checkout@v3 24 | - name: Docker meta 25 | id: meta 26 | uses: crazy-max/ghaction-docker-meta@v2 27 | with: 28 | images: | 29 | primeintellect/open_diloco 30 | tags: | 31 | type=ref,event=branch 32 | type=ref,event=pr 33 | type=semver,pattern={{version}} 34 | type=semver,pattern={{major}}.{{minor}} 35 | type=semver,pattern={{major}} 36 | type=sha,prefix=commit- 37 | 38 | - name: Set up Docker Buildx 39 | id: buildx 40 | uses: docker/setup-buildx-action@v1 41 | 42 | - name: Login to Docker Hub 43 | if: github.event_name != 'pull_request' 44 | uses: docker/login-action@v1 45 | with: 46 | username: ${{ secrets.DOCKER_HUB_USERNAME }} 47 | password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }} 48 | 49 | - name: Build and push 50 | id: docker_build 51 | uses: docker/build-push-action@v2 52 | with: 53 | context: . 54 | push: ${{ github.event_name != 'pull_request' }} 55 | tags: ${{ steps.meta.outputs.tags }} 56 | 57 | - name: Image digest 58 | run: echo ${{ steps.docker_build.outputs.digest }} 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # node and NPM 2 | npm-debug.log 3 | node_modules 4 | 5 | # swap files 6 | *~ 7 | *.swp 8 | 9 | examples/data/* 10 | examples/runs/* 11 | examples/.ipynb_checkpoints/* 12 | 13 | env.sh 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | env/ 24 | bin/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | eggs/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg/ 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | 49 | # Translations 50 | *.mo 51 | 52 | # Mr Developer 53 | .mr.developer.cfg 54 | .project 55 | .pydevproject 56 | .idea 57 | .vscode 58 | .ipynb_checkpoints 59 | 60 | # Rope 61 | .ropeproject 62 | 63 | # Django stuff: 64 | *.log 65 | *.pot 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | docs/tmp* 70 | 71 | # OS X garbage 72 | .DS_Store 73 | 74 | # Debian things 75 | debian/reproducible-experiment-platform 76 | debian/files 77 | *.substvars 78 | *.debhelper.log 79 | 80 | # protobuf stuff 81 | hivemind/proto/*_pb2* 82 | 83 | # libp2p-daemon binary 84 | hivemind/hivemind_cli/p2pd 85 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "hivemind_source"] 2 | path = hivemind_source 3 | url = https://github.com/PrimeIntellect-ai/hivemind.git 4 | branch = feat-add-downloading-time 5 | [submodule "pydantic_config"] 6 | path = pydantic_config 7 | url = https://github.com/samsja/pydantic_config.git 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.4.7 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | args: [ --fix ] 9 | # Run the formatter. 10 | - id: ruff-format -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Development workflow 2 | 3 | This is the development workflow of prime intellect to build upon hivemind 4 | 5 | ## Install dependencies 6 | 7 | Install hivemind 8 | 9 | ```bash 10 | cd hivemind_source 11 | pip install . 12 | cp build/lib/hivemind/proto/* hivemind/proto/. 13 | pip install -e ".[all]"``` 14 | ``` 15 | 16 | ## Pre-commit hook 17 | 18 | Install the pre commit hook to keep black and isort updated on each commit: 19 | 20 | ``` 21 | pre-commit install 22 | ``` 23 | 24 | ## Testing 25 | To run the tests: 26 | 27 | ``` 28 | python -m pytest tests 29 | ``` 30 | 31 | Be sure to actually use python -m otherwise path won't be appended correctly 32 | 33 | # Development flags 34 | Add the `PRIME_INTELLECT_DEV` environment variable to your *.bashrc* or *.zshrc* so that development features are enabled. 35 | e.g. 36 | - torch compile error will crash the script instead of silently failing 37 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel 2 | LABEL maintainer="prime intellect" 3 | LABEL repository="open_diloco" 4 | 5 | # Set en_US.UTF-8 locale by default 6 | RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment 7 | 8 | # Set CUDA_HOME and update PATH 9 | ENV CUDA_HOME=/usr/local/cuda 10 | ENV PATH=$PATH:/usr/local/cuda/bin 11 | 12 | # Install packages 13 | RUN apt-get update && apt-get install -y --no-install-recommends --force-yes \ 14 | build-essential \ 15 | curl \ 16 | wget \ 17 | git \ 18 | vim \ 19 | htop \ 20 | nvtop \ 21 | iperf \ 22 | tmux \ 23 | openssh-server \ 24 | git-lfs \ 25 | && apt-get clean autoclean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* 26 | 27 | # Install Git LFS 28 | RUN git-lfs install 29 | 30 | # Install Rust 31 | RUN curl https://sh.rustup.rs -sSf | sh -s -- -y 32 | ENV PATH="/root/.cargo/bin:${PATH}" 33 | RUN echo "export PATH=\"/opt/conda/bin:/root/.cargo/bin:\$PATH\"" >> /root/.bashrc 34 | 35 | # Install Python dependencies (The gradual copies help with caching) 36 | WORKDIR open_diloco 37 | RUN pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu 38 | RUN pip install flash-attn>=2.5.8 39 | COPY requirements.txt requirements.txt 40 | RUN pip install --no-cache-dir -r requirements.txt 41 | COPY requirements-dev.txt requirements-dev.txt 42 | RUN pip install --no-cache-dir -r requirements-dev.txt 43 | COPY . . 44 | RUN pip install . 45 | RUN rm -rf ~/.cache/pip 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenDiLoCo 2 | 3 | This repository contains the training code and experiment results for the paper [OpenDiLoCo: An Open-Source Framework for Globally Distributed Low-Communication Training](https://arxiv.org/abs/2407.07852). 4 | 5 | > **Important Notice**: OpenDiLoCo is no longer maintained. For our production-ready distributed training solution, please check out [prime](https://github.com/PrimeIntellect-ai/prime), which offers improved fault tolerance, bandwidth utilization, and scalability. 6 | 7 | https://github.com/user-attachments/assets/38caf4cc-51e1-4ee6-8f43-406910a35995 8 | 9 | 10 | # Setup 11 | 12 | Before running the experiment scripts, you must first setup the environment. 13 | You can clone the repository and setup a conda environment or use our pre-built docker image. 14 | 15 | ## Cloning the repository 16 | 17 | Clone the repository along with the submodules: 18 | ``` 19 | git clone https://github.com/PrimeIntellect-ai/OpenDiLoCo.git --recursive 20 | cd OpenDiLoCo 21 | ``` 22 | 23 | ## Environment setup 24 | 25 | Create a new conda environment and activate it: 26 | ```bash 27 | conda create -n OpenDiLoCo python=3.11 -y && conda activate OpenDiLoCo 28 | ``` 29 | 30 | or with virtualenv: 31 | ```bash 32 | python -m venv .venv 33 | source .venv/bin/activate 34 | ``` 35 | 36 | Install python dependencies: 37 | ```bash 38 | pip install . 39 | pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu 40 | ``` 41 | 42 | Optionally, you can install flash-attn to use Flash Attention 2. 43 | This requires your system to have cuda compiler set up. 44 | 45 | ```bash 46 | # (Optional) flash-attn 47 | pip install flash-attn>=2.5.8 48 | ``` 49 | 50 | ## Docker container 51 | 52 | If you prefer to run your experiments in a reproduceable container, you can use our pre-built docker image containing the repository and pre-installed dependencies. 53 | ```bash 54 | docker pull primeintellect/open_diloco:main 55 | docker run -d --name open-diloco --ipc=host --network=host --gpus=all primeintellect/open_diloco:main 56 | docker exec -it open-diloco bash 57 | ``` 58 | 59 | # Experiments 60 | 61 | This section describes the configurations we used for the experiments reported in the paper. 62 | 63 | The scripts to launch the experiment are in the `open_diloco` folder. 64 | The commands in this document assume you are in the `open_diloco` folder: 65 | ```bash 66 | cd open_diloco 67 | ``` 68 | 69 | ## Machine specific configurations 70 | 71 | The torchrun arguments can be changed to match your machine configuration without affecting the final results unless stated otherwise. 72 | The `per-device-train-batch-size` can also be changed to match the VRAM of your GPUs without affecting the final results. 73 | 74 | ## Hivemind Initialisation 75 | 76 | Some of our experiments utilize the hivemind library to perform distributed weight averaging. 77 | This requires a [Distributed Hash Table (DHT)](https://en.wikipedia.org/wiki/Distributed_hash_table). 78 | An initial peer is required to initialize the DHT. 79 | This initial peer must be reachable from all machines participating in the distributed training. 80 | 81 | On the machine chosen as the initial peer, run: 82 | 83 | ```bash 84 | hivemind-dht --identity_path fixed_private_key.pem --host_maddrs /ip4/0.0.0.0/tcp/30001 85 | ``` 86 | 87 | You should receive an output similar to this: 88 | 89 | ```bash 90 | Feb 30 13:35:32.717 [INFO] Running a DHT instance. To connect other peers to this one, use --initial_peers /ip4/127.0.0.1/tcp/30001/p2p/Qmbh7opLJxFCtY22XqwETuo6bnWqijs76YXz7D69MBWEuZ 91 | Feb 30 13:35:32.717 [INFO] Full list of visible multiaddresses: /ip4/127.0.0.1/tcp/30001/p2p/Qmbh7opLJxFCtY22XqwETuo6bnWqijs76YXz7D69MBWEuZ /ip4/192.168.100.20/tcp/30001/p2p/Qmbh7opLJxFCtY22XqwETuo6bnWqijs76YXz7D69MBWEuZ 92 | Feb 30 13:35:32.719 [INFO] 1 DHT nodes (including this one) are in the local routing table 93 | Feb 30 13:35:32.719 [INFO] Local storage contains 0 keys 94 | ``` 95 | 96 | The [multiaddress](https://github.com/multiformats/multiaddr) strings listed after `Full list of visible multiaddresses: ` in the output are the multiaddresses you can use to initialize your training processes. In this example they are `/ip4/127.0.0.1/tcp/30001/p2p/Qmbh7opLJxFCtY22XqwETuo6bnWqijs76YXz7D69MBWEuZ` and `/ip4/192.168.100.20/tcp/30001/p2p/Qmbh7opLJxFCtY22XqwETuo6bnWqijs76YXz7D69MBWEuZ` 97 | 98 | ## Stopping hivemind runs 99 | 100 | The current implementation of hivemind doesn't handle Ctrl+C keyboard interrupt well. You can stop the runs using `pkill`: 101 | ```bash 102 | pkill -f torchrun 103 | ``` 104 | 105 | ## Resuming from checkpoint 106 | To resume from checkpoint, you can pass the `--resume-from-checkpoint` argument to the training script. e.g. 107 | ```bash 108 | torchrun --nproc_per_node=8 \ 109 | train_fsdp.py \ 110 | ... 111 | --resume-from-checkpoint checkpoints_1b/2024-06-20/hivemind_1b/bm5zjkzr/model_step_6000 112 | ``` 113 | 114 | ## 150m DDP Baseline 115 | In the `open_diloco` folder, run: 116 | ```bash 117 | torchrun --nproc_per_node=8 \ 118 | train_fsdp.py \ 119 | --sharding-strategy NO_SHARD \ 120 | --per-device-train-batch-size 32 \ 121 | --precision bf16-mixed \ 122 | --total-batch-size 512 \ 123 | --total-steps 88_000 \ 124 | --project OpenDiLoCo \ 125 | --lr 4e-4 \ 126 | --path_model PrimeIntellect/llama-150m-fresh \ 127 | --log-activations-steps 200 \ 128 | --ckpt.interval 8000 \ 129 | --ckpt.path 150_ckpt 130 | ``` 131 | 132 | ## 150m on 8 DiLoCo Worker with 500 local steps 133 | In the `open_diloco` folder, run: 134 | ```bash 135 | ./run_training.sh 8 1 $PEER \ 136 | --sharding-strategy NO_SHARD \ 137 | --per-device-train-batch-size 8 \ 138 | --precision bf16-mixed \ 139 | --total-batch-size 512 \ 140 | --hv.local-steps 500 \ 141 | --total-steps 88_000 \ 142 | --project OpenDiLoCo \ 143 | --hv.skip_load_from_peers \ 144 | --lr 4e-4 \ 145 | --path-model PrimeIntellect/llama-150m-fresh \ 146 | --log-activations-steps 250 \ 147 | --ckpt.interval 4975 \ 148 | --ckpt.path 150_ckpt 149 | ``` 150 | 151 | under the hood the `run_training.sh` script calls `train_fsdp.py` 8 times with the right argument to simulate 8 workers locally. 152 | 153 | 154 | ## 150m on 8 DiLoCo Worker with 50 local steps 155 | In the `open_diloco` folder, run: 156 | ```bash 157 | ./run_training.sh 8 1 $PEER \ 158 | --sharding-strategy NO_SHARD \ 159 | --per-device-train-batch-size 8 \ 160 | --total-batch-size 512 \ 161 | --precision bf16-mixed \ 162 | --hv.local-steps 50 \ 163 | --total-steps 88_000 \ 164 | --project OpenDiLoCo \ 165 | --hv.skip_load_from_peers \ 166 | --lr 4e-4 \ 167 | --path-model PrimeIntellect/llama-150m-fresh \ 168 | --log-activations-steps 250 \ 169 | --ckpt.interval 4975 \ 170 | --ckpt.path 150_ckpt 171 | ``` 172 | 173 | ## 1b Baseline 174 | In the `open_diloco` folder, run: 175 | ```bash 176 | torchrun --nproc_per_node=8 \ 177 | train_fsdp.py \ 178 | --sharding-strategy _HYBRID_SHARD_ZERO2 \ 179 | --per-device-train-batch-size 16 \ 180 | --total-batch-size 8192 \ 181 | --precision bf16-mixed \ 182 | --total-steps 88_000 \ 183 | --project OpenDiLoCo \ 184 | --lr 4e-4 \ 185 | --path_model PrimeIntellect/llama-1b-fresh \ 186 | --ckpt.path 1b_ckpt \ 187 | --ckpt.interval 500 188 | ``` 189 | 190 | ## 1b on 4 DiLoCo Workers with 500 local steps 191 | Set the `PEER` environment variable to the multiaddress string obtained from the **Hivemind Initialisation** step above. 192 | Launch the command below on 4 separate machines with the environment variable `WORLD_RANK` set to `0`, `1`, `2` and `3` respectively. 193 | 194 | ```bash 195 | export PEER=/ip4/192.168.100.20/tcp/30001/p2p/Qmbh7opLJxFCtY22XqwETuo6bnWqijs76YXz7D69MBWEuZ 196 | export WORLD_RANK=0 197 | 198 | torchrun --nproc_per_node=8 \ 199 | train_fsdp.py \ 200 | --per-device-train-batch-size 16 \ 201 | --total-batch-size 2048 \ 202 | --precision bf16-mixed \ 203 | --total-steps 88_000 \ 204 | --hv.local_steps 500 \ 205 | --project OpenDiLoCo \ 206 | --lr 4e-4 \ 207 | --path_model PrimeIntellect/llama-1b-fresh \ 208 | --warmup-steps 1000 \ 209 | --hv.averaging_timeout 1800 \ 210 | --hv.skip_load_from_peers \ 211 | --hv.local_steps 500 \ 212 | --hv.initial-peers $PEER \ 213 | --hv.galaxy-size 4 \ 214 | --hv.world-rank $WORLD_RANK \ 215 | --checkpoint_interval 500 \ 216 | --ckpt.path 1b_diloco_ckpt 217 | ``` 218 | ## 1b on 4 DiLoCo Workers with 125 local steps 219 | 220 | similar as above but with 221 | 222 | 223 | 224 | ```bash 225 | export PEER=/ip4/192.168.100.20/tcp/30001/p2p/Qmbh7opLJxFCtY22XqwETuo6bnWqijs76YXz7D69MBWEuZ 226 | export WORLD_RANK=0 227 | 228 | torchrun --nproc_per_node=8 \ 229 | train_fsdp.py \ 230 | --per-device-train-batch-size 16 \ 231 | --total-batch-size 2048 \ 232 | --precision bf16-mixed \ 233 | --total-steps 88_000 \ 234 | --hv.local_steps 500 \ 235 | --project OpenDiLoCo \ 236 | --lr 4e-4 \ 237 | --path_model PrimeIntellect/llama-1b-fresh \ 238 | --warmup-steps 1000 \ 239 | --hv.averaging_timeout 1800 \ 240 | --hv.skip_load_from_peers \ 241 | --hv.local_steps 125 \ 242 | --hv.initial-peers $PEER \ 243 | --hv.galaxy-size 4 \ 244 | --hv.world-rank $WORLD_RANK \ 245 | --checkpoint_interval 500 \ 246 | --ckpt.path 1b_diloco_ckpt 247 | ``` 248 | 249 | # Use OpenDiLoCo in your own code 250 | 251 | This codebase is composed of a full training script to use OpenDiLoCo with torch FSDP and hivemind to pretrain transformers (what is used below) as well as individual components to use OpenDiLoCo with other frameworks. 252 | 253 | Specifically, if you want to use OpenDiLoCo in your own training script, you can replace your optimizer with `open_diloco.hivemind_diloco.DiLoCoOptimizer`, which is an (almost) drop-in replacement for `hivemind.optim.optimizer` 254 | 255 | ## Example usage of `DiLoCoOptimizer`. 256 | 257 | ```python 258 | from functools import partial 259 | 260 | from open_diloco.hivemind_diloco import DiLoCoOptimizer 261 | from hivemind.dht.dht import DHT 262 | 263 | 264 | dht = DHT(start=True, initial_peers=os.environ["PEERS"]) 265 | 266 | inner_optimizer = partial(torch.optim.AdamW, lr=4e-4) # optimizer need to be function 267 | outer_optimizer = partial(torch.optim.SGD, lr=0.7, momentum=0.9, nesterov=True) # optimizer need to be function 268 | 269 | model = ... 270 | 271 | optimizer = DiLoCoOptimizer(dht=dht,params=model.parameters(), batch_size=512, num_inner_steps=500,inner_optimizer=inner_optimizer, outer_optimizer=outer_optimizer) 272 | 273 | train_dataloader = ... 274 | 275 | for step, batch in enumerate(train_dataloader): 276 | optimizer.zero_grad() 277 | loss = model(batch) 278 | loss.backward() 279 | optimizer.step() 280 | 281 | ``` 282 | 283 | Note on using gradient scaler: If you are using a gradient scaler, you need to specifically call the `unscale_` on the inner optimizer. 284 | 285 | ```python 286 | scaler.unscale_(optimizer.inner_optimizer) 287 | ``` 288 | 289 | and you need to pass the scaler as a parameter of the `optimizer.step`. 290 | 291 | ```python 292 | optimizer.step(scaler) 293 | ``` 294 | 295 | 296 | We recommend using `bf16` to avoid scaling and desynchronization issues with hivemind/fsdp and are actively working to make it easier to handle scalers with our optimizer. 297 | 298 | 299 | # Debugging Issues 300 | 1. `RuntimeError: CUDA error: invalid device ordinal` 301 | A possible culprit is that your `--nproc-per-node` argument for the torchrun launcher is set incorrectly. 302 | Please set it to an integer less than equal to the number of gpus you have on your machine. 303 | 304 | 2. `torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate...` 305 | A possible culprit is that your `--per-device-train-batch-size` is too high. 306 | Try a smaller value. 307 | 308 | # Citation 309 | If you use OpenDiloco for your research, please cite our [paper](https://arxiv.org/abs/2407.07852): 310 | ```bibtex 311 | @misc{jaghouar2024opendiloco, 312 | title={OpenDiLoCo: An Open-Source Framework for Globally Distributed Low-Communication Training}, 313 | author={Sami Jaghouar and Jack Min Ong and Johannes Hagemann}, 314 | year={2024}, 315 | eprint={2407.07852}, 316 | archivePrefix={arXiv}, 317 | primaryClass={cs.LG}, 318 | url={https://arxiv.org/abs/2407.07852}, 319 | } 320 | ``` 321 | -------------------------------------------------------------------------------- /open_diloco/.gitignore: -------------------------------------------------------------------------------- 1 | wandb/* 2 | data/* 3 | outputs/* 4 | *ipynb 5 | logs/* -------------------------------------------------------------------------------- /open_diloco/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrimeIntellect-ai/OpenDiloco/2d750e58a692ce1424d2a2366b2b3de1f42c9bf1/open_diloco/__init__.py -------------------------------------------------------------------------------- /open_diloco/ckpt_utils.py: -------------------------------------------------------------------------------- 1 | import fsspec 2 | from pydantic_config import BaseConfig 3 | import torch 4 | from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict 5 | import torch.distributed.checkpoint as dcp 6 | import os 7 | from torchdata.stateful_dataloader import StatefulDataLoader 8 | from fsspec.generic import GenericFileSystem 9 | from hivemind.optim.optimizer import logger 10 | 11 | 12 | GLOBAL_STATE_FILE = "global_state_dict.pt" 13 | CKPT_PREFIX = "model_step" 14 | 15 | 16 | class CkptConfig(BaseConfig): 17 | resume: str | bool | None = None # if resume is a boolean, it means we should resume from the last checkpoint 18 | interval: int | None = None 19 | path: str = "outputs" 20 | topk: int | None = None # how many checkpoints to keep 21 | 22 | 23 | def get_resume_info(ckpt_config: CkptConfig) -> tuple[bool, str | None]: 24 | """ 25 | check if we should resume from a checkpoint, if yes return the path to the checkpoint, otherwise return None 26 | """ 27 | if ckpt_config.resume is None: 28 | return False, None 29 | elif isinstance(ckpt_config.resume, bool): 30 | # Using fsspec to list directory contents 31 | fs = GenericFileSystem() 32 | try: 33 | ckpt_files = [f for f in fs.ls(ckpt_config.path, detail=False) if filter_ckpt_files(f)] 34 | except FileNotFoundError: 35 | logger.info(f"Checkpoint path {ckpt_config.path} not found, starting from scratch") 36 | return False, None 37 | 38 | if len(ckpt_files) == 0: 39 | logger.info(f"No checkpoints found in {ckpt_config.path}, starting from scratch") 40 | return False, None 41 | 42 | latest_ckpt = max(ckpt_files, key=lambda f: int(f.split("_")[-1])) 43 | return True, latest_ckpt 44 | else: 45 | return True, ckpt_config.resume 46 | 47 | 48 | def save_checkpoint( 49 | checkpoint_path: str, 50 | model: torch.nn.Module, 51 | optimizer: torch.optim.Optimizer, 52 | scheduler: torch.optim.lr_scheduler.LambdaLR, 53 | outer_optimizer: torch.optim.Optimizer | None = None, 54 | scaler: torch.cuda.amp.GradScaler | None = None, 55 | loss: float | None = None, 56 | data_loader: StatefulDataLoader | None = None, 57 | save_global_state: bool = True, 58 | ): 59 | """Save the model and optimizer state to a checkpoint folderx 60 | 61 | Args: 62 | checkpoint_path: the path to the checkpoint folder 63 | model: the model to save 64 | optimizer: the optimizer to save 65 | scheduler: the scheduler to save 66 | outer_optimizer: the outer optimizer to save 67 | loss: the loss to save 68 | data_loader: the data loader to save 69 | save_global_state: whether to save the global state 70 | """ 71 | rank = int(os.environ["RANK"]) 72 | 73 | # 1. Save distributed states 74 | fs_storage_writer = dcp.FsspecWriter(checkpoint_path, sync_files=False) 75 | # for some reason sync_files = True try to call stream.fileno which is not supported with gcp ffspec storage. 76 | 77 | model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) 78 | dcp_state_dict = { 79 | "model": model_state_dict, 80 | "optimizer": optimizer_state_dict, 81 | } 82 | dcp.save(dcp_state_dict, storage_writer=fs_storage_writer) 83 | if data_loader is not None: 84 | rank_state_dict = {} 85 | rank_state_dict["data_loader"] = data_loader.state_dict() 86 | with fsspec.open(os.path.join(checkpoint_path, f"__{rank}_0.pt"), "wb") as f: 87 | torch.save(rank_state_dict, f) 88 | 89 | if not save_global_state: 90 | return 91 | 92 | # 2. Save global states 93 | global_state_dict = {"scheduler": scheduler.state_dict(), "loss": loss if loss is not None else 0} 94 | if outer_optimizer is not None: 95 | global_state_dict["outer_optimizer"] = outer_optimizer.state_dict() 96 | if scaler is not None: 97 | global_state_dict["scaler"] = scaler.state_dict() 98 | 99 | with fsspec.open(os.path.join(checkpoint_path, GLOBAL_STATE_FILE), "wb") as f: 100 | torch.save(global_state_dict, f) 101 | 102 | 103 | def load_checkpoint( 104 | checkpoint_path: str, 105 | model: torch.nn.Module, 106 | optimizer: torch.optim.Optimizer, 107 | scheduler: torch.optim.lr_scheduler.LambdaLR | None = None, 108 | outer_optimizer: torch.optim.Optimizer | None = None, 109 | scaler: torch.cuda.amp.GradScaler | None = None, 110 | data_loader: StatefulDataLoader | None = None, 111 | ) -> float: 112 | """Load the model and optimizer state from a checkpoint folder 113 | 114 | Args: 115 | checkpoint_path: the path to the checkpoint folder 116 | model: the model to load 117 | optimizer: the optimizer to load 118 | scheduler: the scheduler to load 119 | outer_optimizer: the outer optimizer to load 120 | data_loader: the data loader to load 121 | 122 | Returns: 123 | loss: the loss from the checkpoint 124 | """ 125 | rank = int(os.environ["RANK"]) 126 | # 1. Load distributed states 127 | fs_storage_reader = dcp.FsspecReader(checkpoint_path) 128 | 129 | model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) 130 | dcp_state_dict = { 131 | "model": model_state_dict, 132 | "optimizer": optimizer_state_dict, 133 | } 134 | dcp.load(dcp_state_dict, storage_reader=fs_storage_reader) 135 | set_state_dict( 136 | model, 137 | optimizer, 138 | model_state_dict=model_state_dict, 139 | optim_state_dict=optimizer_state_dict, 140 | ) 141 | if data_loader is not None: 142 | with fsspec.open(os.path.join(checkpoint_path, f"__{rank}_0.pt"), "rb") as f: 143 | rank_state_dict = torch.load(f) 144 | data_loader.load_state_dict(rank_state_dict["data_loader"]) 145 | 146 | # 2. Load global states 147 | with fsspec.open(os.path.join(checkpoint_path, GLOBAL_STATE_FILE), "rb") as f: 148 | global_state_dict = torch.load(f) 149 | if scheduler is not None: 150 | scheduler.load_state_dict(global_state_dict["scheduler"]) 151 | optimizer.param_groups[0]["lr"] = scheduler.get_last_lr()[0] 152 | if outer_optimizer is not None: 153 | outer_optimizer.load_state_dict(global_state_dict["outer_optimizer"]) 154 | if scaler is not None: 155 | scaler.load_state_dict(global_state_dict["scaler"]) 156 | return global_state_dict["loss"] 157 | 158 | 159 | def filter_ckpt_files(f): 160 | if CKPT_PREFIX not in f: 161 | return False 162 | else: 163 | try: 164 | int(f.split("_")[-1]) 165 | return True 166 | except ValueError: 167 | return False 168 | 169 | 170 | def delete_old_checkpoints(checkpoint_path: str, topk: int) -> list[str]: 171 | fs = GenericFileSystem() 172 | ckpt_files = [f for f in fs.ls(checkpoint_path, detail=False) if filter_ckpt_files(f)] 173 | ckpt_files.sort(key=lambda x: int(x.split("_")[-1])) 174 | 175 | ckpt_deleted = [] 176 | for ckpt_file in ckpt_files[:-topk]: 177 | fs.rm(ckpt_file, recursive=True) 178 | ckpt_deleted.append(ckpt_file) 179 | return ckpt_deleted 180 | 181 | 182 | def check_checkpoint_path_access(checkpoint_path: str, rank: int, world_rank_hv: int | None = None): 183 | if world_rank_hv: 184 | dummy_file_path = os.path.join( 185 | checkpoint_path, get_diloco_rank_dir_name(world_rank_hv), f"dummy_file_{rank}.txt" 186 | ) 187 | else: 188 | dummy_file_path = os.path.join(checkpoint_path, f"dummy_file_{rank}.txt") 189 | 190 | with fsspec.open(dummy_file_path, "w") as f: 191 | f.write("This is a dummy file for testing access.") 192 | gfs = GenericFileSystem() 193 | gfs.rm(dummy_file_path) 194 | 195 | 196 | def get_diloco_rank_dir_name(world_rank_diloco: int) -> str: 197 | return f"diloco_rank_{world_rank_diloco}" 198 | -------------------------------------------------------------------------------- /open_diloco/configs/config_14m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "model_type": "llama", 6 | "hidden_size": 128, 7 | "intermediate_size": 512, 8 | "num_attention_heads": 4, 9 | "num_hidden_layers": 6, 10 | "rms_norm_eps": 1e-05, 11 | "use_cache": false 12 | } 13 | 14 | -------------------------------------------------------------------------------- /open_diloco/configs/config_150m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "model_type": "llama", 6 | "hidden_size": 1024, 7 | "intermediate_size": 2688, 8 | "num_attention_heads": 16, 9 | "num_hidden_layers": 12, 10 | "use_cache": false, 11 | "rms_norm_eps": 1e-05 12 | } 13 | 14 | -------------------------------------------------------------------------------- /open_diloco/configs/config_1b.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "model_type": "llama", 6 | "hidden_size": 2048, 7 | "intermediate_size": 5632, 8 | "num_attention_heads": 32, 9 | "num_hidden_layers": 22, 10 | "use_cache": false, 11 | "rms_norm_eps": 1e-05, 12 | "num_key_value_heads": 4 13 | } 14 | -------------------------------------------------------------------------------- /open_diloco/configs/config_2m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "model_type": "llama", 6 | "hidden_size": 64, 7 | "intermediate_size": 256, 8 | "num_attention_heads": 2, 9 | "num_hidden_layers": 2, 10 | "rms_norm_eps": 1e-05, 11 | "use_cache": false, 12 | "vocab_size": 1024 13 | } 14 | 15 | -------------------------------------------------------------------------------- /open_diloco/configs/config_60m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "model_type": "llama", 6 | "hidden_size": 1024, 7 | "intermediate_size": 2688, 8 | "num_attention_heads": 16, 9 | "num_hidden_layers": 3, 10 | "use_cache": false, 11 | "rms_norm_eps": 1e-05 12 | } 13 | 14 | -------------------------------------------------------------------------------- /open_diloco/fixed_key.pem: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrimeIntellect-ai/OpenDiloco/2d750e58a692ce1424d2a2366b2b3de1f42c9bf1/open_diloco/fixed_key.pem -------------------------------------------------------------------------------- /open_diloco/hivemind_diloco.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import time 3 | from typing import Callable, Iterator, List, Optional, Union 4 | import numpy as np 5 | 6 | import torch 7 | 8 | from hivemind.averaging.averager import DecentralizedAverager 9 | from hivemind.averaging.control import StepControl 10 | from hivemind.compression.base import CompressionBase, NoCompression 11 | from hivemind.dht.dht import DHT 12 | from hivemind.optim.optimizer import Optimizer 13 | from hivemind.optim.progress_tracker import ( 14 | GlobalTrainingProgress, 15 | ProgressTracker, 16 | TrainingProgressSchema, 17 | ) 18 | from hivemind.optim.state_averager import ( 19 | LRSchedulerBase, 20 | OptimizerFactory, 21 | Parameters, 22 | ParamGroups, 23 | SchedulerFactory, 24 | TorchOptimizer, 25 | TrainingStateAverager, 26 | ) 27 | from hivemind.utils import get_dht_time 28 | from hivemind.utils.timed_storage import DHTExpiration 29 | from hivemind.optim.optimizer import logger 30 | from hivemind.optim.progress_tracker import LocalTrainingProgress 31 | 32 | from open_diloco.utils import found_inf_grad 33 | 34 | 35 | class DiLoCoStateAverager(TrainingStateAverager): 36 | def __init__( 37 | self, 38 | *, 39 | num_inner_steps: int, 40 | inner_optimizer: TorchOptimizer, 41 | scheduler: Optional[SchedulerFactory] = None, 42 | **kwargs, 43 | ): 44 | self.inner_optimizer = inner_optimizer 45 | self.num_inner_steps = num_inner_steps 46 | 47 | super().__init__( 48 | **kwargs 49 | ) # we specifically don't pass the scheduler here, default TrainingStateAverager would use it with the outer optimizer and we w 50 | 51 | self.scheduler_inner_optimizer = scheduler(self.inner_optimizer) if scheduler is not None else None 52 | assert isinstance(self.scheduler_inner_optimizer, (LRSchedulerBase, type(None))) 53 | 54 | def _update_scheduler(self): 55 | """Increase the scheduler state until it becomes synchronized with local epoch""" 56 | # TODO(sami) handle update scheduler 57 | # for now assuming that all scheduler are on time 58 | pass 59 | 60 | 61 | class DiLoCoGradAverager(DecentralizedAverager): 62 | """ " 63 | DiLoCoGradAverager is meant to be used in pair with DiLoCoStateAverager. Specifically it takes as input the offloaded optimizer of DiLoCoStateAverager, and 64 | use the grad buffer of the offloaded param as averaged_tensors for the DecentralizedAverager. In other words the DiLoCoGradAverager makes sure that the grad of the offloaded optimizer 65 | are kept in sync between peers. 66 | """ 67 | 68 | def __init__( 69 | self, 70 | main_parameters: List[torch.nn.Parameter], 71 | offloaded_optimizer: TorchOptimizer, 72 | *, 73 | dht: DHT, 74 | prefix: str, 75 | warn: bool = True, 76 | **kwargs, 77 | ): 78 | if "client_mode" in kwargs: 79 | if kwargs["client_mode"] is not None and kwargs["client_mode"]: 80 | raise KeyError("client_mode is not supported in DiLoCoGradAverager") 81 | else: 82 | kwargs.pop("client_mode") 83 | 84 | if "averaged_grads" in kwargs: 85 | raise KeyError( 86 | "DiLoCoGradAverager does not support averaged_grads since it use the offloaded optimizer gradients directly" 87 | ) 88 | 89 | if not isinstance(main_parameters, (list, tuple)): 90 | raise ValueError( 91 | "main_parameters must be a list or tuple of torch.nn.Parameter and not an iterator otherwise parameters will be consumed" 92 | ) 93 | self.main_parameters = list(main_parameters) 94 | self.offloaded_optimizer = offloaded_optimizer 95 | 96 | self.warn = warn 97 | self.local_samples_accumulated = 0 98 | self.local_times_accumulated = 0 99 | 100 | self._new_averaged_grads = False 101 | 102 | averaged_grads = tuple(grad for grad in self._grads_from_optimizer()) 103 | 104 | super().__init__( 105 | averaged_tensors=averaged_grads, 106 | dht=dht, 107 | prefix=prefix, 108 | client_mode=False, 109 | **kwargs, 110 | ) 111 | 112 | def _grads_from_optimizer(self) -> Iterator[torch.Tensor]: 113 | """gradient buffers associated optimizer""" 114 | param_groups = self.offloaded_optimizer.param_groups 115 | for param_group in param_groups: 116 | for param in param_group["params"]: 117 | if param.grad is None: 118 | param.grad = torch.zeros_like(param) 119 | yield param.grad 120 | 121 | def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl: 122 | """ 123 | Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time. 124 | 125 | :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time 126 | :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc 127 | :note: setting weight at this stage is not supported, please leave this parameter as None 128 | :returns: step_control - a handle that can be passed into GradientAverager.step to use the pre-scheduled group 129 | :note: in the current implementation, each step_control can only be used in one step. 130 | """ 131 | assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported" 132 | return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs) 133 | 134 | def step( 135 | self, 136 | control: Optional[StepControl] = None, 137 | timeout: Optional[float] = None, 138 | wait: bool = True, 139 | **kwargs, 140 | ): 141 | """ 142 | Average accumulated gradients with peers, optionally load averaged gradients and reset accumulators 143 | 144 | :param weight: overrides the averaging weight; by default, weight equals the number of accumulated samples 145 | :param reset_accumulators: by default, set local gradient accumulators to zeros after averaging succeeds 146 | :param control: reuse a pre-arranged group of peers (or a matchmaking in progress) from averager.schedule_step 147 | :param timeout: if specified, await for averaging round for at most this number of seconds (if wait=True) 148 | :param wait: if True, await for the step to finish (or fail), otherwise run all-reduce in background 149 | """ 150 | if control is None: 151 | control = self.schedule_step(timeout=timeout, **kwargs) 152 | 153 | self.compute_and_load_pseudo_grad_into_averager() 154 | control.allow_allreduce() 155 | 156 | return control.result(timeout) if wait else control 157 | 158 | @torch.no_grad() 159 | def compute_and_load_pseudo_grad_into_averager(self): 160 | """compute pseudo gradient by subtracting the offloaded optimizer parameters with the main parameters and load them in the averager""" 161 | opt_parameters = [param for group in self.offloaded_optimizer.param_groups for param in group["params"]] 162 | with self.get_tensors() as averaged_grads: 163 | for opt_param, averaged_grad, main_param in zip(opt_parameters, averaged_grads, self.main_parameters): 164 | # opt_param is the param that will be all_reduce, it is suppose to be on cpu 165 | # main_param is the param that has been updated by the inner optimizer, it is suppose to be on gpu 166 | grad = opt_param.data - main_param.detach().to(opt_param.device) 167 | averaged_grad.copy_(grad, non_blocking=True) 168 | 169 | def notify_used_averaged_gradients(self): 170 | """Notify averager that the results of a previous averaging round are accounted for""" 171 | self._new_averaged_grads = False 172 | 173 | 174 | class DiloCoProgressTracker(ProgressTracker): 175 | global_progress: GlobalTrainingProgress 176 | local_progress: LocalTrainingProgress 177 | 178 | def __init__(self, batch_size: int, num_inner_steps: int, **kwargs): 179 | self.batch_size = batch_size 180 | self.num_inner_steps = num_inner_steps 181 | super().__init__(**kwargs) 182 | 183 | @property 184 | def ready_to_update_epoch(self) -> bool: 185 | """Whether or not this peer can increment epoch right away.""" 186 | return ( 187 | self.global_epoch > self.local_progress.epoch 188 | or self.local_progress.samples_accumulated 189 | >= self.target_batch_size # here we track local progress as each diloco worker need to do num_inner_steps (for now) 190 | # or get_dht_time() >= self.global_progress.eta_next_epoch # disabled for our test 191 | ) 192 | 193 | @property 194 | def estimated_next_update_time(self) -> DHTExpiration: 195 | """Estimate (absolute) time when this peer should increment epoch""" 196 | if self.ready_to_update_epoch: 197 | return get_dht_time() 198 | 199 | samples_remaining_to_next_epoch = max(0, self.target_batch_size - self.local_progress.samples_accumulated) 200 | return samples_remaining_to_next_epoch / self.performance_ema.samples_per_second 201 | 202 | @property 203 | def local_step(self) -> int: 204 | return self.local_progress.samples_accumulated // self.batch_size 205 | 206 | @property 207 | def real_step(self) -> int: 208 | return self.local_step + self.local_progress.epoch * self.batch_size 209 | 210 | def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> GlobalTrainingProgress: 211 | """Read performance statistics reported by peers, estimate progress towards next batch 212 | This function is copy paste from hivemind. Only difference is that if fix the ETA estimation. 213 | """ 214 | current_time = get_dht_time() 215 | 216 | if not isinstance(metadata, dict) or len(metadata) == 0: 217 | logger.log(self.status_loglevel, f"Found no active peers: {metadata}") 218 | samples_remaining_to_next_epoch = max(0, self.target_batch_size - self.local_progress.samples_accumulated) 219 | local_eta_next_epoch = samples_remaining_to_next_epoch / self.performance_ema.samples_per_second 220 | 221 | return GlobalTrainingProgress( 222 | self.local_progress.epoch, 223 | self.local_progress.samples_accumulated, 224 | self.target_batch_size, 225 | num_peers=0, 226 | num_clients=0, 227 | eta_next_epoch=current_time + local_eta_next_epoch, 228 | next_fetch_time=current_time + self.default_refresh_period, 229 | ) 230 | 231 | valid_peer_entries = [ 232 | LocalTrainingProgress.parse_obj(peer_state.value) 233 | for peer_state in metadata.values() 234 | if peer_state.value is not None 235 | ] 236 | 237 | num_peers = len(valid_peer_entries) 238 | num_clients = sum(peer.client_mode for peer in valid_peer_entries) 239 | 240 | global_epoch = self.local_progress.epoch 241 | for peer in valid_peer_entries: 242 | if not peer.client_mode: 243 | global_epoch = max(global_epoch, peer.epoch) 244 | 245 | total_samples_accumulated = 0 246 | total_samples_per_second = self.performance_ema.eps 247 | 248 | estimated_time_to_next_epoch = 0 249 | 250 | for peer in valid_peer_entries: 251 | total_samples_per_second += peer.samples_per_second 252 | if peer.epoch == global_epoch: 253 | samples_remaining_to_next_epoch = max(0, self.target_batch_size - peer.samples_accumulated) 254 | local_eta_next_epoch = samples_remaining_to_next_epoch / peer.samples_per_second 255 | 256 | estimated_time_to_next_epoch = max(estimated_time_to_next_epoch, local_eta_next_epoch) 257 | 258 | # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance; 259 | # the rationale behind this is that outdated peers will synchronize and begin contributing shortly. 260 | 261 | time_to_next_fetch = float( 262 | np.clip( 263 | a=estimated_time_to_next_epoch, 264 | a_min=self.min_refresh_period, 265 | a_max=self.max_refresh_period, 266 | ) 267 | ) 268 | 269 | logger.log( 270 | self.status_loglevel, 271 | f"{self.prefix} has taken {self.local_step} local steps. Peers: {num_peers}, epoch: {self.local_progress.epoch}, steps: {self.real_step}. ETA: {estimated_time_to_next_epoch:.2f}", 272 | ) 273 | 274 | return GlobalTrainingProgress( 275 | global_epoch, 276 | total_samples_accumulated, 277 | target_batch_size=self.target_batch_size, 278 | num_peers=num_peers, 279 | num_clients=num_clients, 280 | eta_next_epoch=current_time + estimated_time_to_next_epoch, 281 | next_fetch_time=current_time + time_to_next_fetch, 282 | ) 283 | 284 | 285 | class AllReduceStrategy(Enum): 286 | """ 287 | DiLoCo support multiple strategy to trigger the pseudo gradient averaging step. 288 | 289 | stregy: 290 | * WAIT_FOR_ALL: DiLoCo will wait for all peers to finish their local updates before triggering the all reduce step 291 | use this strategy when you trust all of your peers 292 | * NO_WAIT: The fastest peer will trigger the all reduce as soon as it reach its local steps (modulo the amount of time it need to wait because of the `matchmaking_time`) 293 | use this strategy when some of your peers are unreliable 294 | """ 295 | 296 | WAIT_FOR_ALL = "WAIT_FOR_ALL" 297 | NO_WAIT = "NO_WAIT" 298 | 299 | 300 | DEFAULT_TIMEOUT_WAITING_FOR_PEERS = 600 301 | 302 | 303 | class DiLoCoOptimizer(Optimizer): 304 | """ 305 | DiLoCo optimizer extend Hivemind's Optimizer to support DiLoCo training with local updates, requiring less bandwidth to train 306 | and still converge. 307 | 308 | Pseudo gradient is the difference between the weight before and after the multiple local update of the inner optimizer. 309 | 310 | Paper: https://arxiv.org/abs/2311.08105 311 | 312 | :param: outer_optimizer: Callable to an optimizer to update the pseudo gradient, this optimizer is shared between peers. (DiLoCo used the Nesterov opt) 313 | :param: inner_optimizer: Callable to an optimizer to update the model parameter locally, this optimizer is not shared between peers (DiLoCo used the AdamW opt) 314 | :param: scheduler: callable to a learning rate scheduler to update the inner optimizer lr. 315 | :param: num_inner_steps: number of inner optimizer updates per outer optimizer update 316 | :param: batch_size: number of samples in a single batch 317 | 318 | the rest of parameters are the same as Hivemind's Optimizer, expect `optimizer` that is override by `outer_optimizer`. 319 | """ 320 | 321 | state_averager: DiLoCoStateAverager 322 | inner_optimizer: TorchOptimizer 323 | tracker: DiloCoProgressTracker 324 | diloco_grad_averager: DiLoCoGradAverager 325 | 326 | def __init__( 327 | self, 328 | *, 329 | dht: DHT, 330 | run_id: str, 331 | batch_size: int, 332 | num_inner_steps: int, 333 | outer_optimizer: OptimizerFactory, 334 | inner_optimizer: OptimizerFactory, 335 | params: Optional[Union[Parameters, ParamGroups]] = None, 336 | scheduler: Optional[SchedulerFactory] = None, 337 | averager_opts: Optional[dict] = None, 338 | grad_compression: CompressionBase = NoCompression(), 339 | tracker_opts: Optional[dict] = None, 340 | all_reduce_strategy: AllReduceStrategy = AllReduceStrategy.WAIT_FOR_ALL, 341 | timeout_waiting_for_peers: float | None = None, 342 | matchmaking_time: Optional[float] = 15.0, 343 | **kwargs, 344 | ): 345 | self._check_kwargs(kwargs) 346 | 347 | if timeout_waiting_for_peers is not None: 348 | if all_reduce_strategy == AllReduceStrategy.NO_WAIT: 349 | raise ValueError( 350 | "You cannot use timeout_waiting_for_peers with NO_WAIT strategy, use WAIT_FOR_ALL instead" 351 | ) 352 | 353 | if timeout_waiting_for_peers is not None and timeout_waiting_for_peers < matchmaking_time: 354 | raise ValueError("timeout_waiting_for_peers must be greater than matchmaking_time") 355 | 356 | if all_reduce_strategy == AllReduceStrategy.WAIT_FOR_ALL: 357 | if timeout_waiting_for_peers is None: 358 | timeout_waiting_for_peers = DEFAULT_TIMEOUT_WAITING_FOR_PEERS 359 | 360 | self.all_reduce_strategy = all_reduce_strategy 361 | self.timeout_waiting_for_peers = timeout_waiting_for_peers 362 | 363 | params = list(params) 364 | # if params is a generator (like model.parameters()) it would be consumed by the first optimizer 365 | # since we have two optimizers, we need to persist the params to a list 366 | self.num_inner_steps = num_inner_steps 367 | 368 | for opt_or_scheduler in [outer_optimizer, scheduler]: 369 | if not (callable(opt_or_scheduler) or opt_or_scheduler is None): 370 | raise TypeError("You need to pass inner and outer optimizer as well as scheduler as callable") 371 | 372 | if isinstance(inner_optimizer, TorchOptimizer): 373 | self.inner_optimizer = inner_optimizer 374 | elif isinstance(inner_optimizer, Callable): 375 | self.inner_optimizer = inner_optimizer(params=params) 376 | else: 377 | raise TypeError( 378 | f"Expected inner_optimizer to be TorchOptimizer or OptimizerFactory, got {type(inner_optimizer)}" 379 | ) 380 | 381 | if tracker_opts is None: 382 | tracker_opts = {} 383 | 384 | tracker_opts.update(dict(batch_size=batch_size, num_inner_steps=num_inner_steps)) 385 | 386 | if "max_refresh_period" not in tracker_opts: 387 | tracker_opts["max_refresh_period"] = 2 388 | 389 | self.scheduled_diloco_grads: Optional[StepControl] = None 390 | 391 | super().__init__( 392 | optimizer=outer_optimizer, 393 | dht=dht, 394 | run_id=run_id, 395 | target_batch_size=batch_size * num_inner_steps, 396 | batch_size_per_step=batch_size, 397 | params=params, 398 | scheduler=scheduler, 399 | use_local_updates=True, # we are handling grad scaler ourself 400 | offload_optimizer=True, # DiLoCo is always offloading optimizers bc of the pseudo gradient 401 | averager_opts=averager_opts, 402 | tracker_opts=tracker_opts, 403 | matchmaking_time=matchmaking_time, 404 | **kwargs, 405 | ) 406 | self.diloco_grad_averager = self._make_gradient_averager(compression=grad_compression) 407 | 408 | def _check_kwargs(self, kwargs) -> None: 409 | """DiLoCo Optimizer only support a subset of Hivemind Optimizer kwargs. 410 | This function raise an error if some kwargs are not supported""" 411 | 412 | if "optimizer" in kwargs: 413 | raise KeyError("optimizer should not be passed to DiLoCoOptimizer, pass rather to outer_optimizer") 414 | 415 | if "use_local_updates" in kwargs: 416 | if kwargs["use_local_updates"] is False: 417 | raise ValueError( 418 | "You cannot use DiLoCo without local updates, please use normal Optimizer if you don't want local updates" 419 | ) 420 | else: 421 | kwargs.pop("use_local_updates") 422 | 423 | if "offload_optimizer" in kwargs: 424 | if kwargs["offload_optimizer"] is False: 425 | raise ValueError("offload_optimizer=False, is not supported in DiLoCo for now") 426 | else: 427 | kwargs.pop("offload_optimizer") 428 | 429 | for arg_name in ( 430 | "delay_state_averaging", 431 | "delay_grad_averaging", 432 | "delay_optimizer_step", 433 | ): 434 | if arg_name in kwargs: 435 | if kwargs[arg_name] is True: 436 | raise ValueError(f"{arg_name} is not supported in DiLoCo for now") 437 | 438 | if "target_batch_size" in kwargs: 439 | raise KeyError( 440 | "DiLoCo does not have a target_batch_size, use batch_size instead in combination with num_inner_steps" 441 | ) 442 | 443 | if "batch_size_per_step" in kwargs: 444 | raise KeyError("DiLoCo does not have a batch_size_per_step, use batch_size instead") 445 | 446 | def _make_gradient_averager(self, **kwargs) -> DiLoCoGradAverager: 447 | assert hasattr(self, "state_averager"), "must initialize state averager first" 448 | grad_averager = DiLoCoGradAverager( 449 | dht=self.dht, 450 | prefix=f"{self.run_id}_grad_averager", 451 | main_parameters=self.state_averager.main_parameters, 452 | offloaded_optimizer=self.state_averager.optimizer, 453 | min_matchmaking_time=self.matchmaking_time, 454 | allreduce_timeout=self.allreduce_timeout, 455 | shutdown_timeout=self.shutdown_timeout, 456 | next_chunk_timeout=self.next_chunk_timeout, 457 | client_mode=self.client_mode, 458 | auxiliary=self.auxiliary, 459 | start=True, 460 | **kwargs, 461 | ) 462 | return grad_averager 463 | 464 | def _make_state_averager(self, **kwargs) -> DiLoCoStateAverager: 465 | return DiLoCoStateAverager( 466 | dht=self.dht, 467 | prefix=f"{self.run_id}_state_averager", 468 | min_matchmaking_time=self.matchmaking_time, 469 | allreduce_timeout=self.allreduce_timeout, 470 | shutdown_timeout=self.shutdown_timeout, 471 | offload_optimizer=self.offload_optimizer, 472 | custom_gradients=self.offload_optimizer, 473 | status_loglevel=self.status_loglevel, 474 | next_chunk_timeout=self.next_chunk_timeout, 475 | client_mode=self.client_mode, 476 | auxiliary=self.auxiliary, 477 | start=True, 478 | num_inner_steps=self.num_inner_steps, 479 | inner_optimizer=self.inner_optimizer, 480 | **kwargs, 481 | ) 482 | 483 | def step( 484 | self, 485 | closure: Optional[Callable[[], torch.Tensor]] = None, 486 | batch_size: Optional[int] = None, 487 | scaler: Optional[torch.cuda.amp.GradScaler] = None, 488 | ): 489 | """ 490 | Note: code is is copied from Hivemind's Optimizer.step, the main change is that the local step is used with the **iner optimizer**, only 491 | the global step that sync data via all reduce is using the **outer optimizer** states. 492 | 493 | Note: There is no gradient accumulation in our DiLoCo implementation since we use local updates. 494 | 495 | Note2: the gradient scaler is only apply to the inner optimizer step. The outer optimizer step is working on pseudo gradient 496 | that don't need to be scaled. 497 | 498 | Note3: You should not call scaler.step(optimizer) but rather optimizer.step(scaler=scaler) otherwise the scaler will not work as expected because of the outer step. 499 | 500 | Update training. Depending on the configuration, this will 501 | report progress to peers, run global or local optimizer step, average parameters or schedule background tasks. 502 | 503 | Grad scaler must be pass to use mixed precision with the inner optimizer. One can call unscale_ before tho. 504 | 505 | :param closure: A closure that reevaluates the model and returns the loss. 506 | :param batch_size: optional override for batch_size_per_step from init. 507 | :param scaler: a scaler from torch.cuda.amp.GradScaler, if provided, the scaler will be used to scale the inner optimizer step but not the outer optimizer step. 508 | :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details. 509 | """ 510 | ### OG HIVEMIND CODE START ### 511 | if self.batch_size_per_step is None and batch_size is None and not self.auxiliary: 512 | raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step") 513 | if self.auxiliary and (closure is not None or batch_size is not None): 514 | raise ValueError("Auxiliary peers should not have batch size, run closures, or use grad_scaler") 515 | if scaler is not None and closure is not None: 516 | raise ValueError("You cannot use closure and scaler at the same time") 517 | 518 | batch_size = batch_size if batch_size is not None else self.batch_size_per_step 519 | 520 | # if delayed updates finished before step, apply these updates; otherwise do nothing 521 | # self.state_averager.step(apply_delayed_updates=True) 522 | 523 | loss = None 524 | if closure is not None: 525 | with torch.enable_grad(): 526 | loss = closure() 527 | 528 | if not self.auxiliary and self._should_load_state_from_peers(): 529 | logger.log(self.status_loglevel, "Peer is out of sync") 530 | self.load_state_from_peers() 531 | return loss # local gradients were computed with out-of-sync parameters, must start over 532 | 533 | ### OG HIVEMIND CODE END ### 534 | 535 | # this code is similar to the hivemind.Optimizer.step when `use_local_updates` is True 536 | # at the difference that it call the inner optimizer step as well. 537 | 538 | if not self.auxiliary: 539 | new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size 540 | self.tracker.report_local_progress(self.local_epoch, samples_accumulated=new_samples_accumulated) 541 | 542 | self._maybe_schedule_state_averaging() 543 | self._maybe_schedule_gradient_averaging() 544 | 545 | if scaler is not None: 546 | scaler.step(self.inner_optimizer) 547 | if found_inf_grad(self.inner_optimizer, scaler): 548 | logger.log(self.status_loglevel, f"Found inf grad at step {self.tracker.real_step}") 549 | else: 550 | self.inner_optimizer.step(closure=closure) 551 | 552 | if self.state_averager.scheduler_inner_optimizer: 553 | self.state_averager.scheduler_inner_optimizer.step() 554 | 555 | if self.tracker.ready_to_update_epoch: 556 | self._update_global_epoch() 557 | 558 | return loss 559 | 560 | def _compute_schema_hash(self) -> int: 561 | """this function is similar to hivemind.Optimizer._compute_schema_hash 562 | but disregard the gradient buffers of the offloaded optimizer 563 | """ 564 | optimized_param_groups = self.state_averager.optimizer.param_groups 565 | optimized_parameters = [param for group in optimized_param_groups for param in group["params"]] 566 | param_shapes = tuple(tuple(param.shape) for param in optimized_parameters) 567 | grad_ids = None 568 | return hash((grad_ids, param_shapes)) 569 | 570 | def _update_global_epoch(self) -> None: 571 | """Depending on the configuration: aggregate gradients and/or parameters, perform global optimizer step 572 | 573 | NOTE: this has been mostly copied from hivemind.Optimizer._update_global_epoch, except highlighted lines 574 | """ 575 | assert self._schema_hash == self._compute_schema_hash(), "parameters changed during iteration" 576 | _epoch_start_time = time.perf_counter() 577 | 578 | if self.tracker.global_progress.num_peers > 1: 579 | if self.all_reduce_strategy == AllReduceStrategy.WAIT_FOR_ALL: 580 | if self.scheduled_diloco_grads is None: 581 | init_time_waiting = time.perf_counter() 582 | 583 | timeout_triggered = False 584 | 585 | while time.perf_counter() - init_time_waiting < self.timeout_waiting_for_peers: 586 | eta_next_epoch = self.tracker.global_progress.eta_next_epoch - get_dht_time() 587 | if eta_next_epoch > self.matchmaking_time: 588 | time_to_wait = max(0.1, self.tracker.global_progress.next_fetch_time - get_dht_time()) 589 | logger.log( 590 | self.status_loglevel, 591 | f"ETA next epoch {eta_next_epoch}, refresh in {time_to_wait}", 592 | ) 593 | time.sleep(time_to_wait) 594 | else: 595 | logger.log( 596 | self.status_loglevel, 597 | f"Pre-scheduling gradient averaging round in {self.matchmaking_time:.2f} sec", 598 | ) 599 | break 600 | else: 601 | timeout_triggered = True 602 | 603 | if timeout_triggered: 604 | logger.log( 605 | self.status_loglevel, 606 | "Timeout waiting for peers all-reduce was triggered. Going to skip slowest peers", 607 | ) 608 | # todo(sami) in this case we still will have to wait for min_matchmaking_time, this could be optimized 609 | 610 | with self.tracker.pause_updates(): 611 | assert not self.delay_optimizer_step, "delay_optimizer_step must be False in DiLoCo" 612 | 613 | if self.tracker.global_progress.num_peers > 1: 614 | logger.log(self.status_loglevel, f"Beginning optimizer step #{self.local_epoch}") 615 | time_0 = time.perf_counter() 616 | 617 | self.diloco_grad_averager.step( 618 | wait=True, timeout=self.averaging_timeout, control=self.scheduled_diloco_grads 619 | ) 620 | time_1 = time.perf_counter() 621 | logger.log( 622 | self.status_loglevel, 623 | f"Time taken for gradient all reduce: {time_1 - time_0} sec", 624 | ) 625 | 626 | self.diloco_grad_averager.notify_used_averaged_gradients() 627 | self.scheduled_diloco_grads = None 628 | else: 629 | self.diloco_grad_averager.compute_and_load_pseudo_grad_into_averager() 630 | 631 | next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch) 632 | swarm_not_empty = self.tracker.global_progress.num_peers > 1 633 | should_perform_optimizer_step = True # different from hivemind.Optimizer 634 | should_average_state = ( 635 | swarm_not_empty 636 | and next_epoch % self.average_state_every == 0 637 | and not self.state_averager.averaging_in_progress 638 | ) 639 | 640 | if should_average_state and self.scheduled_state is not None: 641 | if self.scheduled_state.triggered or self.scheduled_state.done(): 642 | logger.log( 643 | self.status_loglevel, 644 | f"Not using pre-scheduled group for state averaging because it" 645 | f"was already used elsewhere: {self.scheduled_state}", 646 | ) 647 | self.scheduled_state = None 648 | self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time) 649 | 650 | assert self.state_averager.custom_gradients, "custom gradient must be enable for syncing pseudo gradients" 651 | 652 | logger.info(f"Try outer optimizer step at {self.tracker.real_step} step") 653 | 654 | self.state_averager.step( 655 | increment_epoch=True, 656 | wait_for_trigger=None, 657 | optimizer_step=should_perform_optimizer_step, 658 | delay_optimizer_step=self.delay_optimizer_step and should_perform_optimizer_step, 659 | grad_scaler=None, 660 | averaging_round=should_average_state, 661 | delay_averaging=self.delay_state_averaging and not self.auxiliary, 662 | averaging_control=(self.scheduled_state if should_average_state else None), 663 | averaging_opts=(dict(timeout=self.averaging_timeout) if should_average_state else None), 664 | zero_grad=False, # zero grad should be done outside of diloco 665 | ) 666 | 667 | if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done(): 668 | self.scheduled_state.cancel() 669 | self.scheduled_state = None 670 | 671 | self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch) 672 | self._should_check_synchronization_on_update = True 673 | # the above line ensures that peers check for *strict* synchronization once per epoch 674 | 675 | if not self.client_mode: 676 | self.state_averager.state_sharing_priority = self.local_epoch 677 | 678 | self.update_main_param_after_outer_step() 679 | logger.log(self.status_loglevel, f"Transitioning to epoch {self.local_epoch}") 680 | 681 | def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> DiloCoProgressTracker: 682 | return DiloCoProgressTracker( 683 | dht=self.dht, 684 | prefix=self.run_id, 685 | target_batch_size=target_batch_size, 686 | client_mode=self.client_mode, 687 | status_loglevel=self.status_loglevel, 688 | start=True, 689 | **kwargs, 690 | ) 691 | 692 | @property 693 | def param_groups(self) -> ParamGroups: 694 | """Inner optimizer is the main optimizer""" 695 | return self.inner_optimizer.param_groups 696 | 697 | def state_dict(self) -> dict: 698 | """we save both inner and outer optimizer states, and the local epoch""" 699 | state_dict_outer = self.state_averager.optimizer.state_dict() 700 | state_dict_outer["state"]["local_epoch"] = self.local_epoch 701 | 702 | state_dict_inner = self.inner_optimizer.state_dict() 703 | 704 | return { 705 | "state_dict_outer": state_dict_outer, 706 | "state_dict_inner": state_dict_inner, 707 | } 708 | 709 | def load_state_dict(self, state_dict: dict): 710 | if "local_epoch" in state_dict["state_dict_outer"]["state"]: 711 | self.state_averager.local_epoch = state_dict["state_dict_outer"]["state"].pop("local_epoch") 712 | 713 | self.state_averager.optimizer.load_state_dict(state_dict["state_dict_outer"]) 714 | self.inner_optimizer.load_state_dict(state_dict["state_dict_inner"]) 715 | 716 | def update_main_param_after_outer_step(self): 717 | """Update the main parameters with the inner optimizer step""" 718 | opt_parameters = [param for group in self.inner_optimizer.param_groups for param in group["params"]] 719 | for main_param, opt_param in zip(self.state_averager.main_parameters, opt_parameters): 720 | main_param.data.copy_(opt_param.data, non_blocking=True) 721 | 722 | def _maybe_schedule_gradient_averaging(self) -> None: 723 | """If next epoch is coming soon, schedule the next gradient averaging round at the estimated end of epoch""" 724 | 725 | if self.all_reduce_strategy == AllReduceStrategy.WAIT_FOR_ALL: 726 | eta_seconds = self.tracker.global_progress.eta_next_epoch - get_dht_time() 727 | else: 728 | eta_seconds = self.tracker.estimated_next_update_time 729 | 730 | if eta_seconds <= self.matchmaking_time: 731 | if ( 732 | self.scheduled_diloco_grads is None 733 | or self.scheduled_diloco_grads.triggered 734 | or self.scheduled_diloco_grads.done() 735 | ): 736 | eta_seconds = max(eta_seconds, self.diloco_grad_averager.matchmaking_kwargs["min_matchmaking_time"]) 737 | logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f} sec") 738 | self.scheduled_diloco_grads = self.diloco_grad_averager.schedule_step(timeout=self.averaging_timeout) 739 | -------------------------------------------------------------------------------- /open_diloco/init_weights.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # python3 init_weights.py --config-name-or-path configs/config_60m.json --hub-model-id PrimeIntellect/llama-60m-fresh 3 | from transformers import AutoConfig, AutoModelForCausalLM 4 | from cyclopts import App 5 | import os 6 | 7 | app = App() 8 | 9 | 10 | @app.default 11 | def main( 12 | config_name_or_path: str, 13 | hub_model_id: str | None = None, 14 | save_to_disk: str | None = None, 15 | ): 16 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path=config_name_or_path) 17 | from transformers import LlamaForCausalLM 18 | 19 | model: LlamaForCausalLM = AutoModelForCausalLM.from_config(config) 20 | print(model) 21 | if save_to_disk: 22 | os.makedirs(save_to_disk, exist_ok=True) 23 | model.save_pretrained(save_to_disk) 24 | if hub_model_id: 25 | model.push_to_hub(hub_model_id) 26 | 27 | 28 | if __name__ == "__main__": 29 | app() 30 | -------------------------------------------------------------------------------- /open_diloco/run_training.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## example usage 4 | # ./run_training.sh 4 2 /ip4/127.0.0.1/tcp/36593/p2p/12D3KooWEAyutJ1zFqhAbzDn1LSzraB3o1uS8GSHxQYM87QP4AHN --per-device-train-batch-size 16 --batch-size 512 --local-steps 10 --total-steps 88000 --c4-tiny 5 | # note that everything after the initial peer with will pass to all worker 6 | # 7 | ## the command above will use a total of 8 gpu and create 4 diloco workers each of them with two gpu training ddp/fsdp wise 8 | 9 | 10 | # you can either pass a fixed initial peer or set it to auto and the script will start a dht server for you 11 | ## ./run_training.sh 4 1 auto --per-device-train-batch-size 8 --total-batch-size 128 --lr 1e-2 --path-model ../tests/models/llama-2m-fresh --project debug --no-torch-compile --hv.local-steps 100 --fake-data --hv.matchmaking_time 2 --hv.fail_rank_drop --hv.skip_load_from_peers 12 | 13 | # Function to get CUDA devices based on the number of GPUs and index 14 | function get_cuda_devices() { 15 | local num_gpu=$1 16 | local index=$2 17 | local start_gpu=$((num_gpu * index)) 18 | local end_gpu=$((start_gpu + num_gpu - 1)) 19 | 20 | if [ "$num_gpu" -eq 1 ]; then 21 | echo $start_gpu 22 | else 23 | echo $(seq -s ',' $start_gpu $end_gpu) 24 | fi 25 | } 26 | 27 | # Check if at least three arguments were passed 28 | if [ "$#" -lt 3 ]; then 29 | echo "Usage: $0 [additional_python_args]" 30 | exit 1 31 | fi 32 | 33 | N=$1 # Set N from the first argument 34 | NUM_GPU=$2 35 | INITIAL_PEER=$3 # Set INITIAL_PEER from the second argument 36 | shift 3 # Remove the first three arguments so $@ contains only additional Python arguments 37 | 38 | mkdir -p logs 39 | echo "Initial peer: $INITIAL_PEER" 40 | 41 | # Check if INITIAL_PEER is set to 'auto' and adjust accordingly 42 | if [ "$INITIAL_PEER" = "auto" ]; then 43 | # start the dht server 44 | echo "Starting DHT server" 45 | hivemind-dht --host_maddr /ip4/0.0.0.0/tcp/12345 --identity_path fixed_key.pem > logs/log_dht 2>&1 & 46 | 47 | INITIAL_PEER="" 48 | # get the initial peer from the logs, loop until the peer is found 49 | while [ -z "$INITIAL_PEER" ]; do 50 | sleep 1 51 | INITIAL_PEER=$(awk '/Running a DHT instance/ {print $NF}' logs/log_dht) 52 | 53 | done 54 | fi 55 | echo "Initial peer: $INITIAL_PEER" 56 | 57 | # Ensure the logs directory exists 58 | mkdir -p logs 59 | 60 | # Execute the command for the first device and log the output, run in background 61 | CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU 0) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 train_fsdp.py --hv.initial-peers $INITIAL_PEER $@ --hv.world-rank 0 --hv.galaxy-size $N > logs/log0 2>&1 & 62 | # Wait for 1 second before continuing with the rest 63 | sleep 2 64 | 65 | # Loop from 1 to N-1 and execute the command with different CUDA_VISIBLE_DEVICES and seed values, logging each command's output, run each in background 66 | for i in $(seq 1 $(($N - 1))) 67 | do 68 | WANDB_MODE=disabled CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) torchrun --nproc_per_node=$NUM_GPU --rdzv-endpoint localhost:123$i --nnodes=1 train_fsdp.py --hv.initial-peers $INITIAL_PEER $@ --hv.world-rank $i --hv.galaxy-size $N > logs/log$i 2>&1 & 69 | done 70 | 71 | tail -f logs/log0 72 | -------------------------------------------------------------------------------- /open_diloco/train_diloco_torch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from contextlib import nullcontext 4 | from datetime import datetime 5 | from typing import Literal 6 | 7 | import fsspec 8 | import torch 9 | import torch.distributed as dist 10 | import wandb 11 | from cyclopts import App 12 | from datasets import load_dataset 13 | from datasets.distributed import split_dataset_by_node 14 | from fsspec.generic import GenericFileSystem 15 | from torch.distributed import destroy_process_group, init_process_group 16 | from torch.utils.data import DataLoader 17 | from transformers import ( 18 | AutoTokenizer, 19 | DataCollatorForLanguageModeling, 20 | LlamaForCausalLM, 21 | get_cosine_schedule_with_warmup, 22 | ) 23 | 24 | from open_diloco.utils import get_grad_norm, register_hooks_log_activations 25 | 26 | 27 | # Function to initialize the distributed process group 28 | def ddp_setup(): 29 | init_process_group(backend="nccl") 30 | torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) 31 | 32 | 33 | # Function to load the checkpoint state into model, optimizer, and scheduler 34 | def load_checkpoint( 35 | model, 36 | inner_optimizer, 37 | outer_optimizer, 38 | scheduler, 39 | filename, 40 | resume_only_model: bool, 41 | ): 42 | with fsspec.open(filename, "rb") as f: 43 | checkpoint = torch.load(f) 44 | 45 | if resume_only_model: 46 | for key in list(checkpoint["model_state_dict"].keys()): 47 | if "module" in key: 48 | checkpoint["model_state_dict"][key.replace("module.", "")] = checkpoint["model_state_dict"].pop(key) 49 | 50 | model.load_state_dict(checkpoint["model_state_dict"]) 51 | if not resume_only_model: 52 | inner_optimizer.load_state_dict(checkpoint["inner_optimizer_state_dict"]) 53 | scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) 54 | outer_optimizer.load_state_dict(checkpoint["outer_optimizer_state_dict"]) 55 | 56 | return checkpoint["step"], checkpoint["loss"] 57 | 58 | 59 | def save_checkpoint( 60 | real_step: int, 61 | model, 62 | inner_optimizer, 63 | outer_optimizer, 64 | scheduler, 65 | loss, 66 | checkpoint_path, 67 | training_date, 68 | project, 69 | ): 70 | local_file_path = os.path.join( 71 | get_ckpt_folder(checkpoint_path, training_date, project), 72 | f"model_step_{real_step}.pt", 73 | ) 74 | checkpoint_data = { 75 | "model_state_dict": model.state_dict(), 76 | "inner_optimizer_state_dict": inner_optimizer.state_dict(), 77 | "outer_optimizer_state_dict": outer_optimizer.state_dict(), 78 | "scheduler_state_dict": scheduler.state_dict(), 79 | "loss": loss.item(), 80 | "step": real_step, 81 | } 82 | with fsspec.open(local_file_path, "wb") as f: 83 | torch.save(checkpoint_data, f) 84 | print(f"Checkpoint saved at step {real_step}") 85 | 86 | 87 | def evaluate_model(eval_dataloader, model, half_precision): 88 | loss_eval = 0 89 | step_eval = 0 90 | 91 | eval_start_time = time.time() 92 | for batch_eval in eval_dataloader: 93 | for key in batch_eval.keys(): 94 | batch_eval[key] = batch_eval[key].to("cuda") 95 | 96 | with torch.no_grad(): 97 | model.eval() 98 | 99 | with torch.autocast(device_type="cuda", dtype=torch.float16) if half_precision else nullcontext(): 100 | outputs = model(**batch_eval) 101 | loss_eval += outputs.loss 102 | 103 | step_eval += 1 104 | 105 | eval_end_time = time.time() 106 | model.train() 107 | 108 | print(f"Evaluation time: {eval_end_time - eval_start_time:.2f} seconds") 109 | loss_eval /= step_eval 110 | return {"eval_loss": loss_eval, "eval_perplexity": torch.exp(loss_eval)} 111 | 112 | 113 | def get_ckpt_folder(checkpoint_path, training_date, project): 114 | return os.path.join(checkpoint_path, training_date, project, wandb.run.id) 115 | 116 | 117 | def check_checkpoint_path_access(checkpoint_path: str, training_date, project): 118 | dummy_file_path = os.path.join( 119 | get_ckpt_folder( 120 | checkpoint_path=checkpoint_path, 121 | training_date=training_date, 122 | project=project, 123 | ), 124 | "dummy_file.txt", 125 | ) 126 | with fsspec.open(dummy_file_path, "w") as f: 127 | f.write("This is a dummy file for testing access.") 128 | gfs = GenericFileSystem() 129 | gfs.rm(dummy_file_path) 130 | 131 | 132 | def get_offloaded_param(outer_optimizer: torch.optim.Optimizer): 133 | return [ 134 | param.data.detach().clone().to("cpu") for group in outer_optimizer.param_groups for param in group["params"] 135 | ] 136 | 137 | 138 | app = App() 139 | 140 | 141 | @app.default 142 | def main( 143 | batch_size: int = 512, 144 | per_device_train_batch_size: int = 32, 145 | seq_length: int = 1024, 146 | c4_tiny: bool = False, 147 | checkpoint_interval: int | None = None, 148 | checkpoint_path: str = "outputs", 149 | warmup_steps: int = 1000, 150 | total_steps: int = 88_000, 151 | precision: Literal["fp16-mixed", "bf16-mixed", "32-true"] = "fp16-mixed", 152 | project: str = "hivemind_debug", 153 | model_name_or_path: str = "PrimeIntellect/llama-150m-fresh", 154 | lr: float = 4e-4, 155 | resume_from_checkpoint: str | None = None, 156 | seed_data: int | None = None, 157 | eval_steps: int | None = None, 158 | log_activations_steps: int | None = None, 159 | local_steps: int = 500, 160 | wandb_group: str | None = None, 161 | resume_only_model: bool = False, 162 | outer_lr: float = 0.7, 163 | ): 164 | local_rank = int(os.environ["LOCAL_RANK"]) 165 | world_size = int(os.environ["WORLD_SIZE"]) 166 | 167 | # batch_size is the total batch size for all GPUs 168 | # assert batch_size % world_size == 0 169 | # batch_size = batch_size / world_size 170 | 171 | assert batch_size % per_device_train_batch_size == 0 172 | gradient_accumulation_steps = batch_size // per_device_train_batch_size 173 | 174 | if local_rank == 0: 175 | wandb.init(project=project, group=wandb_group) 176 | 177 | training_date = datetime.now().strftime( 178 | "%Y-%m-%d" 179 | ) # we define the data at the beginning of training in case the training take several days 180 | 181 | check_checkpoint_path_access(checkpoint_path, training_date, project) 182 | # Load model configuration and tokenizer 183 | model = LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=model_name_or_path).to(local_rank) 184 | 185 | # Setup optimizers 186 | inner_optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.1, betas=(0.9, 0.95)) 187 | outer_optimizer = torch.optim.SGD(model.parameters(), lr=outer_lr, momentum=0.9, nesterov=True) 188 | 189 | scheduler = get_cosine_schedule_with_warmup( 190 | inner_optimizer, 191 | num_warmup_steps=warmup_steps, 192 | num_training_steps=total_steps, 193 | ) 194 | if precision not in ["fp16-mixed", "bf16-mixed", "32-true"]: 195 | raise ValueError(f"Invalid precision: {precision}. Please choose 'fp16-mixed', 'bf16-mixed', or '32-true'.") 196 | 197 | half_precision = precision == "fp16-mixed" or precision == "bf16-mixed" 198 | half_precision_dtype = torch.bfloat16 if precision == "bf16-mixed" else torch.float16 199 | scaler = torch.cuda.amp.GradScaler(enabled=precision == "fp16-mixed") 200 | 201 | tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True) 202 | tokenizer.pad_token = "" # Ensure pad token is set for models that need it 203 | 204 | ds = ( 205 | load_dataset("PrimeIntellect/c4-tiny", "en", ignore_verifications=True) 206 | if c4_tiny 207 | else load_dataset( 208 | "allenai/c4", 209 | "en", 210 | streaming=True, 211 | data_files={ 212 | "train": "en/c4-train.*.json.gz", 213 | "validation": "en/c4-validation.00000-of-00008.json.gz", 214 | }, 215 | ) 216 | ) 217 | # we only load one eval file to be faster 218 | 219 | if seed_data is not None: 220 | ds = ds.shuffle(seed=seed_data) 221 | 222 | def tokenize_function(data): 223 | outputs = tokenizer(data["text"], truncation=True, max_length=seq_length) 224 | return outputs 225 | 226 | tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["text", "timestamp", "url"]) 227 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 228 | train_dataset = split_dataset_by_node(tokenized_datasets["train"], world_size=world_size, rank=local_rank) 229 | train_dataloader = DataLoader(train_dataset, collate_fn=data_collator, batch_size=per_device_train_batch_size) 230 | 231 | if eval_steps is not None: 232 | eval_dataset = tokenized_datasets["validation"] 233 | eval_dataloader = DataLoader( 234 | eval_dataset, 235 | collate_fn=data_collator, 236 | batch_size=per_device_train_batch_size, 237 | ) 238 | 239 | start_step = 0 240 | 241 | if resume_from_checkpoint is not None: 242 | last_step, last_loss = load_checkpoint( 243 | model, 244 | inner_optimizer, 245 | outer_optimizer, 246 | scheduler, 247 | resume_from_checkpoint, 248 | resume_only_model, 249 | ) 250 | start_step = last_step 251 | print(f"Resumed from checkpoint at step {start_step} with loss {last_loss}") 252 | 253 | for param in model.parameters(): 254 | # this make sure all device have the same weight init 255 | dist.broadcast(param.data, src=0) 256 | 257 | params_offloaded = get_offloaded_param(outer_optimizer) 258 | 259 | model.train() 260 | 261 | log_activations = None 262 | 263 | start_time = time.time() 264 | print(f"starting from step {start_step}") 265 | 266 | check_start_step = start_step > 0 267 | 268 | handles = None 269 | 270 | loss_batch = 0 271 | 272 | for step, batch in enumerate(iterable=train_dataloader): 273 | real_step = (step + 1) // gradient_accumulation_steps 274 | step_within_grad_acc = (step + 1) % gradient_accumulation_steps 275 | 276 | if check_start_step: 277 | if real_step < start_step: 278 | continue # skipping steps before start_step 279 | else: 280 | check_start_step = False 281 | print(f"skipped step {step+1}, real_step {real_step} in {time.time() - start_time:.2f} seconds") 282 | continue 283 | 284 | if log_activations_steps is not None: 285 | if ( 286 | real_step >= log_activations_steps 287 | and real_step % log_activations_steps == 0 288 | and step_within_grad_acc == 0 289 | ): 290 | if local_rank == 0: 291 | print(f"Logging activations at step {real_step}") 292 | handles, log_activations = register_hooks_log_activations(model) 293 | 294 | if ( 295 | real_step - 1 >= log_activations_steps 296 | and (real_step - 1) % log_activations_steps == 0 297 | and step_within_grad_acc == 0 298 | ): 299 | if local_rank == 0: 300 | print(f"Removing activations logging at step {real_step}") 301 | 302 | # if we are after the step where we log the activations, we remove the hooks 303 | if handles is not None: 304 | for handle in handles: 305 | handle.remove() 306 | handles = None 307 | log_activations = None 308 | 309 | for key in batch.keys(): 310 | batch[key] = batch[key].to("cuda") 311 | 312 | with torch.autocast(device_type="cuda", dtype=half_precision_dtype) if half_precision else nullcontext(): 313 | outputs = model(**batch) 314 | loss = outputs.loss / gradient_accumulation_steps 315 | 316 | loss_batch += loss.detach() 317 | 318 | scaler.scale(loss).backward() 319 | 320 | if step_within_grad_acc == 0: 321 | scaler.unscale_(optimizer=inner_optimizer) 322 | 323 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # gradient clipping 324 | 325 | scaler.step(optimizer=inner_optimizer) 326 | scaler.update() 327 | scheduler.step() 328 | 329 | if log_activations_steps is not None and real_step % log_activations_steps == 0: 330 | log_norms_data = get_grad_norm(model) 331 | else: 332 | log_norms_data = None 333 | 334 | inner_optimizer.zero_grad() 335 | 336 | if real_step % local_steps == 0: 337 | if local_rank == 0: 338 | print(f"perform outer step at step {real_step}") 339 | 340 | main_param = [param for group in inner_optimizer.param_groups for param in group["params"]] 341 | 342 | for param_offloaded, param in zip(params_offloaded, main_param): 343 | param_offloaded_on_device = param_offloaded.data.to(param.device) 344 | param.grad = param_offloaded_on_device - param.data 345 | dist.all_reduce(tensor=param.grad, op=dist.ReduceOp.AVG) 346 | param.data = param_offloaded_on_device 347 | 348 | # here we don't call scaler.step. Indeed the scaler has already done his work (scaling down the gradients) with the optimizer.step call 349 | outer_optimizer.step() 350 | 351 | outer_optimizer.zero_grad() 352 | 353 | params_offloaded = get_offloaded_param(outer_optimizer) 354 | 355 | if local_rank == 0 and eval_steps is not None and real_step % eval_steps == 0: 356 | print(f"Evaluating at step {real_step}") 357 | 358 | if handles is not None: 359 | for handle in handles: 360 | handle.remove() 361 | handles = None 362 | dict_to_log_eval = evaluate_model(eval_dataloader, model, half_precision) 363 | 364 | else: 365 | dict_to_log_eval = {} 366 | 367 | if local_rank == 0: 368 | dict_to_log = { 369 | "Loss": loss_batch.item(), 370 | "step": real_step, 371 | "lr": [group["lr"] for group in inner_optimizer.param_groups][0], 372 | "Perplexity": torch.exp(loss_batch).item(), 373 | "effective_step": real_step * world_size, 374 | "total_samples": real_step * batch_size * world_size, 375 | } 376 | dict_to_log.update(dict_to_log_eval) 377 | 378 | if log_norms_data is not None: 379 | dict_to_log.update(log_norms_data) 380 | 381 | if log_activations: 382 | for key, _ in log_activations.items(): 383 | log_activations[key] /= gradient_accumulation_steps 384 | # log activation will accumulate all of the norm of the activations at each grad acc step 385 | # so we need to divide 386 | dict_to_log.update(log_activations) 387 | 388 | wandb.log(dict_to_log) 389 | print( 390 | f"step: {real_step}, loss: {loss_batch.item()}, lr {[group['lr'] for group in inner_optimizer.param_groups][0]}" 391 | ) 392 | loss_batch = 0 393 | 394 | # Save checkpoint every 'checkpoint_interval' steps 395 | if local_rank == 0 and checkpoint_interval is not None and real_step % checkpoint_interval == 0: 396 | print(f"saving at step {real_step}, step {step+1}") 397 | save_checkpoint( 398 | real_step, 399 | model, 400 | inner_optimizer, 401 | outer_optimizer, 402 | scheduler, 403 | loss, 404 | checkpoint_path, 405 | training_date, 406 | project, 407 | ) 408 | 409 | print("Training completed.") 410 | wandb.finish() 411 | 412 | 413 | if __name__ == "__main__": 414 | ddp_setup() 415 | app() 416 | destroy_process_group() 417 | -------------------------------------------------------------------------------- /open_diloco/train_fsdp.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | to test quickly do 4 | torchrun --nproc_per_node=2 \ 5 | train_fsdp.py --per-device-train-batch-size 8 --total-batch-size 128 --lr 1e-2 --path-model ../tests/models/llama-2m-fresh \ 6 | --no-torch-compile --log-activations-steps 5 --fake-data --max-steps 20 7 | """ 8 | 9 | from functools import partial 10 | import os 11 | import time 12 | from contextlib import nullcontext 13 | import datetime 14 | from typing import Any, Literal 15 | 16 | from pydantic import model_validator 17 | import torch 18 | from pydantic_config import parse_argv, BaseConfig 19 | from datasets import load_dataset 20 | from datasets.distributed import split_dataset_by_node 21 | from torch.distributed import destroy_process_group, init_process_group 22 | 23 | from torchdata.stateful_dataloader import StatefulDataLoader 24 | from transformers import ( 25 | AutoTokenizer, 26 | DataCollatorForLanguageModeling, 27 | LlamaConfig, 28 | LlamaForCausalLM, 29 | get_cosine_schedule_with_warmup, 30 | ) 31 | from torch.distributed.fsdp import ( 32 | FullyShardedDataParallel as FSDP, 33 | ShardingStrategy, 34 | MixedPrecision, 35 | ) 36 | from torch.distributed.device_mesh import init_device_mesh 37 | 38 | from open_diloco.ckpt_utils import ( 39 | CKPT_PREFIX, 40 | CkptConfig, 41 | check_checkpoint_path_access, 42 | delete_old_checkpoints, 43 | get_diloco_rank_dir_name, 44 | get_resume_info, 45 | load_checkpoint, 46 | save_checkpoint, 47 | ) 48 | from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer 49 | from open_diloco.utils import WandbLogger, DummyLogger 50 | 51 | from hivemind.dht.dht import DHT 52 | from hivemind.utils.networking import log_visible_maddrs 53 | from hivemind.optim.optimizer import logger 54 | 55 | 56 | from open_diloco.utils import ( 57 | FakeTokenizedDataset, 58 | get_compression_kwargs, 59 | get_sharding_strategy, 60 | register_metrics_hooks, 61 | ) 62 | 63 | 64 | TIMEOUT_NCCL_MINUTES = os.environ.get("TIMEOUT_NCCL_MINUTES", 120) 65 | TARGET_LAYER_ACTIVATIONS = ["self_attn", "lm_head"] 66 | TEST_VOCAB_SIZE = 1024 67 | 68 | 69 | # Function to initialize the distributed process group 70 | def ddp_setup(): 71 | init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=TIMEOUT_NCCL_MINUTES)) 72 | torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) 73 | 74 | 75 | def log(message): 76 | logger.info(f"[rank {os.environ['LOCAL_RANK']}] {message}") 77 | 78 | 79 | class HvConfig(BaseConfig): 80 | outer_lr: float = 0.7 81 | local_steps: int = 500 82 | initial_peers: list[str] | None = None 83 | host_maddrs: list[str] = ["/ip4/0.0.0.0/tcp/0"] 84 | announce_maddrs: list[str] | None = None 85 | matchmaking_time: float | None = None 86 | averaging_timeout: float | None = None 87 | hivemind_compression: Literal["fp16", "scaled-fp16", "uniform8bit", "quantile8bit", "blockwise8bit"] | None = None 88 | all_reduce_strategy: AllReduceStrategy = AllReduceStrategy.WAIT_FOR_ALL 89 | timeout_waiting_for_peers: float | None = None 90 | skip_load_from_peers: bool = False 91 | world_rank: int 92 | galaxy_size: int 93 | fail_rank_drop: bool = False # fail if we lose a diloco worker 94 | 95 | @model_validator(mode="before") 96 | def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]: 97 | """This allow to only pass a string and it will still be cast as a list""" 98 | for arg_name in ["initial_peers", "host_maddrs", "announce_maddrs"]: 99 | if arg_name in values.keys() and isinstance(values[arg_name], str): 100 | values[arg_name] = [values[arg_name]] 101 | return values 102 | 103 | 104 | class Config(BaseConfig): 105 | path_model: str = "PrimeIntellect/llama-150m-fresh" 106 | torch_compile: bool = True 107 | attn_implementation: str = "sdpa" 108 | # Data 109 | dataset_name_or_path: str = "allenai/c4" 110 | seq_length: int = 1024 111 | c4_tiny: bool = False 112 | num_workers: int = 4 113 | # Optimization 114 | lr: float = 4e-4 115 | total_batch_size: int = 512 116 | per_device_train_batch_size: int = 32 117 | warmup_steps: int = 1000 118 | total_steps: int = 88_000 119 | sharding_strategy: str = "NO_SHARD" 120 | precision: Literal["fp16-mixed", "bf16-mixed", "32-true"] = "fp16-mixed" 121 | # Checkpointing and logging 122 | project: str = "hivemind_debug" 123 | metric_logger_type: Literal["wandb", "dummy"] = "wandb" 124 | log_activations_steps: int | None = None 125 | ckpt: CkptConfig = CkptConfig() 126 | # Hivemind 127 | hv: HvConfig | None = None # if no hv config then hivemind is disabled 128 | fake_data: bool = False 129 | max_steps: int | None = None 130 | 131 | 132 | def get_dataloader(tokenizer, world_size, rank, local_rank, config: Config) -> StatefulDataLoader: 133 | if config.fake_data: 134 | train_dataset = FakeTokenizedDataset(config.seq_length, TEST_VOCAB_SIZE) 135 | else: 136 | ds = load_dataset(config.dataset_name_or_path, "en", streaming=True) 137 | 138 | def tokenize_function(data): 139 | outputs = tokenizer( 140 | data["text"], 141 | truncation=True, 142 | max_length=config.seq_length, 143 | padding="max_length", 144 | ) 145 | return outputs 146 | 147 | tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["text", "timestamp", "url"])[ 148 | "train" 149 | ] 150 | 151 | if config.hv is not None: 152 | train_dataset = split_dataset_by_node( 153 | tokenized_datasets, 154 | world_size=config.hv.galaxy_size * world_size, 155 | rank=config.hv.world_rank * world_size + local_rank, 156 | ) 157 | 158 | else: 159 | train_dataset = split_dataset_by_node(tokenized_datasets, world_size=world_size, rank=rank) 160 | 161 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 162 | 163 | return StatefulDataLoader( 164 | train_dataset, 165 | collate_fn=data_collator, 166 | batch_size=config.per_device_train_batch_size, 167 | num_workers=config.num_workers, 168 | ) 169 | 170 | 171 | def get_model(config: Config) -> LlamaForCausalLM: 172 | # Load model 173 | config_model = LlamaConfig.from_pretrained(config.path_model, attn_implementation=config.attn_implementation) 174 | return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model) 175 | 176 | 177 | def train(config: Config): 178 | sharding_strategy = get_sharding_strategy(config.sharding_strategy) 179 | local_rank = int(os.environ["LOCAL_RANK"]) 180 | world_size = int(os.environ["WORLD_SIZE"]) 181 | rank = int(os.environ["RANK"]) 182 | 183 | world_messenger_hv = config.hv is not None and local_rank == 0 184 | 185 | # batch_size is the total batch size for all GPUs 186 | assert config.total_batch_size % world_size == 0 187 | batch_size = config.total_batch_size // world_size 188 | 189 | assert batch_size % config.per_device_train_batch_size == 0 190 | gradient_accumulation_steps = batch_size // config.per_device_train_batch_size 191 | 192 | if config.hv is not None: 193 | sharding_strategy = ShardingStrategy.NO_SHARD 194 | log("Hivemind is used, ShardingStrategy.NO_SHARD is used") 195 | 196 | resume_from_ckpt, resume_path = get_resume_info(config.ckpt) 197 | 198 | if rank == 0: 199 | logger_cls = WandbLogger if config.metric_logger_type == "wandb" else DummyLogger 200 | metric_logger = logger_cls(project=config.project, config=config.model_dump(), resume=resume_from_ckpt) 201 | 202 | if config.hv is not None: 203 | log("hivemind diloco enabled") 204 | 205 | if world_messenger_hv: 206 | dht = DHT( 207 | start=True, 208 | initial_peers=config.hv.initial_peers, 209 | host_maddrs=config.hv.host_maddrs, 210 | announce_maddrs=config.hv.announce_maddrs, 211 | ) 212 | log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=False) 213 | 214 | if local_rank == 0: 215 | check_checkpoint_path_access(config.ckpt.path, rank, config.hv.world_rank if config.hv else None) 216 | 217 | # DataLoader preparation 218 | tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True) 219 | tokenizer.pad_token = "" # Ensure pad token is set for models that need it 220 | 221 | train_dataloader = get_dataloader(tokenizer, world_size, rank, local_rank, config) 222 | 223 | model = get_model(config) 224 | model = model.to(local_rank) 225 | 226 | half_precision = config.precision == "fp16-mixed" or config.precision == "bf16-mixed" 227 | half_precision_dtype = torch.bfloat16 if config.precision == "bf16-mixed" else torch.float16 228 | scaler = torch.cuda.amp.GradScaler(enabled=config.precision == "fp16-mixed") 229 | 230 | if sharding_strategy in [ 231 | ShardingStrategy._HYBRID_SHARD_ZERO2, 232 | ShardingStrategy.HYBRID_SHARD, 233 | ]: 234 | local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) 235 | nnodes = world_size // local_world_size 236 | device_mesh = init_device_mesh("cuda", (nnodes, local_world_size), mesh_dim_names=("global", "local")) 237 | else: 238 | device_mesh = None 239 | model = FSDP( 240 | model, 241 | sharding_strategy=sharding_strategy, 242 | mixed_precision=MixedPrecision(param_dtype=half_precision_dtype) if half_precision else None, 243 | use_orig_params=config.torch_compile, 244 | device_mesh=device_mesh, 245 | ) 246 | if config.torch_compile: 247 | model = torch.compile(model) 248 | 249 | # Setup optimizers 250 | inner_optimizer = partial(torch.optim.AdamW, lr=config.lr, weight_decay=0.1, betas=(0.9, 0.95)) # noqa: F821 251 | 252 | if config.hv is not None: 253 | outer_optimizer = partial(torch.optim.SGD, lr=config.hv.outer_lr, momentum=0.9, nesterov=True) 254 | 255 | def scheduler_fn(opt): 256 | return get_cosine_schedule_with_warmup( 257 | opt, 258 | num_warmup_steps=config.warmup_steps, 259 | num_training_steps=config.total_steps, 260 | ) 261 | 262 | if config.hv is not None: 263 | if resume_from_ckpt: 264 | # We need to load with a fake optimizer to set the model parameters correctly before initializing the DiLoCoOptimizer 265 | # This is because the DiLoCoOptimizer makes a copy of the model parameters for the state averager which is hard to update later 266 | # We also need to do this on follower workers so that the world_messenger has friends to talk to when it does its two loads 267 | # Otherwise the world messenger will get lonely and hang 268 | fake_optimizer = inner_optimizer(model.parameters()) 269 | last_loss = load_checkpoint( 270 | checkpoint_path=os.path.join(resume_path, get_diloco_rank_dir_name(config.hv.world_rank)), 271 | model=model, 272 | optimizer=fake_optimizer, 273 | ) 274 | del fake_optimizer 275 | 276 | if resume_from_ckpt: 277 | if config.hv is not None: 278 | ckpt_path = os.path.join(resume_path, get_diloco_rank_dir_name(config.hv.world_rank)) 279 | else: 280 | ckpt_path = resume_path 281 | 282 | if world_messenger_hv: 283 | diloco_args = dict( 284 | dht=dht, 285 | run_id="llama", 286 | batch_size=batch_size, 287 | num_inner_steps=config.hv.local_steps, 288 | outer_optimizer=outer_optimizer, 289 | inner_optimizer=inner_optimizer, 290 | scheduler=None, 291 | params=model.parameters(), 292 | delay_optimizer_step=False, 293 | delay_grad_averaging=False, 294 | verbose=True, 295 | all_reduce_strategy=config.hv.all_reduce_strategy, 296 | timeout_waiting_for_peers=config.hv.timeout_waiting_for_peers, 297 | ) 298 | 299 | diloco_args.update(get_compression_kwargs(config.hv.hivemind_compression)) 300 | 301 | if config.hv.averaging_timeout is not None: 302 | diloco_args["averaging_timeout"] = config.hv.averaging_timeout 303 | 304 | if config.hv.matchmaking_time is not None: 305 | diloco_args["matchmaking_time"] = config.hv.matchmaking_time 306 | 307 | optimizer = DiLoCoOptimizer(**diloco_args) 308 | 309 | scheduler = scheduler_fn( 310 | optimizer.inner_optimizer 311 | ) # scheduler(optimizer) should work but better to make it explicit here 312 | 313 | if resume_from_ckpt: 314 | last_loss = load_checkpoint( 315 | checkpoint_path=ckpt_path, 316 | model=model, 317 | optimizer=optimizer.inner_optimizer, 318 | scheduler=scheduler, 319 | outer_optimizer=optimizer.state_averager.optimizer, 320 | scaler=scaler, 321 | data_loader=train_dataloader, 322 | ) 323 | start_step = scheduler.last_epoch 324 | else: 325 | start_step = 0 326 | 327 | else: 328 | optimizer = inner_optimizer(model.parameters()) 329 | scheduler = scheduler_fn(optimizer) 330 | if resume_from_ckpt: 331 | last_loss = load_checkpoint( 332 | checkpoint_path=ckpt_path, 333 | model=model, 334 | optimizer=optimizer, 335 | scheduler=scheduler, 336 | scaler=scaler, 337 | data_loader=train_dataloader, 338 | ) 339 | start_step = scheduler.last_epoch 340 | else: 341 | start_step = 0 342 | 343 | if resume_from_ckpt: 344 | log(f"Resumed from checkpoint at step {start_step} with loss {last_loss}") 345 | 346 | model.train() 347 | 348 | if world_messenger_hv and not config.hv.skip_load_from_peers: 349 | optimizer.load_state_from_peers() 350 | 351 | current_time = time.time() 352 | log(f"starting from step {start_step}") 353 | 354 | loss_batch = 0 355 | 356 | if world_messenger_hv: 357 | max_num_peers = 0 358 | 359 | log_activations = {} 360 | 361 | for step, batch in enumerate(iterable=train_dataloader, start=start_step * gradient_accumulation_steps): 362 | real_step = (step + 1) // gradient_accumulation_steps 363 | is_accumulating = bool((step + 1) % gradient_accumulation_steps) 364 | 365 | logging_activations_steps = ( 366 | config.log_activations_steps is not None and real_step % config.log_activations_steps == 0 367 | ) 368 | 369 | if logging_activations_steps: 370 | handles = register_metrics_hooks( 371 | model, TARGET_LAYER_ACTIVATIONS, log_activations, gradient_accumulation_steps 372 | ) 373 | 374 | for key in batch.keys(): 375 | batch[key] = batch[key].to("cuda") 376 | 377 | with model.no_sync() if is_accumulating else nullcontext(): 378 | outputs = model(**batch) 379 | loss = outputs.loss / gradient_accumulation_steps 380 | 381 | loss_batch += loss.detach() 382 | 383 | scaler.scale(loss).backward() 384 | 385 | if logging_activations_steps: 386 | for handle in handles: 387 | handle.remove() 388 | 389 | if not is_accumulating: 390 | if world_messenger_hv: 391 | scaler.unscale_(optimizer=optimizer.inner_optimizer) 392 | else: 393 | scaler.unscale_(optimizer=optimizer) 394 | 395 | model.clip_grad_norm_(1.0) # gradient clipping 396 | 397 | if world_messenger_hv: 398 | optimizer.step(scaler=scaler) 399 | 400 | # todo(sami): refactor to use built in pytorch mechanism to handle scaler manually 401 | # should allow to just do scaler.step(optimizer) 402 | else: 403 | scaler.step(optimizer) 404 | 405 | scaler.update() 406 | 407 | scheduler.step() 408 | optimizer.zero_grad() 409 | 410 | if config.hv is not None: 411 | if int(real_step) % config.hv.local_steps == 0: 412 | for param in model.parameters(): 413 | torch.distributed.broadcast(param.data, src=0) 414 | 415 | if rank == 0: 416 | total_samples = real_step * config.total_batch_size 417 | effective_step = real_step 418 | 419 | if config.hv is not None: 420 | # Note that this assumes that we have the right amount of worker since t0. 421 | # Not robust to off/on ramping 422 | effective_step = real_step * config.hv.galaxy_size 423 | total_samples = real_step * config.total_batch_size * config.hv.galaxy_size 424 | 425 | metrics = { 426 | "Loss": loss_batch.item(), 427 | "step": real_step, 428 | "lr": [group["lr"] for group in optimizer.param_groups][0], 429 | "Perplexity": torch.exp(loss_batch).item(), 430 | "effective_step": effective_step, # at each step the we have compute total_batch_size. Independent of the number of GPUs 431 | "total_samples": total_samples, 432 | "time_taken": time.time() - current_time, 433 | "tokens_per_second": config.seq_length * config.total_batch_size / (time.time() - current_time), 434 | } 435 | 436 | if world_messenger_hv: 437 | outer_lr = [group["lr"] for group in optimizer.state_averager.optimizer.param_groups][0] 438 | num_peers = optimizer.tracker.global_progress.num_peers 439 | 440 | max_num_peers = max(max_num_peers, num_peers) 441 | 442 | if num_peers == 0: 443 | num_peers = 1 444 | 445 | metrics["outer_lr"] = outer_lr 446 | metrics["num_peers"] = num_peers 447 | 448 | if logging_activations_steps: 449 | metrics.update(log_activations) 450 | log_activations = {} 451 | 452 | if world_messenger_hv and num_peers < max_num_peers: 453 | log(message=f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}") 454 | if config.hv.fail_rank_drop: 455 | raise ValueError( 456 | f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}" 457 | ) 458 | 459 | current_time = time.time() 460 | 461 | metric_logger.log(metrics) 462 | 463 | if config.hv is None: 464 | log( 465 | f"step: {real_step}, loss: {loss_batch.item()}, lr {[group['lr'] for group in optimizer.param_groups][0]}" 466 | ) 467 | 468 | # Save checkpoint every 'checkpoint_interval' steps 469 | if config.ckpt.interval is not None and real_step % config.ckpt.interval == 0: 470 | log(f"saving at step {real_step}, step {step+1}") 471 | ckpt_path = os.path.join(config.ckpt.path, f"{CKPT_PREFIX}_{int(real_step)}") 472 | 473 | if config.hv: 474 | ckpt_path = os.path.join(ckpt_path, get_diloco_rank_dir_name(config.hv.world_rank)) 475 | 476 | if world_messenger_hv: 477 | assert isinstance(optimizer, DiLoCoOptimizer) 478 | with optimizer.tracker.pause_updates(): 479 | save_checkpoint( 480 | checkpoint_path=ckpt_path, 481 | model=model, 482 | optimizer=optimizer.inner_optimizer, 483 | scheduler=scheduler, 484 | outer_optimizer=optimizer.state_averager.optimizer, 485 | loss=loss_batch.item(), 486 | scaler=scaler, 487 | data_loader=train_dataloader, 488 | save_global_state=True, 489 | ) 490 | else: 491 | save_checkpoint( 492 | checkpoint_path=ckpt_path, 493 | model=model, 494 | optimizer=optimizer, 495 | scheduler=scheduler, 496 | loss=loss_batch.item(), 497 | scaler=scaler, 498 | data_loader=train_dataloader, 499 | save_global_state=rank == 0, 500 | ) 501 | 502 | if local_rank == 0: 503 | # only the rank 0 deletes the checkpoints 504 | if config.ckpt.topk is not None: 505 | ckpt_deleted = delete_old_checkpoints(config.ckpt.path, config.ckpt.topk) 506 | if ckpt_deleted: 507 | log(f"Deleted old checkpoints: {ckpt_deleted}") 508 | 509 | loss_batch = 0 510 | 511 | if config.max_steps is not None and real_step >= config.max_steps: 512 | break 513 | 514 | log("Training completed.") 515 | if rank == 0: 516 | metric_logger.finish() 517 | 518 | 519 | if __name__ == "__main__": 520 | # Allow eager fallback during production so that that the training runs dont die 521 | # However, in development, we want to know that we broke torch compile 522 | torch._dynamo.config.suppress_errors = "PRIME_INTELLECT_DEV" not in os.environ 523 | torch.set_float32_matmul_precision("high") 524 | ddp_setup() 525 | config = Config(**parse_argv()) 526 | train(config) 527 | destroy_process_group() 528 | -------------------------------------------------------------------------------- /open_diloco/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from functools import partial 3 | import pickle 4 | from typing import Any, Generator, Protocol 5 | 6 | import torch 7 | from torch.utils.hooks import RemovableHandle 8 | from torch.distributed.fsdp import ShardingStrategy 9 | from torch.utils.data import IterableDataset 10 | import wandb 11 | 12 | 13 | _WRAPPED_NAME_TO_REMOVE = ["_forward_module.", "_fsdp_wrapped_module.", "_orig_mod."] 14 | 15 | 16 | def _remove_fsdp_prefix(name: str) -> str: 17 | for prefix in _WRAPPED_NAME_TO_REMOVE: 18 | if prefix in name: 19 | name = name.replace(prefix, "") 20 | return name 21 | 22 | 23 | @torch.compiler.disable() 24 | @torch.no_grad() 25 | def log_activations_hook( 26 | _mod: torch.nn.Module, 27 | _inp: torch.Tensor, 28 | outp: torch.Tensor | tuple[torch.Tensor, ...], 29 | mod_name: str, 30 | gradient_accumulation_steps: int, 31 | log_activations: dict[str, float], 32 | ) -> None: 33 | if isinstance(outp, tuple): 34 | outp = outp[0] 35 | norm = outp.norm(p=2) / gradient_accumulation_steps 36 | name = _remove_fsdp_prefix(mod_name) 37 | if f"activation/{name}" not in log_activations: 38 | log_activations[f"activation/{name}"] = norm 39 | else: 40 | log_activations[f"activation/{name}"] += norm 41 | 42 | 43 | def register_metrics_hooks( 44 | model: torch.nn.Module, 45 | target_layers: list[str], 46 | log_activations: dict[str, torch.Tensor], 47 | gradient_accumulation_steps: int, 48 | ) -> list[RemovableHandle]: 49 | """ 50 | this function take a torch module, a list of layer name and apply a hook function that 51 | monitor the output norm of the layers. 52 | """ 53 | handles = [] 54 | for name, mod in model.named_modules(): 55 | for layer in target_layers: 56 | if name.endswith(layer): 57 | handle = mod.register_forward_hook( 58 | partial( 59 | log_activations_hook, 60 | log_activations=log_activations, 61 | mod_name=name, 62 | gradient_accumulation_steps=gradient_accumulation_steps, 63 | ) 64 | ) 65 | handles.append(handle) 66 | 67 | return handles 68 | 69 | 70 | def _round_str(x: float): 71 | return f"{x:.4f}" 72 | 73 | 74 | def _round_flatten(a: torch.Tensor, max_size: int = 1000): 75 | bounds = int(max_size**0.5) 76 | return ",".join(_round_str(i) for i, _ in zip(a[:bounds, :bounds].flatten(), range(max_size))) 77 | 78 | 79 | def hash_tensor_content(a: torch.Tensor, max_size: int = 1000) -> str: 80 | return hashlib.md5(_round_flatten(a, max_size=max_size).encode("utf-8")).hexdigest() 81 | 82 | 83 | def get_compression_kwargs(hivemind_compression: str | None) -> dict: 84 | """Return the compression kwargs for hivemind optimizer based on the hivemind_compression argument.""" 85 | ret_kwargs = {} 86 | 87 | if hivemind_compression is None: 88 | from hivemind import NoCompression 89 | 90 | ret_kwargs["grad_compression"] = NoCompression() 91 | ret_kwargs["state_averaging_compression"] = NoCompression() 92 | 93 | elif hivemind_compression == "fp16": 94 | from hivemind import Float16Compression 95 | 96 | ret_kwargs["grad_compression"] = Float16Compression() 97 | ret_kwargs["state_averaging_compression"] = Float16Compression() 98 | elif hivemind_compression == "scaled-fp16": 99 | from hivemind import ScaledFloat16Compression 100 | 101 | ret_kwargs["grad_compression"] = ScaledFloat16Compression() 102 | ret_kwargs["state_averaging_compression"] = ScaledFloat16Compression() 103 | elif hivemind_compression == "uniform8bit": 104 | from hivemind import Uniform8BitQuantization 105 | 106 | ret_kwargs["grad_compression"] = Uniform8BitQuantization() 107 | ret_kwargs["state_averaging_compression"] = Uniform8BitQuantization() 108 | elif hivemind_compression == "quantile8bit": 109 | from hivemind import Quantile8BitQuantization 110 | 111 | ret_kwargs["grad_compression"] = Quantile8BitQuantization() 112 | ret_kwargs["state_averaging_compression"] = Quantile8BitQuantization() 113 | 114 | elif hivemind_compression == "blockwise8bit": 115 | from hivemind import BlockwiseQuantization 116 | 117 | ret_kwargs["grad_compression"] = BlockwiseQuantization() 118 | ret_kwargs["state_averaging_compression"] = BlockwiseQuantization() 119 | else: 120 | raise ValueError(f"Invalid hivemind_compression: {hivemind_compression}") 121 | return ret_kwargs 122 | 123 | 124 | def found_inf_grad(optimizer: torch.optim.Optimizer, scaler: torch.cuda.amp.GradScaler) -> bool: 125 | """ 126 | this function check if the scaler has found inf grad for the optimizer. It does by looking up the optimizer state 127 | regsited inside the scaler. Code is mostly copied/inspired by the torch GradScaler codebase. 128 | """ 129 | if not scaler._enabled: 130 | return False 131 | 132 | optimizer_state = scaler._per_optimizer_states[id(optimizer)] 133 | assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer." 134 | 135 | return sum(v.item() for v in optimizer_state["found_inf_per_device"].values()) > 0 136 | 137 | 138 | def get_sharding_strategy(sharding_strategy: str) -> ShardingStrategy: 139 | if sharding_strategy == "FULL_SHARD": 140 | return ShardingStrategy.FULL_SHARD 141 | elif sharding_strategy == "SHARD_GRAD_OP": 142 | return ShardingStrategy.SHARD_GRAD_OP 143 | elif sharding_strategy == "NO_SHARD": 144 | return ShardingStrategy.NO_SHARD 145 | elif sharding_strategy == "HYBRID_SHARD": 146 | return ShardingStrategy.HYBRID_SHARD 147 | elif sharding_strategy == "_HYBRID_SHARD_ZERO2": 148 | return ShardingStrategy._HYBRID_SHARD_ZERO2 149 | else: 150 | raise ValueError( 151 | f"Invalid sharding_strategy: {sharding_strategy}. Please choose 'FULL_SHARD', 'SHARD_GRAD_OP', 'NO_SHARD', 'HYBRID_SHARD', or '_HYBRID_SHARD_ZERO2'." 152 | ) 153 | 154 | 155 | class FakeTokenizedDataset(IterableDataset): 156 | """This is a dummy dataset that generates random sequences of length seq_len and vocab_size""" 157 | 158 | def __init__(self, seq_len: int, vocab_size: int): 159 | self.seq_len = seq_len 160 | self.vocab_size = vocab_size 161 | assert vocab_size > 3, "Vocab size must be greater than 3" 162 | 163 | def __iter__(self) -> Generator[dict[str, Any], Any, None]: 164 | while True: 165 | input_ids = torch.randint(3, self.vocab_size, (self.seq_len,)).tolist() 166 | attention_mask = [1] * self.seq_len 167 | yield {"input_ids": input_ids, "attention_mask": attention_mask} 168 | 169 | 170 | class Logger(Protocol): 171 | def __init__(self, project, config): ... 172 | 173 | def log(self, metrics: dict[str, Any]): ... 174 | 175 | def finish(self): ... 176 | 177 | 178 | class WandbLogger: 179 | def __init__(self, project, config, resume: bool): 180 | wandb.init( 181 | project=project, config=config, resume="auto" if resume else None 182 | ) # make wandb reuse the same run id if possible 183 | 184 | def log(self, metrics: dict[str, Any]): 185 | wandb.log(metrics) 186 | 187 | def finish(self): 188 | wandb.finish() 189 | 190 | 191 | class DummyLogger: 192 | def __init__(self, project, config, *args, **kwargs): 193 | self.project = project 194 | self.config = config 195 | open(project, "a").close() # Create an empty file at the project path 196 | 197 | self.data = [] 198 | 199 | def log(self, metrics: dict[str, Any]): 200 | self.data.append(metrics) 201 | 202 | def finish(self): 203 | with open(self.project, "wb") as f: 204 | pickle.dump(self.data, f) 205 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "OpenDiloco" 7 | version = "0.1.0" 8 | dynamic = ["dependencies"] 9 | 10 | [tool.setuptools] 11 | packages = ["open_diloco"] 12 | 13 | [tool.ruff] 14 | line-length = 120 # thanks to johannes screen's 15 | 16 | [tool.setuptools.dynamic] 17 | dependencies = {file = ["requirements.txt"]} 18 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | ruff>=0.4.7 2 | pre-commit>=3.7.0 3 | pytest -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers~=4.40 2 | datasets>=2.19.1 3 | wandb>=0.16.4 4 | cyclopts>=2.6.1 5 | fsspec[gcs]>=2024.3.1 6 | torch==2.3.1 7 | hivemind @ git+https://github.com/learning-at-home/hivemind.git@213bff9 8 | pydantic_config @ git+https://github.com/samsja/pydantic_config.git@8e19e05 9 | 10 | -------------------------------------------------------------------------------- /scripts/pull-c4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Use HTTP: bash scripts/pull-c4.sh 3 | # Use GIT: USE_GIT=1 bash scripts/pull-c4.sh 4 | set -e 5 | # The git server seems to be faster than the http server 6 | # However, it requires ssh keys to be set up 7 | if [ "$USE_GIT" = "1" ]; then 8 | echo "Using git" 9 | GIT_LFS_SKIP_SMUDGE=1 git clone git@hf.co:datasets/allenai/c4 10 | else 11 | echo "Using http" 12 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/datasets/allenai/c4 13 | fi 14 | 15 | echo "Pulling LFS files" 16 | cd c4 17 | git lfs pull -I 'en/*.json.gz' 18 | -------------------------------------------------------------------------------- /scripts/pull-model.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # Example: ./scripts/pull-model.py PrimeIntellect/llama-1b-fresh 3 | import sys 4 | from transformers import AutoModelForCausalLM 5 | 6 | MODEL = sys.argv[1] if len(sys.argv) >= 2 else "PrimeIntellect/llama-1b-fresh" 7 | model = AutoModelForCausalLM.from_pretrained(MODEL) 8 | 9 | print(model) 10 | -------------------------------------------------------------------------------- /tests/models/llama-2m-fresh/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "configs/config_2m.json", 3 | "architectures": [ 4 | "LlamaForCausalLM" 5 | ], 6 | "attention_bias": false, 7 | "attention_dropout": 0.0, 8 | "bos_token_id": 1, 9 | "eos_token_id": 2, 10 | "hidden_act": "silu", 11 | "hidden_size": 64, 12 | "initializer_range": 0.02, 13 | "intermediate_size": 256, 14 | "max_position_embeddings": 2048, 15 | "mlp_bias": false, 16 | "model_type": "llama", 17 | "num_attention_heads": 2, 18 | "num_hidden_layers": 2, 19 | "num_key_value_heads": 2, 20 | "pretraining_tp": 1, 21 | "rms_norm_eps": 1e-05, 22 | "rope_scaling": null, 23 | "rope_theta": 10000.0, 24 | "tie_word_embeddings": false, 25 | "torch_dtype": "float32", 26 | "transformers_version": "4.41.2", 27 | "use_cache": false, 28 | "vocab_size": 1024 29 | } 30 | -------------------------------------------------------------------------------- /tests/models/llama-2m-fresh/generation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_from_model_config": true, 3 | "bos_token_id": 1, 4 | "eos_token_id": 2, 5 | "transformers_version": "4.41.2", 6 | "use_cache": false 7 | } 8 | -------------------------------------------------------------------------------- /tests/models/llama-2m-fresh/model.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrimeIntellect-ai/OpenDiloco/2d750e58a692ce1424d2a2366b2b3de1f42c9bf1/tests/models/llama-2m-fresh/model.safetensors -------------------------------------------------------------------------------- /tests/test_diloco_hivemind.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import gc 3 | import multiprocessing as mp 4 | import time 5 | from functools import partial 6 | from typing import List 7 | 8 | import pytest 9 | 10 | import hivemind 11 | from hivemind.dht import DHT 12 | 13 | from hivemind.utils.crypto import RSAPrivateKey 14 | from hivemind.utils.mpfuture import MPFuture 15 | 16 | import psutil 17 | 18 | from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoGradAverager, DiLoCoOptimizer 19 | 20 | 21 | @pytest.fixture(autouse=True, scope="module") 22 | def cleanup_children(): 23 | yield 24 | 25 | with RSAPrivateKey._process_wide_key_lock: 26 | RSAPrivateKey._process_wide_key = None 27 | 28 | gc.collect() # Call .__del__() for removed objects 29 | 30 | children = psutil.Process().children(recursive=True) 31 | if children: 32 | gone, alive = psutil.wait_procs(children, timeout=0.1) 33 | for child in alive: 34 | child.terminate() 35 | gone, alive = psutil.wait_procs(alive, timeout=1) 36 | for child in alive: 37 | child.kill() 38 | 39 | MPFuture.reset_backend() 40 | 41 | 42 | def launch_dht_instances(n_peers: int, **kwargs) -> List[DHT]: 43 | dhts = [DHT(start=True, **kwargs)] 44 | initial_peers = dhts[0].get_visible_maddrs() 45 | 46 | dhts.extend(DHT(initial_peers=initial_peers, start=True, await_ready=False, **kwargs) for _ in range(n_peers - 1)) 47 | for process in dhts[1:]: 48 | process.wait_until_ready() 49 | 50 | return dhts 51 | 52 | 53 | @pytest.mark.forked 54 | def test_allreduce_dilco_grad_averager(): 55 | import torch 56 | 57 | n_peers = 4 58 | 59 | def get_model(): 60 | return torch.nn.Linear(5, 1, bias=False) 61 | 62 | models = [get_model() for _ in range(n_peers)] 63 | offloaded_models = [get_model() for _ in range(n_peers)] 64 | optimizers = [torch.optim.SGD(model.parameters(), lr=0.1) for model in offloaded_models] 65 | 66 | dht_instances = launch_dht_instances(n_peers) 67 | averagers = [ 68 | DiLoCoGradAverager( 69 | main_parameters=tuple(model.parameters()), 70 | offloaded_optimizer=opt, 71 | dht=dht, 72 | target_group_size=4, 73 | min_matchmaking_time=15, 74 | prefix="mygroup", 75 | client_mode=False, 76 | auxiliary=False, 77 | start=True, 78 | ) 79 | for model, opt, dht in zip(models, optimizers, dht_instances) 80 | ] 81 | 82 | futures = [] 83 | for averager in averagers: 84 | futures.append(averager.step(wait=False)) 85 | for future in futures: 86 | result = future.result() 87 | for averager in averagers: 88 | assert averager.peer_id in result 89 | 90 | for averager in averagers: 91 | with averager.get_tensors() as averaged_pseudo_grads: 92 | for grad in averaged_pseudo_grads: 93 | assert not torch.isnan(grad).any(), "Averaged grad is nan" 94 | 95 | for process in averagers + dht_instances: 96 | process.shutdown() 97 | 98 | 99 | def test_load_and_save_state(): 100 | import torch 101 | import torch.nn as nn 102 | import torch.nn.functional as F 103 | 104 | inner_lr = 0.1 105 | outer_lr = 0.7 106 | 107 | model = nn.Linear(5, 1) 108 | features = torch.randn(100, 5) 109 | targets = features @ torch.randn(5, 1) 110 | 111 | def get_opt(): 112 | return DiLoCoOptimizer( 113 | run_id="test_run", 114 | batch_size=32, 115 | num_inner_steps=5, 116 | params=model.parameters(), 117 | outer_optimizer=partial(torch.optim.SGD, lr=outer_lr, nesterov=True, momentum=0.9), 118 | inner_optimizer=partial(torch.optim.AdamW, lr=inner_lr), 119 | scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=0.5, step_size=1), 120 | dht=hivemind.DHT(start=True), 121 | tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0), 122 | averager_opts=dict(request_timeout=0.5), 123 | matchmaking_time=1.0, 124 | averaging_timeout=5.0, 125 | verbose=False, 126 | ) 127 | 128 | opt1 = get_opt() 129 | 130 | for _ in range(2): 131 | batch = torch.randint(0, len(features), (32,)) 132 | 133 | loss = F.mse_loss(model(features[batch]), targets[batch]) 134 | 135 | loss.backward() 136 | assert loss.item() != 0, "Loss is zero, maybe gradient exploded." 137 | 138 | opt1.step() 139 | 140 | state = opt1.state_dict() 141 | 142 | opt2 = get_opt() 143 | opt2.load_state_dict(state) 144 | 145 | assert opt1.state_dict() == opt2.state_dict() 146 | 147 | assert opt1.state_averager.optimizer.param_groups[0]["lr"] == opt2.state_averager.optimizer.param_groups[0]["lr"] 148 | assert opt1.state_averager.optimizer.state_dict() == opt2.state_averager.optimizer.state_dict() 149 | 150 | assert opt1.state_averager.inner_optimizer.state_dict() == state["state_dict_inner"] 151 | assert opt1.state_averager.inner_optimizer.state_dict() == opt2.state_averager.inner_optimizer.state_dict() 152 | 153 | 154 | # for some reason this test does not pass, the code is correct tho (tested manually). 155 | # I (sami) still want to keep tracke of this test for the future. 156 | @pytest.mark.skip("skip test") 157 | @pytest.mark.parametrize( 158 | "strategy, expected_peer", [(AllReduceStrategy.NO_WAIT, 1), (AllReduceStrategy.WAIT_FOR_ALL, 4)] 159 | ) 160 | def test_strategy_all_reduce(strategy: AllReduceStrategy, expected_peer: int): 161 | dht = hivemind.DHT(start=True) 162 | 163 | import torch # putting import here for multi processing 164 | import torch.nn as nn 165 | import torch.nn.functional as F 166 | 167 | sleep_time = [0.5, 0.5, 0.5, 0.1] 168 | num_peers = len(sleep_time) 169 | 170 | on_time_peer = mp.Value(ctypes.c_int32, 0) 171 | 172 | batch_size = 16 173 | total_epochs = 10 174 | num_inner_steps = 5 175 | 176 | def run_trainer(sleep_time: float): 177 | features = torch.randn(100, 5) / 100 178 | targets = features @ torch.randn(5, 1) 179 | 180 | outer_lr = 0.7 181 | inner_lr = 0.1 182 | 183 | model = nn.Linear(5, 1) 184 | 185 | assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module" 186 | 187 | optimizer = DiLoCoOptimizer( 188 | run_id="test_run", 189 | batch_size=batch_size, 190 | num_inner_steps=num_inner_steps, 191 | params=model.parameters(), 192 | all_reduce_strategy=strategy, 193 | timeout_waiting_for_peers=None if strategy == AllReduceStrategy.NO_WAIT else 10.0, 194 | outer_optimizer=partial(torch.optim.SGD, lr=outer_lr, nesterov=True, momentum=0.9), 195 | inner_optimizer=partial(torch.optim.AdamW, lr=inner_lr), 196 | scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=0.5, step_size=1), 197 | dht=hivemind.DHT( 198 | initial_peers=dht.get_visible_maddrs(), 199 | start=True, 200 | ), 201 | tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0), 202 | averager_opts=dict(request_timeout=0.5), 203 | matchmaking_time=2.0, 204 | averaging_timeout=5.0, 205 | verbose=False, 206 | ) 207 | time.sleep(sleep_time) 208 | 209 | optimizer.load_state_from_peers() 210 | 211 | for _ in range(total_epochs): 212 | time.sleep(sleep_time) 213 | batch = torch.randint(0, len(features), (batch_size,)) 214 | 215 | loss = F.mse_loss(model(features[batch]), targets[batch]) 216 | 217 | loss.backward() 218 | 219 | optimizer.step() 220 | 221 | optimizer.zero_grad() 222 | 223 | if optimizer.local_epoch == optimizer.tracker.global_epoch: 224 | on_time_peer.value += 1 225 | 226 | time.sleep(1.0) 227 | optimizer.shutdown() 228 | 229 | peers = [] 230 | for index in range(num_peers): 231 | peers.append( 232 | mp.Process( 233 | target=run_trainer, 234 | name=f"trainer-{index}", 235 | kwargs=dict(sleep_time=sleep_time[index]), 236 | ) 237 | ) 238 | 239 | for peer in peers: 240 | peer.start() 241 | 242 | for peer in peers: 243 | peer.join() 244 | 245 | assert on_time_peer.value == expected_peer 246 | 247 | for process in peers: 248 | process.terminate() 249 | -------------------------------------------------------------------------------- /tests/test_training/test_train.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import subprocess 3 | import numpy as np 4 | import pytest 5 | import socket 6 | from hivemind.dht.dht import DHT 7 | from open_diloco.ckpt_utils import CKPT_PREFIX 8 | 9 | 10 | def get_random_available_port(): 11 | # https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number 12 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 13 | s.bind(("", 0)) 14 | return s.getsockname()[1] 15 | 16 | 17 | @pytest.fixture(scope="session") 18 | def random_available_port(): 19 | return get_random_available_port() 20 | 21 | 22 | @pytest.fixture 23 | def config() -> list[str]: 24 | return [ 25 | "--path_model", 26 | "tests/models/llama-2m-fresh", 27 | "--fake_data", 28 | "--no-torch_compile", 29 | "--lr", 30 | "1e-2", 31 | "--per_device_train_batch_size", 32 | "8", 33 | "--total_batch_size", 34 | "16", 35 | "--max_steps", 36 | "50", 37 | "--metric_logger_type", 38 | "dummy", 39 | ] 40 | 41 | 42 | @pytest.mark.parametrize("num_gpu", [2]) 43 | def test_multi_gpu_ckpt(config, random_available_port, num_gpu, tmp_path): 44 | ckpt_path = f"{tmp_path}/ckpt" 45 | log_file_1 = f"{tmp_path}/log1.json" 46 | log_file_2 = f"{tmp_path}/log2.json" 47 | 48 | run_1 = ["--ckpt.path", ckpt_path, "--ckpt.interval", "10", "--project", log_file_1] 49 | 50 | cmd = [ 51 | "torchrun", 52 | f"--nproc_per_node={num_gpu}", 53 | "--rdzv-endpoint", 54 | f"localhost:{random_available_port}", 55 | "open_diloco/train_fsdp.py", 56 | *config, 57 | ] 58 | 59 | result = subprocess.run(cmd + run_1) 60 | 61 | if result.returncode != 0: 62 | pytest.fail(f"Process {result} failed {result.stderr}") 63 | 64 | run_2 = ["--ckpt.path", ckpt_path, "--ckpt.resume", f"{ckpt_path}/{CKPT_PREFIX}_20", "--project", log_file_2] 65 | 66 | results_resume = subprocess.run(cmd + run_2) 67 | 68 | if results_resume.returncode != 0: 69 | pytest.fail(f"Process {result} failed {result.stderr}") 70 | 71 | with open(log_file_1, "rb") as f: 72 | log1 = pickle.load(f) 73 | with open(log_file_2, "rb") as f: 74 | log2 = pickle.load(f) 75 | 76 | log1 = {data["step"]: [data["Loss"], data["lr"]] for data in log1} 77 | log2 = {data["step"]: [data["Loss"], data["lr"]] for data in log2} 78 | 79 | common_step = set(log1.keys()) & set(log2.keys()) 80 | 81 | for step in common_step: 82 | assert np.allclose(log1[step][0], log2[step][0], atol=1e-3), f"Loss at step {step} is different" 83 | assert log1[step][1] == log2[step][1], f"Lr at step {step} is different" 84 | 85 | 86 | @pytest.fixture 87 | def config_hv() -> list[str]: 88 | config = [ 89 | "--path_model", 90 | "tests/models/llama-2m-fresh", 91 | "--fake_data", 92 | "--no-torch_compile", 93 | "--lr", 94 | "1e-2", 95 | "--per_device_train_batch_size", 96 | "8", 97 | "--total_batch_size", 98 | "16", 99 | "--max_steps", 100 | "100", 101 | "--metric_logger_type", 102 | "dummy", 103 | ] 104 | 105 | return config + [ 106 | "--hv.local_steps", 107 | "25", 108 | "--hv.skip_load_from_peers", 109 | "--hv.fail_rank_drop", 110 | "--hv.matchmaking_time", 111 | "5", 112 | ] 113 | 114 | 115 | @pytest.mark.parametrize("num_diloco", [2]) 116 | def test_multi_gpu_hivemind(config_hv, num_diloco, tmp_path): 117 | dht = DHT( 118 | start=True, 119 | host_maddrs=[f"/ip4/0.0.0.0/tcp/{get_random_available_port()}"], 120 | ) 121 | 122 | initial_peers = str(dht.get_visible_maddrs()[0]) 123 | 124 | results = [] 125 | 126 | ckpt_path = f"{tmp_path}/ckpt" 127 | 128 | def get_base_cmd(i, initial_peers): 129 | return [ 130 | "torchrun", 131 | f"--nproc_per_node={1}", 132 | "--rdzv-endpoint", 133 | f"localhost:{port}", 134 | "open_diloco/train_fsdp.py", 135 | *config_hv, 136 | "--hv.initial_peers", 137 | initial_peers, 138 | "--hv.world_rank", 139 | str(i), 140 | "--hv.galaxy_size", 141 | str(num_diloco), 142 | ] 143 | 144 | for i in range(num_diloco): 145 | port = get_random_available_port() 146 | 147 | cmd = get_base_cmd(i, initial_peers) + [ 148 | "--ckpt.path", 149 | ckpt_path, 150 | "--ckpt.interval", 151 | "25", 152 | "--project", 153 | f"{tmp_path}/log{i}_part1.json", 154 | ] 155 | 156 | result = subprocess.Popen(cmd) 157 | results.append(result) 158 | 159 | for result in results: 160 | result.wait() 161 | if result.returncode != 0: 162 | pytest.fail(f"Process {result} failed {result.stderr}") 163 | 164 | # resume from ckpt 165 | 166 | dht.shutdown() 167 | 168 | del dht 169 | dht = DHT( 170 | start=True, 171 | host_maddrs=[f"/ip4/0.0.0.0/tcp/{get_random_available_port()}"], 172 | ) 173 | initial_peers = str(dht.get_visible_maddrs()[0]) 174 | 175 | for i in range(num_diloco): 176 | port = get_random_available_port() 177 | 178 | cmd = get_base_cmd(i, initial_peers) + [ 179 | "--ckpt.resume", 180 | f"{ckpt_path}/{CKPT_PREFIX}_50", 181 | "--project", 182 | f"{tmp_path}/log{i}_part2.json", 183 | ] 184 | 185 | result = subprocess.Popen(cmd) 186 | results.append(result) 187 | 188 | for result in results: 189 | result.wait() 190 | if result.returncode != 0: 191 | pytest.fail(f"Process {result} failed {result.stderr}") 192 | 193 | for i in range(num_diloco): 194 | with open(f"{tmp_path}/log{i}_part1.json", "rb") as f: 195 | log1 = pickle.load(f) 196 | with open(f"{tmp_path}/log{i}_part2.json", "rb") as f: 197 | log2 = pickle.load(f) 198 | 199 | log1 = {data["step"]: [data["Loss"], data["lr"]] for data in log1} 200 | log2 = {data["step"]: [data["Loss"], data["lr"]] for data in log2} 201 | 202 | common_step = set(log1.keys()) & set(log2.keys()) 203 | 204 | for step in common_step: 205 | assert np.allclose(log1[step][0], log2[step][0], atol=1e-2), f"Loss at step {step} is different" 206 | assert log1[step][1] == log2[step][1], f"Lr at step {step} is different" 207 | --------------------------------------------------------------------------------