├── .gitignore ├── LICENSE ├── README.md ├── docs └── flash-stu.webp ├── flash_stu ├── __init__.py ├── config.py ├── layers │ ├── __init__.py │ ├── attention_layer.py │ └── stu_layer.py ├── model.py ├── modules │ ├── __init__.py │ ├── attention.py │ ├── stu.py │ └── swiglu.py └── utils │ ├── __init__.py │ ├── numerics.py │ ├── stu_utils.py │ └── training.py ├── pyproject.toml └── training ├── config.json ├── data.py ├── dataloader.py ├── distributed.py ├── example.py └── job.slurm /.gitignore: -------------------------------------------------------------------------------- 1 | # =============================== 2 | # Byte-compiled / Optimized / DLL files 3 | # =============================== 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # =============================== 9 | # Artifacts 10 | # =============================== 11 | .empty/ 12 | **/*.out 13 | log*/ 14 | ignore/ 15 | 16 | # =============================== 17 | # Datasets 18 | # =============================== 19 | data/ 20 | 21 | # =============================== 22 | # Scripts and tests 23 | # =============================== 24 | **/*test*.* 25 | 26 | # =============================== 27 | # Plotting (To be added in the next commit) 28 | # =============================== 29 | plot.py 30 | 31 | # =============================== 32 | # C Extensions 33 | # =============================== 34 | *.so 35 | 36 | # =============================== 37 | # Distribution / Packaging 38 | # =============================== 39 | .Python 40 | build/ 41 | develop-eggs/ 42 | dist/ 43 | downloads/ 44 | eggs/ 45 | .eggs/ 46 | lib/ 47 | lib64/ 48 | parts/ 49 | sdist/ 50 | var/ 51 | wheels/ 52 | share/python-wheels/ 53 | *.egg-info/ 54 | .installed.cfg 55 | *.egg 56 | MANIFEST 57 | 58 | # =============================== 59 | # PyInstaller 60 | # =============================== 61 | # Usually these files are written by a Python script from a template before PyInstaller builds 62 | # the exe, so as to inject date/other info into it. 63 | *.manifest 64 | *.spec 65 | 66 | # =============================== 67 | # Installer Logs 68 | # =============================== 69 | pip-log.txt 70 | pip-delete-this-directory.txt 71 | 72 | # =============================== 73 | # Unit Test / Coverage Reports 74 | # =============================== 75 | htmlcov/ 76 | .tox/ 77 | .nox/ 78 | .coverage 79 | .coverage.* 80 | .cache 81 | nosetests.xml 82 | coverage.xml 83 | *.cover 84 | *.py,cover 85 | .hypothesis/ 86 | .pytest_cache/ 87 | cover/ 88 | 89 | # =============================== 90 | # Translations 91 | # =============================== 92 | *.mo 93 | *.pot 94 | 95 | # =============================== 96 | # Django 97 | # =============================== 98 | *.log 99 | local_settings.py 100 | db.sqlite3 101 | db.sqlite3-journal 102 | 103 | # =============================== 104 | # Flask 105 | # =============================== 106 | instance/ 107 | .webassets-cache 108 | 109 | # =============================== 110 | # Scrapy 111 | # =============================== 112 | .scrapy 113 | 114 | # =============================== 115 | # Sphinx Documentation 116 | # =============================== 117 | docs/_build/ 118 | 119 | # =============================== 120 | # PyBuilder 121 | # =============================== 122 | .pybuilder/ 123 | target/ 124 | 125 | # =============================== 126 | # Jupyter Notebook 127 | # =============================== 128 | .ipynb_checkpoints 129 | 130 | # =============================== 131 | # IPython 132 | # =============================== 133 | profile_default/ 134 | ipython_config.py 135 | 136 | # =============================== 137 | # pyenv 138 | # =============================== 139 | # For a library or package, you might want to ignore these files since the code is 140 | # intended to run in multiple environments; otherwise, check them in. 141 | # .python-version 142 | 143 | # =============================== 144 | # pipenv 145 | # =============================== 146 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 147 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 148 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 149 | # install all needed dependencies. 150 | #Pipfile.lock 151 | 152 | # =============================== 153 | # poetry 154 | # =============================== 155 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 156 | # Especially recommended for binary packages to ensure reproducibility. 157 | #poetry.lock 158 | 159 | # =============================== 160 | # pdm 161 | # =============================== 162 | .pdm.toml 163 | .pdm-python 164 | .pdm-build/ 165 | 166 | # =============================== 167 | # PEP 582 (Used by tools like PyFlow, pdm) 168 | # =============================== 169 | __pypackages__/ 170 | 171 | # =============================== 172 | # Celery 173 | # =============================== 174 | celerybeat-schedule 175 | celerybeat.pid 176 | 177 | # =============================== 178 | # SageMath 179 | # =============================== 180 | *.sage.py 181 | 182 | # =============================== 183 | # Environments 184 | # =============================== 185 | .env 186 | .venv 187 | env/ 188 | venv/ 189 | ENV/ 190 | env.bak/ 191 | venv.bak/ 192 | 193 | # Artifact of uv 194 | hello.py 195 | 196 | # =============================== 197 | # Spyder 198 | # =============================== 199 | .spyderproject 200 | .spyproject 201 | 202 | # =============================== 203 | # Rope 204 | # =============================== 205 | .ropeproject 206 | 207 | # =============================== 208 | # mkdocs 209 | # =============================== 210 | /site 211 | 212 | # =============================== 213 | # mypy 214 | # =============================== 215 | .mypy_cache/ 216 | .dmypy.json 217 | dmypy.json 218 | 219 | # =============================== 220 | # Pyre 221 | # =============================== 222 | .pyre/ 223 | 224 | # =============================== 225 | # pytype 226 | # =============================== 227 | .pytype/ 228 | 229 | # =============================== 230 | # Cython 231 | # =============================== 232 | cython_debug/ 233 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 Windsor Nguyen 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ⚡️ Flash STU ⚡️ 2 | 3 |
4 | Flash STU Logo 5 |
6 | 7 | ## Table of Contents 8 | 1. [Introduction](#introduction) 9 | 2. [Features](#features) 10 | 3. [Installation](#installation) 11 | 4. [Usage](#usage) 12 | 5. [Configuration](#configuration) 13 | 6. [Contributing](#contributing) 14 | 7. [License](#license) 15 | 8. [Acknowledgments](#acknowledgments) 16 | 17 | ## Introduction 18 | 19 | This repository complements the [Flash STU: Fast Spectral Transform Units](https://arxiv.org/abs/2409.10489) paper and contains an optimized, open-source PyTorch implementation of the Spectral Transform Unit (STU) as proposed in [*Spectral State Space Models*](https://arxiv.org/abs/2312.06837) by Agarwal et al. (2024). 20 | 21 | The [STU](stu.py) module is a fast and flexible building block that can be adapted into a wide range of neural network architectures, especially those that aim to solve tasks with long-range dependencies. 22 | 23 | ## Features 24 | 25 | - ⚡️ Fast convolutions using [Flash FFT](https://github.com/HazyResearch/flash-fft-conv) 26 | - 🚀 Fast, local attention using (sliding window) [Flash Attention](https://github.com/Dao-AILab/flash-attention) 27 | - 🌐 Support for distributed training using [DDP](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) and [FSDP](https://pytorch.org/docs/stable/fsdp.html) 28 | 29 | ## Installation 30 | 31 | > **Note**: CUDA is required to run code from this repository. 32 | 33 | This repository was tested with: 34 | - Python 3.12.5 35 | - PyTorch 2.4.1 36 | - Triton 3.0.0 37 | - CUDA 12.4 38 | 39 | and may be incompatible with other versions. 40 | 41 | 1. Install PyTorch with CUDA support: 42 | ```bash 43 | pip install torch --index-url https://download.pytorch.org/whl/cu124 44 | ``` 45 | 46 | 2. Install required packages: 47 | ```bash 48 | pip install -e . 49 | ``` 50 | 51 | 2. Install Flash Attention: 52 | ```bash 53 | MAX_JOBS=4 pip install flash-attn --no-build-isolation 54 | ``` 55 | 56 | 3. Install Flash FFT: 57 | ```bash 58 | pip install git+https://github.com/HazyResearch/flash-fft-conv.git#subdirectory=csrc/flashfftconv 59 | pip install git+https://github.com/HazyResearch/flash-fft-conv.git 60 | ``` 61 | 62 | Or from source: 63 | ``` 64 | pip install git+https://github.com/windsornguyen/flash-stu.git 65 | ``` 66 | 67 | ## Usage 68 | 69 | ### Using Flash STU 70 | 71 | Here is an example of how to import and use Flash STU: 72 | ``` python 73 | from flash_stu import FlashSTU, FlashSTUConfig, get_spectral_filters 74 | import torch 75 | 76 | device = torch.device('cuda') # Flash STU requires CUDA 77 | 78 | config = FlashSTUConfig( 79 | MODIFY_YOUR_ARGS_HERE, 80 | ) 81 | 82 | phi = get_spectral_filters( 83 | config.seq_len, 84 | config.num_eigh, 85 | config.use_hankel_L, 86 | device, 87 | config.torch_dtype 88 | ) 89 | 90 | model = FlashSTU( 91 | config, 92 | phi 93 | ) 94 | 95 | y = model(x) 96 | ``` 97 | 98 | ### Training 99 | 100 | An example LLM pretraining script is provided in [`example.py`](training/example.py) for you to test out the repository. 101 | 102 | If your compute cluster does not have internet access, you will need to pre-download the entire dataset before running the example training script. 103 | 104 | To download the dataset, run: 105 | ```bash 106 | cd training 107 | python data.py 108 | ``` 109 | 110 | > **Note**: The FineWeb-Edu 10B-token sample is a relatively large dataset. It can be swapped out for something smaller, e.g. [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) (476.6M tokens). 111 | 112 | To begin training, make sure you are in the `training` directory and run the following command in your terminal: 113 | 114 | ```bash 115 | torchrun example.py 116 | ``` 117 | 118 | If you are in a compute cluster that uses Slurm and [environment modules](https://modules.readthedocs.io/en/latest/index.html), you can submit a job using the following command: 119 | ```bash 120 | sbatch job.slurm 121 | ``` 122 | 123 | Model configurations can be adjusted as needed in [`config.json`](training/config.json). Be sure to adjust the configurations of the [Slurm job](training/job.slurm) based on your cluster's constraints. 124 | 125 | > **Note**: PyTorch's `torch.compile` currently does not have great support for distributed wrapper modules like DDP or FSDP. If you encounter errors during training, try disabling `torch.compile`. For more information on `torch.compile`, see this [informal manual](https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.ivdr7fmrbeab). 126 | 127 | 128 | ## Contributing 129 | 130 | Contributions are welcomed! Writing performant distributed code is always tricky. We welcome contributors to: 131 | 132 | - Submit pull requests 133 | - Report issues 134 | - Help improve the project overall 135 | 136 | ## License 137 | 138 | Apache 2.0 License 139 | 140 | You can freely use, modify, and distribute the software, **even in proprietary products**, as long as you: 141 | - Include proper attribution 142 | - Include a copy of the license 143 | - Mention any changes made 144 | 145 | It also provides an express grant of patent rights from contributors. 146 | 147 | See the [LICENSE](LICENSE) file for more details. 148 | 149 | ## Acknowledgments 150 | 151 | Special thanks to (in no particular order): 152 | - Elad Hazan and the authors of the [Spectral State Space Models](https://arxiv.org/abs/2312.06837) paper 153 | - Isabel Liu, Yagiz Devre, Evan Dogariu 154 | - The Flash Attention team 155 | - The Flash FFT team 156 | - The PyTorch team 157 | - Princeton Research Computing and Princeton Language and Intelligence, for supplying compute 158 | - Andrej Karpathy, for his awesome [NanoGPT](https://github.com/karpathy/build-nanogpt) repository 159 | 160 | ## Citation 161 | 162 | If you use this repository, or otherwise find our work valuable, please cite Flash STU: 163 | ``` 164 | @article{flashstu, 165 | title={Flash STU: Fast Spectral Transform Units}, 166 | author={Y. Isabel Liu, Windsor Nguyen, Yagiz Devre, Evan Dogariu, Anirudha Majumdar, Elad Hazan}, 167 | journal={arXiv preprint arXiv:2409.10489}, 168 | year={2024}, 169 | url={https://arxiv.org/abs/2409.10489} 170 | } 171 | -------------------------------------------------------------------------------- /docs/flash-stu.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hazan-lab/flash-stu/54f78d4b91cc068ccc23554172beab1277aa8ab1/docs/flash-stu.webp -------------------------------------------------------------------------------- /flash_stu/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import FlashSTU, FlashSTUConfig 2 | from .modules.stu import STU 3 | from .utils.stu_utils import get_spectral_filters 4 | -------------------------------------------------------------------------------- /flash_stu/config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import PretrainedConfig 4 | 5 | 6 | class FlashSTUConfig(PretrainedConfig): 7 | model_type = "FlashSTU" 8 | 9 | def __init__( 10 | self, 11 | bsz: int = 1, 12 | n_embd: int = 1536, 13 | n_heads: int = 8, 14 | n_layers: int = 26, 15 | seq_len: int = 8192, 16 | window_size: int = 1024, 17 | vocab_size: int = 200064, 18 | mlp_scale: int = 12, 19 | bias: bool = False, 20 | dropout: float = 0.0, 21 | num_eigh: int = 24, 22 | use_hankel_L: bool = False, 23 | use_flash_fft: bool = True, 24 | use_approx: bool = True, 25 | use_attn: bool = True, 26 | softcap: float = 50.0, 27 | torch_dtype: torch.dtype = torch.bfloat16, 28 | **kwargs, 29 | ): 30 | super().__init__(**kwargs) 31 | self.bsz = bsz 32 | self.n_embd = n_embd 33 | self.n_heads = n_heads 34 | self.n_layers = n_layers 35 | self.seq_len = seq_len 36 | self.window_size = window_size 37 | self.vocab_size = vocab_size 38 | self.hidden_size = n_embd 39 | self.intermediate_size = n_embd * mlp_scale 40 | self.hidden_act = "swish" 41 | self.bias = bias 42 | self.dropout = dropout 43 | self.num_eigh = num_eigh 44 | self.use_hankel_L = use_hankel_L 45 | self.use_flash_fft = use_flash_fft 46 | self.use_approx = use_approx 47 | self.use_attn = use_attn 48 | self.softcap = softcap 49 | self.torch_dtype = torch_dtype 50 | -------------------------------------------------------------------------------- /flash_stu/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hazan-lab/flash-stu/54f78d4b91cc068ccc23554172beab1277aa8ab1/flash_stu/layers/__init__.py -------------------------------------------------------------------------------- /flash_stu/layers/attention_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from flash_stu.modules.attention import Attention 5 | from flash_stu.modules.swiglu import MLP 6 | 7 | try: 8 | from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP 9 | 10 | triton_mlp = True 11 | except ImportError as e: 12 | print( 13 | f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead." 14 | ) 15 | triton_mlp = False 16 | 17 | try: 18 | from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm 19 | 20 | triton_norm = True 21 | except ImportError as e: 22 | print( 23 | f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation." 24 | ) 25 | from torch.nn import RMSNorm 26 | 27 | triton_norm = False 28 | 29 | 30 | class AttentionLayer(nn.Module): 31 | def __init__(self, config) -> None: 32 | super(AttentionLayer, self).__init__() 33 | self.attn_norm = ( 34 | TritonNorm(config.n_embd) 35 | if triton_norm 36 | else RMSNorm(config.n_embd, dtype=config.torch_dtype) 37 | ) 38 | self.attn = Attention(config) 39 | self.mlp_norm = ( 40 | TritonNorm(config.n_embd) 41 | if triton_norm 42 | else RMSNorm(config.n_embd, dtype=config.torch_dtype) 43 | ) 44 | self.mlp = ( 45 | TritonMLP(config) if triton_mlp else MLP(config, dtype=config.torch_dtype) 46 | ) 47 | 48 | # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP 49 | self.attn_norm = self.attn_norm.to(dtype=config.torch_dtype) 50 | self.mlp = self.mlp.to(dtype=config.torch_dtype) 51 | self.mlp_norm = self.mlp_norm.to(dtype=config.torch_dtype) 52 | 53 | def forward(self, x: torch.Tensor) -> torch.Tensor: 54 | x = x + self.attn(self.attn_norm(x)) 55 | x = x + self.mlp(self.mlp_norm(x)) 56 | return x 57 | -------------------------------------------------------------------------------- /flash_stu/layers/stu_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from flash_stu.modules.stu import STU 5 | from flash_stu.modules.swiglu import MLP 6 | 7 | try: 8 | from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP 9 | triton_mlp = True 10 | except ImportError as e: 11 | print( 12 | f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead." 13 | ) 14 | triton_mlp = False 15 | 16 | try: 17 | from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm 18 | triton_norm = True 19 | except ImportError as e: 20 | print( 21 | f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation." 22 | ) 23 | from torch.nn import RMSNorm 24 | triton_norm = False 25 | 26 | 27 | class STULayer(nn.Module): 28 | def __init__(self, config, phi, n): 29 | super(STULayer, self).__init__() 30 | self.stu_norm = ( 31 | TritonNorm(config.n_embd) 32 | if triton_norm 33 | else RMSNorm(config.n_embd, dtype=config.torch_dtype) 34 | ) 35 | self.stu = STU(config, phi, n) 36 | self.mlp_norm = ( 37 | TritonNorm(config.n_embd) 38 | if triton_norm 39 | else RMSNorm(config.n_embd, dtype=config.torch_dtype) 40 | ) 41 | self.mlp = ( 42 | TritonMLP(config) if triton_mlp else MLP(config, dtype=config.torch_dtype) 43 | ) 44 | 45 | # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP 46 | self.stu_norm = self.stu_norm.to(dtype=config.torch_dtype) 47 | self.mlp = self.mlp.to(dtype=config.torch_dtype) 48 | self.mlp_norm = self.mlp_norm.to(dtype=config.torch_dtype) 49 | 50 | def forward(self, x: torch.Tensor) -> torch.Tensor: 51 | x = x + self.stu(self.stu_norm(x)) 52 | x = x + self.mlp(self.mlp_norm(x)) 53 | return x 54 | -------------------------------------------------------------------------------- /flash_stu/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import PreTrainedModel 5 | 6 | from flash_stu.modules.stu import STU 7 | from flash_stu.modules.attention import Attention 8 | from flash_stu.utils.numerics import nearest_power_of_two 9 | from flash_stu.config import FlashSTUConfig 10 | from flash_stu.layers.stu_layer import STULayer 11 | from flash_stu.layers.attention_layer import AttentionLayer 12 | 13 | try: 14 | from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm 15 | triton_norm = True 16 | except ImportError as e: 17 | print( 18 | f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation." 19 | ) 20 | from torch.nn import RMSNorm 21 | 22 | triton_norm = False 23 | 24 | 25 | class FlashSTU(PreTrainedModel): 26 | config_class = FlashSTUConfig 27 | 28 | def __init__(self, config, phi) -> None: 29 | super(FlashSTU, self).__init__(config) 30 | self.n_layers = config.n_layers 31 | self.n = nearest_power_of_two(config.seq_len * 2 - 1, round_up=True) 32 | self.phi = phi 33 | self.use_approx = config.use_approx 34 | self.use_hankel_L = config.use_hankel_L 35 | 36 | # TODO: Add support for Liger-Kernel Embedding once no longer experimental 37 | self.tok_emb = nn.Embedding( 38 | config.vocab_size, config.n_embd, dtype=config.torch_dtype 39 | ) 40 | self.dropout = nn.Dropout(config.dropout) 41 | 42 | self.layers = nn.ModuleList() 43 | for layer_idx in range(self.n_layers): 44 | # For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887 45 | if layer_idx % 2 == 0: 46 | self.layers.append(STULayer(config, self.phi, self.n)) 47 | else: 48 | self.layers.append( 49 | AttentionLayer(config) 50 | if config.use_attn 51 | else STULayer(config, self.phi, self.n) 52 | ) 53 | 54 | self.norm = ( 55 | TritonNorm(config.n_embd) 56 | if triton_norm 57 | else RMSNorm(config.n_embd, dtype=config.torch_dtype) 58 | ) 59 | # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for RMS Norm 60 | self.norm = self.norm.to(dtype=config.torch_dtype) 61 | self.lm_head = nn.Linear( 62 | config.n_embd, config.vocab_size, bias=config.bias, dtype=config.torch_dtype 63 | ) 64 | self.tok_emb.weight = self.lm_head.weight 65 | 66 | self.std = (config.n_embd) ** -0.5 67 | self.apply(self._init_weights) 68 | print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,)) 69 | 70 | def forward(self, x: torch.Tensor) -> torch.tensor: 71 | tok_emb = self.tok_emb(x) 72 | x = self.dropout(tok_emb) 73 | 74 | for layer in self.layers: 75 | x = layer(x) 76 | 77 | x = self.norm(x) 78 | y_hat = self.lm_head(x) 79 | 80 | return y_hat 81 | 82 | def _get_num_params(self): 83 | n_params = sum(p.numel() for p in self.parameters()) 84 | if hasattr(self, "pos_emb") and self.pos_emb is not None: 85 | n_params -= self.pos_emb.weight.numel() 86 | if self.tok_emb.weight is not self.lm_head.weight: 87 | n_params -= self.tok_emb.weight.numel() 88 | return n_params 89 | 90 | def _init_weights(self, module): 91 | if isinstance(module, nn.Linear): 92 | if hasattr(module, "SCALE_INIT"): 93 | self.std *= (2 * self.n_layers) ** -0.5 94 | torch.nn.init.normal_(module.weight, mean=0.0, std=self.std) 95 | if module.bias is not None: 96 | torch.nn.init.zeros_(module.bias) 97 | elif isinstance(module, nn.Embedding): 98 | torch.nn.init.normal_(module.weight, mean=0.0, std=self.std) 99 | elif isinstance(module, STU): 100 | if self.use_approx: 101 | torch.nn.init.xavier_normal_(module.M_inputs) 102 | torch.nn.init.xavier_normal_(module.M_filters) 103 | else: 104 | torch.nn.init.xavier_normal_(module.M_phi_plus) 105 | if not self.use_hankel_L: 106 | torch.nn.init.xavier_normal_(module.M_phi_minus) 107 | elif isinstance(module, Attention): 108 | torch.nn.init.xavier_normal_(module.c_attn.weight) 109 | torch.nn.init.xavier_normal_(module.c_proj.weight) 110 | if module.c_attn.bias is not None: 111 | torch.nn.init.zeros_(module.c_attn.bias) 112 | if module.c_proj.bias is not None: 113 | torch.nn.init.zeros_(module.c_proj.bias) 114 | -------------------------------------------------------------------------------- /flash_stu/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hazan-lab/flash-stu/54f78d4b91cc068ccc23554172beab1277aa8ab1/flash_stu/modules/__init__.py -------------------------------------------------------------------------------- /flash_stu/modules/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from flash_stu.utils.numerics import nearest_power_of_two 7 | 8 | try: 9 | from flash_attn import flash_attn_func as fa2 10 | except ImportError as e: 11 | print( 12 | f"Unable to import Triton-based flash attention: {e}. No alternative currently available." 13 | ) 14 | # TODO: Add FlexAttention + local attention mask when it's in stable release 15 | 16 | class Attention(nn.Module): 17 | def __init__(self, config): 18 | super(Attention, self).__init__() 19 | assert torch.cuda.is_available(), "CUDA is required." 20 | assert config.n_embd % config.n_heads == 0 21 | self.n_heads = config.n_heads 22 | 23 | self.device = torch.device("cuda") 24 | self.bsz = config.bsz 25 | self.c_attn = nn.Linear( 26 | config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=config.torch_dtype 27 | ) 28 | self.c_proj = nn.Linear( 29 | config.n_embd, config.n_embd, bias=config.bias, dtype=config.torch_dtype 30 | ) 31 | self.c_proj.SCALE_INIT = 1 32 | self.dropout = config.dropout 33 | self.resid_dropout = nn.Dropout(self.dropout) 34 | self.alibi_slopes = self._get_alibi_slopes(self.n_heads) 35 | self.window_size = config.window_size 36 | self.softcap = config.softcap 37 | 38 | def _generate_slopes(self, n: int): 39 | start = 2 ** (-(2 ** -(math.log2(n) - 3))) 40 | return [start * (start**i) for i in range(n)] 41 | 42 | def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25): 43 | # If n_heads is a power of 2, generate slopes directly 44 | if math.log2(n_heads).is_integer(): 45 | slopes = self._generate_slopes(n_heads) 46 | else: 47 | # Get slopes for the nearest power of two 48 | n = nearest_power_of_two(n_heads, round_up=False) 49 | slopes_power_of_two = self._generate_slopes(n) 50 | 51 | # Generate extra slopes 52 | extra_slopes = self._generate_slopes(2 * n) 53 | extra_slopes_trunc = extra_slopes[0::2][: n_heads - n] 54 | slopes = slopes_power_of_two + extra_slopes_trunc 55 | slopes = torch.tensor(slopes, device=self.device) 56 | slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017 57 | return slopes 58 | 59 | def forward(self, x): 60 | bsz, seq_len, d_in = x.size() 61 | 62 | qkv = self.c_attn(x) 63 | q, k, v = torch.chunk(qkv, 3, dim=2) 64 | 65 | q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads) 66 | k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads) 67 | v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads) 68 | y = fa2( # https://arxiv.org/pdf/2307.08691 69 | q, 70 | k, 71 | v, 72 | dropout_p=self.dropout if self.training else 0.0, 73 | causal=True, 74 | window_size=(self.window_size, 0), 75 | alibi_slopes=self.alibi_slopes, # https://arxiv.org/pdf/2108.12409 76 | softcap=self.softcap, # https://arxiv.org/pdf/2408.00118 77 | ) 78 | y = y.contiguous().view(bsz, seq_len, d_in) 79 | y = self.resid_dropout(self.c_proj(y)) 80 | return y 81 | -------------------------------------------------------------------------------- /flash_stu/modules/stu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from flash_stu.utils.stu_utils import convolve, flash_convolve 5 | 6 | try: 7 | from flashfftconv import FlashFFTConv 8 | 9 | flash_fft_available = True 10 | except ImportError as e: 11 | print( 12 | f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation." 13 | ) 14 | flash_fft_available = False 15 | 16 | 17 | class STU(nn.Module): 18 | def __init__(self, config, phi, n) -> None: 19 | super(STU, self).__init__() 20 | self.config = config 21 | self.phi = phi 22 | self.n = n 23 | self.K = config.num_eigh 24 | self.d_in = config.n_embd 25 | self.d_out = config.n_embd 26 | self.use_hankel_L = config.use_hankel_L 27 | self.use_approx = config.use_approx 28 | self.flash_fft = ( 29 | FlashFFTConv(self.n, dtype=torch.bfloat16) 30 | if config.use_flash_fft and flash_fft_available 31 | else None 32 | ) 33 | if self.use_approx: 34 | self.M_inputs = nn.Parameter( 35 | torch.empty(self.d_in, self.d_out, dtype=config.torch_dtype) 36 | ) 37 | self.M_filters = nn.Parameter( 38 | torch.empty(self.K, self.d_in, dtype=config.torch_dtype) 39 | ) 40 | else: 41 | self.M_phi_plus = nn.Parameter( 42 | torch.empty(self.K, self.d_in, self.d_out, dtype=config.torch_dtype) 43 | ) 44 | if not self.use_hankel_L: 45 | self.M_phi_minus = nn.Parameter( 46 | torch.empty(self.K, self.d_in, self.d_out, dtype=config.torch_dtype) 47 | ) 48 | 49 | def forward(self, x: torch.Tensor) -> torch.Tensor: 50 | if self.use_approx: 51 | # Contract inputs and filters over the K and d_in dimensions, then convolve 52 | x_proj = x @ self.M_inputs 53 | phi_proj = self.phi @ self.M_filters 54 | if self.flash_fft: 55 | spectral_plus, spectral_minus = flash_convolve( 56 | x_proj, phi_proj, self.flash_fft, self.use_approx 57 | ) 58 | else: 59 | spectral_plus, spectral_minus = convolve( 60 | x_proj, phi_proj, self.n, self.use_approx 61 | ) 62 | else: 63 | # Convolve inputs and filters, 64 | if self.flash_fft: 65 | U_plus, U_minus = flash_convolve( 66 | x, self.phi, self.flash_fft, self.use_approx 67 | ) 68 | else: 69 | U_plus, U_minus = convolve(x, self.phi, self.n, self.use_approx) 70 | # Then, contract over the K and d_in dimensions 71 | spectral_plus = torch.tensordot( 72 | U_plus, self.M_phi_plus, dims=([2, 3], [0, 1]) 73 | ) 74 | if not self.use_hankel_L: 75 | spectral_minus = torch.tensordot( 76 | U_minus, self.M_phi_minus, dims=([2, 3], [0, 1]) 77 | ) 78 | 79 | return spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus 80 | -------------------------------------------------------------------------------- /flash_stu/modules/swiglu.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import functional as F 3 | 4 | class MLP(nn.Module): 5 | def __init__(self, config, dtype=None): 6 | # https://arxiv.org/pdf/2002.05202 7 | super().__init__() 8 | dtype = dtype if dtype is not None else config.torch_dtype 9 | self.hidden_size = config.n_embd 10 | self.intermediate_size = config.n_embd * config.mlp_scale 11 | self.gate_proj = nn.Linear( 12 | self.hidden_size, self.intermediate_size, bias=config.bias, dtype=dtype 13 | ) 14 | self.up_proj = nn.Linear( 15 | self.hidden_size, self.intermediate_size, bias=config.bias, dtype=dtype 16 | ) 17 | self.down_proj = nn.Linear( 18 | self.intermediate_size, self.hidden_size, bias=config.bias, dtype=dtype 19 | ) 20 | self.dropout = nn.Dropout( 21 | config.dropout 22 | ) # TODO: Write Issue in Liger-Kernel repo to support Dropout 23 | 24 | def forward(self, x): 25 | gate = self.gate_proj(x) 26 | gate = F.gelu(gate, approximate="tanh") 27 | up = self.up_proj(x) 28 | fuse = gate * up 29 | outputs = self.down_proj(fuse) 30 | outputs = self.dropout(outputs) 31 | return outputs 32 | -------------------------------------------------------------------------------- /flash_stu/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hazan-lab/flash-stu/54f78d4b91cc068ccc23554172beab1277aa8ab1/flash_stu/utils/__init__.py -------------------------------------------------------------------------------- /flash_stu/utils/numerics.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def nearest_power_of_two(x: int, round_up: bool = False) -> int: 4 | return ( 5 | 1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x)) 6 | ) 7 | -------------------------------------------------------------------------------- /flash_stu/utils/stu_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from flashfftconv import FlashFFTConv 6 | 7 | from flash_stu.utils.numerics import nearest_power_of_two 8 | 9 | 10 | def get_hankel(seq_len: int, use_hankel_L: bool = False) -> np.ndarray: 11 | entries = np.arange(1, seq_len + 1, dtype=np.float64) 12 | i_plus_j = entries[:, None] + entries[None, :] 13 | 14 | if use_hankel_L: 15 | sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0 16 | denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0) 17 | Z = sgn * (8.0 / denom) 18 | elif not use_hankel_L: 19 | Z = 2.0 / (i_plus_j**3 - i_plus_j) 20 | else: 21 | raise ValueError("use_hankel_L must be a boolean") 22 | 23 | return Z 24 | 25 | def get_spectral_filters( 26 | seq_len: int, 27 | K: int, 28 | use_hankel_L: bool = False, 29 | device: torch.device = None, 30 | dtype: torch.dtype = torch.bfloat16, 31 | ) -> torch.Tensor: 32 | assert torch.cuda.is_available(), "CUDA is required." 33 | Z = get_hankel(seq_len, use_hankel_L) 34 | sigma, phi = np.linalg.eigh(Z) 35 | sigma, phi = sigma[-K:], phi[:, -K:] 36 | phi *= sigma ** 0.25 37 | return torch.tensor(phi, device=device, dtype=dtype) 38 | 39 | def convolve(u: torch.Tensor, v: torch.Tensor, n: int, use_approx: bool = True) -> tuple[torch.Tensor, torch.Tensor]: 40 | bsz, seq_len, d_in = u.shape 41 | 42 | sgn = torch.full((1, seq_len, 1), 1, device=u.device) 43 | sgn[:, 1::2] *= -1 44 | if use_approx: 45 | _, d_out = v.shape 46 | v = v.view(1, -1, d_out, 1).to(torch.float32) 47 | else: 48 | _, K = v.shape 49 | sgn = sgn.unsqueeze(-1) 50 | v = v.view(1, -1, K, 1, 1).to(torch.float32) # (bsz, seq_len, K, d_in, stack) 51 | u = u.view(bsz, -1, 1, d_in).expand(bsz, -1, K, d_in) 52 | 53 | v = torch.fft.rfft(v, n=n, dim=1) 54 | U = torch.stack([u, u * sgn], dim=-1).to(torch.float32) 55 | U = torch.fft.rfft(U, n=n, dim=1) 56 | U_conv = torch.fft.irfft(v * U, n=n, dim=1)[:, :seq_len] 57 | U_plus, U_minus = torch.unbind(U_conv, dim=-1) 58 | U_minus = U_minus * sgn 59 | 60 | return U_plus, U_minus 61 | 62 | def flash_convolve( 63 | u: torch.Tensor, v: torch.Tensor, flash_fft: FlashFFTConv, use_approx: bool = True, 64 | ) -> tuple[torch.Tensor, torch.Tensor]: 65 | bsz, seq_len, d_in = u.shape 66 | _, K = v.shape 67 | 68 | padded_len = nearest_power_of_two(seq_len, round_up=True) 69 | pad_len = padded_len - seq_len 70 | 71 | sgn = torch.full((1, 1, padded_len), 1, device=u.device) 72 | sgn[:, :, 1::2] = -1 73 | 74 | if use_approx: 75 | u_padded = F.pad(u.transpose(1, 2), (0, pad_len)).to(torch.bfloat16).contiguous() 76 | v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32).contiguous() 77 | u_conv = torch.stack([u_padded, u_padded * sgn], dim=0).reshape(2 * bsz, d_in, padded_len) 78 | else: 79 | u_k_padded = F.pad(u.transpose(1, 2), (0, pad_len)).to(torch.bfloat16).repeat_interleave(K, dim=1).contiguous() 80 | v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32).repeat(d_in, 1).contiguous() 81 | u_conv = torch.stack([u_k_padded, u_k_padded * sgn], dim=0).reshape(2 * bsz, K * d_in, padded_len) 82 | 83 | U_conv = flash_fft(u_conv, v_padded) 84 | 85 | # Trim the output back to the original sequence length 86 | U_conv = U_conv[..., :seq_len] 87 | 88 | u_plus, u_minus = torch.chunk(U_conv, 2, dim=0) 89 | 90 | if use_approx: 91 | u_minus = u_minus * sgn[:, :, :seq_len] 92 | U_plus, U_minus = u_plus.transpose(1, 2), u_minus.transpose(1, 2) 93 | else: 94 | sgn = sgn[:, :, :seq_len].unsqueeze(-1).transpose(1, 2) 95 | U_plus = u_plus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() 96 | U_minus = u_minus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() * sgn 97 | 98 | return U_plus, U_minus 99 | -------------------------------------------------------------------------------- /flash_stu/utils/training.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | class GarbageCollection: 4 | def __init__(self, gc_freq=1000): 5 | assert gc_freq > 0, "gc_freq must be a positive integer" 6 | self.gc_freq = gc_freq 7 | gc.disable() 8 | gc.collect(1) 9 | 10 | def run(self, step_count): 11 | if step_count > 1 and step_count % self.gc_freq == 0: 12 | gc.collect(1) 13 | 14 | def linear_decay_with_warmup( # https://arxiv.org/pdf/2310.07831 15 | current_step: int, 16 | warmup_steps: int, 17 | num_steps: int, 18 | max_lr: float = 3e-4, 19 | min_lr: float = 3e-5, 20 | ) -> float: 21 | if current_step < warmup_steps: 22 | return min_lr + (max_lr - min_lr) * float(current_step) / float(max(warmup_steps, 1)) 23 | else: 24 | return max_lr - (max_lr - min_lr) * float(current_step - warmup_steps) / float(max(num_steps - warmup_steps, 1)) 25 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "flash-stu" 3 | version = "1.0.0" 4 | description = "An optimized, open-source PyTorch implementation of the Spectral Transform Unit (STU)" 5 | readme = "README.md" 6 | requires-python = ">=3.11" 7 | license = {file = "LICENSE"} 8 | 9 | authors = [ 10 | { name = "Isabel Liu", email = "isabel.liu@princeton.edu" }, 11 | { name = "Windsor Nguyen", email = "windsor@princetonalignment.org" }, 12 | { name = "Yagiz Devre", email = "yagiz.devre@princeton.edu" }, 13 | { name = "Evan Dogariu", email = "ed2719@nyu.edu" }, 14 | { name = "Anirudha Majumdar", email = "ani.majumdar@princeton.edu" }, 15 | { name = "Elad Hazan", email = "ehazan@princeton.edu" } 16 | ] 17 | 18 | maintainers = [ 19 | { name = "Windsor Nguyen", email = "windsor@princetonalignment.org" } 20 | ] 21 | 22 | keywords = [ 23 | "pytorch", 24 | "state space model", 25 | "spectral filtering", 26 | "deep learning", 27 | "machine learning", 28 | "time series", 29 | "dynamical systems", 30 | "language modeling", 31 | ] 32 | 33 | classifiers = [ 34 | "Programming Language :: Python :: 3", 35 | "License :: OSI Approved :: Apache License", 36 | "Operating System :: Unix" 37 | ] 38 | 39 | dependencies = [ 40 | "packaging==24.1", 41 | "wheel==0.44.0", 42 | "numpy==1.26.4", 43 | "ninja==1.11.1.1", 44 | "torch==2.4.1", 45 | "triton==3.0.0", 46 | "einops==0.8.0", 47 | "liger-kernel==0.2.1", 48 | "tiktoken==0.7.0", 49 | "transformers==4.44.0", 50 | "datasets==2.20.0", 51 | "safetensors==0.4.4", 52 | "psutil==6.0.0", 53 | "matplotlib==3.9.2", 54 | ] 55 | 56 | [tool.setuptools] 57 | packages = ["flash_stu"] 58 | 59 | # =============================== 60 | # Project URLs 61 | # =============================== 62 | [project.urls] 63 | Homepage = "https://arxiv.org/abs/2409.10489" 64 | Repository = "https://github.com/windsornguyen/flash-stu" 65 | 66 | # =============================== 67 | # Ruff Configuration 68 | # =============================== 69 | [tool.ruff] 70 | # Exclude commonly ignored directories from linting 71 | exclude = [ 72 | ".bzr", 73 | ".direnv", 74 | ".eggs", 75 | ".git", 76 | ".git-rewrite", 77 | ".hg", 78 | ".mypy_cache", 79 | ".nox", 80 | ".pants.d", 81 | ".pytype", 82 | ".ruff_cache", 83 | ".svn", 84 | ".tox", 85 | ".venv", 86 | "__pypackages__", 87 | "_build", 88 | "buck-out", 89 | "build", 90 | "dist", 91 | "node_modules", 92 | "venv", 93 | ] 94 | 95 | # General Ruff settings 96 | line-length = 119 97 | indent-width = 4 98 | target-version = "py311" 99 | 100 | # ------------------------------- 101 | # Ruff Linting Rules 102 | # ------------------------------- 103 | [tool.ruff.lint] 104 | # Enable specific linting rules 105 | select = ["E4", "E7", "E9", "F", "B", "Q"] 106 | ignore = [] 107 | 108 | # Define fixable and unfixable rules 109 | fixable = ["ALL"] 110 | unfixable = [] 111 | 112 | # ------------------------------- 113 | # Ruff Flake8-Quotes Settings 114 | # ------------------------------- 115 | [tool.ruff.lint.flake8-quotes] 116 | docstring-quotes = "double" # Enforce double quotes for docstrings 117 | inline-quotes = "double" # Enforce double quotes for inline strings 118 | 119 | # ------------------------------- 120 | # Ruff Formatting Settings 121 | # ------------------------------- 122 | [tool.ruff.format] 123 | quote-style = "double" # Use double quotes for auto-formatting 124 | indent-style = "space" # Indent with spaces instead of tabs 125 | skip-magic-trailing-comma = false # Respect magic trailing commas 126 | line-ending = "auto" # Automatically detect line endings 127 | 128 | [build-system] 129 | requires = ["setuptools", "wheel"] 130 | build-backend = "setuptools.build_meta" 131 | -------------------------------------------------------------------------------- /training/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "FlashSTU", 3 | "n_embd": 1536, 4 | "n_heads": 8, 5 | "n_layers": 26, 6 | "seq_len": 8192, 7 | "window_size": 1024, 8 | "vocab_size": 200064, 9 | "mlp_scale": 12, 10 | "bias": false, 11 | "dropout": 0.0, 12 | "num_eigh": 24, 13 | "use_hankel_L": false, 14 | "num_epochs": 1, 15 | "global_bsz": 524288, 16 | "bsz": 1, 17 | "warmup_steps": 19073, 18 | "eval_period": 25, 19 | "save_period": 19000, 20 | "max_lr": 3.0e-4, 21 | "min_lr": 3.0e-5, 22 | "max_norm": 1.0, 23 | "dilation": 1, 24 | "fsdp": true, 25 | "ddp": false, 26 | "mixed_precision": true, 27 | "torch_dtype": "bfloat16", 28 | "use_cpu_offload": false, 29 | "sharding_strategy": "full_shard", 30 | "state_dict_type": "full", 31 | "auto_wrap_policy": "partial", 32 | "backward_prefetch": "backward_pre", 33 | "forward_prefetch": false, 34 | "sync_module_states": true, 35 | "use_orig_params": true, 36 | "device_id": null, 37 | "precision": { 38 | "param": "bfloat16", 39 | "reduce": "bfloat16", 40 | "buffer": "bfloat16" 41 | }, 42 | "fsdp_modules": [ 43 | "STU", 44 | "Attention", 45 | "MLP" 46 | ], 47 | "use_activation_checkpointing": true, 48 | "use_flash_fft": true, 49 | "use_approx": true, 50 | "use_attn": true, 51 | "softcap": 50.0, 52 | "torch_compile": false 53 | } 54 | -------------------------------------------------------------------------------- /training/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import multiprocessing as mp 3 | 4 | import torch 5 | import tiktoken 6 | from datasets import load_dataset 7 | from tqdm import tqdm 8 | 9 | 10 | """Adapted from https://github.com/karpathy/build-nanogpt/blob/master/fineweb.py""" 11 | 12 | # Configuration 13 | local_dir = "data/fineweb-edu-10B" 14 | remote_name = "sample-10BT" 15 | shard_size = int(1e8) # 100M tokens per shard 16 | 17 | # Create the cache directory if it doesn't exist 18 | DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir) 19 | os.makedirs(DATA_CACHE_DIR, exist_ok=True) 20 | 21 | # Download the dataset 22 | fw = load_dataset("HuggingFaceFW/fineweb-edu", name=remote_name, split="train") 23 | 24 | # Initialize the tokenizer 25 | enc = tiktoken.get_encoding("o200k_base") 26 | eot = enc._special_tokens['<|endoftext|>'] 27 | 28 | def tokenize(doc): 29 | tokens = [eot] 30 | tokens.extend(enc.encode_ordinary(doc["text"])) 31 | tokens = torch.tensor(tokens, dtype=torch.int32) 32 | return tokens 33 | 34 | def write_datafile(file, tokens): 35 | torch.save(tokens, file + '.pt') 36 | 37 | # Tokenize all documents and write output shards 38 | nprocs = 8 39 | with mp.Pool(nprocs) as pool: 40 | shard_idx = 0 41 | all_tokens_tensor = torch.empty((shard_size,), dtype=torch.int32) 42 | token_count = 0 43 | progress_bar = None 44 | 45 | for tokens in pool.imap(tokenize, fw, chunksize=16): 46 | # If there is enough space in current shard for new tokens, 47 | if token_count + len(tokens) < shard_size: 48 | all_tokens_tensor[token_count:token_count+len(tokens)] = tokens 49 | token_count += len(tokens) 50 | if progress_bar is None: 51 | progress_bar = tqdm(total=shard_size, unit=" toks", desc=f"Shard {shard_idx}") 52 | progress_bar.update(len(tokens)) 53 | else: 54 | # Else, write the current shard and start a new one 55 | split = "val" if shard_idx == 0 else "train" 56 | filename = os.path.join(DATA_CACHE_DIR, f"fineweb-edu-10B_{split}_{shard_idx:06d}") 57 | 58 | # Split the document into whatever fits in this shard, remainder goes to next one 59 | remainder = shard_size - token_count 60 | progress_bar.update(remainder) 61 | all_tokens_tensor[token_count:token_count+remainder] = tokens[:remainder] 62 | write_datafile(filename, all_tokens_tensor) 63 | shard_idx += 1 64 | progress_bar = None 65 | 66 | # Populate the next shard with the leftovers of the current doc 67 | all_tokens_tensor[0:len(tokens)-remainder] = tokens[remainder:] 68 | token_count = len(tokens)-remainder 69 | 70 | # Write any remaining tokens as the last shard 71 | if token_count != 0: 72 | split = "val" if shard_idx == 0 else "train" 73 | filename = os.path.join(DATA_CACHE_DIR, f"fineweb-edu-10B_{split}_{shard_idx:06d}") 74 | write_datafile(filename, all_tokens_tensor[:token_count]) 75 | -------------------------------------------------------------------------------- /training/dataloader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | 7 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def load_tokens(filename): 12 | try: 13 | ext = os.path.splitext(filename)[1] 14 | 15 | if ext == ".npy": 16 | npt = np.load(filename) 17 | npt = npt.astype(np.int32) 18 | ptt = torch.tensor(npt, dtype=torch.long) 19 | return ptt 20 | elif ext == ".pt": 21 | return torch.load(filename, weights_only=True) 22 | else: 23 | raise ValueError(f"Unsupported file extension: {ext}") 24 | 25 | except Exception as e: 26 | logger.error(f"Error loading file {filename}: {str(e)}") 27 | raise 28 | 29 | class DistributedDataloader: 30 | def __init__( 31 | self, 32 | bsz: int, 33 | seq_len: int, 34 | rank: int, 35 | world_size: int, 36 | dataset: str, 37 | split: str, 38 | main_process: bool = False, 39 | ): 40 | self.bsz = bsz 41 | self.seq_len = seq_len 42 | self.rank = rank 43 | self.world_size = world_size 44 | assert split in {'train', 'val', 'test'}, f"Invalid split: {split}" 45 | 46 | data_root = dataset 47 | shards = [s for s in os.listdir(data_root) if split in s and (s.endswith('.pt') or s.endswith('.npy'))] 48 | self.shards = [os.path.join(data_root, s) for s in sorted(shards)] 49 | assert len(self.shards) > 0, f'No shards found for split {split}' 50 | if main_process: 51 | logger.info(f'Found {len(self.shards)} shards for split {split}') 52 | 53 | self.reset() 54 | 55 | def reset(self): 56 | self.current_shard = 0 57 | self.tokens = load_tokens(self.shards[self.current_shard]) 58 | self.current_position = (self.bsz * self.seq_len * self.rank) % len(self.tokens) 59 | 60 | def set_epoch(self, epoch): 61 | self.generator = torch.Generator() 62 | self.generator.manual_seed(epoch) 63 | self.shard_order = torch.randperm(len(self.shards), generator=self.generator).tolist() 64 | self.current_shard = self.shard_order[self.rank % len(self.shards)] 65 | self.tokens = load_tokens(self.shards[self.current_shard]) 66 | self.current_position = (self.bsz * self.seq_len * self.rank) % len(self.tokens) 67 | 68 | def __iter__(self): 69 | return self 70 | 71 | def __next__(self): 72 | if self.current_position + (self.bsz * self.seq_len + 1) > len(self.tokens): 73 | self.current_shard = (self.current_shard + 1) % len(self.shards) 74 | self.tokens = load_tokens(self.shards[self.current_shard]) 75 | self.current_position = (self.bsz * self.seq_len * self.rank) % len(self.tokens) 76 | 77 | buf = self.tokens[self.current_position : self.current_position + self.bsz * self.seq_len + 1] 78 | x = buf[:-1].view(self.bsz, self.seq_len) 79 | y = buf[1:].view(self.bsz, self.seq_len) 80 | 81 | self.current_position += self.bsz * self.seq_len * self.world_size 82 | return x, y.to(torch.long) 83 | -------------------------------------------------------------------------------- /training/distributed.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import socket 5 | import sys 6 | from packaging.version import parse as version_parse 7 | from functools import partial 8 | from safetensors.torch import save_file, load_file 9 | import numpy as np 10 | import psutil 11 | import torch 12 | import torch.cuda.nccl as nccl 13 | import torch.distributed as dist 14 | import torch.nn as nn 15 | from glob import glob 16 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 17 | CheckpointImpl, 18 | apply_activation_checkpointing, 19 | checkpoint_wrapper, 20 | ) 21 | from torch.distributed.fsdp import ( 22 | BackwardPrefetch, 23 | CPUOffload, 24 | FullyShardedDataParallel as FSDP, 25 | MixedPrecision, 26 | ShardingStrategy, 27 | ) 28 | from torch.distributed.fsdp.wrap import _module_wrap_policy, size_based_auto_wrap_policy 29 | 30 | from flash_stu import STU 31 | from flash_stu.modules.attention import Attention 32 | from flash_stu.modules.swiglu import MLP 33 | 34 | 35 | logging.basicConfig( 36 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 37 | ) 38 | logger = logging.getLogger(__name__) 39 | 40 | 41 | def set_seeds(seed: int, cuda_deterministic: bool = False) -> None: 42 | random.seed(seed) 43 | np.random.seed(seed) 44 | torch.manual_seed(seed) 45 | if torch.cuda.is_available(): 46 | torch.cuda.manual_seed(seed) 47 | torch.cuda.manual_seed_all(seed) 48 | 49 | if cuda_deterministic: 50 | torch.backends.cudnn.deterministic = True 51 | torch.backends.cudnn.benchmark = False 52 | 53 | logger.info(f"Random seeds set to {seed}") 54 | 55 | 56 | def setup_distributed(seed: int = 1337) -> tuple[torch.device, int, int, int, bool]: 57 | if not dist.is_available(): 58 | raise RuntimeError("Distributed package not available!") 59 | 60 | if not torch.cuda.is_available(): 61 | raise RuntimeError("CUDA is required for distributed training!") 62 | 63 | dist.init_process_group(backend="nccl") 64 | local_rank = int(os.environ["LOCAL_RANK"]) 65 | rank = dist.get_rank() 66 | world_size = dist.get_world_size() 67 | main_process = rank == 0 68 | 69 | # Set up devices 70 | device = torch.device(f"cuda:{local_rank}") 71 | torch.cuda.set_device(device) 72 | set_seeds(seed + rank) 73 | 74 | if main_process: 75 | logger.info(f"Main process initialized on {socket.gethostname()}") 76 | logger.info(f"PyTorch version: {torch.__version__}") 77 | logger.info(f"Python version: {sys.version}") 78 | logger.info(f"CUDA version: {torch.version.cuda}") 79 | logger.info(f"cuDNN version: {torch.backends.cudnn.version()}") 80 | logger.info( 81 | f"World info: size={world_size}, rank={rank}, local_rank={local_rank}" 82 | ) 83 | log_system_info(world_size, rank) 84 | 85 | return device, local_rank, rank, world_size, main_process 86 | 87 | 88 | def log_system_info(world_size: int, rank: int): 89 | logger.info(f"System info for rank {rank}:") 90 | logger.info(f"CPU count: {psutil.cpu_count()}") 91 | logger.info(f"Total RAM: {psutil.virtual_memory().total / (1024**3):.2f} GB") 92 | logger.info( 93 | f"Available RAM: {psutil.virtual_memory().available / (1024**3):.2f} GB" 94 | ) 95 | 96 | # Log local GPU count and world size for clarity 97 | local_gpu_count = torch.cuda.device_count() 98 | logger.info(f"Local GPU count (rank {rank}): {local_gpu_count}") 99 | logger.info(f"Total GPU count across all nodes: {world_size * local_gpu_count}") 100 | 101 | # Log specific GPU properties for this node (local GPUs) 102 | for i in range(local_gpu_count): 103 | logger.info(f"GPU {i} (rank {rank}) name: {torch.cuda.get_device_name(i)}") 104 | logger.info( 105 | f"GPU {i} (rank {rank}) memory: {torch.cuda.get_device_properties(i).total_memory / (1024**3):.2f} GB" 106 | ) 107 | 108 | 109 | def find_checkpoint(log_dir: str) -> str: 110 | model_pattern = os.path.join(log_dir, "model_*.safetensors") 111 | misc_pattern = os.path.join(log_dir, "other_checkpoints_*.pt") 112 | model_checkpoints = glob(model_pattern) 113 | misc_checkpoints = glob(misc_pattern) 114 | if not model_checkpoints or not misc_checkpoints: 115 | return None 116 | latest_checkpoint = max( 117 | model_checkpoints, key=lambda x: int(x.split("_")[-1].split(".")[0]) 118 | ) 119 | misc_checkpoint = max( 120 | misc_checkpoints, key=lambda x: int(x.split("_")[-1].split(".")[0]) 121 | ) 122 | return latest_checkpoint, misc_checkpoint 123 | 124 | 125 | def load_checkpoint(model_path: str, misc_path: str, model, optimizer, device): 126 | model_checkpoint = load_file(model_path) 127 | model.load_state_dict(model_checkpoint) 128 | model.to(device) 129 | 130 | misc_checkpoint = torch.load(misc_path, map_location=device, weights_only=True) 131 | model.config = misc_checkpoint["config"] 132 | optimizer.load_state_dict(misc_checkpoint["optimizer"]) 133 | 134 | step = misc_checkpoint["step"] 135 | val_loss = misc_checkpoint["val_loss"] 136 | 137 | return model, optimizer, step, val_loss 138 | 139 | 140 | def save_checkpoint( 141 | model_checkpoint, optim_checkpoint, config, step, best_val_loss, log_dir 142 | ): 143 | model_checkpoint_path = os.path.join(log_dir, f"model_{step:05d}.safetensors") 144 | other_checkpoint_path = os.path.join(log_dir, f"other_checkpoints_{step:05d}.pt") 145 | 146 | save_file(model_checkpoint, model_checkpoint_path) 147 | 148 | other_checkpoint = { 149 | "config": config, 150 | "optimizer": optim_checkpoint, 151 | "step": step, 152 | "val_loss": best_val_loss, 153 | } 154 | torch.save(other_checkpoint, other_checkpoint_path) 155 | 156 | logging.info( 157 | f"Validation loss improved at step {step}! Save the model to {model_checkpoint_path}, misc data to {other_checkpoint_path}." 158 | ) 159 | 160 | 161 | def setup_fsdp( 162 | model: nn.Module, 163 | mixed_precision: bool = True, 164 | use_cpu_offload: bool = False, 165 | sharding_strategy: str = "full_shard", 166 | auto_wrap_policy: str = "partial", 167 | backward_prefetch: str = "backward_pre", 168 | forward_prefetch: bool = False, 169 | sync_module_states: bool = True, 170 | use_orig_params: bool = True, 171 | device_id: int = None, 172 | precision: dict = None, 173 | fsdp_modules: list = None, 174 | use_activation_checkpointing: bool = True, 175 | ) -> tuple[FSDP, dict]: 176 | if not torch.cuda.is_available() or not dist.is_nccl_available(): 177 | raise RuntimeError("CUDA and NCCL must be available for FSDP setup") 178 | 179 | fsdp_params = {} 180 | 181 | bf16_ready = ( 182 | torch.version.cuda 183 | and torch.cuda.is_bf16_supported() 184 | and version_parse(torch.version.cuda) >= version_parse("11.0") 185 | and dist.is_nccl_available() 186 | and nccl.version() >= (2, 10) 187 | ) 188 | 189 | # Set up mixed precision 190 | if mixed_precision and precision: 191 | param_dtype = precision.get( 192 | "param", torch.bfloat16 if bf16_ready else torch.float32 193 | ) 194 | reduce_dtype = precision.get( 195 | "reduce", torch.bfloat16 if bf16_ready else torch.float32 196 | ) 197 | buffer_dtype = precision.get( 198 | "buffer", torch.bfloat16 if bf16_ready else torch.float32 199 | ) 200 | 201 | if isinstance(param_dtype, str): 202 | param_dtype = getattr(torch, param_dtype) 203 | if isinstance(reduce_dtype, str): 204 | reduce_dtype = getattr(torch, reduce_dtype) 205 | if isinstance(buffer_dtype, str): 206 | buffer_dtype = getattr(torch, buffer_dtype) 207 | 208 | fsdp_params["mixed_precision"] = MixedPrecision( 209 | param_dtype=param_dtype, 210 | reduce_dtype=reduce_dtype, 211 | buffer_dtype=buffer_dtype, 212 | ) 213 | 214 | # Set up CPU offloading 215 | if use_cpu_offload: 216 | fsdp_params["use_cpu_offload"] = CPUOffload(offload_params=True) 217 | 218 | # Set up sharding strategy 219 | if sharding_strategy == "full_shard": 220 | fsdp_params["sharding_strategy"] = ShardingStrategy.FULL_SHARD 221 | elif sharding_strategy == "shard_grad_op": 222 | fsdp_params["sharding_strategy"] = ShardingStrategy.SHARD_GRAD_OP 223 | elif sharding_strategy == "no_shard": 224 | fsdp_params["sharding_strategy"] = ShardingStrategy.NO_SHARD 225 | else: 226 | raise ValueError(f"Invalid sharding strategy: {sharding_strategy}") 227 | 228 | # Set up backward prefetch 229 | if backward_prefetch == "backward_pre": 230 | fsdp_params["backward_prefetch"] = BackwardPrefetch.BACKWARD_PRE 231 | elif backward_prefetch == "backward_post": 232 | fsdp_params["backward_prefetch"] = BackwardPrefetch.BACKWARD_POST 233 | elif backward_prefetch is not None: 234 | raise ValueError(f"Invalid backward prefetch option: {backward_prefetch}") 235 | 236 | # Set up other parameters 237 | fsdp_params["forward_prefetch"] = forward_prefetch 238 | fsdp_params["sync_module_states"] = sync_module_states 239 | fsdp_params["use_orig_params"] = use_orig_params 240 | 241 | if device_id is None: 242 | device_id = torch.cuda.current_device() 243 | fsdp_params["device_id"] = device_id 244 | 245 | # Set up auto wrap policy 246 | fsdp_modules_set = set( 247 | eval(module) if isinstance(module, str) else module for module in fsdp_modules 248 | ) 249 | if auto_wrap_policy == "partial": 250 | fsdp_params["auto_wrap_policy"] = partial( 251 | _module_wrap_policy, module_classes=fsdp_modules_set 252 | ) 253 | elif auto_wrap_policy == "size_based": 254 | fsdp_params["auto_wrap_policy"] = size_based_auto_wrap_policy 255 | else: 256 | raise ValueError(f"Invalid auto wrap policy: {auto_wrap_policy}") 257 | 258 | # Apply activation checkpointing 259 | if use_activation_checkpointing: 260 | check_fn = lambda submodule: isinstance(submodule, tuple(fsdp_modules_set)) 261 | apply_activation_checkpointing( 262 | model, 263 | checkpoint_wrapper_fn=partial( 264 | checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT 265 | ), 266 | check_fn=check_fn, 267 | ) 268 | 269 | # Wrap the model with FSDP 270 | fsdp_model = FSDP(model, **fsdp_params) 271 | 272 | return fsdp_model 273 | 274 | 275 | def cleanup_distributed(rank: int): 276 | if dist.is_initialized(): 277 | logging.info(f"[Rank {rank}]: Finished training.") 278 | logging.info(f"[Rank {rank}]: Waiting for other processes to finish...") 279 | dist.barrier() 280 | dist.destroy_process_group() 281 | -------------------------------------------------------------------------------- /training/example.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import time 5 | from contextlib import nullcontext 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from torch.amp import autocast, GradScaler 10 | from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict 11 | from torch.distributed.elastic.multiprocessing.errors import record 12 | from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from torch.optim import AdamW 15 | 16 | from dataloader import DistributedDataloader 17 | from distributed import ( 18 | cleanup_distributed, 19 | find_checkpoint, 20 | load_checkpoint, 21 | save_checkpoint, 22 | setup_distributed, 23 | setup_fsdp 24 | ) 25 | from flash_stu.config import FlashSTUConfig 26 | from flash_stu.model import FlashSTU 27 | from flash_stu.utils.stu_utils import get_spectral_filters 28 | from flash_stu.utils.training import GarbageCollection, linear_decay_with_warmup 29 | 30 | try: 31 | from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss as CrossEntropyLoss 32 | except ImportError as e: 33 | print(f"Unable to import Triton-based cross entropy loss: {e}. Falling back to PyTorch implementation.") 34 | from torch.nn import CrossEntropyLoss 35 | 36 | 37 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") 38 | logger = logging.getLogger(__name__) 39 | 40 | @record 41 | def main(): 42 | torch.set_float32_matmul_precision("high") 43 | device, local_rank, rank, world_size, main_process = setup_distributed(seed=1337) 44 | gc_handler = GarbageCollection() 45 | 46 | with open("config.json", "r") as file: 47 | config = json.load(file) 48 | 49 | # Model configurations 50 | n_embd = config['n_embd'] 51 | n_heads = config['n_heads'] 52 | n_layers = config['n_layers'] 53 | seq_len = config['seq_len'] 54 | window_size = config['window_size'] 55 | vocab_size = config['vocab_size'] 56 | mlp_scale = config['mlp_scale'] 57 | bias = config['bias'] 58 | dropout = config['dropout'] 59 | num_eigh = config['num_eigh'] 60 | use_hankel_L = config['use_hankel_L'] 61 | 62 | # Optimizations 63 | use_flash_fft = config['use_flash_fft'] 64 | use_approx = config['use_approx'] 65 | use_attn = config['use_attn'] 66 | softcap = config['softcap'] 67 | torch_compile = config['torch_compile'] 68 | 69 | # Training configurations 70 | dilation = config['dilation'] 71 | warmup_steps = config['warmup_steps'] // dilation 72 | eval_period = config['eval_period'] // dilation 73 | save_period = config['save_period'] // dilation 74 | num_epochs = config['num_epochs'] 75 | max_lr = config['max_lr'] 76 | min_lr = config['min_lr'] 77 | max_norm = config['max_norm'] 78 | 79 | global_bsz = config['global_bsz'] 80 | bsz = config['bsz'] 81 | assert ( 82 | global_bsz % (bsz * seq_len * world_size) == 0 83 | ), f"global_bsz ({global_bsz}) must be divisible by bsz * seq_len * world_size ({bsz * seq_len * world_size})," 84 | f" got {global_bsz % (bsz * seq_len * world_size)}" 85 | gradient_accumulation_steps = global_bsz // (bsz * seq_len * world_size) 86 | 87 | # Distributed 88 | fsdp = config['fsdp'] 89 | ddp = config['ddp'] 90 | assert not (fsdp and ddp), "FSDP and DDP are both enabled which is not allowed" 91 | 92 | distributed = (fsdp or ddp) and world_size > 1 93 | cache_enabled = not ddp 94 | 95 | mixed_precision = config['mixed_precision'] 96 | torch_dtype = getattr(torch, config['torch_dtype']) 97 | use_cpu_offload = config['use_cpu_offload'] 98 | sharding_strategy = config['sharding_strategy'] 99 | auto_wrap_policy = config['auto_wrap_policy'] 100 | backward_prefetch = config['backward_prefetch'] 101 | forward_prefetch = config['forward_prefetch'] 102 | sync_module_states = config['sync_module_states'] 103 | use_orig_params = config['use_orig_params'] 104 | device_id = config['device_id'] 105 | precision = config['precision'] 106 | fsdp_modules = config['fsdp_modules'] 107 | use_activation_checkpointing = config['use_activation_checkpointing'] 108 | 109 | if main_process: 110 | logging.info(f"Training config: {config}\n") 111 | 112 | if world_size == 1 and fsdp: 113 | if main_process: 114 | logging.info("World size is 1, disabling sharding.") 115 | sharding_strategy = "no_shard" 116 | 117 | config = FlashSTUConfig( 118 | bsz=bsz, 119 | n_embd=n_embd, 120 | n_heads=n_heads, 121 | n_layers=n_layers, 122 | seq_len=seq_len, 123 | window_size=window_size, 124 | vocab_size=vocab_size, 125 | mlp_scale=mlp_scale, 126 | bias=bias, 127 | dropout=dropout, 128 | num_eigh=num_eigh, 129 | use_hankel_L=use_hankel_L, 130 | use_flash_fft=use_flash_fft, 131 | use_approx=use_approx, 132 | use_attn=use_attn, 133 | softcap=softcap, 134 | torch_dtype=torch_dtype, 135 | ) 136 | fsdp_params = { 137 | "mixed_precision": mixed_precision, 138 | "use_cpu_offload": use_cpu_offload, 139 | "sharding_strategy": sharding_strategy, 140 | "auto_wrap_policy": auto_wrap_policy, 141 | "backward_prefetch": backward_prefetch, 142 | "forward_prefetch": forward_prefetch, 143 | "sync_module_states": sync_module_states, 144 | "use_orig_params": use_orig_params, 145 | "device_id": device_id, 146 | "precision": precision, 147 | "fsdp_modules": fsdp_modules, 148 | "use_activation_checkpointing": use_activation_checkpointing, 149 | } 150 | 151 | phi = get_spectral_filters(seq_len, num_eigh, use_hankel_L, device, torch_dtype) 152 | model = FlashSTU(config, phi) 153 | if torch_compile: 154 | model = torch.compile(model) 155 | if main_process: 156 | logging.info(f"PyTorch Compiler Enabled?: {torch_compile}") 157 | model = model.to(device) 158 | 159 | if fsdp: 160 | model = setup_fsdp(model, **fsdp_params) 161 | if ddp: 162 | model = DDP(model, device_ids=[local_rank], gradient_as_bucket_view=True) 163 | state_dict_options = StateDictOptions( 164 | full_state_dict=True, 165 | cpu_offload=True, 166 | ) 167 | torch.serialization.add_safe_globals([FlashSTUConfig]) 168 | 169 | use_scaler = torch_dtype == torch.float16 170 | if use_scaler: 171 | scaler = ShardedGradScaler() if fsdp else GradScaler() 172 | optimizer = AdamW( 173 | model.parameters(), 174 | lr=max_lr, 175 | fused=torch.cuda.is_available(), 176 | ) 177 | loss_fn = CrossEntropyLoss() 178 | 179 | # Create the log directory to write checkpoints to and log to 180 | log_dir = "log" 181 | os.makedirs(log_dir, exist_ok=True) 182 | 183 | checkpoints = find_checkpoint(log_dir) 184 | if checkpoints: 185 | model_checkpoint, misc_checkpoint = checkpoints 186 | model, optimizer, start_step, best_val_loss = load_checkpoint( 187 | model_checkpoint, misc_checkpoint, model, optimizer, device 188 | ) 189 | log_mode = "a" # Append to the log if resuming 190 | if main_process: 191 | logging.info(f"Resuming from checkpoint: {model_checkpoint}") 192 | logging.info(f"Starting from step: {start_step}") 193 | logging.info(f"Best validation loss: {best_val_loss}") 194 | else: 195 | start_step = 0 196 | best_val_loss = float('inf') 197 | log_mode = "w" # Create a new log if starting fresh 198 | 199 | if main_process: 200 | log_file = os.path.join(log_dir, "log.txt") 201 | with open(log_file, log_mode) as f: 202 | pass 203 | 204 | # Data loader section 205 | # 19,073 steps is ~1 epoch, if data is 10B tokens and batch size ~0.5M tokens 206 | # This is dataset and batch size dependent. 207 | dataset = "data/fineweb-edu-10B" 208 | total_tokens = 10_000_000_000 209 | num_steps = total_tokens // global_bsz # Number of steps for one epoch 210 | max_steps = num_steps * num_epochs 211 | 212 | if main_process: 213 | logging.info(f"Total (desired) batch size: {global_bsz}") 214 | logging.info( 215 | f"=> Calculated gradient accumulation steps: {gradient_accumulation_steps}" 216 | ) 217 | logging.info(f"Training on {max_steps} steps") 218 | 219 | train_loader = DistributedDataloader( 220 | bsz=bsz, 221 | seq_len=seq_len, 222 | rank=rank, 223 | world_size=world_size, 224 | dataset=dataset, 225 | split="train", 226 | main_process=main_process, 227 | ) 228 | val_loader = DistributedDataloader( 229 | bsz=bsz, 230 | seq_len=seq_len, 231 | rank=rank, 232 | world_size=world_size, 233 | dataset=dataset, 234 | split="val", 235 | main_process=main_process, 236 | ) 237 | 238 | for step in range(start_step + 1, max_steps + 1): 239 | gc_handler.run(step) 240 | torch.cuda.empty_cache() 241 | epoch = step // num_steps 242 | last_step = step % num_steps == 0 243 | 244 | if step == 1 or step % num_steps == 1: 245 | if main_process: 246 | logging.info(f"Starting epoch {epoch}") 247 | train_loader.set_epoch(epoch) 248 | 249 | t0 = time.perf_counter() 250 | 251 | if step == 1 or step % (eval_period // dilation) == 0 or last_step: 252 | val_loss = 0.0 253 | val_steps = 20 # Arbitrarily set to reduce long evaluations 254 | model.eval() 255 | val_loader.reset() 256 | with torch.no_grad(): 257 | for i, batch in zip(range(val_steps), val_loader, strict=False): 258 | inputs, targets = batch 259 | inputs, targets = inputs.to(device), targets.to(device) 260 | if torch_dtype != torch.float32: 261 | with autocast(device_type=device.type, dtype=torch_dtype, cache_enabled=cache_enabled): 262 | preds = model(inputs) 263 | else: 264 | preds = model(inputs) 265 | 266 | loss = loss_fn(preds.flatten(0, 1), targets.flatten(0, 1)) 267 | loss = loss / val_steps 268 | val_loss += loss.detach().float() 269 | 270 | if distributed: 271 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) 272 | 273 | if main_process: 274 | logging.info(f"Validation loss: {val_loss.item():.4f}") 275 | with open(log_file, "a") as f: 276 | f.write(f"{step} val {val_loss.item():.4f}\n") 277 | 278 | if step > 0 and (step % save_period == 0 or last_step): 279 | if val_loss.item() < best_val_loss: 280 | best_val_loss = val_loss.item() 281 | dist.barrier() 282 | model_checkpoint, optim_checkpoint = get_state_dict( 283 | model, optimizer, options=state_dict_options 284 | ) 285 | if main_process: 286 | save_checkpoint(model_checkpoint, optim_checkpoint, config, step, best_val_loss, log_dir) 287 | 288 | model.train() 289 | train_loss = 0.0 290 | for micro_step, batch in zip(range(gradient_accumulation_steps), train_loader, strict=False): 291 | inputs, targets = batch 292 | inputs, targets = inputs.to(device), targets.to(device) 293 | 294 | last_micro_step = micro_step == gradient_accumulation_steps - 1 295 | 296 | if (fsdp or ddp) and not last_micro_step: 297 | context = model.no_sync() 298 | else: 299 | context = nullcontext() 300 | 301 | with context: 302 | if torch_dtype != torch.float32: 303 | with autocast(device_type=device.type, dtype=torch_dtype, cache_enabled=cache_enabled): 304 | preds = model(inputs) 305 | else: 306 | preds = model(inputs) 307 | 308 | loss = loss_fn(preds.flatten(0, 1), targets.flatten(0, 1)) 309 | loss = loss / gradient_accumulation_steps 310 | train_loss += loss.detach().float() 311 | scaler.scale(loss).backward() if use_scaler else loss.backward() 312 | 313 | if distributed: 314 | dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) 315 | 316 | if use_scaler: 317 | scaler.unscale_(optimizer) 318 | 319 | if fsdp: 320 | norm = model.clip_grad_norm_(max_norm) 321 | else: 322 | norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 323 | 324 | lr = linear_decay_with_warmup(step, warmup_steps, max_steps, max_lr, min_lr) 325 | for param_group in optimizer.param_groups: 326 | param_group['lr'] = lr 327 | 328 | if use_scaler: 329 | scaler.step(optimizer) 330 | scaler.update() 331 | else: 332 | optimizer.step() 333 | optimizer.zero_grad() 334 | 335 | torch.cuda.synchronize() 336 | t1 = time.perf_counter() 337 | dt = t1 - t0 338 | toks_processed = ( 339 | train_loader.bsz 340 | * train_loader.seq_len 341 | * gradient_accumulation_steps 342 | * world_size 343 | ) 344 | toks_per_sec = toks_processed / dt 345 | 346 | if main_process: 347 | log_message = ( 348 | f"step {step:5d} | " 349 | f"loss: {train_loss:.6f} | " 350 | f"lr {lr:.4e} | " 351 | f"norm: {norm:.4f} | " 352 | f"dt: {dt*1000:.2f}ms | " 353 | f"tok/s: {toks_per_sec:.2f}" 354 | ) 355 | print(log_message) 356 | 357 | with open(log_file, "a") as f: 358 | f.write(f"{step} train {train_loss:.6f} lr {lr:.4e} norm {norm:.4f} dt {dt*1000:.2f} tok/s {toks_per_sec:.2f}\n") 359 | 360 | cleanup_distributed(rank) 361 | 362 | if __name__ == "__main__": 363 | main() 364 | -------------------------------------------------------------------------------- /training/job.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=stu # Name of the job 3 | #SBATCH --nodes=1 # Number of nodes 4 | #SBATCH --ntasks-per-node=1 # Each node runs 1 task that manages all GPUs 5 | #SBATCH --gpus-per-task=8 # Number of GPUs to allocate per task 6 | #SBATCH --cpus-per-task=8 # Must match >= GPUs on the task 7 | #SBATCH --mem=48G # Total memory for job 8 | #SBATCH --time=15:59:00 # Max time limit 9 | 10 | #SBATCH --error=stu_%j.err 11 | #SBATCH --output=stu_%j.out 12 | 13 | # Logging 14 | log_info() { 15 | echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1" 16 | } 17 | 18 | # Error handling 19 | set -e 20 | trap 'log_info "Error on line $LINENO"; exit 1' ERR 21 | 22 | # Activate your virtual environment accordingly 23 | source activate .venv/bin/activate 24 | 25 | # Get the first node (master node) from the SLURM_JOB_NODELIST 26 | MASTER_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n1) 27 | 28 | # Get the IP address of the master node 29 | MASTER_NODE_ADDR=$(srun --nodes=1 --ntasks=1 -w "$MASTER_NODE" hostname --ip-address) 30 | 31 | # Find an available port 32 | RDZV_PORT=$(python3 -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()") 33 | 34 | # Log start of training run 35 | log_info "Starting training run..." 36 | 37 | # Run the script using torchrun 38 | torchrun \ 39 | --nnodes 1 \ 40 | --nproc_per_node 8 \ 41 | --rdzv_id $SLURM_JOB_ID \ 42 | --rdzv_backend c10d \ 43 | --rdzv_endpoint $MASTER_NODE_ADDR:$RDZV_PORT \ 44 | --max-restarts 16 \ 45 | example.py 46 | 47 | # Log end of training run 48 | log_info "Job finished." 49 | 50 | ##################### HELPFUL SLURM COMMANDS ##################### 51 | 52 | # Disk usage and quota info; request additional space via link. 53 | # `checkquota` 54 | 55 | # Operating system details. 56 | # `cat /etc/os-release` 57 | 58 | # CPU specifications on the current node. 59 | # `lscpu` 60 | 61 | # Compute node information (standard and easy-to-read formats). 62 | # `snodes` 63 | # `shownodes` 64 | 65 | # Cluster nodes usage overview; check status (idle, down, busy). 66 | # `sinfo` 67 | # GPU usage specifics. 68 | # `sinfo -p gpu` 69 | 70 | # Quality of Service insights: job partitions and limits. 71 | # `qos` 72 | 73 | # Current processor activity; exit with 'q'. 74 | # `top` 75 | # `htop` 76 | 77 | # Overview of group-assigned cluster shares. 78 | # `sshare` 79 | 80 | # Job priority mechanics: factors and weights. 81 | # `sprio -w` 82 | 83 | # Performance of a completed job by job ID. 84 | # `seff ` 85 | 86 | # Your historical job records. 87 | # `shistory` 88 | 89 | # Detailed job statistics (memory, CPU, GPU). 90 | # `jobstats ` 91 | 92 | # Additional commands from your list for GPU details: 93 | 94 | # Details about GPUs on the cluster. 95 | # `snodes` 96 | 97 | # Number of available GPUs. 98 | # `shownodes -p gpu,mig` 99 | 100 | # GPU utilization, refreshed every 10 min. 101 | # `gpudash` 102 | # Specific to your jobs. 103 | # `gpudash -u $USER` 104 | 105 | # Real-time GPU status on active jobs. 106 | # `nvidia-smi` OR `watch nvidia-smi` 107 | 108 | # Your queue status. 109 | # `squeue -u $USER` 110 | --------------------------------------------------------------------------------