├── .github
└── workflows
│ ├── lint.yml
│ ├── mypy.yml
│ └── pytest.yml
├── .gitignore
├── .isort.cfg
├── LICENSE
├── README.md
├── code-of-conduct.md
├── docs
├── configurations.md
├── dataloader.md
├── evaluation.md
├── fine_tuning.md
└── train_details.md
├── fms_fsdp
├── __init__.py
├── config
│ ├── __init__.py
│ └── training.py
├── policies
│ ├── __init__.py
│ ├── ac_handler.py
│ ├── mixed_precision.py
│ ├── param_init.py
│ └── wrapping.py
├── readme.md
├── requirements.txt
└── utils
│ ├── __init__.py
│ ├── checkpointing_utils.py
│ ├── config_utils.py
│ ├── dataloader_utils.py
│ ├── dataset_utils.py
│ └── train_utils.py
├── fms_to_hf_llama.py
├── fms_to_hf_mamba.py
├── images
├── loss_curve.png
└── lr.png
├── main_training_llama.py
├── main_training_mamba.py
├── requirements-speculator.txt
├── requirements.txt
├── scripts
├── README_SPECULATOR.md
├── train.sh
├── train.slurm
└── train_speculator.sh
├── setup.py
├── speculator
├── __init__.py
├── train_speculator.py
└── train_speculator_utils.py
├── test-requirements.txt
└── tests
├── conftest.py
├── test_datasets.py
└── test_selective_ac.py
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: Lint
2 |
3 | on: [pull_request]
4 |
5 | permissions:
6 | contents: read
7 |
8 | jobs:
9 | lint:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v4
13 | - uses: psf/black@stable
14 | with:
15 | options: "--check --diff --color"
16 | src: "."
17 | version: "~= 23.3.0"
18 | - uses: isort/isort-action@master
19 | with:
20 | sort-paths: .
21 | requirementsFiles: "requirements.txt" # We don't need extra test requirements for linting
22 |
--------------------------------------------------------------------------------
/.github/workflows/mypy.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies and run MyPy
2 |
3 | name: MyPy Type Checking
4 |
5 | on: [pull_request]
6 |
7 | permissions:
8 | contents: read
9 |
10 | jobs:
11 | build:
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - uses: actions/checkout@v4
16 | - name: Set up Python 3.10
17 | id: setup_python
18 | uses: actions/setup-python@v5
19 | with:
20 | python-version: "3.10"
21 | - name: Restore Virtualenv
22 | uses: actions/cache/restore@v4
23 | id: cache-venv-restore
24 | with:
25 | path: ./.venv/
26 | key: ${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-venv-${{ hashFiles('*requirements.txt') }}
27 | restore-keys: |
28 | ${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-venv-
29 | - name: Install dependencies
30 | run: |
31 | # Create the virtual environment
32 | python -m venv .venv
33 | . ./.venv/bin/activate
34 |
35 | # Install the dependencies
36 | # In case of a cache hit on the primary key, this will be a no-op
37 | # In case of a cache miss, but hit on a secondary key, this will update what's changed
38 | python -m pip install --upgrade pip
39 | pip install -r test-requirements.txt
40 |
41 | # Enables the virtual env for following steps
42 | echo "$VIRTUAL_ENV/bin" >> $GITHUB_PATH
43 | echo "VIRTUAL_ENV=$VIRTUAL_ENV" >> $GITHUB_ENV
44 |
45 | - name: Test with mypy
46 | run: |
47 | # Install ibm-fms from the main branch for testing purposes
48 | # Use -I to ignore the existing install and actually install
49 | # the version on main
50 | pip install -I ibm-fms@git+https://github.com/foundation-model-stack/foundation-model-stack@main
51 |
52 | # No type stubs available for "fire" and "transformers"
53 | mypy --exclude fms_to_hf.py --exclude main_training.py --exclude setup.py .
54 |
55 | - name: Save Virtualenv
56 | id: cache-venv-save
57 | uses: actions/cache/save@v4
58 | with:
59 | path: ./.venv/
60 | key: ${{ steps.cache-venv-restore.outputs.cache-primary-key }}
61 |
--------------------------------------------------------------------------------
/.github/workflows/pytest.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install dependencies and run pytest
2 |
3 | name: Pytest
4 |
5 | on: [pull_request]
6 |
7 | permissions:
8 | contents: read
9 |
10 | jobs:
11 | build:
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - uses: actions/checkout@v4
16 | - name: Set up Python 3.10
17 | id: setup_python
18 | uses: actions/setup-python@v5
19 | with:
20 | python-version: "3.10"
21 | - name: Restore Virtualenv
22 | uses: actions/cache/restore@v4
23 | id: cache-venv-restore
24 | with:
25 | path: ./.venv/
26 | key: ${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-venv-${{ hashFiles('*requirements.txt') }}
27 | restore-keys: |
28 | ${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-venv-
29 | - name: Install dependencies
30 | run: |
31 | # Create the virtual environment
32 | python -m venv .venv
33 | . ./.venv/bin/activate
34 |
35 | # Install the dependencies
36 | # In case of a cache hit on the primary key, this will be a no-op
37 | # In case of a cache miss, but hit on a secondary key, this will update what's changed
38 | python -m pip install --upgrade pip
39 | pip install -r test-requirements.txt
40 |
41 | # Enables the virtual env for following steps
42 | echo "$VIRTUAL_ENV/bin" >> $GITHUB_PATH
43 | echo "VIRTUAL_ENV=$VIRTUAL_ENV" >> $GITHUB_ENV
44 |
45 | - name: Test with pytest
46 | run: |
47 | # Install ibm-fms from the main branch for testing purposes
48 | # Use -I to ignore the existing install and actually install
49 | # the version on main
50 | pip install -I ibm-fms@git+https://github.com/foundation-model-stack/foundation-model-stack@main
51 |
52 | # Install fms-fsdp project
53 | pip install -e .
54 |
55 | # No type stubs available for "fire" and "transformers"
56 | pytest tests/
57 |
58 | - name: Save Virtualenv
59 | id: cache-venv-save
60 | uses: actions/cache/save@v4
61 | with:
62 | path: ./.venv/
63 | key: ${{ steps.cache-venv-restore.outputs.cache-primary-key }}
64 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 |
3 | .DS_Store
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # poetry
102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103 | # This is especially recommended for binary packages to ensure reproducibility, and is more
104 | # commonly ignored for libraries.
105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106 | #poetry.lock
107 |
108 | # pdm
109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110 | #pdm.lock
111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112 | # in version control.
113 | # https://pdm.fming.dev/#use-with-ide
114 | .pdm.toml
115 |
116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117 | __pypackages__/
118 |
119 | # Celery stuff
120 | celerybeat-schedule
121 | celerybeat.pid
122 |
123 | # SageMath parsed files
124 | *.sage.py
125 |
126 | # Environments
127 | .env
128 | .venv
129 | env/
130 | venv/
131 | ENV/
132 | env.bak/
133 | venv.bak/
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # mkdocs documentation
143 | /site
144 |
145 | # mypy
146 | .mypy_cache/
147 | .dmypy.json
148 | dmypy.json
149 |
150 | # Pyre type checker
151 | .pyre/
152 |
153 | # pytype static type analyzer
154 | .pytype/
155 |
156 | # Cython debug symbols
157 | cython_debug/
158 |
159 | # PyCharm
160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162 | # and can be added to the global gitignore or merged into this file. For a more nuclear
163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164 | #.idea/
165 |
--------------------------------------------------------------------------------
/.isort.cfg:
--------------------------------------------------------------------------------
1 | [settings]
2 | ensure_newline_before_comments = True
3 | force_grid_wrap = 0
4 | include_trailing_comma = True
5 | lines_after_imports = 2
6 | multi_line_output = 3
7 | use_parentheses = True
8 | profile = black
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FMS FSDP - (Pre)Training FMS with FSDP
2 |
3 | The “fms-fsdp” repo is a companion to the [Foundation Model Stack](https://github.com/foundation-model-stack/foundation-model-stack).
4 | The goal of this repo is to provide a (pre)training example to efficiently train
5 | FMS models, in particular Llama2 by leveraging native PyTorch features - FSDP for training and SDPA implementation of Flash attention v2. While there are many exemplar repositories that can perform pretraining at scale (e.g., [MegatronLM](), [DeepSpeed]()), this work is what IBM has been doing with PyTorch community on using FSDP for training and how to do that efficiently. It is not meant to be an end-to-end framework for training of models, which includes data preparation (pre), and alignment/tuning of the base model (post).
6 |
7 | For an end-to-end framework, we would recommend the reader to [OLMo](https://github.com/allenai/OLMo) from AllenAI, which provides datasets, data preprocessing frameworks, leverages FSDP on AMD GPUs for training, and provides a tuning/alignment framework.
8 |
9 | ## Training throughput benchmarks
10 | **_numbers are updated with `torch.compile`, as our fms models are fully compatible with torch compile_**
11 |
12 | We benchmark the best possible throughput and the strategies we employ in the below table and share the throughput obtained on 128 A100 GPUs as well as 96 H100 GPUs, we use the exact same scripts and configurations for these GPUs.
13 |
14 | | Model Size | Sharding Strategy | Compile | Activation Checkpointing | Batch Size | Training Throughput
tokens/sec/GPU
A100 80G 128 GPUs with 400Gbps | Training throughput
tokens/sec/GPU
H100 96 GPUs with 800 Gbps |
15 | |------------|-------------------|---------|--------------------------|------------|-------------------------------------------------------------------------------|---------------------------------------------------------------------------|
16 | | 7b | HSDP | Y | No AC | 2 | 4550 | 9600 |
17 | | 13b | FSDP | Y | Selective AC | 2 | 2150 | 4850 |
18 | | 34b | FSDP | Y | Selective AC | 2 | 820 | 1830 |
19 | | 70b | FSDP | Y | Selective AC | 2 | 410 | 890 |
20 |
21 | HFU numbers are computed using the [PyTorch FLOP counter](https://github.com/pytorch/pytorch/blob/2240018c03744ee34ea14ad53481db934c37e384/torch/utils/flop_counter.py#L336) and the theoretical bf16 performance of
22 | A100 and H100 GPUs, whereas MFU numbers are computed using the methodology outlined in
23 | [NanoGPT](https://github.com/karpathy/nanoGPT) and the [PaLM](https://arxiv.org/pdf/2204.02311.pdf) paper.
24 |
25 | | Model Size | Compile | Batch size | MFU (A100 80G) | HFU (A100 80G) | MFU (H100 80G) | HFU (H100 80G) |
26 | |------------|---------|------------|----------------|----------------|----------------|----------------|
27 | | 7B | Y | 2 | 0.68 | 0.68 | 0.46 | 0.46 |
28 | | 13B | Y | 2 | 0.61 | 0.69 | 0.43 | 0.46 |
29 | | 34B | Y | 2 | 0.55 | 0.74 | 0.38 | 0.49 |
30 | | 70B | Y | 2 | 0.55 | 0.74 | 0.38 | 0.47 |
31 |
32 | A few points to note here, on the A100s, we note that for 13B we are not utilizing the hardware as well (only 0.48 MFU) because of smaller batch size. We can dial up the MFU by turning on activation checkpointing, however the throughput falls to 1600 tokens/sec/GPU. Whereas, note that the gaps here are more glaring with H100s where the MFU for 7 and 13B falls below 0.40.
33 |
34 | Another point to note here is that for the larger models, we could increase the throughput by a few percentage points when we increase the batch size. However, we have left the batches to be smaller to allow for scaling to 1024 GPUs without introducing tensor parallelism.
35 |
36 | ## Installation
37 | You need to install the required packages by running the following command.
38 | We recommend running the latest [PyTorch nightlies](https://pytorch.org/) and latest [ibm-fms](https://github.com/foundation-model-stack/foundation-model-stack).
39 | ```bash
40 | pip install -r requirements.txt
41 | ```
42 |
43 | ## Training
44 |
45 | ### Model
46 | We trained one model, a replica of Llama2 7B as an exemplar on IBM curated data. This model was trained to 2.2T tokens with a 4k context length on 128 A100 GPUs for a total of 162k GPU hours, achieving an efficiency of 3700 tokens/sec/GPU (~40B tokens/day), which is roughly 20% faster than the Llama2 published training time. These speedups were possible by combining multiple techniques - SDPA Flash v2 implementation, FSDP with overlap in computation and communication, and selective activation checkpointing.
47 | The generated model has a good performance on various metrics as evaluated by [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness), with MMLU score of 0.5. We share further [scores](docs/evaluation.md) in the details of the model for completeness.
48 |
49 | ### Dataset
50 | We use an internally curated dataset for training the model. We use sampling ratios similar to what Llama1 paper proposed with minor changes (e.g., no C4 dataset). Since the goal of this repo is to demonstrate the feasibility of training using PyTorch components at scale, we omit the details of the sampling ratios. The overall dataset is roughly 1.5T tokens and the model has seen all the tokens in the dataset at least once.
51 |
52 | For this dataset, we designed a large-scale workload dataloader, details can be found [here](docs/dataloader.md).
53 |
54 | ### Train Config
55 |
56 | Below assumes running with Slurm, but same can be easily adopted
57 | if running with other clusters.
58 |
59 | 1. modify Training Config in [scripts/train.sh](scripts/train.sh) (for the full
60 | list of training configs and best practices, refer to [Configuration Doc](docs/configurations.md)).
61 | 2. modify Run Config in [scripts/train.slurm](scripts/train.slurm)
62 |
63 | ### Run
64 | ```bash
65 | sbatch ./scripts/train.slurm
66 | ```
67 | For other cluster setup, we can simply use the *torchrun* commands inside `train.sh`.
68 |
69 | ### Training Details and Lessons learnt
70 | Details on training stability, loss curve, LR curve, etc., as well as what
71 | we have learnt from this journey can be found in [Training Details](docs/train_details.md).
72 |
73 | ## Post Training
74 |
75 | ### Convert to Hugging Face format
76 |
77 | The model trained with this repo is in FMS format, and you might want to convert it
78 | to Huggingface format so that you can load it natively with Huggingface and leverage Huggingface ecosystem:
79 | ```bash
80 | python fms_to_hf.py --model_variant 7b --nocompiled --load_path /path/to/trained/checkpoints --save_path /output/path --tokenizer_name_or_path /path/to/llama/tokenizer
81 | ```
82 | > [!Note]
83 | > This repo consumes pre-tokenized data thus does not require a tokenizer. However,
84 | > Huggingface checkpoint requires a paired tokenizer thus you need to pass a tokenizer
85 | > here so it can be copied over to the save dir. Just download the HF Llama tokenizer
86 | > and pass the path here.
87 |
88 | ## Fine tuning
89 |
90 | We have performed preliminary fine-tuning on our base model and details can be found [here](docs/fine_tuning.md).
91 |
--------------------------------------------------------------------------------
/code-of-conduct.md:
--------------------------------------------------------------------------------
1 | # Foundation Model Stack Community Code of Conduct
2 |
3 | Please refer to [Foundation Model Stack Community Code of Conduct](https://github.com/foundation-model-stack/foundation-model-stack/blob/main/code-of-conduct.md).
4 |
--------------------------------------------------------------------------------
/docs/configurations.md:
--------------------------------------------------------------------------------
1 | # Configurations
2 |
3 | All configurations in [scripts/train.sh](scripts/train.sh) will be passed into
4 | [training configs](../pretraining/config/training.py).
5 |
6 | ## Full list of configurations
7 |
8 | ### Model
9 | - **model_variant**: the llama variant, values in "7b", "13b", "34b" and "70b".
10 | - **ckpt_load_path**: the path from where checkpoint will be loaded for continued training.
11 | - **ckpt_save_path**: the path to which checkpoint will be saved.
12 |
13 | ### Dataset and Dataloader
14 | - **use_dummy_dataset**: set this to `True`` to use dummy dataset for quick testing and performance benchmarking.
15 | - **data_path**: Data path.
16 | - **seq_length**: Sequence/context length to build when preparing model input.
17 | - **sep_token**: Separator token in the tokenized dataset.
18 | - **datasets**: Subfolders under `datapath` that contains different datasets.
19 | - **weights**: Proportion of each dataset when training.
20 | - **logical_shards**: Number of logical shards when building dataloader. This is an advanced setting and we will go into detail in a future update.
21 |
22 | ### FSDP policies
23 | - **sharding_strategy**: "FSDP" (Fully Sharded) or "HSDP" (Hybrid Sharded), HSDP allows for sharding within a node and DP across nodes.
24 | - **mixed_precision**: Whether to use `bf16` mixed precision for training
25 | - **fsdp_activation_checkpointing**: whether to turn on activation checkpointing
26 | - **selective_checkpointing**: How many blocks to checkpoint the activation. 1 is the default setting, experiment with this number to trade off between memory and compute requirements.
27 | - **low_cpu_fsdp**: Whether to load the model in low cpu mode. This is useful when loading large models like 70b.
28 |
29 | ### Training spec
30 | - **seed**: random seed for reproduction.
31 | - **batch_size**: batch size per gpu.
32 | - **num_steps**: total number of steps to train.
33 | - **learning_rate**: learning rate. This is the max learning rate for model to warm up to. We default to the `cosine` schedule which is popular for pretraining workloads.
34 |
35 | ### Profiling and reporting
36 | - **use_profiler**: whether to turn on PyTorch profiler to generate profiling traces
37 | - **report_interval**: how many steps to report training metrics
38 | - **checkpoint_interval**: how many steps to save a checkpoint
39 |
40 | ### Compile
41 | - **use_torch_compile**: whether to turn on compile. It is recommended to NOT use compile at this stage due to some known issues with compile-training.
42 |
43 |
44 | ## Deep Dive into FSDP Configs
45 | You can skip this section if you are already familiar with FSDP.
46 |
47 | ### Basics
48 | In case you are new to FSDP, here are some basic references:
49 | [FSDP Intro](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/)
50 | [FSDP API](https://pytorch.org/docs/stable/fsdp.html)
51 |
52 | ### Tune the configs
53 | The key to achieve the best performance for FSDP is to fully overlap the communication
54 | time with computation time so that the GPUs are busy all the time as if there is no
55 | communication cost/gap.
56 |
57 | **sharding_strategy** is the first thing you want to decide. FSDP will shard your model
58 | across all devices (GPUs) while HSDP will only shard your model across devices within
59 | the same node. E.g. if you are training with 128 nodes (i.e. 128 * 8 = 1024 total gpus),
60 | FSDP will shard your model across all 1024 gpus while HSDP will shard your model across
61 | 8 gpus in each node. Therefore, FSDP will save more memory while HSDP will introduce
62 | lesser communications. For smaller models like 7b and 13b, HSDP is preferred as it
63 | will make communication shorter thus easier to be overlapped with computation; while
64 | for larger models like 34b and 70b, FSDP will be a necessity as the model is too large
65 | to be fitted into only 8 gpus.
66 |
67 | **fsdp_activation_checkpointing** controls if activation checkpointing is enabled.
68 | Enabling it will greatly save the memory but also increase the computation time due
69 | to activation re-computation. For large models you would typically have to set this
70 | to `True` as activations consume large amount of memory and without checkpointing it
71 | you will face OOM. For smaller models, depending on your batch size, you may or may
72 | not enable it. As a companion, **selective_checkpointing** controls how "often"
73 | to checkpoint the activation (i.e. checkpoint activation only every k steps), the
74 | smaller the value, the more often it checkpoints and thus the more memory will
75 | be saved. default value is 1 meaning checkpoint every block.
76 |
--------------------------------------------------------------------------------
/docs/dataloader.md:
--------------------------------------------------------------------------------
1 | # Data Loader
2 |
3 | We design a data loader as part of our pretraining that can provide shuffling in realtime while ensuring that there is no drop in GPU utilization. We design it to be scalable to multiple nodes (tested to 128 nodes), streaming, rescaling the number of GPUs during a single training run, and allows for restart from a given state.
4 |
5 | ## Details
6 |
7 | The current distributed dataloader is designed to meet two important needs of data scientists running large-scale training workloads: seamless resumption of an interrupted job, and rapid iteration on dataset composition and handling.
8 |
9 | We address the first by maintaining a checkpointable state that allows the user to restart model training from checkpoint, mid-epoch, while keeping a guarantee that each document will be viewed exactly once in any given epoch, up to oversampling (i.e. no revisiting stale data). The user is also free to scale the job up or down to different numbers of gpus from phase to phase, while still maintaining this guarantee.
10 |
11 | To address the second concern, we enforce a rigid format on our input data (tokenized documents, in arrow shard files, organized into dataset subdirectories, with a single unified metadata file of document counts per shard) but construct the specific dataset combinations and mixes dynamically at runtime, based on user inputs. This is accomplished separately on each worker process, with no communication needed between devices. Each worker then streams through the ordered documents and shard files according to its constructed plan, pulling files from disk or cloud on demand as training proceeds. This allows the user to add or eliminate subdatasets, adjust subdataset sampling rates, change BOS/EOS/SEP tokens, toggle padding or packing on or off, adjust sequence lengths, or swap out the training task, for example, without having to build any new training datasets on disk from run to run (a potentially long and expensive process for Terabyte-scale data).
12 |
13 | Because each worker is streaming documents and files sequentially, shuffling is required, and this is accomplished via an internal buffer which ensures that in expectation, two consecutive lines in the stream will appear 10,000 steps apart (this can be adjusted higher or lower as desired). Finally, the dataloader is implemented as modular extensions of PyTorch Datasets, allowing the user to add or remove custom data pipeline functionality as needed.
14 |
15 | Further technical details can be found in the `fms-fsdp/utils/dataset_utils.py` file.
16 |
--------------------------------------------------------------------------------
/docs/evaluation.md:
--------------------------------------------------------------------------------
1 | # Preliminary Evaluation
2 |
3 | We convert the sharded FSDP checkpoint to a Hugging Face checkpoint and run [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)
4 | without any changes. The key scores at the 2T checkpoint are summarized below:
5 |
6 | | Evaluation metric | Llama2-7B (baseline) | LlamaT-7B |
7 | |----------------------------|----------------------|---------------|
8 | | MMLU (zero shot) | 0.41 | 0.43 |
9 | | MMLU (5-shot weighted avg) | 0.47 | 0.50 |
10 | | Arc challenge | 0.46 | 0.44 |
11 | | Arc easy | 0.74 | 0.71 |
12 | | Boolq | 0.78 | 0.76 |
13 | | Copa | 0.87 | 0.83 |
14 | | Hellaswag | 0.76 | 0.74 |
15 | | Openbookqa | 0.44 | 0.42 |
16 | | Piqa | 0.79 | 0.79 |
17 | | Sciq | 0.91 | 0.91 |
18 | | Winogrande | 0.69 | 0.67 |
19 | | Truthfulqa | 0.39 | 0.39 |
20 | | GSM8k (8-shot) | 0.13 | 0.11 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/docs/fine_tuning.md:
--------------------------------------------------------------------------------
1 | # Fine Tuning
2 |
3 | To validate that the stack produces a model which is as
4 | easy to fine-tune as Llama models, we convert the FSDP checkpoint into a HF checkpoint and fine tune it
5 | using popular fine-tuning configurations and data mixes.
6 |
7 | Specifically, we follow Allen AI’s [open-instruct](https://github.com/allenai/open-instruct) framework, leveraging the TULU v2 stack as-is
8 | (DeepSpeed, TULU v2 mixture and recommended configuration for Llama 2 models). The tuned model
9 | scores are presented below and we note improvements in several tasks. We did not do a hyperparameter
10 | exploration for the best parameters to fine-tune LlamaT. We note that optimal hyperparameter for
11 | LlamaT tuning could be different from Llama 2 as they are likely to have followed different learning
12 | rate schedules.
13 |
14 | | Evaluation metric | Llama2-7B (baseline) | LlamaT-7B |
15 | |----------------------------|----------------------|---------------|
16 | | MMLU (5-shot weighted avg) | 0.53 | 0.49 |
17 | | Arc challenge | 0.48 | 0.43 |
18 | | Arc easy | 0.73 | 0.67 |
19 | | Boolq | 0.82 | 0.82 |
20 | | Copa | 0.89 | 0.86 |
21 | | Hellaswag | 0.76 | 0.75 |
22 | | Openbookqa | 0.47 | 0.42 |
23 | | Piqa | 0.79 | 0.79 |
24 | | Sciq | 0.93 | 0.91 |
25 | | Winogrande | 0.71 | 0.65 |
26 | | Truthfulqa | 0.45 | 0.46 |
27 | | GSM8k (8-shot) | 0 | 0 |
28 |
--------------------------------------------------------------------------------
/docs/train_details.md:
--------------------------------------------------------------------------------
1 | # Training Details
2 |
3 | The model training used PyTorch FSDP with no activation recomputation, hybrid sharding with model
4 | weights and optimizer state sharded within a node and data parallel across nodes, per GPU batch size of
5 | 2 (effective batch size of 1M tokens/batch), AdamW optimizer with beta1 of 0.9 and beta2 of 0.95, weight
6 | decay of 0.1, and a learning rate ending at 3e-5 with a warmup to max learning rate of 3e-4 and a cosine
7 | schedule to reduce to 3e-5 over 2T tokens. The loss curve tracks that of Llama2 paper and reaches a lower
8 | loss than Llama2 7B does, which we believe is the characteristic of the dataset.
9 |
10 | ### Loss Curve
11 | 
12 |
13 | ### Learning Rate
14 | 
15 |
16 | ## Lesson learned
17 |
18 | ### Stability
19 |
20 | Training was stable with no crashes. We had a few hiccups as outlined below.
21 |
22 | **0-200B tokens**: We observed a slowdown in the iteration time (time taken to execute one training step). We stopped the job (freeing up GPUs for other workloads) to ensure that the data loader was not causing any slowdowns and the checkpointing was performant and accurate. We did not find any issues. By this time, HSDP checkpointing code was available in PyTorch, and we took this opportunity to make the switch to PyTorch checkpointing code.
23 |
24 | **200B tokens-1.9T**: We did not do any manual intervention in the job and forgot that it was running during the winter break. When we came back early January, disk space had exceeded and checkpoints were failing to be written, although training job continued. The last known checkpoint was 1.5T.
25 |
26 | **1.5T-1.7T**: We evaluated the 1.5T checkpoint with lm-evaluation-harness and discovered that model has been trained with extra special token between two documents due to the Hugging Face tokenizer introducing a separator token and our dataloader also appending its own document separator. We modified the dataloader to eliminate the extra special token, and continued training with the modified dataloader from 1.7T token onwards.
27 |
28 | **1.7T-2T**: The loss initially spiked due to the change in the special tokens which was quickly recovered in a few billion tokens. The training finished without any other manual intervention!!
29 |
30 | ### Speedups
31 |
32 | There are two approaches to speeding up the performance even further. With our recent work on
33 | improving inference speeds, we fused several layers that resulted in reduced inference latencies. We
34 | expect these techniques to benefit training as well.
35 |
36 | Further, with the release of a similar training code by OLMo, the issue that we had raised with PyTorch to
37 | get compile working for FSDP increased in priority. We are currently engaged with the PT team on enabling
38 | compile, which can provide further boost to the training speeds.
39 |
--------------------------------------------------------------------------------
/fms_fsdp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foundation-model-stack/fms-fsdp/503da7ede354e1ffdabc80fae8bbd211cb2174c8/fms_fsdp/__init__.py
--------------------------------------------------------------------------------
/fms_fsdp/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .training import train_config
2 |
--------------------------------------------------------------------------------
/fms_fsdp/config/training.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional, Union
3 |
4 |
5 | @dataclass
6 | class train_config:
7 | # model
8 | model_variant: str = "7b"
9 | ckpt_load_path: str = "/fsx/output/ckpt"
10 | ckpt_save_path: str = "/fsx/output/ckpt"
11 |
12 | # dataset and dataloader
13 | use_dummy_dataset: bool = False
14 | data_path: str = "/fsx/data"
15 | file_type: str = "arrow"
16 | col_name: str = "tokens"
17 | tokenizer_path: str = "/fsx/tokenizer"
18 | datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
19 | weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100"
20 | seq_length: int = 4096
21 | vocab_size: int = 32000
22 | bos_token: Optional[int] = None
23 | eos_token: int = 0
24 | bol_token: Optional[int] = None
25 | eol_token: Optional[int] = None
26 | strip_tokens: str = ""
27 | logical_shards: int = 1024
28 | num_workers: int = 1
29 |
30 | # fsdp policies
31 | sharding_strategy: str = "hsdp"
32 | fsdp_activation_checkpointing: bool = False
33 | selective_checkpointing: Union[float, str] = 1 # percentage of blocks to apply ac
34 | mixed_precision: bool = True
35 | low_cpu_fsdp: bool = False
36 |
37 | # training spec
38 | batch_size: int = 2
39 | num_steps: int = 1000000
40 | training_stage: str = "initial"
41 | learning_rate: float = 3e-4
42 | grad_clip_thresh: float = 1.0
43 | seed: int = 2023
44 |
45 | # continued training spec
46 | resuming_dataset: bool = False
47 |
48 | # profiling
49 | use_profiler: bool = False
50 | profiler_rank0_only: bool = True
51 |
52 | # logging
53 | report_interval: int = 100
54 | checkpoint_interval: int = 10000
55 | tracker: Optional[str] = None # None, "wandb", "aim"
56 | tracker_dir: str = "/fsx/aim_logs/llama"
57 | tracker_project_name: str = "llama" # project name for a group of runs
58 | tracker_run_id: Optional[str] = None # run id, for job resume purpose
59 |
60 | # compile
61 | use_torch_compile: bool = True
62 |
63 | # speculator training
64 | tp_size: int = 8
65 | model_arch: str = "embedllama"
66 | model_path: str = "/path/to/model/"
67 | n_speculator_heads: int = 3
68 | speculator_width: int = 4096
69 | speculator_tie_weights: bool = True
70 | speculator_scale_input: bool = True
71 | stage2_start_step: int = 15000
72 | stage2_prompt_length: int = 64
73 | stage2_batch_size: int = 96
74 | stage2_seq_length: int = 256
75 |
--------------------------------------------------------------------------------
/fms_fsdp/policies/__init__.py:
--------------------------------------------------------------------------------
1 | from .ac_handler import apply_fsdp_checkpointing
2 | from .mixed_precision import *
3 | from .param_init import param_init_function
4 | from .wrapping import get_wrapper
5 |
--------------------------------------------------------------------------------
/fms_fsdp/policies/ac_handler.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
4 | CheckpointImpl,
5 | apply_activation_checkpointing,
6 | checkpoint_wrapper,
7 | )
8 |
9 |
10 | non_reentrant_wrapper = partial(
11 | checkpoint_wrapper,
12 | checkpoint_impl=CheckpointImpl.NO_REENTRANT,
13 | )
14 |
15 |
16 | def apply_fsdp_checkpointing(model, block, p):
17 | """
18 | Apply selective activation checkpointing.
19 |
20 | Selectivity is defined as a percentage p, which means we apply ac
21 | on p of the total blocks. p is a floating number in the range of
22 | [0, 1].
23 |
24 | Some examples:
25 | p = 0: no ac for all blocks. same as `fsdp_activation_checkpointing=False`
26 | p = 1: apply ac on every block. i.e. "full ac".
27 | p = 1/2: [ac, no-ac, ac, no-ac, ...]
28 | p = 1/3: [no-ac, ac, no-ac, no-ac, ac, no-ac, ...]
29 | p = 2/3: [ac, no-ac, ac, ac, no-ac, ac, ...]
30 | Since blocks are homogeneous, we make ac blocks evenly spaced among
31 | all blocks.
32 |
33 | Implementation:
34 | For a given ac ratio p, we should essentially apply ac on every "1/p"
35 | blocks. The first ac block can be as early as the 0th block, or as
36 | late as the "1/p"th block, and we pick the middle one: (0.5p)th block.
37 | Therefore, we are essentially to apply ac on:
38 | (0.5/p)th block, (1.5/p)th block, (2.5/p)th block, etc., and of course,
39 | with these values rounding to integers.
40 | Since ac is applied recursively, we can simply use the following math
41 | in the code to apply ac on corresponding blocks.
42 | """
43 | block_idx = 0
44 | cut_off = 1 / 2
45 | # when passing p as a fraction number (e.g. 1/3), it will be interpreted
46 | # as a string in argv, thus we need eval("1/3") here for fractions.
47 | p = eval(p) if isinstance(p, str) else p
48 |
49 | def selective_checkpointing(submodule):
50 | nonlocal block_idx
51 | nonlocal cut_off
52 |
53 | if isinstance(submodule, block):
54 | block_idx += 1
55 | if block_idx * p >= cut_off:
56 | cut_off += 1
57 | return True
58 | return False
59 |
60 | apply_activation_checkpointing(
61 | model,
62 | checkpoint_wrapper_fn=non_reentrant_wrapper,
63 | check_fn=selective_checkpointing,
64 | )
65 |
--------------------------------------------------------------------------------
/fms_fsdp/policies/mixed_precision.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.distributed.fsdp import MixedPrecision
3 |
4 |
5 | fpSixteen = MixedPrecision(
6 | param_dtype=torch.float16,
7 | reduce_dtype=torch.float16,
8 | buffer_dtype=torch.float16,
9 | )
10 |
11 | bfSixteen = MixedPrecision(
12 | param_dtype=torch.bfloat16,
13 | reduce_dtype=torch.bfloat16,
14 | buffer_dtype=torch.bfloat16,
15 | )
16 |
17 | bfSixteen_working = MixedPrecision(
18 | param_dtype=torch.float32,
19 | reduce_dtype=torch.bfloat16,
20 | buffer_dtype=torch.bfloat16,
21 | )
22 |
23 | fp32_policy = MixedPrecision(
24 | param_dtype=torch.float32,
25 | reduce_dtype=torch.float32,
26 | buffer_dtype=torch.float32,
27 | )
28 |
--------------------------------------------------------------------------------
/fms_fsdp/policies/param_init.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from fms.modules.attention import MultiHeadAttention
3 | from fms.modules.embedding import WordEmbedding
4 | from fms.modules.feedforward import GatedLinearUnit
5 | from fms.modules.layernorm import LayerNormParameterized
6 |
7 |
8 | # for details, read https://github.com/foundation-model-stack/fms-fsdp/issues/64
9 | def param_init_function(module):
10 | if (
11 | isinstance(module, MultiHeadAttention)
12 | or isinstance(module, WordEmbedding)
13 | or isinstance(module, GatedLinearUnit)
14 | or isinstance(module, LayerNormParameterized)
15 | ):
16 | module.to_empty(device=torch.cuda.current_device())
17 | with torch.no_grad():
18 | module.reset_parameters()
19 |
--------------------------------------------------------------------------------
/fms_fsdp/policies/wrapping.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
4 |
5 |
6 | def get_wrapper(block):
7 | auto_wrap_policy = functools.partial(
8 | transformer_auto_wrap_policy,
9 | transformer_layer_cls={
10 | block,
11 | },
12 | )
13 |
14 | return auto_wrap_policy
15 |
--------------------------------------------------------------------------------
/fms_fsdp/readme.md:
--------------------------------------------------------------------------------
1 | # LLAMA branch
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/fms_fsdp/requirements.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foundation-model-stack/fms-fsdp/503da7ede354e1ffdabc80fae8bbd211cb2174c8/fms_fsdp/requirements.txt
--------------------------------------------------------------------------------
/fms_fsdp/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foundation-model-stack/fms-fsdp/503da7ede354e1ffdabc80fae8bbd211cb2174c8/fms_fsdp/utils/__init__.py
--------------------------------------------------------------------------------
/fms_fsdp/utils/checkpointing_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import time
4 | from pathlib import Path
5 |
6 | import torch
7 | from torch.distributed._shard.checkpoint import (
8 | FileSystemReader,
9 | FileSystemWriter,
10 | load_state_dict,
11 | save_state_dict,
12 | )
13 | from torch.distributed.checkpoint.default_planner import (
14 | DefaultLoadPlanner,
15 | DefaultSavePlanner,
16 | )
17 | from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
18 | from torch.distributed.fsdp import FullStateDictConfig
19 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
20 | from torch.distributed.fsdp import StateDictType
21 |
22 |
23 | def get_latest(targdir, qualifier=lambda x: True, key=os.path.getctime):
24 | """
25 | Fetch the full path of the latest file or folder written to target directory,
26 | subject to name passing the qualifier fn.
27 | Optional key fn can be used for custom sorting.
28 | Both functions take full path arguments.
29 | If directory is empty or nonexistent or no items qualify, return None.
30 | """
31 | if os.path.exists(targdir) and len(os.listdir(targdir)) > 0:
32 | latest = max(
33 | [
34 | os.path.join(targdir, x)
35 | for x in os.listdir(targdir)
36 | if qualifier(os.path.join(targdir, x))
37 | ],
38 | key=key,
39 | )
40 | return latest
41 | return None
42 |
43 |
44 | def get_oldest(targdir, qualifier=lambda x: True, key=os.path.getctime):
45 | """
46 | Fetch the full path of the oldest file or folder written to target directory,
47 | subject to name passing the qualifier fn.
48 | Optional key fn can be used for custom sorting.
49 | Both functions take full path arguments.
50 | If directory is empty or nonexistent or no items qualify, return None.
51 | """
52 | if os.path.exists(targdir) and len(os.listdir(targdir)) > 0:
53 | oldest = min(
54 | [
55 | os.path.join(targdir, x)
56 | for x in os.listdir(targdir)
57 | if qualifier(os.path.join(targdir, x))
58 | ],
59 | key=key,
60 | )
61 | return oldest
62 | return None
63 |
64 |
65 | class Checkpointer:
66 | """
67 | Manages the checkpoint directory. Saves new checkpoints and deletes old ones after the specified number are written.
68 | Also handles loading and saving of checkpoints in sharded and unsharded formats.
69 | Assumes model and optimizer inputs are in FSDP.
70 | ...
71 | Args
72 | ----
73 | ckpdir : str
74 | Absolute path to desired save location. Creates a new 'checkpoints/' subfolder at that location.
75 | n_to_save : int
76 | Number of volatile checkpoints to maintain at any given time.
77 | parallel_mode : str
78 | Write sharded folder ckps (when sharded: 'fsdp' or 'hsdp') or unsharded file ckps (when sharded: 'ddp')
79 | report_fn : Callable or None
80 | Optional function for reporting or logging status updates. Expected to handle arbitrary *args, **kwargs.
81 | Defaults to self._selective_print().
82 | model_auto_placement : bool
83 | Optional; If True, auto detect GPU device to move model to, as set in device mesh init
84 |
85 | Methods
86 | -------
87 | save : keyword args -> str | None
88 | Saves dictionary of keyword arg key/value pairs to specified checkpoint directory, deleting old checkpoints
89 | as necessary. If a checkpoint is deleted, returns the filename of that checkpoint.
90 | load :
91 | See docstring for individual function below
92 | """
93 |
94 | def __init__(
95 | self,
96 | ckpdir,
97 | n_to_save,
98 | parallel_mode,
99 | rank,
100 | local_rank,
101 | report_fn=None,
102 | model_auto_placement=False,
103 | ):
104 | self.max_ckps = n_to_save
105 | self.rank = rank
106 | self.local_rank = local_rank
107 | self.ckp_path = os.path.join(ckpdir, "checkpoints/")
108 | os.makedirs(self.ckp_path, exist_ok=True)
109 | self.p_mode = parallel_mode
110 | assert parallel_mode in ["fsdp", "hsdp", "ddp"]
111 | self.report = self._selective_print if report_fn is None else report_fn
112 | self.model_auto_placement = model_auto_placement
113 |
114 | def _selective_print(self, *args, **kwargs):
115 | if self.rank == 0:
116 | print(*args)
117 | for k, v in kwargs.items():
118 | print(k, "=", v)
119 |
120 | def _cleanup(self):
121 | # Clean old checkpoints. Barrier to keep synchronization correct.
122 | file_to_remove = None
123 | if (
124 | self.rank == 0
125 | and len([x for x in os.listdir(self.ckp_path) if "tmp" in x])
126 | > self.max_ckps
127 | ):
128 | ckp_to_remove = Path(
129 | get_oldest(self.ckp_path, qualifier=lambda x: "tmp" in x)
130 | )
131 | if os.path.isfile(ckp_to_remove):
132 | ckp_to_remove.unlink()
133 | else:
134 | shutil.rmtree(ckp_to_remove)
135 | return file_to_remove
136 |
137 | def _do_save(self, rank, local_rank): # , shard_group, replicate_group):
138 | if self.p_mode == "hsdp":
139 | return rank == local_rank
140 | else:
141 | return True
142 | # TODO: Distributed writing contingent upon the following fix: https://github.com/pytorch/pytorch/issues/104081
143 | # if not is_dist:
144 | # return (rank == local_rank)
145 | # else:
146 | # a = rank % shard_group.size()
147 | # b = rank // shard_group.size()
148 | # return True if a % replicate_group.size() == b else False
149 | # shard_group = model.process_group
150 | # replicate_group = model.__inter_node_state.process_group
151 |
152 | def _write(self, state_dict, loader_state, process_group, save_name, rank):
153 | os.makedirs(save_name, exist_ok=True)
154 | writer = FileSystemWriter(save_name, single_file_per_rank=True)
155 | if state_dict is not None:
156 | save_state_dict(
157 | state_dict=state_dict,
158 | storage_writer=writer,
159 | process_group=process_group,
160 | planner=DefaultSavePlanner(),
161 | )
162 | if loader_state is not None:
163 | loader_state.save_to_path(save_name)
164 |
165 | def _validate_ckp_path(self, path):
166 | """Interpret path to appropriate checkpoint. If found, return modified path. If not found, return None."""
167 | # Does path exist and is it non-empty?
168 | if os.path.exists(path):
169 | # Is this a file?
170 | if os.path.isfile(path):
171 | return path
172 | # Is this a sharded directory?
173 | elif "metadata.pth" in os.listdir(path):
174 | return path
175 | # Is this a path to a set of checkpoints?
176 | elif len(os.listdir(path)) > 0:
177 | latest = get_latest(path)
178 | if os.path.isfile(latest):
179 | return latest
180 | elif "metadata.pth" in os.listdir(latest):
181 | return latest
182 | return None
183 |
184 | def load(
185 | self,
186 | model,
187 | optimizer,
188 | dataloader,
189 | path="",
190 | reset_stepcount=False,
191 | strict=True,
192 | is_compiled=False,
193 | ):
194 | """
195 | Handle checkpoint loading for model/optimizer/dataloader from given path, according to arguments.
196 | Defaults to save path for locating an appropriate checkpoint. If a path is provided, will use
197 | it only if no appropriate checkpoint is found in the save path (in which case it's a job restart).
198 | Reset_stepcount manually resets optimizer and dataloader states, and stat tracking.
199 | Strict determines whether to use strict loading or not FOR SINGLEFILE LOADING ONLY.
200 | Returns model, optimizer, dataloader, current step, and current tokens seen.
201 | """
202 | is_resuming = False
203 | if self._validate_ckp_path(self.ckp_path) is not None:
204 | path = self.ckp_path
205 | is_resuming = True
206 | load_path = self._validate_ckp_path(path)
207 | if load_path is None:
208 | self.report(
209 | f"No valid checkpoint detected at {path}, starting from scratch."
210 | )
211 | return model, optimizer, dataloader, 0, 0, False
212 | else:
213 | self.report(f"Prior checkpoint {load_path} detected.")
214 | model_load_time = time.time()
215 | if os.path.isfile(load_path):
216 | checkpoint_data = torch.load(load_path, map_location="cpu")
217 | if is_compiled:
218 | model._orig_mod.load_state_dict(
219 | checkpoint_data.get("model_state"), strict=strict
220 | )
221 | else:
222 | model.load_state_dict(
223 | checkpoint_data.get("model_state"), strict=strict
224 | )
225 | if self.model_auto_placement:
226 | model.to("cuda")
227 | else:
228 | model.to(self.local_rank)
229 | self.report(
230 | f"Checkpoint {load_path} is a single-file checkpoint containing only a model. Optimizer and dataloader are from scratch.",
231 | model_load_time=time.time() - model_load_time,
232 | )
233 | return model, optimizer, dataloader, 0, 0, is_resuming
234 | else:
235 | # Load model
236 | with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
237 | state_dict = model.state_dict()
238 | model_ckp = {"model_state": state_dict}
239 | load_state_dict(
240 | state_dict=model_ckp,
241 | storage_reader=FileSystemReader(load_path),
242 | planner=DefaultLoadPlanner(),
243 | )
244 | model.load_state_dict(model_ckp["model_state"])
245 | if self.model_auto_placement:
246 | model.to("cuda")
247 | else:
248 | model.to(self.local_rank)
249 | self.report(model_load_time=time.time() - model_load_time)
250 | step = 0
251 | ntok = 0
252 | # Load metadata
253 | if is_resuming:
254 | metadata = torch.load(os.path.join(load_path, "metadata.pth"))
255 | step = metadata.get("step", 0)
256 | ntok = metadata.get("tokens_seen", 0)
257 | self.report("Metadata loaded", start_step=step, n_tokens_seen=ntok)
258 | # Load optimizer
259 | if optimizer is not None:
260 | optim_load_time = time.time()
261 | with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
262 | optim_state = load_sharded_optimizer_state_dict(
263 | model_state_dict=model.state_dict(),
264 | optimizer_key="optimizer_state",
265 | storage_reader=FileSystemReader(load_path),
266 | )
267 | flattened_osd = FSDP.optim_state_dict_to_load(
268 | model, optimizer, optim_state["optimizer_state"]
269 | )
270 | optimizer.load_state_dict(flattened_osd)
271 | self.report(optimizer_load_time=time.time() - optim_load_time)
272 | else:
273 | self.report("Skipping optimizer load, no optimizer provided.")
274 | # Load dataset
275 | if dataloader is not None:
276 | data_load_time = time.time()
277 | dataloader.dataset.load_from_path(path)
278 | self.report(dataset_load_time=time.time() - data_load_time)
279 | else:
280 | self.report("Skipping dataset load, no dataloader provided.")
281 | return model, optimizer, dataloader, step, ntok, is_resuming
282 |
283 | def save(
284 | self,
285 | step,
286 | model,
287 | optimizer,
288 | dataloader,
289 | **kwargs,
290 | ):
291 | # Note: metadata kwargs cannot contain any of:
292 | # (step, model, optimizer, dataloader)
293 | rank = self.rank
294 | save_time = time.time()
295 | with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
296 | model_state = model.state_dict()
297 | optim_state = FSDP.sharded_optim_state_dict(model, optimizer)
298 | dataloader_state = None if dataloader is None else dataloader.dataset
299 |
300 | save_name = os.path.join(self.ckp_path, "step_" + str(step) + "_ckp")
301 | state_dict = {"model_state": model_state, "optimizer_state": optim_state}
302 | if self._do_save(rank, self.local_rank):
303 | self._write(
304 | state_dict, dataloader_state, model.process_group, save_name, rank
305 | )
306 | else:
307 | self._write(None, dataloader_state, None, save_name, rank)
308 | if rank == 0:
309 | metadata = kwargs
310 | metadata["step"] = step
311 | torch.save(metadata, os.path.join(save_name, "metadata.pth"))
312 | self.report(
313 | f"Checkpoint saved in {save_name}", model_save_time=time.time() - save_time
314 | )
315 |
316 | return self._cleanup()
317 |
--------------------------------------------------------------------------------
/fms_fsdp/utils/config_utils.py:
--------------------------------------------------------------------------------
1 | from fms.models.llama import LLaMAConfig
2 |
3 | from fms_fsdp.config import train_config
4 |
5 |
6 | def update_config(config, **kwargs):
7 | if isinstance(config, (tuple, list)):
8 | for c in config:
9 | update_config(c, **kwargs)
10 | else:
11 | for k, v in kwargs.items():
12 | if hasattr(config, k):
13 | setattr(config, k, v)
14 | elif "." in k:
15 | config_name, param_name = k.split(".")
16 | if type(config).__name__ == config_name:
17 | if hasattr(config, param_name):
18 | setattr(config, param_name, v)
19 | else:
20 | print(f"Warning: {config_name} does not accept parameter: {k}")
21 | elif isinstance(config, train_config):
22 | print(f"Warning: unknown parameter {k}")
23 |
24 |
25 | def get_model_config(model_variant):
26 | if model_variant == "llama2_70b":
27 | model_config = LLaMAConfig(
28 | emb_dim=8192,
29 | multiple_of=4096,
30 | nheads=64,
31 | kvheads=8,
32 | nlayers=80,
33 | hidden_grow_factor=28672 / 8192,
34 | )
35 | elif model_variant == "llama2_34b":
36 | model_config = LLaMAConfig(
37 | emb_dim=8192,
38 | nheads=64,
39 | kvheads=8,
40 | nlayers=48,
41 | hidden_grow_factor=22016 / 8192,
42 | max_expected_seq_len=16384,
43 | rope_theta=1000000.0,
44 | )
45 | elif model_variant == "llama2_13b":
46 | model_config = LLaMAConfig(
47 | emb_dim=5120,
48 | nheads=40,
49 | nlayers=40,
50 | hidden_grow_factor=13824 / 5120,
51 | )
52 | elif model_variant == "llama2_7b":
53 | model_config = LLaMAConfig(
54 | hidden_grow_factor=11008 / 4096,
55 | kvheads=32,
56 | )
57 | elif model_variant == "llama2_1.4b":
58 | model_config = LLaMAConfig(
59 | emb_dim=2048,
60 | nheads=16,
61 | nlayers=24,
62 | hidden_grow_factor=3,
63 | kvheads=4,
64 | )
65 | elif model_variant == "llama3_8b":
66 | model_config = LLaMAConfig(
67 | src_vocab_size=128256,
68 | emb_dim=4096,
69 | nheads=32,
70 | kvheads=8,
71 | nlayers=32,
72 | hidden_grow_factor=3.5,
73 | max_expected_seq_len=8192,
74 | rope_theta=500000.0,
75 | )
76 | elif model_variant == "llama3_8b_4k":
77 | model_config = LLaMAConfig(
78 | src_vocab_size=128256,
79 | emb_dim=4096,
80 | nheads=32,
81 | kvheads=8,
82 | nlayers=32,
83 | hidden_grow_factor=3.5,
84 | max_expected_seq_len=4096,
85 | rope_theta=500000.0,
86 | )
87 | elif model_variant == "llama3_1.8b":
88 | model_config = LLaMAConfig(
89 | src_vocab_size=128256,
90 | emb_dim=2048,
91 | nheads=16,
92 | kvheads=8,
93 | nlayers=24,
94 | hidden_grow_factor=3.5,
95 | max_expected_seq_len=8192,
96 | rope_theta=500000.0,
97 | )
98 | elif model_variant == "llama3_1.8b_4k":
99 | model_config = LLaMAConfig(
100 | src_vocab_size=128256,
101 | emb_dim=2048,
102 | nheads=16,
103 | kvheads=8,
104 | nlayers=24,
105 | hidden_grow_factor=3.5,
106 | max_expected_seq_len=4096,
107 | rope_theta=500000.0,
108 | )
109 | elif model_variant == "llama3_3.2b":
110 | model_config = LLaMAConfig(
111 | src_vocab_size=128256,
112 | emb_dim=3072,
113 | nheads=24,
114 | kvheads=8,
115 | nlayers=24,
116 | hidden_grow_factor=8 / 3,
117 | max_expected_seq_len=8192,
118 | rope_theta=500000.0,
119 | )
120 | elif model_variant == "llama3_3.2b_4k":
121 | model_config = LLaMAConfig(
122 | src_vocab_size=128256,
123 | emb_dim=3072,
124 | nheads=24,
125 | kvheads=8,
126 | nlayers=24,
127 | hidden_grow_factor=8 / 3,
128 | max_expected_seq_len=4096,
129 | rope_theta=500000.0,
130 | )
131 | elif model_variant == "llama3_70b":
132 | model_config = LLaMAConfig(
133 | src_vocab_size=128256,
134 | emb_dim=8192,
135 | nheads=64,
136 | kvheads=8,
137 | nlayers=80,
138 | hidden_grow_factor=3.5,
139 | max_expected_seq_len=8192,
140 | rope_theta=500000.0,
141 | )
142 | elif model_variant == "llama3_70b_4k":
143 | model_config = LLaMAConfig(
144 | src_vocab_size=128256,
145 | emb_dim=8192,
146 | nheads=64,
147 | kvheads=8,
148 | nlayers=80,
149 | hidden_grow_factor=3.5,
150 | max_expected_seq_len=4096,
151 | rope_theta=500000.0,
152 | )
153 | elif model_variant == "llama3_194m_4k":
154 | model_config = LLaMAConfig(
155 | src_vocab_size=128256,
156 | emb_dim=1024,
157 | nheads=8,
158 | nlayers=10,
159 | max_expected_seq_len=4096,
160 | rope_theta=500000.0,
161 | )
162 | elif model_variant == "mamba_9.8b":
163 | model_config = {
164 | "d_model": 4096,
165 | "d_intermediate": 14336,
166 | "n_layer": 32,
167 | "vocab_size": 128256,
168 | "ssm_cfg": {"layer": "Mamba2"},
169 | "attn_layer_idx": [9, 18, 27],
170 | "attn_cfg": {
171 | "causal": True,
172 | "d_conv": 0,
173 | "head_dim": 128,
174 | "num_heads": 32,
175 | "num_heads_kv": 8,
176 | "out_proj_bias": False,
177 | "qkv_proj_bias": False,
178 | "rotary_emb_dim": 64,
179 | },
180 | "rms_norm": True,
181 | "residual_in_fp32": True,
182 | "fused_add_norm": True,
183 | "pad_vocab_size_multiple": 16,
184 | "tie_embeddings": False,
185 | }
186 | else:
187 | raise ValueError(f"model variant {model_variant} not supported.")
188 |
189 | return model_config
190 |
--------------------------------------------------------------------------------
/fms_fsdp/utils/dataloader_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from fms_fsdp.utils.dataset_utils import (
4 | ArrowHandler,
5 | AutoHandler,
6 | BufferDataset,
7 | CheckpointDataset,
8 | ParquetHandler,
9 | PreloadBufferDataset,
10 | PreprocessDataset,
11 | SamplingDataset,
12 | ScalableShardDataset,
13 | StreamingDocDataset,
14 | )
15 |
16 |
17 | _handler_map = {
18 | "arrow": ArrowHandler,
19 | "hf_parquet": ParquetHandler,
20 | "auto": AutoHandler,
21 | }
22 |
23 |
24 | def causal_lm(data_seq, prompt_len=1):
25 | """
26 | Perform causal language modeling by right-shifting the input sequence.
27 | Sets first prompt_len tokens to be ignored by the loss.
28 | """
29 | data_seq = torch.tensor(data_seq, dtype=torch.int)
30 | t = data_seq.clone()[1:]
31 | data_seq = data_seq[:-1]
32 | t[:prompt_len] = -100
33 | return data_seq, t
34 |
35 |
36 | def get_dummy_loader(cfg, rank, world_size):
37 | """
38 | A simple dummy dataloader yielding incrementing vocab indices in an infinite loop
39 | """
40 |
41 | class SteadyCounter(torch.utils.data.IterableDataset):
42 | # Spit out incremental counts of constant length l, modulo vocab size v
43 | def __init__(self, l, v):
44 | self.i = 0
45 | self.l = l
46 | self.v = v
47 |
48 | def __iter__(self):
49 | while True:
50 | out = torch.IntTensor(
51 | [x % self.v for x in range(self.i, self.i + self.l)]
52 | )
53 | yield out, out
54 | self.i += self.l
55 |
56 | data = SteadyCounter(cfg.seq_length, cfg.vocab_size)
57 | return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size)
58 |
59 |
60 | def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
61 | """
62 | Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training.
63 | Assumes underlying data is sequences of integer values.
64 | ...
65 | Args
66 | ----
67 | cfg : dataclass
68 | Training config containing seq len, dataset, dataset weight, datapath, etc. arguments
69 | rank : int
70 | Rank of current distributed worker. Used for handling dataset sharding logic.
71 | world_size : int
72 | Number of distributed workers. Used for handling dataset sharding logic.
73 | postprocess : List[Callable]
74 | Any task-specific postprocessing to apply before handing over data. Steps will apply in
75 | the order provided by the user. For CLM training, use postprocess=[causal_lm].
76 | """
77 |
78 | datasets, weights = parse_data_args(cfg.datasets, cfg.weights)
79 |
80 | # Base streaming dataset. Returns doc chunks in sequence.
81 | # Implements dataset sampling and rescalability.
82 | droplist = [
83 | int(x.strip()) for x in cfg.strip_tokens.split(",") if len(x.strip()) > 0
84 | ]
85 | droplist = droplist + [cfg.bos_token, cfg.eos_token, cfg.bol_token, cfg.eol_token]
86 | assert (
87 | cfg.file_type in _handler_map
88 | ), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})"
89 | if cfg.file_type == "hf_parquet" or cfg.file_type == "auto":
90 | filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cfg.col_name)
91 | else:
92 | filehandler = _handler_map[cfg.file_type]
93 | # Base reader layer
94 | data = StreamingDocDataset(
95 | cfg.data_path,
96 | rank,
97 | world_size,
98 | filehandler,
99 | cfg.eos_token,
100 | bos_token=cfg.bos_token,
101 | strip_tokens=set(droplist),
102 | min_length=3,
103 | seed=cfg.seed,
104 | )
105 | # Add rescaling/resharding
106 | data = ScalableShardDataset(
107 | data,
108 | cfg.eos_token,
109 | n_logical_shards=cfg.logical_shards,
110 | )
111 | # Add multi-dataset handling
112 | data = SamplingDataset(
113 | cfg.data_path,
114 | data,
115 | cfg.eos_token,
116 | datasets=datasets,
117 | weights=weights,
118 | verbose=(rank == 0),
119 | )
120 | # Wrap above dataset in packing logic to form constant-length lines.
121 | data = BufferDataset(
122 | data,
123 | cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1,
124 | bos_token=cfg.bol_token,
125 | eos_token=cfg.eol_token,
126 | pack_hard=True,
127 | )
128 | # Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average.
129 | data = PreloadBufferDataset(data, 10000)
130 |
131 | # Apply desired postprocessing steps in sequence
132 | data = PreprocessDataset(data, torch.IntTensor)
133 | for p in postprocess:
134 | data = PreprocessDataset(data, p)
135 |
136 | # Enable auto-saving
137 | data = CheckpointDataset(
138 | data,
139 | cfg.ckpt_load_path if cfg.resuming_dataset else cfg.ckpt_save_path,
140 | cfg.checkpoint_interval,
141 | cfg.batch_size,
142 | cfg.ckpt_save_path,
143 | )
144 | return torch.utils.data.DataLoader(
145 | data, num_workers=cfg.num_workers, batch_size=cfg.batch_size
146 | )
147 |
148 |
149 | def parse_data_args(datas, weights):
150 | # Convert csv inputs into corresponding lists of values
151 | def splitstrip(x):
152 | if isinstance(x, str):
153 | return [item.strip() for item in x.split(",")]
154 | elif isinstance(x, (list, tuple)):
155 | return list(x)
156 | elif isinstance(x, (int, float, complex)):
157 | return [x]
158 | else:
159 | raise ValueError(f"arg input {x} cannot be parsed.")
160 |
161 | datas = splitstrip(datas)
162 | weights = [float(x) for x in splitstrip(weights)]
163 | return datas, weights
164 |
--------------------------------------------------------------------------------
/fms_fsdp/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import asdict
3 | from functools import partial
4 |
5 |
6 | try:
7 | import packaging.version
8 | except ImportError:
9 | from pkg_resources import packaging # type: ignore
10 |
11 | import time
12 | from datetime import timedelta
13 |
14 | import torch.cuda.nccl as nccl
15 | import torch.distributed as dist
16 | from torch.distributed.fsdp import ShardingStrategy
17 |
18 | from fms_fsdp.policies import *
19 |
20 |
21 | def train(
22 | cfg,
23 | model,
24 | local_rank,
25 | rank,
26 | train_loader,
27 | optimizer,
28 | scheduler,
29 | profiler,
30 | checkpointer,
31 | start_step,
32 | tokens_seen,
33 | ):
34 | if cfg.tracker:
35 | if cfg.tracker not in ["wandb", "aim"]:
36 | raise ValueError(f"tracker {cfg.tracker} not supported.")
37 | tracker_dir = cfg.tracker_dir
38 | project_name = cfg.tracker_project_name
39 | run_id = cfg.tracker_run_id
40 |
41 | if cfg.tracker == "wandb":
42 | try:
43 | import wandb # type: ignore
44 | except ImportError:
45 | raise ImportError("tracker is set to wandb but wandb is not installed.")
46 | if rank == 0:
47 | print(f"--> wandb is enabled!")
48 | try:
49 | wandb.init(
50 | project=project_name,
51 | dir=tracker_dir,
52 | resume="allow",
53 | id=run_id,
54 | )
55 | except wandb.errors.UsageError:
56 | raise ValueError(
57 | "wandb failed to init, did you pass your wandb api key via WANDB_API_KEY?"
58 | )
59 | wandb.config = asdict(cfg)
60 |
61 | if cfg.tracker == "aim":
62 | try:
63 | from aim import Run # type: ignore
64 | except ImportError:
65 | raise ImportError("tracker is set to aim but aim is not installed.")
66 | if rank == 0:
67 | print(f"--> aim is enabled!")
68 | run = Run(
69 | experiment=project_name,
70 | repo=tracker_dir,
71 | run_hash=run_id,
72 | )
73 | run["hparams"] = asdict(cfg)
74 |
75 | model.train()
76 | ddp_stats = torch.zeros(3).to(local_rank)
77 |
78 | start = time.time()
79 | loop_start = time.time()
80 | train_loss = -1
81 | for batch_idx, (input, label) in enumerate(train_loader, start=start_step + 1):
82 | if batch_idx > cfg.num_steps:
83 | break
84 | input = input.to(local_rank)
85 | label = label.to(local_rank)
86 |
87 | optimizer.zero_grad()
88 | output = model(input)
89 | output = output.logits if hasattr(output, "logits") else output
90 | ce_loss = torch.nn.CrossEntropyLoss()
91 | loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long())
92 |
93 | loss.backward()
94 | ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item()
95 | optimizer.step()
96 | scheduler.step()
97 |
98 | ddp_stats[0] += loss.item()
99 | ddp_stats[2] += 1
100 |
101 | if profiler:
102 | profiler.step()
103 |
104 | if batch_idx % cfg.report_interval == 0:
105 | dist.all_reduce(ddp_stats, op=dist.ReduceOp.SUM)
106 | train_loss = ddp_stats[0] / ddp_stats[2]
107 | g_norm = ddp_stats[1] / ddp_stats[2]
108 | elapsed_time = time.time() - loop_start
109 | world_size = int(os.environ["WORLD_SIZE"])
110 | new_tokens_seen = (
111 | (batch_idx - start_step) * world_size * cfg.batch_size * cfg.seq_length
112 | )
113 | if rank == 0:
114 | total_tokens_seen = tokens_seen + new_tokens_seen
115 | current_loss = train_loss.item()
116 | current_lr = scheduler.get_last_lr()[0]
117 | current_gnorm = g_norm.item()
118 | current_step_time = (time.time() - start) / cfg.report_interval
119 | overall_step_time = elapsed_time / (batch_idx - start_step)
120 | current_throughput = int(
121 | cfg.batch_size * cfg.seq_length / current_step_time
122 | )
123 | overall_throughput = int(
124 | cfg.batch_size * cfg.seq_length / overall_step_time
125 | )
126 | reserved_mem = torch.cuda.max_memory_reserved(
127 | device=torch.cuda.current_device()
128 | )
129 | allocated_mem = torch.cuda.max_memory_allocated(
130 | device=torch.cuda.current_device()
131 | )
132 |
133 | print("step:", batch_idx)
134 | print("loss:", current_loss)
135 | print("LR:", current_lr)
136 | print("tokens seen:", total_tokens_seen)
137 | print("gradient norm:", current_gnorm)
138 | print("reserved memory:", reserved_mem)
139 | print("allocated memory:", allocated_mem)
140 | print("current step time:", current_step_time)
141 | print("overall step time:", overall_step_time)
142 | print("current token per gpu per sec:", current_throughput)
143 | print("overall token per gpu per sec:", overall_throughput)
144 | print(
145 | "overall token per day:",
146 | int(new_tokens_seen / elapsed_time * 3600 * 24),
147 | )
148 | if cfg.tracker:
149 | vals_to_track = {
150 | "learning rate": current_lr,
151 | "loss": current_loss,
152 | "gradient norm": current_gnorm,
153 | "token seen": total_tokens_seen,
154 | "current throughput (token per gpu per sec)": current_throughput,
155 | "overall throughput (token per gpu per sec)": overall_throughput,
156 | "gpu reserved memory": reserved_mem,
157 | "gpu allocated memory": allocated_mem,
158 | }
159 | if cfg.tracker == "wandb":
160 | tracker_fn = wandb.log
161 | elif cfg.tracker == "aim":
162 | tracker_fn = run.track
163 | tracker_fn(vals_to_track, step=batch_idx)
164 |
165 | start = time.time()
166 | ddp_stats.zero_()
167 | torch.cuda.reset_peak_memory_stats(device=torch.cuda.current_device())
168 |
169 | if batch_idx % cfg.checkpoint_interval == 0 or batch_idx == cfg.num_steps:
170 | checkpointer.save(
171 | batch_idx,
172 | model,
173 | optimizer,
174 | None,
175 | tokens_seen=tokens_seen + new_tokens_seen,
176 | )
177 |
178 | return train_loss
179 |
180 |
181 | def setup():
182 | dist.init_process_group("nccl", timeout=timedelta(seconds=60 * 60))
183 |
184 |
185 | def setup_environ_flags():
186 | os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
187 | os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
188 |
189 |
190 | def get_mixed_precision_policy(cfg, rank):
191 | verify_bfloat_support = (
192 | torch.version.cuda
193 | and torch.cuda.is_bf16_supported()
194 | and packaging.version.parse(torch.version.cuda).release >= (11, 0)
195 | and dist.is_nccl_available()
196 | and nccl.version() >= (2, 10)
197 | )
198 |
199 | if cfg.mixed_precision:
200 | bf16_ready = verify_bfloat_support
201 | if bf16_ready:
202 | mixed_precision_policy = bfSixteen
203 | if rank == 0:
204 | print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
205 | else:
206 | mixed_precision_policy = fpSixteen
207 | if rank == 0:
208 | print(f"FP16 enabled")
209 | else:
210 | mixed_precision_policy = None
211 |
212 | return mixed_precision_policy
213 |
214 |
215 | def get_policies(cfg, rank, block):
216 | """Get policies for mixed precision, wrapping, sharding, ac and param init function."""
217 |
218 | # mixed precision
219 | mixed_precision_policy = get_mixed_precision_policy(cfg, rank)
220 |
221 | # wrapping policy
222 | wrapping_policy = get_wrapper(block)
223 |
224 | # sharding strategy
225 | if cfg.sharding_strategy == "fsdp":
226 | sharding_strategy = ShardingStrategy.FULL_SHARD
227 | elif cfg.sharding_strategy == "hsdp":
228 | sharding_strategy = ShardingStrategy.HYBRID_SHARD
229 | elif cfg.sharding_strategy == "ddp":
230 | sharding_strategy = ShardingStrategy.NO_SHARD
231 | else:
232 | sharding_strategy = ShardingStrategy.FULL_SHARD
233 | if rank == 0:
234 | print(f"Sharding strategy = {cfg.sharding_strategy}")
235 |
236 | # ac handler
237 | apply_selective_ac = partial(apply_fsdp_checkpointing, block=block)
238 |
239 | # param init function
240 | if cfg.low_cpu_fsdp:
241 | param_init_fn = param_init_function
242 | else:
243 | param_init_fn = None
244 |
245 | return (
246 | mixed_precision_policy,
247 | wrapping_policy,
248 | sharding_strategy,
249 | apply_selective_ac,
250 | param_init_fn,
251 | )
252 |
253 |
254 | def get_profiler(cfg, rank):
255 | if not cfg.use_profiler:
256 | return
257 | if cfg.profiler_rank0_only and rank != 0:
258 | return
259 | return torch.profiler.profile(
260 | activities=[
261 | torch.profiler.ProfilerActivity.CPU,
262 | torch.profiler.ProfilerActivity.CUDA,
263 | ],
264 | schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1),
265 | on_trace_ready=torch.profiler.tensorboard_trace_handler("profile_traces"),
266 | profile_memory=True,
267 | with_stack=False,
268 | record_shapes=True,
269 | )
270 |
--------------------------------------------------------------------------------
/fms_to_hf_llama.py:
--------------------------------------------------------------------------------
1 | import fire
2 | import torch
3 | from fms.models.hf.utils import to_hf_api
4 | from fms.models.llama import LLaMA
5 | from torch.distributed._shard.checkpoint import FileSystemReader, load_state_dict
6 | from transformers import LlamaConfig, LlamaForCausalLM
7 |
8 | from fms_fsdp.utils.config_utils import get_model_config
9 |
10 |
11 | def convert_to_hf(model: LLaMA, model_variant, is_old_fms) -> LlamaForCausalLM:
12 | fms_hf_model = to_hf_api(model)
13 | hf_config = fms_hf_model.config
14 | if "llama3" in model_variant:
15 | hf_config.bos_token_id = 128000
16 | hf_config.eos_token_id = 128001
17 | oss_hf_model = LlamaForCausalLM(
18 | LlamaConfig(
19 | vocab_size=hf_config.vocab_size,
20 | hidden_size=hf_config.hidden_size,
21 | rms_norm_eps=hf_config.norm_eps,
22 | num_attention_heads=hf_config.nheads,
23 | num_key_value_heads=None if hf_config.kvheads == 0 else hf_config.kvheads,
24 | num_hidden_layers=hf_config.nlayers,
25 | intermediate_size=hf_config.multiple_of
26 | * (
27 | (
28 | int(hf_config.hidden_grow_factor * hf_config.hidden_size)
29 | + hf_config.multiple_of
30 | - 1
31 | )
32 | // hf_config.multiple_of
33 | ),
34 | pad_token_id=(
35 | None if hf_config.pad_token_id == -1 else hf_config.pad_token_id
36 | ),
37 | bos_token_id=hf_config.bos_token_id,
38 | eos_token_id=hf_config.eos_token_id,
39 | max_position_embeddings=hf_config.max_expected_seq_len,
40 | )
41 | )
42 |
43 | # compute the freq from rot_emb since it is gathered lazily
44 | rot_emb = fms_hf_model.decoder.model.rot_emb
45 | max_seq_len = rot_emb.max_seq_len
46 | alpha = rot_emb._alpha(max_seq_len)
47 | ratio = rot_emb.ratio
48 | dim = rot_emb.dim
49 | if rot_emb.ntk_scaling:
50 | ratio = ratio * alpha ** (dim / (dim - 2))
51 | freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
52 |
53 | with torch.no_grad():
54 | oss_hf_model.model.embed_tokens.weight.copy_(fms_hf_model.embedding.weight)
55 | i = 0
56 | for oss_hf_layer in oss_hf_model.model.layers:
57 | fms_hf_layer = fms_hf_model.decoder.model.layers[i]
58 |
59 | # self attn
60 | if is_old_fms:
61 | oss_hf_layer.self_attn.q_proj.weight.copy_(
62 | fms_hf_layer.attn.query.weight
63 | )
64 | oss_hf_layer.self_attn.k_proj.weight.copy_(fms_hf_layer.attn.key.weight)
65 | oss_hf_layer.self_attn.v_proj.weight.copy_(
66 | fms_hf_layer.attn.value.weight
67 | )
68 | else:
69 | q, k, v = torch.split(
70 | fms_hf_layer.attn.in_proj.qkv_fused.weight,
71 | fms_hf_layer.attn.in_proj.splits,
72 | dim=0,
73 | )
74 | oss_hf_layer.self_attn.q_proj.weight.copy_(q)
75 | oss_hf_layer.self_attn.k_proj.weight.copy_(k)
76 | oss_hf_layer.self_attn.v_proj.weight.copy_(v)
77 | oss_hf_layer.self_attn.o_proj.weight.copy_(fms_hf_layer.attn.dense.weight)
78 | oss_hf_layer.self_attn.rotary_emb.inv_freqs = freqs
79 |
80 | # mlp
81 | if is_old_fms:
82 | oss_hf_layer.mlp.gate_proj.weight.copy_(
83 | fms_hf_layer.ff_sub_layer.wg.weight
84 | )
85 | oss_hf_layer.mlp.up_proj.weight.copy_(
86 | fms_hf_layer.ff_sub_layer.w1.weight
87 | )
88 | else:
89 | wg1_fused = fms_hf_layer.ff_sub_layer.wg1_fused.weight
90 | wg_splits = [wg1_fused.size(0) // 2, wg1_fused.size(0) // 2]
91 | wg, w1 = torch.split(
92 | fms_hf_layer.ff_sub_layer.wg1_fused.weight, wg_splits, dim=0
93 | )
94 | oss_hf_layer.mlp.gate_proj.weight.copy_(wg)
95 | oss_hf_layer.mlp.up_proj.weight.copy_(w1)
96 | oss_hf_layer.mlp.down_proj.weight.copy_(fms_hf_layer.ff_sub_layer.w2.weight)
97 |
98 | # layer norm
99 | oss_hf_layer.input_layernorm.weight.copy_(fms_hf_layer.ln.weight)
100 | oss_hf_layer.post_attention_layernorm.weight.copy_(
101 | fms_hf_layer.ff_ln.weight
102 | )
103 |
104 | # adjust q, k
105 | q = oss_hf_layer.self_attn.q_proj.weight.data
106 | q = (
107 | q.view(hf_config.nheads, -1, 2, q.size(1))
108 | .transpose(1, 2)
109 | .reshape(*q.size())
110 | )
111 | oss_hf_layer.self_attn.q_proj.weight.copy_(q)
112 |
113 | k = oss_hf_layer.self_attn.k_proj.weight.data
114 | k = (
115 | k.view(
116 | hf_config.nheads if hf_config.kvheads == 0 else hf_config.kvheads,
117 | -1,
118 | 2,
119 | k.size(1),
120 | )
121 | .transpose(1, 2)
122 | .reshape(*k.size())
123 | )
124 | oss_hf_layer.self_attn.k_proj.weight.copy_(k)
125 |
126 | i = i + 1
127 | oss_hf_model.model.norm.weight = fms_hf_model.decoder.model.dec_norm.weight
128 | oss_hf_model.lm_head.weight = fms_hf_model.lm_head.weight
129 |
130 | return oss_hf_model
131 |
132 |
133 | def main(
134 | model_variant, compiled, is_old_fms, load_path, save_path, tokenizer_name_or_path
135 | ):
136 | print("Initializing model...")
137 | llama_config = get_model_config(model_variant)
138 | with torch.device("meta"):
139 | model = LLaMA(llama_config)
140 | model.to_empty(device="cpu")
141 |
142 | print(f"Reading state dict from {load_path}")
143 | if not compiled:
144 | state_dict = {"model_state": model.state_dict()}
145 | else:
146 | state_dict = {"model_state": {"_orig_mod": model.state_dict()}}
147 | load_state_dict(
148 | state_dict=state_dict, storage_reader=FileSystemReader(load_path), no_dist=True
149 | )
150 |
151 | print("Loading state dict into the model...")
152 | if not compiled:
153 | model.load_state_dict(state_dict["model_state"])
154 | else:
155 | model.load_state_dict(state_dict["model_state"]["_orig_mod"])
156 |
157 | print("Converting to HF model..")
158 | hf_model = convert_to_hf(model, model_variant, is_old_fms)
159 | hf_model.save_pretrained(save_path)
160 |
161 | print("Copying tokenizer...")
162 | from transformers import AutoTokenizer
163 |
164 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
165 | tokenizer.save_pretrained(save_path)
166 |
167 | print(f"Model converted to HF model, saving at {save_path}")
168 |
169 |
170 | if __name__ == "__main__":
171 | fire.Fire(main)
172 |
--------------------------------------------------------------------------------
/fms_to_hf_mamba.py:
--------------------------------------------------------------------------------
1 | import fire
2 | from mamba_ssm.models.config_mamba import MambaConfig
3 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
4 | from torch.distributed._shard.checkpoint import FileSystemReader, load_state_dict
5 |
6 | from fms_fsdp.utils.config_utils import get_model_config
7 |
8 |
9 | def main(model_variant, load_path, save_path, tokenizer_name_or_path):
10 | print("Initializing model...")
11 | config_data = get_model_config(model_variant)
12 | mamba_config = MambaConfig(**config_data)
13 | model = MambaLMHeadModel(mamba_config)
14 |
15 | print(f"Reading state dict from {load_path}")
16 | state_dict = {"model_state": model.state_dict()}
17 | load_state_dict(
18 | state_dict=state_dict, storage_reader=FileSystemReader(load_path), no_dist=True
19 | )
20 |
21 | print("Loading state dict into the model...")
22 | model.load_state_dict(state_dict["model_state"])
23 |
24 | print("Saving model to HF-compatible format...")
25 | model.save_pretrained(save_path)
26 |
27 | print("Copying tokenizer...")
28 | from transformers import AutoTokenizer
29 |
30 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
31 | tokenizer.save_pretrained(save_path)
32 |
33 | print(f"Model saving at {save_path}")
34 |
35 |
36 | if __name__ == "__main__":
37 | fire.Fire(main)
38 |
--------------------------------------------------------------------------------
/images/loss_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foundation-model-stack/fms-fsdp/503da7ede354e1ffdabc80fae8bbd211cb2174c8/images/loss_curve.png
--------------------------------------------------------------------------------
/images/lr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foundation-model-stack/fms-fsdp/503da7ede354e1ffdabc80fae8bbd211cb2174c8/images/lr.png
--------------------------------------------------------------------------------
/main_training_llama.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 |
4 | import fire
5 | import torch
6 | import torch.optim as optim
7 | from fms.models.llama import LLaMA, LLaMABlock
8 | from torch import distributed as dist
9 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
10 | from torch.optim.lr_scheduler import LambdaLR
11 |
12 | from fms_fsdp import config
13 | from fms_fsdp.utils.checkpointing_utils import Checkpointer
14 | from fms_fsdp.utils.config_utils import get_model_config, update_config
15 | from fms_fsdp.utils.dataloader_utils import get_data_loader, get_dummy_loader
16 | from fms_fsdp.utils.train_utils import (
17 | get_policies,
18 | get_profiler,
19 | setup,
20 | setup_environ_flags,
21 | train,
22 | )
23 |
24 |
25 | def main(**kwargs):
26 | # get configs
27 | cfg = config.train_config()
28 | update_config(cfg, **kwargs)
29 |
30 | # ensure reproducibility
31 | torch.cuda.manual_seed(cfg.seed)
32 | torch.manual_seed(cfg.seed)
33 |
34 | # torchrun specific
35 | local_rank = int(os.environ["LOCAL_RANK"])
36 | rank = int(os.environ["RANK"])
37 | world_size = int(os.environ["WORLD_SIZE"])
38 |
39 | if rank == 0:
40 | print(f"--> running with these configs {cfg}")
41 |
42 | # some setups
43 | setup()
44 | torch.cuda.set_device(local_rank)
45 | torch.cuda.empty_cache()
46 | setup_environ_flags()
47 |
48 | # get policy
49 | block = LLaMABlock
50 | (
51 | mixed_precision_policy,
52 | wrapping_policy,
53 | sharding_strategy_policy,
54 | apply_selective_ac,
55 | param_init_fn,
56 | ) = get_policies(cfg, rank, block)
57 |
58 | # get fms model
59 | llama_config = get_model_config(cfg.model_variant)
60 | if cfg.low_cpu_fsdp:
61 | with torch.device("meta"):
62 | model = LLaMA(llama_config)
63 | else:
64 | model = LLaMA(llama_config)
65 | model.reset_parameters()
66 |
67 | if rank == 0:
68 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
69 | print(f"\n--> model has {total_params / 1e6} Million params\n")
70 |
71 | # get data loader
72 | if rank == 0:
73 | print("Constructing datasets...")
74 | if not cfg.use_dummy_dataset:
75 | train_loader = get_data_loader(cfg, rank, world_size)
76 | else:
77 | train_loader = get_dummy_loader(cfg, rank, world_size)
78 | if rank == 0:
79 | print("Datasets constructed!")
80 |
81 | # FSDP
82 | model = FSDP(
83 | model,
84 | auto_wrap_policy=wrapping_policy,
85 | mixed_precision=mixed_precision_policy,
86 | sharding_strategy=sharding_strategy_policy,
87 | use_orig_params=cfg.use_torch_compile,
88 | device_id=torch.cuda.current_device(),
89 | limit_all_gathers=True,
90 | param_init_fn=param_init_fn,
91 | )
92 | # we need this post-fsdp call to avoid graph break with torch.compile, until we figure out a better solution.
93 | model.rot_emb.compute_freqs_cis(
94 | torch.device("cuda", torch.cuda.current_device()),
95 | model.config.max_expected_seq_len,
96 | )
97 |
98 | # fsdp activation checkpointing
99 | if cfg.fsdp_activation_checkpointing:
100 | if rank == 0:
101 | print(f"--> applying FSDP activation checkpointing...")
102 | apply_selective_ac(model, p=cfg.selective_checkpointing)
103 |
104 | # torch compile
105 | if cfg.use_torch_compile:
106 | if rank == 0:
107 | print(f"--> enabling torch compile...")
108 | # the default accumulated_cache_size_limit=64 is not enough for 70b model, so we make it 128 here
109 | torch._dynamo.config.accumulated_cache_size_limit = 128
110 | model = torch.compile(model)
111 |
112 | # Optimizer
113 | optimizer = optim.AdamW(
114 | model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1
115 | )
116 |
117 | # optionally load from checkpoint (when continue pretraining)
118 | checkpointer = Checkpointer(
119 | cfg.ckpt_save_path, 1000, cfg.sharding_strategy, rank, local_rank
120 | )
121 | model, optimizer, _, start_step, tokens_seen, is_resuming = checkpointer.load(
122 | model,
123 | optimizer,
124 | None,
125 | path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
126 | if not os.path.isfile(cfg.ckpt_load_path)
127 | else cfg.ckpt_load_path,
128 | strict=False,
129 | )
130 | if not is_resuming:
131 | start_step = 0
132 | # Override loaded optim hyperparams with the current values
133 | for g in optimizer.param_groups:
134 | g["initial_lr"] = cfg.learning_rate
135 |
136 | # LR schedule
137 | if cfg.training_stage == "annealing":
138 | schedule = lambda x: 1 - x / cfg.num_steps
139 | else:
140 | warmup_interval = min(2000, cfg.num_steps // 20)
141 | schedule = lambda x: min(
142 | 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2,
143 | 0.1
144 | + 0.5
145 | * (1 - 0.1)
146 | * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)),
147 | )
148 | scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step))
149 |
150 | # profiler
151 | profiler = get_profiler(cfg, rank)
152 |
153 | # Train
154 | if rank == 0:
155 | print(f"Training for {cfg.num_steps} steps")
156 | train(
157 | cfg,
158 | model,
159 | local_rank,
160 | rank,
161 | train_loader,
162 | optimizer,
163 | scheduler,
164 | profiler,
165 | checkpointer,
166 | start_step,
167 | tokens_seen,
168 | )
169 |
170 | dist.barrier()
171 | dist.destroy_process_group()
172 |
173 |
174 | if __name__ == "__main__":
175 | fire.Fire(main)
176 |
--------------------------------------------------------------------------------
/main_training_mamba.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from pathlib import Path
4 |
5 | import fire
6 | import torch
7 | import torch.optim as optim
8 | from mamba_ssm.models.config_mamba import MambaConfig
9 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
10 | from mamba_ssm.modules.block import Block
11 | from torch import distributed as dist
12 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
13 | from torch.optim.lr_scheduler import LambdaLR
14 |
15 | from fms_fsdp import config
16 | from fms_fsdp.utils.checkpointing_utils import Checkpointer
17 | from fms_fsdp.utils.config_utils import get_model_config, update_config
18 | from fms_fsdp.utils.dataloader_utils import get_data_loader, get_dummy_loader
19 | from fms_fsdp.utils.train_utils import (
20 | get_policies,
21 | get_profiler,
22 | setup,
23 | setup_environ_flags,
24 | train,
25 | )
26 |
27 |
28 | def main(**kwargs):
29 | # get configs
30 | cfg = config.train_config()
31 | update_config(cfg, **kwargs)
32 |
33 | # ensure reproducibility
34 | torch.cuda.manual_seed(cfg.seed)
35 | torch.manual_seed(cfg.seed)
36 |
37 | # torchrun specific
38 | local_rank = int(os.environ["LOCAL_RANK"])
39 | rank = int(os.environ["RANK"])
40 | world_size = int(os.environ["WORLD_SIZE"])
41 |
42 | if rank == 0:
43 | print(f"--> running with these configs {cfg}")
44 |
45 | # some setups
46 | setup()
47 | torch.cuda.set_device(local_rank)
48 | torch.cuda.empty_cache()
49 | setup_environ_flags()
50 | os.environ["TRITON_CACHE_DIR"] = os.path.join(
51 | Path.home(), ".triton", "cache", str(local_rank)
52 | )
53 |
54 | # get policy
55 | block = Block
56 | (
57 | mixed_precision_policy,
58 | wrapping_policy,
59 | sharding_strategy_policy,
60 | apply_selective_ac,
61 | param_init_fn,
62 | ) = get_policies(cfg, rank, block)
63 |
64 | # get model
65 | config_data = get_model_config(cfg.model_variant)
66 | mamba_config = MambaConfig(**config_data)
67 | model = MambaLMHeadModel(mamba_config)
68 |
69 | if rank == 0:
70 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
71 | print(f"\n--> model has {total_params / 1e6} Million params\n")
72 |
73 | # get data loader
74 | if rank == 0:
75 | print("Constructing datasets...")
76 | if not cfg.use_dummy_dataset:
77 | train_loader = get_data_loader(cfg, rank, world_size)
78 | else:
79 | train_loader = get_dummy_loader(cfg, rank, world_size)
80 | if rank == 0:
81 | print("Datasets constructed!")
82 |
83 | # FSDP
84 | model = FSDP(
85 | model,
86 | auto_wrap_policy=wrapping_policy,
87 | mixed_precision=mixed_precision_policy,
88 | sharding_strategy=sharding_strategy_policy,
89 | use_orig_params=cfg.use_torch_compile,
90 | device_id=torch.cuda.current_device(),
91 | limit_all_gathers=True,
92 | param_init_fn=param_init_fn,
93 | )
94 |
95 | # fsdp activation checkpointing
96 | if cfg.fsdp_activation_checkpointing:
97 | if rank == 0:
98 | print(f"--> applying FSDP activation checkpointing...")
99 | apply_selective_ac(model, p=cfg.selective_checkpointing)
100 |
101 | # torch compile
102 | if cfg.use_torch_compile:
103 | if rank == 0:
104 | print(f"--> enabling torch compile...")
105 | # the default accumulated_cache_size_limit=64 is not enough for 70b model, so we make it 128 here
106 | torch._dynamo.config.accumulated_cache_size_limit = 128
107 | model = torch.compile(model)
108 |
109 | # Optimizer
110 | optimizer = optim.AdamW(
111 | model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1
112 | )
113 |
114 | # optionally load from checkpoint (when continue pretraining)
115 | checkpointer = Checkpointer(
116 | cfg.ckpt_save_path, 1000, cfg.sharding_strategy, rank, local_rank
117 | )
118 | model, optimizer, _, start_step, tokens_seen, is_resuming = checkpointer.load(
119 | model,
120 | optimizer,
121 | None,
122 | path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
123 | if not os.path.isfile(cfg.ckpt_load_path)
124 | else cfg.ckpt_load_path,
125 | strict=False,
126 | )
127 | if not is_resuming:
128 | start_step = 0
129 | # Override loaded optim hyperparams with the current values
130 | for g in optimizer.param_groups:
131 | g["initial_lr"] = cfg.learning_rate
132 |
133 | # LR schedule
134 | # linear decay for annealing
135 | if cfg.training_stage == "annealing":
136 | schedule = lambda x: 1 - x / cfg.num_steps
137 | else:
138 | # cosine decay
139 | warmup_interval = min(2000, cfg.num_steps // 20)
140 | schedule = lambda x: min(
141 | 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2,
142 | 0.1
143 | + 0.5
144 | * (1 - 0.1)
145 | * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)),
146 | )
147 |
148 | scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step))
149 |
150 | # profiler
151 | profiler = get_profiler(cfg, rank)
152 |
153 | # Train
154 | if rank == 0:
155 | print(f"Training for {cfg.num_steps} steps")
156 | train(
157 | cfg,
158 | model,
159 | local_rank,
160 | rank,
161 | train_loader,
162 | optimizer,
163 | scheduler,
164 | profiler,
165 | checkpointer,
166 | start_step,
167 | tokens_seen,
168 | )
169 |
170 | dist.barrier()
171 | dist.destroy_process_group()
172 |
173 |
174 | if __name__ == "__main__":
175 | fire.Fire(main)
176 |
--------------------------------------------------------------------------------
/requirements-speculator.txt:
--------------------------------------------------------------------------------
1 | -r requirements.txt
2 | fms-extras @ git+https://github.com/foundation-model-stack/fms-extras@main
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=2.2.0
2 | fire==0.5.0
3 | pyarrow==15.0.0
4 | transformers==4.40.2
5 | ibm-fms>=0.0.3
6 |
--------------------------------------------------------------------------------
/scripts/README_SPECULATOR.md:
--------------------------------------------------------------------------------
1 | ### Following parameters are relevant for speculator training:
2 |
3 | - *model_arch*: architecture of the base model (one of: embedllama, embedmixtral, embedgpt_bigcode-- FMS implementations extending the base arch to also emit embedding vector together with the model output. See 'EmbedLLaMA' in train_spculator_utils.py)
4 |
5 | - *model_variant*: identifier with which a specific variant (e.g., 7b) is registered for the model architecture. See 'example model registrations' in train_spculator_utils.py.
6 |
7 | - *model_path*: path to dir containing base model weights
8 |
9 | - *ckpt_save_path*: path to dir for storing intermediate checkpoints during speculator training
10 |
11 | - *ckpt_load_path*: path to dir for loading intermediate speculator checkpoint to resume training
12 |
13 | - *sharding_strategy*: how to shard the model across process group: tp / fsdp / hsdp
14 |
15 | - *tp_size*: If loading base model using tensor parallel, no. of GPUs/ranks to split the model across
16 |
17 | - *seq_length*: sequence length of the base model
18 |
19 | - *batch_size*: batch size for stage 1 training for aligning speculator to base model input behavior
20 |
21 | - *report_interval*: no. of steps after which to report training stats
22 |
23 | - *checkpoint_interval*: no. of steps after which to save an intermediate speculator checkpoint
24 |
25 | - *num_steps*: total no. of speculator training steps (stage 1 + stage 2)
26 |
27 | - *stage2_start_step*: no. of steps after which to switch to stage 2 training
28 |
29 | - *stage2_batch_size*: batch size for stage 2 training for aligning speculator to base model output behavior
30 |
31 | - *n_speculator_heads*: no. of lookahead tokens to train the speculator for
32 |
33 | - *speculator_width*: embedding dimension of the speculator MLP
34 |
35 | - *use_torch_compile*: whether to compile base model and speculator-- may speed up training.
36 |
37 | - *learning_rate*: learning rate for speculator training
38 |
39 | - *seed*: random seed to use for training dataset shuffling
40 |
41 | - *data_path*: path to dir containing the training dataset. Expects directory to contain subfolders, which in turn contain shard files.
42 |
43 | - *datasets*: a list of subdatasets (e.g., commoncrawl, github, etc.) to draw from. If None, draws from all subfolders of data_path.
44 |
45 | - *weights*: list of weights reflecting the percentage of tokens to be used from each subdataset during training
46 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # On AWS, the EFA and OFI paths enable NCCL to use optimized networking.
4 | export LD_LIBRARY_PATH=/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/amazon/openmpi/lib:/opt/aws-ofi-nccl/lib:/usr/local/cuda/lib:/usr/local/cuda/lib64:/usr/local/cuda:/usr/local/cuda/targets/x86_64-linux/lib/:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/lib:$LD_LIBRARY_PATH
5 |
6 | export FI_EFA_SET_CUDA_SYNC_MEMOPS=0
7 |
8 | MODEL_ARGS="\
9 | --use_dummy_dataset=False
10 | --ckpt_load_path=/lustre/pretrain/ckpt
11 | --ckpt_save_path=/lustre/pretrain/ckpt
12 | --data_path=/lustre/bluepile-processing/rel0_7/tokens/llama2/high_quality_rerun_fuzzy_deduped
13 | --fsdp_activation_checkpointing=False
14 | --selective_checkpointing=1
15 | --sharding_strategy=hsdp
16 | --low_cpu_fsdp=False
17 | --batch_size=2
18 | --report_interval=200
19 | --checkpoint_interval=20000
20 | --use_torch_compile=False
21 | --use_profiler=False
22 | "
23 |
24 | torchrun \
25 | --nnodes=$SLURM_NTASKS \
26 | --node_rank=$SLURM_NODEID \
27 | --nproc_per_node=8 \
28 | --master_addr=`scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1` \
29 | --master_port="12234" \
30 | main_training.py \
31 | ${MODEL_ARGS}
32 |
33 |
--------------------------------------------------------------------------------
/scripts/train.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --nodes=16
4 | #SBATCH --gres=gpu:8
5 | #SBATCH --cpus-per-task=64
6 | #SBATCH --ntasks-per-node=1
7 | #SBATCH --wait-all-nodes=1
8 | #SBATCH --exclusive
9 | ##SBATCH --contiguous
10 |
11 | srun ./scripts/train.sh
12 |
--------------------------------------------------------------------------------
/scripts/train_speculator.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # On AWS, the EFA and OFI paths enable NCCL to use optimized networking.
4 | export LD_LIBRARY_PATH=/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/amazon/openmpi/lib:/opt/aws-ofi-nccl/lib:/usr/local/cuda/lib:/usr/local/cuda/lib64:/usr/local/cuda:/usr/local/cuda/targets/x86_64-linux/lib/:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/lib:$LD_LIBRARY_PATH
5 |
6 | export FI_EFA_SET_CUDA_SYNC_MEMOPS=0
7 |
8 | MODEL_ARGS="\
9 | --model_path=/path/to/models/meta-llama/Llama-2-7b-hf
10 | --model_arch=embedllama
11 | --model_variant=7b
12 | --ckpt_load_path=/path/to/checkpoints/llama2-7b
13 | --ckpt_save_path=/path/to/checkpoints/llama2-7b
14 | --logical_shards=768
15 | --sharding_strategy=hsdp
16 | --seq_length=4096
17 | --batch_size=8
18 | --report_interval=10
19 | --checkpoint_interval=3000
20 | --num_steps=21000
21 | --stage2_start_step=15000
22 | --stage2_batch_size=96
23 | --n_speculator_heads=3
24 | --speculator_width=4096
25 | --use_torch_compile=False
26 | --learning_rate=1e-3
27 | --seed=42
28 | --data_path=/path/to/dataset/
29 | --datasets="'dataset=commoncrawl'"
30 | --weights="'1'"
31 | "
32 |
33 | torchrun \
34 | --nproc_per_node=8 \
35 | speculator/train_speculator.py \
36 | ${MODEL_ARGS}
37 |
38 |
39 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 |
4 | setup(
5 | name="fms_fsdp",
6 | version="0.0.1",
7 | author="Linsong Chu, Davis Wertheimer, Brian Vaughan, Andrea Frittoli, Joshua Rosenkranz, Antoni Viros i Martin, Raghu Kiran Ganti",
8 | author_email="lchu@us.ibm.com",
9 | description="Pretraining scripts using FSDP and IBM Foundation Model Stack",
10 | url="https://github.com/foundation-model-stack/fms-fsdp",
11 | packages=find_packages(),
12 | install_requires=["ibm-fms >= 0.0.3", "torch >= 2.1"],
13 | license="Apache License 2.0",
14 | classifiers=[
15 | "Programming Language :: Python :: 3",
16 | "License :: OSI Approved :: Apache Software License",
17 | ],
18 | )
19 |
--------------------------------------------------------------------------------
/speculator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foundation-model-stack/fms-fsdp/503da7ede354e1ffdabc80fae8bbd211cb2174c8/speculator/__init__.py
--------------------------------------------------------------------------------
/speculator/train_speculator.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import time
4 |
5 | import fire # type: ignore
6 | import torch
7 | import torch.optim as optim
8 | from fms.models import get_model
9 | from fms.models.llama import LLaMABlock
10 | from fms.utils import generation, tokenizers
11 | from fms_extras.models.speculator import MLPSpeculator # type: ignore
12 | from torch import distributed as dist
13 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
14 | from torch.distributed.fsdp import ShardingStrategy
15 | from torch.optim.lr_scheduler import LambdaLR
16 |
17 | from fms_fsdp import config
18 | from fms_fsdp.utils.checkpointing_utils import Checkpointer
19 | from fms_fsdp.utils.config_utils import update_config
20 | from fms_fsdp.utils.dataloader_utils import get_data_loader, get_dummy_loader
21 | from fms_fsdp.utils.train_utils import (
22 | get_mixed_precision_policy,
23 | get_profiler,
24 | setup,
25 | setup_environ_flags,
26 | )
27 | from speculator.train_speculator_utils import train_speculator
28 |
29 |
30 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
31 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
32 |
33 |
34 | def test_model(rank, model, arch, cfg, prompt_type="chat"):
35 | if rank == 0:
36 | print("testing model output")
37 | tokenizer = tokenizers.get_tokenizer(cfg.model_path)
38 | if prompt_type == "chat":
39 | template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:"
40 | prompt = template.format(
41 | "Provide a list of instructions for preparing chicken soup."
42 | )
43 | else:
44 | template = "[INST] Write code to solve the following coding problem that obeys the constraints and passes the example test cases. Please wrap your code answer using ```:\n{}\n[/INST]"
45 | prompt = template.format("Write a bubble sort function in python.")
46 |
47 | tokens = tokenizer.tokenize(prompt)
48 | ids = tokenizer.convert_tokens_to_ids(tokens)
49 | if "llama" in arch:
50 | ids = [tokenizer.bos_token_id] + ids
51 | ids = torch.tensor(ids, dtype=torch.long, device="cuda")
52 | result = generation.generate(
53 | model,
54 | ids,
55 | max_new_tokens=100,
56 | use_cache=True,
57 | do_sample=False,
58 | max_seq_len=8192,
59 | )
60 | result = generation.truncate_after_eos(result, tokenizer.eos_token_id)
61 | if rank == 0:
62 | print(f"{rank}: quick test of base model")
63 | print(
64 | tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(result))
65 | )
66 |
67 |
68 | def get_emb_dim(model):
69 | if hasattr(model.config, "emb_dim"):
70 | emb_dim = model.config.emb_dim
71 | elif hasattr(model.config, "dim"): # Mixtral
72 | emb_dim = model.config.dim
73 | elif hasattr(model.config, "hidden_size"): # HF
74 | emb_dim = model.config.hidden_size
75 | else:
76 | raise Exception("config missing embedding dimension")
77 | return emb_dim
78 |
79 |
80 | def get_vocab_size(model):
81 | if hasattr(model.config, "src_vocab_size"): # FMS
82 | vocab_size = model.config.src_vocab_size
83 | elif hasattr(model.config, "vocab_size"): # HF
84 | vocab_size = model.config.vocab_size
85 | else:
86 | raise Exception("config missing vocab size config")
87 | return vocab_size
88 |
89 |
90 | def get_training_data_loader(rank, cfg, world_size, speculator_mesh):
91 | if rank == 0:
92 | print(f"{time.time()} Constructing datasets...")
93 | if not cfg.use_dummy_dataset:
94 | if cfg.sharding_strategy == "tp":
95 | train_loader = get_data_loader(
96 | cfg, speculator_mesh.get_rank(), speculator_mesh.size(), postprocess=[]
97 | )
98 | else:
99 | train_loader = get_data_loader(cfg, rank, world_size, postprocess=[])
100 | else:
101 | train_loader = get_dummy_loader(cfg, rank, world_size)
102 | if rank == 0:
103 | print(f"{time.time()} Datasets constructed!")
104 | return train_loader
105 |
106 |
107 | def main(**kwargs):
108 | # get configs
109 | cfg = config.train_config()
110 | update_config(cfg, **kwargs)
111 | cfg.seq_length = cfg.seq_length + cfg.n_speculator_heads + 1
112 |
113 | # ensure reproducibility
114 | torch.cuda.manual_seed(cfg.seed)
115 | torch.manual_seed(cfg.seed)
116 |
117 | # torchrun specific
118 | local_rank = int(os.environ["LOCAL_RANK"])
119 | rank = int(os.environ["RANK"])
120 | world_size = int(os.environ["WORLD_SIZE"])
121 |
122 | if rank == 0:
123 | print(f"{time.time()} running with these configs {cfg}")
124 |
125 | # some setups
126 | torch.cuda.set_device(local_rank)
127 |
128 | if cfg.sharding_strategy != "tp":
129 | setup()
130 | torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
131 | base_model_mesh = None
132 | speculator_mesh = None
133 | else:
134 | base_model_mesh = dist.device_mesh.init_device_mesh(
135 | "cuda",
136 | (world_size // cfg.tp_size, cfg.tp_size),
137 | mesh_dim_names=("dp", "tp"),
138 | )
139 | speculator_mesh = dist.device_mesh.init_device_mesh("cuda", (world_size,))
140 | torch._C._distributed_c10d._register_process_group(
141 | "default", base_model_mesh["tp"].get_group()
142 | )
143 |
144 | torch.cuda.empty_cache()
145 | setup_environ_flags()
146 | torch.set_default_dtype(torch.bfloat16)
147 |
148 | mixed_precision_policy = get_mixed_precision_policy(cfg, rank)
149 |
150 | model = get_model(
151 | cfg.model_arch,
152 | cfg.model_variant,
153 | model_path=cfg.model_path,
154 | device_type="cuda",
155 | source="hf",
156 | distributed_strategy=cfg.sharding_strategy,
157 | group=(
158 | base_model_mesh["tp"].get_group() if cfg.sharding_strategy == "tp" else None
159 | ),
160 | )
161 |
162 | if rank == 0:
163 | print(f"{time.time()}")
164 | print(model.config)
165 | print(model)
166 |
167 | model.eval()
168 | with torch.no_grad():
169 | test_model(rank, model, cfg.model_arch, cfg)
170 |
171 | emb_dim = get_emb_dim(model)
172 | vocab_size = get_vocab_size(model)
173 |
174 | # get speculator
175 | if rank == 0:
176 | print(f"{time.time()} Loading speculator")
177 | speculator = MLPSpeculator(
178 | emb_dim,
179 | cfg.speculator_width,
180 | vocab_size,
181 | cfg.n_speculator_heads,
182 | tie_weights=cfg.speculator_tie_weights,
183 | scale_input=cfg.speculator_scale_input,
184 | )
185 | speculator.reset_parameters()
186 |
187 | if rank == 0:
188 | total_params = sum(
189 | p.numel() for p in speculator.parameters() if p.requires_grad
190 | )
191 | print(f"\n{time.time()} speculator has {total_params / 1e6} Million params\n")
192 |
193 | # get data loader
194 | train_loader = get_training_data_loader(rank, cfg, world_size, speculator_mesh)
195 |
196 | # FSDP
197 | speculator = FSDP(
198 | speculator,
199 | auto_wrap_policy=None,
200 | mixed_precision=mixed_precision_policy,
201 | sharding_strategy=ShardingStrategy.NO_SHARD,
202 | use_orig_params=cfg.use_torch_compile,
203 | device_id=torch.cuda.current_device(),
204 | limit_all_gathers=True,
205 | sync_module_states=cfg.low_cpu_fsdp,
206 | param_init_fn=lambda module: (
207 | module.to_empty(device=torch.device("cuda"), recurse=False)
208 | if cfg.low_cpu_fsdp
209 | else None
210 | ),
211 | device_mesh=speculator_mesh if cfg.sharding_strategy == "tp" else None,
212 | )
213 |
214 | # torch compile
215 | if cfg.use_torch_compile:
216 | if rank == 0:
217 | print(f"enabling torch compile...")
218 | if cfg.fsdp_activation_checkpointing:
219 | raise ValueError(
220 | "Compile does not yet work well with llama+ac, please"
221 | "either use it without activation checkpointing, or disable"
222 | "compile."
223 | )
224 | # we need this post-fsdp call to avoid graph break with torch.compile,
225 | if cfg.sharding_strategy != "tp" and hasattr(model, "rot_emb"):
226 | model.rot_emb.compute_freqs_cis(
227 | torch.device("cuda", torch.cuda.current_device()),
228 | model.config.max_expected_seq_len + 10,
229 | )
230 | model = torch.compile(model)
231 | speculator = torch.compile(speculator)
232 |
233 | # Optimizer
234 | optimizer = optim.AdamW(
235 | speculator.parameters(),
236 | lr=cfg.learning_rate,
237 | betas=(0.9, 0.95),
238 | weight_decay=0.1,
239 | )
240 |
241 | # optionally load from checkpoint (when continue pretraining)
242 | if cfg.sharding_strategy == "tp":
243 | checkpointer = Checkpointer(
244 | cfg.ckpt_save_path,
245 | 1000,
246 | "ddp",
247 | speculator_mesh.get_rank(),
248 | speculator_mesh.get_local_rank(),
249 | model_auto_placement=True,
250 | )
251 | else:
252 | checkpointer = Checkpointer(cfg.ckpt_save_path, 1000, "ddp", rank, local_rank)
253 | speculator, optimizer, train_loader, start_step, tokens_seen, _ = checkpointer.load(
254 | speculator,
255 | optimizer,
256 | train_loader,
257 | path=os.path.join(cfg.ckpt_load_path, "checkpoints/"),
258 | is_compiled=cfg.use_torch_compile,
259 | )
260 |
261 | # LR schedule
262 | # These functions map step count to LR scaling factor in [0,1].
263 | # Stage 1: warm up over first 2k or 5% of steps, whichever is smaller.
264 | # Then cosine anneal to 10% of max LR.
265 | warmup_interval1 = min(2000, cfg.stage2_start_step // 20)
266 | stage1_schedule = lambda x: min(
267 | # Parabolic warmup
268 | 1 - (1 - min(x, warmup_interval1) / warmup_interval1) ** 2,
269 | # Final .1 scaling factor
270 | 0.1
271 | # Cosine anneal from 1 to .1 over stage2_start_step steps
272 | + 0.5 * (1 - 0.1) * (1 + math.cos(x / cfg.stage2_start_step * math.pi)),
273 | )
274 | # Stage 2: warm up over first 2k or 5% of steps, whichever is smaller.
275 | # Then cosine anneal to 10% of stage 1's final LR.
276 | warmup_interval2 = min(2000, (cfg.num_steps - cfg.stage2_start_step) // 20)
277 | stage2_schedule = lambda x: min(
278 | # Parabolic warmup to stage2's max LR (10% of stage1's max LR)
279 | 0.1 * (1 - (1 - min(x, warmup_interval2) / warmup_interval2) ** 2),
280 | # Final 10% of 10% scaling factor
281 | 0.01
282 | # Cosine anneal from .1 to .01 over remaining stage2 steps
283 | + 0.05
284 | * (1 - 0.1)
285 | * (
286 | 1
287 | + math.cos(
288 | min(x, cfg.num_steps - cfg.stage2_start_step)
289 | / (cfg.num_steps - cfg.stage2_start_step)
290 | * math.pi
291 | )
292 | ),
293 | )
294 | # Assemble full scheduling function with correct step offsets.
295 | schedule = lambda x: (
296 | stage1_schedule(x)
297 | if x <= cfg.stage2_start_step
298 | else stage2_schedule(x - cfg.stage2_start_step)
299 | )
300 | scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step))
301 |
302 | # profiler
303 | profiler = get_profiler(cfg, rank)
304 |
305 | # Train
306 | if rank == 0:
307 | print(f"{time.time()} Training for {cfg.num_steps} steps")
308 | torch.cuda.empty_cache()
309 | train_speculator(
310 | cfg,
311 | model,
312 | speculator,
313 | local_rank,
314 | rank,
315 | train_loader,
316 | optimizer,
317 | scheduler,
318 | checkpointer,
319 | start_step,
320 | tokens_seen,
321 | profiler,
322 | base_model_mesh,
323 | )
324 |
325 | dist.barrier()
326 | dist.destroy_process_group()
327 |
328 |
329 | if __name__ == "__main__":
330 | fire.Fire(main)
331 |
--------------------------------------------------------------------------------
/speculator/train_speculator_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import time
4 | from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union
5 |
6 | import torch
7 | import torch.distributed as dist
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from fms.models import register_model
11 | from fms.models.gpt_bigcode import GPTBigCode
12 | from fms.models.gpt_bigcode import _20b_config as _gpt_bigcode_20b_config
13 | from fms.models.gpt_bigcode import _hf_to_fms_names as _gptbigcode_hf_sd_to_fms_sd
14 | from fms.models.llama import LLaMA
15 | from fms.models.llama import _hf_to_fms_names as _llama_hf_sd_to_fms_sd
16 | from fms.models.mixtral import Mixtral, MixtralConfig
17 | from fms.models.mixtral import _hf_to_fms_names as _mixtral_hf_sd_to_fms_sd
18 | from fms.utils import serialization, tokenizers
19 | from fms.utils.generation import _make_cache_contiguous
20 | from torch.nn import CrossEntropyLoss
21 | from torch.utils.data import DataLoader
22 |
23 | from fms_fsdp.config import train_config
24 | from fms_fsdp.utils.checkpointing_utils import Checkpointer
25 | from fms_fsdp.utils.config_utils import get_model_config
26 |
27 |
28 | def generate(
29 | model: Union[Callable, torch.nn.Module],
30 | input_ids: torch.Tensor,
31 | max_seq_len: int = 2048,
32 | max_new_tokens: int = 256,
33 | temperature: float = 1.0,
34 | top_k: int = 10,
35 | do_sample: bool = True,
36 | num_beams: int = 1,
37 | use_cache: bool = False,
38 | contiguous_cache: bool = False,
39 | include_embeds: bool = True,
40 | ):
41 | """
42 | A straightforward copy of the generate method in fms.utils.generation.
43 | The only change is the include_embeds flag, which when true also returns
44 | the embedding vectors corresponding to the tokens in the output sequence.
45 | """
46 | batched = False
47 | if num_beams != 1:
48 | raise NotImplementedError("generate() does yet not support beam search")
49 | if type(input_ids) == torch.Tensor:
50 | if input_ids.dim() != 1:
51 | batched = True
52 | else:
53 | raise RuntimeError("generate() requires a tensor of token ids as the prefix")
54 |
55 | if not batched:
56 | input_ids = input_ids.unsqueeze(0)
57 |
58 | embeds = None
59 | result = input_ids
60 | next_input = input_ids
61 | kwargs: MutableMapping[str, Any] = dict()
62 | kwargs["past_key_value_states"] = None
63 | kwargs["use_cache"] = use_cache
64 | kwargs["include_embeds"] = include_embeds
65 |
66 | for _ in range(max_new_tokens):
67 | input_ids = next_input[:, -max_seq_len:]
68 | output = model(input_ids, **kwargs)
69 | if not use_cache and not include_embeds:
70 | logits = output
71 | else:
72 | logits = output[0]
73 | if include_embeds:
74 | z = output[-1]
75 | if use_cache:
76 | past_key_value_states = output[1]
77 | # TODO: this should go away when reduce-overhead issues are fixed, or
78 | # maybe could be moved into model code to be more portable.
79 | if contiguous_cache:
80 | kwargs["past_key_value_states"] = _make_cache_contiguous(
81 | past_key_value_states
82 | )
83 | else:
84 | kwargs["past_key_value_states"] = past_key_value_states
85 | logits = logits[:, -1, :]
86 |
87 | if do_sample:
88 | # get logits from last value in sequence nad scale
89 | logits = logits / temperature
90 | if top_k:
91 | v, _ = torch.topk(logits, top_k)
92 | logits[logits < v[:, [-1]]] = -float("inf")
93 |
94 | probs = F.softmax(logits, dim=-1)
95 | next_val = torch.multinomial(probs, num_samples=1)
96 | else:
97 | next_val = torch.argmax(logits, dim=-1).unsqueeze(0).t()
98 |
99 | result = torch.cat((result, next_val), dim=-1)
100 |
101 | if use_cache:
102 | next_input = next_val
103 | else:
104 | next_input = result
105 |
106 | if include_embeds:
107 | if embeds is None:
108 | embeds = z
109 | else:
110 | embeds = torch.cat((embeds, z), dim=-2)
111 |
112 | if not batched:
113 | result = result[0]
114 |
115 | if include_embeds:
116 | return result, embeds
117 |
118 | return result
119 |
120 |
121 | # Stage 1 training
122 | def stage1_loss(
123 | cfg, model, speculator, base_model_input, input, loss_fn, ddp_stats, base_model_mesh
124 | ):
125 | """
126 | Perform a forward pass for stage 1 training and calculate the loss.
127 | Given the sequence of embeddings produced in parallel by the base model,
128 | get n+2,n+3,... speculator predictions and compare to ground truth tokens.
129 | ...
130 | Args
131 | ----
132 | cfg: train_config
133 | Set of training parameters.
134 | model: nn.Module
135 | The frozen base model. Must return output logits AND corresponding embedding vectors.
136 | speculator: nn.Module
137 | The speculator to be trained. Takes as input sequence of embeddings and token indices,
138 | and return token prediction logits for each head.
139 | input: torch.IntTensor
140 | The ground truth token indices. If using TP, this is per TP rank,
141 | with 'base_model_input' containing all-gathered input across all TP ranks
142 | loss_fn: Callable
143 | Torch loss function comparing logits to indices i.e. CrossEntropyLoss()
144 | ddp_stats: torch.FloatTensor
145 | Aggregate stat tracking buffer.
146 | Entries are: grad norm, accumulation steps, head 1 loss, head 2 loss, etc.
147 | base_model_mesh: torch.distributed.device_mesh.DeviceMesh
148 | Device layout of the particiapting process group ranks
149 | ----
150 | Returns: scalar loss value, updated ddp stats, number of tokens in input
151 | """
152 | with torch.no_grad():
153 | _, embeds = model(
154 | base_model_input[:, : -speculator.n_predict - 1],
155 | include_embeds=True,
156 | use_cache=False,
157 | )
158 | if cfg.sharding_strategy == "tp":
159 | embeds = embeds.chunk(base_model_mesh["tp"].size())[
160 | base_model_mesh["tp"].get_local_rank()
161 | ]
162 |
163 | preds = speculator(embeds.detach(), input[:, 1:])
164 | losses = []
165 | for i in range(preds.size(0)):
166 | targ = input[:, i + 2 : preds.size(2) + i + 2] # b n
167 | loss = loss_fn(preds[i].reshape(-1, preds.size(3)), targ.long().reshape(-1))
168 | losses.append(loss)
169 | ddp_stats[2 + i] += loss.item()
170 | loss = sum(losses)
171 | return loss, ddp_stats, input.numel()
172 |
173 |
174 | # Stage 2 training: more heavyweight than stage 1; will take longer
175 | def stage2_loss(
176 | cfg, model, speculator, base_model_input, input, loss_fn, ddp_stats, base_model_mesh
177 | ):
178 | """
179 | Perform a forward pass for stage 2 training and calculate the loss.
180 | Given the sequence of embeddings produced in serial by the base model,
181 | get n+1,n+2,... speculator predictions and compare to base model's generated tokens.
182 | Reshapes input to more entries / shorter sequences, for more efficient generation.
183 | ...
184 | Args
185 | ----
186 | cfg: train_config
187 | Set of training parameters. Used here for reshaping input batches.
188 | model: nn.Module
189 | The frozen base model. Must return output logits AND corresponding embedding vectors.
190 | speculator: nn.Module
191 | The speculator to be trained. Takes as input sequence of embeddings and token indices,
192 | and return token prediction logits for each head.
193 | input: torch.IntTensor
194 | The ground truth token indices. If using TP, this is per TP rank,
195 | with 'base_model_input' containing all-gathered input across all TP ranks
196 | loss_fn: Callable
197 | Torch loss function comparing logits to indices i.e. CrossEntropyLoss()
198 | ddp_stats: torch.FloatTensor
199 | Aggregate stat tracking buffer.
200 | Entries are: grad norm, accumulation steps, head 1 loss, head 2 loss, etc.
201 | base_model_mesh: torch.distributed.device_mesh.DeviceMesh
202 | Device layout of the particiapting process group ranks
203 | ----
204 | Returns: scalar loss value, updated ddp stats, number of tokens in input
205 | """
206 | with torch.no_grad():
207 | grow_factor = cfg.stage2_batch_size // cfg.batch_size
208 | assert (
209 | cfg.stage2_prompt_length * grow_factor <= cfg.seq_length
210 | ), "Error: batch is too small for specified partition"
211 | base_model_input = base_model_input[
212 | :, : cfg.stage2_prompt_length * grow_factor
213 | ].reshape(base_model_input.size(0) * grow_factor, cfg.stage2_prompt_length)
214 | targs, embeds = generate(
215 | model,
216 | base_model_input,
217 | cfg.seq_length,
218 | cfg.stage2_seq_length,
219 | do_sample=True,
220 | use_cache=True,
221 | include_embeds=True,
222 | )
223 |
224 | if cfg.sharding_strategy == "tp":
225 | targs = targs.chunk(base_model_mesh["tp"].size())[
226 | base_model_mesh["tp"].get_local_rank()
227 | ]
228 | embeds = embeds.chunk(base_model_mesh["tp"].size())[
229 | base_model_mesh["tp"].get_local_rank()
230 | ]
231 | targs = targs[:, -cfg.stage2_seq_length :]
232 | embeds = embeds[:, -cfg.stage2_seq_length : -speculator.n_predict]
233 | preds = speculator(embeds.detach(), targs[:, :-1].detach())
234 |
235 | losses = []
236 | for i in range(preds.size(0)):
237 | targ = targs[:, i + 1 : preds.size(2) + i + 1] # b n
238 | loss = loss_fn(preds[i].reshape(-1, preds.size(3)), targ.long().reshape(-1))
239 | losses.append(loss)
240 | ddp_stats[2 + i] += loss.item()
241 | loss = sum(losses)
242 | return loss, ddp_stats, targs.numel()
243 |
244 |
245 | # on demand checkpointing: echo 1 > /path/to/model_ckpt_dir/do_ckpt
246 | def do_ckpt(ckpt_save_path, reset=False):
247 | ckpt_cmd_file = ckpt_save_path + "/do_ckpt"
248 | if not os.path.exists(ckpt_cmd_file):
249 | return False
250 |
251 | if reset:
252 | with open(ckpt_cmd_file, "w") as fd:
253 | fd.write("0")
254 | return False
255 |
256 | with open(ckpt_cmd_file) as fd:
257 | if fd.read().strip() == "1":
258 | return True
259 |
260 | return False
261 |
262 |
263 | def train_speculator(
264 | cfg: train_config,
265 | model: nn.Module,
266 | speculator: nn.Module,
267 | local_rank: int,
268 | rank: int,
269 | train_loader: DataLoader,
270 | optimizer: torch.optim.Optimizer,
271 | scheduler: torch.optim.lr_scheduler.LRScheduler,
272 | checkpointer: Checkpointer,
273 | start_step: int = 0,
274 | n_tok: int = 0,
275 | profiler: Optional[Union[torch.profiler.profile, None]] = None,
276 | base_model_mesh=None,
277 | ):
278 | """
279 | The training loop for speculator training. Handles at a high level: data loading,
280 | forward and backward passes, model updates, stat tracking, reporting, and checkpointing.
281 | ...
282 | Args
283 | ----
284 | cfg: train_config
285 | The set of training parameters
286 | model: nn.Module
287 | The frozen base model. Must return output logits AND corresponding embedding vectors.
288 | speculator: nn.Module
289 | The speculator to be trained. Takes as input sequence of embeddings and token indices,
290 | and returns token prediction logits for each head.
291 | local_rank: int
292 | The local rank of the current process. Used for stat tracking / aggregation across ranks.
293 | rank: int
294 | The global rank of the current process. Used for reporting.
295 | train_loader: torch.utils.data.DataLoader
296 | The dataloader used for reading in ground truth token sequences. Train_loader.dataset must
297 | support save_to_path() for distributed checkpointing via checkpointer.
298 | optimizer: torch.optim.Optimizer
299 | The optimizer associated with the speculator's weights
300 | scheduler: torch.optim.lr_scheduler.LRScheduler
301 | A scheduler for the optimizer's LR. Scheduler.step() is called on every optimizer step.
302 | checkpointer: fms_fsdp.utils.checkpointing_utils.Checkpointer
303 | A checkpointer tied to the save directory. Used for saving distributed checkpoints.
304 | start_step: optional[int]
305 | If resuming from checkpoint, resume step count from this value.
306 | n_tok: optional[int]
307 | If resuming from checkpoint, resume token count from this value.
308 | profiler: optional[torch.profiler.profile]
309 | Optional torch profiler for performance benchmarking.
310 | base_model_mesh: DeviceMesh
311 | Device layout of the particiapting process group ranks
312 | """
313 | model.eval()
314 | speculator.train()
315 | ddp_stats = torch.zeros(2 + speculator.n_predict).to(local_rank)
316 |
317 | start = time.time()
318 | loop_start = time.time()
319 | loss_fn = CrossEntropyLoss()
320 | elapsed_tokens = 0
321 | for batch_idx, input in enumerate(train_loader, start=start_step + 1):
322 | if batch_idx > cfg.num_steps:
323 | break
324 |
325 | input = input.to(local_rank)
326 |
327 | if cfg.sharding_strategy == "tp":
328 | base_model_input = torch.zeros(
329 | base_model_mesh["tp"].size() * input.size(0),
330 | input.size(1),
331 | dtype=input.dtype,
332 | device=input.device,
333 | )
334 | dist.all_gather_into_tensor(
335 | base_model_input, input, group=base_model_mesh["tp"].get_group()
336 | )
337 | else:
338 | base_model_input = input
339 |
340 | optimizer.zero_grad()
341 |
342 | if batch_idx <= cfg.stage2_start_step:
343 | loss, ddp_stats, step_tok = stage1_loss(
344 | cfg,
345 | model,
346 | speculator,
347 | base_model_input,
348 | input,
349 | loss_fn,
350 | ddp_stats,
351 | base_model_mesh,
352 | )
353 | else:
354 | loss, ddp_stats, step_tok = stage2_loss(
355 | cfg,
356 | model,
357 | speculator,
358 | base_model_input,
359 | input,
360 | loss_fn,
361 | ddp_stats,
362 | base_model_mesh,
363 | )
364 |
365 | loss.backward()
366 | ddp_stats[0] += speculator.clip_grad_norm_(cfg.grad_clip_thresh).item()
367 | optimizer.step()
368 | scheduler.step()
369 |
370 | ddp_stats[1] += 1
371 |
372 | if profiler:
373 | profiler.step()
374 |
375 | if batch_idx % cfg.report_interval == 0:
376 | dist.all_reduce(ddp_stats, op=dist.ReduceOp.SUM)
377 | train_loss = ddp_stats[2:] / ddp_stats[1]
378 | g_norm = ddp_stats[0] / ddp_stats[1]
379 | elapsed_time = time.time() - loop_start
380 | world_size = int(os.environ["WORLD_SIZE"])
381 | elapsed_tokens += cfg.report_interval * world_size * step_tok
382 | if rank == 0:
383 | print(f"{time.time()}")
384 | print("step:", batch_idx)
385 | print("tokens seen:", n_tok + elapsed_tokens)
386 | for i in range(len(train_loss)):
387 | print(f"loss {i+1}:", train_loss[i].item())
388 | print("gradient norm:", g_norm.item())
389 | print(
390 | f"speed for these {cfg.report_interval} steps:",
391 | (time.time() - start) / cfg.report_interval,
392 | )
393 | print("overall speed:", elapsed_time / (batch_idx - start_step))
394 | print("LR:", scheduler.get_last_lr())
395 | print(
396 | "reserved memory:",
397 | torch.cuda.max_memory_reserved(device=torch.cuda.current_device()),
398 | )
399 | print(
400 | "active memory:",
401 | torch.cuda.max_memory_allocated(device=torch.cuda.current_device()),
402 | )
403 | print(
404 | "overall token per gpu per sec:",
405 | int(elapsed_tokens / world_size / elapsed_time),
406 | )
407 | print("token per day:", int(elapsed_tokens / elapsed_time * 3600 * 24))
408 | print()
409 | start = time.time()
410 | ddp_stats.zero_()
411 | torch.cuda.reset_peak_memory_stats(device=torch.cuda.current_device())
412 |
413 | if (
414 | batch_idx % cfg.checkpoint_interval == 0
415 | or batch_idx == cfg.num_steps
416 | or do_ckpt(cfg.ckpt_save_path) is True
417 | ):
418 | torch.cuda.empty_cache()
419 | checkpointer.save(
420 | batch_idx,
421 | speculator,
422 | optimizer,
423 | train_loader,
424 | tokens_seen=elapsed_tokens + n_tok,
425 | )
426 | torch.cuda.empty_cache()
427 | do_ckpt(cfg.ckpt_save_path, reset=True)
428 |
429 |
430 | class EmbedGPTBigCode(GPTBigCode):
431 | # Overrides the forward function of GPTBigCode to allow returning embedding vectors
432 | def forward(
433 | self,
434 | x: torch.Tensor,
435 | mask: Optional[torch.Tensor] = None,
436 | position_ids: Optional[torch.Tensor] = None,
437 | past_key_value_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
438 | use_cache: bool = False,
439 | only_last_token: bool = False,
440 | attn_algorithm: Optional[str] = None,
441 | include_embeds: bool = False,
442 | ):
443 | output, cache = self.base_model(
444 | x,
445 | mask,
446 | position_ids=position_ids,
447 | past_key_value_states=past_key_value_states,
448 | use_cache=use_cache,
449 | attn_algorithm=attn_algorithm,
450 | )
451 |
452 | preds = self.head(output)
453 |
454 | out = [preds]
455 | if use_cache:
456 | out.append(cache)
457 | if include_embeds:
458 | out.append(output)
459 | if len(out) == 1:
460 | return out[0]
461 | return out
462 |
463 |
464 | class EmbedLLaMA(LLaMA):
465 | # Overrides the forward function of LLaMA to allow returning embedding vectors
466 | def forward(
467 | self,
468 | x,
469 | mask=None,
470 | position_ids=None,
471 | past_key_value_states=None,
472 | use_cache=False,
473 | only_last_token=False,
474 | attn_algorithm=None,
475 | include_embeds=False,
476 | ):
477 | output, cache = self._helper(
478 | x, mask, position_ids, past_key_value_states, use_cache, attn_algorithm
479 | )
480 |
481 | if only_last_token:
482 | output = output[:, -1, :]
483 | preds = self.shared(output, reverse=True)
484 |
485 | out = [preds]
486 | if use_cache:
487 | out.append(cache)
488 | if include_embeds:
489 | out.append(output)
490 | if len(out) == 1:
491 | return out[0]
492 | return out
493 |
494 |
495 | class EmbedMixtral(Mixtral): # FMS impl of Mixtral
496 | # Overrides the forward function of Mixtral to allow returning embedding vectors
497 | def forward(
498 | self,
499 | x,
500 | mask=None,
501 | position_ids=None,
502 | past_key_value_states=None,
503 | use_cache=False,
504 | only_last_token=False,
505 | attn_algorithm=None,
506 | include_embeds=False,
507 | ):
508 | output, cache = self.base_model(
509 | x, mask, position_ids, past_key_value_states, use_cache, attn_algorithm
510 | )
511 |
512 | if only_last_token:
513 | output = output[:, -1, :]
514 | preds = self.head(output)
515 |
516 | out = [preds]
517 | if use_cache:
518 | out.append(cache)
519 | if include_embeds:
520 | out.append(output)
521 | if len(out) == 1:
522 | return out[0]
523 | return out
524 |
525 |
526 | def _gpt_bigcode_factory_factory(config):
527 | def factory(**kwargs):
528 | return EmbedGPTBigCode(config, **kwargs)
529 |
530 | return factory
531 |
532 |
533 | def _llama_factory_factory(config):
534 | def factory(**kwargs):
535 | return EmbedLLaMA(config, **kwargs)
536 |
537 | return factory
538 |
539 |
540 | def _mixtral_factory_factory(config):
541 | def factory(**kwargs):
542 | return EmbedMixtral(config, **kwargs)
543 |
544 | return factory
545 |
546 |
547 | # example model registrations
548 | register_model(
549 | "embedgpt_bigcode", "20b", _gpt_bigcode_factory_factory(_gpt_bigcode_20b_config)
550 | )
551 | serialization.register_adapter_step(
552 | "embedgpt_bigcode", "hf_to_fms", _gptbigcode_hf_sd_to_fms_sd
553 | )
554 | serialization.register_adapter("embedgpt_bigcode", "hf", ["hf_to_fms"])
555 |
556 | register_model(
557 | "embedllama", "7b", _llama_factory_factory(get_model_config("llama2_7b"))
558 | )
559 | register_model(
560 | "embedllama", "8b", _llama_factory_factory(get_model_config("llama3_8b"))
561 | )
562 | serialization.register_adapter_step("embedllama", "hf_to_fms", _llama_hf_sd_to_fms_sd)
563 | serialization.register_adapter("embedllama", "hf", ["hf_to_fms"])
564 |
565 | register_model("embedmixtral", "8x7b", _mixtral_factory_factory(MixtralConfig()))
566 | serialization.register_adapter_step(
567 | "embedmixtral", "hf_to_fms", _mixtral_hf_sd_to_fms_sd
568 | )
569 | serialization.register_adapter("embedmixtral", "hf", ["hf_to_fms"])
570 |
--------------------------------------------------------------------------------
/test-requirements.txt:
--------------------------------------------------------------------------------
1 | # Used to install pinned test dependencies
2 | # Useful for dev/test jobs caches
3 |
4 | -r requirements.txt
5 |
6 | # Test tools
7 | black==24.1.1
8 | mypy==1.8.0
9 | mypy-extensions==1.0.0
10 | pytest==8.1.1
11 |
12 | # Types packages
13 | pyarrow-stubs==10.0.1.7
14 | types-requests==2.31.0.20240125
15 | types-setuptools==69.0.0.20240125
16 |
17 | # Local import libraries that we don't want to put in global requirements.txt
18 | wandb==0.16.4
19 | aim==3.19.1
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from fms.models.llama import LLaMA, LLaMAConfig
3 |
4 |
5 | @pytest.fixture
6 | def narrow_model(request):
7 | return LLaMA(
8 | LLaMAConfig(src_vocab_size=1, emb_dim=1, nheads=1, nlayers=request.param)
9 | )
10 |
11 |
12 | @pytest.fixture
13 | def narrow_model_factory(request):
14 | class NarrowModelFactory:
15 | def create(self):
16 | return LLaMA(
17 | LLaMAConfig(
18 | src_vocab_size=1, emb_dim=1, nheads=1, nlayers=request.param
19 | )
20 | )
21 |
22 | return NarrowModelFactory()
23 |
--------------------------------------------------------------------------------
/tests/test_datasets.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os
3 | import tempfile
4 | from collections import Counter
5 | from copy import deepcopy
6 | from itertools import chain
7 |
8 | import pyarrow as pa
9 | import torch
10 |
11 | from fms_fsdp.utils.dataset_utils import *
12 |
13 |
14 | # Generates test data in a temp directory, and returns that tempdir object.
15 | # (file path can be retrieved via tempdir.name)
16 | # Two dataset folders: one has a large shardfile (100x100), other has two small shardfiles (50x50)
17 | def generate_sequential_multidata():
18 | tmpdir = tempfile.TemporaryDirectory()
19 | schema = pa.schema([pa.field("tokens", pa.uint32())])
20 |
21 | os.mkdir(os.path.join(tmpdir.name, "dataset_1"))
22 | os.mkdir(os.path.join(tmpdir.name, "dataset_2"))
23 | os.mkdir(os.path.join(tmpdir.name, "dataset_2", "subfolder"))
24 | with pa.ipc.new_file(
25 | os.path.join(tmpdir.name, "dataset_1/fullshard.arrow"), schema
26 | ) as writer:
27 | for i in range(100):
28 | out = list(range(i * 100, i * 100 + 100))
29 | writer.write(pa.record_batch([out], schema=schema))
30 |
31 | with pa.ipc.new_file(
32 | os.path.join(tmpdir.name, "dataset_2/quartershard_1.arrow"), schema
33 | ) as writer:
34 | for i in range(50):
35 | out = list(range(i * 50, i * 50 + 50))
36 | writer.write(pa.record_batch([out], schema=schema))
37 |
38 | with pa.ipc.new_file(
39 | os.path.join(tmpdir.name, "dataset_2/subfolder/quartershard_2.arrow"), schema
40 | ) as writer:
41 | for i in range(50):
42 | out = list(range(2500 + i * 50, 2500 + i * 50 + 50))
43 | writer.write(pa.record_batch([out], schema=schema))
44 |
45 | # Make metadata file
46 | os.mkdir(os.path.join(tmpdir.name, "meta"))
47 | f = open(os.path.join(tmpdir.name, "meta", "combined_counts.csv"), "w")
48 | f.write("dataset/filename,documents,tokens\n")
49 | f.write("/dataset_1/fullshard.arrow,100,10000\n")
50 | f.write("/dataset_2/quartershard_1.arrow,50,2500\n")
51 | f.write("/dataset_2/subfolder/quartershard_2.arrow,50,2500\n")
52 | f.close()
53 |
54 | return tmpdir
55 |
56 |
57 | # Make mock data for re-use. Returns directory path.
58 | tmpdir = generate_sequential_multidata()
59 |
60 |
61 | # REPEATED CHECKS
62 | # Checks take a dataset definition (and any other args), instantiate it, and perform a single unit test
63 | # For X_check see corresponding test_X
64 |
65 |
66 | def count_check(d, ntok, alldoc, allpercent):
67 | # Check that tokens tracked matches tokens seen, and docs tracked matches docs seen
68 | # d is a lambda for a fully-defined dataset (i.e. d() instantiates the dataset)
69 | assert (
70 | d.tokens_seen == ntok
71 | ), f"Tokens tracked {d.tokens_seen} failed to match target {ntok}"
72 | assert (
73 | d.docs_seen == alldoc
74 | ), f"Total document count {d.docs_seen} does not match target {alldoc}"
75 | coverage = d.percent_seen
76 | assert (
77 | abs(coverage - allpercent) < 1e-4
78 | ), f"Percent coverage {coverage} is not within 1e-4 of {allpercent}"
79 |
80 |
81 | def multi_reload_stress_check(d):
82 | # Perform the reload stress test for different numbers of steps before and after checkpoint
83 | # d is a lambda for a fully-defined dataset (i.e. d() instantiates the dataset)
84 |
85 | def reload_stress(datasets, datasets2, steps1, steps2):
86 | # Perform the 5-step reload stress test (see test_multi_reload_stress)
87 |
88 | loaders = [iter(d) for d in datasets]
89 |
90 | for i in range(steps1):
91 | [next(l) for l in loaders]
92 |
93 | states = [deepcopy(d.state_dict()) for d in datasets]
94 |
95 | [d.load_state_dict(states) for d in datasets2]
96 |
97 | loaders2 = [iter(d) for d in datasets2]
98 |
99 | for k in range(steps2):
100 | for i in range(3):
101 | out1 = next(loaders[i])
102 | out2 = next(loaders2[i])
103 | assert len(out1) == len(
104 | out2
105 | ), f"Dataloader {i} in step {k} has mismatched length: {len(out1)} vs {len(out2)}"
106 | for j in range(len(out1)):
107 | assert (
108 | out1[j] == out2[j]
109 | ), f"Dataloader {i} in step {k} has mismatched token in position {j}: {out1[j]} vs {out2[j]}"
110 |
111 | steps1 = [0, 1, 10, 100, 1000]
112 | steps2 = [100, 200, 300, 400, 500]
113 | for i in range(len(steps1)):
114 | # Reset between tests (instantiate fresh datasets)
115 | reload_stress(d(), d(), steps1[i], steps2[i])
116 |
117 |
118 | def single_epoch_check(d, do_countcheck=False):
119 | # For a single loader on dataset_1, check that every doc appears once per epoch
120 | # d is a lambda for a dataset (i.e. d() instantiates the dataset)
121 | dataset = d(datasets=["dataset_1"])
122 | ins = []
123 | loader = iter(dataset)
124 | for i in range(100):
125 | out = next(loader)
126 | ins.append(out[0])
127 |
128 | for i in range(100):
129 | assert (
130 | i * 100 in ins
131 | ), f"Line starting with {i * 100} failed to appear in generated data"
132 |
133 | if do_countcheck:
134 | # Check state flags tracking correctly
135 | count_check(dataset, 100 * 100, 100, 100)
136 |
137 |
138 | def two_epoch_check(d, do_countcheck=False):
139 | # For a single loader on dataset_1, check that every doc appears twice per two epochs
140 | # d is a lambda for a dataset (i.e. d() instantiates the dataset)
141 | dataset = d(datasets=["dataset_1"])
142 | ins = []
143 | loader = iter(dataset)
144 | for i in range(100 * 2):
145 | out = next(loader)
146 | ins.append(out[0])
147 |
148 | for i in range(100):
149 | key = ins.pop(0)
150 | assert (
151 | key in ins
152 | ), f"Line starting with {key} failed to appear a second time in generated data"
153 |
154 | if do_countcheck:
155 | # Check state flags tracking correctly
156 | count_check(dataset, 100 * 100 * 2, 200, 200)
157 |
158 |
159 | def chunk_check(d, do_countcheck=False):
160 | # For a single loader on dataset_1, check that every doc chunks properly and that all chunks appear in one epoch
161 | # d is a lambda for a dataset (i.e. d() instantiates the dataset)
162 | dataset = d(datasets=["dataset_1"], max_chunksize=50)
163 | ins = []
164 | loader = iter(dataset)
165 | for i in range(300):
166 | out = next(loader)
167 | if i % 3 != 2:
168 | assert (
169 | len(out) == 50
170 | ), f"Line length {len(out)} does not match chunk size 50"
171 | else:
172 | assert (
173 | out[0] == -1
174 | ), f"Chunk 3 of document {i} is not delimiter token, but {out}"
175 | ins.append(out[0])
176 |
177 | for i in range(200):
178 | assert (
179 | i * 50 in ins
180 | ), f"Chunk starting with {i * 50} failed to appear in generated data"
181 |
182 | if do_countcheck:
183 | count_check(dataset, 100 * 100, 100, 100)
184 |
185 |
186 | def two_loader_check(d, do_countcheck=False):
187 | # For two loaders on dataset_1, check that every doc appears once per epoch, collectively
188 | # d is a lambda for a dataset (i.e. d() instantiates the dataset)
189 | dataset1 = d(datasets=["dataset_1"], worldsize=2, rank=0)
190 | dataset2 = d(datasets=["dataset_1"], worldsize=2, rank=1)
191 | ins = []
192 | loader = iter(dataset1)
193 | for i in range(50):
194 | out = next(loader)
195 | ins.append(out[0])
196 | loader = iter(dataset2)
197 | for i in range(50):
198 | out = next(loader)
199 | ins.append(out[0])
200 |
201 | for i in range(100):
202 | assert (
203 | i * 100 in ins
204 | ), f"Line starting with {i * 100} failed to appear in generated data"
205 |
206 | if do_countcheck:
207 | count_check(dataset1, 50 * 100, 50, 100)
208 | count_check(dataset2, 50 * 100, 50, 100)
209 |
210 |
211 | def multi_file_check(d, do_countcheck=False):
212 | # For a single loader on dataset 2, check that every doc appears once per epoch
213 | # d is a lambda for a dataset (i.e. d() instantiates the dataset)
214 | dataset = d(datasets=["dataset_2"])
215 | ins = []
216 | loader = iter(dataset)
217 | for i in range(100):
218 | out = next(loader)
219 | ins.append(out[0])
220 |
221 | for i in range(100):
222 | assert (
223 | i * 50 in ins
224 | ), f"Line starting with {i * 50} failed to appear in generated data"
225 |
226 | if do_countcheck:
227 | count_check(dataset, 100 * 50, 100, 100)
228 |
229 |
230 | def chunk_weight_check(w1, w2, d, do_countcheck=False):
231 | # For a single loader on combined datasets, with given oversamples, chunksize 50: check that chunks appear the proper number of times
232 | # d is a lambda for a dataset (i.e. d() instantiates the dataset)
233 | dataset = d(datasets=["dataset_1", "dataset_2"], weights=[w1, w2], max_chunksize=50)
234 | ins = []
235 | loader = iter(dataset)
236 | for i in range(3 * w1 * 100 + 2 * w2 * 100):
237 | out = next(loader)
238 | if len(out) > 1:
239 | ins.append(out[0])
240 |
241 | check = Counter(ins)
242 | for i in range(200):
243 | if i < 100:
244 | assert (
245 | check[i * 50] == w1 + w2
246 | ), f"Chunk starting with {i * 50} appeared {check[i*50]} times rather than {w1+w2}"
247 | else:
248 | assert (
249 | check[i * 50] == w1
250 | ), f"Chunk starting with {i * 50} appeared {check[i*50]} times rather than {w1}"
251 |
252 |
253 | def reload_epoch_check(loader):
254 | # Single shard, two loaders: do exactly 1/3 of an epoch, checkpoint, reload to same number of workers.
255 | # Complete the epoch and verify that no loaded chunks are revisiting old chunks.
256 | datasets = [
257 | loader(
258 | rank=i,
259 | worldsize=2,
260 | max_chunksize=40,
261 | )
262 | for i in range(2)
263 | ] # Length 300
264 | loaders = [iter(d) for d in datasets]
265 |
266 | ins = []
267 | for _ in range(50):
268 | out = next(loaders[0])
269 | ins.append(out[0])
270 | for _ in range(50):
271 | out = next(loaders[1])
272 | ins.append(out[0])
273 |
274 | states = [d.state_dict() for d in datasets]
275 |
276 | datasets2 = [
277 | loader(
278 | rank=i,
279 | worldsize=2,
280 | max_chunksize=40,
281 | )
282 | for i in range(2)
283 | ] # Length 300
284 | [d.load_state_dict(states) for d in datasets2]
285 | loaders2 = [iter(d) for d in datasets2]
286 |
287 | for j in range(100):
288 | for i in range(2):
289 | out = next(loaders2[i])
290 | assert (
291 | out[0] not in ins
292 | ), f"Step {j+1}, dataset {i+1}: chunk starting with {out[0]} has already appeared in the epoch"
293 |
294 |
295 | def reload_single_epoch_check(loader):
296 | # Single shard, two loaders: advance 37 steps, checkpoint, reload to same number of workers.
297 | # Run a full epoch and verify that all data appears once and only once.
298 | datasets = [
299 | loader(
300 | rank=i,
301 | worldsize=2,
302 | max_chunksize=40,
303 | )
304 | for i in range(2)
305 | ] # Length 300
306 | loaders = [iter(d) for d in datasets]
307 |
308 | for _ in range(37):
309 | out = next(loaders[0])
310 | for _ in range(37):
311 | out = next(loaders[1])
312 |
313 | states = [d.state_dict() for d in datasets]
314 |
315 | datasets2 = [
316 | loader(
317 | rank=i,
318 | worldsize=2,
319 | max_chunksize=40,
320 | )
321 | for i in range(2)
322 | ] # Length 300
323 | [d.load_state_dict(states) for d in datasets2]
324 | loaders2 = [iter(d) for d in datasets2]
325 |
326 | ins = []
327 | for _ in range(150):
328 | out = next(loaders2[0])
329 | assert out[0] not in ins, (ins, out[0])
330 | ins.append(out[0])
331 | for _ in range(150):
332 | out = next(loaders2[1])
333 | ins.append(out[0])
334 |
335 | assert len(ins) == len(
336 | set(ins)
337 | ), f"Full epoch output contains {len(ins)} values but only {len(set(ins))} unique"
338 |
339 |
340 | def single_doc_bos_eos_check(loader, do_bos):
341 | # Single shard, single loader: load two chunks, verify that sizes match when BOS is on/off
342 | expected_vals = (
343 | [
344 | [99, 3],
345 | [100, 2],
346 | [101, 1],
347 | [102, 102],
348 | [102, 102],
349 | ]
350 | if do_bos
351 | else [
352 | [99, 2],
353 | [100, 1],
354 | [101, 101],
355 | [101, 101],
356 | [101, 101],
357 | ]
358 | )
359 | for i, c in enumerate([99, 100, 101, 102, 103]):
360 | dataset = loader(
361 | rank=0, worldsize=1, max_chunksize=c, bos_token=100 if do_bos else None
362 | )
363 | d = iter(dataset)
364 | for _ in range(10):
365 | c1 = next(d)
366 | c2 = next(d)
367 | assert (
368 | len(c1) == expected_vals[i][0]
369 | ), f"Expected size {expected_vals[i][0]} in first chunk, got {len(c1)}"
370 | assert (
371 | len(c2) == expected_vals[i][1]
372 | ), f"Expected size {expected_vals[i][1]} in second chunk, got {len(c2)}"
373 | if c == 99:
374 | assert (
375 | c1[-1] == c2[0] - 1
376 | ), f"Expected chunk 2 to follow chunk1, got {c1[-1]} and {c2[0]}"
377 |
378 |
379 | def single_epoch_loader_worker_check(d, n_workers=0):
380 | # For dataset_1 partitioned over logical shards / workers / ranks,
381 | # check that every doc appears once per epoch
382 | loaders = [
383 | torch.utils.data.DataLoader(x, num_workers=n_workers, batch_size=1) for x in d
384 | ]
385 | loaders = [iter(l) for l in loaders]
386 | n_steps = 100 // len(loaders)
387 | ins = []
388 | for _ in range(n_steps):
389 | for l in loaders:
390 | out = next(l)
391 | ins.append(out[0].item())
392 |
393 | for i in range(100):
394 | assert (
395 | i * 100 in ins
396 | ), f"Line starting with {i * 100} failed to appear in generated data: worldsize {len(loaders)}, n_workers {n_workers}"
397 |
398 |
399 | # BASE DATASET TESTS
400 |
401 |
402 | def basic_loader(
403 | rank=0,
404 | worldsize=1,
405 | datasets=["dataset_1"],
406 | max_chunksize=1000,
407 | bos_token=None,
408 | ):
409 | assert len(datasets) == 1, "Basic loader takes only 1 dataset"
410 | return StreamingDocDataset(
411 | os.path.join(tmpdir.name, datasets[0]),
412 | rank,
413 | worldsize,
414 | ArrowHandler(),
415 | -1,
416 | max_chunksize=max_chunksize,
417 | bos_token=bos_token,
418 | )
419 |
420 |
421 | def basic_sampler(
422 | rank=0, worldsize=1, datasets=["dataset_1"], weights=[1], max_chunksize=1000
423 | ):
424 | return SamplingDataset(
425 | tmpdir.name,
426 | basic_loader(rank, worldsize, datasets[:1], max_chunksize, None),
427 | -1,
428 | datasets,
429 | weights,
430 | )
431 |
432 |
433 | def basic_scalable(
434 | rank=0,
435 | worldsize=1,
436 | datasets=["dataset_1"],
437 | max_chunksize=1000,
438 | n_logical_shards=7,
439 | bos_token=None,
440 | ):
441 | assert len(datasets) == 1, "Basic loader takes only 1 dataset"
442 | return ScalableShardDataset(
443 | basic_loader(rank, worldsize, datasets, max_chunksize, bos_token),
444 | -1,
445 | n_logical_shards,
446 | )
447 |
448 |
449 | def basic_sampler_scalable(
450 | rank=0,
451 | worldsize=1,
452 | datasets=["dataset_1"],
453 | weights=[1],
454 | max_chunksize=1000,
455 | n_logical_shards=7,
456 | ):
457 | return SamplingDataset(
458 | tmpdir.name,
459 | basic_scalable(
460 | rank, worldsize, datasets[:1], max_chunksize, n_logical_shards, None
461 | ),
462 | -1,
463 | datasets,
464 | weights,
465 | )
466 |
467 |
468 | def test_single_epoch():
469 | # Single shard, single loader: every line appears once in an epoch
470 | single_epoch_check(basic_loader, True)
471 | single_epoch_check(basic_scalable)
472 | single_epoch_check(basic_sampler)
473 | single_epoch_check(basic_sampler_scalable)
474 |
475 |
476 | def test_two_epoch():
477 | # Single shard, single loader: every line appears twice in two epochs
478 | two_epoch_check(basic_loader, True)
479 | two_epoch_check(basic_scalable)
480 | two_epoch_check(basic_sampler)
481 | two_epoch_check(basic_sampler_scalable)
482 |
483 |
484 | def test_chunk():
485 | # Single shard, single loader, two chunks/doc plus a delimiter token: every chunk appears once in an epoch
486 | chunk_check(functools.partial(basic_loader, max_chunksize=50), True)
487 | chunk_check(functools.partial(basic_scalable, max_chunksize=50))
488 | chunk_check(functools.partial(basic_sampler, max_chunksize=50))
489 | chunk_check(functools.partial(basic_sampler_scalable, max_chunksize=50))
490 |
491 |
492 | def test_two_loader():
493 | # Single shard, two loaders: every line appears once per epoch, collectively
494 | two_loader_check(basic_loader, True)
495 | two_loader_check(functools.partial(basic_scalable, n_logical_shards=8))
496 | two_loader_check(basic_sampler)
497 | two_loader_check(functools.partial(basic_sampler_scalable, n_logical_shards=8))
498 |
499 |
500 | def test_multi_file():
501 | # Multiple shard files, single loader: every line appears once in an epoch
502 | multi_file_check(basic_loader, True)
503 | multi_file_check(basic_scalable)
504 | multi_file_check(basic_sampler)
505 | multi_file_check(basic_sampler_scalable)
506 |
507 |
508 | def test_reload_epoch():
509 | # Single shard, two loaders: check that reloading mid-epoch does not cause data to repeat while finishing the epoch
510 | reload_epoch_check(basic_loader)
511 | reload_epoch_check(functools.partial(basic_scalable, n_logical_shards=8))
512 | reload_epoch_check(basic_sampler)
513 | reload_epoch_check(functools.partial(basic_sampler_scalable, n_logical_shards=8))
514 |
515 |
516 | def test_reload_complete_epoch():
517 | # Single shard, two loaders: check that reloading mid-epoch can still complete a full epoch
518 | reload_single_epoch_check(basic_loader)
519 | reload_single_epoch_check(functools.partial(basic_scalable, n_logical_shards=8))
520 | reload_single_epoch_check(basic_sampler)
521 | reload_single_epoch_check(
522 | functools.partial(basic_sampler_scalable, n_logical_shards=8)
523 | )
524 |
525 |
526 | def test_eos_bos_chunking():
527 | # Single shard, single loader: check that enabling/disabling bos tokens maintains correct chunking behavior
528 | single_doc_bos_eos_check(basic_loader, False)
529 | single_doc_bos_eos_check(basic_loader, True)
530 | single_doc_bos_eos_check(basic_scalable, False)
531 | single_doc_bos_eos_check(basic_scalable, True)
532 |
533 |
534 | # SUBDATASET WEIGHTING CHECKS
535 |
536 |
537 | def test_sampler_rates():
538 | """
539 | A test for SamplingDataset with Streaming_ and Scalable_ subdatasets.
540 | On the full dataset, with varying weights, on a single worker: verify that loaders pull subdatasets at regular intervals
541 | (verifying that they're regularly picking the most-underviewed subdataset at each step).
542 | """
543 | weights = [[1, 1], [2, 1], [2, 3], [2, 5]]
544 | target_rate = [3, 2, 4, 6]
545 | burnin = [3, 0, 4, 6]
546 |
547 | # Dataset1 docs are twice the length of dataset2. Burnin required to reach equilibrium.
548 | # Expected sequences for each case are:
549 | # 2 1 2 1 2 2 (1 2 2)...
550 | # 1 2 1 2 (1 2)...
551 | # 2 1 2 2 1 2 2 2 (1 2 2 2)...
552 | # 2 1 2 2 2 2 1 2 2 2 2 2 (1 2 2 2 2 2)...
553 |
554 | def check_rates(w, t, b, m):
555 | s = []
556 | d = m(datasets=["dataset_1", "dataset_2"], weights=w)
557 | l = iter(d)
558 | for i in range(b):
559 | s.append(len(next(l)))
560 | for i in range(100):
561 | out = next(l)
562 | s.append(len(out))
563 | if i % t == 0:
564 | assert (
565 | len(out) == 101
566 | ), f"Output {i} length {len(out)} does not match expected 101. Sequence so far: {s}"
567 | else:
568 | assert (
569 | len(out) == 51
570 | ), f"Output {i} length {len(out)} does not match expected 51. Sequence so far: {s}"
571 |
572 | for i in range(3):
573 | for m in [basic_sampler, basic_sampler_scalable]:
574 | check_rates(weights[i], target_rate[i], burnin[i], m)
575 |
576 |
577 | # STRESS TEST
578 |
579 |
580 | def test_multi_reload_stress():
581 | """
582 | For each nontrivial layer of the dataset pipeline:
583 | For each combo of steps:
584 | Initialize two identical datasets
585 | Take n steps with the first one
586 | Save checkpoint
587 | Load checkpoint into second dataset
588 | Take k steps with both datasets, check that outputs are identical
589 | Parameters are chosen to ensure messy states (non-divisible chunk sizes, shard numbers, n_workers, buffer_length, etc.)
590 | """
591 | # Shard doc dataset
592 | d1 = lambda: [
593 | StreamingDocDataset(
594 | os.path.join(tmpdir.name, "dataset_2"),
595 | i,
596 | 3,
597 | ArrowHandler(),
598 | -1,
599 | max_chunksize=17,
600 | )
601 | for i in range(3)
602 | ]
603 | multi_reload_stress_check(d1)
604 |
605 | # Scalable shard dataset
606 | d2 = lambda x: [ScalableShardDataset(d, -1, n_logical_shards=15) for d in x]
607 | multi_reload_stress_check(lambda: d2(d1()))
608 |
609 | # Sampling dataset
610 | d3 = lambda x: [
611 | SamplingDataset(
612 | tmpdir.name,
613 | d,
614 | -1,
615 | datasets=["dataset_1", "dataset_2"],
616 | weights=[3, 5],
617 | )
618 | for d in x
619 | ]
620 | multi_reload_stress_check(lambda: d3(d1()))
621 |
622 | # Nested scalable sampling dataset
623 | d4 = lambda: d3(d2(d1()))
624 | multi_reload_stress_check(d4)
625 |
626 | # Add buffer dataset
627 | d5 = lambda x: [BufferDataset(d, 73, pack_hard=True, bos_token=-1) for d in x]
628 | multi_reload_stress_check(lambda: d5(d4()))
629 |
630 | # Add preload buffer dataset
631 | d6 = lambda x: [PreloadBufferDataset(d, 99) for d in x]
632 | # preload / sample / scale / doc pipeline
633 | multi_reload_stress_check(lambda: d6(d5(d4())))
634 |
635 |
636 | # SCALABLEDATASET TESTS
637 |
638 |
639 | def test_scalable_partitioning():
640 | """
641 | Test that partitioning occurs correctly when rescaling up or down, including to non-multiples of the original
642 | physical worker count. Start with 4 workers with 12 logical shards, and for each of [1,2,3,6,12], verify that:
643 | 1) no overlap exists between workers and 2) in over one epoch's worth of steps, each data point appears at least once
644 | """
645 | l1 = lambda r, w: basic_scalable(r, w, max_chunksize=200, n_logical_shards=12)
646 | l2 = lambda r, w: basic_sampler_scalable(
647 | r, w, max_chunksize=200, n_logical_shards=12
648 | )
649 | for layer in [l1, l2]:
650 | datasets = [layer(i, 4) for i in range(4)] # 25 steps per epoch
651 | loaders = [iter(d) for d in datasets]
652 |
653 | for _ in range(50):
654 | [next(l) for l in loaders]
655 |
656 | states = [d.state_dict() for d in datasets]
657 |
658 | for worldsize in [1, 2, 3, 6, 12]:
659 | datasets = [layer(i, worldsize) for i in range(worldsize)]
660 | [d.load_state_dict(states) for d in datasets]
661 | loaders = [iter(d) for d in datasets]
662 | outs = [[] for _ in datasets]
663 | steps = int(100 / worldsize * 1.25)
664 | for i in range(steps):
665 | for j, l in enumerate(loaders):
666 | outs[j].append(next(l)[0])
667 |
668 | # Check for non-overlap
669 | for i in range(len(datasets)):
670 | for j in range(i + 1, len(datasets)):
671 | outi = set(outs[i])
672 | outj = set(outs[j])
673 | for t in outi:
674 | assert (
675 | t not in outj
676 | ), f"Overlapping value {t} detected in worker {i} and {j}: {outi}, {outj}"
677 | for t in outj:
678 | assert (
679 | t not in outi
680 | ), f"Overlapping value {t} detected in worker {i} and {j}: {outi}, {outj}"
681 |
682 | # Check for completion
683 | allout = set(chain(*outs))
684 | for i in range(100):
685 | assert i * 100 in allout, f"Token {i*100} missing from outputs {allout}"
686 |
687 |
688 | def test_scalable_shard_reload_scale():
689 | """
690 | As test_reload_epoch, but in this case we scale from 2 workers to 4 (complete 1/3 epoch, reload, finish without duplication).
691 | Because logical shards won't all be the exact same length when checkpointed, we complete the epoch of the shortest of the new workers.
692 | """
693 | datasets = [
694 | basic_scalable(i, 2, max_chunksize=40, n_logical_shards=8) for i in range(2)
695 | ] # Length 300
696 | loaders = [iter(d) for d in datasets]
697 |
698 | ins = []
699 | for _ in range(50):
700 | out = next(loaders[0])
701 | ins.append(out[0])
702 | for _ in range(50):
703 | out = next(loaders[1])
704 | ins.append(out[0])
705 |
706 | states = [d.state_dict() for d in datasets]
707 |
708 | datasets2 = [
709 | basic_scalable(i, 4, max_chunksize=40, n_logical_shards=8) for i in range(4)
710 | ] # Length 300
711 | [d.load_state_dict(states) for d in datasets2]
712 | ndocs = [sum(d.n_docs_remaining) for d in datasets]
713 | print("n_docs_remaining from old loader:", ndocs)
714 | ndocs = [sum(d.n_docs_remaining) for d in datasets2]
715 | print("n_docs_remaining per new loader:", ndocs)
716 |
717 | loaders2 = [iter(d) for d in datasets2]
718 |
719 | print("Checking only", min(ndocs) * 3, "steps instead of full 50")
720 | for j in range(min(ndocs) * 3):
721 | for i in range(4):
722 | out = next(loaders2[i])
723 | assert (
724 | out[0] not in ins
725 | ), f"Step {j+1}, dataset {i+1}: chunk starting with {out[0]} has already appeared in the epoch"
726 |
727 |
728 | def test_scalable_sampler_reload_scale():
729 | """
730 | As test_reload_epoch, but in this case we scale from 2 workers to 4 (complete 1/3 epoch, reload, finish without duplication).
731 | Because logical shards and sampling ratios won't be exact, take a few extra steps then check that epoch is complete.
732 | """
733 | datasets = [
734 | basic_sampler_scalable(i, 2, max_chunksize=40, n_logical_shards=8)
735 | for i in range(2)
736 | ] # Length 300
737 | loaders = [iter(d) for d in datasets]
738 |
739 | ins = []
740 | for _ in range(50):
741 | out = next(loaders[0])
742 | ins.append(out[0])
743 | for _ in range(50):
744 | out = next(loaders[1])
745 | ins.append(out[0])
746 |
747 | states = [d.state_dict() for d in datasets]
748 |
749 | datasets2 = [
750 | basic_sampler_scalable(i, 4, max_chunksize=40, n_logical_shards=8)
751 | for i in range(4)
752 | ] # Length 300
753 | [d.load_state_dict(states) for d in datasets2]
754 | loaders2 = [iter(d) for d in datasets2]
755 |
756 | for i in range(4):
757 | for _ in range(55):
758 | out = next(loaders2[i])
759 | ins.append(out[0])
760 |
761 | for suf in [0, 40, 80]:
762 | for i in range(100):
763 | assert (
764 | i * 100 + suf in ins
765 | ), f"Expected value {i*100+suf} not found in output set {ins}"
766 |
767 |
768 | # BUFFERDATASET TESTS
769 |
770 |
771 | class RandCounter:
772 | # Spit out incremental counts of random length, uniformly sampled from 1 to 50
773 | def __init__(self):
774 | self.i = 0
775 | self.rank = 0
776 | self.worldsize = 1
777 | self.datapath = tmpdir.name
778 |
779 | def __iter__(self):
780 | while True:
781 | l = torch.randint(1, 50, [1]).item()
782 | yield list(range(self.i, self.i + l))
783 | self.i += l
784 |
785 |
786 | def test_buffer_format():
787 | # Using the RandCounter, verify that streams are reformed into correct-length buffers,
788 | # that final tokens match the predicted count, and that BOS/EOS add correctly
789 |
790 | for _ in range(100):
791 | # 100 trials of random length inputs
792 | base = RandCounter()
793 | dataset = BufferDataset(base, 100, pack_hard=True)
794 | loader = iter(dataset)
795 | for _ in range(100):
796 | out = next(loader)
797 | assert (
798 | len(out) == 100
799 | ), f"Length of output {len(out)} does not match specified 100"
800 | assert (
801 | out[-1] == 100 * 100 - 1
802 | ), f"Final token {out[-1]} does not match expected value {100*100-1}"
803 |
804 | # As above, but now with EOS tokens
805 | for _ in range(100):
806 | base = RandCounter()
807 | dataset = BufferDataset(base, 100, pack_hard=True, eos_token=-1)
808 | loader = iter(dataset)
809 | for i in range(100):
810 | out = next(loader)
811 | assert (
812 | len(out) == 100
813 | ), f"Length of output {len(out)} does not match specified 100"
814 | assert out[-1] == -1, f"Output {out} does not end in EOS"
815 | assert (
816 | out[-2] == 100 * 99 - 1
817 | ), f"Penultimate token {out[-2]} does not match expected value {100*99-1}"
818 |
819 | # As above, but now with BOS tokens
820 | for _ in range(100):
821 | base = RandCounter()
822 | dataset = BufferDataset(base, 100, pack_hard=True, bos_token=-1)
823 | loader = iter(dataset)
824 | for i in range(100):
825 | out = next(loader)
826 | assert (
827 | len(out) == 100
828 | ), f"Length of output {len(out)} does not match specified 100"
829 | assert out[0] == -1, f"Output {out} does not begin with BOS"
830 | assert (
831 | out[-1] == 100 * 99 - 1
832 | ), f"Final token {out[-1]} does not match expected value {100*99-1}"
833 |
834 |
835 | def test_buffer_delimiter_overlap():
836 | """
837 | Check that BOS adds correctly when absent, and refrains when present.
838 | Because doc delimiter token is also -1, BOS will add in the first instance, which shunts the delimiter token
839 | into the first slot in the next (and all subsequent) outputs. BOS should then refrain from adding.
840 | """
841 | dataset = basic_loader(max_chunksize=101)
842 | dataset = BufferDataset(dataset, 101, pack_hard=True, bos_token=-1)
843 | loader = iter(dataset)
844 | for _ in range(100):
845 | out = next(loader)
846 | assert (
847 | len(out) == 101
848 | ), f"Length of output {len(out)} does not match specified 101"
849 | assert out[0] == -1, f"Output {out} does not begin with BOS"
850 | assert (
851 | out[-1] % 100 == 99
852 | ), f"Final token {out[-1]} does not end in expected value 99"
853 |
854 |
855 | # PRELOADBUFFERDATASET TESTS
856 |
857 |
858 | class SteadyCounter:
859 | # Spit out incremental counts of constant length l
860 | def __init__(self, l):
861 | self.i = 0
862 | self.rank = 0
863 | self.worldsize = 1
864 | self.datapath = tmpdir.name
865 | self.l = l
866 |
867 | def __iter__(self):
868 | while True:
869 | yield list(range(self.i, self.i + self.l))
870 | self.i += self.l
871 |
872 |
873 | def test_preload_buffer_uniformity():
874 | """
875 | With underlying SteadyCounter and window size 200, take 1000 steps.
876 | Ensure 95% of values between 0 and 100 are emitted.
877 | """
878 | dataset = PreloadBufferDataset(SteadyCounter(1), 200)
879 | loader = iter(dataset)
880 | outs = []
881 |
882 | for _ in range(1000):
883 | x = next(loader)[0]
884 | if x < 100:
885 | outs.append(x)
886 |
887 | assert len(outs) > 95, f"Only {len(outs)} values <100 detected"
888 |
889 |
890 | # CHECKPOINTDATASET TESTS
891 |
892 |
893 | def test_checkpoint_reload_match():
894 | """
895 | Check that the auto-checkpointer saves and loads correctly, and that loaded checkpoints
896 | resume properly (matching the continued behavior of the saved ones)
897 | """
898 | datasets = [
899 | basic_sampler(i, 3, ["dataset_1", "dataset_2"], [3, 5], max_chunksize=17)
900 | for i in range(3)
901 | ]
902 | datasets = [BufferDataset(d, 73, pack_hard=True, bos_token=-1) for d in datasets]
903 | datasets = [
904 | CheckpointDataset(x, os.path.join(tmpdir.name, "ckp_test"), 100, 2)
905 | for x in datasets
906 | ]
907 | loaders = [
908 | torch.utils.data.DataLoader(
909 | x, num_workers=1, batch_size=2, prefetch_factor=1, persistent_workers=True
910 | )
911 | for x in datasets
912 | ]
913 | loaders = [iter(x) for x in loaders]
914 | for _ in range(100):
915 | for loader in loaders:
916 | next(loader)
917 |
918 | # Assert checkpoint exists and is properly formatted
919 | ckps = os.listdir(os.path.join(tmpdir.name, "ckp_test", "checkpoints"))
920 | assert len(ckps) == 1, f"Expected only one checkpoint (found {len(ckps)})"
921 | ckp_shards = os.listdir(
922 | os.path.join(tmpdir.name, "ckp_test", "checkpoints", ckps[0])
923 | )
924 | assert (
925 | len(ckp_shards) == 3
926 | ), f"Expected three checkpoint shards (found {len(ckp_shards)})"
927 |
928 | # Create a second loader, pointing to first's checkpoint
929 | datasets2 = [
930 | basic_sampler(i, 3, ["dataset_1", "dataset_2"], [3, 5], max_chunksize=17)
931 | for i in range(3)
932 | ]
933 | datasets2 = [BufferDataset(d, 73, pack_hard=True, bos_token=-1) for d in datasets2]
934 | datasets2 = [
935 | CheckpointDataset(x, os.path.join(tmpdir.name, "ckp_test"), 1000, 2)
936 | for x in datasets2
937 | ]
938 | [d.setup() for d in datasets2]
939 |
940 | # Assert checkpoints have loaded correctly
941 | for d in datasets2:
942 | assert d.step == 100, f"Expected to load back to step 100, got {d.step}"
943 |
944 | # Continue iterating, verify matching behavior
945 | loaders2 = [
946 | torch.utils.data.DataLoader(
947 | x, num_workers=1, batch_size=2, prefetch_factor=1, persistent_workers=True
948 | )
949 | for x in datasets2
950 | ]
951 | loaders2 = [iter(x) for x in loaders2]
952 | for _ in range(300):
953 | for loader, loader2 in zip(loaders, loaders2):
954 | out = sum(next(loader2))
955 | targ = sum(next(loader))
956 | assert len(out) == len(
957 | targ
958 | ), f"Expected same output lengths, got {len(out)}, {len(targ)}"
959 | for i, (x, y) in enumerate(zip(out, targ)):
960 | assert x == y, f"Mismatch in position {i}: got {x}, {y}"
961 |
962 |
963 | # MULTIPROCESS DATALOADER WORKER TESTS
964 |
965 |
966 | def test_multiprocess_epoch():
967 | """
968 | Check that ScalableShardDataset partitions correctly over various worldsize / n_worker
969 | combinations. A single epoch should contain each datapoint exactly once.
970 | """
971 | n_workers = [0, 2]
972 | worldsizes = [2, 5]
973 | for n in n_workers:
974 | for w in worldsizes:
975 | d = [basic_scalable(i, w, n_logical_shards=20) for i in range(w)]
976 | # Add a dummy wrapper (append some pads) to test correct wrapper behavior
977 | d = [BufferDataset(x, 110, False, pad_token=-1) for x in d]
978 | single_epoch_loader_worker_check(d, n)
979 |
--------------------------------------------------------------------------------
/tests/test_selective_ac.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import pytest
4 | from fms.models.llama import LLaMABlock
5 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
6 | CheckpointWrapper,
7 | )
8 |
9 | from fms_fsdp.policies import apply_fsdp_checkpointing
10 |
11 |
12 | @pytest.mark.parametrize("narrow_model_factory", [15], indirect=True)
13 | def test_selective_ac(narrow_model_factory):
14 | apply_ac = partial(apply_fsdp_checkpointing, block=LLaMABlock)
15 |
16 | model = narrow_model_factory.create()
17 | apply_ac(model, p=0)
18 | expected = [False] * 15
19 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected
20 |
21 | model = narrow_model_factory.create()
22 | apply_ac(model, p=1 / 100)
23 | expected = [False] * 15
24 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected
25 |
26 | model = narrow_model_factory.create()
27 | apply_ac(model, p=1 / 5)
28 | expected = [False, False, True, False, False] * 3
29 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected
30 |
31 | model = narrow_model_factory.create()
32 | apply_ac(model, p=1 / 3)
33 | expected = [False, True, False] * 5
34 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected
35 |
36 | model = narrow_model_factory.create()
37 | apply_ac(model, p=1 / 2)
38 | expected = [True, False] * 7 + [True]
39 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected
40 |
41 | model = narrow_model_factory.create()
42 | apply_ac(model, p=3 / 5)
43 | expected = [True, False, True, False, True] * 3
44 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected
45 |
46 | model = narrow_model_factory.create()
47 | apply_ac(model, p=2 / 3)
48 | expected = [True, False, True] * 5
49 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected
50 |
51 | model = narrow_model_factory.create()
52 | apply_ac(model, p=1)
53 | expected = [True] * 15
54 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected
55 |
56 | model = narrow_model_factory.create()
57 | apply_ac(model, p=5 / 3)
58 | expected = [True] * 15
59 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected
60 |
61 | model = narrow_model_factory.create()
62 | apply_ac(model, p=-1)
63 | expected = [False] * 15
64 | assert [isinstance(block, CheckpointWrapper) for block in model.layers] == expected
65 |
--------------------------------------------------------------------------------