├── .github └── workflows │ ├── lint.yml │ ├── mypy.yml │ └── pytest.yml ├── .gitignore ├── .isort.cfg ├── LICENSE ├── README.md ├── code-of-conduct.md ├── docs ├── configurations.md ├── dataloader.md ├── evaluation.md ├── fine_tuning.md └── train_details.md ├── fms_fsdp ├── __init__.py ├── config │ ├── __init__.py │ └── training.py ├── policies │ ├── __init__.py │ ├── ac_handler.py │ ├── mixed_precision.py │ ├── param_init.py │ └── wrapping.py ├── readme.md ├── requirements.txt └── utils │ ├── __init__.py │ ├── checkpointing_utils.py │ ├── config_utils.py │ ├── dataloader_utils.py │ ├── dataset_utils.py │ └── train_utils.py ├── fms_to_hf_llama.py ├── fms_to_hf_mamba.py ├── images ├── loss_curve.png └── lr.png ├── main_training_llama.py ├── main_training_mamba.py ├── requirements-speculator.txt ├── requirements.txt ├── scripts ├── README_SPECULATOR.md ├── train.sh ├── train.slurm └── train_speculator.sh ├── setup.py ├── speculator ├── __init__.py ├── train_speculator.py └── train_speculator_utils.py ├── test-requirements.txt └── tests ├── conftest.py ├── test_datasets.py └── test_selective_ac.py /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: [pull_request] 4 | 5 | permissions: 6 | contents: read 7 | 8 | jobs: 9 | lint: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: psf/black@stable 14 | with: 15 | options: "--check --diff --color" 16 | src: "." 17 | version: "~= 23.3.0" 18 | - uses: isort/isort-action@master 19 | with: 20 | sort-paths: . 21 | requirementsFiles: "requirements.txt" # We don't need extra test requirements for linting 22 | -------------------------------------------------------------------------------- /.github/workflows/mypy.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies and run MyPy 2 | 3 | name: MyPy Type Checking 4 | 5 | on: [pull_request] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Python 3.10 17 | id: setup_python 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: "3.10" 21 | - name: Restore Virtualenv 22 | uses: actions/cache/restore@v4 23 | id: cache-venv-restore 24 | with: 25 | path: ./.venv/ 26 | key: ${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-venv-${{ hashFiles('*requirements.txt') }} 27 | restore-keys: | 28 | ${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-venv- 29 | - name: Install dependencies 30 | run: | 31 | # Create the virtual environment 32 | python -m venv .venv 33 | . ./.venv/bin/activate 34 | 35 | # Install the dependencies 36 | # In case of a cache hit on the primary key, this will be a no-op 37 | # In case of a cache miss, but hit on a secondary key, this will update what's changed 38 | python -m pip install --upgrade pip 39 | pip install -r test-requirements.txt 40 | 41 | # Enables the virtual env for following steps 42 | echo "$VIRTUAL_ENV/bin" >> $GITHUB_PATH 43 | echo "VIRTUAL_ENV=$VIRTUAL_ENV" >> $GITHUB_ENV 44 | 45 | - name: Test with mypy 46 | run: | 47 | # Install ibm-fms from the main branch for testing purposes 48 | # Use -I to ignore the existing install and actually install 49 | # the version on main 50 | pip install -I ibm-fms@git+https://github.com/foundation-model-stack/foundation-model-stack@main 51 | 52 | # No type stubs available for "fire" and "transformers" 53 | mypy --exclude fms_to_hf.py --exclude main_training.py --exclude setup.py . 54 | 55 | - name: Save Virtualenv 56 | id: cache-venv-save 57 | uses: actions/cache/save@v4 58 | with: 59 | path: ./.venv/ 60 | key: ${{ steps.cache-venv-restore.outputs.cache-primary-key }} 61 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install dependencies and run pytest 2 | 3 | name: Pytest 4 | 5 | on: [pull_request] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Python 3.10 17 | id: setup_python 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: "3.10" 21 | - name: Restore Virtualenv 22 | uses: actions/cache/restore@v4 23 | id: cache-venv-restore 24 | with: 25 | path: ./.venv/ 26 | key: ${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-venv-${{ hashFiles('*requirements.txt') }} 27 | restore-keys: | 28 | ${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-venv- 29 | - name: Install dependencies 30 | run: | 31 | # Create the virtual environment 32 | python -m venv .venv 33 | . ./.venv/bin/activate 34 | 35 | # Install the dependencies 36 | # In case of a cache hit on the primary key, this will be a no-op 37 | # In case of a cache miss, but hit on a secondary key, this will update what's changed 38 | python -m pip install --upgrade pip 39 | pip install -r test-requirements.txt 40 | 41 | # Enables the virtual env for following steps 42 | echo "$VIRTUAL_ENV/bin" >> $GITHUB_PATH 43 | echo "VIRTUAL_ENV=$VIRTUAL_ENV" >> $GITHUB_ENV 44 | 45 | - name: Test with pytest 46 | run: | 47 | # Install ibm-fms from the main branch for testing purposes 48 | # Use -I to ignore the existing install and actually install 49 | # the version on main 50 | pip install -I ibm-fms@git+https://github.com/foundation-model-stack/foundation-model-stack@main 51 | 52 | # Install fms-fsdp project 53 | pip install -e . 54 | 55 | # No type stubs available for "fire" and "transformers" 56 | pytest tests/ 57 | 58 | - name: Save Virtualenv 59 | id: cache-venv-save 60 | uses: actions/cache/save@v4 61 | with: 62 | path: ./.venv/ 63 | key: ${{ steps.cache-venv-restore.outputs.cache-primary-key }} 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | 3 | .DS_Store 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | ensure_newline_before_comments = True 3 | force_grid_wrap = 0 4 | include_trailing_comma = True 5 | lines_after_imports = 2 6 | multi_line_output = 3 7 | use_parentheses = True 8 | profile = black 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FMS FSDP - (Pre)Training FMS with FSDP 2 | 3 | The “fms-fsdp” repo is a companion to the [Foundation Model Stack](https://github.com/foundation-model-stack/foundation-model-stack). 4 | The goal of this repo is to provide a (pre)training example to efficiently train 5 | FMS models, in particular Llama2 by leveraging native PyTorch features - FSDP for training and SDPA implementation of Flash attention v2. While there are many exemplar repositories that can perform pretraining at scale (e.g., [MegatronLM](), [DeepSpeed]()), this work is what IBM has been doing with PyTorch community on using FSDP for training and how to do that efficiently. It is not meant to be an end-to-end framework for training of models, which includes data preparation (pre), and alignment/tuning of the base model (post). 6 | 7 | For an end-to-end framework, we would recommend the reader to [OLMo](https://github.com/allenai/OLMo) from AllenAI, which provides datasets, data preprocessing frameworks, leverages FSDP on AMD GPUs for training, and provides a tuning/alignment framework. 8 | 9 | ## Training throughput benchmarks 10 | **_numbers are updated with `torch.compile`, as our fms models are fully compatible with torch compile_** 11 | 12 | We benchmark the best possible throughput and the strategies we employ in the below table and share the throughput obtained on 128 A100 GPUs as well as 96 H100 GPUs, we use the exact same scripts and configurations for these GPUs. 13 | 14 | | Model Size | Sharding Strategy | Compile | Activation Checkpointing | Batch Size | Training Throughput
tokens/sec/GPU
A100 80G 128 GPUs with 400Gbps | Training throughput
tokens/sec/GPU
H100 96 GPUs with 800 Gbps | 15 | |------------|-------------------|---------|--------------------------|------------|-------------------------------------------------------------------------------|---------------------------------------------------------------------------| 16 | | 7b | HSDP | Y | No AC | 2 | 4550 | 9600 | 17 | | 13b | FSDP | Y | Selective AC | 2 | 2150 | 4850 | 18 | | 34b | FSDP | Y | Selective AC | 2 | 820 | 1830 | 19 | | 70b | FSDP | Y | Selective AC | 2 | 410 | 890 | 20 | 21 | HFU numbers are computed using the [PyTorch FLOP counter](https://github.com/pytorch/pytorch/blob/2240018c03744ee34ea14ad53481db934c37e384/torch/utils/flop_counter.py#L336) and the theoretical bf16 performance of 22 | A100 and H100 GPUs, whereas MFU numbers are computed using the methodology outlined in 23 | [NanoGPT](https://github.com/karpathy/nanoGPT) and the [PaLM](https://arxiv.org/pdf/2204.02311.pdf) paper. 24 | 25 | | Model Size | Compile | Batch size | MFU (A100 80G) | HFU (A100 80G) | MFU (H100 80G) | HFU (H100 80G) | 26 | |------------|---------|------------|----------------|----------------|----------------|----------------| 27 | | 7B | Y | 2 | 0.68 | 0.68 | 0.46 | 0.46 | 28 | | 13B | Y | 2 | 0.61 | 0.69 | 0.43 | 0.46 | 29 | | 34B | Y | 2 | 0.55 | 0.74 | 0.38 | 0.49 | 30 | | 70B | Y | 2 | 0.55 | 0.74 | 0.38 | 0.47 | 31 | 32 | A few points to note here, on the A100s, we note that for 13B we are not utilizing the hardware as well (only 0.48 MFU) because of smaller batch size. We can dial up the MFU by turning on activation checkpointing, however the throughput falls to 1600 tokens/sec/GPU. Whereas, note that the gaps here are more glaring with H100s where the MFU for 7 and 13B falls below 0.40. 33 | 34 | Another point to note here is that for the larger models, we could increase the throughput by a few percentage points when we increase the batch size. However, we have left the batches to be smaller to allow for scaling to 1024 GPUs without introducing tensor parallelism. 35 | 36 | ## Installation 37 | You need to install the required packages by running the following command. 38 | We recommend running the latest [PyTorch nightlies](https://pytorch.org/) and latest [ibm-fms](https://github.com/foundation-model-stack/foundation-model-stack). 39 | ```bash 40 | pip install -r requirements.txt 41 | ``` 42 | 43 | ## Training 44 | 45 | ### Model 46 | We trained one model, a replica of Llama2 7B as an exemplar on IBM curated data. This model was trained to 2.2T tokens with a 4k context length on 128 A100 GPUs for a total of 162k GPU hours, achieving an efficiency of 3700 tokens/sec/GPU (~40B tokens/day), which is roughly 20% faster than the Llama2 published training time. These speedups were possible by combining multiple techniques - SDPA Flash v2 implementation, FSDP with overlap in computation and communication, and selective activation checkpointing. 47 | The generated model has a good performance on various metrics as evaluated by [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness), with MMLU score of 0.5. We share further [scores](docs/evaluation.md) in the details of the model for completeness. 48 | 49 | ### Dataset 50 | We use an internally curated dataset for training the model. We use sampling ratios similar to what Llama1 paper proposed with minor changes (e.g., no C4 dataset). Since the goal of this repo is to demonstrate the feasibility of training using PyTorch components at scale, we omit the details of the sampling ratios. The overall dataset is roughly 1.5T tokens and the model has seen all the tokens in the dataset at least once. 51 | 52 | For this dataset, we designed a large-scale workload dataloader, details can be found [here](docs/dataloader.md). 53 | 54 | ### Train Config 55 | 56 | Below assumes running with Slurm, but same can be easily adopted 57 | if running with other clusters. 58 | 59 | 1. modify Training Config in [scripts/train.sh](scripts/train.sh) (for the full 60 | list of training configs and best practices, refer to [Configuration Doc](docs/configurations.md)). 61 | 2. modify Run Config in [scripts/train.slurm](scripts/train.slurm) 62 | 63 | ### Run 64 | ```bash 65 | sbatch ./scripts/train.slurm 66 | ``` 67 | For other cluster setup, we can simply use the *torchrun* commands inside `train.sh`. 68 | 69 | ### Training Details and Lessons learnt 70 | Details on training stability, loss curve, LR curve, etc., as well as what 71 | we have learnt from this journey can be found in [Training Details](docs/train_details.md). 72 | 73 | ## Post Training 74 | 75 | ### Convert to Hugging Face format 76 | 77 | The model trained with this repo is in FMS format, and you might want to convert it 78 | to Huggingface format so that you can load it natively with Huggingface and leverage Huggingface ecosystem: 79 | ```bash 80 | python fms_to_hf.py --model_variant 7b --nocompiled --load_path /path/to/trained/checkpoints --save_path /output/path --tokenizer_name_or_path /path/to/llama/tokenizer 81 | ``` 82 | > [!Note] 83 | > This repo consumes pre-tokenized data thus does not require a tokenizer. However, 84 | > Huggingface checkpoint requires a paired tokenizer thus you need to pass a tokenizer 85 | > here so it can be copied over to the save dir. Just download the HF Llama tokenizer 86 | > and pass the path here. 87 | 88 | ## Fine tuning 89 | 90 | We have performed preliminary fine-tuning on our base model and details can be found [here](docs/fine_tuning.md). 91 | -------------------------------------------------------------------------------- /code-of-conduct.md: -------------------------------------------------------------------------------- 1 | # Foundation Model Stack Community Code of Conduct 2 | 3 | Please refer to [Foundation Model Stack Community Code of Conduct](https://github.com/foundation-model-stack/foundation-model-stack/blob/main/code-of-conduct.md). 4 | -------------------------------------------------------------------------------- /docs/configurations.md: -------------------------------------------------------------------------------- 1 | # Configurations 2 | 3 | All configurations in [scripts/train.sh](scripts/train.sh) will be passed into 4 | [training configs](../pretraining/config/training.py). 5 | 6 | ## Full list of configurations 7 | 8 | ### Model 9 | - **model_variant**: the llama variant, values in "7b", "13b", "34b" and "70b". 10 | - **ckpt_load_path**: the path from where checkpoint will be loaded for continued training. 11 | - **ckpt_save_path**: the path to which checkpoint will be saved. 12 | 13 | ### Dataset and Dataloader 14 | - **use_dummy_dataset**: set this to `True`` to use dummy dataset for quick testing and performance benchmarking. 15 | - **data_path**: Data path. 16 | - **seq_length**: Sequence/context length to build when preparing model input. 17 | - **sep_token**: Separator token in the tokenized dataset. 18 | - **datasets**: Subfolders under `datapath` that contains different datasets. 19 | - **weights**: Proportion of each dataset when training. 20 | - **logical_shards**: Number of logical shards when building dataloader. This is an advanced setting and we will go into detail in a future update. 21 | 22 | ### FSDP policies 23 | - **sharding_strategy**: "FSDP" (Fully Sharded) or "HSDP" (Hybrid Sharded), HSDP allows for sharding within a node and DP across nodes. 24 | - **mixed_precision**: Whether to use `bf16` mixed precision for training 25 | - **fsdp_activation_checkpointing**: whether to turn on activation checkpointing 26 | - **selective_checkpointing**: How many blocks to checkpoint the activation. 1 is the default setting, experiment with this number to trade off between memory and compute requirements. 27 | - **low_cpu_fsdp**: Whether to load the model in low cpu mode. This is useful when loading large models like 70b. 28 | 29 | ### Training spec 30 | - **seed**: random seed for reproduction. 31 | - **batch_size**: batch size per gpu. 32 | - **num_steps**: total number of steps to train. 33 | - **learning_rate**: learning rate. This is the max learning rate for model to warm up to. We default to the `cosine` schedule which is popular for pretraining workloads. 34 | 35 | ### Profiling and reporting 36 | - **use_profiler**: whether to turn on PyTorch profiler to generate profiling traces 37 | - **report_interval**: how many steps to report training metrics 38 | - **checkpoint_interval**: how many steps to save a checkpoint 39 | 40 | ### Compile 41 | - **use_torch_compile**: whether to turn on compile. It is recommended to NOT use compile at this stage due to some known issues with compile-training. 42 | 43 | 44 | ## Deep Dive into FSDP Configs 45 | You can skip this section if you are already familiar with FSDP. 46 | 47 | ### Basics 48 | In case you are new to FSDP, here are some basic references: 49 | [FSDP Intro](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) 50 | [FSDP API](https://pytorch.org/docs/stable/fsdp.html) 51 | 52 | ### Tune the configs 53 | The key to achieve the best performance for FSDP is to fully overlap the communication 54 | time with computation time so that the GPUs are busy all the time as if there is no 55 | communication cost/gap. 56 | 57 | **sharding_strategy** is the first thing you want to decide. FSDP will shard your model 58 | across all devices (GPUs) while HSDP will only shard your model across devices within 59 | the same node. E.g. if you are training with 128 nodes (i.e. 128 * 8 = 1024 total gpus), 60 | FSDP will shard your model across all 1024 gpus while HSDP will shard your model across 61 | 8 gpus in each node. Therefore, FSDP will save more memory while HSDP will introduce 62 | lesser communications. For smaller models like 7b and 13b, HSDP is preferred as it 63 | will make communication shorter thus easier to be overlapped with computation; while 64 | for larger models like 34b and 70b, FSDP will be a necessity as the model is too large 65 | to be fitted into only 8 gpus. 66 | 67 | **fsdp_activation_checkpointing** controls if activation checkpointing is enabled. 68 | Enabling it will greatly save the memory but also increase the computation time due 69 | to activation re-computation. For large models you would typically have to set this 70 | to `True` as activations consume large amount of memory and without checkpointing it 71 | you will face OOM. For smaller models, depending on your batch size, you may or may 72 | not enable it. As a companion, **selective_checkpointing** controls how "often" 73 | to checkpoint the activation (i.e. checkpoint activation only every k steps), the 74 | smaller the value, the more often it checkpoints and thus the more memory will 75 | be saved. default value is 1 meaning checkpoint every block. 76 | -------------------------------------------------------------------------------- /docs/dataloader.md: -------------------------------------------------------------------------------- 1 | # Data Loader 2 | 3 | We design a data loader as part of our pretraining that can provide shuffling in realtime while ensuring that there is no drop in GPU utilization. We design it to be scalable to multiple nodes (tested to 128 nodes), streaming, rescaling the number of GPUs during a single training run, and allows for restart from a given state. 4 | 5 | ## Details 6 | 7 | The current distributed dataloader is designed to meet two important needs of data scientists running large-scale training workloads: seamless resumption of an interrupted job, and rapid iteration on dataset composition and handling. 8 | 9 | We address the first by maintaining a checkpointable state that allows the user to restart model training from checkpoint, mid-epoch, while keeping a guarantee that each document will be viewed exactly once in any given epoch, up to oversampling (i.e. no revisiting stale data). The user is also free to scale the job up or down to different numbers of gpus from phase to phase, while still maintaining this guarantee. 10 | 11 | To address the second concern, we enforce a rigid format on our input data (tokenized documents, in arrow shard files, organized into dataset subdirectories, with a single unified metadata file of document counts per shard) but construct the specific dataset combinations and mixes dynamically at runtime, based on user inputs. This is accomplished separately on each worker process, with no communication needed between devices. Each worker then streams through the ordered documents and shard files according to its constructed plan, pulling files from disk or cloud on demand as training proceeds. This allows the user to add or eliminate subdatasets, adjust subdataset sampling rates, change BOS/EOS/SEP tokens, toggle padding or packing on or off, adjust sequence lengths, or swap out the training task, for example, without having to build any new training datasets on disk from run to run (a potentially long and expensive process for Terabyte-scale data). 12 | 13 | Because each worker is streaming documents and files sequentially, shuffling is required, and this is accomplished via an internal buffer which ensures that in expectation, two consecutive lines in the stream will appear 10,000 steps apart (this can be adjusted higher or lower as desired). Finally, the dataloader is implemented as modular extensions of PyTorch Datasets, allowing the user to add or remove custom data pipeline functionality as needed. 14 | 15 | Further technical details can be found in the `fms-fsdp/utils/dataset_utils.py` file. 16 | -------------------------------------------------------------------------------- /docs/evaluation.md: -------------------------------------------------------------------------------- 1 | # Preliminary Evaluation 2 | 3 | We convert the sharded FSDP checkpoint to a Hugging Face checkpoint and run [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) 4 | without any changes. The key scores at the 2T checkpoint are summarized below: 5 | 6 | | Evaluation metric | Llama2-7B (baseline) | LlamaT-7B | 7 | |----------------------------|----------------------|---------------| 8 | | MMLU (zero shot) | 0.41 | 0.43 | 9 | | MMLU (5-shot weighted avg) | 0.47 | 0.50 | 10 | | Arc challenge | 0.46 | 0.44 | 11 | | Arc easy | 0.74 | 0.71 | 12 | | Boolq | 0.78 | 0.76 | 13 | | Copa | 0.87 | 0.83 | 14 | | Hellaswag | 0.76 | 0.74 | 15 | | Openbookqa | 0.44 | 0.42 | 16 | | Piqa | 0.79 | 0.79 | 17 | | Sciq | 0.91 | 0.91 | 18 | | Winogrande | 0.69 | 0.67 | 19 | | Truthfulqa | 0.39 | 0.39 | 20 | | GSM8k (8-shot) | 0.13 | 0.11 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /docs/fine_tuning.md: -------------------------------------------------------------------------------- 1 | # Fine Tuning 2 | 3 | To validate that the stack produces a model which is as 4 | easy to fine-tune as Llama models, we convert the FSDP checkpoint into a HF checkpoint and fine tune it 5 | using popular fine-tuning configurations and data mixes. 6 | 7 | Specifically, we follow Allen AI’s [open-instruct](https://github.com/allenai/open-instruct) framework, leveraging the TULU v2 stack as-is 8 | (DeepSpeed, TULU v2 mixture and recommended configuration for Llama 2 models). The tuned model 9 | scores are presented below and we note improvements in several tasks. We did not do a hyperparameter 10 | exploration for the best parameters to fine-tune LlamaT. We note that optimal hyperparameter for 11 | LlamaT tuning could be different from Llama 2 as they are likely to have followed different learning 12 | rate schedules. 13 | 14 | | Evaluation metric | Llama2-7B (baseline) | LlamaT-7B | 15 | |----------------------------|----------------------|---------------| 16 | | MMLU (5-shot weighted avg) | 0.53 | 0.49 | 17 | | Arc challenge | 0.48 | 0.43 | 18 | | Arc easy | 0.73 | 0.67 | 19 | | Boolq | 0.82 | 0.82 | 20 | | Copa | 0.89 | 0.86 | 21 | | Hellaswag | 0.76 | 0.75 | 22 | | Openbookqa | 0.47 | 0.42 | 23 | | Piqa | 0.79 | 0.79 | 24 | | Sciq | 0.93 | 0.91 | 25 | | Winogrande | 0.71 | 0.65 | 26 | | Truthfulqa | 0.45 | 0.46 | 27 | | GSM8k (8-shot) | 0 | 0 | 28 | -------------------------------------------------------------------------------- /docs/train_details.md: -------------------------------------------------------------------------------- 1 | # Training Details 2 | 3 | The model training used PyTorch FSDP with no activation recomputation, hybrid sharding with model 4 | weights and optimizer state sharded within a node and data parallel across nodes, per GPU batch size of 5 | 2 (effective batch size of 1M tokens/batch), AdamW optimizer with beta1 of 0.9 and beta2 of 0.95, weight 6 | decay of 0.1, and a learning rate ending at 3e-5 with a warmup to max learning rate of 3e-4 and a cosine 7 | schedule to reduce to 3e-5 over 2T tokens. The loss curve tracks that of Llama2 paper and reaches a lower 8 | loss than Llama2 7B does, which we believe is the characteristic of the dataset. 9 | 10 | ### Loss Curve 11 | ![](../images/loss_curve.png) 12 | 13 | ### Learning Rate 14 | ![](../images/lr.png) 15 | 16 | ## Lesson learned 17 | 18 | ### Stability 19 | 20 | Training was stable with no crashes. We had a few hiccups as outlined below. 21 | 22 | **0-200B tokens**: We observed a slowdown in the iteration time (time taken to execute one training step). We stopped the job (freeing up GPUs for other workloads) to ensure that the data loader was not causing any slowdowns and the checkpointing was performant and accurate. We did not find any issues. By this time, HSDP checkpointing code was available in PyTorch, and we took this opportunity to make the switch to PyTorch checkpointing code. 23 | 24 | **200B tokens-1.9T**: We did not do any manual intervention in the job and forgot that it was running during the winter break. When we came back early January, disk space had exceeded and checkpoints were failing to be written, although training job continued. The last known checkpoint was 1.5T. 25 | 26 | **1.5T-1.7T**: We evaluated the 1.5T checkpoint with lm-evaluation-harness and discovered that model has been trained with extra special token between two documents due to the Hugging Face tokenizer introducing a separator token and our dataloader also appending its own document separator. We modified the dataloader to eliminate the extra special token, and continued training with the modified dataloader from 1.7T token onwards. 27 | 28 | **1.7T-2T**: The loss initially spiked due to the change in the special tokens which was quickly recovered in a few billion tokens. The training finished without any other manual intervention!! 29 | 30 | ### Speedups 31 | 32 | There are two approaches to speeding up the performance even further. With our recent work on 33 | improving inference speeds, we fused several layers that resulted in reduced inference latencies. We 34 | expect these techniques to benefit training as well. 35 | 36 | Further, with the release of a similar training code by OLMo, the issue that we had raised with PyTorch to 37 | get compile working for FSDP increased in priority. We are currently engaged with the PT team on enabling 38 | compile, which can provide further boost to the training speeds. 39 | -------------------------------------------------------------------------------- /fms_fsdp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/fms-fsdp/503da7ede354e1ffdabc80fae8bbd211cb2174c8/fms_fsdp/__init__.py -------------------------------------------------------------------------------- /fms_fsdp/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .training import train_config 2 | -------------------------------------------------------------------------------- /fms_fsdp/config/training.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Union 3 | 4 | 5 | @dataclass 6 | class train_config: 7 | # model 8 | model_variant: str = "7b" 9 | ckpt_load_path: str = "/fsx/output/ckpt" 10 | ckpt_save_path: str = "/fsx/output/ckpt" 11 | 12 | # dataset and dataloader 13 | use_dummy_dataset: bool = False 14 | data_path: str = "/fsx/data" 15 | file_type: str = "arrow" 16 | col_name: str = "tokens" 17 | tokenizer_path: str = "/fsx/tokenizer" 18 | datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange" 19 | weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100" 20 | seq_length: int = 4096 21 | vocab_size: int = 32000 22 | bos_token: Optional[int] = None 23 | eos_token: int = 0 24 | bol_token: Optional[int] = None 25 | eol_token: Optional[int] = None 26 | strip_tokens: str = "" 27 | logical_shards: int = 1024 28 | num_workers: int = 1 29 | 30 | # fsdp policies 31 | sharding_strategy: str = "hsdp" 32 | fsdp_activation_checkpointing: bool = False 33 | selective_checkpointing: Union[float, str] = 1 # percentage of blocks to apply ac 34 | mixed_precision: bool = True 35 | low_cpu_fsdp: bool = False 36 | 37 | # training spec 38 | batch_size: int = 2 39 | num_steps: int = 1000000 40 | training_stage: str = "initial" 41 | learning_rate: float = 3e-4 42 | grad_clip_thresh: float = 1.0 43 | seed: int = 2023 44 | 45 | # continued training spec 46 | resuming_dataset: bool = False 47 | 48 | # profiling 49 | use_profiler: bool = False 50 | profiler_rank0_only: bool = True 51 | 52 | # logging 53 | report_interval: int = 100 54 | checkpoint_interval: int = 10000 55 | tracker: Optional[str] = None # None, "wandb", "aim" 56 | tracker_dir: str = "/fsx/aim_logs/llama" 57 | tracker_project_name: str = "llama" # project name for a group of runs 58 | tracker_run_id: Optional[str] = None # run id, for job resume purpose 59 | 60 | # compile 61 | use_torch_compile: bool = True 62 | 63 | # speculator training 64 | tp_size: int = 8 65 | model_arch: str = "embedllama" 66 | model_path: str = "/path/to/model/" 67 | n_speculator_heads: int = 3 68 | speculator_width: int = 4096 69 | speculator_tie_weights: bool = True 70 | speculator_scale_input: bool = True 71 | stage2_start_step: int = 15000 72 | stage2_prompt_length: int = 64 73 | stage2_batch_size: int = 96 74 | stage2_seq_length: int = 256 75 | -------------------------------------------------------------------------------- /fms_fsdp/policies/__init__.py: -------------------------------------------------------------------------------- 1 | from .ac_handler import apply_fsdp_checkpointing 2 | from .mixed_precision import * 3 | from .param_init import param_init_function 4 | from .wrapping import get_wrapper 5 | -------------------------------------------------------------------------------- /fms_fsdp/policies/ac_handler.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 4 | CheckpointImpl, 5 | apply_activation_checkpointing, 6 | checkpoint_wrapper, 7 | ) 8 | 9 | 10 | non_reentrant_wrapper = partial( 11 | checkpoint_wrapper, 12 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 13 | ) 14 | 15 | 16 | def apply_fsdp_checkpointing(model, block, p): 17 | """ 18 | Apply selective activation checkpointing. 19 | 20 | Selectivity is defined as a percentage p, which means we apply ac 21 | on p of the total blocks. p is a floating number in the range of 22 | [0, 1]. 23 | 24 | Some examples: 25 | p = 0: no ac for all blocks. same as `fsdp_activation_checkpointing=False` 26 | p = 1: apply ac on every block. i.e. "full ac". 27 | p = 1/2: [ac, no-ac, ac, no-ac, ...] 28 | p = 1/3: [no-ac, ac, no-ac, no-ac, ac, no-ac, ...] 29 | p = 2/3: [ac, no-ac, ac, ac, no-ac, ac, ...] 30 | Since blocks are homogeneous, we make ac blocks evenly spaced among 31 | all blocks. 32 | 33 | Implementation: 34 | For a given ac ratio p, we should essentially apply ac on every "1/p" 35 | blocks. The first ac block can be as early as the 0th block, or as 36 | late as the "1/p"th block, and we pick the middle one: (0.5p)th block. 37 | Therefore, we are essentially to apply ac on: 38 | (0.5/p)th block, (1.5/p)th block, (2.5/p)th block, etc., and of course, 39 | with these values rounding to integers. 40 | Since ac is applied recursively, we can simply use the following math 41 | in the code to apply ac on corresponding blocks. 42 | """ 43 | block_idx = 0 44 | cut_off = 1 / 2 45 | # when passing p as a fraction number (e.g. 1/3), it will be interpreted 46 | # as a string in argv, thus we need eval("1/3") here for fractions. 47 | p = eval(p) if isinstance(p, str) else p 48 | 49 | def selective_checkpointing(submodule): 50 | nonlocal block_idx 51 | nonlocal cut_off 52 | 53 | if isinstance(submodule, block): 54 | block_idx += 1 55 | if block_idx * p >= cut_off: 56 | cut_off += 1 57 | return True 58 | return False 59 | 60 | apply_activation_checkpointing( 61 | model, 62 | checkpoint_wrapper_fn=non_reentrant_wrapper, 63 | check_fn=selective_checkpointing, 64 | ) 65 | -------------------------------------------------------------------------------- /fms_fsdp/policies/mixed_precision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributed.fsdp import MixedPrecision 3 | 4 | 5 | fpSixteen = MixedPrecision( 6 | param_dtype=torch.float16, 7 | reduce_dtype=torch.float16, 8 | buffer_dtype=torch.float16, 9 | ) 10 | 11 | bfSixteen = MixedPrecision( 12 | param_dtype=torch.bfloat16, 13 | reduce_dtype=torch.bfloat16, 14 | buffer_dtype=torch.bfloat16, 15 | ) 16 | 17 | bfSixteen_working = MixedPrecision( 18 | param_dtype=torch.float32, 19 | reduce_dtype=torch.bfloat16, 20 | buffer_dtype=torch.bfloat16, 21 | ) 22 | 23 | fp32_policy = MixedPrecision( 24 | param_dtype=torch.float32, 25 | reduce_dtype=torch.float32, 26 | buffer_dtype=torch.float32, 27 | ) 28 | -------------------------------------------------------------------------------- /fms_fsdp/policies/param_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fms.modules.attention import MultiHeadAttention 3 | from fms.modules.embedding import WordEmbedding 4 | from fms.modules.feedforward import GatedLinearUnit 5 | from fms.modules.layernorm import LayerNormParameterized 6 | 7 | 8 | # for details, read https://github.com/foundation-model-stack/fms-fsdp/issues/64 9 | def param_init_function(module): 10 | if ( 11 | isinstance(module, MultiHeadAttention) 12 | or isinstance(module, WordEmbedding) 13 | or isinstance(module, GatedLinearUnit) 14 | or isinstance(module, LayerNormParameterized) 15 | ): 16 | module.to_empty(device=torch.cuda.current_device()) 17 | with torch.no_grad(): 18 | module.reset_parameters() 19 | -------------------------------------------------------------------------------- /fms_fsdp/policies/wrapping.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 4 | 5 | 6 | def get_wrapper(block): 7 | auto_wrap_policy = functools.partial( 8 | transformer_auto_wrap_policy, 9 | transformer_layer_cls={ 10 | block, 11 | }, 12 | ) 13 | 14 | return auto_wrap_policy 15 | -------------------------------------------------------------------------------- /fms_fsdp/readme.md: -------------------------------------------------------------------------------- 1 | # LLAMA branch 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /fms_fsdp/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/fms-fsdp/503da7ede354e1ffdabc80fae8bbd211cb2174c8/fms_fsdp/requirements.txt -------------------------------------------------------------------------------- /fms_fsdp/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/fms-fsdp/503da7ede354e1ffdabc80fae8bbd211cb2174c8/fms_fsdp/utils/__init__.py -------------------------------------------------------------------------------- /fms_fsdp/utils/checkpointing_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | from pathlib import Path 5 | 6 | import torch 7 | from torch.distributed._shard.checkpoint import ( 8 | FileSystemReader, 9 | FileSystemWriter, 10 | load_state_dict, 11 | save_state_dict, 12 | ) 13 | from torch.distributed.checkpoint.default_planner import ( 14 | DefaultLoadPlanner, 15 | DefaultSavePlanner, 16 | ) 17 | from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict 18 | from torch.distributed.fsdp import FullStateDictConfig 19 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 20 | from torch.distributed.fsdp import StateDictType 21 | 22 | 23 | def get_latest(targdir, qualifier=lambda x: True, key=os.path.getctime): 24 | """ 25 | Fetch the full path of the latest file or folder written to target directory, 26 | subject to name passing the qualifier fn. 27 | Optional key fn can be used for custom sorting. 28 | Both functions take full path arguments. 29 | If directory is empty or nonexistent or no items qualify, return None. 30 | """ 31 | if os.path.exists(targdir) and len(os.listdir(targdir)) > 0: 32 | latest = max( 33 | [ 34 | os.path.join(targdir, x) 35 | for x in os.listdir(targdir) 36 | if qualifier(os.path.join(targdir, x)) 37 | ], 38 | key=key, 39 | ) 40 | return latest 41 | return None 42 | 43 | 44 | def get_oldest(targdir, qualifier=lambda x: True, key=os.path.getctime): 45 | """ 46 | Fetch the full path of the oldest file or folder written to target directory, 47 | subject to name passing the qualifier fn. 48 | Optional key fn can be used for custom sorting. 49 | Both functions take full path arguments. 50 | If directory is empty or nonexistent or no items qualify, return None. 51 | """ 52 | if os.path.exists(targdir) and len(os.listdir(targdir)) > 0: 53 | oldest = min( 54 | [ 55 | os.path.join(targdir, x) 56 | for x in os.listdir(targdir) 57 | if qualifier(os.path.join(targdir, x)) 58 | ], 59 | key=key, 60 | ) 61 | return oldest 62 | return None 63 | 64 | 65 | class Checkpointer: 66 | """ 67 | Manages the checkpoint directory. Saves new checkpoints and deletes old ones after the specified number are written. 68 | Also handles loading and saving of checkpoints in sharded and unsharded formats. 69 | Assumes model and optimizer inputs are in FSDP. 70 | ... 71 | Args 72 | ---- 73 | ckpdir : str 74 | Absolute path to desired save location. Creates a new 'checkpoints/' subfolder at that location. 75 | n_to_save : int 76 | Number of volatile checkpoints to maintain at any given time. 77 | parallel_mode : str 78 | Write sharded folder ckps (when sharded: 'fsdp' or 'hsdp') or unsharded file ckps (when sharded: 'ddp') 79 | report_fn : Callable or None 80 | Optional function for reporting or logging status updates. Expected to handle arbitrary *args, **kwargs. 81 | Defaults to self._selective_print(). 82 | model_auto_placement : bool 83 | Optional; If True, auto detect GPU device to move model to, as set in device mesh init 84 | 85 | Methods 86 | ------- 87 | save : keyword args -> str | None 88 | Saves dictionary of keyword arg key/value pairs to specified checkpoint directory, deleting old checkpoints 89 | as necessary. If a checkpoint is deleted, returns the filename of that checkpoint. 90 | load : 91 | See docstring for individual function below 92 | """ 93 | 94 | def __init__( 95 | self, 96 | ckpdir, 97 | n_to_save, 98 | parallel_mode, 99 | rank, 100 | local_rank, 101 | report_fn=None, 102 | model_auto_placement=False, 103 | ): 104 | self.max_ckps = n_to_save 105 | self.rank = rank 106 | self.local_rank = local_rank 107 | self.ckp_path = os.path.join(ckpdir, "checkpoints/") 108 | os.makedirs(self.ckp_path, exist_ok=True) 109 | self.p_mode = parallel_mode 110 | assert parallel_mode in ["fsdp", "hsdp", "ddp"] 111 | self.report = self._selective_print if report_fn is None else report_fn 112 | self.model_auto_placement = model_auto_placement 113 | 114 | def _selective_print(self, *args, **kwargs): 115 | if self.rank == 0: 116 | print(*args) 117 | for k, v in kwargs.items(): 118 | print(k, "=", v) 119 | 120 | def _cleanup(self): 121 | # Clean old checkpoints. Barrier to keep synchronization correct. 122 | file_to_remove = None 123 | if ( 124 | self.rank == 0 125 | and len([x for x in os.listdir(self.ckp_path) if "tmp" in x]) 126 | > self.max_ckps 127 | ): 128 | ckp_to_remove = Path( 129 | get_oldest(self.ckp_path, qualifier=lambda x: "tmp" in x) 130 | ) 131 | if os.path.isfile(ckp_to_remove): 132 | ckp_to_remove.unlink() 133 | else: 134 | shutil.rmtree(ckp_to_remove) 135 | return file_to_remove 136 | 137 | def _do_save(self, rank, local_rank): # , shard_group, replicate_group): 138 | if self.p_mode == "hsdp": 139 | return rank == local_rank 140 | else: 141 | return True 142 | # TODO: Distributed writing contingent upon the following fix: https://github.com/pytorch/pytorch/issues/104081 143 | # if not is_dist: 144 | # return (rank == local_rank) 145 | # else: 146 | # a = rank % shard_group.size() 147 | # b = rank // shard_group.size() 148 | # return True if a % replicate_group.size() == b else False 149 | # shard_group = model.process_group 150 | # replicate_group = model.__inter_node_state.process_group 151 | 152 | def _write(self, state_dict, loader_state, process_group, save_name, rank): 153 | os.makedirs(save_name, exist_ok=True) 154 | writer = FileSystemWriter(save_name, single_file_per_rank=True) 155 | if state_dict is not None: 156 | save_state_dict( 157 | state_dict=state_dict, 158 | storage_writer=writer, 159 | process_group=process_group, 160 | planner=DefaultSavePlanner(), 161 | ) 162 | if loader_state is not None: 163 | loader_state.save_to_path(save_name) 164 | 165 | def _validate_ckp_path(self, path): 166 | """Interpret path to appropriate checkpoint. If found, return modified path. If not found, return None.""" 167 | # Does path exist and is it non-empty? 168 | if os.path.exists(path): 169 | # Is this a file? 170 | if os.path.isfile(path): 171 | return path 172 | # Is this a sharded directory? 173 | elif "metadata.pth" in os.listdir(path): 174 | return path 175 | # Is this a path to a set of checkpoints? 176 | elif len(os.listdir(path)) > 0: 177 | latest = get_latest(path) 178 | if os.path.isfile(latest): 179 | return latest 180 | elif "metadata.pth" in os.listdir(latest): 181 | return latest 182 | return None 183 | 184 | def load( 185 | self, 186 | model, 187 | optimizer, 188 | dataloader, 189 | path="", 190 | reset_stepcount=False, 191 | strict=True, 192 | is_compiled=False, 193 | ): 194 | """ 195 | Handle checkpoint loading for model/optimizer/dataloader from given path, according to arguments. 196 | Defaults to save path for locating an appropriate checkpoint. If a path is provided, will use 197 | it only if no appropriate checkpoint is found in the save path (in which case it's a job restart). 198 | Reset_stepcount manually resets optimizer and dataloader states, and stat tracking. 199 | Strict determines whether to use strict loading or not FOR SINGLEFILE LOADING ONLY. 200 | Returns model, optimizer, dataloader, current step, and current tokens seen. 201 | """ 202 | is_resuming = False 203 | if self._validate_ckp_path(self.ckp_path) is not None: 204 | path = self.ckp_path 205 | is_resuming = True 206 | load_path = self._validate_ckp_path(path) 207 | if load_path is None: 208 | self.report( 209 | f"No valid checkpoint detected at {path}, starting from scratch." 210 | ) 211 | return model, optimizer, dataloader, 0, 0, False 212 | else: 213 | self.report(f"Prior checkpoint {load_path} detected.") 214 | model_load_time = time.time() 215 | if os.path.isfile(load_path): 216 | checkpoint_data = torch.load(load_path, map_location="cpu") 217 | if is_compiled: 218 | model._orig_mod.load_state_dict( 219 | checkpoint_data.get("model_state"), strict=strict 220 | ) 221 | else: 222 | model.load_state_dict( 223 | checkpoint_data.get("model_state"), strict=strict 224 | ) 225 | if self.model_auto_placement: 226 | model.to("cuda") 227 | else: 228 | model.to(self.local_rank) 229 | self.report( 230 | f"Checkpoint {load_path} is a single-file checkpoint containing only a model. Optimizer and dataloader are from scratch.", 231 | model_load_time=time.time() - model_load_time, 232 | ) 233 | return model, optimizer, dataloader, 0, 0, is_resuming 234 | else: 235 | # Load model 236 | with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): 237 | state_dict = model.state_dict() 238 | model_ckp = {"model_state": state_dict} 239 | load_state_dict( 240 | state_dict=model_ckp, 241 | storage_reader=FileSystemReader(load_path), 242 | planner=DefaultLoadPlanner(), 243 | ) 244 | model.load_state_dict(model_ckp["model_state"]) 245 | if self.model_auto_placement: 246 | model.to("cuda") 247 | else: 248 | model.to(self.local_rank) 249 | self.report(model_load_time=time.time() - model_load_time) 250 | step = 0 251 | ntok = 0 252 | # Load metadata 253 | if is_resuming: 254 | metadata = torch.load(os.path.join(load_path, "metadata.pth")) 255 | step = metadata.get("step", 0) 256 | ntok = metadata.get("tokens_seen", 0) 257 | self.report("Metadata loaded", start_step=step, n_tokens_seen=ntok) 258 | # Load optimizer 259 | if optimizer is not None: 260 | optim_load_time = time.time() 261 | with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): 262 | optim_state = load_sharded_optimizer_state_dict( 263 | model_state_dict=model.state_dict(), 264 | optimizer_key="optimizer_state", 265 | storage_reader=FileSystemReader(load_path), 266 | ) 267 | flattened_osd = FSDP.optim_state_dict_to_load( 268 | model, optimizer, optim_state["optimizer_state"] 269 | ) 270 | optimizer.load_state_dict(flattened_osd) 271 | self.report(optimizer_load_time=time.time() - optim_load_time) 272 | else: 273 | self.report("Skipping optimizer load, no optimizer provided.") 274 | # Load dataset 275 | if dataloader is not None: 276 | data_load_time = time.time() 277 | dataloader.dataset.load_from_path(path) 278 | self.report(dataset_load_time=time.time() - data_load_time) 279 | else: 280 | self.report("Skipping dataset load, no dataloader provided.") 281 | return model, optimizer, dataloader, step, ntok, is_resuming 282 | 283 | def save( 284 | self, 285 | step, 286 | model, 287 | optimizer, 288 | dataloader, 289 | **kwargs, 290 | ): 291 | # Note: metadata kwargs cannot contain any of: 292 | # (step, model, optimizer, dataloader) 293 | rank = self.rank 294 | save_time = time.time() 295 | with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): 296 | model_state = model.state_dict() 297 | optim_state = FSDP.sharded_optim_state_dict(model, optimizer) 298 | dataloader_state = None if dataloader is None else dataloader.dataset 299 | 300 | save_name = os.path.join(self.ckp_path, "step_" + str(step) + "_ckp") 301 | state_dict = {"model_state": model_state, "optimizer_state": optim_state} 302 | if self._do_save(rank, self.local_rank): 303 | self._write( 304 | state_dict, dataloader_state, model.process_group, save_name, rank 305 | ) 306 | else: 307 | self._write(None, dataloader_state, None, save_name, rank) 308 | if rank == 0: 309 | metadata = kwargs 310 | metadata["step"] = step 311 | torch.save(metadata, os.path.join(save_name, "metadata.pth")) 312 | self.report( 313 | f"Checkpoint saved in {save_name}", model_save_time=time.time() - save_time 314 | ) 315 | 316 | return self._cleanup() 317 | -------------------------------------------------------------------------------- /fms_fsdp/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | from fms.models.llama import LLaMAConfig 2 | 3 | from fms_fsdp.config import train_config 4 | 5 | 6 | def update_config(config, **kwargs): 7 | if isinstance(config, (tuple, list)): 8 | for c in config: 9 | update_config(c, **kwargs) 10 | else: 11 | for k, v in kwargs.items(): 12 | if hasattr(config, k): 13 | setattr(config, k, v) 14 | elif "." in k: 15 | config_name, param_name = k.split(".") 16 | if type(config).__name__ == config_name: 17 | if hasattr(config, param_name): 18 | setattr(config, param_name, v) 19 | else: 20 | print(f"Warning: {config_name} does not accept parameter: {k}") 21 | elif isinstance(config, train_config): 22 | print(f"Warning: unknown parameter {k}") 23 | 24 | 25 | def get_model_config(model_variant): 26 | if model_variant == "llama2_70b": 27 | model_config = LLaMAConfig( 28 | emb_dim=8192, 29 | multiple_of=4096, 30 | nheads=64, 31 | kvheads=8, 32 | nlayers=80, 33 | hidden_grow_factor=28672 / 8192, 34 | ) 35 | elif model_variant == "llama2_34b": 36 | model_config = LLaMAConfig( 37 | emb_dim=8192, 38 | nheads=64, 39 | kvheads=8, 40 | nlayers=48, 41 | hidden_grow_factor=22016 / 8192, 42 | max_expected_seq_len=16384, 43 | rope_theta=1000000.0, 44 | ) 45 | elif model_variant == "llama2_13b": 46 | model_config = LLaMAConfig( 47 | emb_dim=5120, 48 | nheads=40, 49 | nlayers=40, 50 | hidden_grow_factor=13824 / 5120, 51 | ) 52 | elif model_variant == "llama2_7b": 53 | model_config = LLaMAConfig( 54 | hidden_grow_factor=11008 / 4096, 55 | kvheads=32, 56 | ) 57 | elif model_variant == "llama2_1.4b": 58 | model_config = LLaMAConfig( 59 | emb_dim=2048, 60 | nheads=16, 61 | nlayers=24, 62 | hidden_grow_factor=3, 63 | kvheads=4, 64 | ) 65 | elif model_variant == "llama3_8b": 66 | model_config = LLaMAConfig( 67 | src_vocab_size=128256, 68 | emb_dim=4096, 69 | nheads=32, 70 | kvheads=8, 71 | nlayers=32, 72 | hidden_grow_factor=3.5, 73 | max_expected_seq_len=8192, 74 | rope_theta=500000.0, 75 | ) 76 | elif model_variant == "llama3_8b_4k": 77 | model_config = LLaMAConfig( 78 | src_vocab_size=128256, 79 | emb_dim=4096, 80 | nheads=32, 81 | kvheads=8, 82 | nlayers=32, 83 | hidden_grow_factor=3.5, 84 | max_expected_seq_len=4096, 85 | rope_theta=500000.0, 86 | ) 87 | elif model_variant == "llama3_1.8b": 88 | model_config = LLaMAConfig( 89 | src_vocab_size=128256, 90 | emb_dim=2048, 91 | nheads=16, 92 | kvheads=8, 93 | nlayers=24, 94 | hidden_grow_factor=3.5, 95 | max_expected_seq_len=8192, 96 | rope_theta=500000.0, 97 | ) 98 | elif model_variant == "llama3_1.8b_4k": 99 | model_config = LLaMAConfig( 100 | src_vocab_size=128256, 101 | emb_dim=2048, 102 | nheads=16, 103 | kvheads=8, 104 | nlayers=24, 105 | hidden_grow_factor=3.5, 106 | max_expected_seq_len=4096, 107 | rope_theta=500000.0, 108 | ) 109 | elif model_variant == "llama3_3.2b": 110 | model_config = LLaMAConfig( 111 | src_vocab_size=128256, 112 | emb_dim=3072, 113 | nheads=24, 114 | kvheads=8, 115 | nlayers=24, 116 | hidden_grow_factor=8 / 3, 117 | max_expected_seq_len=8192, 118 | rope_theta=500000.0, 119 | ) 120 | elif model_variant == "llama3_3.2b_4k": 121 | model_config = LLaMAConfig( 122 | src_vocab_size=128256, 123 | emb_dim=3072, 124 | nheads=24, 125 | kvheads=8, 126 | nlayers=24, 127 | hidden_grow_factor=8 / 3, 128 | max_expected_seq_len=4096, 129 | rope_theta=500000.0, 130 | ) 131 | elif model_variant == "llama3_70b": 132 | model_config = LLaMAConfig( 133 | src_vocab_size=128256, 134 | emb_dim=8192, 135 | nheads=64, 136 | kvheads=8, 137 | nlayers=80, 138 | hidden_grow_factor=3.5, 139 | max_expected_seq_len=8192, 140 | rope_theta=500000.0, 141 | ) 142 | elif model_variant == "llama3_70b_4k": 143 | model_config = LLaMAConfig( 144 | src_vocab_size=128256, 145 | emb_dim=8192, 146 | nheads=64, 147 | kvheads=8, 148 | nlayers=80, 149 | hidden_grow_factor=3.5, 150 | max_expected_seq_len=4096, 151 | rope_theta=500000.0, 152 | ) 153 | elif model_variant == "llama3_194m_4k": 154 | model_config = LLaMAConfig( 155 | src_vocab_size=128256, 156 | emb_dim=1024, 157 | nheads=8, 158 | nlayers=10, 159 | max_expected_seq_len=4096, 160 | rope_theta=500000.0, 161 | ) 162 | elif model_variant == "mamba_9.8b": 163 | model_config = { 164 | "d_model": 4096, 165 | "d_intermediate": 14336, 166 | "n_layer": 32, 167 | "vocab_size": 128256, 168 | "ssm_cfg": {"layer": "Mamba2"}, 169 | "attn_layer_idx": [9, 18, 27], 170 | "attn_cfg": { 171 | "causal": True, 172 | "d_conv": 0, 173 | "head_dim": 128, 174 | "num_heads": 32, 175 | "num_heads_kv": 8, 176 | "out_proj_bias": False, 177 | "qkv_proj_bias": False, 178 | "rotary_emb_dim": 64, 179 | }, 180 | "rms_norm": True, 181 | "residual_in_fp32": True, 182 | "fused_add_norm": True, 183 | "pad_vocab_size_multiple": 16, 184 | "tie_embeddings": False, 185 | } 186 | else: 187 | raise ValueError(f"model variant {model_variant} not supported.") 188 | 189 | return model_config 190 | -------------------------------------------------------------------------------- /fms_fsdp/utils/dataloader_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from fms_fsdp.utils.dataset_utils import ( 4 | ArrowHandler, 5 | AutoHandler, 6 | BufferDataset, 7 | CheckpointDataset, 8 | ParquetHandler, 9 | PreloadBufferDataset, 10 | PreprocessDataset, 11 | SamplingDataset, 12 | ScalableShardDataset, 13 | StreamingDocDataset, 14 | ) 15 | 16 | 17 | _handler_map = { 18 | "arrow": ArrowHandler, 19 | "hf_parquet": ParquetHandler, 20 | "auto": AutoHandler, 21 | } 22 | 23 | 24 | def causal_lm(data_seq, prompt_len=1): 25 | """ 26 | Perform causal language modeling by right-shifting the input sequence. 27 | Sets first prompt_len tokens to be ignored by the loss. 28 | """ 29 | data_seq = torch.tensor(data_seq, dtype=torch.int) 30 | t = data_seq.clone()[1:] 31 | data_seq = data_seq[:-1] 32 | t[:prompt_len] = -100 33 | return data_seq, t 34 | 35 | 36 | def get_dummy_loader(cfg, rank, world_size): 37 | """ 38 | A simple dummy dataloader yielding incrementing vocab indices in an infinite loop 39 | """ 40 | 41 | class SteadyCounter(torch.utils.data.IterableDataset): 42 | # Spit out incremental counts of constant length l, modulo vocab size v 43 | def __init__(self, l, v): 44 | self.i = 0 45 | self.l = l 46 | self.v = v 47 | 48 | def __iter__(self): 49 | while True: 50 | out = torch.IntTensor( 51 | [x % self.v for x in range(self.i, self.i + self.l)] 52 | ) 53 | yield out, out 54 | self.i += self.l 55 | 56 | data = SteadyCounter(cfg.seq_length, cfg.vocab_size) 57 | return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size) 58 | 59 | 60 | def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): 61 | """ 62 | Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training. 63 | Assumes underlying data is sequences of integer values. 64 | ... 65 | Args 66 | ---- 67 | cfg : dataclass 68 | Training config containing seq len, dataset, dataset weight, datapath, etc. arguments 69 | rank : int 70 | Rank of current distributed worker. Used for handling dataset sharding logic. 71 | world_size : int 72 | Number of distributed workers. Used for handling dataset sharding logic. 73 | postprocess : List[Callable] 74 | Any task-specific postprocessing to apply before handing over data. Steps will apply in 75 | the order provided by the user. For CLM training, use postprocess=[causal_lm]. 76 | """ 77 | 78 | datasets, weights = parse_data_args(cfg.datasets, cfg.weights) 79 | 80 | # Base streaming dataset. Returns doc chunks in sequence. 81 | # Implements dataset sampling and rescalability. 82 | droplist = [ 83 | int(x.strip()) for x in cfg.strip_tokens.split(",") if len(x.strip()) > 0 84 | ] 85 | droplist = droplist + [cfg.bos_token, cfg.eos_token, cfg.bol_token, cfg.eol_token] 86 | assert ( 87 | cfg.file_type in _handler_map 88 | ), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})" 89 | if cfg.file_type == "hf_parquet" or cfg.file_type == "auto": 90 | filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cfg.col_name) 91 | else: 92 | filehandler = _handler_map[cfg.file_type] 93 | # Base reader layer 94 | data = StreamingDocDataset( 95 | cfg.data_path, 96 | rank, 97 | world_size, 98 | filehandler, 99 | cfg.eos_token, 100 | bos_token=cfg.bos_token, 101 | strip_tokens=set(droplist), 102 | min_length=3, 103 | seed=cfg.seed, 104 | ) 105 | # Add rescaling/resharding 106 | data = ScalableShardDataset( 107 | data, 108 | cfg.eos_token, 109 | n_logical_shards=cfg.logical_shards, 110 | ) 111 | # Add multi-dataset handling 112 | data = SamplingDataset( 113 | cfg.data_path, 114 | data, 115 | cfg.eos_token, 116 | datasets=datasets, 117 | weights=weights, 118 | verbose=(rank == 0), 119 | ) 120 | # Wrap above dataset in packing logic to form constant-length lines. 121 | data = BufferDataset( 122 | data, 123 | cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1, 124 | bos_token=cfg.bol_token, 125 | eos_token=cfg.eol_token, 126 | pack_hard=True, 127 | ) 128 | # Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average. 129 | data = PreloadBufferDataset(data, 10000) 130 | 131 | # Apply desired postprocessing steps in sequence 132 | data = PreprocessDataset(data, torch.IntTensor) 133 | for p in postprocess: 134 | data = PreprocessDataset(data, p) 135 | 136 | # Enable auto-saving 137 | data = CheckpointDataset( 138 | data, 139 | cfg.ckpt_load_path if cfg.resuming_dataset else cfg.ckpt_save_path, 140 | cfg.checkpoint_interval, 141 | cfg.batch_size, 142 | cfg.ckpt_save_path, 143 | ) 144 | return torch.utils.data.DataLoader( 145 | data, num_workers=cfg.num_workers, batch_size=cfg.batch_size 146 | ) 147 | 148 | 149 | def parse_data_args(datas, weights): 150 | # Convert csv inputs into corresponding lists of values 151 | def splitstrip(x): 152 | if isinstance(x, str): 153 | return [item.strip() for item in x.split(",")] 154 | elif isinstance(x, (list, tuple)): 155 | return list(x) 156 | elif isinstance(x, (int, float, complex)): 157 | return [x] 158 | else: 159 | raise ValueError(f"arg input {x} cannot be parsed.") 160 | 161 | datas = splitstrip(datas) 162 | weights = [float(x) for x in splitstrip(weights)] 163 | return datas, weights 164 | -------------------------------------------------------------------------------- /fms_fsdp/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import asdict 3 | from functools import partial 4 | 5 | 6 | try: 7 | import packaging.version 8 | except ImportError: 9 | from pkg_resources import packaging # type: ignore 10 | 11 | import time 12 | from datetime import timedelta 13 | 14 | import torch.cuda.nccl as nccl 15 | import torch.distributed as dist 16 | from torch.distributed.fsdp import ShardingStrategy 17 | 18 | from fms_fsdp.policies import * 19 | 20 | 21 | def train( 22 | cfg, 23 | model, 24 | local_rank, 25 | rank, 26 | train_loader, 27 | optimizer, 28 | scheduler, 29 | profiler, 30 | checkpointer, 31 | start_step, 32 | tokens_seen, 33 | ): 34 | if cfg.tracker: 35 | if cfg.tracker not in ["wandb", "aim"]: 36 | raise ValueError(f"tracker {cfg.tracker} not supported.") 37 | tracker_dir = cfg.tracker_dir 38 | project_name = cfg.tracker_project_name 39 | run_id = cfg.tracker_run_id 40 | 41 | if cfg.tracker == "wandb": 42 | try: 43 | import wandb # type: ignore 44 | except ImportError: 45 | raise ImportError("tracker is set to wandb but wandb is not installed.") 46 | if rank == 0: 47 | print(f"--> wandb is enabled!") 48 | try: 49 | wandb.init( 50 | project=project_name, 51 | dir=tracker_dir, 52 | resume="allow", 53 | id=run_id, 54 | ) 55 | except wandb.errors.UsageError: 56 | raise ValueError( 57 | "wandb failed to init, did you pass your wandb api key via WANDB_API_KEY?" 58 | ) 59 | wandb.config = asdict(cfg) 60 | 61 | if cfg.tracker == "aim": 62 | try: 63 | from aim import Run # type: ignore 64 | except ImportError: 65 | raise ImportError("tracker is set to aim but aim is not installed.") 66 | if rank == 0: 67 | print(f"--> aim is enabled!") 68 | run = Run( 69 | experiment=project_name, 70 | repo=tracker_dir, 71 | run_hash=run_id, 72 | ) 73 | run["hparams"] = asdict(cfg) 74 | 75 | model.train() 76 | ddp_stats = torch.zeros(3).to(local_rank) 77 | 78 | start = time.time() 79 | loop_start = time.time() 80 | train_loss = -1 81 | for batch_idx, (input, label) in enumerate(train_loader, start=start_step + 1): 82 | if batch_idx > cfg.num_steps: 83 | break 84 | input = input.to(local_rank) 85 | label = label.to(local_rank) 86 | 87 | optimizer.zero_grad() 88 | output = model(input) 89 | output = output.logits if hasattr(output, "logits") else output 90 | ce_loss = torch.nn.CrossEntropyLoss() 91 | loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long()) 92 | 93 | loss.backward() 94 | ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item() 95 | optimizer.step() 96 | scheduler.step() 97 | 98 | ddp_stats[0] += loss.item() 99 | ddp_stats[2] += 1 100 | 101 | if profiler: 102 | profiler.step() 103 | 104 | if batch_idx % cfg.report_interval == 0: 105 | dist.all_reduce(ddp_stats, op=dist.ReduceOp.SUM) 106 | train_loss = ddp_stats[0] / ddp_stats[2] 107 | g_norm = ddp_stats[1] / ddp_stats[2] 108 | elapsed_time = time.time() - loop_start 109 | world_size = int(os.environ["WORLD_SIZE"]) 110 | new_tokens_seen = ( 111 | (batch_idx - start_step) * world_size * cfg.batch_size * cfg.seq_length 112 | ) 113 | if rank == 0: 114 | total_tokens_seen = tokens_seen + new_tokens_seen 115 | current_loss = train_loss.item() 116 | current_lr = scheduler.get_last_lr()[0] 117 | current_gnorm = g_norm.item() 118 | current_step_time = (time.time() - start) / cfg.report_interval 119 | overall_step_time = elapsed_time / (batch_idx - start_step) 120 | current_throughput = int( 121 | cfg.batch_size * cfg.seq_length / current_step_time 122 | ) 123 | overall_throughput = int( 124 | cfg.batch_size * cfg.seq_length / overall_step_time 125 | ) 126 | reserved_mem = torch.cuda.max_memory_reserved( 127 | device=torch.cuda.current_device() 128 | ) 129 | allocated_mem = torch.cuda.max_memory_allocated( 130 | device=torch.cuda.current_device() 131 | ) 132 | 133 | print("step:", batch_idx) 134 | print("loss:", current_loss) 135 | print("LR:", current_lr) 136 | print("tokens seen:", total_tokens_seen) 137 | print("gradient norm:", current_gnorm) 138 | print("reserved memory:", reserved_mem) 139 | print("allocated memory:", allocated_mem) 140 | print("current step time:", current_step_time) 141 | print("overall step time:", overall_step_time) 142 | print("current token per gpu per sec:", current_throughput) 143 | print("overall token per gpu per sec:", overall_throughput) 144 | print( 145 | "overall token per day:", 146 | int(new_tokens_seen / elapsed_time * 3600 * 24), 147 | ) 148 | if cfg.tracker: 149 | vals_to_track = { 150 | "learning rate": current_lr, 151 | "loss": current_loss, 152 | "gradient norm": current_gnorm, 153 | "token seen": total_tokens_seen, 154 | "current throughput (token per gpu per sec)": current_throughput, 155 | "overall throughput (token per gpu per sec)": overall_throughput, 156 | "gpu reserved memory": reserved_mem, 157 | "gpu allocated memory": allocated_mem, 158 | } 159 | if cfg.tracker == "wandb": 160 | tracker_fn = wandb.log 161 | elif cfg.tracker == "aim": 162 | tracker_fn = run.track 163 | tracker_fn(vals_to_track, step=batch_idx) 164 | 165 | start = time.time() 166 | ddp_stats.zero_() 167 | torch.cuda.reset_peak_memory_stats(device=torch.cuda.current_device()) 168 | 169 | if batch_idx % cfg.checkpoint_interval == 0 or batch_idx == cfg.num_steps: 170 | checkpointer.save( 171 | batch_idx, 172 | model, 173 | optimizer, 174 | None, 175 | tokens_seen=tokens_seen + new_tokens_seen, 176 | ) 177 | 178 | return train_loss 179 | 180 | 181 | def setup(): 182 | dist.init_process_group("nccl", timeout=timedelta(seconds=60 * 60)) 183 | 184 | 185 | def setup_environ_flags(): 186 | os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1) 187 | os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1) 188 | 189 | 190 | def get_mixed_precision_policy(cfg, rank): 191 | verify_bfloat_support = ( 192 | torch.version.cuda 193 | and torch.cuda.is_bf16_supported() 194 | and packaging.version.parse(torch.version.cuda).release >= (11, 0) 195 | and dist.is_nccl_available() 196 | and nccl.version() >= (2, 10) 197 | ) 198 | 199 | if cfg.mixed_precision: 200 | bf16_ready = verify_bfloat_support 201 | if bf16_ready: 202 | mixed_precision_policy = bfSixteen 203 | if rank == 0: 204 | print(f"bFloat16 enabled for mixed precision - using bfSixteen policy") 205 | else: 206 | mixed_precision_policy = fpSixteen 207 | if rank == 0: 208 | print(f"FP16 enabled") 209 | else: 210 | mixed_precision_policy = None 211 | 212 | return mixed_precision_policy 213 | 214 | 215 | def get_policies(cfg, rank, block): 216 | """Get policies for mixed precision, wrapping, sharding, ac and param init function.""" 217 | 218 | # mixed precision 219 | mixed_precision_policy = get_mixed_precision_policy(cfg, rank) 220 | 221 | # wrapping policy 222 | wrapping_policy = get_wrapper(block) 223 | 224 | # sharding strategy 225 | if cfg.sharding_strategy == "fsdp": 226 | sharding_strategy = ShardingStrategy.FULL_SHARD 227 | elif cfg.sharding_strategy == "hsdp": 228 | sharding_strategy = ShardingStrategy.HYBRID_SHARD 229 | elif cfg.sharding_strategy == "ddp": 230 | sharding_strategy = ShardingStrategy.NO_SHARD 231 | else: 232 | sharding_strategy = ShardingStrategy.FULL_SHARD 233 | if rank == 0: 234 | print(f"Sharding strategy = {cfg.sharding_strategy}") 235 | 236 | # ac handler 237 | apply_selective_ac = partial(apply_fsdp_checkpointing, block=block) 238 | 239 | # param init function 240 | if cfg.low_cpu_fsdp: 241 | param_init_fn = param_init_function 242 | else: 243 | param_init_fn = None 244 | 245 | return ( 246 | mixed_precision_policy, 247 | wrapping_policy, 248 | sharding_strategy, 249 | apply_selective_ac, 250 | param_init_fn, 251 | ) 252 | 253 | 254 | def get_profiler(cfg, rank): 255 | if not cfg.use_profiler: 256 | return 257 | if cfg.profiler_rank0_only and rank != 0: 258 | return 259 | return torch.profiler.profile( 260 | activities=[ 261 | torch.profiler.ProfilerActivity.CPU, 262 | torch.profiler.ProfilerActivity.CUDA, 263 | ], 264 | schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1), 265 | on_trace_ready=torch.profiler.tensorboard_trace_handler("profile_traces"), 266 | profile_memory=True, 267 | with_stack=False, 268 | record_shapes=True, 269 | ) 270 | -------------------------------------------------------------------------------- /fms_to_hf_llama.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import torch 3 | from fms.models.hf.utils import to_hf_api 4 | from fms.models.llama import LLaMA 5 | from torch.distributed._shard.checkpoint import FileSystemReader, load_state_dict 6 | from transformers import LlamaConfig, LlamaForCausalLM 7 | 8 | from fms_fsdp.utils.config_utils import get_model_config 9 | 10 | 11 | def convert_to_hf(model: LLaMA, model_variant, is_old_fms) -> LlamaForCausalLM: 12 | fms_hf_model = to_hf_api(model) 13 | hf_config = fms_hf_model.config 14 | if "llama3" in model_variant: 15 | hf_config.bos_token_id = 128000 16 | hf_config.eos_token_id = 128001 17 | oss_hf_model = LlamaForCausalLM( 18 | LlamaConfig( 19 | vocab_size=hf_config.vocab_size, 20 | hidden_size=hf_config.hidden_size, 21 | rms_norm_eps=hf_config.norm_eps, 22 | num_attention_heads=hf_config.nheads, 23 | num_key_value_heads=None if hf_config.kvheads == 0 else hf_config.kvheads, 24 | num_hidden_layers=hf_config.nlayers, 25 | intermediate_size=hf_config.multiple_of 26 | * ( 27 | ( 28 | int(hf_config.hidden_grow_factor * hf_config.hidden_size) 29 | + hf_config.multiple_of 30 | - 1 31 | ) 32 | // hf_config.multiple_of 33 | ), 34 | pad_token_id=( 35 | None if hf_config.pad_token_id == -1 else hf_config.pad_token_id 36 | ), 37 | bos_token_id=hf_config.bos_token_id, 38 | eos_token_id=hf_config.eos_token_id, 39 | max_position_embeddings=hf_config.max_expected_seq_len, 40 | ) 41 | ) 42 | 43 | # compute the freq from rot_emb since it is gathered lazily 44 | rot_emb = fms_hf_model.decoder.model.rot_emb 45 | max_seq_len = rot_emb.max_seq_len 46 | alpha = rot_emb._alpha(max_seq_len) 47 | ratio = rot_emb.ratio 48 | dim = rot_emb.dim 49 | if rot_emb.ntk_scaling: 50 | ratio = ratio * alpha ** (dim / (dim - 2)) 51 | freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 52 | 53 | with torch.no_grad(): 54 | oss_hf_model.model.embed_tokens.weight.copy_(fms_hf_model.embedding.weight) 55 | i = 0 56 | for oss_hf_layer in oss_hf_model.model.layers: 57 | fms_hf_layer = fms_hf_model.decoder.model.layers[i] 58 | 59 | # self attn 60 | if is_old_fms: 61 | oss_hf_layer.self_attn.q_proj.weight.copy_( 62 | fms_hf_layer.attn.query.weight 63 | ) 64 | oss_hf_layer.self_attn.k_proj.weight.copy_(fms_hf_layer.attn.key.weight) 65 | oss_hf_layer.self_attn.v_proj.weight.copy_( 66 | fms_hf_layer.attn.value.weight 67 | ) 68 | else: 69 | q, k, v = torch.split( 70 | fms_hf_layer.attn.in_proj.qkv_fused.weight, 71 | fms_hf_layer.attn.in_proj.splits, 72 | dim=0, 73 | ) 74 | oss_hf_layer.self_attn.q_proj.weight.copy_(q) 75 | oss_hf_layer.self_attn.k_proj.weight.copy_(k) 76 | oss_hf_layer.self_attn.v_proj.weight.copy_(v) 77 | oss_hf_layer.self_attn.o_proj.weight.copy_(fms_hf_layer.attn.dense.weight) 78 | oss_hf_layer.self_attn.rotary_emb.inv_freqs = freqs 79 | 80 | # mlp 81 | if is_old_fms: 82 | oss_hf_layer.mlp.gate_proj.weight.copy_( 83 | fms_hf_layer.ff_sub_layer.wg.weight 84 | ) 85 | oss_hf_layer.mlp.up_proj.weight.copy_( 86 | fms_hf_layer.ff_sub_layer.w1.weight 87 | ) 88 | else: 89 | wg1_fused = fms_hf_layer.ff_sub_layer.wg1_fused.weight 90 | wg_splits = [wg1_fused.size(0) // 2, wg1_fused.size(0) // 2] 91 | wg, w1 = torch.split( 92 | fms_hf_layer.ff_sub_layer.wg1_fused.weight, wg_splits, dim=0 93 | ) 94 | oss_hf_layer.mlp.gate_proj.weight.copy_(wg) 95 | oss_hf_layer.mlp.up_proj.weight.copy_(w1) 96 | oss_hf_layer.mlp.down_proj.weight.copy_(fms_hf_layer.ff_sub_layer.w2.weight) 97 | 98 | # layer norm 99 | oss_hf_layer.input_layernorm.weight.copy_(fms_hf_layer.ln.weight) 100 | oss_hf_layer.post_attention_layernorm.weight.copy_( 101 | fms_hf_layer.ff_ln.weight 102 | ) 103 | 104 | # adjust q, k 105 | q = oss_hf_layer.self_attn.q_proj.weight.data 106 | q = ( 107 | q.view(hf_config.nheads, -1, 2, q.size(1)) 108 | .transpose(1, 2) 109 | .reshape(*q.size()) 110 | ) 111 | oss_hf_layer.self_attn.q_proj.weight.copy_(q) 112 | 113 | k = oss_hf_layer.self_attn.k_proj.weight.data 114 | k = ( 115 | k.view( 116 | hf_config.nheads if hf_config.kvheads == 0 else hf_config.kvheads, 117 | -1, 118 | 2, 119 | k.size(1), 120 | ) 121 | .transpose(1, 2) 122 | .reshape(*k.size()) 123 | ) 124 | oss_hf_layer.self_attn.k_proj.weight.copy_(k) 125 | 126 | i = i + 1 127 | oss_hf_model.model.norm.weight = fms_hf_model.decoder.model.dec_norm.weight 128 | oss_hf_model.lm_head.weight = fms_hf_model.lm_head.weight 129 | 130 | return oss_hf_model 131 | 132 | 133 | def main( 134 | model_variant, compiled, is_old_fms, load_path, save_path, tokenizer_name_or_path 135 | ): 136 | print("Initializing model...") 137 | llama_config = get_model_config(model_variant) 138 | with torch.device("meta"): 139 | model = LLaMA(llama_config) 140 | model.to_empty(device="cpu") 141 | 142 | print(f"Reading state dict from {load_path}") 143 | if not compiled: 144 | state_dict = {"model_state": model.state_dict()} 145 | else: 146 | state_dict = {"model_state": {"_orig_mod": model.state_dict()}} 147 | load_state_dict( 148 | state_dict=state_dict, storage_reader=FileSystemReader(load_path), no_dist=True 149 | ) 150 | 151 | print("Loading state dict into the model...") 152 | if not compiled: 153 | model.load_state_dict(state_dict["model_state"]) 154 | else: 155 | model.load_state_dict(state_dict["model_state"]["_orig_mod"]) 156 | 157 | print("Converting to HF model..") 158 | hf_model = convert_to_hf(model, model_variant, is_old_fms) 159 | hf_model.save_pretrained(save_path) 160 | 161 | print("Copying tokenizer...") 162 | from transformers import AutoTokenizer 163 | 164 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) 165 | tokenizer.save_pretrained(save_path) 166 | 167 | print(f"Model converted to HF model, saving at {save_path}") 168 | 169 | 170 | if __name__ == "__main__": 171 | fire.Fire(main) 172 | -------------------------------------------------------------------------------- /fms_to_hf_mamba.py: -------------------------------------------------------------------------------- 1 | import fire 2 | from mamba_ssm.models.config_mamba import MambaConfig 3 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 4 | from torch.distributed._shard.checkpoint import FileSystemReader, load_state_dict 5 | 6 | from fms_fsdp.utils.config_utils import get_model_config 7 | 8 | 9 | def main(model_variant, load_path, save_path, tokenizer_name_or_path): 10 | print("Initializing model...") 11 | config_data = get_model_config(model_variant) 12 | mamba_config = MambaConfig(**config_data) 13 | model = MambaLMHeadModel(mamba_config) 14 | 15 | print(f"Reading state dict from {load_path}") 16 | state_dict = {"model_state": model.state_dict()} 17 | load_state_dict( 18 | state_dict=state_dict, storage_reader=FileSystemReader(load_path), no_dist=True 19 | ) 20 | 21 | print("Loading state dict into the model...") 22 | model.load_state_dict(state_dict["model_state"]) 23 | 24 | print("Saving model to HF-compatible format...") 25 | model.save_pretrained(save_path) 26 | 27 | print("Copying tokenizer...") 28 | from transformers import AutoTokenizer 29 | 30 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) 31 | tokenizer.save_pretrained(save_path) 32 | 33 | print(f"Model saving at {save_path}") 34 | 35 | 36 | if __name__ == "__main__": 37 | fire.Fire(main) 38 | -------------------------------------------------------------------------------- /images/loss_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/fms-fsdp/503da7ede354e1ffdabc80fae8bbd211cb2174c8/images/loss_curve.png -------------------------------------------------------------------------------- /images/lr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/fms-fsdp/503da7ede354e1ffdabc80fae8bbd211cb2174c8/images/lr.png -------------------------------------------------------------------------------- /main_training_llama.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import fire 5 | import torch 6 | import torch.optim as optim 7 | from fms.models.llama import LLaMA, LLaMABlock 8 | from torch import distributed as dist 9 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 10 | from torch.optim.lr_scheduler import LambdaLR 11 | 12 | from fms_fsdp import config 13 | from fms_fsdp.utils.checkpointing_utils import Checkpointer 14 | from fms_fsdp.utils.config_utils import get_model_config, update_config 15 | from fms_fsdp.utils.dataloader_utils import get_data_loader, get_dummy_loader 16 | from fms_fsdp.utils.train_utils import ( 17 | get_policies, 18 | get_profiler, 19 | setup, 20 | setup_environ_flags, 21 | train, 22 | ) 23 | 24 | 25 | def main(**kwargs): 26 | # get configs 27 | cfg = config.train_config() 28 | update_config(cfg, **kwargs) 29 | 30 | # ensure reproducibility 31 | torch.cuda.manual_seed(cfg.seed) 32 | torch.manual_seed(cfg.seed) 33 | 34 | # torchrun specific 35 | local_rank = int(os.environ["LOCAL_RANK"]) 36 | rank = int(os.environ["RANK"]) 37 | world_size = int(os.environ["WORLD_SIZE"]) 38 | 39 | if rank == 0: 40 | print(f"--> running with these configs {cfg}") 41 | 42 | # some setups 43 | setup() 44 | torch.cuda.set_device(local_rank) 45 | torch.cuda.empty_cache() 46 | setup_environ_flags() 47 | 48 | # get policy 49 | block = LLaMABlock 50 | ( 51 | mixed_precision_policy, 52 | wrapping_policy, 53 | sharding_strategy_policy, 54 | apply_selective_ac, 55 | param_init_fn, 56 | ) = get_policies(cfg, rank, block) 57 | 58 | # get fms model 59 | llama_config = get_model_config(cfg.model_variant) 60 | if cfg.low_cpu_fsdp: 61 | with torch.device("meta"): 62 | model = LLaMA(llama_config) 63 | else: 64 | model = LLaMA(llama_config) 65 | model.reset_parameters() 66 | 67 | if rank == 0: 68 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 69 | print(f"\n--> model has {total_params / 1e6} Million params\n") 70 | 71 | # get data loader 72 | if rank == 0: 73 | print("Constructing datasets...") 74 | if not cfg.use_dummy_dataset: 75 | train_loader = get_data_loader(cfg, rank, world_size) 76 | else: 77 | train_loader = get_dummy_loader(cfg, rank, world_size) 78 | if rank == 0: 79 | print("Datasets constructed!") 80 | 81 | # FSDP 82 | model = FSDP( 83 | model, 84 | auto_wrap_policy=wrapping_policy, 85 | mixed_precision=mixed_precision_policy, 86 | sharding_strategy=sharding_strategy_policy, 87 | use_orig_params=cfg.use_torch_compile, 88 | device_id=torch.cuda.current_device(), 89 | limit_all_gathers=True, 90 | param_init_fn=param_init_fn, 91 | ) 92 | # we need this post-fsdp call to avoid graph break with torch.compile, until we figure out a better solution. 93 | model.rot_emb.compute_freqs_cis( 94 | torch.device("cuda", torch.cuda.current_device()), 95 | model.config.max_expected_seq_len, 96 | ) 97 | 98 | # fsdp activation checkpointing 99 | if cfg.fsdp_activation_checkpointing: 100 | if rank == 0: 101 | print(f"--> applying FSDP activation checkpointing...") 102 | apply_selective_ac(model, p=cfg.selective_checkpointing) 103 | 104 | # torch compile 105 | if cfg.use_torch_compile: 106 | if rank == 0: 107 | print(f"--> enabling torch compile...") 108 | # the default accumulated_cache_size_limit=64 is not enough for 70b model, so we make it 128 here 109 | torch._dynamo.config.accumulated_cache_size_limit = 128 110 | model = torch.compile(model) 111 | 112 | # Optimizer 113 | optimizer = optim.AdamW( 114 | model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1 115 | ) 116 | 117 | # optionally load from checkpoint (when continue pretraining) 118 | checkpointer = Checkpointer( 119 | cfg.ckpt_save_path, 1000, cfg.sharding_strategy, rank, local_rank 120 | ) 121 | model, optimizer, _, start_step, tokens_seen, is_resuming = checkpointer.load( 122 | model, 123 | optimizer, 124 | None, 125 | path=os.path.join(cfg.ckpt_load_path, "checkpoints/") 126 | if not os.path.isfile(cfg.ckpt_load_path) 127 | else cfg.ckpt_load_path, 128 | strict=False, 129 | ) 130 | if not is_resuming: 131 | start_step = 0 132 | # Override loaded optim hyperparams with the current values 133 | for g in optimizer.param_groups: 134 | g["initial_lr"] = cfg.learning_rate 135 | 136 | # LR schedule 137 | if cfg.training_stage == "annealing": 138 | schedule = lambda x: 1 - x / cfg.num_steps 139 | else: 140 | warmup_interval = min(2000, cfg.num_steps // 20) 141 | schedule = lambda x: min( 142 | 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, 143 | 0.1 144 | + 0.5 145 | * (1 - 0.1) 146 | * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), 147 | ) 148 | scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step)) 149 | 150 | # profiler 151 | profiler = get_profiler(cfg, rank) 152 | 153 | # Train 154 | if rank == 0: 155 | print(f"Training for {cfg.num_steps} steps") 156 | train( 157 | cfg, 158 | model, 159 | local_rank, 160 | rank, 161 | train_loader, 162 | optimizer, 163 | scheduler, 164 | profiler, 165 | checkpointer, 166 | start_step, 167 | tokens_seen, 168 | ) 169 | 170 | dist.barrier() 171 | dist.destroy_process_group() 172 | 173 | 174 | if __name__ == "__main__": 175 | fire.Fire(main) 176 | -------------------------------------------------------------------------------- /main_training_mamba.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from pathlib import Path 4 | 5 | import fire 6 | import torch 7 | import torch.optim as optim 8 | from mamba_ssm.models.config_mamba import MambaConfig 9 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 10 | from mamba_ssm.modules.block import Block 11 | from torch import distributed as dist 12 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 13 | from torch.optim.lr_scheduler import LambdaLR 14 | 15 | from fms_fsdp import config 16 | from fms_fsdp.utils.checkpointing_utils import Checkpointer 17 | from fms_fsdp.utils.config_utils import get_model_config, update_config 18 | from fms_fsdp.utils.dataloader_utils import get_data_loader, get_dummy_loader 19 | from fms_fsdp.utils.train_utils import ( 20 | get_policies, 21 | get_profiler, 22 | setup, 23 | setup_environ_flags, 24 | train, 25 | ) 26 | 27 | 28 | def main(**kwargs): 29 | # get configs 30 | cfg = config.train_config() 31 | update_config(cfg, **kwargs) 32 | 33 | # ensure reproducibility 34 | torch.cuda.manual_seed(cfg.seed) 35 | torch.manual_seed(cfg.seed) 36 | 37 | # torchrun specific 38 | local_rank = int(os.environ["LOCAL_RANK"]) 39 | rank = int(os.environ["RANK"]) 40 | world_size = int(os.environ["WORLD_SIZE"]) 41 | 42 | if rank == 0: 43 | print(f"--> running with these configs {cfg}") 44 | 45 | # some setups 46 | setup() 47 | torch.cuda.set_device(local_rank) 48 | torch.cuda.empty_cache() 49 | setup_environ_flags() 50 | os.environ["TRITON_CACHE_DIR"] = os.path.join( 51 | Path.home(), ".triton", "cache", str(local_rank) 52 | ) 53 | 54 | # get policy 55 | block = Block 56 | ( 57 | mixed_precision_policy, 58 | wrapping_policy, 59 | sharding_strategy_policy, 60 | apply_selective_ac, 61 | param_init_fn, 62 | ) = get_policies(cfg, rank, block) 63 | 64 | # get model 65 | config_data = get_model_config(cfg.model_variant) 66 | mamba_config = MambaConfig(**config_data) 67 | model = MambaLMHeadModel(mamba_config) 68 | 69 | if rank == 0: 70 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 71 | print(f"\n--> model has {total_params / 1e6} Million params\n") 72 | 73 | # get data loader 74 | if rank == 0: 75 | print("Constructing datasets...") 76 | if not cfg.use_dummy_dataset: 77 | train_loader = get_data_loader(cfg, rank, world_size) 78 | else: 79 | train_loader = get_dummy_loader(cfg, rank, world_size) 80 | if rank == 0: 81 | print("Datasets constructed!") 82 | 83 | # FSDP 84 | model = FSDP( 85 | model, 86 | auto_wrap_policy=wrapping_policy, 87 | mixed_precision=mixed_precision_policy, 88 | sharding_strategy=sharding_strategy_policy, 89 | use_orig_params=cfg.use_torch_compile, 90 | device_id=torch.cuda.current_device(), 91 | limit_all_gathers=True, 92 | param_init_fn=param_init_fn, 93 | ) 94 | 95 | # fsdp activation checkpointing 96 | if cfg.fsdp_activation_checkpointing: 97 | if rank == 0: 98 | print(f"--> applying FSDP activation checkpointing...") 99 | apply_selective_ac(model, p=cfg.selective_checkpointing) 100 | 101 | # torch compile 102 | if cfg.use_torch_compile: 103 | if rank == 0: 104 | print(f"--> enabling torch compile...") 105 | # the default accumulated_cache_size_limit=64 is not enough for 70b model, so we make it 128 here 106 | torch._dynamo.config.accumulated_cache_size_limit = 128 107 | model = torch.compile(model) 108 | 109 | # Optimizer 110 | optimizer = optim.AdamW( 111 | model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1 112 | ) 113 | 114 | # optionally load from checkpoint (when continue pretraining) 115 | checkpointer = Checkpointer( 116 | cfg.ckpt_save_path, 1000, cfg.sharding_strategy, rank, local_rank 117 | ) 118 | model, optimizer, _, start_step, tokens_seen, is_resuming = checkpointer.load( 119 | model, 120 | optimizer, 121 | None, 122 | path=os.path.join(cfg.ckpt_load_path, "checkpoints/") 123 | if not os.path.isfile(cfg.ckpt_load_path) 124 | else cfg.ckpt_load_path, 125 | strict=False, 126 | ) 127 | if not is_resuming: 128 | start_step = 0 129 | # Override loaded optim hyperparams with the current values 130 | for g in optimizer.param_groups: 131 | g["initial_lr"] = cfg.learning_rate 132 | 133 | # LR schedule 134 | # linear decay for annealing 135 | if cfg.training_stage == "annealing": 136 | schedule = lambda x: 1 - x / cfg.num_steps 137 | else: 138 | # cosine decay 139 | warmup_interval = min(2000, cfg.num_steps // 20) 140 | schedule = lambda x: min( 141 | 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, 142 | 0.1 143 | + 0.5 144 | * (1 - 0.1) 145 | * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), 146 | ) 147 | 148 | scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step)) 149 | 150 | # profiler 151 | profiler = get_profiler(cfg, rank) 152 | 153 | # Train 154 | if rank == 0: 155 | print(f"Training for {cfg.num_steps} steps") 156 | train( 157 | cfg, 158 | model, 159 | local_rank, 160 | rank, 161 | train_loader, 162 | optimizer, 163 | scheduler, 164 | profiler, 165 | checkpointer, 166 | start_step, 167 | tokens_seen, 168 | ) 169 | 170 | dist.barrier() 171 | dist.destroy_process_group() 172 | 173 | 174 | if __name__ == "__main__": 175 | fire.Fire(main) 176 | -------------------------------------------------------------------------------- /requirements-speculator.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | fms-extras @ git+https://github.com/foundation-model-stack/fms-extras@main 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.2.0 2 | fire==0.5.0 3 | pyarrow==15.0.0 4 | transformers==4.40.2 5 | ibm-fms>=0.0.3 6 | -------------------------------------------------------------------------------- /scripts/README_SPECULATOR.md: -------------------------------------------------------------------------------- 1 | ### Following parameters are relevant for speculator training: 2 | 3 | - *model_arch*: architecture of the base model (one of: embedllama, embedmixtral, embedgpt_bigcode-- FMS implementations extending the base arch to also emit embedding vector together with the model output. See 'EmbedLLaMA' in train_spculator_utils.py) 4 | 5 | - *model_variant*: identifier with which a specific variant (e.g., 7b) is registered for the model architecture. See 'example model registrations' in train_spculator_utils.py. 6 | 7 | - *model_path*: path to dir containing base model weights 8 | 9 | - *ckpt_save_path*: path to dir for storing intermediate checkpoints during speculator training 10 | 11 | - *ckpt_load_path*: path to dir for loading intermediate speculator checkpoint to resume training 12 | 13 | - *sharding_strategy*: how to shard the model across process group: tp / fsdp / hsdp 14 | 15 | - *tp_size*: If loading base model using tensor parallel, no. of GPUs/ranks to split the model across 16 | 17 | - *seq_length*: sequence length of the base model 18 | 19 | - *batch_size*: batch size for stage 1 training for aligning speculator to base model input behavior 20 | 21 | - *report_interval*: no. of steps after which to report training stats 22 | 23 | - *checkpoint_interval*: no. of steps after which to save an intermediate speculator checkpoint 24 | 25 | - *num_steps*: total no. of speculator training steps (stage 1 + stage 2) 26 | 27 | - *stage2_start_step*: no. of steps after which to switch to stage 2 training 28 | 29 | - *stage2_batch_size*: batch size for stage 2 training for aligning speculator to base model output behavior 30 | 31 | - *n_speculator_heads*: no. of lookahead tokens to train the speculator for 32 | 33 | - *speculator_width*: embedding dimension of the speculator MLP 34 | 35 | - *use_torch_compile*: whether to compile base model and speculator-- may speed up training. 36 | 37 | - *learning_rate*: learning rate for speculator training 38 | 39 | - *seed*: random seed to use for training dataset shuffling 40 | 41 | - *data_path*: path to dir containing the training dataset. Expects directory to contain subfolders, which in turn contain shard files. 42 | 43 | - *datasets*: a list of subdatasets (e.g., commoncrawl, github, etc.) to draw from. If None, draws from all subfolders of data_path. 44 | 45 | - *weights*: list of weights reflecting the percentage of tokens to be used from each subdataset during training 46 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # On AWS, the EFA and OFI paths enable NCCL to use optimized networking. 4 | export LD_LIBRARY_PATH=/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/amazon/openmpi/lib:/opt/aws-ofi-nccl/lib:/usr/local/cuda/lib:/usr/local/cuda/lib64:/usr/local/cuda:/usr/local/cuda/targets/x86_64-linux/lib/:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/lib:$LD_LIBRARY_PATH 5 | 6 | export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 7 | 8 | MODEL_ARGS="\ 9 | --use_dummy_dataset=False 10 | --ckpt_load_path=/lustre/pretrain/ckpt 11 | --ckpt_save_path=/lustre/pretrain/ckpt 12 | --data_path=/lustre/bluepile-processing/rel0_7/tokens/llama2/high_quality_rerun_fuzzy_deduped 13 | --fsdp_activation_checkpointing=False 14 | --selective_checkpointing=1 15 | --sharding_strategy=hsdp 16 | --low_cpu_fsdp=False 17 | --batch_size=2 18 | --report_interval=200 19 | --checkpoint_interval=20000 20 | --use_torch_compile=False 21 | --use_profiler=False 22 | " 23 | 24 | torchrun \ 25 | --nnodes=$SLURM_NTASKS \ 26 | --node_rank=$SLURM_NODEID \ 27 | --nproc_per_node=8 \ 28 | --master_addr=`scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1` \ 29 | --master_port="12234" \ 30 | main_training.py \ 31 | ${MODEL_ARGS} 32 | 33 | -------------------------------------------------------------------------------- /scripts/train.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --nodes=16 4 | #SBATCH --gres=gpu:8 5 | #SBATCH --cpus-per-task=64 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --wait-all-nodes=1 8 | #SBATCH --exclusive 9 | ##SBATCH --contiguous 10 | 11 | srun ./scripts/train.sh 12 | -------------------------------------------------------------------------------- /scripts/train_speculator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # On AWS, the EFA and OFI paths enable NCCL to use optimized networking. 4 | export LD_LIBRARY_PATH=/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/amazon/openmpi/lib:/opt/aws-ofi-nccl/lib:/usr/local/cuda/lib:/usr/local/cuda/lib64:/usr/local/cuda:/usr/local/cuda/targets/x86_64-linux/lib/:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/lib:$LD_LIBRARY_PATH 5 | 6 | export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 7 | 8 | MODEL_ARGS="\ 9 | --model_path=/path/to/models/meta-llama/Llama-2-7b-hf 10 | --model_arch=embedllama 11 | --model_variant=7b 12 | --ckpt_load_path=/path/to/checkpoints/llama2-7b 13 | --ckpt_save_path=/path/to/checkpoints/llama2-7b 14 | --logical_shards=768 15 | --sharding_strategy=hsdp 16 | --seq_length=4096 17 | --batch_size=8 18 | --report_interval=10 19 | --checkpoint_interval=3000 20 | --num_steps=21000 21 | --stage2_start_step=15000 22 | --stage2_batch_size=96 23 | --n_speculator_heads=3 24 | --speculator_width=4096 25 | --use_torch_compile=False 26 | --learning_rate=1e-3 27 | --seed=42 28 | --data_path=/path/to/dataset/ 29 | --datasets="'dataset=commoncrawl'" 30 | --weights="'1'" 31 | " 32 | 33 | torchrun \ 34 | --nproc_per_node=8 \ 35 | speculator/train_speculator.py \ 36 | ${MODEL_ARGS} 37 | 38 | 39 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | setup( 5 | name="fms_fsdp", 6 | version="0.0.1", 7 | author="Linsong Chu, Davis Wertheimer, Brian Vaughan, Andrea Frittoli, Joshua Rosenkranz, Antoni Viros i Martin, Raghu Kiran Ganti", 8 | author_email="lchu@us.ibm.com", 9 | description="Pretraining scripts using FSDP and IBM Foundation Model Stack", 10 | url="https://github.com/foundation-model-stack/fms-fsdp", 11 | packages=find_packages(), 12 | install_requires=["ibm-fms >= 0.0.3", "torch >= 2.1"], 13 | license="Apache License 2.0", 14 | classifiers=[ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: Apache Software License", 17 | ], 18 | ) 19 | -------------------------------------------------------------------------------- /speculator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/fms-fsdp/503da7ede354e1ffdabc80fae8bbd211cb2174c8/speculator/__init__.py -------------------------------------------------------------------------------- /speculator/train_speculator.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import time 4 | 5 | import fire # type: ignore 6 | import torch 7 | import torch.optim as optim 8 | from fms.models import get_model 9 | from fms.models.llama import LLaMABlock 10 | from fms.utils import generation, tokenizers 11 | from fms_extras.models.speculator import MLPSpeculator # type: ignore 12 | from torch import distributed as dist 13 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 14 | from torch.distributed.fsdp import ShardingStrategy 15 | from torch.optim.lr_scheduler import LambdaLR 16 | 17 | from fms_fsdp import config 18 | from fms_fsdp.utils.checkpointing_utils import Checkpointer 19 | from fms_fsdp.utils.config_utils import update_config 20 | from fms_fsdp.utils.dataloader_utils import get_data_loader, get_dummy_loader 21 | from fms_fsdp.utils.train_utils import ( 22 | get_mixed_precision_policy, 23 | get_profiler, 24 | setup, 25 | setup_environ_flags, 26 | ) 27 | from speculator.train_speculator_utils import train_speculator 28 | 29 | 30 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 31 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 32 | 33 | 34 | def test_model(rank, model, arch, cfg, prompt_type="chat"): 35 | if rank == 0: 36 | print("testing model output") 37 | tokenizer = tokenizers.get_tokenizer(cfg.model_path) 38 | if prompt_type == "chat": 39 | template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:" 40 | prompt = template.format( 41 | "Provide a list of instructions for preparing chicken soup." 42 | ) 43 | else: 44 | template = "[INST] Write code to solve the following coding problem that obeys the constraints and passes the example test cases. Please wrap your code answer using ```:\n{}\n[/INST]" 45 | prompt = template.format("Write a bubble sort function in python.") 46 | 47 | tokens = tokenizer.tokenize(prompt) 48 | ids = tokenizer.convert_tokens_to_ids(tokens) 49 | if "llama" in arch: 50 | ids = [tokenizer.bos_token_id] + ids 51 | ids = torch.tensor(ids, dtype=torch.long, device="cuda") 52 | result = generation.generate( 53 | model, 54 | ids, 55 | max_new_tokens=100, 56 | use_cache=True, 57 | do_sample=False, 58 | max_seq_len=8192, 59 | ) 60 | result = generation.truncate_after_eos(result, tokenizer.eos_token_id) 61 | if rank == 0: 62 | print(f"{rank}: quick test of base model") 63 | print( 64 | tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(result)) 65 | ) 66 | 67 | 68 | def get_emb_dim(model): 69 | if hasattr(model.config, "emb_dim"): 70 | emb_dim = model.config.emb_dim 71 | elif hasattr(model.config, "dim"): # Mixtral 72 | emb_dim = model.config.dim 73 | elif hasattr(model.config, "hidden_size"): # HF 74 | emb_dim = model.config.hidden_size 75 | else: 76 | raise Exception("config missing embedding dimension") 77 | return emb_dim 78 | 79 | 80 | def get_vocab_size(model): 81 | if hasattr(model.config, "src_vocab_size"): # FMS 82 | vocab_size = model.config.src_vocab_size 83 | elif hasattr(model.config, "vocab_size"): # HF 84 | vocab_size = model.config.vocab_size 85 | else: 86 | raise Exception("config missing vocab size config") 87 | return vocab_size 88 | 89 | 90 | def get_training_data_loader(rank, cfg, world_size, speculator_mesh): 91 | if rank == 0: 92 | print(f"{time.time()} Constructing datasets...") 93 | if not cfg.use_dummy_dataset: 94 | if cfg.sharding_strategy == "tp": 95 | train_loader = get_data_loader( 96 | cfg, speculator_mesh.get_rank(), speculator_mesh.size(), postprocess=[] 97 | ) 98 | else: 99 | train_loader = get_data_loader(cfg, rank, world_size, postprocess=[]) 100 | else: 101 | train_loader = get_dummy_loader(cfg, rank, world_size) 102 | if rank == 0: 103 | print(f"{time.time()} Datasets constructed!") 104 | return train_loader 105 | 106 | 107 | def main(**kwargs): 108 | # get configs 109 | cfg = config.train_config() 110 | update_config(cfg, **kwargs) 111 | cfg.seq_length = cfg.seq_length + cfg.n_speculator_heads + 1 112 | 113 | # ensure reproducibility 114 | torch.cuda.manual_seed(cfg.seed) 115 | torch.manual_seed(cfg.seed) 116 | 117 | # torchrun specific 118 | local_rank = int(os.environ["LOCAL_RANK"]) 119 | rank = int(os.environ["RANK"]) 120 | world_size = int(os.environ["WORLD_SIZE"]) 121 | 122 | if rank == 0: 123 | print(f"{time.time()} running with these configs {cfg}") 124 | 125 | # some setups 126 | torch.cuda.set_device(local_rank) 127 | 128 | if cfg.sharding_strategy != "tp": 129 | setup() 130 | torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) 131 | base_model_mesh = None 132 | speculator_mesh = None 133 | else: 134 | base_model_mesh = dist.device_mesh.init_device_mesh( 135 | "cuda", 136 | (world_size // cfg.tp_size, cfg.tp_size), 137 | mesh_dim_names=("dp", "tp"), 138 | ) 139 | speculator_mesh = dist.device_mesh.init_device_mesh("cuda", (world_size,)) 140 | torch._C._distributed_c10d._register_process_group( 141 | "default", base_model_mesh["tp"].get_group() 142 | ) 143 | 144 | torch.cuda.empty_cache() 145 | setup_environ_flags() 146 | torch.set_default_dtype(torch.bfloat16) 147 | 148 | mixed_precision_policy = get_mixed_precision_policy(cfg, rank) 149 | 150 | model = get_model( 151 | cfg.model_arch, 152 | cfg.model_variant, 153 | model_path=cfg.model_path, 154 | device_type="cuda", 155 | source="hf", 156 | distributed_strategy=cfg.sharding_strategy, 157 | group=( 158 | base_model_mesh["tp"].get_group() if cfg.sharding_strategy == "tp" else None 159 | ), 160 | ) 161 | 162 | if rank == 0: 163 | print(f"{time.time()}") 164 | print(model.config) 165 | print(model) 166 | 167 | model.eval() 168 | with torch.no_grad(): 169 | test_model(rank, model, cfg.model_arch, cfg) 170 | 171 | emb_dim = get_emb_dim(model) 172 | vocab_size = get_vocab_size(model) 173 | 174 | # get speculator 175 | if rank == 0: 176 | print(f"{time.time()} Loading speculator") 177 | speculator = MLPSpeculator( 178 | emb_dim, 179 | cfg.speculator_width, 180 | vocab_size, 181 | cfg.n_speculator_heads, 182 | tie_weights=cfg.speculator_tie_weights, 183 | scale_input=cfg.speculator_scale_input, 184 | ) 185 | speculator.reset_parameters() 186 | 187 | if rank == 0: 188 | total_params = sum( 189 | p.numel() for p in speculator.parameters() if p.requires_grad 190 | ) 191 | print(f"\n{time.time()} speculator has {total_params / 1e6} Million params\n") 192 | 193 | # get data loader 194 | train_loader = get_training_data_loader(rank, cfg, world_size, speculator_mesh) 195 | 196 | # FSDP 197 | speculator = FSDP( 198 | speculator, 199 | auto_wrap_policy=None, 200 | mixed_precision=mixed_precision_policy, 201 | sharding_strategy=ShardingStrategy.NO_SHARD, 202 | use_orig_params=cfg.use_torch_compile, 203 | device_id=torch.cuda.current_device(), 204 | limit_all_gathers=True, 205 | sync_module_states=cfg.low_cpu_fsdp, 206 | param_init_fn=lambda module: ( 207 | module.to_empty(device=torch.device("cuda"), recurse=False) 208 | if cfg.low_cpu_fsdp 209 | else None 210 | ), 211 | device_mesh=speculator_mesh if cfg.sharding_strategy == "tp" else None, 212 | ) 213 | 214 | # torch compile 215 | if cfg.use_torch_compile: 216 | if rank == 0: 217 | print(f"enabling torch compile...") 218 | if cfg.fsdp_activation_checkpointing: 219 | raise ValueError( 220 | "Compile does not yet work well with llama+ac, please" 221 | "either use it without activation checkpointing, or disable" 222 | "compile." 223 | ) 224 | # we need this post-fsdp call to avoid graph break with torch.compile, 225 | if cfg.sharding_strategy != "tp" and hasattr(model, "rot_emb"): 226 | model.rot_emb.compute_freqs_cis( 227 | torch.device("cuda", torch.cuda.current_device()), 228 | model.config.max_expected_seq_len + 10, 229 | ) 230 | model = torch.compile(model) 231 | speculator = torch.compile(speculator) 232 | 233 | # Optimizer 234 | optimizer = optim.AdamW( 235 | speculator.parameters(), 236 | lr=cfg.learning_rate, 237 | betas=(0.9, 0.95), 238 | weight_decay=0.1, 239 | ) 240 | 241 | # optionally load from checkpoint (when continue pretraining) 242 | if cfg.sharding_strategy == "tp": 243 | checkpointer = Checkpointer( 244 | cfg.ckpt_save_path, 245 | 1000, 246 | "ddp", 247 | speculator_mesh.get_rank(), 248 | speculator_mesh.get_local_rank(), 249 | model_auto_placement=True, 250 | ) 251 | else: 252 | checkpointer = Checkpointer(cfg.ckpt_save_path, 1000, "ddp", rank, local_rank) 253 | speculator, optimizer, train_loader, start_step, tokens_seen, _ = checkpointer.load( 254 | speculator, 255 | optimizer, 256 | train_loader, 257 | path=os.path.join(cfg.ckpt_load_path, "checkpoints/"), 258 | is_compiled=cfg.use_torch_compile, 259 | ) 260 | 261 | # LR schedule 262 | # These functions map step count to LR scaling factor in [0,1]. 263 | # Stage 1: warm up over first 2k or 5% of steps, whichever is smaller. 264 | # Then cosine anneal to 10% of max LR. 265 | warmup_interval1 = min(2000, cfg.stage2_start_step // 20) 266 | stage1_schedule = lambda x: min( 267 | # Parabolic warmup 268 | 1 - (1 - min(x, warmup_interval1) / warmup_interval1) ** 2, 269 | # Final .1 scaling factor 270 | 0.1 271 | # Cosine anneal from 1 to .1 over stage2_start_step steps 272 | + 0.5 * (1 - 0.1) * (1 + math.cos(x / cfg.stage2_start_step * math.pi)), 273 | ) 274 | # Stage 2: warm up over first 2k or 5% of steps, whichever is smaller. 275 | # Then cosine anneal to 10% of stage 1's final LR. 276 | warmup_interval2 = min(2000, (cfg.num_steps - cfg.stage2_start_step) // 20) 277 | stage2_schedule = lambda x: min( 278 | # Parabolic warmup to stage2's max LR (10% of stage1's max LR) 279 | 0.1 * (1 - (1 - min(x, warmup_interval2) / warmup_interval2) ** 2), 280 | # Final 10% of 10% scaling factor 281 | 0.01 282 | # Cosine anneal from .1 to .01 over remaining stage2 steps 283 | + 0.05 284 | * (1 - 0.1) 285 | * ( 286 | 1 287 | + math.cos( 288 | min(x, cfg.num_steps - cfg.stage2_start_step) 289 | / (cfg.num_steps - cfg.stage2_start_step) 290 | * math.pi 291 | ) 292 | ), 293 | ) 294 | # Assemble full scheduling function with correct step offsets. 295 | schedule = lambda x: ( 296 | stage1_schedule(x) 297 | if x <= cfg.stage2_start_step 298 | else stage2_schedule(x - cfg.stage2_start_step) 299 | ) 300 | scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step)) 301 | 302 | # profiler 303 | profiler = get_profiler(cfg, rank) 304 | 305 | # Train 306 | if rank == 0: 307 | print(f"{time.time()} Training for {cfg.num_steps} steps") 308 | torch.cuda.empty_cache() 309 | train_speculator( 310 | cfg, 311 | model, 312 | speculator, 313 | local_rank, 314 | rank, 315 | train_loader, 316 | optimizer, 317 | scheduler, 318 | checkpointer, 319 | start_step, 320 | tokens_seen, 321 | profiler, 322 | base_model_mesh, 323 | ) 324 | 325 | dist.barrier() 326 | dist.destroy_process_group() 327 | 328 | 329 | if __name__ == "__main__": 330 | fire.Fire(main) 331 | -------------------------------------------------------------------------------- /speculator/train_speculator_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union 5 | 6 | import torch 7 | import torch.distributed as dist 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from fms.models import register_model 11 | from fms.models.gpt_bigcode import GPTBigCode 12 | from fms.models.gpt_bigcode import _20b_config as _gpt_bigcode_20b_config 13 | from fms.models.gpt_bigcode import _hf_to_fms_names as _gptbigcode_hf_sd_to_fms_sd 14 | from fms.models.llama import LLaMA 15 | from fms.models.llama import _hf_to_fms_names as _llama_hf_sd_to_fms_sd 16 | from fms.models.mixtral import Mixtral, MixtralConfig 17 | from fms.models.mixtral import _hf_to_fms_names as _mixtral_hf_sd_to_fms_sd 18 | from fms.utils import serialization, tokenizers 19 | from fms.utils.generation import _make_cache_contiguous 20 | from torch.nn import CrossEntropyLoss 21 | from torch.utils.data import DataLoader 22 | 23 | from fms_fsdp.config import train_config 24 | from fms_fsdp.utils.checkpointing_utils import Checkpointer 25 | from fms_fsdp.utils.config_utils import get_model_config 26 | 27 | 28 | def generate( 29 | model: Union[Callable, torch.nn.Module], 30 | input_ids: torch.Tensor, 31 | max_seq_len: int = 2048, 32 | max_new_tokens: int = 256, 33 | temperature: float = 1.0, 34 | top_k: int = 10, 35 | do_sample: bool = True, 36 | num_beams: int = 1, 37 | use_cache: bool = False, 38 | contiguous_cache: bool = False, 39 | include_embeds: bool = True, 40 | ): 41 | """ 42 | A straightforward copy of the generate method in fms.utils.generation. 43 | The only change is the include_embeds flag, which when true also returns 44 | the embedding vectors corresponding to the tokens in the output sequence. 45 | """ 46 | batched = False 47 | if num_beams != 1: 48 | raise NotImplementedError("generate() does yet not support beam search") 49 | if type(input_ids) == torch.Tensor: 50 | if input_ids.dim() != 1: 51 | batched = True 52 | else: 53 | raise RuntimeError("generate() requires a tensor of token ids as the prefix") 54 | 55 | if not batched: 56 | input_ids = input_ids.unsqueeze(0) 57 | 58 | embeds = None 59 | result = input_ids 60 | next_input = input_ids 61 | kwargs: MutableMapping[str, Any] = dict() 62 | kwargs["past_key_value_states"] = None 63 | kwargs["use_cache"] = use_cache 64 | kwargs["include_embeds"] = include_embeds 65 | 66 | for _ in range(max_new_tokens): 67 | input_ids = next_input[:, -max_seq_len:] 68 | output = model(input_ids, **kwargs) 69 | if not use_cache and not include_embeds: 70 | logits = output 71 | else: 72 | logits = output[0] 73 | if include_embeds: 74 | z = output[-1] 75 | if use_cache: 76 | past_key_value_states = output[1] 77 | # TODO: this should go away when reduce-overhead issues are fixed, or 78 | # maybe could be moved into model code to be more portable. 79 | if contiguous_cache: 80 | kwargs["past_key_value_states"] = _make_cache_contiguous( 81 | past_key_value_states 82 | ) 83 | else: 84 | kwargs["past_key_value_states"] = past_key_value_states 85 | logits = logits[:, -1, :] 86 | 87 | if do_sample: 88 | # get logits from last value in sequence nad scale 89 | logits = logits / temperature 90 | if top_k: 91 | v, _ = torch.topk(logits, top_k) 92 | logits[logits < v[:, [-1]]] = -float("inf") 93 | 94 | probs = F.softmax(logits, dim=-1) 95 | next_val = torch.multinomial(probs, num_samples=1) 96 | else: 97 | next_val = torch.argmax(logits, dim=-1).unsqueeze(0).t() 98 | 99 | result = torch.cat((result, next_val), dim=-1) 100 | 101 | if use_cache: 102 | next_input = next_val 103 | else: 104 | next_input = result 105 | 106 | if include_embeds: 107 | if embeds is None: 108 | embeds = z 109 | else: 110 | embeds = torch.cat((embeds, z), dim=-2) 111 | 112 | if not batched: 113 | result = result[0] 114 | 115 | if include_embeds: 116 | return result, embeds 117 | 118 | return result 119 | 120 | 121 | # Stage 1 training 122 | def stage1_loss( 123 | cfg, model, speculator, base_model_input, input, loss_fn, ddp_stats, base_model_mesh 124 | ): 125 | """ 126 | Perform a forward pass for stage 1 training and calculate the loss. 127 | Given the sequence of embeddings produced in parallel by the base model, 128 | get n+2,n+3,... speculator predictions and compare to ground truth tokens. 129 | ... 130 | Args 131 | ---- 132 | cfg: train_config 133 | Set of training parameters. 134 | model: nn.Module 135 | The frozen base model. Must return output logits AND corresponding embedding vectors. 136 | speculator: nn.Module 137 | The speculator to be trained. Takes as input sequence of embeddings and token indices, 138 | and return token prediction logits for each head. 139 | input: torch.IntTensor 140 | The ground truth token indices. If using TP, this is per TP rank, 141 | with 'base_model_input' containing all-gathered input across all TP ranks 142 | loss_fn: Callable 143 | Torch loss function comparing logits to indices i.e. CrossEntropyLoss() 144 | ddp_stats: torch.FloatTensor 145 | Aggregate stat tracking buffer. 146 | Entries are: grad norm, accumulation steps, head 1 loss, head 2 loss, etc. 147 | base_model_mesh: torch.distributed.device_mesh.DeviceMesh 148 | Device layout of the particiapting process group ranks 149 | ---- 150 | Returns: scalar loss value, updated ddp stats, number of tokens in input 151 | """ 152 | with torch.no_grad(): 153 | _, embeds = model( 154 | base_model_input[:, : -speculator.n_predict - 1], 155 | include_embeds=True, 156 | use_cache=False, 157 | ) 158 | if cfg.sharding_strategy == "tp": 159 | embeds = embeds.chunk(base_model_mesh["tp"].size())[ 160 | base_model_mesh["tp"].get_local_rank() 161 | ] 162 | 163 | preds = speculator(embeds.detach(), input[:, 1:]) 164 | losses = [] 165 | for i in range(preds.size(0)): 166 | targ = input[:, i + 2 : preds.size(2) + i + 2] # b n 167 | loss = loss_fn(preds[i].reshape(-1, preds.size(3)), targ.long().reshape(-1)) 168 | losses.append(loss) 169 | ddp_stats[2 + i] += loss.item() 170 | loss = sum(losses) 171 | return loss, ddp_stats, input.numel() 172 | 173 | 174 | # Stage 2 training: more heavyweight than stage 1; will take longer 175 | def stage2_loss( 176 | cfg, model, speculator, base_model_input, input, loss_fn, ddp_stats, base_model_mesh 177 | ): 178 | """ 179 | Perform a forward pass for stage 2 training and calculate the loss. 180 | Given the sequence of embeddings produced in serial by the base model, 181 | get n+1,n+2,... speculator predictions and compare to base model's generated tokens. 182 | Reshapes input to more entries / shorter sequences, for more efficient generation. 183 | ... 184 | Args 185 | ---- 186 | cfg: train_config 187 | Set of training parameters. Used here for reshaping input batches. 188 | model: nn.Module 189 | The frozen base model. Must return output logits AND corresponding embedding vectors. 190 | speculator: nn.Module 191 | The speculator to be trained. Takes as input sequence of embeddings and token indices, 192 | and return token prediction logits for each head. 193 | input: torch.IntTensor 194 | The ground truth token indices. If using TP, this is per TP rank, 195 | with 'base_model_input' containing all-gathered input across all TP ranks 196 | loss_fn: Callable 197 | Torch loss function comparing logits to indices i.e. CrossEntropyLoss() 198 | ddp_stats: torch.FloatTensor 199 | Aggregate stat tracking buffer. 200 | Entries are: grad norm, accumulation steps, head 1 loss, head 2 loss, etc. 201 | base_model_mesh: torch.distributed.device_mesh.DeviceMesh 202 | Device layout of the particiapting process group ranks 203 | ---- 204 | Returns: scalar loss value, updated ddp stats, number of tokens in input 205 | """ 206 | with torch.no_grad(): 207 | grow_factor = cfg.stage2_batch_size // cfg.batch_size 208 | assert ( 209 | cfg.stage2_prompt_length * grow_factor <= cfg.seq_length 210 | ), "Error: batch is too small for specified partition" 211 | base_model_input = base_model_input[ 212 | :, : cfg.stage2_prompt_length * grow_factor 213 | ].reshape(base_model_input.size(0) * grow_factor, cfg.stage2_prompt_length) 214 | targs, embeds = generate( 215 | model, 216 | base_model_input, 217 | cfg.seq_length, 218 | cfg.stage2_seq_length, 219 | do_sample=True, 220 | use_cache=True, 221 | include_embeds=True, 222 | ) 223 | 224 | if cfg.sharding_strategy == "tp": 225 | targs = targs.chunk(base_model_mesh["tp"].size())[ 226 | base_model_mesh["tp"].get_local_rank() 227 | ] 228 | embeds = embeds.chunk(base_model_mesh["tp"].size())[ 229 | base_model_mesh["tp"].get_local_rank() 230 | ] 231 | targs = targs[:, -cfg.stage2_seq_length :] 232 | embeds = embeds[:, -cfg.stage2_seq_length : -speculator.n_predict] 233 | preds = speculator(embeds.detach(), targs[:, :-1].detach()) 234 | 235 | losses = [] 236 | for i in range(preds.size(0)): 237 | targ = targs[:, i + 1 : preds.size(2) + i + 1] # b n 238 | loss = loss_fn(preds[i].reshape(-1, preds.size(3)), targ.long().reshape(-1)) 239 | losses.append(loss) 240 | ddp_stats[2 + i] += loss.item() 241 | loss = sum(losses) 242 | return loss, ddp_stats, targs.numel() 243 | 244 | 245 | # on demand checkpointing: echo 1 > /path/to/model_ckpt_dir/do_ckpt 246 | def do_ckpt(ckpt_save_path, reset=False): 247 | ckpt_cmd_file = ckpt_save_path + "/do_ckpt" 248 | if not os.path.exists(ckpt_cmd_file): 249 | return False 250 | 251 | if reset: 252 | with open(ckpt_cmd_file, "w") as fd: 253 | fd.write("0") 254 | return False 255 | 256 | with open(ckpt_cmd_file) as fd: 257 | if fd.read().strip() == "1": 258 | return True 259 | 260 | return False 261 | 262 | 263 | def train_speculator( 264 | cfg: train_config, 265 | model: nn.Module, 266 | speculator: nn.Module, 267 | local_rank: int, 268 | rank: int, 269 | train_loader: DataLoader, 270 | optimizer: torch.optim.Optimizer, 271 | scheduler: torch.optim.lr_scheduler.LRScheduler, 272 | checkpointer: Checkpointer, 273 | start_step: int = 0, 274 | n_tok: int = 0, 275 | profiler: Optional[Union[torch.profiler.profile, None]] = None, 276 | base_model_mesh=None, 277 | ): 278 | """ 279 | The training loop for speculator training. Handles at a high level: data loading, 280 | forward and backward passes, model updates, stat tracking, reporting, and checkpointing. 281 | ... 282 | Args 283 | ---- 284 | cfg: train_config 285 | The set of training parameters 286 | model: nn.Module 287 | The frozen base model. Must return output logits AND corresponding embedding vectors. 288 | speculator: nn.Module 289 | The speculator to be trained. Takes as input sequence of embeddings and token indices, 290 | and returns token prediction logits for each head. 291 | local_rank: int 292 | The local rank of the current process. Used for stat tracking / aggregation across ranks. 293 | rank: int 294 | The global rank of the current process. Used for reporting. 295 | train_loader: torch.utils.data.DataLoader 296 | The dataloader used for reading in ground truth token sequences. Train_loader.dataset must 297 | support save_to_path() for distributed checkpointing via checkpointer. 298 | optimizer: torch.optim.Optimizer 299 | The optimizer associated with the speculator's weights 300 | scheduler: torch.optim.lr_scheduler.LRScheduler 301 | A scheduler for the optimizer's LR. Scheduler.step() is called on every optimizer step. 302 | checkpointer: fms_fsdp.utils.checkpointing_utils.Checkpointer 303 | A checkpointer tied to the save directory. Used for saving distributed checkpoints. 304 | start_step: optional[int] 305 | If resuming from checkpoint, resume step count from this value. 306 | n_tok: optional[int] 307 | If resuming from checkpoint, resume token count from this value. 308 | profiler: optional[torch.profiler.profile] 309 | Optional torch profiler for performance benchmarking. 310 | base_model_mesh: DeviceMesh 311 | Device layout of the particiapting process group ranks 312 | """ 313 | model.eval() 314 | speculator.train() 315 | ddp_stats = torch.zeros(2 + speculator.n_predict).to(local_rank) 316 | 317 | start = time.time() 318 | loop_start = time.time() 319 | loss_fn = CrossEntropyLoss() 320 | elapsed_tokens = 0 321 | for batch_idx, input in enumerate(train_loader, start=start_step + 1): 322 | if batch_idx > cfg.num_steps: 323 | break 324 | 325 | input = input.to(local_rank) 326 | 327 | if cfg.sharding_strategy == "tp": 328 | base_model_input = torch.zeros( 329 | base_model_mesh["tp"].size() * input.size(0), 330 | input.size(1), 331 | dtype=input.dtype, 332 | device=input.device, 333 | ) 334 | dist.all_gather_into_tensor( 335 | base_model_input, input, group=base_model_mesh["tp"].get_group() 336 | ) 337 | else: 338 | base_model_input = input 339 | 340 | optimizer.zero_grad() 341 | 342 | if batch_idx <= cfg.stage2_start_step: 343 | loss, ddp_stats, step_tok = stage1_loss( 344 | cfg, 345 | model, 346 | speculator, 347 | base_model_input, 348 | input, 349 | loss_fn, 350 | ddp_stats, 351 | base_model_mesh, 352 | ) 353 | else: 354 | loss, ddp_stats, step_tok = stage2_loss( 355 | cfg, 356 | model, 357 | speculator, 358 | base_model_input, 359 | input, 360 | loss_fn, 361 | ddp_stats, 362 | base_model_mesh, 363 | ) 364 | 365 | loss.backward() 366 | ddp_stats[0] += speculator.clip_grad_norm_(cfg.grad_clip_thresh).item() 367 | optimizer.step() 368 | scheduler.step() 369 | 370 | ddp_stats[1] += 1 371 | 372 | if profiler: 373 | profiler.step() 374 | 375 | if batch_idx % cfg.report_interval == 0: 376 | dist.all_reduce(ddp_stats, op=dist.ReduceOp.SUM) 377 | train_loss = ddp_stats[2:] / ddp_stats[1] 378 | g_norm = ddp_stats[0] / ddp_stats[1] 379 | elapsed_time = time.time() - loop_start 380 | world_size = int(os.environ["WORLD_SIZE"]) 381 | elapsed_tokens += cfg.report_interval * world_size * step_tok 382 | if rank == 0: 383 | print(f"{time.time()}") 384 | print("step:", batch_idx) 385 | print("tokens seen:", n_tok + elapsed_tokens) 386 | for i in range(len(train_loss)): 387 | print(f"loss {i+1}:", train_loss[i].item()) 388 | print("gradient norm:", g_norm.item()) 389 | print( 390 | f"speed for these {cfg.report_interval} steps:", 391 | (time.time() - start) / cfg.report_interval, 392 | ) 393 | print("overall speed:", elapsed_time / (batch_idx - start_step)) 394 | print("LR:", scheduler.get_last_lr()) 395 | print( 396 | "reserved memory:", 397 | torch.cuda.max_memory_reserved(device=torch.cuda.current_device()), 398 | ) 399 | print( 400 | "active memory:", 401 | torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), 402 | ) 403 | print( 404 | "overall token per gpu per sec:", 405 | int(elapsed_tokens / world_size / elapsed_time), 406 | ) 407 | print("token per day:", int(elapsed_tokens / elapsed_time * 3600 * 24)) 408 | print() 409 | start = time.time() 410 | ddp_stats.zero_() 411 | torch.cuda.reset_peak_memory_stats(device=torch.cuda.current_device()) 412 | 413 | if ( 414 | batch_idx % cfg.checkpoint_interval == 0 415 | or batch_idx == cfg.num_steps 416 | or do_ckpt(cfg.ckpt_save_path) is True 417 | ): 418 | torch.cuda.empty_cache() 419 | checkpointer.save( 420 | batch_idx, 421 | speculator, 422 | optimizer, 423 | train_loader, 424 | tokens_seen=elapsed_tokens + n_tok, 425 | ) 426 | torch.cuda.empty_cache() 427 | do_ckpt(cfg.ckpt_save_path, reset=True) 428 | 429 | 430 | class EmbedGPTBigCode(GPTBigCode): 431 | # Overrides the forward function of GPTBigCode to allow returning embedding vectors 432 | def forward( 433 | self, 434 | x: torch.Tensor, 435 | mask: Optional[torch.Tensor] = None, 436 | position_ids: Optional[torch.Tensor] = None, 437 | past_key_value_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, 438 | use_cache: bool = False, 439 | only_last_token: bool = False, 440 | attn_algorithm: Optional[str] = None, 441 | include_embeds: bool = False, 442 | ): 443 | output, cache = self.base_model( 444 | x, 445 | mask, 446 | position_ids=position_ids, 447 | past_key_value_states=past_key_value_states, 448 | use_cache=use_cache, 449 | attn_algorithm=attn_algorithm, 450 | ) 451 | 452 | preds = self.head(output) 453 | 454 | out = [preds] 455 | if use_cache: 456 | out.append(cache) 457 | if include_embeds: 458 | out.append(output) 459 | if len(out) == 1: 460 | return out[0] 461 | return out 462 | 463 | 464 | class EmbedLLaMA(LLaMA): 465 | # Overrides the forward function of LLaMA to allow returning embedding vectors 466 | def forward( 467 | self, 468 | x, 469 | mask=None, 470 | position_ids=None, 471 | past_key_value_states=None, 472 | use_cache=False, 473 | only_last_token=False, 474 | attn_algorithm=None, 475 | include_embeds=False, 476 | ): 477 | output, cache = self._helper( 478 | x, mask, position_ids, past_key_value_states, use_cache, attn_algorithm 479 | ) 480 | 481 | if only_last_token: 482 | output = output[:, -1, :] 483 | preds = self.shared(output, reverse=True) 484 | 485 | out = [preds] 486 | if use_cache: 487 | out.append(cache) 488 | if include_embeds: 489 | out.append(output) 490 | if len(out) == 1: 491 | return out[0] 492 | return out 493 | 494 | 495 | class EmbedMixtral(Mixtral): # FMS impl of Mixtral 496 | # Overrides the forward function of Mixtral to allow returning embedding vectors 497 | def forward( 498 | self, 499 | x, 500 | mask=None, 501 | position_ids=None, 502 | past_key_value_states=None, 503 | use_cache=False, 504 | only_last_token=False, 505 | attn_algorithm=None, 506 | include_embeds=False, 507 | ): 508 | output, cache = self.base_model( 509 | x, mask, position_ids, past_key_value_states, use_cache, attn_algorithm 510 | ) 511 | 512 | if only_last_token: 513 | output = output[:, -1, :] 514 | preds = self.head(output) 515 | 516 | out = [preds] 517 | if use_cache: 518 | out.append(cache) 519 | if include_embeds: 520 | out.append(output) 521 | if len(out) == 1: 522 | return out[0] 523 | return out 524 | 525 | 526 | def _gpt_bigcode_factory_factory(config): 527 | def factory(**kwargs): 528 | return EmbedGPTBigCode(config, **kwargs) 529 | 530 | return factory 531 | 532 | 533 | def _llama_factory_factory(config): 534 | def factory(**kwargs): 535 | return EmbedLLaMA(config, **kwargs) 536 | 537 | return factory 538 | 539 | 540 | def _mixtral_factory_factory(config): 541 | def factory(**kwargs): 542 | return EmbedMixtral(config, **kwargs) 543 | 544 | return factory 545 | 546 | 547 | # example model registrations 548 | register_model( 549 | "embedgpt_bigcode", "20b", _gpt_bigcode_factory_factory(_gpt_bigcode_20b_config) 550 | ) 551 | serialization.register_adapter_step( 552 | "embedgpt_bigcode", "hf_to_fms", _gptbigcode_hf_sd_to_fms_sd 553 | ) 554 | serialization.register_adapter("embedgpt_bigcode", "hf", ["hf_to_fms"]) 555 | 556 | register_model( 557 | "embedllama", "7b", _llama_factory_factory(get_model_config("llama2_7b")) 558 | ) 559 | register_model( 560 | "embedllama", "8b", _llama_factory_factory(get_model_config("llama3_8b")) 561 | ) 562 | serialization.register_adapter_step("embedllama", "hf_to_fms", _llama_hf_sd_to_fms_sd) 563 | serialization.register_adapter("embedllama", "hf", ["hf_to_fms"]) 564 | 565 | register_model("embedmixtral", "8x7b", _mixtral_factory_factory(MixtralConfig())) 566 | serialization.register_adapter_step( 567 | "embedmixtral", "hf_to_fms", _mixtral_hf_sd_to_fms_sd 568 | ) 569 | serialization.register_adapter("embedmixtral", "hf", ["hf_to_fms"]) 570 | -------------------------------------------------------------------------------- /test-requirements.txt: -------------------------------------------------------------------------------- 1 | # Used to install pinned test dependencies 2 | # Useful for dev/test jobs caches 3 | 4 | -r requirements.txt 5 | 6 | # Test tools 7 | black==24.1.1 8 | mypy==1.8.0 9 | mypy-extensions==1.0.0 10 | pytest==8.1.1 11 | 12 | # Types packages 13 | pyarrow-stubs==10.0.1.7 14 | types-requests==2.31.0.20240125 15 | types-setuptools==69.0.0.20240125 16 | 17 | # Local import libraries that we don't want to put in global requirements.txt 18 | wandb==0.16.4 19 | aim==3.19.1 -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from fms.models.llama import LLaMA, LLaMAConfig 3 | 4 | 5 | @pytest.fixture 6 | def narrow_model(request): 7 | return LLaMA( 8 | LLaMAConfig(src_vocab_size=1, emb_dim=1, nheads=1, nlayers=request.param) 9 | ) 10 | 11 | 12 | @pytest.fixture 13 | def narrow_model_factory(request): 14 | class NarrowModelFactory: 15 | def create(self): 16 | return LLaMA( 17 | LLaMAConfig( 18 | src_vocab_size=1, emb_dim=1, nheads=1, nlayers=request.param 19 | ) 20 | ) 21 | 22 | return NarrowModelFactory() 23 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | import tempfile 4 | from collections import Counter 5 | from copy import deepcopy 6 | from itertools import chain 7 | 8 | import pyarrow as pa 9 | import torch 10 | 11 | from fms_fsdp.utils.dataset_utils import * 12 | 13 | 14 | # Generates test data in a temp directory, and returns that tempdir object. 15 | # (file path can be retrieved via tempdir.name) 16 | # Two dataset folders: one has a large shardfile (100x100), other has two small shardfiles (50x50) 17 | def generate_sequential_multidata(): 18 | tmpdir = tempfile.TemporaryDirectory() 19 | schema = pa.schema([pa.field("tokens", pa.uint32())]) 20 | 21 | os.mkdir(os.path.join(tmpdir.name, "dataset_1")) 22 | os.mkdir(os.path.join(tmpdir.name, "dataset_2")) 23 | os.mkdir(os.path.join(tmpdir.name, "dataset_2", "subfolder")) 24 | with pa.ipc.new_file( 25 | os.path.join(tmpdir.name, "dataset_1/fullshard.arrow"), schema 26 | ) as writer: 27 | for i in range(100): 28 | out = list(range(i * 100, i * 100 + 100)) 29 | writer.write(pa.record_batch([out], schema=schema)) 30 | 31 | with pa.ipc.new_file( 32 | os.path.join(tmpdir.name, "dataset_2/quartershard_1.arrow"), schema 33 | ) as writer: 34 | for i in range(50): 35 | out = list(range(i * 50, i * 50 + 50)) 36 | writer.write(pa.record_batch([out], schema=schema)) 37 | 38 | with pa.ipc.new_file( 39 | os.path.join(tmpdir.name, "dataset_2/subfolder/quartershard_2.arrow"), schema 40 | ) as writer: 41 | for i in range(50): 42 | out = list(range(2500 + i * 50, 2500 + i * 50 + 50)) 43 | writer.write(pa.record_batch([out], schema=schema)) 44 | 45 | # Make metadata file 46 | os.mkdir(os.path.join(tmpdir.name, "meta")) 47 | f = open(os.path.join(tmpdir.name, "meta", "combined_counts.csv"), "w") 48 | f.write("dataset/filename,documents,tokens\n") 49 | f.write("/dataset_1/fullshard.arrow,100,10000\n") 50 | f.write("/dataset_2/quartershard_1.arrow,50,2500\n") 51 | f.write("/dataset_2/subfolder/quartershard_2.arrow,50,2500\n") 52 | f.close() 53 | 54 | return tmpdir 55 | 56 | 57 | # Make mock data for re-use. Returns directory path. 58 | tmpdir = generate_sequential_multidata() 59 | 60 | 61 | # REPEATED CHECKS 62 | # Checks take a dataset definition (and any other args), instantiate it, and perform a single unit test 63 | # For X_check see corresponding test_X 64 | 65 | 66 | def count_check(d, ntok, alldoc, allpercent): 67 | # Check that tokens tracked matches tokens seen, and docs tracked matches docs seen 68 | # d is a lambda for a fully-defined dataset (i.e. d() instantiates the dataset) 69 | assert ( 70 | d.tokens_seen == ntok 71 | ), f"Tokens tracked {d.tokens_seen} failed to match target {ntok}" 72 | assert ( 73 | d.docs_seen == alldoc 74 | ), f"Total document count {d.docs_seen} does not match target {alldoc}" 75 | coverage = d.percent_seen 76 | assert ( 77 | abs(coverage - allpercent) < 1e-4 78 | ), f"Percent coverage {coverage} is not within 1e-4 of {allpercent}" 79 | 80 | 81 | def multi_reload_stress_check(d): 82 | # Perform the reload stress test for different numbers of steps before and after checkpoint 83 | # d is a lambda for a fully-defined dataset (i.e. d() instantiates the dataset) 84 | 85 | def reload_stress(datasets, datasets2, steps1, steps2): 86 | # Perform the 5-step reload stress test (see test_multi_reload_stress) 87 | 88 | loaders = [iter(d) for d in datasets] 89 | 90 | for i in range(steps1): 91 | [next(l) for l in loaders] 92 | 93 | states = [deepcopy(d.state_dict()) for d in datasets] 94 | 95 | [d.load_state_dict(states) for d in datasets2] 96 | 97 | loaders2 = [iter(d) for d in datasets2] 98 | 99 | for k in range(steps2): 100 | for i in range(3): 101 | out1 = next(loaders[i]) 102 | out2 = next(loaders2[i]) 103 | assert len(out1) == len( 104 | out2 105 | ), f"Dataloader {i} in step {k} has mismatched length: {len(out1)} vs {len(out2)}" 106 | for j in range(len(out1)): 107 | assert ( 108 | out1[j] == out2[j] 109 | ), f"Dataloader {i} in step {k} has mismatched token in position {j}: {out1[j]} vs {out2[j]}" 110 | 111 | steps1 = [0, 1, 10, 100, 1000] 112 | steps2 = [100, 200, 300, 400, 500] 113 | for i in range(len(steps1)): 114 | # Reset between tests (instantiate fresh datasets) 115 | reload_stress(d(), d(), steps1[i], steps2[i]) 116 | 117 | 118 | def single_epoch_check(d, do_countcheck=False): 119 | # For a single loader on dataset_1, check that every doc appears once per epoch 120 | # d is a lambda for a dataset (i.e. d() instantiates the dataset) 121 | dataset = d(datasets=["dataset_1"]) 122 | ins = [] 123 | loader = iter(dataset) 124 | for i in range(100): 125 | out = next(loader) 126 | ins.append(out[0]) 127 | 128 | for i in range(100): 129 | assert ( 130 | i * 100 in ins 131 | ), f"Line starting with {i * 100} failed to appear in generated data" 132 | 133 | if do_countcheck: 134 | # Check state flags tracking correctly 135 | count_check(dataset, 100 * 100, 100, 100) 136 | 137 | 138 | def two_epoch_check(d, do_countcheck=False): 139 | # For a single loader on dataset_1, check that every doc appears twice per two epochs 140 | # d is a lambda for a dataset (i.e. d() instantiates the dataset) 141 | dataset = d(datasets=["dataset_1"]) 142 | ins = [] 143 | loader = iter(dataset) 144 | for i in range(100 * 2): 145 | out = next(loader) 146 | ins.append(out[0]) 147 | 148 | for i in range(100): 149 | key = ins.pop(0) 150 | assert ( 151 | key in ins 152 | ), f"Line starting with {key} failed to appear a second time in generated data" 153 | 154 | if do_countcheck: 155 | # Check state flags tracking correctly 156 | count_check(dataset, 100 * 100 * 2, 200, 200) 157 | 158 | 159 | def chunk_check(d, do_countcheck=False): 160 | # For a single loader on dataset_1, check that every doc chunks properly and that all chunks appear in one epoch 161 | # d is a lambda for a dataset (i.e. d() instantiates the dataset) 162 | dataset = d(datasets=["dataset_1"], max_chunksize=50) 163 | ins = [] 164 | loader = iter(dataset) 165 | for i in range(300): 166 | out = next(loader) 167 | if i % 3 != 2: 168 | assert ( 169 | len(out) == 50 170 | ), f"Line length {len(out)} does not match chunk size 50" 171 | else: 172 | assert ( 173 | out[0] == -1 174 | ), f"Chunk 3 of document {i} is not delimiter token, but {out}" 175 | ins.append(out[0]) 176 | 177 | for i in range(200): 178 | assert ( 179 | i * 50 in ins 180 | ), f"Chunk starting with {i * 50} failed to appear in generated data" 181 | 182 | if do_countcheck: 183 | count_check(dataset, 100 * 100, 100, 100) 184 | 185 | 186 | def two_loader_check(d, do_countcheck=False): 187 | # For two loaders on dataset_1, check that every doc appears once per epoch, collectively 188 | # d is a lambda for a dataset (i.e. d() instantiates the dataset) 189 | dataset1 = d(datasets=["dataset_1"], worldsize=2, rank=0) 190 | dataset2 = d(datasets=["dataset_1"], worldsize=2, rank=1) 191 | ins = [] 192 | loader = iter(dataset1) 193 | for i in range(50): 194 | out = next(loader) 195 | ins.append(out[0]) 196 | loader = iter(dataset2) 197 | for i in range(50): 198 | out = next(loader) 199 | ins.append(out[0]) 200 | 201 | for i in range(100): 202 | assert ( 203 | i * 100 in ins 204 | ), f"Line starting with {i * 100} failed to appear in generated data" 205 | 206 | if do_countcheck: 207 | count_check(dataset1, 50 * 100, 50, 100) 208 | count_check(dataset2, 50 * 100, 50, 100) 209 | 210 | 211 | def multi_file_check(d, do_countcheck=False): 212 | # For a single loader on dataset 2, check that every doc appears once per epoch 213 | # d is a lambda for a dataset (i.e. d() instantiates the dataset) 214 | dataset = d(datasets=["dataset_2"]) 215 | ins = [] 216 | loader = iter(dataset) 217 | for i in range(100): 218 | out = next(loader) 219 | ins.append(out[0]) 220 | 221 | for i in range(100): 222 | assert ( 223 | i * 50 in ins 224 | ), f"Line starting with {i * 50} failed to appear in generated data" 225 | 226 | if do_countcheck: 227 | count_check(dataset, 100 * 50, 100, 100) 228 | 229 | 230 | def chunk_weight_check(w1, w2, d, do_countcheck=False): 231 | # For a single loader on combined datasets, with given oversamples, chunksize 50: check that chunks appear the proper number of times 232 | # d is a lambda for a dataset (i.e. d() instantiates the dataset) 233 | dataset = d(datasets=["dataset_1", "dataset_2"], weights=[w1, w2], max_chunksize=50) 234 | ins = [] 235 | loader = iter(dataset) 236 | for i in range(3 * w1 * 100 + 2 * w2 * 100): 237 | out = next(loader) 238 | if len(out) > 1: 239 | ins.append(out[0]) 240 | 241 | check = Counter(ins) 242 | for i in range(200): 243 | if i < 100: 244 | assert ( 245 | check[i * 50] == w1 + w2 246 | ), f"Chunk starting with {i * 50} appeared {check[i*50]} times rather than {w1+w2}" 247 | else: 248 | assert ( 249 | check[i * 50] == w1 250 | ), f"Chunk starting with {i * 50} appeared {check[i*50]} times rather than {w1}" 251 | 252 | 253 | def reload_epoch_check(loader): 254 | # Single shard, two loaders: do exactly 1/3 of an epoch, checkpoint, reload to same number of workers. 255 | # Complete the epoch and verify that no loaded chunks are revisiting old chunks. 256 | datasets = [ 257 | loader( 258 | rank=i, 259 | worldsize=2, 260 | max_chunksize=40, 261 | ) 262 | for i in range(2) 263 | ] # Length 300 264 | loaders = [iter(d) for d in datasets] 265 | 266 | ins = [] 267 | for _ in range(50): 268 | out = next(loaders[0]) 269 | ins.append(out[0]) 270 | for _ in range(50): 271 | out = next(loaders[1]) 272 | ins.append(out[0]) 273 | 274 | states = [d.state_dict() for d in datasets] 275 | 276 | datasets2 = [ 277 | loader( 278 | rank=i, 279 | worldsize=2, 280 | max_chunksize=40, 281 | ) 282 | for i in range(2) 283 | ] # Length 300 284 | [d.load_state_dict(states) for d in datasets2] 285 | loaders2 = [iter(d) for d in datasets2] 286 | 287 | for j in range(100): 288 | for i in range(2): 289 | out = next(loaders2[i]) 290 | assert ( 291 | out[0] not in ins 292 | ), f"Step {j+1}, dataset {i+1}: chunk starting with {out[0]} has already appeared in the epoch" 293 | 294 | 295 | def reload_single_epoch_check(loader): 296 | # Single shard, two loaders: advance 37 steps, checkpoint, reload to same number of workers. 297 | # Run a full epoch and verify that all data appears once and only once. 298 | datasets = [ 299 | loader( 300 | rank=i, 301 | worldsize=2, 302 | max_chunksize=40, 303 | ) 304 | for i in range(2) 305 | ] # Length 300 306 | loaders = [iter(d) for d in datasets] 307 | 308 | for _ in range(37): 309 | out = next(loaders[0]) 310 | for _ in range(37): 311 | out = next(loaders[1]) 312 | 313 | states = [d.state_dict() for d in datasets] 314 | 315 | datasets2 = [ 316 | loader( 317 | rank=i, 318 | worldsize=2, 319 | max_chunksize=40, 320 | ) 321 | for i in range(2) 322 | ] # Length 300 323 | [d.load_state_dict(states) for d in datasets2] 324 | loaders2 = [iter(d) for d in datasets2] 325 | 326 | ins = [] 327 | for _ in range(150): 328 | out = next(loaders2[0]) 329 | assert out[0] not in ins, (ins, out[0]) 330 | ins.append(out[0]) 331 | for _ in range(150): 332 | out = next(loaders2[1]) 333 | ins.append(out[0]) 334 | 335 | assert len(ins) == len( 336 | set(ins) 337 | ), f"Full epoch output contains {len(ins)} values but only {len(set(ins))} unique" 338 | 339 | 340 | def single_doc_bos_eos_check(loader, do_bos): 341 | # Single shard, single loader: load two chunks, verify that sizes match when BOS is on/off 342 | expected_vals = ( 343 | [ 344 | [99, 3], 345 | [100, 2], 346 | [101, 1], 347 | [102, 102], 348 | [102, 102], 349 | ] 350 | if do_bos 351 | else [ 352 | [99, 2], 353 | [100, 1], 354 | [101, 101], 355 | [101, 101], 356 | [101, 101], 357 | ] 358 | ) 359 | for i, c in enumerate([99, 100, 101, 102, 103]): 360 | dataset = loader( 361 | rank=0, worldsize=1, max_chunksize=c, bos_token=100 if do_bos else None 362 | ) 363 | d = iter(dataset) 364 | for _ in range(10): 365 | c1 = next(d) 366 | c2 = next(d) 367 | assert ( 368 | len(c1) == expected_vals[i][0] 369 | ), f"Expected size {expected_vals[i][0]} in first chunk, got {len(c1)}" 370 | assert ( 371 | len(c2) == expected_vals[i][1] 372 | ), f"Expected size {expected_vals[i][1]} in second chunk, got {len(c2)}" 373 | if c == 99: 374 | assert ( 375 | c1[-1] == c2[0] - 1 376 | ), f"Expected chunk 2 to follow chunk1, got {c1[-1]} and {c2[0]}" 377 | 378 | 379 | def single_epoch_loader_worker_check(d, n_workers=0): 380 | # For dataset_1 partitioned over logical shards / workers / ranks, 381 | # check that every doc appears once per epoch 382 | loaders = [ 383 | torch.utils.data.DataLoader(x, num_workers=n_workers, batch_size=1) for x in d 384 | ] 385 | loaders = [iter(l) for l in loaders] 386 | n_steps = 100 // len(loaders) 387 | ins = [] 388 | for _ in range(n_steps): 389 | for l in loaders: 390 | out = next(l) 391 | ins.append(out[0].item()) 392 | 393 | for i in range(100): 394 | assert ( 395 | i * 100 in ins 396 | ), f"Line starting with {i * 100} failed to appear in generated data: worldsize {len(loaders)}, n_workers {n_workers}" 397 | 398 | 399 | # BASE DATASET TESTS 400 | 401 | 402 | def basic_loader( 403 | rank=0, 404 | worldsize=1, 405 | datasets=["dataset_1"], 406 | max_chunksize=1000, 407 | bos_token=None, 408 | ): 409 | assert len(datasets) == 1, "Basic loader takes only 1 dataset" 410 | return StreamingDocDataset( 411 | os.path.join(tmpdir.name, datasets[0]), 412 | rank, 413 | worldsize, 414 | ArrowHandler(), 415 | -1, 416 | max_chunksize=max_chunksize, 417 | bos_token=bos_token, 418 | ) 419 | 420 | 421 | def basic_sampler( 422 | rank=0, worldsize=1, datasets=["dataset_1"], weights=[1], max_chunksize=1000 423 | ): 424 | return SamplingDataset( 425 | tmpdir.name, 426 | basic_loader(rank, worldsize, datasets[:1], max_chunksize, None), 427 | -1, 428 | datasets, 429 | weights, 430 | ) 431 | 432 | 433 | def basic_scalable( 434 | rank=0, 435 | worldsize=1, 436 | datasets=["dataset_1"], 437 | max_chunksize=1000, 438 | n_logical_shards=7, 439 | bos_token=None, 440 | ): 441 | assert len(datasets) == 1, "Basic loader takes only 1 dataset" 442 | return ScalableShardDataset( 443 | basic_loader(rank, worldsize, datasets, max_chunksize, bos_token), 444 | -1, 445 | n_logical_shards, 446 | ) 447 | 448 | 449 | def basic_sampler_scalable( 450 | rank=0, 451 | worldsize=1, 452 | datasets=["dataset_1"], 453 | weights=[1], 454 | max_chunksize=1000, 455 | n_logical_shards=7, 456 | ): 457 | return SamplingDataset( 458 | tmpdir.name, 459 | basic_scalable( 460 | rank, worldsize, datasets[:1], max_chunksize, n_logical_shards, None 461 | ), 462 | -1, 463 | datasets, 464 | weights, 465 | ) 466 | 467 | 468 | def test_single_epoch(): 469 | # Single shard, single loader: every line appears once in an epoch 470 | single_epoch_check(basic_loader, True) 471 | single_epoch_check(basic_scalable) 472 | single_epoch_check(basic_sampler) 473 | single_epoch_check(basic_sampler_scalable) 474 | 475 | 476 | def test_two_epoch(): 477 | # Single shard, single loader: every line appears twice in two epochs 478 | two_epoch_check(basic_loader, True) 479 | two_epoch_check(basic_scalable) 480 | two_epoch_check(basic_sampler) 481 | two_epoch_check(basic_sampler_scalable) 482 | 483 | 484 | def test_chunk(): 485 | # Single shard, single loader, two chunks/doc plus a delimiter token: every chunk appears once in an epoch 486 | chunk_check(functools.partial(basic_loader, max_chunksize=50), True) 487 | chunk_check(functools.partial(basic_scalable, max_chunksize=50)) 488 | chunk_check(functools.partial(basic_sampler, max_chunksize=50)) 489 | chunk_check(functools.partial(basic_sampler_scalable, max_chunksize=50)) 490 | 491 | 492 | def test_two_loader(): 493 | # Single shard, two loaders: every line appears once per epoch, collectively 494 | two_loader_check(basic_loader, True) 495 | two_loader_check(functools.partial(basic_scalable, n_logical_shards=8)) 496 | two_loader_check(basic_sampler) 497 | two_loader_check(functools.partial(basic_sampler_scalable, n_logical_shards=8)) 498 | 499 | 500 | def test_multi_file(): 501 | # Multiple shard files, single loader: every line appears once in an epoch 502 | multi_file_check(basic_loader, True) 503 | multi_file_check(basic_scalable) 504 | multi_file_check(basic_sampler) 505 | multi_file_check(basic_sampler_scalable) 506 | 507 | 508 | def test_reload_epoch(): 509 | # Single shard, two loaders: check that reloading mid-epoch does not cause data to repeat while finishing the epoch 510 | reload_epoch_check(basic_loader) 511 | reload_epoch_check(functools.partial(basic_scalable, n_logical_shards=8)) 512 | reload_epoch_check(basic_sampler) 513 | reload_epoch_check(functools.partial(basic_sampler_scalable, n_logical_shards=8)) 514 | 515 | 516 | def test_reload_complete_epoch(): 517 | # Single shard, two loaders: check that reloading mid-epoch can still complete a full epoch 518 | reload_single_epoch_check(basic_loader) 519 | reload_single_epoch_check(functools.partial(basic_scalable, n_logical_shards=8)) 520 | reload_single_epoch_check(basic_sampler) 521 | reload_single_epoch_check( 522 | functools.partial(basic_sampler_scalable, n_logical_shards=8) 523 | ) 524 | 525 | 526 | def test_eos_bos_chunking(): 527 | # Single shard, single loader: check that enabling/disabling bos tokens maintains correct chunking behavior 528 | single_doc_bos_eos_check(basic_loader, False) 529 | single_doc_bos_eos_check(basic_loader, True) 530 | single_doc_bos_eos_check(basic_scalable, False) 531 | single_doc_bos_eos_check(basic_scalable, True) 532 | 533 | 534 | # SUBDATASET WEIGHTING CHECKS 535 | 536 | 537 | def test_sampler_rates(): 538 | """ 539 | A test for SamplingDataset with Streaming_ and Scalable_ subdatasets. 540 | On the full dataset, with varying weights, on a single worker: verify that loaders pull subdatasets at regular intervals 541 | (verifying that they're regularly picking the most-underviewed subdataset at each step). 542 | """ 543 | weights = [[1, 1], [2, 1], [2, 3], [2, 5]] 544 | target_rate = [3, 2, 4, 6] 545 | burnin = [3, 0, 4, 6] 546 | 547 | # Dataset1 docs are twice the length of dataset2. Burnin required to reach equilibrium. 548 | # Expected sequences for each case are: 549 | # 2 1 2 1 2 2 (1 2 2)... 550 | # 1 2 1 2 (1 2)... 551 | # 2 1 2 2 1 2 2 2 (1 2 2 2)... 552 | # 2 1 2 2 2 2 1 2 2 2 2 2 (1 2 2 2 2 2)... 553 | 554 | def check_rates(w, t, b, m): 555 | s = [] 556 | d = m(datasets=["dataset_1", "dataset_2"], weights=w) 557 | l = iter(d) 558 | for i in range(b): 559 | s.append(len(next(l))) 560 | for i in range(100): 561 | out = next(l) 562 | s.append(len(out)) 563 | if i % t == 0: 564 | assert ( 565 | len(out) == 101 566 | ), f"Output {i} length {len(out)} does not match expected 101. Sequence so far: {s}" 567 | else: 568 | assert ( 569 | len(out) == 51 570 | ), f"Output {i} length {len(out)} does not match expected 51. Sequence so far: {s}" 571 | 572 | for i in range(3): 573 | for m in [basic_sampler, basic_sampler_scalable]: 574 | check_rates(weights[i], target_rate[i], burnin[i], m) 575 | 576 | 577 | # STRESS TEST 578 | 579 | 580 | def test_multi_reload_stress(): 581 | """ 582 | For each nontrivial layer of the dataset pipeline: 583 | For each combo of steps: 584 | Initialize two identical datasets 585 | Take n steps with the first one 586 | Save checkpoint 587 | Load checkpoint into second dataset 588 | Take k steps with both datasets, check that outputs are identical 589 | Parameters are chosen to ensure messy states (non-divisible chunk sizes, shard numbers, n_workers, buffer_length, etc.) 590 | """ 591 | # Shard doc dataset 592 | d1 = lambda: [ 593 | StreamingDocDataset( 594 | os.path.join(tmpdir.name, "dataset_2"), 595 | i, 596 | 3, 597 | ArrowHandler(), 598 | -1, 599 | max_chunksize=17, 600 | ) 601 | for i in range(3) 602 | ] 603 | multi_reload_stress_check(d1) 604 | 605 | # Scalable shard dataset 606 | d2 = lambda x: [ScalableShardDataset(d, -1, n_logical_shards=15) for d in x] 607 | multi_reload_stress_check(lambda: d2(d1())) 608 | 609 | # Sampling dataset 610 | d3 = lambda x: [ 611 | SamplingDataset( 612 | tmpdir.name, 613 | d, 614 | -1, 615 | datasets=["dataset_1", "dataset_2"], 616 | weights=[3, 5], 617 | ) 618 | for d in x 619 | ] 620 | multi_reload_stress_check(lambda: d3(d1())) 621 | 622 | # Nested scalable sampling dataset 623 | d4 = lambda: d3(d2(d1())) 624 | multi_reload_stress_check(d4) 625 | 626 | # Add buffer dataset 627 | d5 = lambda x: [BufferDataset(d, 73, pack_hard=True, bos_token=-1) for d in x] 628 | multi_reload_stress_check(lambda: d5(d4())) 629 | 630 | # Add preload buffer dataset 631 | d6 = lambda x: [PreloadBufferDataset(d, 99) for d in x] 632 | # preload / sample / scale / doc pipeline 633 | multi_reload_stress_check(lambda: d6(d5(d4()))) 634 | 635 | 636 | # SCALABLEDATASET TESTS 637 | 638 | 639 | def test_scalable_partitioning(): 640 | """ 641 | Test that partitioning occurs correctly when rescaling up or down, including to non-multiples of the original 642 | physical worker count. Start with 4 workers with 12 logical shards, and for each of [1,2,3,6,12], verify that: 643 | 1) no overlap exists between workers and 2) in over one epoch's worth of steps, each data point appears at least once 644 | """ 645 | l1 = lambda r, w: basic_scalable(r, w, max_chunksize=200, n_logical_shards=12) 646 | l2 = lambda r, w: basic_sampler_scalable( 647 | r, w, max_chunksize=200, n_logical_shards=12 648 | ) 649 | for layer in [l1, l2]: 650 | datasets = [layer(i, 4) for i in range(4)] # 25 steps per epoch 651 | loaders = [iter(d) for d in datasets] 652 | 653 | for _ in range(50): 654 | [next(l) for l in loaders] 655 | 656 | states = [d.state_dict() for d in datasets] 657 | 658 | for worldsize in [1, 2, 3, 6, 12]: 659 | datasets = [layer(i, worldsize) for i in range(worldsize)] 660 | [d.load_state_dict(states) for d in datasets] 661 | loaders = [iter(d) for d in datasets] 662 | outs = [[] for _ in datasets] 663 | steps = int(100 / worldsize * 1.25) 664 | for i in range(steps): 665 | for j, l in enumerate(loaders): 666 | outs[j].append(next(l)[0]) 667 | 668 | # Check for non-overlap 669 | for i in range(len(datasets)): 670 | for j in range(i + 1, len(datasets)): 671 | outi = set(outs[i]) 672 | outj = set(outs[j]) 673 | for t in outi: 674 | assert ( 675 | t not in outj 676 | ), f"Overlapping value {t} detected in worker {i} and {j}: {outi}, {outj}" 677 | for t in outj: 678 | assert ( 679 | t not in outi 680 | ), f"Overlapping value {t} detected in worker {i} and {j}: {outi}, {outj}" 681 | 682 | # Check for completion 683 | allout = set(chain(*outs)) 684 | for i in range(100): 685 | assert i * 100 in allout, f"Token {i*100} missing from outputs {allout}" 686 | 687 | 688 | def test_scalable_shard_reload_scale(): 689 | """ 690 | As test_reload_epoch, but in this case we scale from 2 workers to 4 (complete 1/3 epoch, reload, finish without duplication). 691 | Because logical shards won't all be the exact same length when checkpointed, we complete the epoch of the shortest of the new workers. 692 | """ 693 | datasets = [ 694 | basic_scalable(i, 2, max_chunksize=40, n_logical_shards=8) for i in range(2) 695 | ] # Length 300 696 | loaders = [iter(d) for d in datasets] 697 | 698 | ins = [] 699 | for _ in range(50): 700 | out = next(loaders[0]) 701 | ins.append(out[0]) 702 | for _ in range(50): 703 | out = next(loaders[1]) 704 | ins.append(out[0]) 705 | 706 | states = [d.state_dict() for d in datasets] 707 | 708 | datasets2 = [ 709 | basic_scalable(i, 4, max_chunksize=40, n_logical_shards=8) for i in range(4) 710 | ] # Length 300 711 | [d.load_state_dict(states) for d in datasets2] 712 | ndocs = [sum(d.n_docs_remaining) for d in datasets] 713 | print("n_docs_remaining from old loader:", ndocs) 714 | ndocs = [sum(d.n_docs_remaining) for d in datasets2] 715 | print("n_docs_remaining per new loader:", ndocs) 716 | 717 | loaders2 = [iter(d) for d in datasets2] 718 | 719 | print("Checking only", min(ndocs) * 3, "steps instead of full 50") 720 | for j in range(min(ndocs) * 3): 721 | for i in range(4): 722 | out = next(loaders2[i]) 723 | assert ( 724 | out[0] not in ins 725 | ), f"Step {j+1}, dataset {i+1}: chunk starting with {out[0]} has already appeared in the epoch" 726 | 727 | 728 | def test_scalable_sampler_reload_scale(): 729 | """ 730 | As test_reload_epoch, but in this case we scale from 2 workers to 4 (complete 1/3 epoch, reload, finish without duplication). 731 | Because logical shards and sampling ratios won't be exact, take a few extra steps then check that epoch is complete. 732 | """ 733 | datasets = [ 734 | basic_sampler_scalable(i, 2, max_chunksize=40, n_logical_shards=8) 735 | for i in range(2) 736 | ] # Length 300 737 | loaders = [iter(d) for d in datasets] 738 | 739 | ins = [] 740 | for _ in range(50): 741 | out = next(loaders[0]) 742 | ins.append(out[0]) 743 | for _ in range(50): 744 | out = next(loaders[1]) 745 | ins.append(out[0]) 746 | 747 | states = [d.state_dict() for d in datasets] 748 | 749 | datasets2 = [ 750 | basic_sampler_scalable(i, 4, max_chunksize=40, n_logical_shards=8) 751 | for i in range(4) 752 | ] # Length 300 753 | [d.load_state_dict(states) for d in datasets2] 754 | loaders2 = [iter(d) for d in datasets2] 755 | 756 | for i in range(4): 757 | for _ in range(55): 758 | out = next(loaders2[i]) 759 | ins.append(out[0]) 760 | 761 | for suf in [0, 40, 80]: 762 | for i in range(100): 763 | assert ( 764 | i * 100 + suf in ins 765 | ), f"Expected value {i*100+suf} not found in output set {ins}" 766 | 767 | 768 | # BUFFERDATASET TESTS 769 | 770 | 771 | class RandCounter: 772 | # Spit out incremental counts of random length, uniformly sampled from 1 to 50 773 | def __init__(self): 774 | self.i = 0 775 | self.rank = 0 776 | self.worldsize = 1 777 | self.datapath = tmpdir.name 778 | 779 | def __iter__(self): 780 | while True: 781 | l = torch.randint(1, 50, [1]).item() 782 | yield list(range(self.i, self.i + l)) 783 | self.i += l 784 | 785 | 786 | def test_buffer_format(): 787 | # Using the RandCounter, verify that streams are reformed into correct-length buffers, 788 | # that final tokens match the predicted count, and that BOS/EOS add correctly 789 | 790 | for _ in range(100): 791 | # 100 trials of random length inputs 792 | base = RandCounter() 793 | dataset = BufferDataset(base, 100, pack_hard=True) 794 | loader = iter(dataset) 795 | for _ in range(100): 796 | out = next(loader) 797 | assert ( 798 | len(out) == 100 799 | ), f"Length of output {len(out)} does not match specified 100" 800 | assert ( 801 | out[-1] == 100 * 100 - 1 802 | ), f"Final token {out[-1]} does not match expected value {100*100-1}" 803 | 804 | # As above, but now with EOS tokens 805 | for _ in range(100): 806 | base = RandCounter() 807 | dataset = BufferDataset(base, 100, pack_hard=True, eos_token=-1) 808 | loader = iter(dataset) 809 | for i in range(100): 810 | out = next(loader) 811 | assert ( 812 | len(out) == 100 813 | ), f"Length of output {len(out)} does not match specified 100" 814 | assert out[-1] == -1, f"Output {out} does not end in EOS" 815 | assert ( 816 | out[-2] == 100 * 99 - 1 817 | ), f"Penultimate token {out[-2]} does not match expected value {100*99-1}" 818 | 819 | # As above, but now with BOS tokens 820 | for _ in range(100): 821 | base = RandCounter() 822 | dataset = BufferDataset(base, 100, pack_hard=True, bos_token=-1) 823 | loader = iter(dataset) 824 | for i in range(100): 825 | out = next(loader) 826 | assert ( 827 | len(out) == 100 828 | ), f"Length of output {len(out)} does not match specified 100" 829 | assert out[0] == -1, f"Output {out} does not begin with BOS" 830 | assert ( 831 | out[-1] == 100 * 99 - 1 832 | ), f"Final token {out[-1]} does not match expected value {100*99-1}" 833 | 834 | 835 | def test_buffer_delimiter_overlap(): 836 | """ 837 | Check that BOS adds correctly when absent, and refrains when present. 838 | Because doc delimiter token is also -1, BOS will add in the first instance, which shunts the delimiter token 839 | into the first slot in the next (and all subsequent) outputs. BOS should then refrain from adding. 840 | """ 841 | dataset = basic_loader(max_chunksize=101) 842 | dataset = BufferDataset(dataset, 101, pack_hard=True, bos_token=-1) 843 | loader = iter(dataset) 844 | for _ in range(100): 845 | out = next(loader) 846 | assert ( 847 | len(out) == 101 848 | ), f"Length of output {len(out)} does not match specified 101" 849 | assert out[0] == -1, f"Output {out} does not begin with BOS" 850 | assert ( 851 | out[-1] % 100 == 99 852 | ), f"Final token {out[-1]} does not end in expected value 99" 853 | 854 | 855 | # PRELOADBUFFERDATASET TESTS 856 | 857 | 858 | class SteadyCounter: 859 | # Spit out incremental counts of constant length l 860 | def __init__(self, l): 861 | self.i = 0 862 | self.rank = 0 863 | self.worldsize = 1 864 | self.datapath = tmpdir.name 865 | self.l = l 866 | 867 | def __iter__(self): 868 | while True: 869 | yield list(range(self.i, self.i + self.l)) 870 | self.i += self.l 871 | 872 | 873 | def test_preload_buffer_uniformity(): 874 | """ 875 | With underlying SteadyCounter and window size 200, take 1000 steps. 876 | Ensure 95% of values between 0 and 100 are emitted. 877 | """ 878 | dataset = PreloadBufferDataset(SteadyCounter(1), 200) 879 | loader = iter(dataset) 880 | outs = [] 881 | 882 | for _ in range(1000): 883 | x = next(loader)[0] 884 | if x < 100: 885 | outs.append(x) 886 | 887 | assert len(outs) > 95, f"Only {len(outs)} values <100 detected" 888 | 889 | 890 | # CHECKPOINTDATASET TESTS 891 | 892 | 893 | def test_checkpoint_reload_match(): 894 | """ 895 | Check that the auto-checkpointer saves and loads correctly, and that loaded checkpoints 896 | resume properly (matching the continued behavior of the saved ones) 897 | """ 898 | datasets = [ 899 | basic_sampler(i, 3, ["dataset_1", "dataset_2"], [3, 5], max_chunksize=17) 900 | for i in range(3) 901 | ] 902 | datasets = [BufferDataset(d, 73, pack_hard=True, bos_token=-1) for d in datasets] 903 | datasets = [ 904 | CheckpointDataset(x, os.path.join(tmpdir.name, "ckp_test"), 100, 2) 905 | for x in datasets 906 | ] 907 | loaders = [ 908 | torch.utils.data.DataLoader( 909 | x, num_workers=1, batch_size=2, prefetch_factor=1, persistent_workers=True 910 | ) 911 | for x in datasets 912 | ] 913 | loaders = [iter(x) for x in loaders] 914 | for _ in range(100): 915 | for loader in loaders: 916 | next(loader) 917 | 918 | # Assert checkpoint exists and is properly formatted 919 | ckps = os.listdir(os.path.join(tmpdir.name, "ckp_test", "checkpoints")) 920 | assert len(ckps) == 1, f"Expected only one checkpoint (found {len(ckps)})" 921 | ckp_shards = os.listdir( 922 | os.path.join(tmpdir.name, "ckp_test", "checkpoints", ckps[0]) 923 | ) 924 | assert ( 925 | len(ckp_shards) == 3 926 | ), f"Expected three checkpoint shards (found {len(ckp_shards)})" 927 | 928 | # Create a second loader, pointing to first's checkpoint 929 | datasets2 = [ 930 | basic_sampler(i, 3, ["dataset_1", "dataset_2"], [3, 5], max_chunksize=17) 931 | for i in range(3) 932 | ] 933 | datasets2 = [BufferDataset(d, 73, pack_hard=True, bos_token=-1) for d in datasets2] 934 | datasets2 = [ 935 | CheckpointDataset(x, os.path.join(tmpdir.name, "ckp_test"), 1000, 2) 936 | for x in datasets2 937 | ] 938 | [d.setup() for d in datasets2] 939 | 940 | # Assert checkpoints have loaded correctly 941 | for d in datasets2: 942 | assert d.step == 100, f"Expected to load back to step 100, got {d.step}" 943 | 944 | # Continue iterating, verify matching behavior 945 | loaders2 = [ 946 | torch.utils.data.DataLoader( 947 | x, num_workers=1, batch_size=2, prefetch_factor=1, persistent_workers=True 948 | ) 949 | for x in datasets2 950 | ] 951 | loaders2 = [iter(x) for x in loaders2] 952 | for _ in range(300): 953 | for loader, loader2 in zip(loaders, loaders2): 954 | out = sum(next(loader2)) 955 | targ = sum(next(loader)) 956 | assert len(out) == len( 957 | targ 958 | ), f"Expected same output lengths, got {len(out)}, {len(targ)}" 959 | for i, (x, y) in enumerate(zip(out, targ)): 960 | assert x == y, f"Mismatch in position {i}: got {x}, {y}" 961 | 962 | 963 | # MULTIPROCESS DATALOADER WORKER TESTS 964 | 965 | 966 | def test_multiprocess_epoch(): 967 | """ 968 | Check that ScalableShardDataset partitions correctly over various worldsize / n_worker 969 | combinations. A single epoch should contain each datapoint exactly once. 970 | """ 971 | n_workers = [0, 2] 972 | worldsizes = [2, 5] 973 | for n in n_workers: 974 | for w in worldsizes: 975 | d = [basic_scalable(i, w, n_logical_shards=20) for i in range(w)] 976 | # Add a dummy wrapper (append some pads) to test correct wrapper behavior 977 | d = [BufferDataset(x, 110, False, pad_token=-1) for x in d] 978 | single_epoch_loader_worker_check(d, n) 979 | -------------------------------------------------------------------------------- /tests/test_selective_ac.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import pytest 4 | from fms.models.llama import LLaMABlock 5 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 6 | CheckpointWrapper, 7 | ) 8 | 9 | from fms_fsdp.policies import apply_fsdp_checkpointing 10 | 11 | 12 | @pytest.mark.parametrize("narrow_model_factory", [15], indirect=True) 13 | def test_selective_ac(narrow_model_factory): 14 | apply_ac = partial(apply_fsdp_checkpointing, block=LLaMABlock) 15 | 16 | model = narrow_model_factory.create() 17 | apply_ac(model, p=0) 18 | expected = [False] * 15 19 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected 20 | 21 | model = narrow_model_factory.create() 22 | apply_ac(model, p=1 / 100) 23 | expected = [False] * 15 24 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected 25 | 26 | model = narrow_model_factory.create() 27 | apply_ac(model, p=1 / 5) 28 | expected = [False, False, True, False, False] * 3 29 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected 30 | 31 | model = narrow_model_factory.create() 32 | apply_ac(model, p=1 / 3) 33 | expected = [False, True, False] * 5 34 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected 35 | 36 | model = narrow_model_factory.create() 37 | apply_ac(model, p=1 / 2) 38 | expected = [True, False] * 7 + [True] 39 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected 40 | 41 | model = narrow_model_factory.create() 42 | apply_ac(model, p=3 / 5) 43 | expected = [True, False, True, False, True] * 3 44 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected 45 | 46 | model = narrow_model_factory.create() 47 | apply_ac(model, p=2 / 3) 48 | expected = [True, False, True] * 5 49 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected 50 | 51 | model = narrow_model_factory.create() 52 | apply_ac(model, p=1) 53 | expected = [True] * 15 54 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected 55 | 56 | model = narrow_model_factory.create() 57 | apply_ac(model, p=5 / 3) 58 | expected = [True] * 15 59 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected 60 | 61 | model = narrow_model_factory.create() 62 | apply_ac(model, p=-1) 63 | expected = [False] * 15 64 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected 65 | --------------------------------------------------------------------------------