├── .gitignore
├── LICENSE
├── README.md
├── docs
└── flash-stu.webp
├── flash_stu
├── __init__.py
├── config.py
├── layers
│ ├── __init__.py
│ ├── attention_layer.py
│ └── stu_layer.py
├── model.py
├── modules
│ ├── __init__.py
│ ├── attention.py
│ ├── stu.py
│ └── swiglu.py
└── utils
│ ├── __init__.py
│ ├── numerics.py
│ ├── stu_utils.py
│ └── training.py
├── pyproject.toml
└── training
├── config.json
├── data.py
├── dataloader.py
├── distributed.py
├── example.py
└── job.slurm
/.gitignore:
--------------------------------------------------------------------------------
1 | # ===============================
2 | # Byte-compiled / Optimized / DLL files
3 | # ===============================
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # ===============================
9 | # Artifacts
10 | # ===============================
11 | .empty/
12 | **/*.out
13 | log*/
14 | ignore/
15 |
16 | # ===============================
17 | # Datasets
18 | # ===============================
19 | data/
20 |
21 | # ===============================
22 | # Scripts and tests
23 | # ===============================
24 | **/*test*.*
25 |
26 | # ===============================
27 | # Plotting (To be added in the next commit)
28 | # ===============================
29 | plot.py
30 |
31 | # ===============================
32 | # C Extensions
33 | # ===============================
34 | *.so
35 |
36 | # ===============================
37 | # Distribution / Packaging
38 | # ===============================
39 | .Python
40 | build/
41 | develop-eggs/
42 | dist/
43 | downloads/
44 | eggs/
45 | .eggs/
46 | lib/
47 | lib64/
48 | parts/
49 | sdist/
50 | var/
51 | wheels/
52 | share/python-wheels/
53 | *.egg-info/
54 | .installed.cfg
55 | *.egg
56 | MANIFEST
57 |
58 | # ===============================
59 | # PyInstaller
60 | # ===============================
61 | # Usually these files are written by a Python script from a template before PyInstaller builds
62 | # the exe, so as to inject date/other info into it.
63 | *.manifest
64 | *.spec
65 |
66 | # ===============================
67 | # Installer Logs
68 | # ===============================
69 | pip-log.txt
70 | pip-delete-this-directory.txt
71 |
72 | # ===============================
73 | # Unit Test / Coverage Reports
74 | # ===============================
75 | htmlcov/
76 | .tox/
77 | .nox/
78 | .coverage
79 | .coverage.*
80 | .cache
81 | nosetests.xml
82 | coverage.xml
83 | *.cover
84 | *.py,cover
85 | .hypothesis/
86 | .pytest_cache/
87 | cover/
88 |
89 | # ===============================
90 | # Translations
91 | # ===============================
92 | *.mo
93 | *.pot
94 |
95 | # ===============================
96 | # Django
97 | # ===============================
98 | *.log
99 | local_settings.py
100 | db.sqlite3
101 | db.sqlite3-journal
102 |
103 | # ===============================
104 | # Flask
105 | # ===============================
106 | instance/
107 | .webassets-cache
108 |
109 | # ===============================
110 | # Scrapy
111 | # ===============================
112 | .scrapy
113 |
114 | # ===============================
115 | # Sphinx Documentation
116 | # ===============================
117 | docs/_build/
118 |
119 | # ===============================
120 | # PyBuilder
121 | # ===============================
122 | .pybuilder/
123 | target/
124 |
125 | # ===============================
126 | # Jupyter Notebook
127 | # ===============================
128 | .ipynb_checkpoints
129 |
130 | # ===============================
131 | # IPython
132 | # ===============================
133 | profile_default/
134 | ipython_config.py
135 |
136 | # ===============================
137 | # pyenv
138 | # ===============================
139 | # For a library or package, you might want to ignore these files since the code is
140 | # intended to run in multiple environments; otherwise, check them in.
141 | # .python-version
142 |
143 | # ===============================
144 | # pipenv
145 | # ===============================
146 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
147 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
148 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
149 | # install all needed dependencies.
150 | #Pipfile.lock
151 |
152 | # ===============================
153 | # poetry
154 | # ===============================
155 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
156 | # Especially recommended for binary packages to ensure reproducibility.
157 | #poetry.lock
158 |
159 | # ===============================
160 | # pdm
161 | # ===============================
162 | .pdm.toml
163 | .pdm-python
164 | .pdm-build/
165 |
166 | # ===============================
167 | # PEP 582 (Used by tools like PyFlow, pdm)
168 | # ===============================
169 | __pypackages__/
170 |
171 | # ===============================
172 | # Celery
173 | # ===============================
174 | celerybeat-schedule
175 | celerybeat.pid
176 |
177 | # ===============================
178 | # SageMath
179 | # ===============================
180 | *.sage.py
181 |
182 | # ===============================
183 | # Environments
184 | # ===============================
185 | .env
186 | .venv
187 | env/
188 | venv/
189 | ENV/
190 | env.bak/
191 | venv.bak/
192 |
193 | # Artifact of uv
194 | hello.py
195 |
196 | # ===============================
197 | # Spyder
198 | # ===============================
199 | .spyderproject
200 | .spyproject
201 |
202 | # ===============================
203 | # Rope
204 | # ===============================
205 | .ropeproject
206 |
207 | # ===============================
208 | # mkdocs
209 | # ===============================
210 | /site
211 |
212 | # ===============================
213 | # mypy
214 | # ===============================
215 | .mypy_cache/
216 | .dmypy.json
217 | dmypy.json
218 |
219 | # ===============================
220 | # Pyre
221 | # ===============================
222 | .pyre/
223 |
224 | # ===============================
225 | # pytype
226 | # ===============================
227 | .pytype/
228 |
229 | # ===============================
230 | # Cython
231 | # ===============================
232 | cython_debug/
233 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2024 Windsor Nguyen
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ⚡️ Flash STU ⚡️
2 |
3 |
4 |

5 |
6 |
7 | ## Table of Contents
8 | 1. [Introduction](#introduction)
9 | 2. [Features](#features)
10 | 3. [Installation](#installation)
11 | 4. [Usage](#usage)
12 | 5. [Configuration](#configuration)
13 | 6. [Contributing](#contributing)
14 | 7. [License](#license)
15 | 8. [Acknowledgments](#acknowledgments)
16 |
17 | ## Introduction
18 |
19 | This repository complements the [Flash STU: Fast Spectral Transform Units](https://arxiv.org/abs/2409.10489) paper and contains an optimized, open-source PyTorch implementation of the Spectral Transform Unit (STU) as proposed in [*Spectral State Space Models*](https://arxiv.org/abs/2312.06837) by Agarwal et al. (2024).
20 |
21 | The [STU](stu.py) module is a fast and flexible building block that can be adapted into a wide range of neural network architectures, especially those that aim to solve tasks with long-range dependencies.
22 |
23 | ## Features
24 |
25 | - ⚡️ Fast convolutions using [Flash FFT](https://github.com/HazyResearch/flash-fft-conv)
26 | - 🚀 Fast, local attention using (sliding window) [Flash Attention](https://github.com/Dao-AILab/flash-attention)
27 | - 🌐 Support for distributed training using [DDP](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) and [FSDP](https://pytorch.org/docs/stable/fsdp.html)
28 |
29 | ## Installation
30 |
31 | > **Note**: CUDA is required to run code from this repository.
32 |
33 | This repository was tested with:
34 | - Python 3.12.5
35 | - PyTorch 2.4.1
36 | - Triton 3.0.0
37 | - CUDA 12.4
38 |
39 | and may be incompatible with other versions.
40 |
41 | 1. Install PyTorch with CUDA support:
42 | ```bash
43 | pip install torch --index-url https://download.pytorch.org/whl/cu124
44 | ```
45 |
46 | 2. Install required packages:
47 | ```bash
48 | pip install -e .
49 | ```
50 |
51 | 2. Install Flash Attention:
52 | ```bash
53 | MAX_JOBS=4 pip install flash-attn --no-build-isolation
54 | ```
55 |
56 | 3. Install Flash FFT:
57 | ```bash
58 | pip install git+https://github.com/HazyResearch/flash-fft-conv.git#subdirectory=csrc/flashfftconv
59 | pip install git+https://github.com/HazyResearch/flash-fft-conv.git
60 | ```
61 |
62 | Or from source:
63 | ```
64 | pip install git+https://github.com/windsornguyen/flash-stu.git
65 | ```
66 |
67 | ## Usage
68 |
69 | ### Using Flash STU
70 |
71 | Here is an example of how to import and use Flash STU:
72 | ``` python
73 | from flash_stu import FlashSTU, FlashSTUConfig, get_spectral_filters
74 | import torch
75 |
76 | device = torch.device('cuda') # Flash STU requires CUDA
77 |
78 | config = FlashSTUConfig(
79 | MODIFY_YOUR_ARGS_HERE,
80 | )
81 |
82 | phi = get_spectral_filters(
83 | config.seq_len,
84 | config.num_eigh,
85 | config.use_hankel_L,
86 | device,
87 | config.torch_dtype
88 | )
89 |
90 | model = FlashSTU(
91 | config,
92 | phi
93 | )
94 |
95 | y = model(x)
96 | ```
97 |
98 | ### Training
99 |
100 | An example LLM pretraining script is provided in [`example.py`](training/example.py) for you to test out the repository.
101 |
102 | If your compute cluster does not have internet access, you will need to pre-download the entire dataset before running the example training script.
103 |
104 | To download the dataset, run:
105 | ```bash
106 | cd training
107 | python data.py
108 | ```
109 |
110 | > **Note**: The FineWeb-Edu 10B-token sample is a relatively large dataset. It can be swapped out for something smaller, e.g. [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) (476.6M tokens).
111 |
112 | To begin training, make sure you are in the `training` directory and run the following command in your terminal:
113 |
114 | ```bash
115 | torchrun example.py
116 | ```
117 |
118 | If you are in a compute cluster that uses Slurm and [environment modules](https://modules.readthedocs.io/en/latest/index.html), you can submit a job using the following command:
119 | ```bash
120 | sbatch job.slurm
121 | ```
122 |
123 | Model configurations can be adjusted as needed in [`config.json`](training/config.json). Be sure to adjust the configurations of the [Slurm job](training/job.slurm) based on your cluster's constraints.
124 |
125 | > **Note**: PyTorch's `torch.compile` currently does not have great support for distributed wrapper modules like DDP or FSDP. If you encounter errors during training, try disabling `torch.compile`. For more information on `torch.compile`, see this [informal manual](https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.ivdr7fmrbeab).
126 |
127 |
128 | ## Contributing
129 |
130 | Contributions are welcomed! Writing performant distributed code is always tricky. We welcome contributors to:
131 |
132 | - Submit pull requests
133 | - Report issues
134 | - Help improve the project overall
135 |
136 | ## License
137 |
138 | Apache 2.0 License
139 |
140 | You can freely use, modify, and distribute the software, **even in proprietary products**, as long as you:
141 | - Include proper attribution
142 | - Include a copy of the license
143 | - Mention any changes made
144 |
145 | It also provides an express grant of patent rights from contributors.
146 |
147 | See the [LICENSE](LICENSE) file for more details.
148 |
149 | ## Acknowledgments
150 |
151 | Special thanks to (in no particular order):
152 | - Elad Hazan and the authors of the [Spectral State Space Models](https://arxiv.org/abs/2312.06837) paper
153 | - Isabel Liu, Yagiz Devre, Evan Dogariu
154 | - The Flash Attention team
155 | - The Flash FFT team
156 | - The PyTorch team
157 | - Princeton Research Computing and Princeton Language and Intelligence, for supplying compute
158 | - Andrej Karpathy, for his awesome [NanoGPT](https://github.com/karpathy/build-nanogpt) repository
159 |
160 | ## Citation
161 |
162 | If you use this repository, or otherwise find our work valuable, please cite Flash STU:
163 | ```
164 | @article{flashstu,
165 | title={Flash STU: Fast Spectral Transform Units},
166 | author={Y. Isabel Liu, Windsor Nguyen, Yagiz Devre, Evan Dogariu, Anirudha Majumdar, Elad Hazan},
167 | journal={arXiv preprint arXiv:2409.10489},
168 | year={2024},
169 | url={https://arxiv.org/abs/2409.10489}
170 | }
171 |
--------------------------------------------------------------------------------
/docs/flash-stu.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hazan-lab/flash-stu/54f78d4b91cc068ccc23554172beab1277aa8ab1/docs/flash-stu.webp
--------------------------------------------------------------------------------
/flash_stu/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import FlashSTU, FlashSTUConfig
2 | from .modules.stu import STU
3 | from .utils.stu_utils import get_spectral_filters
4 |
--------------------------------------------------------------------------------
/flash_stu/config.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from transformers import PretrainedConfig
4 |
5 |
6 | class FlashSTUConfig(PretrainedConfig):
7 | model_type = "FlashSTU"
8 |
9 | def __init__(
10 | self,
11 | bsz: int = 1,
12 | n_embd: int = 1536,
13 | n_heads: int = 8,
14 | n_layers: int = 26,
15 | seq_len: int = 8192,
16 | window_size: int = 1024,
17 | vocab_size: int = 200064,
18 | mlp_scale: int = 12,
19 | bias: bool = False,
20 | dropout: float = 0.0,
21 | num_eigh: int = 24,
22 | use_hankel_L: bool = False,
23 | use_flash_fft: bool = True,
24 | use_approx: bool = True,
25 | use_attn: bool = True,
26 | softcap: float = 50.0,
27 | torch_dtype: torch.dtype = torch.bfloat16,
28 | **kwargs,
29 | ):
30 | super().__init__(**kwargs)
31 | self.bsz = bsz
32 | self.n_embd = n_embd
33 | self.n_heads = n_heads
34 | self.n_layers = n_layers
35 | self.seq_len = seq_len
36 | self.window_size = window_size
37 | self.vocab_size = vocab_size
38 | self.hidden_size = n_embd
39 | self.intermediate_size = n_embd * mlp_scale
40 | self.hidden_act = "swish"
41 | self.bias = bias
42 | self.dropout = dropout
43 | self.num_eigh = num_eigh
44 | self.use_hankel_L = use_hankel_L
45 | self.use_flash_fft = use_flash_fft
46 | self.use_approx = use_approx
47 | self.use_attn = use_attn
48 | self.softcap = softcap
49 | self.torch_dtype = torch_dtype
50 |
--------------------------------------------------------------------------------
/flash_stu/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hazan-lab/flash-stu/54f78d4b91cc068ccc23554172beab1277aa8ab1/flash_stu/layers/__init__.py
--------------------------------------------------------------------------------
/flash_stu/layers/attention_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from flash_stu.modules.attention import Attention
5 | from flash_stu.modules.swiglu import MLP
6 |
7 | try:
8 | from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP
9 |
10 | triton_mlp = True
11 | except ImportError as e:
12 | print(
13 | f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead."
14 | )
15 | triton_mlp = False
16 |
17 | try:
18 | from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
19 |
20 | triton_norm = True
21 | except ImportError as e:
22 | print(
23 | f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation."
24 | )
25 | from torch.nn import RMSNorm
26 |
27 | triton_norm = False
28 |
29 |
30 | class AttentionLayer(nn.Module):
31 | def __init__(self, config) -> None:
32 | super(AttentionLayer, self).__init__()
33 | self.attn_norm = (
34 | TritonNorm(config.n_embd)
35 | if triton_norm
36 | else RMSNorm(config.n_embd, dtype=config.torch_dtype)
37 | )
38 | self.attn = Attention(config)
39 | self.mlp_norm = (
40 | TritonNorm(config.n_embd)
41 | if triton_norm
42 | else RMSNorm(config.n_embd, dtype=config.torch_dtype)
43 | )
44 | self.mlp = (
45 | TritonMLP(config) if triton_mlp else MLP(config, dtype=config.torch_dtype)
46 | )
47 |
48 | # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
49 | self.attn_norm = self.attn_norm.to(dtype=config.torch_dtype)
50 | self.mlp = self.mlp.to(dtype=config.torch_dtype)
51 | self.mlp_norm = self.mlp_norm.to(dtype=config.torch_dtype)
52 |
53 | def forward(self, x: torch.Tensor) -> torch.Tensor:
54 | x = x + self.attn(self.attn_norm(x))
55 | x = x + self.mlp(self.mlp_norm(x))
56 | return x
57 |
--------------------------------------------------------------------------------
/flash_stu/layers/stu_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from flash_stu.modules.stu import STU
5 | from flash_stu.modules.swiglu import MLP
6 |
7 | try:
8 | from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP
9 | triton_mlp = True
10 | except ImportError as e:
11 | print(
12 | f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead."
13 | )
14 | triton_mlp = False
15 |
16 | try:
17 | from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
18 | triton_norm = True
19 | except ImportError as e:
20 | print(
21 | f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation."
22 | )
23 | from torch.nn import RMSNorm
24 | triton_norm = False
25 |
26 |
27 | class STULayer(nn.Module):
28 | def __init__(self, config, phi, n):
29 | super(STULayer, self).__init__()
30 | self.stu_norm = (
31 | TritonNorm(config.n_embd)
32 | if triton_norm
33 | else RMSNorm(config.n_embd, dtype=config.torch_dtype)
34 | )
35 | self.stu = STU(config, phi, n)
36 | self.mlp_norm = (
37 | TritonNorm(config.n_embd)
38 | if triton_norm
39 | else RMSNorm(config.n_embd, dtype=config.torch_dtype)
40 | )
41 | self.mlp = (
42 | TritonMLP(config) if triton_mlp else MLP(config, dtype=config.torch_dtype)
43 | )
44 |
45 | # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
46 | self.stu_norm = self.stu_norm.to(dtype=config.torch_dtype)
47 | self.mlp = self.mlp.to(dtype=config.torch_dtype)
48 | self.mlp_norm = self.mlp_norm.to(dtype=config.torch_dtype)
49 |
50 | def forward(self, x: torch.Tensor) -> torch.Tensor:
51 | x = x + self.stu(self.stu_norm(x))
52 | x = x + self.mlp(self.mlp_norm(x))
53 | return x
54 |
--------------------------------------------------------------------------------
/flash_stu/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import PreTrainedModel
5 |
6 | from flash_stu.modules.stu import STU
7 | from flash_stu.modules.attention import Attention
8 | from flash_stu.utils.numerics import nearest_power_of_two
9 | from flash_stu.config import FlashSTUConfig
10 | from flash_stu.layers.stu_layer import STULayer
11 | from flash_stu.layers.attention_layer import AttentionLayer
12 |
13 | try:
14 | from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
15 | triton_norm = True
16 | except ImportError as e:
17 | print(
18 | f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation."
19 | )
20 | from torch.nn import RMSNorm
21 |
22 | triton_norm = False
23 |
24 |
25 | class FlashSTU(PreTrainedModel):
26 | config_class = FlashSTUConfig
27 |
28 | def __init__(self, config, phi) -> None:
29 | super(FlashSTU, self).__init__(config)
30 | self.n_layers = config.n_layers
31 | self.n = nearest_power_of_two(config.seq_len * 2 - 1, round_up=True)
32 | self.phi = phi
33 | self.use_approx = config.use_approx
34 | self.use_hankel_L = config.use_hankel_L
35 |
36 | # TODO: Add support for Liger-Kernel Embedding once no longer experimental
37 | self.tok_emb = nn.Embedding(
38 | config.vocab_size, config.n_embd, dtype=config.torch_dtype
39 | )
40 | self.dropout = nn.Dropout(config.dropout)
41 |
42 | self.layers = nn.ModuleList()
43 | for layer_idx in range(self.n_layers):
44 | # For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887
45 | if layer_idx % 2 == 0:
46 | self.layers.append(STULayer(config, self.phi, self.n))
47 | else:
48 | self.layers.append(
49 | AttentionLayer(config)
50 | if config.use_attn
51 | else STULayer(config, self.phi, self.n)
52 | )
53 |
54 | self.norm = (
55 | TritonNorm(config.n_embd)
56 | if triton_norm
57 | else RMSNorm(config.n_embd, dtype=config.torch_dtype)
58 | )
59 | # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for RMS Norm
60 | self.norm = self.norm.to(dtype=config.torch_dtype)
61 | self.lm_head = nn.Linear(
62 | config.n_embd, config.vocab_size, bias=config.bias, dtype=config.torch_dtype
63 | )
64 | self.tok_emb.weight = self.lm_head.weight
65 |
66 | self.std = (config.n_embd) ** -0.5
67 | self.apply(self._init_weights)
68 | print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
69 |
70 | def forward(self, x: torch.Tensor) -> torch.tensor:
71 | tok_emb = self.tok_emb(x)
72 | x = self.dropout(tok_emb)
73 |
74 | for layer in self.layers:
75 | x = layer(x)
76 |
77 | x = self.norm(x)
78 | y_hat = self.lm_head(x)
79 |
80 | return y_hat
81 |
82 | def _get_num_params(self):
83 | n_params = sum(p.numel() for p in self.parameters())
84 | if hasattr(self, "pos_emb") and self.pos_emb is not None:
85 | n_params -= self.pos_emb.weight.numel()
86 | if self.tok_emb.weight is not self.lm_head.weight:
87 | n_params -= self.tok_emb.weight.numel()
88 | return n_params
89 |
90 | def _init_weights(self, module):
91 | if isinstance(module, nn.Linear):
92 | if hasattr(module, "SCALE_INIT"):
93 | self.std *= (2 * self.n_layers) ** -0.5
94 | torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
95 | if module.bias is not None:
96 | torch.nn.init.zeros_(module.bias)
97 | elif isinstance(module, nn.Embedding):
98 | torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
99 | elif isinstance(module, STU):
100 | if self.use_approx:
101 | torch.nn.init.xavier_normal_(module.M_inputs)
102 | torch.nn.init.xavier_normal_(module.M_filters)
103 | else:
104 | torch.nn.init.xavier_normal_(module.M_phi_plus)
105 | if not self.use_hankel_L:
106 | torch.nn.init.xavier_normal_(module.M_phi_minus)
107 | elif isinstance(module, Attention):
108 | torch.nn.init.xavier_normal_(module.c_attn.weight)
109 | torch.nn.init.xavier_normal_(module.c_proj.weight)
110 | if module.c_attn.bias is not None:
111 | torch.nn.init.zeros_(module.c_attn.bias)
112 | if module.c_proj.bias is not None:
113 | torch.nn.init.zeros_(module.c_proj.bias)
114 |
--------------------------------------------------------------------------------
/flash_stu/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hazan-lab/flash-stu/54f78d4b91cc068ccc23554172beab1277aa8ab1/flash_stu/modules/__init__.py
--------------------------------------------------------------------------------
/flash_stu/modules/attention.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from flash_stu.utils.numerics import nearest_power_of_two
7 |
8 | try:
9 | from flash_attn import flash_attn_func as fa2
10 | except ImportError as e:
11 | print(
12 | f"Unable to import Triton-based flash attention: {e}. No alternative currently available."
13 | )
14 | # TODO: Add FlexAttention + local attention mask when it's in stable release
15 |
16 | class Attention(nn.Module):
17 | def __init__(self, config):
18 | super(Attention, self).__init__()
19 | assert torch.cuda.is_available(), "CUDA is required."
20 | assert config.n_embd % config.n_heads == 0
21 | self.n_heads = config.n_heads
22 |
23 | self.device = torch.device("cuda")
24 | self.bsz = config.bsz
25 | self.c_attn = nn.Linear(
26 | config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=config.torch_dtype
27 | )
28 | self.c_proj = nn.Linear(
29 | config.n_embd, config.n_embd, bias=config.bias, dtype=config.torch_dtype
30 | )
31 | self.c_proj.SCALE_INIT = 1
32 | self.dropout = config.dropout
33 | self.resid_dropout = nn.Dropout(self.dropout)
34 | self.alibi_slopes = self._get_alibi_slopes(self.n_heads)
35 | self.window_size = config.window_size
36 | self.softcap = config.softcap
37 |
38 | def _generate_slopes(self, n: int):
39 | start = 2 ** (-(2 ** -(math.log2(n) - 3)))
40 | return [start * (start**i) for i in range(n)]
41 |
42 | def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25):
43 | # If n_heads is a power of 2, generate slopes directly
44 | if math.log2(n_heads).is_integer():
45 | slopes = self._generate_slopes(n_heads)
46 | else:
47 | # Get slopes for the nearest power of two
48 | n = nearest_power_of_two(n_heads, round_up=False)
49 | slopes_power_of_two = self._generate_slopes(n)
50 |
51 | # Generate extra slopes
52 | extra_slopes = self._generate_slopes(2 * n)
53 | extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
54 | slopes = slopes_power_of_two + extra_slopes_trunc
55 | slopes = torch.tensor(slopes, device=self.device)
56 | slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
57 | return slopes
58 |
59 | def forward(self, x):
60 | bsz, seq_len, d_in = x.size()
61 |
62 | qkv = self.c_attn(x)
63 | q, k, v = torch.chunk(qkv, 3, dim=2)
64 |
65 | q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
66 | k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
67 | v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
68 | y = fa2( # https://arxiv.org/pdf/2307.08691
69 | q,
70 | k,
71 | v,
72 | dropout_p=self.dropout if self.training else 0.0,
73 | causal=True,
74 | window_size=(self.window_size, 0),
75 | alibi_slopes=self.alibi_slopes, # https://arxiv.org/pdf/2108.12409
76 | softcap=self.softcap, # https://arxiv.org/pdf/2408.00118
77 | )
78 | y = y.contiguous().view(bsz, seq_len, d_in)
79 | y = self.resid_dropout(self.c_proj(y))
80 | return y
81 |
--------------------------------------------------------------------------------
/flash_stu/modules/stu.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from flash_stu.utils.stu_utils import convolve, flash_convolve
5 |
6 | try:
7 | from flashfftconv import FlashFFTConv
8 |
9 | flash_fft_available = True
10 | except ImportError as e:
11 | print(
12 | f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation."
13 | )
14 | flash_fft_available = False
15 |
16 |
17 | class STU(nn.Module):
18 | def __init__(self, config, phi, n) -> None:
19 | super(STU, self).__init__()
20 | self.config = config
21 | self.phi = phi
22 | self.n = n
23 | self.K = config.num_eigh
24 | self.d_in = config.n_embd
25 | self.d_out = config.n_embd
26 | self.use_hankel_L = config.use_hankel_L
27 | self.use_approx = config.use_approx
28 | self.flash_fft = (
29 | FlashFFTConv(self.n, dtype=torch.bfloat16)
30 | if config.use_flash_fft and flash_fft_available
31 | else None
32 | )
33 | if self.use_approx:
34 | self.M_inputs = nn.Parameter(
35 | torch.empty(self.d_in, self.d_out, dtype=config.torch_dtype)
36 | )
37 | self.M_filters = nn.Parameter(
38 | torch.empty(self.K, self.d_in, dtype=config.torch_dtype)
39 | )
40 | else:
41 | self.M_phi_plus = nn.Parameter(
42 | torch.empty(self.K, self.d_in, self.d_out, dtype=config.torch_dtype)
43 | )
44 | if not self.use_hankel_L:
45 | self.M_phi_minus = nn.Parameter(
46 | torch.empty(self.K, self.d_in, self.d_out, dtype=config.torch_dtype)
47 | )
48 |
49 | def forward(self, x: torch.Tensor) -> torch.Tensor:
50 | if self.use_approx:
51 | # Contract inputs and filters over the K and d_in dimensions, then convolve
52 | x_proj = x @ self.M_inputs
53 | phi_proj = self.phi @ self.M_filters
54 | if self.flash_fft:
55 | spectral_plus, spectral_minus = flash_convolve(
56 | x_proj, phi_proj, self.flash_fft, self.use_approx
57 | )
58 | else:
59 | spectral_plus, spectral_minus = convolve(
60 | x_proj, phi_proj, self.n, self.use_approx
61 | )
62 | else:
63 | # Convolve inputs and filters,
64 | if self.flash_fft:
65 | U_plus, U_minus = flash_convolve(
66 | x, self.phi, self.flash_fft, self.use_approx
67 | )
68 | else:
69 | U_plus, U_minus = convolve(x, self.phi, self.n, self.use_approx)
70 | # Then, contract over the K and d_in dimensions
71 | spectral_plus = torch.tensordot(
72 | U_plus, self.M_phi_plus, dims=([2, 3], [0, 1])
73 | )
74 | if not self.use_hankel_L:
75 | spectral_minus = torch.tensordot(
76 | U_minus, self.M_phi_minus, dims=([2, 3], [0, 1])
77 | )
78 |
79 | return spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus
80 |
--------------------------------------------------------------------------------
/flash_stu/modules/swiglu.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn import functional as F
3 |
4 | class MLP(nn.Module):
5 | def __init__(self, config, dtype=None):
6 | # https://arxiv.org/pdf/2002.05202
7 | super().__init__()
8 | dtype = dtype if dtype is not None else config.torch_dtype
9 | self.hidden_size = config.n_embd
10 | self.intermediate_size = config.n_embd * config.mlp_scale
11 | self.gate_proj = nn.Linear(
12 | self.hidden_size, self.intermediate_size, bias=config.bias, dtype=dtype
13 | )
14 | self.up_proj = nn.Linear(
15 | self.hidden_size, self.intermediate_size, bias=config.bias, dtype=dtype
16 | )
17 | self.down_proj = nn.Linear(
18 | self.intermediate_size, self.hidden_size, bias=config.bias, dtype=dtype
19 | )
20 | self.dropout = nn.Dropout(
21 | config.dropout
22 | ) # TODO: Write Issue in Liger-Kernel repo to support Dropout
23 |
24 | def forward(self, x):
25 | gate = self.gate_proj(x)
26 | gate = F.gelu(gate, approximate="tanh")
27 | up = self.up_proj(x)
28 | fuse = gate * up
29 | outputs = self.down_proj(fuse)
30 | outputs = self.dropout(outputs)
31 | return outputs
32 |
--------------------------------------------------------------------------------
/flash_stu/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hazan-lab/flash-stu/54f78d4b91cc068ccc23554172beab1277aa8ab1/flash_stu/utils/__init__.py
--------------------------------------------------------------------------------
/flash_stu/utils/numerics.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | def nearest_power_of_two(x: int, round_up: bool = False) -> int:
4 | return (
5 | 1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x))
6 | )
7 |
--------------------------------------------------------------------------------
/flash_stu/utils/stu_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from flashfftconv import FlashFFTConv
6 |
7 | from flash_stu.utils.numerics import nearest_power_of_two
8 |
9 |
10 | def get_hankel(seq_len: int, use_hankel_L: bool = False) -> np.ndarray:
11 | entries = np.arange(1, seq_len + 1, dtype=np.float64)
12 | i_plus_j = entries[:, None] + entries[None, :]
13 |
14 | if use_hankel_L:
15 | sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0
16 | denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0)
17 | Z = sgn * (8.0 / denom)
18 | elif not use_hankel_L:
19 | Z = 2.0 / (i_plus_j**3 - i_plus_j)
20 | else:
21 | raise ValueError("use_hankel_L must be a boolean")
22 |
23 | return Z
24 |
25 | def get_spectral_filters(
26 | seq_len: int,
27 | K: int,
28 | use_hankel_L: bool = False,
29 | device: torch.device = None,
30 | dtype: torch.dtype = torch.bfloat16,
31 | ) -> torch.Tensor:
32 | assert torch.cuda.is_available(), "CUDA is required."
33 | Z = get_hankel(seq_len, use_hankel_L)
34 | sigma, phi = np.linalg.eigh(Z)
35 | sigma, phi = sigma[-K:], phi[:, -K:]
36 | phi *= sigma ** 0.25
37 | return torch.tensor(phi, device=device, dtype=dtype)
38 |
39 | def convolve(u: torch.Tensor, v: torch.Tensor, n: int, use_approx: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
40 | bsz, seq_len, d_in = u.shape
41 |
42 | sgn = torch.full((1, seq_len, 1), 1, device=u.device)
43 | sgn[:, 1::2] *= -1
44 | if use_approx:
45 | _, d_out = v.shape
46 | v = v.view(1, -1, d_out, 1).to(torch.float32)
47 | else:
48 | _, K = v.shape
49 | sgn = sgn.unsqueeze(-1)
50 | v = v.view(1, -1, K, 1, 1).to(torch.float32) # (bsz, seq_len, K, d_in, stack)
51 | u = u.view(bsz, -1, 1, d_in).expand(bsz, -1, K, d_in)
52 |
53 | v = torch.fft.rfft(v, n=n, dim=1)
54 | U = torch.stack([u, u * sgn], dim=-1).to(torch.float32)
55 | U = torch.fft.rfft(U, n=n, dim=1)
56 | U_conv = torch.fft.irfft(v * U, n=n, dim=1)[:, :seq_len]
57 | U_plus, U_minus = torch.unbind(U_conv, dim=-1)
58 | U_minus = U_minus * sgn
59 |
60 | return U_plus, U_minus
61 |
62 | def flash_convolve(
63 | u: torch.Tensor, v: torch.Tensor, flash_fft: FlashFFTConv, use_approx: bool = True,
64 | ) -> tuple[torch.Tensor, torch.Tensor]:
65 | bsz, seq_len, d_in = u.shape
66 | _, K = v.shape
67 |
68 | padded_len = nearest_power_of_two(seq_len, round_up=True)
69 | pad_len = padded_len - seq_len
70 |
71 | sgn = torch.full((1, 1, padded_len), 1, device=u.device)
72 | sgn[:, :, 1::2] = -1
73 |
74 | if use_approx:
75 | u_padded = F.pad(u.transpose(1, 2), (0, pad_len)).to(torch.bfloat16).contiguous()
76 | v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32).contiguous()
77 | u_conv = torch.stack([u_padded, u_padded * sgn], dim=0).reshape(2 * bsz, d_in, padded_len)
78 | else:
79 | u_k_padded = F.pad(u.transpose(1, 2), (0, pad_len)).to(torch.bfloat16).repeat_interleave(K, dim=1).contiguous()
80 | v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32).repeat(d_in, 1).contiguous()
81 | u_conv = torch.stack([u_k_padded, u_k_padded * sgn], dim=0).reshape(2 * bsz, K * d_in, padded_len)
82 |
83 | U_conv = flash_fft(u_conv, v_padded)
84 |
85 | # Trim the output back to the original sequence length
86 | U_conv = U_conv[..., :seq_len]
87 |
88 | u_plus, u_minus = torch.chunk(U_conv, 2, dim=0)
89 |
90 | if use_approx:
91 | u_minus = u_minus * sgn[:, :, :seq_len]
92 | U_plus, U_minus = u_plus.transpose(1, 2), u_minus.transpose(1, 2)
93 | else:
94 | sgn = sgn[:, :, :seq_len].unsqueeze(-1).transpose(1, 2)
95 | U_plus = u_plus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous()
96 | U_minus = u_minus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() * sgn
97 |
98 | return U_plus, U_minus
99 |
--------------------------------------------------------------------------------
/flash_stu/utils/training.py:
--------------------------------------------------------------------------------
1 | import gc
2 |
3 | class GarbageCollection:
4 | def __init__(self, gc_freq=1000):
5 | assert gc_freq > 0, "gc_freq must be a positive integer"
6 | self.gc_freq = gc_freq
7 | gc.disable()
8 | gc.collect(1)
9 |
10 | def run(self, step_count):
11 | if step_count > 1 and step_count % self.gc_freq == 0:
12 | gc.collect(1)
13 |
14 | def linear_decay_with_warmup( # https://arxiv.org/pdf/2310.07831
15 | current_step: int,
16 | warmup_steps: int,
17 | num_steps: int,
18 | max_lr: float = 3e-4,
19 | min_lr: float = 3e-5,
20 | ) -> float:
21 | if current_step < warmup_steps:
22 | return min_lr + (max_lr - min_lr) * float(current_step) / float(max(warmup_steps, 1))
23 | else:
24 | return max_lr - (max_lr - min_lr) * float(current_step - warmup_steps) / float(max(num_steps - warmup_steps, 1))
25 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "flash-stu"
3 | version = "1.0.0"
4 | description = "An optimized, open-source PyTorch implementation of the Spectral Transform Unit (STU)"
5 | readme = "README.md"
6 | requires-python = ">=3.11"
7 | license = {file = "LICENSE"}
8 |
9 | authors = [
10 | { name = "Isabel Liu", email = "isabel.liu@princeton.edu" },
11 | { name = "Windsor Nguyen", email = "windsor@princetonalignment.org" },
12 | { name = "Yagiz Devre", email = "yagiz.devre@princeton.edu" },
13 | { name = "Evan Dogariu", email = "ed2719@nyu.edu" },
14 | { name = "Anirudha Majumdar", email = "ani.majumdar@princeton.edu" },
15 | { name = "Elad Hazan", email = "ehazan@princeton.edu" }
16 | ]
17 |
18 | maintainers = [
19 | { name = "Windsor Nguyen", email = "windsor@princetonalignment.org" }
20 | ]
21 |
22 | keywords = [
23 | "pytorch",
24 | "state space model",
25 | "spectral filtering",
26 | "deep learning",
27 | "machine learning",
28 | "time series",
29 | "dynamical systems",
30 | "language modeling",
31 | ]
32 |
33 | classifiers = [
34 | "Programming Language :: Python :: 3",
35 | "License :: OSI Approved :: Apache License",
36 | "Operating System :: Unix"
37 | ]
38 |
39 | dependencies = [
40 | "packaging==24.1",
41 | "wheel==0.44.0",
42 | "numpy==1.26.4",
43 | "ninja==1.11.1.1",
44 | "torch==2.4.1",
45 | "triton==3.0.0",
46 | "einops==0.8.0",
47 | "liger-kernel==0.2.1",
48 | "tiktoken==0.7.0",
49 | "transformers==4.44.0",
50 | "datasets==2.20.0",
51 | "safetensors==0.4.4",
52 | "psutil==6.0.0",
53 | "matplotlib==3.9.2",
54 | ]
55 |
56 | [tool.setuptools]
57 | packages = ["flash_stu"]
58 |
59 | # ===============================
60 | # Project URLs
61 | # ===============================
62 | [project.urls]
63 | Homepage = "https://arxiv.org/abs/2409.10489"
64 | Repository = "https://github.com/windsornguyen/flash-stu"
65 |
66 | # ===============================
67 | # Ruff Configuration
68 | # ===============================
69 | [tool.ruff]
70 | # Exclude commonly ignored directories from linting
71 | exclude = [
72 | ".bzr",
73 | ".direnv",
74 | ".eggs",
75 | ".git",
76 | ".git-rewrite",
77 | ".hg",
78 | ".mypy_cache",
79 | ".nox",
80 | ".pants.d",
81 | ".pytype",
82 | ".ruff_cache",
83 | ".svn",
84 | ".tox",
85 | ".venv",
86 | "__pypackages__",
87 | "_build",
88 | "buck-out",
89 | "build",
90 | "dist",
91 | "node_modules",
92 | "venv",
93 | ]
94 |
95 | # General Ruff settings
96 | line-length = 119
97 | indent-width = 4
98 | target-version = "py311"
99 |
100 | # -------------------------------
101 | # Ruff Linting Rules
102 | # -------------------------------
103 | [tool.ruff.lint]
104 | # Enable specific linting rules
105 | select = ["E4", "E7", "E9", "F", "B", "Q"]
106 | ignore = []
107 |
108 | # Define fixable and unfixable rules
109 | fixable = ["ALL"]
110 | unfixable = []
111 |
112 | # -------------------------------
113 | # Ruff Flake8-Quotes Settings
114 | # -------------------------------
115 | [tool.ruff.lint.flake8-quotes]
116 | docstring-quotes = "double" # Enforce double quotes for docstrings
117 | inline-quotes = "double" # Enforce double quotes for inline strings
118 |
119 | # -------------------------------
120 | # Ruff Formatting Settings
121 | # -------------------------------
122 | [tool.ruff.format]
123 | quote-style = "double" # Use double quotes for auto-formatting
124 | indent-style = "space" # Indent with spaces instead of tabs
125 | skip-magic-trailing-comma = false # Respect magic trailing commas
126 | line-ending = "auto" # Automatically detect line endings
127 |
128 | [build-system]
129 | requires = ["setuptools", "wheel"]
130 | build-backend = "setuptools.build_meta"
131 |
--------------------------------------------------------------------------------
/training/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "FlashSTU",
3 | "n_embd": 1536,
4 | "n_heads": 8,
5 | "n_layers": 26,
6 | "seq_len": 8192,
7 | "window_size": 1024,
8 | "vocab_size": 200064,
9 | "mlp_scale": 12,
10 | "bias": false,
11 | "dropout": 0.0,
12 | "num_eigh": 24,
13 | "use_hankel_L": false,
14 | "num_epochs": 1,
15 | "global_bsz": 524288,
16 | "bsz": 1,
17 | "warmup_steps": 19073,
18 | "eval_period": 25,
19 | "save_period": 19000,
20 | "max_lr": 3.0e-4,
21 | "min_lr": 3.0e-5,
22 | "max_norm": 1.0,
23 | "dilation": 1,
24 | "fsdp": true,
25 | "ddp": false,
26 | "mixed_precision": true,
27 | "torch_dtype": "bfloat16",
28 | "use_cpu_offload": false,
29 | "sharding_strategy": "full_shard",
30 | "state_dict_type": "full",
31 | "auto_wrap_policy": "partial",
32 | "backward_prefetch": "backward_pre",
33 | "forward_prefetch": false,
34 | "sync_module_states": true,
35 | "use_orig_params": true,
36 | "device_id": null,
37 | "precision": {
38 | "param": "bfloat16",
39 | "reduce": "bfloat16",
40 | "buffer": "bfloat16"
41 | },
42 | "fsdp_modules": [
43 | "STU",
44 | "Attention",
45 | "MLP"
46 | ],
47 | "use_activation_checkpointing": true,
48 | "use_flash_fft": true,
49 | "use_approx": true,
50 | "use_attn": true,
51 | "softcap": 50.0,
52 | "torch_compile": false
53 | }
54 |
--------------------------------------------------------------------------------
/training/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import multiprocessing as mp
3 |
4 | import torch
5 | import tiktoken
6 | from datasets import load_dataset
7 | from tqdm import tqdm
8 |
9 |
10 | """Adapted from https://github.com/karpathy/build-nanogpt/blob/master/fineweb.py"""
11 |
12 | # Configuration
13 | local_dir = "data/fineweb-edu-10B"
14 | remote_name = "sample-10BT"
15 | shard_size = int(1e8) # 100M tokens per shard
16 |
17 | # Create the cache directory if it doesn't exist
18 | DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir)
19 | os.makedirs(DATA_CACHE_DIR, exist_ok=True)
20 |
21 | # Download the dataset
22 | fw = load_dataset("HuggingFaceFW/fineweb-edu", name=remote_name, split="train")
23 |
24 | # Initialize the tokenizer
25 | enc = tiktoken.get_encoding("o200k_base")
26 | eot = enc._special_tokens['<|endoftext|>']
27 |
28 | def tokenize(doc):
29 | tokens = [eot]
30 | tokens.extend(enc.encode_ordinary(doc["text"]))
31 | tokens = torch.tensor(tokens, dtype=torch.int32)
32 | return tokens
33 |
34 | def write_datafile(file, tokens):
35 | torch.save(tokens, file + '.pt')
36 |
37 | # Tokenize all documents and write output shards
38 | nprocs = 8
39 | with mp.Pool(nprocs) as pool:
40 | shard_idx = 0
41 | all_tokens_tensor = torch.empty((shard_size,), dtype=torch.int32)
42 | token_count = 0
43 | progress_bar = None
44 |
45 | for tokens in pool.imap(tokenize, fw, chunksize=16):
46 | # If there is enough space in current shard for new tokens,
47 | if token_count + len(tokens) < shard_size:
48 | all_tokens_tensor[token_count:token_count+len(tokens)] = tokens
49 | token_count += len(tokens)
50 | if progress_bar is None:
51 | progress_bar = tqdm(total=shard_size, unit=" toks", desc=f"Shard {shard_idx}")
52 | progress_bar.update(len(tokens))
53 | else:
54 | # Else, write the current shard and start a new one
55 | split = "val" if shard_idx == 0 else "train"
56 | filename = os.path.join(DATA_CACHE_DIR, f"fineweb-edu-10B_{split}_{shard_idx:06d}")
57 |
58 | # Split the document into whatever fits in this shard, remainder goes to next one
59 | remainder = shard_size - token_count
60 | progress_bar.update(remainder)
61 | all_tokens_tensor[token_count:token_count+remainder] = tokens[:remainder]
62 | write_datafile(filename, all_tokens_tensor)
63 | shard_idx += 1
64 | progress_bar = None
65 |
66 | # Populate the next shard with the leftovers of the current doc
67 | all_tokens_tensor[0:len(tokens)-remainder] = tokens[remainder:]
68 | token_count = len(tokens)-remainder
69 |
70 | # Write any remaining tokens as the last shard
71 | if token_count != 0:
72 | split = "val" if shard_idx == 0 else "train"
73 | filename = os.path.join(DATA_CACHE_DIR, f"fineweb-edu-10B_{split}_{shard_idx:06d}")
74 | write_datafile(filename, all_tokens_tensor[:token_count])
75 |
--------------------------------------------------------------------------------
/training/dataloader.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 |
7 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | def load_tokens(filename):
12 | try:
13 | ext = os.path.splitext(filename)[1]
14 |
15 | if ext == ".npy":
16 | npt = np.load(filename)
17 | npt = npt.astype(np.int32)
18 | ptt = torch.tensor(npt, dtype=torch.long)
19 | return ptt
20 | elif ext == ".pt":
21 | return torch.load(filename, weights_only=True)
22 | else:
23 | raise ValueError(f"Unsupported file extension: {ext}")
24 |
25 | except Exception as e:
26 | logger.error(f"Error loading file {filename}: {str(e)}")
27 | raise
28 |
29 | class DistributedDataloader:
30 | def __init__(
31 | self,
32 | bsz: int,
33 | seq_len: int,
34 | rank: int,
35 | world_size: int,
36 | dataset: str,
37 | split: str,
38 | main_process: bool = False,
39 | ):
40 | self.bsz = bsz
41 | self.seq_len = seq_len
42 | self.rank = rank
43 | self.world_size = world_size
44 | assert split in {'train', 'val', 'test'}, f"Invalid split: {split}"
45 |
46 | data_root = dataset
47 | shards = [s for s in os.listdir(data_root) if split in s and (s.endswith('.pt') or s.endswith('.npy'))]
48 | self.shards = [os.path.join(data_root, s) for s in sorted(shards)]
49 | assert len(self.shards) > 0, f'No shards found for split {split}'
50 | if main_process:
51 | logger.info(f'Found {len(self.shards)} shards for split {split}')
52 |
53 | self.reset()
54 |
55 | def reset(self):
56 | self.current_shard = 0
57 | self.tokens = load_tokens(self.shards[self.current_shard])
58 | self.current_position = (self.bsz * self.seq_len * self.rank) % len(self.tokens)
59 |
60 | def set_epoch(self, epoch):
61 | self.generator = torch.Generator()
62 | self.generator.manual_seed(epoch)
63 | self.shard_order = torch.randperm(len(self.shards), generator=self.generator).tolist()
64 | self.current_shard = self.shard_order[self.rank % len(self.shards)]
65 | self.tokens = load_tokens(self.shards[self.current_shard])
66 | self.current_position = (self.bsz * self.seq_len * self.rank) % len(self.tokens)
67 |
68 | def __iter__(self):
69 | return self
70 |
71 | def __next__(self):
72 | if self.current_position + (self.bsz * self.seq_len + 1) > len(self.tokens):
73 | self.current_shard = (self.current_shard + 1) % len(self.shards)
74 | self.tokens = load_tokens(self.shards[self.current_shard])
75 | self.current_position = (self.bsz * self.seq_len * self.rank) % len(self.tokens)
76 |
77 | buf = self.tokens[self.current_position : self.current_position + self.bsz * self.seq_len + 1]
78 | x = buf[:-1].view(self.bsz, self.seq_len)
79 | y = buf[1:].view(self.bsz, self.seq_len)
80 |
81 | self.current_position += self.bsz * self.seq_len * self.world_size
82 | return x, y.to(torch.long)
83 |
--------------------------------------------------------------------------------
/training/distributed.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import random
4 | import socket
5 | import sys
6 | from packaging.version import parse as version_parse
7 | from functools import partial
8 | from safetensors.torch import save_file, load_file
9 | import numpy as np
10 | import psutil
11 | import torch
12 | import torch.cuda.nccl as nccl
13 | import torch.distributed as dist
14 | import torch.nn as nn
15 | from glob import glob
16 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
17 | CheckpointImpl,
18 | apply_activation_checkpointing,
19 | checkpoint_wrapper,
20 | )
21 | from torch.distributed.fsdp import (
22 | BackwardPrefetch,
23 | CPUOffload,
24 | FullyShardedDataParallel as FSDP,
25 | MixedPrecision,
26 | ShardingStrategy,
27 | )
28 | from torch.distributed.fsdp.wrap import _module_wrap_policy, size_based_auto_wrap_policy
29 |
30 | from flash_stu import STU
31 | from flash_stu.modules.attention import Attention
32 | from flash_stu.modules.swiglu import MLP
33 |
34 |
35 | logging.basicConfig(
36 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
37 | )
38 | logger = logging.getLogger(__name__)
39 |
40 |
41 | def set_seeds(seed: int, cuda_deterministic: bool = False) -> None:
42 | random.seed(seed)
43 | np.random.seed(seed)
44 | torch.manual_seed(seed)
45 | if torch.cuda.is_available():
46 | torch.cuda.manual_seed(seed)
47 | torch.cuda.manual_seed_all(seed)
48 |
49 | if cuda_deterministic:
50 | torch.backends.cudnn.deterministic = True
51 | torch.backends.cudnn.benchmark = False
52 |
53 | logger.info(f"Random seeds set to {seed}")
54 |
55 |
56 | def setup_distributed(seed: int = 1337) -> tuple[torch.device, int, int, int, bool]:
57 | if not dist.is_available():
58 | raise RuntimeError("Distributed package not available!")
59 |
60 | if not torch.cuda.is_available():
61 | raise RuntimeError("CUDA is required for distributed training!")
62 |
63 | dist.init_process_group(backend="nccl")
64 | local_rank = int(os.environ["LOCAL_RANK"])
65 | rank = dist.get_rank()
66 | world_size = dist.get_world_size()
67 | main_process = rank == 0
68 |
69 | # Set up devices
70 | device = torch.device(f"cuda:{local_rank}")
71 | torch.cuda.set_device(device)
72 | set_seeds(seed + rank)
73 |
74 | if main_process:
75 | logger.info(f"Main process initialized on {socket.gethostname()}")
76 | logger.info(f"PyTorch version: {torch.__version__}")
77 | logger.info(f"Python version: {sys.version}")
78 | logger.info(f"CUDA version: {torch.version.cuda}")
79 | logger.info(f"cuDNN version: {torch.backends.cudnn.version()}")
80 | logger.info(
81 | f"World info: size={world_size}, rank={rank}, local_rank={local_rank}"
82 | )
83 | log_system_info(world_size, rank)
84 |
85 | return device, local_rank, rank, world_size, main_process
86 |
87 |
88 | def log_system_info(world_size: int, rank: int):
89 | logger.info(f"System info for rank {rank}:")
90 | logger.info(f"CPU count: {psutil.cpu_count()}")
91 | logger.info(f"Total RAM: {psutil.virtual_memory().total / (1024**3):.2f} GB")
92 | logger.info(
93 | f"Available RAM: {psutil.virtual_memory().available / (1024**3):.2f} GB"
94 | )
95 |
96 | # Log local GPU count and world size for clarity
97 | local_gpu_count = torch.cuda.device_count()
98 | logger.info(f"Local GPU count (rank {rank}): {local_gpu_count}")
99 | logger.info(f"Total GPU count across all nodes: {world_size * local_gpu_count}")
100 |
101 | # Log specific GPU properties for this node (local GPUs)
102 | for i in range(local_gpu_count):
103 | logger.info(f"GPU {i} (rank {rank}) name: {torch.cuda.get_device_name(i)}")
104 | logger.info(
105 | f"GPU {i} (rank {rank}) memory: {torch.cuda.get_device_properties(i).total_memory / (1024**3):.2f} GB"
106 | )
107 |
108 |
109 | def find_checkpoint(log_dir: str) -> str:
110 | model_pattern = os.path.join(log_dir, "model_*.safetensors")
111 | misc_pattern = os.path.join(log_dir, "other_checkpoints_*.pt")
112 | model_checkpoints = glob(model_pattern)
113 | misc_checkpoints = glob(misc_pattern)
114 | if not model_checkpoints or not misc_checkpoints:
115 | return None
116 | latest_checkpoint = max(
117 | model_checkpoints, key=lambda x: int(x.split("_")[-1].split(".")[0])
118 | )
119 | misc_checkpoint = max(
120 | misc_checkpoints, key=lambda x: int(x.split("_")[-1].split(".")[0])
121 | )
122 | return latest_checkpoint, misc_checkpoint
123 |
124 |
125 | def load_checkpoint(model_path: str, misc_path: str, model, optimizer, device):
126 | model_checkpoint = load_file(model_path)
127 | model.load_state_dict(model_checkpoint)
128 | model.to(device)
129 |
130 | misc_checkpoint = torch.load(misc_path, map_location=device, weights_only=True)
131 | model.config = misc_checkpoint["config"]
132 | optimizer.load_state_dict(misc_checkpoint["optimizer"])
133 |
134 | step = misc_checkpoint["step"]
135 | val_loss = misc_checkpoint["val_loss"]
136 |
137 | return model, optimizer, step, val_loss
138 |
139 |
140 | def save_checkpoint(
141 | model_checkpoint, optim_checkpoint, config, step, best_val_loss, log_dir
142 | ):
143 | model_checkpoint_path = os.path.join(log_dir, f"model_{step:05d}.safetensors")
144 | other_checkpoint_path = os.path.join(log_dir, f"other_checkpoints_{step:05d}.pt")
145 |
146 | save_file(model_checkpoint, model_checkpoint_path)
147 |
148 | other_checkpoint = {
149 | "config": config,
150 | "optimizer": optim_checkpoint,
151 | "step": step,
152 | "val_loss": best_val_loss,
153 | }
154 | torch.save(other_checkpoint, other_checkpoint_path)
155 |
156 | logging.info(
157 | f"Validation loss improved at step {step}! Save the model to {model_checkpoint_path}, misc data to {other_checkpoint_path}."
158 | )
159 |
160 |
161 | def setup_fsdp(
162 | model: nn.Module,
163 | mixed_precision: bool = True,
164 | use_cpu_offload: bool = False,
165 | sharding_strategy: str = "full_shard",
166 | auto_wrap_policy: str = "partial",
167 | backward_prefetch: str = "backward_pre",
168 | forward_prefetch: bool = False,
169 | sync_module_states: bool = True,
170 | use_orig_params: bool = True,
171 | device_id: int = None,
172 | precision: dict = None,
173 | fsdp_modules: list = None,
174 | use_activation_checkpointing: bool = True,
175 | ) -> tuple[FSDP, dict]:
176 | if not torch.cuda.is_available() or not dist.is_nccl_available():
177 | raise RuntimeError("CUDA and NCCL must be available for FSDP setup")
178 |
179 | fsdp_params = {}
180 |
181 | bf16_ready = (
182 | torch.version.cuda
183 | and torch.cuda.is_bf16_supported()
184 | and version_parse(torch.version.cuda) >= version_parse("11.0")
185 | and dist.is_nccl_available()
186 | and nccl.version() >= (2, 10)
187 | )
188 |
189 | # Set up mixed precision
190 | if mixed_precision and precision:
191 | param_dtype = precision.get(
192 | "param", torch.bfloat16 if bf16_ready else torch.float32
193 | )
194 | reduce_dtype = precision.get(
195 | "reduce", torch.bfloat16 if bf16_ready else torch.float32
196 | )
197 | buffer_dtype = precision.get(
198 | "buffer", torch.bfloat16 if bf16_ready else torch.float32
199 | )
200 |
201 | if isinstance(param_dtype, str):
202 | param_dtype = getattr(torch, param_dtype)
203 | if isinstance(reduce_dtype, str):
204 | reduce_dtype = getattr(torch, reduce_dtype)
205 | if isinstance(buffer_dtype, str):
206 | buffer_dtype = getattr(torch, buffer_dtype)
207 |
208 | fsdp_params["mixed_precision"] = MixedPrecision(
209 | param_dtype=param_dtype,
210 | reduce_dtype=reduce_dtype,
211 | buffer_dtype=buffer_dtype,
212 | )
213 |
214 | # Set up CPU offloading
215 | if use_cpu_offload:
216 | fsdp_params["use_cpu_offload"] = CPUOffload(offload_params=True)
217 |
218 | # Set up sharding strategy
219 | if sharding_strategy == "full_shard":
220 | fsdp_params["sharding_strategy"] = ShardingStrategy.FULL_SHARD
221 | elif sharding_strategy == "shard_grad_op":
222 | fsdp_params["sharding_strategy"] = ShardingStrategy.SHARD_GRAD_OP
223 | elif sharding_strategy == "no_shard":
224 | fsdp_params["sharding_strategy"] = ShardingStrategy.NO_SHARD
225 | else:
226 | raise ValueError(f"Invalid sharding strategy: {sharding_strategy}")
227 |
228 | # Set up backward prefetch
229 | if backward_prefetch == "backward_pre":
230 | fsdp_params["backward_prefetch"] = BackwardPrefetch.BACKWARD_PRE
231 | elif backward_prefetch == "backward_post":
232 | fsdp_params["backward_prefetch"] = BackwardPrefetch.BACKWARD_POST
233 | elif backward_prefetch is not None:
234 | raise ValueError(f"Invalid backward prefetch option: {backward_prefetch}")
235 |
236 | # Set up other parameters
237 | fsdp_params["forward_prefetch"] = forward_prefetch
238 | fsdp_params["sync_module_states"] = sync_module_states
239 | fsdp_params["use_orig_params"] = use_orig_params
240 |
241 | if device_id is None:
242 | device_id = torch.cuda.current_device()
243 | fsdp_params["device_id"] = device_id
244 |
245 | # Set up auto wrap policy
246 | fsdp_modules_set = set(
247 | eval(module) if isinstance(module, str) else module for module in fsdp_modules
248 | )
249 | if auto_wrap_policy == "partial":
250 | fsdp_params["auto_wrap_policy"] = partial(
251 | _module_wrap_policy, module_classes=fsdp_modules_set
252 | )
253 | elif auto_wrap_policy == "size_based":
254 | fsdp_params["auto_wrap_policy"] = size_based_auto_wrap_policy
255 | else:
256 | raise ValueError(f"Invalid auto wrap policy: {auto_wrap_policy}")
257 |
258 | # Apply activation checkpointing
259 | if use_activation_checkpointing:
260 | check_fn = lambda submodule: isinstance(submodule, tuple(fsdp_modules_set))
261 | apply_activation_checkpointing(
262 | model,
263 | checkpoint_wrapper_fn=partial(
264 | checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT
265 | ),
266 | check_fn=check_fn,
267 | )
268 |
269 | # Wrap the model with FSDP
270 | fsdp_model = FSDP(model, **fsdp_params)
271 |
272 | return fsdp_model
273 |
274 |
275 | def cleanup_distributed(rank: int):
276 | if dist.is_initialized():
277 | logging.info(f"[Rank {rank}]: Finished training.")
278 | logging.info(f"[Rank {rank}]: Waiting for other processes to finish...")
279 | dist.barrier()
280 | dist.destroy_process_group()
281 |
--------------------------------------------------------------------------------
/training/example.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import time
5 | from contextlib import nullcontext
6 |
7 | import torch
8 | import torch.distributed as dist
9 | from torch.amp import autocast, GradScaler
10 | from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict
11 | from torch.distributed.elastic.multiprocessing.errors import record
12 | from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
13 | from torch.nn.parallel import DistributedDataParallel as DDP
14 | from torch.optim import AdamW
15 |
16 | from dataloader import DistributedDataloader
17 | from distributed import (
18 | cleanup_distributed,
19 | find_checkpoint,
20 | load_checkpoint,
21 | save_checkpoint,
22 | setup_distributed,
23 | setup_fsdp
24 | )
25 | from flash_stu.config import FlashSTUConfig
26 | from flash_stu.model import FlashSTU
27 | from flash_stu.utils.stu_utils import get_spectral_filters
28 | from flash_stu.utils.training import GarbageCollection, linear_decay_with_warmup
29 |
30 | try:
31 | from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss as CrossEntropyLoss
32 | except ImportError as e:
33 | print(f"Unable to import Triton-based cross entropy loss: {e}. Falling back to PyTorch implementation.")
34 | from torch.nn import CrossEntropyLoss
35 |
36 |
37 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
38 | logger = logging.getLogger(__name__)
39 |
40 | @record
41 | def main():
42 | torch.set_float32_matmul_precision("high")
43 | device, local_rank, rank, world_size, main_process = setup_distributed(seed=1337)
44 | gc_handler = GarbageCollection()
45 |
46 | with open("config.json", "r") as file:
47 | config = json.load(file)
48 |
49 | # Model configurations
50 | n_embd = config['n_embd']
51 | n_heads = config['n_heads']
52 | n_layers = config['n_layers']
53 | seq_len = config['seq_len']
54 | window_size = config['window_size']
55 | vocab_size = config['vocab_size']
56 | mlp_scale = config['mlp_scale']
57 | bias = config['bias']
58 | dropout = config['dropout']
59 | num_eigh = config['num_eigh']
60 | use_hankel_L = config['use_hankel_L']
61 |
62 | # Optimizations
63 | use_flash_fft = config['use_flash_fft']
64 | use_approx = config['use_approx']
65 | use_attn = config['use_attn']
66 | softcap = config['softcap']
67 | torch_compile = config['torch_compile']
68 |
69 | # Training configurations
70 | dilation = config['dilation']
71 | warmup_steps = config['warmup_steps'] // dilation
72 | eval_period = config['eval_period'] // dilation
73 | save_period = config['save_period'] // dilation
74 | num_epochs = config['num_epochs']
75 | max_lr = config['max_lr']
76 | min_lr = config['min_lr']
77 | max_norm = config['max_norm']
78 |
79 | global_bsz = config['global_bsz']
80 | bsz = config['bsz']
81 | assert (
82 | global_bsz % (bsz * seq_len * world_size) == 0
83 | ), f"global_bsz ({global_bsz}) must be divisible by bsz * seq_len * world_size ({bsz * seq_len * world_size}),"
84 | f" got {global_bsz % (bsz * seq_len * world_size)}"
85 | gradient_accumulation_steps = global_bsz // (bsz * seq_len * world_size)
86 |
87 | # Distributed
88 | fsdp = config['fsdp']
89 | ddp = config['ddp']
90 | assert not (fsdp and ddp), "FSDP and DDP are both enabled which is not allowed"
91 |
92 | distributed = (fsdp or ddp) and world_size > 1
93 | cache_enabled = not ddp
94 |
95 | mixed_precision = config['mixed_precision']
96 | torch_dtype = getattr(torch, config['torch_dtype'])
97 | use_cpu_offload = config['use_cpu_offload']
98 | sharding_strategy = config['sharding_strategy']
99 | auto_wrap_policy = config['auto_wrap_policy']
100 | backward_prefetch = config['backward_prefetch']
101 | forward_prefetch = config['forward_prefetch']
102 | sync_module_states = config['sync_module_states']
103 | use_orig_params = config['use_orig_params']
104 | device_id = config['device_id']
105 | precision = config['precision']
106 | fsdp_modules = config['fsdp_modules']
107 | use_activation_checkpointing = config['use_activation_checkpointing']
108 |
109 | if main_process:
110 | logging.info(f"Training config: {config}\n")
111 |
112 | if world_size == 1 and fsdp:
113 | if main_process:
114 | logging.info("World size is 1, disabling sharding.")
115 | sharding_strategy = "no_shard"
116 |
117 | config = FlashSTUConfig(
118 | bsz=bsz,
119 | n_embd=n_embd,
120 | n_heads=n_heads,
121 | n_layers=n_layers,
122 | seq_len=seq_len,
123 | window_size=window_size,
124 | vocab_size=vocab_size,
125 | mlp_scale=mlp_scale,
126 | bias=bias,
127 | dropout=dropout,
128 | num_eigh=num_eigh,
129 | use_hankel_L=use_hankel_L,
130 | use_flash_fft=use_flash_fft,
131 | use_approx=use_approx,
132 | use_attn=use_attn,
133 | softcap=softcap,
134 | torch_dtype=torch_dtype,
135 | )
136 | fsdp_params = {
137 | "mixed_precision": mixed_precision,
138 | "use_cpu_offload": use_cpu_offload,
139 | "sharding_strategy": sharding_strategy,
140 | "auto_wrap_policy": auto_wrap_policy,
141 | "backward_prefetch": backward_prefetch,
142 | "forward_prefetch": forward_prefetch,
143 | "sync_module_states": sync_module_states,
144 | "use_orig_params": use_orig_params,
145 | "device_id": device_id,
146 | "precision": precision,
147 | "fsdp_modules": fsdp_modules,
148 | "use_activation_checkpointing": use_activation_checkpointing,
149 | }
150 |
151 | phi = get_spectral_filters(seq_len, num_eigh, use_hankel_L, device, torch_dtype)
152 | model = FlashSTU(config, phi)
153 | if torch_compile:
154 | model = torch.compile(model)
155 | if main_process:
156 | logging.info(f"PyTorch Compiler Enabled?: {torch_compile}")
157 | model = model.to(device)
158 |
159 | if fsdp:
160 | model = setup_fsdp(model, **fsdp_params)
161 | if ddp:
162 | model = DDP(model, device_ids=[local_rank], gradient_as_bucket_view=True)
163 | state_dict_options = StateDictOptions(
164 | full_state_dict=True,
165 | cpu_offload=True,
166 | )
167 | torch.serialization.add_safe_globals([FlashSTUConfig])
168 |
169 | use_scaler = torch_dtype == torch.float16
170 | if use_scaler:
171 | scaler = ShardedGradScaler() if fsdp else GradScaler()
172 | optimizer = AdamW(
173 | model.parameters(),
174 | lr=max_lr,
175 | fused=torch.cuda.is_available(),
176 | )
177 | loss_fn = CrossEntropyLoss()
178 |
179 | # Create the log directory to write checkpoints to and log to
180 | log_dir = "log"
181 | os.makedirs(log_dir, exist_ok=True)
182 |
183 | checkpoints = find_checkpoint(log_dir)
184 | if checkpoints:
185 | model_checkpoint, misc_checkpoint = checkpoints
186 | model, optimizer, start_step, best_val_loss = load_checkpoint(
187 | model_checkpoint, misc_checkpoint, model, optimizer, device
188 | )
189 | log_mode = "a" # Append to the log if resuming
190 | if main_process:
191 | logging.info(f"Resuming from checkpoint: {model_checkpoint}")
192 | logging.info(f"Starting from step: {start_step}")
193 | logging.info(f"Best validation loss: {best_val_loss}")
194 | else:
195 | start_step = 0
196 | best_val_loss = float('inf')
197 | log_mode = "w" # Create a new log if starting fresh
198 |
199 | if main_process:
200 | log_file = os.path.join(log_dir, "log.txt")
201 | with open(log_file, log_mode) as f:
202 | pass
203 |
204 | # Data loader section
205 | # 19,073 steps is ~1 epoch, if data is 10B tokens and batch size ~0.5M tokens
206 | # This is dataset and batch size dependent.
207 | dataset = "data/fineweb-edu-10B"
208 | total_tokens = 10_000_000_000
209 | num_steps = total_tokens // global_bsz # Number of steps for one epoch
210 | max_steps = num_steps * num_epochs
211 |
212 | if main_process:
213 | logging.info(f"Total (desired) batch size: {global_bsz}")
214 | logging.info(
215 | f"=> Calculated gradient accumulation steps: {gradient_accumulation_steps}"
216 | )
217 | logging.info(f"Training on {max_steps} steps")
218 |
219 | train_loader = DistributedDataloader(
220 | bsz=bsz,
221 | seq_len=seq_len,
222 | rank=rank,
223 | world_size=world_size,
224 | dataset=dataset,
225 | split="train",
226 | main_process=main_process,
227 | )
228 | val_loader = DistributedDataloader(
229 | bsz=bsz,
230 | seq_len=seq_len,
231 | rank=rank,
232 | world_size=world_size,
233 | dataset=dataset,
234 | split="val",
235 | main_process=main_process,
236 | )
237 |
238 | for step in range(start_step + 1, max_steps + 1):
239 | gc_handler.run(step)
240 | torch.cuda.empty_cache()
241 | epoch = step // num_steps
242 | last_step = step % num_steps == 0
243 |
244 | if step == 1 or step % num_steps == 1:
245 | if main_process:
246 | logging.info(f"Starting epoch {epoch}")
247 | train_loader.set_epoch(epoch)
248 |
249 | t0 = time.perf_counter()
250 |
251 | if step == 1 or step % (eval_period // dilation) == 0 or last_step:
252 | val_loss = 0.0
253 | val_steps = 20 # Arbitrarily set to reduce long evaluations
254 | model.eval()
255 | val_loader.reset()
256 | with torch.no_grad():
257 | for i, batch in zip(range(val_steps), val_loader, strict=False):
258 | inputs, targets = batch
259 | inputs, targets = inputs.to(device), targets.to(device)
260 | if torch_dtype != torch.float32:
261 | with autocast(device_type=device.type, dtype=torch_dtype, cache_enabled=cache_enabled):
262 | preds = model(inputs)
263 | else:
264 | preds = model(inputs)
265 |
266 | loss = loss_fn(preds.flatten(0, 1), targets.flatten(0, 1))
267 | loss = loss / val_steps
268 | val_loss += loss.detach().float()
269 |
270 | if distributed:
271 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
272 |
273 | if main_process:
274 | logging.info(f"Validation loss: {val_loss.item():.4f}")
275 | with open(log_file, "a") as f:
276 | f.write(f"{step} val {val_loss.item():.4f}\n")
277 |
278 | if step > 0 and (step % save_period == 0 or last_step):
279 | if val_loss.item() < best_val_loss:
280 | best_val_loss = val_loss.item()
281 | dist.barrier()
282 | model_checkpoint, optim_checkpoint = get_state_dict(
283 | model, optimizer, options=state_dict_options
284 | )
285 | if main_process:
286 | save_checkpoint(model_checkpoint, optim_checkpoint, config, step, best_val_loss, log_dir)
287 |
288 | model.train()
289 | train_loss = 0.0
290 | for micro_step, batch in zip(range(gradient_accumulation_steps), train_loader, strict=False):
291 | inputs, targets = batch
292 | inputs, targets = inputs.to(device), targets.to(device)
293 |
294 | last_micro_step = micro_step == gradient_accumulation_steps - 1
295 |
296 | if (fsdp or ddp) and not last_micro_step:
297 | context = model.no_sync()
298 | else:
299 | context = nullcontext()
300 |
301 | with context:
302 | if torch_dtype != torch.float32:
303 | with autocast(device_type=device.type, dtype=torch_dtype, cache_enabled=cache_enabled):
304 | preds = model(inputs)
305 | else:
306 | preds = model(inputs)
307 |
308 | loss = loss_fn(preds.flatten(0, 1), targets.flatten(0, 1))
309 | loss = loss / gradient_accumulation_steps
310 | train_loss += loss.detach().float()
311 | scaler.scale(loss).backward() if use_scaler else loss.backward()
312 |
313 | if distributed:
314 | dist.all_reduce(train_loss, op=dist.ReduceOp.AVG)
315 |
316 | if use_scaler:
317 | scaler.unscale_(optimizer)
318 |
319 | if fsdp:
320 | norm = model.clip_grad_norm_(max_norm)
321 | else:
322 | norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
323 |
324 | lr = linear_decay_with_warmup(step, warmup_steps, max_steps, max_lr, min_lr)
325 | for param_group in optimizer.param_groups:
326 | param_group['lr'] = lr
327 |
328 | if use_scaler:
329 | scaler.step(optimizer)
330 | scaler.update()
331 | else:
332 | optimizer.step()
333 | optimizer.zero_grad()
334 |
335 | torch.cuda.synchronize()
336 | t1 = time.perf_counter()
337 | dt = t1 - t0
338 | toks_processed = (
339 | train_loader.bsz
340 | * train_loader.seq_len
341 | * gradient_accumulation_steps
342 | * world_size
343 | )
344 | toks_per_sec = toks_processed / dt
345 |
346 | if main_process:
347 | log_message = (
348 | f"step {step:5d} | "
349 | f"loss: {train_loss:.6f} | "
350 | f"lr {lr:.4e} | "
351 | f"norm: {norm:.4f} | "
352 | f"dt: {dt*1000:.2f}ms | "
353 | f"tok/s: {toks_per_sec:.2f}"
354 | )
355 | print(log_message)
356 |
357 | with open(log_file, "a") as f:
358 | f.write(f"{step} train {train_loss:.6f} lr {lr:.4e} norm {norm:.4f} dt {dt*1000:.2f} tok/s {toks_per_sec:.2f}\n")
359 |
360 | cleanup_distributed(rank)
361 |
362 | if __name__ == "__main__":
363 | main()
364 |
--------------------------------------------------------------------------------
/training/job.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=stu # Name of the job
3 | #SBATCH --nodes=1 # Number of nodes
4 | #SBATCH --ntasks-per-node=1 # Each node runs 1 task that manages all GPUs
5 | #SBATCH --gpus-per-task=8 # Number of GPUs to allocate per task
6 | #SBATCH --cpus-per-task=8 # Must match >= GPUs on the task
7 | #SBATCH --mem=48G # Total memory for job
8 | #SBATCH --time=15:59:00 # Max time limit
9 |
10 | #SBATCH --error=stu_%j.err
11 | #SBATCH --output=stu_%j.out
12 |
13 | # Logging
14 | log_info() {
15 | echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
16 | }
17 |
18 | # Error handling
19 | set -e
20 | trap 'log_info "Error on line $LINENO"; exit 1' ERR
21 |
22 | # Activate your virtual environment accordingly
23 | source activate .venv/bin/activate
24 |
25 | # Get the first node (master node) from the SLURM_JOB_NODELIST
26 | MASTER_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n1)
27 |
28 | # Get the IP address of the master node
29 | MASTER_NODE_ADDR=$(srun --nodes=1 --ntasks=1 -w "$MASTER_NODE" hostname --ip-address)
30 |
31 | # Find an available port
32 | RDZV_PORT=$(python3 -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()")
33 |
34 | # Log start of training run
35 | log_info "Starting training run..."
36 |
37 | # Run the script using torchrun
38 | torchrun \
39 | --nnodes 1 \
40 | --nproc_per_node 8 \
41 | --rdzv_id $SLURM_JOB_ID \
42 | --rdzv_backend c10d \
43 | --rdzv_endpoint $MASTER_NODE_ADDR:$RDZV_PORT \
44 | --max-restarts 16 \
45 | example.py
46 |
47 | # Log end of training run
48 | log_info "Job finished."
49 |
50 | ##################### HELPFUL SLURM COMMANDS #####################
51 |
52 | # Disk usage and quota info; request additional space via link.
53 | # `checkquota`
54 |
55 | # Operating system details.
56 | # `cat /etc/os-release`
57 |
58 | # CPU specifications on the current node.
59 | # `lscpu`
60 |
61 | # Compute node information (standard and easy-to-read formats).
62 | # `snodes`
63 | # `shownodes`
64 |
65 | # Cluster nodes usage overview; check status (idle, down, busy).
66 | # `sinfo`
67 | # GPU usage specifics.
68 | # `sinfo -p gpu`
69 |
70 | # Quality of Service insights: job partitions and limits.
71 | # `qos`
72 |
73 | # Current processor activity; exit with 'q'.
74 | # `top`
75 | # `htop`
76 |
77 | # Overview of group-assigned cluster shares.
78 | # `sshare`
79 |
80 | # Job priority mechanics: factors and weights.
81 | # `sprio -w`
82 |
83 | # Performance of a completed job by job ID.
84 | # `seff `
85 |
86 | # Your historical job records.
87 | # `shistory`
88 |
89 | # Detailed job statistics (memory, CPU, GPU).
90 | # `jobstats `
91 |
92 | # Additional commands from your list for GPU details:
93 |
94 | # Details about GPUs on the cluster.
95 | # `snodes`
96 |
97 | # Number of available GPUs.
98 | # `shownodes -p gpu,mig`
99 |
100 | # GPU utilization, refreshed every 10 min.
101 | # `gpudash`
102 | # Specific to your jobs.
103 | # `gpudash -u $USER`
104 |
105 | # Real-time GPU status on active jobs.
106 | # `nvidia-smi` OR `watch nvidia-smi`
107 |
108 | # Your queue status.
109 | # `squeue -u $USER`
110 |
--------------------------------------------------------------------------------