├── LICENSE ├── README.md ├── benchmarks ├── communication │ ├── README.md │ ├── __init__.py │ ├── all_reduce.py │ ├── broadcast.py │ ├── constants.py │ ├── pt2pt.py │ ├── run_all.py │ └── utils.py └── computation │ ├── README.md │ ├── benchmark_flash_attention.py │ ├── benchmark_mamba.py │ ├── benchmark_mamba2.py │ └── utils.py ├── calc ├── README.md ├── calc_mamba_flops.py ├── calc_mamba_params.py └── data │ ├── convert_into_jsonl_partitions.py │ └── tokenize_and_count.py └── imgs ├── annealing-example.png ├── mamba-moe.png ├── mamba.png ├── transformer-moe.png ├── transformer.png ├── zamba-7b.png ├── zamba2-1p2b.png ├── zamba2-2p7b.png └── zcookbook.jpg /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Image Credit: FLUX 2 | 3 | # The Zyphra Cookbook 4 | By Quentin Anthony, Beren Millidge, Paolo Glorioso, and Yury Tokpanov 5 | 6 | Training hybrid models is hard, and papers tend to gloss over the practical engineering work that goes into building good ones. The purpose of this cookbook is to enable other technical groups to hit the ground running when building their own hybrid (SSM, Transformer, MoE) models. 7 | 8 | For context, we at Zyphra have built the following hybrid models: 9 | - [BlackMamba](https://arxiv.org/abs/2402.01771) 10 | - [Zamba-7B](https://www.zyphra.com/post/zamba) 11 | - [Zamba2-2.7B](https://www.zyphra.com/post/zamba2-small) 12 | - [Zamba2-1.2B](https://huggingface.co/Zyphra/Zamba2-1.2B) 13 | 14 | The following datasets: 15 | - [Zyda](https://www.zyphra.com/post/zyda) 16 | 17 | And the following engineering optimizations 18 | - [Tree Attention](https://www.zyphra.com/post/tree-attention-topology-aware-decoding-for-long-context-attention-on-gpu-clusters) 19 | 20 | # Introduction: How Zyphra thinks about Hybrid Models 21 | 22 | Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple: 23 | 1. MHA computes exact cross-sequence dependencies, and consists of GEMMs, which are easy to parallelize across many GPU SMs 24 | 2. MLPs mix the heads of MHA and perform per-token processing, and trivially boil down to GEMMs 25 | 26 | Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are: 27 | - Mixing information across the sequence (MHA, Mamba, RWKV sequence mixers) 28 | - Updating representations per token (MLPs, KANs, Mamba in/out projectors and gates) 29 | 30 | Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance. 31 | 32 | Therefore, potential LLM architectures should be evaluated on whether they: 33 | 1. Have lower FLOP and memory requirements. We believe this is most important at [inference-time](https://arxiv.org/pdf/2401.00448v1), but also helps training. 34 | 2. Maintain the benefits of exact cross-sequence modeling from MHA (can be measured by proxy via [long-context reasoning](https://arxiv.org/abs/2406.07887) and [in-context learning](https://arxiv.org/abs/2402.03170), and general language modelling evaluations) 35 | 36 | The deployment context determines which of these properties is most important, for example: 37 | 38 | 1. Massive (100B-1T+) capabilities-focused models like Grok, Claude, and ChatGPT. These models have high parameter-counts (and therefore require more training tokens to saturate) and are deployed on cloud systems with high-VRAM GPUs (and often split between GPUs). This is why the low-FLOP and high-VRAM tradeoff of MoE is attractive. 39 | 2. Smaller (1B-15B) on-device special-purpose models like Zamba and Phi. These models require the lowest memory and latency at inference-time possible, and are deployed on embedded devices with strict power and memory constraints. Therefore they benefit more from SSM and hybrid architectures. 40 | 41 | For larger models, the primary determinant of performance is [scale](https://arxiv.org/abs/2203.15556) in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models. 42 | 43 | Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP. 44 | 45 | The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers. 46 | 47 | This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by [CharacterAI](https://blog.character.ai/optimizing-ai-inference-at-character-ai/) claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported [here](https://arxiv.org/abs/2403.17844). 48 | 49 | ## Reasoning behind the Zamba architecture 50 | 51 | While several other works, such as [Jamba](https://huggingface.co/ai21labs/Jamba-v0.1), have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions. 52 | 53 | What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance. 54 | 55 | An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful. 56 | 57 | Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models. 58 | 59 | # Model Architectures 60 | 61 | Let's talk about model architectures. Why do we think hybrids offer the best model quality per training/inference FLOPs? 62 | 63 | ### Dense Transformers 64 | 65 | Dense transformers, are primarily composed of alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks. We believe dense transformers have the following shortcomings: 66 | 1. The attention operation is still not efficient at long sequence lengths, despite recent [single-GPU efforts](https://arxiv.org/abs/2205.14135) and [distributed context efforts](https://arxiv.org/abs/2408.04093) 67 | 2. Attention blocks are correlated across model depth, which is a waste of parameters and FLOPs 68 | 69 | 70 | ### MoE Architectures 71 | 72 | Mixture of Experts (MoE) architectures introduce a router block that splits the input sequence(s) to appropriate MLP experts on a per-token basis. While the MoE has the inference latency of its forward-pass parameters, all parameters need to be loaded into VRAM which often means inference can only be performed distributed across GPU clusters for large models. 73 | 74 | ### SSM/RNN Architectures 75 | 76 | State Space Models (SSM) offer a more efficient alternative to traditional attention mechanisms, particularly beneficial for smaller models deployed on devices with strict power and memory constraints. Models like [Mamba](https://arxiv.org/abs/2312.00752) and [RWKV](https://arxiv.org/abs/2305.13048) leverage these architectures to achieve competitive performance with significantly lower FLOP and memory requirements. 77 | 78 | However, the exact cross-sequence dependencies of attention is hard to beat, and models without attention can require significantly more tokens to match the performance of attention-based models (https://arxiv.org/abs/2406.07887, https://huggingface.co/tiiuae/falcon-mamba-7b). Whether such attention-free models can ever fully match the performance of attention-based models on specific tasks like in-context learning and long-context reasoning is an open question. 79 | 80 | 81 | **Transformer** | **Mamba** | **Transformer-MoE** | **Mamba-MoE** 82 | :-------------------------:|:-------------------------:|:-------------------------:|:-------------------------: 83 | transformer | mamba | transformer-moe | mamba-moe 84 | 85 | 86 | 87 | ### Hybrid Architectures 88 | 89 | Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs. 90 | 91 | **Zamba-7B** | **Zamba2-2.7B** | **Zamba2-1.2B** 92 | :-------------------------:|:-------------------------:|:-------------------------: 93 | zamba-7b | zamba2-2p7b | zamba2-1p2b 94 | 95 | 96 | # Calculations 97 | 98 | During the model planning phase, it's common to calculate what models will fit into a given budget of parameters, FLOPs, and inference/training memory. In this cookbook we present scripts we use internally to compute the parameters and FLOPs for a given model architecture and sizing. We see this as an extension of the [EleutherAI cookbook](https://github.com/EleutherAI/cookbook) but specialized to SSMs and hybrid models. 99 | 100 | ## SSM and Hybrid Calculations 101 | 102 | We create calculation scripts for the parameters and FLOPs of mamba models in https://github.com/Zyphra/zcookbook/tree/main/calc as well as a detailed walkthrough of the calculations performed in these scripts. 103 | 104 | 105 | ## Transformer Calculations 106 | 107 | For dense and MoE transformers, we recommend using the [EleutherAI cookbook](https://github.com/EleutherAI/cookbook) by Quentin Anthony, Hailey Schoelkopf, and Stella Biderman. 108 | 109 | ## Token Calculation 110 | 111 | We provide a script at https://github.com/Zyphra/zcookbook/tree/main/calc/data/tokenize_and_count.py that tokenizes text data from a Hugging Face dataset, calculates the total number of tokens, and optionally saves the tokenized dataset. 112 | 113 | 114 | # Benchmarks 115 | 116 | 117 | ## Block Benchmarks and Sizing 118 | 119 | We provide computation benchmarks for hybrid model blocks such as attention, Mamba1, and Mamba2 in https://github.com/Zyphra/zcookbook/tree/main/benchmarks/computation. These are useful for comparing hardware performance and for [efficiently sizing models](https://arxiv.org/abs/2401.14489). 120 | 121 | 122 | ## Communication 123 | 124 | For communication benchmarks, there are two levels of tests: 125 | 1. Microbrenchmarks in C/CUDA/C++ such as [OSU-Microbenchmarks](https://mvapich.cse.ohio-state.edu/benchmarks/) and [NCCL-tests](https://github.com/NVIDIA/nccl-tests). These are best for checking hardware, low-level communication software and drivers, and low-level communication optimizations (e.g. [SHARP](), communication algorithm tuning, etc). 126 | 2. Framework-level benchmarks in PyTorch/Jax such as those in the [EleutherAI cookbook](https://github.com/EleutherAI/cookbook). These are best to ensure that framework properties (e.g. synchronization, tensor dtype handling, etc) preserve the performance of microbenchmarks, and measure performance effects of framework-level optimizations (e.g. [tensor fusion/bucketing](https://pytorch.org/docs/stable/notes/ddp.html#internal-design), [CUDA graphs](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/), etc) and communication in the context of applications (e.g. communication/computation overlap) 127 | 128 | In this cookbook, we provide framework-level benchmarks in Jax at https://github.com/Zyphra/zcookbook/tree/main/benchmarks/communication. Why Jax when our model training code is in PyTorch? Because we needed to deeply understand the communication behavior of Jax comms for our [Tree Attention](https://www.zyphra.com/post/tree-attention-topology-aware-decoding-for-long-context-attention-on-gpu-clusters) work! 129 | 130 | # Training 131 | 132 | We perform all our training using PyTorch within our custom internal fork of [MegatronLM](https://arxiv.org/abs/1909.08053). For smaller models we only need to utilize [Zero-1](https://arxiv.org/abs/1910.02054) to shard optimizer states. For larger models such as Zamba-7B, we utilized [tensor-parallelism (TP)](https://arxiv.org/abs/1909.08053) for which we created our own custom implementation in both Mamba and Mamba2. We also utilized [expert-parallelism (EP)](https://arxiv.org/abs/2305.13048) for training [BlackMamba](https://arxiv.org/abs/2402.01771). 133 | 134 | ## Annealing 135 | 136 | ### What is Annealing? 137 | 138 | ![Annealing example](imgs/annealing-example.png) 139 | 140 | We find, following [miniCPM](https://arxiv.org/html/2404.06395v1), that a simple curriculum training approach of increasing the proportion of higher quality tokens towards the end of training can significantly improve performance. 141 | 142 | 'High quality' is obviously subjective in part but we find that documents containing fact-rich information to be the most performant. Examples include: 143 | - Instruction following, which was particularly effective 144 | - Wikipedia and arxiv papers 145 | - Synthetic fact-enhanced textbook style data such as [cosmopedia](https://huggingface.co/blog/cosmopedia). 146 | 147 | In terms of the amount of annealing data, we find in general that more is better, although we are generally constrained by amount of available annealing data so that we have not been able to test truly large (>200B tokens) amounts of such data. This fits with the miniCPM findings of setting annealing to be about 10% of the total tokens of a run. We find that multiple epochs of annealing data do not appear to harm performance, yet beyond 2 epochs give little performance improvement. 148 | 149 |
150 | Annealing and LR Schedule Models/Papers 151 |
152 | 153 | Model Tech Reports Using Annealing 154 | - [Danube3](https://arxiv.org/abs/2407.09276) 155 | - [miniCPM](https://arxiv.org/html/2404.06395v1) 156 | 157 | Papers on Annealing/LR 158 | - https://arxiv.org/abs/2406.03476 159 | - https://arxiv.org/abs/2403.08763 160 | - https://arxiv.org/abs/2405.16712v1 161 | 162 |
163 | 164 | ### Annealing LR Schedule 165 | 166 | We performed significant ablations to explore the LR schedule. We made the following observations: 167 | - **The precise form of the LR decay (whether linear, cosine, or exponential) has relatively little effect on the final eval performance.** 168 | - The primary determinant of performance was the initial maximum LR of the annealing phase. 169 | - Unlike miniCPM, we found that re-warming up the LR to a large percentage (approximately 75%) of the original LR used in the pre-training phase and then decaying to zero outperformed starting at the final LR of the pre-training phase. We believe the reason that rewarmup outperforms is due to the fact that rewarming results in a significantly faster learning rate decay at the beginning of the annealing phase. This is because the difference between the initial and final learning rate of the annealing phase is larger than if the annealing had begun at the learning rate where the phase 1 pretraining finished. 170 | 171 | ### What is replay? How much do you need? 172 | 173 | When doing annealing we find it is important to maintain a high 'replay fraction' of tokens from the original pre-training dataset to stabilize training and maintain performance. This is done both to extend the annealing phase so that the model has more optimizer steps to digest the annealing data, and to minimize forgetting of the original pre-training data distribution. 174 | 175 | We typically find that a fraction of 50-70% 'replay' tokens from the original pre-training dataset and 50-30% tokens from the annealing datasets is optimal. Within this range, we find that the sensitivity to the exact replay fraction is quite low, yet we hold the intuition that replay should scale with the magnitude of the distribution shift between the pre-training and annealing datasets. In general, we have found annealing to be fairly robust to hyperparameter choices as long as the initial settings are sensible. 176 | 177 | ### Annealing Summary 178 | 179 | Concretely, our reccomendations for annealing are: 180 | 181 | 1. Generate or collect as big a dataset of high-quality tokens as possible. We've seen benefits on up to about 10-15% of the total pretraining token budget, beyond which should help but is uncharted territory. 182 | 2. Anneal with a replay fraction of between 50-70% tokens from the pre-training dataset. 183 | 3. The decay shape does not matter that much (cosine is fine). 184 | 4. Use a max LR of about 75% of original max LR that was used in the pre-training phase, and use a linear warmup from 0 to this max annealing LR over a few thousand iterations. 185 | 186 | 187 | ## Contributing 188 | 189 | If you found a bug, typo, or would like to propose an improvement please don't hesitate to open an [Issue](https://github.com/Zyphra/zcookbook/issues) or contribute a [PR](https://github.com/Zyphra/zcookbook/pulls). 190 | 191 | ## Cite As 192 | 193 | If you found this repository helpful, please consider citing it using 194 | 195 | ```bibtex 196 | @misc{anthony2024zcookbook, 197 | title = {{The Zyphra Cookbook}}, 198 | author = {Anthony, Quentin and Millidge, Beren and Glorioso, Paolo and Tokpanov, Yury}, 199 | howpublished = {GitHub Repo}, 200 | url = {https://github.com/Zyphra/zcookbook}, 201 | year = {2024} 202 | } 203 | ``` 204 | -------------------------------------------------------------------------------- /benchmarks/communication/README.md: -------------------------------------------------------------------------------- 1 | # JAX Communication Benchmarks 2 | 3 | The intent of these benchmarks is to measure communication latency/bandwidth of JAX collective communication operations at the Python layer. These benchmarks are complementary to C-level comms benchmarks like [OSU Micro-Benchmarks](https://mvapich.cse.ohio-state.edu/benchmarks/) and [NCCL Tests](https://github.com/NVIDIA/nccl-tests) in that users can: 4 | - Easily debug which layer of the communication software stack hangs or performance degradations originate from. 5 | - Measure the expected communication performance of JAX collective operations. 6 | 7 | To run benchmarks, there are two options: 8 | 9 | 1. Run a single communication operation: 10 | 11 | For example, run with a single large message size (calculated to barely fit within GPU mem): 12 |
13 | python all_reduce.py
14 | 
15 | 16 | Scan across message sizes: 17 |
18 | python all_reduce.py --scan
19 | 
20 | 21 | 2. Run all available communication benchmarks: 22 | 23 |
24 | python run_all.py
25 | 
26 | 27 | Like the individual benchmarks, `run_all.py` supports scanning arguments for the max message size, bandwidth-unit, etc. Simply pass the desired arguments to `run_all.py` and they'll be propagated to each comm op. 28 | 29 | Finally, users can choose specific communication operations to run in `run_all.py` by passing them as arguments (all operations are run by default). For example: 30 | 31 |
32 | python run_all.py --scan --all-reduce --broadcast
33 | 
34 | 35 | For usage information: 36 | 37 | ``` 38 | usage: run_all.py [-h] [--trials TRIALS] [--warmups WARMUPS] [--maxsize MAXSIZE] 39 | [--bw-unit {Gbps,GBps}] [--scan] [--raw] [--all-reduce] [--broadcast] [--dtype DTYPE] [--mem-factor MEM_FACTOR] [--debug] 40 | 41 | options: 42 | -h, --help show this help message and exit 43 | --trials TRIALS Number of timed iterations 44 | --warmups WARMUPS Number of warmup (non-timed) iterations 45 | --maxsize MAXSIZE Max message size as a power of 2 46 | --bw-unit {Gbps,GBps} 47 | --scan Enables scanning all message sizes 48 | --raw Print the message size and latency without units 49 | --all-reduce Run all_reduce 50 | --broadcast Run broadcast 51 | --dtype DTYPE JAX array dtype 52 | --mem-factor MEM_FACTOR 53 | Proportion of max available GPU memory to use for single-size evals 54 | --debug Enables all_to_all debug prints 55 | ``` 56 | 57 | # Adding Communication Benchmarks 58 | 59 | To add new communication benchmarks, follow this general procedure: 60 | 61 | 1. Copy a similar benchmark file (e.g. to add `reduce_scatter`, copy `all_reduce.py` as a template) 62 | 2. Add a new bandwidth formula in `utils.get_bandwidth`, a new maximum array element formula in `utils.max_numel`, and a new arg in `utils.benchmark_parser` 63 | 3. Replace comm op calls in new file with find-replace 64 | 4. Find a good default `mem_factor` for use in `run_()` function 65 | 5. Add new comm op to `run_all.py` 66 | 67 | Note: This JAX implementation doesn't require MPI or a specific launcher. It uses JAX's built-in multi-device support. Make sure you have JAX with GPU support installed if you're running on GPUs. 68 | -------------------------------------------------------------------------------- /benchmarks/communication/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zyphra/zcookbook/293c813f9e3e428c044f2aefc801641997657683/benchmarks/communication/__init__.py -------------------------------------------------------------------------------- /benchmarks/communication/all_reduce.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import sys 4 | import os 5 | import time 6 | from functools import partial 7 | 8 | COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../") 9 | sys.path.append(COMMS_BENCH_DIR) 10 | 11 | from communication.utils import * 12 | from communication.constants import * 13 | 14 | def timed_all_reduce(input, args): 15 | def all_reduce(x): 16 | return jax.lax.pmean(x, axis_name='i') 17 | 18 | pmap_all_reduce = jax.pmap(all_reduce, axis_name='i') 19 | 20 | # Warmups 21 | for _ in range(args.warmups): 22 | result = pmap_all_reduce(input) 23 | result.block_until_ready() # Ensure the computation is complete 24 | 25 | # Time the actual comm op 26 | start_time = time.time() 27 | for _ in range(args.trials): 28 | result = pmap_all_reduce(input) 29 | result.block_until_ready() # Ensure the computation is complete 30 | end_time = time.time() 31 | 32 | duration = (end_time - start_time) 33 | 34 | # Maintain and clean performance data 35 | avg_duration = duration / args.trials 36 | size = input.size * input.dtype.itemsize 37 | n = jax.device_count() 38 | tput, busbw = get_bw('all_reduce', size, avg_duration, args) 39 | tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration) 40 | desc = f'{input.shape[1]}x{input.dtype.itemsize}' 41 | 42 | if not args.raw: 43 | size = convert_size(size) 44 | 45 | print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}") 46 | 47 | def run_all_reduce(args): 48 | # Prepare benchmark header 49 | print_header(args, 'all_reduce') 50 | 51 | if args.scan: 52 | M_LIST = [2**p for p in range(1, args.maxsize)] 53 | 54 | # Loop over various tensor sizes 55 | for M in M_LIST: 56 | try: 57 | mat = jnp.ones((jax.local_device_count(), M), dtype=getattr(jnp, args.dtype)) 58 | input = jax.pmap(lambda x: x)(mat) # Distribute the data across devices 59 | except RuntimeError as e: 60 | if 'out of memory' in str(e): 61 | print_rank_0('WARNING: Ran out of GPU memory. Exiting comm op.') 62 | break 63 | else: 64 | raise e 65 | timed_all_reduce(input, args) 66 | else: 67 | # Send the biggest message size our GPUs can fit 68 | elements_per_gpu = max_numel(comm_op='all_reduce', 69 | dtype=getattr(jnp, args.dtype), 70 | mem_factor=args.mem_factor * 2, 71 | args=args) 72 | try: 73 | mat = jnp.ones((jax.local_device_count(), elements_per_gpu), dtype=getattr(jnp, args.dtype)) 74 | input = jax.pmap(lambda x: x)(mat) # Distribute the data across devices 75 | except RuntimeError as e: 76 | if 'out of memory' in str(e): 77 | print_rank_0('WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!') 78 | return 79 | else: 80 | raise e 81 | timed_all_reduce(input, args) 82 | 83 | if __name__ == "__main__": 84 | args = benchmark_parser().parse_args() 85 | run_all_reduce(args) 86 | -------------------------------------------------------------------------------- /benchmarks/communication/broadcast.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import sys 4 | import os 5 | import time 6 | 7 | COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../") 8 | sys.path.append(COMMS_BENCH_DIR) 9 | 10 | from communication.utils import * 11 | from communication.constants import * 12 | 13 | def timed_broadcast(input, args): 14 | @jax.pmap 15 | def broadcast(x): 16 | return jax.lax.broadcast(x[0], (jax.device_count(),)) 17 | 18 | 19 | # Warmups 20 | for _ in range(args.warmups): 21 | result = broadcast(input) 22 | result.block_until_ready() 23 | 24 | # Time the actual comm op 25 | start_time = time.time() 26 | for _ in range(args.trials): 27 | result = broadcast(input) 28 | result.block_until_ready() 29 | end_time = time.time() 30 | 31 | duration = end_time - start_time 32 | 33 | # Maintain and clean performance data 34 | avg_duration = duration / args.trials 35 | size = input.nbytes 36 | n = jax.device_count() 37 | tput, busbw = get_bw('broadcast', size, avg_duration, args) 38 | tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration) 39 | desc = f'{input.size}x{input.dtype.itemsize}' 40 | 41 | if not args.raw: 42 | size = convert_size(size) 43 | 44 | print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}") 45 | 46 | def run_broadcast(args): 47 | # Prepare benchmark header 48 | print_header(args, 'broadcast') 49 | 50 | if args.scan: 51 | M_LIST = [2**p for p in range(1, args.maxsize)] 52 | 53 | # Loop over various tensor sizes 54 | for M in M_LIST: 55 | try: 56 | mat = jnp.ones((jax.device_count(), M), dtype=getattr(jnp, args.dtype)) 57 | input = jax.pmap(lambda i, x: x * i)(jnp.arange(jax.device_count()), mat) 58 | except RuntimeError as e: 59 | if 'out of memory' in str(e): 60 | print_rank_0('WARNING: Ran out of GPU memory. Exiting comm op.') 61 | break 62 | else: 63 | raise e 64 | timed_broadcast(input, args) 65 | else: 66 | # Send the biggest message size our GPUs can fit 67 | elements_per_gpu = max_numel(comm_op='broadcast', 68 | dtype=getattr(jnp, args.dtype), 69 | mem_factor=args.mem_factor * 2, 70 | args=args) 71 | try: 72 | mat = jnp.ones((jax.device_count(), elements_per_gpu), dtype=getattr(jnp, args.dtype)) 73 | input = jax.pmap(lambda i, x: x * i)(jnp.arange(jax.device_count()), mat) 74 | except RuntimeError as e: 75 | if 'out of memory' in str(e): 76 | print_rank_0('WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!') 77 | return 78 | else: 79 | raise e 80 | timed_broadcast(input, args) 81 | 82 | if __name__ == "__main__": 83 | args = benchmark_parser().parse_args() 84 | run_broadcast(args) 85 | -------------------------------------------------------------------------------- /benchmarks/communication/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_WARMUPS = 5 2 | DEFAULT_TRIALS = 50 3 | DEFAULT_TYPE = 'float' 4 | DEFAULT_BACKEND = 'nccl' 5 | DEFAULT_UNIT = 'Gbps' 6 | DEFAULT_MAXSIZE = 24 7 | DEFAULT_JAX_TIMEOUT = 300 -------------------------------------------------------------------------------- /benchmarks/communication/pt2pt.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import sys 4 | import os 5 | import time 6 | 7 | COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../") 8 | sys.path.append(COMMS_BENCH_DIR) 9 | 10 | from communication.utils import * 11 | from communication.constants import * 12 | 13 | def timed_pt2pt(input, args): 14 | def send_recv(x): 15 | axis_name = 'i' 16 | device_id = jax.lax.axis_index(axis_name) 17 | 18 | # Step 1: GPU 0 sends to GPU 1 19 | step1 = jax.lax.ppermute(x, axis_name, [(0, 1)]) 20 | 21 | # Step 2: GPU 1 sends to GPU 0 22 | step2 = jax.lax.ppermute(x, axis_name, [(1, 0)]) 23 | 24 | # Combine results 25 | result = jax.lax.cond( 26 | device_id == 0, 27 | lambda _: step2, # GPU 0 receives in step 2 28 | lambda _: step1, # GPU 1 receives in step 1 29 | operand=None 30 | ) 31 | return result 32 | 33 | send_recv_op = jax.pmap(send_recv, axis_name='i') 34 | 35 | # Ensure we're using exactly 2 GPUs 36 | if jax.device_count() != 2: 37 | raise ValueError("This benchmark requires exactly 2 GPUs") 38 | 39 | # Warmups 40 | for _ in range(args.warmups): 41 | result = send_recv_op(input) 42 | result.block_until_ready() 43 | 44 | # Time the actual comm op 45 | start_time = time.time() 46 | for _ in range(args.trials): 47 | send_recv_op(input) 48 | result.block_until_ready() 49 | end_time = time.time() 50 | 51 | duration = end_time - start_time 52 | 53 | # Maintain and clean performance data 54 | avg_duration = duration / args.trials 55 | size = input.nbytes # Only considering the data sent in one direction 56 | tput, busbw = get_bw('pt2pt', size, avg_duration, args) 57 | tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration) 58 | desc = f'{input.shape[1]}x{input.dtype.itemsize}' 59 | 60 | if not args.raw: 61 | size = convert_size(size) 62 | 63 | print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}") 64 | 65 | def run_pt2pt(args): 66 | # Prepare benchmark header 67 | print_header(args, 'pt2pt') 68 | 69 | if args.scan: 70 | M_LIST = [2**p for p in range(1, args.maxsize)] 71 | 72 | # Loop over various tensor sizes 73 | for M in M_LIST: 74 | try: 75 | mat = jnp.ones((2, M), dtype=getattr(jnp, args.dtype)) 76 | input = jax.pmap(lambda i, x: x * i)(jnp.arange(2), mat) 77 | except RuntimeError as e: 78 | if 'out of memory' in str(e): 79 | print_rank_0('WARNING: Ran out of GPU memory. Exiting comm op.') 80 | break 81 | else: 82 | raise e 83 | timed_pt2pt(input, args) 84 | else: 85 | # Send the biggest message size our GPUs can fit 86 | elements_per_gpu = max_numel(comm_op='pt2pt', 87 | dtype=getattr(jnp, args.dtype), 88 | mem_factor=args.mem_factor * 2, 89 | args=args) 90 | try: 91 | mat = jnp.ones((2, elements_per_gpu), dtype=getattr(jnp, args.dtype)) 92 | input = jax.pmap(lambda i, x: x * i)(jnp.arange(2), mat) 93 | except RuntimeError as e: 94 | if 'out of memory' in str(e): 95 | print_rank_0('WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!') 96 | return 97 | else: 98 | raise e 99 | timed_pt2pt(input, args) 100 | 101 | if __name__ == "__main__": 102 | args = benchmark_parser().parse_args() 103 | run_pt2pt(args) 104 | 105 | -------------------------------------------------------------------------------- /benchmarks/communication/run_all.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../") 4 | sys.path.append(COMMS_BENCH_DIR) 5 | 6 | from communication.utils import * 7 | from communication.all_reduce import run_all_reduce 8 | from communication.pt2pt import run_pt2pt 9 | from communication.broadcast import run_broadcast 10 | from communication.constants import * 11 | 12 | 13 | # For importing 14 | def main(args, rank): 15 | 16 | init_processes(local_rank=rank, args=args) 17 | 18 | ops_to_run = [] 19 | if args.all_reduce: 20 | ops_to_run.append('all_reduce') 21 | if args.broadcast: 22 | ops_to_run.append('broadcast') 23 | if args.pt2pt: 24 | ops_to_run.append('pt2pt') 25 | 26 | if len(ops_to_run) == 0: 27 | ops_to_run = ['all_reduce', 'broadcast', 'pt2pt'] 28 | 29 | for comm_op in ops_to_run: 30 | if comm_op == 'all_reduce': 31 | run_all_reduce(local_rank=rank, args=args) 32 | if comm_op == 'pt2pt': 33 | run_pt2pt(local_rank=rank, args=args) 34 | if comm_op == 'broadcast': 35 | run_broadcast(local_rank=rank, args=args) 36 | 37 | 38 | # For directly calling benchmark 39 | if __name__ == "__main__": 40 | args = benchmark_parser().parse_args() 41 | rank = args.local_rank 42 | main(args, rank) 43 | -------------------------------------------------------------------------------- /benchmarks/communication/utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import os 4 | import sys 5 | import math 6 | import argparse 7 | import time 8 | 9 | COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../") 10 | sys.path.append(COMMS_BENCH_DIR) 11 | from .constants import * 12 | 13 | os.environ['JAX_THREEFY_TIMEOUT'] = str(DEFAULT_JAX_TIMEOUT) 14 | 15 | def print_rank_0(message): 16 | if jax.process_index() == 0: 17 | print(message) 18 | 19 | def print_header(args, comm_op): 20 | if comm_op == 'pt2pt': 21 | world_size = 2 22 | else: 23 | world_size = jax.device_count() 24 | tput = f'Throughput ({args.bw_unit})' 25 | busbw = f'BusBW ({args.bw_unit})' 26 | header = f"\n---- Performance of {comm_op} on {world_size} devices ---------------------------------------------------------\n" 27 | duration_str = 'Duration' 28 | if args.raw: 29 | duration_str += ' (us)' 30 | header += f"{'Size (Bytes)':20s} {'Description':25s} {duration_str:20s} {tput:20s} {busbw:20s}\n" 31 | header += "----------------------------------------------------------------------------------------------------" 32 | print_rank_0(header) 33 | 34 | def get_bw(comm_op, size, duration, args): 35 | n = jax.device_count() 36 | tput = 0 37 | busbw = 0 38 | if comm_op == "all_to_all": 39 | tput = (size / duration) 40 | busbw = (size / duration) * ((n - 1) / n) 41 | elif comm_op == "all_gather": 42 | size *= n 43 | tput = (size / duration) 44 | busbw = (size / duration) * ((n - 1) / n) 45 | elif comm_op == "all_reduce": 46 | tput = (size * 2 / duration) 47 | busbw = (size / duration) * (2 * (n - 1) / n) 48 | elif comm_op == "pt2pt" or comm_op == "broadcast": 49 | tput = (size / duration) 50 | busbw = tput 51 | else: 52 | print_rank_0("wrong comm_op specified") 53 | exit(0) 54 | 55 | if args.bw_unit == 'Gbps': 56 | tput *= 8 57 | busbw *= 8 58 | 59 | return tput, busbw 60 | 61 | def get_metric_strings(args, tput, busbw, duration): 62 | duration_ms = duration * 1e3 63 | duration_us = duration * 1e6 64 | tput = f'{tput / 1e9:.3f}' 65 | busbw = f'{busbw /1e9:.3f}' 66 | 67 | if duration_us < 1e3 or args.raw: 68 | duration = f'{duration_us:.3f}' 69 | if not args.raw: 70 | duration += ' us' 71 | else: 72 | duration = f'{duration_ms:.3f} ms' 73 | return tput, busbw, duration 74 | 75 | def sync_all(): 76 | # In JAX, explicit synchronization is often not necessary 77 | # due to its functional nature, but we can use a barrier if needed 78 | jax.pmap(lambda x: x)(jnp.zeros(jax.local_device_count())).block_until_ready() 79 | 80 | def max_numel(comm_op, dtype, mem_factor, args): 81 | dtype_size = jnp.dtype(dtype).itemsize 82 | max_memory_per_gpu = jax.local_devices()[0].memory_stats()['bytes_limit'] * mem_factor 83 | if comm_op == 'all_reduce' or comm_op == 'pt2pt' or comm_op == 'broadcast': 84 | elements_per_gpu = int(max_memory_per_gpu // dtype_size) 85 | elif comm_op == 'all_gather': 86 | elements_per_gpu = int(max_memory_per_gpu // dtype_size // jax.device_count()) 87 | elements_per_gpu = int(pow(2, int(math.log(elements_per_gpu, 2)))) 88 | elif comm_op == 'all_to_all': 89 | elements_per_gpu = int(max_memory_per_gpu // dtype_size) 90 | elements_per_gpu = int(jax.device_count() * round(elements_per_gpu / jax.device_count())) 91 | elements_per_gpu = int(pow(2, int(math.log(elements_per_gpu, 2)))) 92 | else: 93 | print(f"This communication operation: {comm_op} is not supported yet") 94 | exit(0) 95 | return elements_per_gpu 96 | 97 | def convert_size(size_bytes): 98 | if size_bytes == 0: 99 | return "0B" 100 | size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") 101 | i = int(math.floor(math.log(size_bytes, 1024))) 102 | p = math.pow(1024, i) 103 | s = round(size_bytes / p, 2) 104 | return "%s %s" % (s, size_name[i]) 105 | 106 | def benchmark_parser(): 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument("--trials", type=int, default=DEFAULT_TRIALS, help='Number of timed iterations') 109 | parser.add_argument("--warmups", type=int, default=DEFAULT_WARMUPS, help='Number of warmup (non-timed) iterations') 110 | parser.add_argument("--maxsize", type=int, default=24, help='Max message size as a power of 2') 111 | parser.add_argument("--bw-unit", type=str, default=DEFAULT_UNIT, choices=['Gbps', 'GBps']) 112 | parser.add_argument("--scan", action="store_true", help='Enables scanning all message sizes') 113 | parser.add_argument("--raw", action="store_true", help='Print the message size and latency without units') 114 | parser.add_argument("--all-reduce", action="store_true", help='Run all_reduce') 115 | #parser.add_argument("--all-gather", action="store_true", help='Run all_gather') 116 | #parser.add_argument("--all-to-all", action="store_true", help='Run all_to_all') 117 | parser.add_argument("--pt2pt", action="store_true", help='Run pt2pt') 118 | parser.add_argument("--broadcast", action="store_true", help='Run broadcast') 119 | parser.add_argument("--dtype", type=str, default='float32', 120 | choices=['float32', 'float64', 'int32', 'int64', 'float16', 'bfloat16'], 121 | help='JAX array dtype') 122 | parser.add_argument("--mem-factor", 123 | type=float, 124 | default=.3, 125 | help='Proportion of max available GPU memory to use for single-size evals') 126 | parser.add_argument("--debug", action="store_true", help='Enables all_to_all debug prints') 127 | return parser 128 | -------------------------------------------------------------------------------- /benchmarks/computation/README.md: -------------------------------------------------------------------------------- 1 | # Computation Benchmarks 2 | 3 | This directory contains isolated benchmarks for the core blocks of Zamba hybrid models: attention and Mamba. These benchmarks are designed to compare accelerators, find [good model sizes](https://arxiv.org/abs/2401.14489), and test optimizations. 4 | 5 | ## Available Benchmarks 6 | 7 | 1. Mamba1 Benchmark (`benchmark_mamba.py`) 8 | 2. Flash Attention Benchmark (`benchmark_flash_attention.py`) 9 | 3. Mamba2 Benchmark (`benchmark_mamba2.py`) 10 | 11 | ## Running Benchmarks 12 | 13 | To run a benchmark, use the following command: 14 | 15 | ``` 16 | python .py 17 | ``` 18 | 19 | ## License 20 | 21 | This project is licensed under the Apache License 2.0. See the LICENSE file for details. 22 | -------------------------------------------------------------------------------- /benchmarks/computation/benchmark_flash_attention.py: -------------------------------------------------------------------------------- 1 | # Benchmark the Flash Attention block 2 | # Based on https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py 3 | 4 | import argparse 5 | import torch 6 | import math 7 | from flash_attn import flash_attn_qkvpacked_func 8 | from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward 9 | 10 | def parse_arguments(): 11 | parser = argparse.ArgumentParser(description='Benchmark Flash Attention') 12 | parser.add_argument('--dtype', type=str, default='fp16', choices=['fp16', 'fp32', 'fp64', 'bf16'], help='Data type for torch operations') 13 | parser.add_argument('--causal', type=str, nargs='+', choices=['true', 'false'], default=['false'], help='Enable causal masking') 14 | parser.add_argument('--verbose', '-v', action='store_true', help='Enable verbose output') 15 | parser.add_argument('--dropout_p', type=float, default=0.0, help='Dropout probability') 16 | parser.add_argument('--seqlen', type=int, default=1024, help='Sequence length') 17 | parser.add_argument('--batch_size', type=int, default=16, help='Batch size') 18 | parser.add_argument('--nheads', type=int, default=16, help='Number of attention heads') 19 | parser.add_argument('--head_dim', type=int, default=128, help='Size of each attention head') 20 | parser.add_argument('--hidden_dim', type=int, default=2048, help='Total hidden dimension') 21 | parser.add_argument('--repeats', type=int, default=30, help='Number of repeats for benchmarking') 22 | parser.add_argument('--device', type=str, default='cuda', help='Torch device to run the benchmark on') 23 | return parser.parse_args() 24 | 25 | def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): 26 | assert mode in ["fwd", "bwd", "fwd_bwd"] 27 | f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) 28 | return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) 29 | 30 | def efficiency(flop, time): 31 | return (flop / time / 10**12) if not math.isnan(time) else 0.0 32 | 33 | def pretty_print_latency(latency_ms): 34 | if latency_ms < 1: 35 | return f"{latency_ms*1000:.2f} µs" 36 | elif latency_ms < 1000: 37 | return f"{latency_ms:.2f} ms" 38 | else: 39 | return f"{latency_ms/1000:.2f} s" 40 | 41 | def main(): 42 | args = parse_arguments() 43 | 44 | dtype_map = {'fp16': torch.float16, 'fp32': torch.float32, 'fp64': torch.float64, 'bf16': torch.bfloat16} 45 | dtype = dtype_map[args.dtype] 46 | causal = args.causal[0] == 'true' 47 | 48 | batch_size = args.batch_size 49 | seqlen = args.seqlen 50 | nheads = args.nheads 51 | head_dim = args.head_dim 52 | hidden_dim = args.hidden_dim 53 | 54 | # Ensure hidden_dim is divisible by nheads 55 | assert hidden_dim % nheads == 0, f"hidden_dim {hidden_dim} must be divisible by nheads {nheads}" 56 | 57 | qkv = torch.randn(batch_size, seqlen, 3, nheads, head_dim, device=args.device, dtype=dtype, requires_grad=True) 58 | 59 | # Warmup 60 | for _ in range(10): 61 | output = flash_attn_qkvpacked_func(qkv, args.dropout_p, causal=causal) 62 | output.sum().backward() 63 | 64 | # Benchmark forward pass 65 | fwd_times = benchmark_forward( 66 | flash_attn_qkvpacked_func, qkv, args.dropout_p, causal=causal, 67 | repeats=args.repeats 68 | ) 69 | 70 | # Benchmark backward pass 71 | bwd_times = benchmark_backward( 72 | flash_attn_qkvpacked_func, qkv, args.dropout_p, causal=causal, 73 | repeats=args.repeats 74 | ) 75 | 76 | f, b = fwd_times[1], bwd_times[1] 77 | 78 | fwd_flops = flops(batch_size, seqlen, head_dim, nheads, causal, mode="fwd") 79 | bwd_flops = flops(batch_size, seqlen, head_dim, nheads, causal, mode="bwd") 80 | fwd_bwd_flops = flops(batch_size, seqlen, head_dim, nheads, causal, mode="fwd_bwd") 81 | 82 | print(f"### Causal={causal}, head_dim={head_dim}, batch_size={batch_size}, seqlen={seqlen}, hidden_dim={hidden_dim} ###") 83 | print(f"FlashAttention2 FWD Latency: {pretty_print_latency(f)}") 84 | print(f"FlashAttention2 BWD Latency: {pretty_print_latency(b)}") 85 | print(f"FlashAttention2 FWD+BWD Latency: {pretty_print_latency(f + b)}") 86 | print(f"FlashAttention2 FWD Throughput: {efficiency(fwd_flops, f):.2f} TFLOPs/s") 87 | print(f"FlashAttention2 BWD Throughput: {efficiency(bwd_flops, b):.2f} TFLOPs/s") 88 | print(f"FlashAttention2 FWD+BWD Throughput: {efficiency(fwd_bwd_flops, f + b):.2f} TFLOPs/s") 89 | 90 | if __name__ == "__main__": 91 | main() 92 | -------------------------------------------------------------------------------- /benchmarks/computation/benchmark_mamba.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import math 4 | from mamba_ssm import Mamba 5 | from utils import time_fwd_bwd 6 | 7 | def parse_arguments(): 8 | parser = argparse.ArgumentParser(description='Benchmark Mamba') 9 | parser.add_argument('--dtype', type=str, default='fp16', choices=['fp16', 'fp32', 'bf16'], help='Data type for torch operations') 10 | parser.add_argument('--verbose', '-v', action='store_true', help='Enable verbose output') 11 | parser.add_argument('--seqlen', type=int, default=1024, help='Sequence length') 12 | parser.add_argument('--batch_size', type=int, default=16, help='Batch size') 13 | parser.add_argument('--d_conv', type=int, default=4, help='Dimension of convolution kernel') 14 | parser.add_argument('--d_model', type=int, default=2048, help='Model dimension') 15 | parser.add_argument('--d_state', type=int, default=64, help='State dimension') 16 | parser.add_argument('--expand', type=int, default=2, help='Expansion factor') 17 | parser.add_argument('--repeats', type=int, default=30, help='Number of repeats for benchmarking') 18 | parser.add_argument('--device', type=str, default='cuda', help='Torch device to run the benchmark on') 19 | return parser.parse_args() 20 | 21 | def flops_mamba(hidden_size, expansion_factor, state_size, seqlen, batch_size, conv_dimension, num_layers, dt_rank="auto", mode="fwd"): 22 | assert mode in ["fwd", "bwd", "fwd_bwd"] 23 | iter_factor = 1 if mode == "fwd" else (2 if mode == "bwd" else 3) 24 | d_inner = hidden_size * expansion_factor 25 | dt_rank = math.ceil(hidden_size / 16) if dt_rank == "auto" else dt_rank 26 | ssm_flops = iter_factor * d_inner * seqlen * batch_size * (11 * state_size + 4 * dt_rank + 1) * num_layers 27 | mamba_projectors_flops = iter_factor * seqlen * batch_size * 6 * d_inner * hidden_size * num_layers 28 | mamba_conv_flops = iter_factor * seqlen * batch_size * 2 * d_inner * conv_dimension * num_layers 29 | mamba_flops = ssm_flops + mamba_projectors_flops + mamba_conv_flops 30 | return mamba_flops 31 | 32 | def efficiency(flop, time): 33 | return (flop / time / 10**12) if not math.isnan(time) else 0.0 34 | 35 | def pretty_print_latency(latency_ms): 36 | if latency_ms < 1: 37 | return f"{latency_ms*1000:.2f} µs" 38 | elif latency_ms < 1000: 39 | return f"{latency_ms:.2f} ms" 40 | else: 41 | return f"{latency_ms/1000:.2f} s" 42 | 43 | def main(): 44 | args = parse_arguments() 45 | 46 | dtype_map = {'fp16': torch.float16, 'fp32': torch.float32, 'bf16': torch.bfloat16} 47 | dtype = dtype_map[args.dtype] 48 | 49 | batch_size = args.batch_size 50 | seqlen = args.seqlen 51 | d_model = args.d_model 52 | d_state = args.d_state 53 | expand = args.expand 54 | d_conv = args.d_conv 55 | 56 | input = torch.randn(batch_size, seqlen, d_model, device=args.device, dtype=dtype, requires_grad=True) 57 | model = Mamba(d_model=d_model, d_state=d_state, expand=expand, device=args.device, dtype=dtype, d_conv=d_conv) 58 | 59 | num_params = sum(p.numel() for p in model.parameters()) 60 | 61 | # Warmup 62 | for _ in range(10): 63 | output = model(input) 64 | output.sum().backward() 65 | 66 | # Benchmark 67 | f, b = time_fwd_bwd(model, input, repeats=args.repeats, verbose=args.verbose) 68 | 69 | fwd_flops = flops_mamba(d_model, expand, d_state, seqlen, batch_size, d_conv, 1, mode="fwd") 70 | bwd_flops = flops_mamba(d_model, expand, d_state, seqlen, batch_size, d_conv, 1, mode="bwd") 71 | fwd_bwd_flops = flops_mamba(d_model, expand, d_state, seqlen, batch_size, d_conv, 1, mode="fwd_bwd") 72 | 73 | print(f"### d_model={d_model}, d_state={d_state}, expand={expand}, batch_size={batch_size}, seqlen={seqlen} ###") 74 | print(f"Mamba FWD Latency: {pretty_print_latency(f)}") 75 | print(f"Mamba BWD Latency: {pretty_print_latency(b)}") 76 | print(f"Mamba FWD+BWD Latency: {pretty_print_latency(f + b)}") 77 | print(f"Mamba FWD Throughput: {efficiency(fwd_flops, f):.2f} TFLOPs/s") 78 | print(f"Mamba BWD Throughput: {efficiency(bwd_flops, b):.2f} TFLOPs/s") 79 | print(f"Mamba FWD+BWD Throughput: {efficiency(fwd_bwd_flops, f + b):.2f} TFLOPs/s") 80 | 81 | if __name__ == "__main__": 82 | main() 83 | -------------------------------------------------------------------------------- /benchmarks/computation/benchmark_mamba2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import math 4 | from mamba_ssm import Mamba2 5 | from utils import time_fwd_bwd 6 | 7 | def parse_arguments(): 8 | parser = argparse.ArgumentParser(description='Benchmark Mamba2') 9 | parser.add_argument('--dtype', type=str, default='fp16', choices=['fp16', 'fp32', 'bf16'], help='Data type for torch operations') 10 | parser.add_argument('--verbose', '-v', action='store_true', help='Enable verbose output') 11 | parser.add_argument('--seqlen', type=int, default=64, help='Sequence length') 12 | parser.add_argument('--batch_size', type=int, default=2, help='Batch size') 13 | parser.add_argument('--d_conv', type=int, default=4, help='Dimension of convolution kernel') 14 | parser.add_argument('--d_model', type=int, default=1024, help='Model dimension') 15 | parser.add_argument('--d_state', type=int, default=64, help='State dimension') 16 | parser.add_argument('--expand', type=int, default=2, help='Expansion factor') 17 | parser.add_argument('--headdim', type=int, default=128, help='Head dimension') 18 | parser.add_argument('--ngroups', type=int, default=1, help='Number of mamba2 groups') 19 | parser.add_argument('--repeats', type=int, default=30, help='Number of repeats for benchmarking') 20 | parser.add_argument('--device', type=str, default='cuda', help='Torch device to run the benchmark on') 21 | return parser.parse_args() 22 | 23 | def flops_mamba2(hidden_size, expansion_factor, headdim, ngroups, d_state, sequence_length, batch_size, conv_dimension, num_layers, dt_rank="auto", mode="fwd"): 24 | assert mode in ["fwd", "bwd", "fwd_bwd"] 25 | iter_factor = 1 if mode == "fwd" else (2 if mode == "bwd" else 3) 26 | d_inner = hidden_size * expansion_factor 27 | Nheads = d_inner // headdim 28 | mamba2_block_flops = 2 * (2 * d_inner + 2 * ngroups * d_state + Nheads) * hidden_size * batch_size * sequence_length # in proj 29 | mamba2_block_flops += 2 * batch_size * sequence_length * (d_inner + 2 * ngroups * d_state) * conv_dimension * d_inner# conv 30 | mamba2_block_flops += 2 * batch_size * sequence_length * d_inner * d_state * d_inner # dtbx 31 | mamba2_block_flops += 2 * batch_size * sequence_length * d_inner * d_state # ssm state rollover 32 | mamba2_block_flops += 2 * batch_size * sequence_length * d_inner * d_state * hidden_size # c-> y 33 | mamba2_block_flops += batch_size * sequence_length * hidden_size # z gate output 34 | return mamba2_block_flops 35 | 36 | def pretty_print_latency(latency_ms): 37 | if latency_ms < 1: 38 | return f"{latency_ms*1000:.2f} µs" 39 | elif latency_ms < 1000: 40 | return f"{latency_ms:.2f} ms" 41 | else: 42 | return f"{latency_ms/1000:.2f} s" 43 | 44 | def efficiency(flop, time): 45 | return (flop / time / 10**12) if not math.isnan(time) else 0.0 46 | 47 | def main(): 48 | args = parse_arguments() 49 | 50 | dtype_map = {'fp16': torch.float16, 'fp32': torch.float32, 'bf16': torch.bfloat16} 51 | dtype = dtype_map[args.dtype] 52 | 53 | batch_size = args.batch_size 54 | seqlen = args.seqlen 55 | d_model = args.d_model 56 | d_state = args.d_state 57 | expand = args.expand 58 | headdim = args.headdim 59 | ngroups = args.ngroups 60 | d_conv = args.d_conv 61 | 62 | 63 | input = torch.randn(batch_size, seqlen, d_model, device=args.device, dtype=dtype, requires_grad=True) 64 | model = Mamba2(d_model=d_model, d_state=d_state, expand=expand, headdim=headdim, device=args.device, dtype=dtype) 65 | 66 | num_params = sum(p.numel() for p in model.parameters()) 67 | 68 | # Warmup 69 | for _ in range(10): 70 | output = model(input) 71 | output.sum().backward() 72 | 73 | # Benchmark 74 | f, b = time_fwd_bwd(model, input, repeats=args.repeats, verbose=args.verbose) 75 | 76 | 77 | fwd_flops = flops_mamba2(d_model, expand, headdim, ngroups, d_state, seqlen, batch_size, d_conv, 1, mode="fwd") 78 | bwd_flops = flops_mamba2(d_model, expand, headdim, ngroups, d_state, seqlen, batch_size, d_conv, 1, mode="bwd") 79 | fwd_bwd_flops = flops_mamba2(d_model, expand, headdim, ngroups, d_state, seqlen, batch_size, d_conv, 1, mode="fwd_bwd") 80 | 81 | 82 | print(f"### d_model={d_model}, d_state={d_state}, expand={expand}, batch_size={batch_size}, seqlen={seqlen} ###") 83 | print(f"Mamba2 FWD Latency: {pretty_print_latency(f)}") 84 | print(f"Mamba2 BWD Latency: {pretty_print_latency(b)}") 85 | print(f"Mamba2 FWD+BWD Latency: {pretty_print_latency(f + b)}") 86 | print(f"Mamba2 FWD Throughput: {efficiency(fwd_flops, f):.2f} TFLOPs/s") 87 | print(f"Mamba2 BWD Throughput: {efficiency(bwd_flops, b):.2f} TFLOPs/s") 88 | print(f"Mamba2 FWD+BWD Throughput: {efficiency(fwd_bwd_flops, f + b):.2f} TFLOPs/s") 89 | 90 | if __name__ == "__main__": 91 | main() 92 | -------------------------------------------------------------------------------- /benchmarks/computation/utils.py: -------------------------------------------------------------------------------- 1 | # Heavily borrowed from 2 | # https://git.gsas.edu.hk/shenguo.wang/flash-attention/-/blob/v2.5.9/flash_attn/utils/benchmark.py 3 | 4 | 5 | import torch 6 | import torch.utils.benchmark as benchmark 7 | import math 8 | 9 | 10 | def time_fwd_bwd(func, *args, **kwargs): 11 | time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs) 12 | #print("time_f", time_f) 13 | #print("time_b", time_b) 14 | return time_f[1].mean, time_b[1].mean 15 | 16 | 17 | def benchmark_forward( 18 | fn, 19 | *inputs, 20 | repeats=10, 21 | desc="", 22 | verbose=True, 23 | amp=False, 24 | amp_dtype=torch.float16, 25 | **kwinputs 26 | ): 27 | """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" 28 | if verbose: 29 | print(desc, "- Forward pass") 30 | 31 | def amp_wrapper(*inputs, **kwinputs): 32 | with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): 33 | fn(*inputs, **kwinputs) 34 | 35 | t = benchmark.Timer( 36 | stmt="fn_amp(*inputs, **kwinputs)", 37 | globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, 38 | num_threads=torch.get_num_threads(), 39 | ) 40 | torch.cuda.synchronize() 41 | m = t.timeit(repeats) 42 | torch.cuda.synchronize() 43 | if verbose: 44 | print(m) 45 | return t, m 46 | 47 | 48 | def benchmark_backward( 49 | fn, 50 | *inputs, 51 | grad=None, 52 | repeats=10, 53 | desc="", 54 | verbose=True, 55 | amp=False, 56 | amp_dtype=torch.float16, 57 | **kwinputs, 58 | ): 59 | """Use Pytorch Benchmark on the backward pass of an arbitrary function.""" 60 | if verbose: 61 | print(desc, "- Backward pass") 62 | with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): 63 | y = fn(*inputs, **kwinputs) 64 | if type(y) is tuple: 65 | y = y[0] 66 | if grad is None: 67 | grad = torch.randn_like(y) 68 | else: 69 | if grad.shape != y.shape: 70 | raise RuntimeError("Grad shape does not match output shape") 71 | 72 | def f(*inputs, y, grad): 73 | # Set .grad to None to avoid extra operation of gradient accumulation 74 | for x in inputs: 75 | if isinstance(x, torch.Tensor): 76 | x.grad = None 77 | y.backward(grad, retain_graph=True) 78 | 79 | t = benchmark.Timer( 80 | stmt="f(*inputs, y=y, grad=grad)", 81 | globals={"f": f, "inputs": inputs, "y": y, "grad": grad}, 82 | num_threads=torch.get_num_threads(), 83 | ) 84 | torch.cuda.synchronize() 85 | m = t.timeit(repeats) 86 | torch.cuda.synchronize() 87 | if verbose: 88 | print(m) 89 | return t, m 90 | 91 | 92 | def benchmark_fwd_bwd( 93 | fn, 94 | *inputs, 95 | grad=None, 96 | repeats=10, 97 | desc="", 98 | verbose=True, 99 | amp=False, 100 | amp_dtype=torch.float16, 101 | **kwinputs, 102 | ): 103 | """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" 104 | return ( 105 | benchmark_forward( 106 | fn, 107 | *inputs, 108 | repeats=repeats, 109 | desc=desc, 110 | verbose=verbose, 111 | amp=amp, 112 | amp_dtype=amp_dtype, 113 | **kwinputs, 114 | ), 115 | benchmark_backward( 116 | fn, 117 | *inputs, 118 | grad=grad, 119 | repeats=repeats, 120 | desc=desc, 121 | verbose=verbose, 122 | amp=amp, 123 | amp_dtype=amp_dtype, 124 | **kwinputs, 125 | ), 126 | ) 127 | 128 | 129 | def attn_flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): 130 | assert mode in ["fwd", "bwd", "fwd_bwd"] 131 | f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) 132 | return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) 133 | 134 | 135 | #def mamba1_flops(): 136 | 137 | 138 | 139 | #def mamba2_flops(): 140 | 141 | 142 | def efficiency(flop, time): 143 | return (flop / time / 10**12) if not math.isnan(time) else 0.0 144 | 145 | def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): 146 | assert mode in ["fwd", "bwd", "fwd_bwd"] 147 | f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) 148 | return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) 149 | 150 | 151 | 152 | # Helper function to pretty-print message sizes 153 | def convert_params(params): 154 | if params == 0: 155 | return "0" 156 | size_name = ("", "K", "M", "B", "T", "P", "E", "Z", "Y") 157 | i = int(math.floor(math.log(params, 1000))) 158 | p = math.pow(1000, i) 159 | s = round(params / p, 2) 160 | return "%s %s" % (s, size_name[i]) 161 | -------------------------------------------------------------------------------- /calc/README.md: -------------------------------------------------------------------------------- 1 | # Parameter and FLOP calculation walkthroughs 2 | 3 | Here we present a walk-through of how to compute the parameters and FLOPs for a variety of architectures such as standard transformers, MoEs, Mamba and Mamba2. 4 | 5 | ## Transformer parameters 6 | 7 | Let us first figure out how many parameters a given transformer has with a certain embedding dimension h, depth, l, vocab size V, and sequence length s. To do so, we walk systematically through every parameter. 8 | 9 | 1.) First we have the embedding layer which has $$V \times h$$ parameters. If the embedding and unembedding weights are untied there are two of these. 10 | 11 | 2.) We now have the QKVO attention matrices. Each one is of dimension $$h \times h$$ for $$4h^2$$ parameters in total. 12 | 13 | 3.) Next we have the MLP layers of which we have two: mlp_in with $$h \times 4h$$ parameters and mlp_out with $$4h \times h$$ parameters (assuming the usual expansion factor of 4. In total, this gives us $$8h^2$$ parameters. 14 | 15 | 4.) We have layernorm parameters (gains and biases) on each of the Q,K,V and mlp_in matrices which gives us $$8h$$ parameters 16 | 17 | Now that we have the parameters per block we simply have to add them up and multiply by the number of blocks l to obtain the total number of parameters. 18 | 19 | Then finally we add in the final layernorm with $$2h$$ parameters and the position embedding with $$sh$$ parameters and we are done. This gives a final equation of: 20 | 21 | $$\begin{align} 22 | \text{total params} = Vh + sh + 12lh^2 + 8hl + 2h 23 | \end{align}$$ 24 | 25 | ## MoE parameters 26 | 27 | The key difference between MoE and dense models is that each MLP layer instead becomes E parallel experts and each token is routed to a single (or many) experts. This means that the total parameter count of the model expands linearly in E but the FLOP cost to train and infer the model stays fixed at that of the forward pass of the original dense. MoEs are thus a possible way to infer much more efficiently than with dense models. The downsides of MoEs is that although they don’t use many FLOPs to infer, they nevertheless have a much larger amount of parameters that need to be kept in memory, and thus suffer from additional communication and storage overhead. 28 | 29 | Since a MoE is otherwise identical to a transformer except each MLP is copied E times, this means that we simply multiply the MLP component $$8lh^2$$ by the number of experts to get $$8Elh^2$$. There is also a routing layer of size $$E \times h$$ at each block to decide which expert to route the token to. This gives a total of $$Ehl$$ parameters due to the router. Thus the updated equation for the MoE transformer reads: 30 | 31 | $$\begin{align} 32 | \text{total params} = Vh + sh + 4hl^2 + 8ELh^2 + Ehl + 8hl + 2h 33 | \end{align}$$ 34 | 35 | ## Transformer FLOPs 36 | 37 | At the most coarse level the approximate flops for training is $$6 \times N \times D$$ where N is the amount of parameters and D is the amount of data. The reasoning for this is that each pass through the network requires approximately 1 multiply and 1 add per parameter per datapoint, and that there are effectively three passes on every step: a forward pass, a backward pass, and a weight update. 38 | 39 | The promise of MoE is that we can only conditionally use some of the parameters per datapoint resulting in an approximate FLOPs of $$6 \times \frac{1}{2}(N + \frac{N}{E}) \times D$$ if we assume that half of the parameters are from the MLP layers with MoEs (in practice for larger models this can be much more). 40 | 41 | We can, of course be much more specific so let’s break things down further. Here we make more use of the batch size b and the sequence length s. For simplicity let’s only consider the FLOPs from the MLP layers and the attention layers since the additional parts of the transformer (the positional encoding, the embedding and unembedding layers, and the layernorms) become increasingly irrelevant for large models. 42 | 43 | MLP layers: 44 | 45 | For each MLP layer per token we perform $$2 \times 4h \times h$$ operations (2 from the multiply + accumulate). There are two MLP layers per MLP block (mlp_in and mlp_out). This gives us $$16bslh^2$$ flops for the forward pass and thus $$32blsh^2$$ for the backwards pass and weight update (approximately 2x as much) for a total of $$48blsh^2$$ FLOPs for step of the MLP layers. 46 | 47 | QKVO attention FLOPs 48 | 49 | The dense matrix operations in attention have a very similar structure to the MLPs. We have four matrices (Q,K,V,O) which are applied per token and each matrix is $$h \times h$$ in size. This gives us $$4 \times 2 \times h^2$$ FLOPs per layer per token and thus $$8bslh^2$$ FLOPs in total for a forward pass and $$24bslh^2$$ FLOPs for a step. 50 | 51 | Attention scores and output 52 | 53 | To compute the attention matrix, we are multiplying together two $b \times s \times h$ matrices to compute an $$b \times s \times s$$ output. The FLOP cost of this is approximately $$2 \times b \times h \times s \times s$$ . The output of the attention (multiplication of V by the attention scores) has an equal cost. We ignore the cost of performing the softmax although this may be nontrivial. This results in a total attention cost of $$4bhs^2$$ for a total step cost of $$12bhs^2$$ FLOPs. 54 | 55 | Putting this together, we see that the total FLOP cost of a transformer model can be estimated as: 56 | 57 | $$\begin{align} 58 | \text{FLOPs per step} = 12bhs^2 + 72blsh^2 59 | \end{align}$$ 60 | 61 | Naively, we thus see that the MLP cost dominates at least as long as the embedding dimension h is larger than the sequence dimension s. For large enough s however, the cost of the attention begins to dominate due to the quadratic dependence of attention on the sequence length. 62 | 63 | ## MoE Flops 64 | 65 | The MoE typically splits the MLP parameters into E parallel copies. This expands the number of parameters but does not appreciably change the FLOP cost of a step of the model, so long as only a single expert is used per token per MLP block. 66 | 67 | ## Mamba parameters 68 | 69 | We now consider computing the parameter and flops of a single Mamba1 and Mamba2 block. First we handle parameters. For Mamba1 these calculations can also be found in the Appendix of our [BlackMamba paper](https://arxiv.org/abs/2402.01771). 70 | 71 | ### Mamba1 parameters 72 | 73 | This refers to the original Mamba block as introduced [here](https://arxiv.org/pdf/2312.00752). A Mamba block takes input from the residual stream of size $D \times S$ where D is the embedding dimension and L is the sequence length. The mamba block also has an internal expansion factor (typically of 2) where it operates in a larger internal embedding dimension we denote I. The Mamba block also contains an internal causal-convolutional layer with kernel width C and an internal SSM state of size S. Finally, the Mamba layer has a dt projection which controls the size of the small MLP which is used to set the token-wise time-constant for the SSM. 74 | 75 | The Mamba layer begins with two input projections of size $$D \times I$$ which map the residual stream input into the inner dimension. There are two projections one for the SSM input itself and secondly for the gate input. This gives a total of $$2ID$$ in-projector parameters. After the in-projector there is a convolutional layer prior to the SSM. This requires $$C \times I$$ parameters plus an additional $$I$$ parameters for the convolutional bias. 76 | 77 | This is then followed by the matrices producing the A,B, and C matrices of the SSM (similar to the QKV matrices of attention). Each of these matrices is of size $$I \times S$$ resulting in $$3IS$$ parameters. Additionally, there is the dt projector which consists of $$2 \times dt \times I$$ as well as a dt bias and the D bias vector both of length $$I$$. Finally, there is the SSM outprojector of size $$I \times D$$ which maps back to the embedding dimension of the residual stream, as well as the input layernorm which contains $2 \times D$ parameters since its gain and bias parameters are both of shape $$D$$. 78 | 79 | Putting this all together, we obtain the following count of total parameters: 80 | 81 | $$\begin{align} 82 | \text{Total parameters} = 3ID + 2I(S + dt + \frac{C}{2}) + I + 2D 83 | \end{align}$$ 84 | 85 | ## Mamba2 parameters 86 | 87 | The Mamba2 block (introduced [here](https://arxiv.org/abs/2405.21060)) introduces a few modifications to the Mamba1 block to improve flop efficiently. These primarily consist in making the A matrix scalar instead of diagonal and making the B,C,dt matrices depend directly on the input from the residual stream instead of first passing through the convolution. It also introduces the notion of heads, similar to attention heads, and groups which are similar to the repeated heads in GQA. We denote the number of groups as G and the number of heads as H. 88 | 89 | We begin by computing the in-projector as before. This consists of the input and gate projections of shape $$D \times I$$ each as well as the B and C matrices of shape $$S \times G$$ each and the dt projection of shape $$H$$ (dt is now a scalar per Mamba head). There are also the A and D matrices of shape $$H$$. 90 | 91 | The in-projector is followed by the convolution which is applied to the x, B, and C matrices. The total parameters for the convolution are thus $$I + 2GS \times C$$. and the convolutional bias of shape $$I + 2GS$$. Following the convolution, unlike Mamba1 there is also an additional SSM internal layernorm which utilizes $$2I$$ parameters. Following the SSM, there is then the out-projector matrix which is of size $$I \times D$$. Putting this all together, we obtain an expression for the parameters in a Mamba2 layer as: 92 | 93 | $$\begin{align} 94 | \text{Total parameters} = 3ID + 2DGS + DH + 2H + (I + 2GS)(1 + C) + I 95 | \end{align}$$ 96 | 97 | ## Mamba Flops 98 | 99 | In general, given two matrices $$A^{K \times M}$$ and $$B^{M \times J}$$ the total flops of computing their matrix product is $$2KMJ$$ where the 2 comes from the fact that there is both a multiple and an addition operation. 100 | 101 | Let us consider the in and out projectors of Mamba. These are matrices of shape $I \times D$ being multipled with input of shape $B \times L \times D$ and there are three such matrix multiplications $$W_x, W_z, W_y$$ resulting in $$6BLID$$ FLOPs. Next is the convolution which can be treated as a single $$I \times C$$ matrix multiply requiring $$2BLIC$$ FLOPs. 102 | 103 | Now, we turn to the SSM block itself. We first compute the input-dependent B and C matrices requiring a matrix multiply of shape $$I \times H$$ each thus resulting in $$4BLIS$$ FLOPs. The A matrix is not multiplied by the input but goes through an elementwise transform costing $$IS$$ FLOPs. The dt projection first goes through an elementwise operation of order $$BLIdt$$ FLOPs. 104 | Next, the discretization. The A matrix is multiplied by the dt vector resulting, costing $$BLIS$$ FLOPs. The B matrix is multiplied by the dt costing $$2BLIS$$ FLOPs. The SSM linear state space step itself is just a matrix multiply and add so costs $$2BLIS$$ FLOPs, and then the output projection using the C matrix also costs $$2BLIS$$ FLOPs. Finally there is the out-projector which costs $$2BLEI$$ FLOPs Putting this all together, we obtain the following expression: 105 | 106 | $$\begin{align} 107 | \text{Total FLOPs} = BLI(6D + 2C + 8IS + 2E + dt) + IS 108 | \end{align}$$ 109 | 110 | ## Mamba2 FLOPs 111 | 112 | Computing the flops of a Mamba2 block involves going through a similar exercise. First we consider the much-enlarged in-projector of Mamba2 which is of shape $$(2I + 2GS + H) \times D$$ which is multiplied by the embedding input of size $$B \times L \times D$$. This results in $$2BL(2I + 2GS + H)D$$ FLOPs. The in-projector is then split and only the xBC matrix is passed through the conv at the FLOP cost of $$2BL(I + 2GS)C$$. Following the conv there is the computation of the ssm state matrices and multiplication by dt which costs $$2BLIS$$ FLOPs and the SSM computation itself which also costs $$2BLIS$$ FLOPs. Finally, there is the multiplication by the C matrix which costs $$2BLIS$$ and the multiplication by the gate which costs $$BLI$$, and finally the multiplication by the out-projector costing $$2BLIE$$. Putting this all together we obtain: 113 | 114 | $$\begin{align} 115 | \text{Total FLOPs} = BL\Big(4ID + 2GSD + 2HD + 2IC + 4GSC + 6IS + 2SD + I + 2IE + D\Big) 116 | \end{align}$$ 117 | 118 | 119 | 120 | 121 | 122 | ## FLOP budgets 123 | 124 | The way to think about the FLOP budget is to figure out how many TFLOPs you can get per GPU running the model and then how many days you can afford to train the model for. That is, we get 125 | 126 | $$\begin{align} 127 | \text{FLOP budget} = \text{TFLOPs per GPU} \times \text{NUM GPUs} \times \text{Days} \times 24 \times 60 \times 60 128 | \end{align}$$ 129 | 130 | Where we convert days into seconds since TFLOPs are in seconds. As an example, let’s suppose (optimistically) that we get 400 TFLOPs/H100, we have 64 H100s, and we train for 60 days, we get a budget of 132M TFLOPs. 131 | 132 | With this number and our estimates of FLOP count for a given model and dataset size we can evaluate what model and dataset sizes are feasible to train given our FLOP budget. This gives us a space of possible models we can train. Within this space, we can use the scaling laws (if they exist) to tell us what the optimal model and dataset size would be to achieve the lowest loss. 133 | 134 | 135 | ## Token Calculation 136 | 137 | We provide a script at https://github.com/Zyphra/cookbook/tree/main/calc/data/tokenize_and_count.py that tokenizes text data from a Hugging Face dataset, calculates the total number of tokens, and optionally saves the tokenized dataset. 138 | 139 | ### Requirements 140 | 141 | - Python 3.6+ 142 | - transformers 143 | - datasets 144 | 145 | Install the required packages: 146 | 147 | ``` 148 | pip install transformers datasets 149 | ``` 150 | 151 | ### Usage 152 | 153 | Run the script from the command line with the following arguments: 154 | 155 | ``` 156 | python tokenize_and_count.py --hf-path --hf-tokenizer [OPTIONS] 157 | ``` 158 | 159 | 160 | #### Required Arguments: 161 | 162 | - `--hf-path`: Path of the Hugging Face dataset 163 | - `--hf-tokenizer`: Path of the Hugging Face tokenizer 164 | 165 | #### Optional Arguments: 166 | 167 | - `--hf-dir`: Directory in the Hugging Face dataset (default: None) 168 | - `--key`: Name of the column that contains text to tokenize (default: 'text') 169 | - `--save-path`: Folder to save processed Hugging Face dataset to (default: None) 170 | - `--num-proc`: Number of processes for Hugging Face processing (default: 1) 171 | 172 | ### Example 173 | 174 | ``` 175 | python tokenize_and_count.py --hf-path "dataset/my_dataset" --hf-tokenizer "bert-base-uncased" --key "content" --save-path "./tokenized_dataset" --num-proc 4 176 | ``` 177 | 178 | 179 | This command will: 180 | 1. Load the dataset from "dataset/my_dataset" 181 | 2. Use the "bert-base-uncased" tokenizer 182 | 3. Tokenize the "content" column of the dataset 183 | 4. Use 4 processes for parallel processing 184 | 5. Save the tokenized dataset to "./tokenized_dataset" 185 | 186 | The script will output the total number of tokens in the dataset and save the tokenized dataset if a save path is provided. 187 | -------------------------------------------------------------------------------- /calc/calc_mamba_flops.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | 4 | # Helper function to pretty-print message sizes 5 | def convert_flops(params): 6 | if params == 0: 7 | return "0" 8 | size_name = ("", "KFLOPs", "MFLOPs", "GFLOPs", "TFLOPs", "PFLOPs", "EFLOPs", "ZFLOPs", "YFLOPs") 9 | i = int(math.floor(math.log(params, 1000))) 10 | p = math.pow(1000, i) 11 | s = round(params / p, 2) 12 | return "%s %s" % (s, size_name[i]) 13 | 14 | def config_parser(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--vocab-size", "-v", 17 | type=int, 18 | default=51200, 19 | help='Size of the vocab') 20 | parser.add_argument("--hidden-size", "-hs", 21 | type=int, 22 | default=768, 23 | help='Dimension of the model\'s hidden size') 24 | parser.add_argument("--sequence-length", "-s", 25 | type=int, 26 | default=2048, 27 | help='Sequence length used for training') 28 | parser.add_argument("--num-mamba-layers", 29 | type=int, 30 | default=0, 31 | help='Number of mamba layers used in model') 32 | parser.add_argument("--state-size", 33 | type=int, 34 | default=16, 35 | help='State dimension') 36 | parser.add_argument("--expansion-factor", 37 | type=int, 38 | default=2, 39 | help='Expansion factor relating inner dimension and hidden size (or d_model)') 40 | parser.add_argument("--conv-dimension", 41 | type=int, 42 | default=4, 43 | help='Dimension of convolution kernel') 44 | parser.add_argument("--dt-rank", 45 | type=str, 46 | default="auto", 47 | help='Rank of dt') 48 | help='conv1d kernel size') 49 | parser.add_argument("--mamba-ngroups", "-dc", 50 | type=int, 51 | default=1, 52 | help='Number of Mamba groups') 53 | parser.add_argument("--mamba-headdim", "-dc", 54 | type=int, 55 | default=64, 56 | help='Mamba2 head dimension') 57 | parser.add_argument("--num-moe-layers", "-l", 58 | type=int, 59 | default=0, 60 | help='Number of moe layers used in model') 61 | parser.add_argument("--num-experts", "-e", 62 | type=int, 63 | default=0, 64 | help='Number of experts for MoE') 65 | parser.add_argument("--ffn-hidden-size", 66 | type=int, 67 | default=None, 68 | help='Hidden dimension of the MLP') 69 | parser.add_argument("--topk", "-t", 70 | type=int, 71 | default=1, 72 | help='Top k routing for MoE') 73 | parser.add_argument("--swiglu", 74 | action="store_true", 75 | help='Use swiglu MLP. If set, ffn-hidden-size is defined as the inner dimension of each of the three MLP weights.') 76 | parser.add_argument("--tokens", 77 | type=int, 78 | default=None, 79 | help='Number of tokens you are training over') 80 | parser.add_argument("--no-checkpoint-activations", "-ca", 81 | action='store_false', 82 | help='Whether activation checkpointing is being used', 83 | dest='checkpoint_activations') 84 | parser.add_argument("--mamba_moe_layers", type = str, default = "") 85 | return parser 86 | 87 | def compute_mamba2_flops(args): 88 | d_inner = args.hidden_size * args.expand 89 | Nheads = d_inner // args.mamba_headdim 90 | mamba2_block_flops = 2 * (2 * d_inner + 2 * args.mamba_ngroups * args.state_size + Nheads) * args.hidden_size * args.batch_size * args.sequence_length # in proj 91 | mamba2_block_flops += 2 * args.batch_size * args.sequence_length * (d_inner + 2 * args.mamba_ngroups * args.state_size) * args.conv_dimension * args.d_inner# conv 92 | mamba2_block_flops += 2 * args.batch_size * args.sequence_length * d_inner * args.state_size * args.d_inner # dtbx 93 | mamba2_block_flops += 2 * args.batch_size * args.sequence_length * d_inner * args.state_size # ssm state rollover 94 | mamba2_block_flops += 2 * args.batch_size * args.sequence_length * d_inner * args.state_size * args.d_model # c-> y 95 | mamba2_block_flops += args.batch_size * args.sequence_length * args.hidden_size # z gate output 96 | return mamba2_block_flops 97 | 98 | def compute_mamba_flops(args): 99 | d_inner = args.hidden_size * args.expansion_factor 100 | dt_rank = math.ceil(args.hidden_size / 16) if args.dt_rank == "auto" else args.dt_rank 101 | ssm_flops = iter_factor * d_inner * args.tokens * (11 * args.state_size + 4 * dt_rank + 1) 102 | mamba_projectors_flops = iter_factor * args.tokens * 6 * d_inner * args.hidden_size 103 | mamba_conv_flops = iter_factor * args.tokens * 2 * d_inner * args.conv_dimension 104 | mamba_flops = ssm_flops + mamba_projectors_flops + mamba_conv_flops 105 | return mamba_flops 106 | 107 | def compute_attention_flops(args): 108 | qkv_flops = int(iter_factor * 2 * (1 + 2 * args.kv_size_ratio) * args.batch_size * args.hidden_size * args.hidden_size) 109 | attention_matrix_flops = iter_factor * 2 * args.batch_size * args.sequence_length * args.hidden_size 110 | attention_over_values_flops = iter_factor * 2 * args.batch_size * args.sequence_length * args.hidden_size 111 | linear_projection_flops = iter_factor * 2 * args.batch_size * args.hidden_size * args.hidden_size 112 | return qkv_flops + attention_matrix_flops + attention_over_values_flops + linear_projection_flops 113 | 114 | def compute_ffn_flops(args): 115 | if args.ffn_hidden_size is not None: 116 | ffn_flops = int(2 * args.batch_size * args.ffn_hidden_size * args.ffn_hidden_size) 117 | else: 118 | ffn_flops = int(2 * args.ffn_expansion_factor) * args.batch_size * args.hidden_size * args.hidden_size 119 | return ffn_flops 120 | 121 | 122 | 123 | # calculates the flops of a model given its hparams 124 | def calc_flops(args): 125 | if args.num_experts > 1: 126 | assert args.topk <= args.num_experts, "You cannot route to more experts than you have!" 127 | 128 | 129 | # An A_(m x k) X B_(k x n) matrix multiplication requires 2m x k x n FLOPs (factor of 2 needed to account for multiplies and adds) 130 | 131 | # determine the flops factor. 132 | # If no activation checkpointing/recomputation, 1 for fwd and 2 for bwd (because we need to calculate the grads with respect to both the input and weight tensors). 133 | # If activation checkpointing/recomputation, add 1 more for the next full forward pass 134 | iter_factor = 3 135 | if args.checkpoint_activations: 136 | iter_factor += 1 137 | if args.ffn_hidden_size is None: 138 | args.ffn_hidden_size = 4* args.hidden_size 139 | 140 | mamba_flops = compute_mamba_flops(args) 141 | mamba2_flops = compute_mamba2_flops(args) 142 | attention_flops = compute_attention_flops(args) 143 | ffn_flops = compute_ffn_flops(args) 144 | 145 | total_mamba_flops = 0 146 | total_mamba2_flops = 0 147 | total_attention_flops = 0 148 | total_ffn_flops = 0 149 | 150 | 151 | # no activation checkpointing for embeddings 152 | embedding_flops = 6 * args.tokens * args.hidden_size * args.vocab_size 153 | 154 | if args.mamba_moe_layers == "": 155 | # assume a pure mamba1 model unless specified otherwise 156 | total_flops = embedding_flops + (mamba_flops * args.num_mamba_layers) 157 | total_mamba_flops += mamba_flops * args.num_mamba_layers) 158 | # if MoE layers add these in 159 | if args.num_moe_layers > 0: 160 | ffn_flops = iter_factor * args.tokens * 4 * args.ffn_hidden_size * args.num_moe_layers * args.hidden_size 161 | if args.swiglu: 162 | ffn_flops = 3/2 * ffn_flops 163 | gating_flops = iter_factor * 2 * args.tokens * args.num_experts * args.num_moe_layers 164 | total_flops += ffn_flops + gating_flops 165 | total_ffn_flops += ffn_flops 166 | 167 | else: 168 | arch_list = args.mamba_moe_layers.split(" ") 169 | total_flops = 0 170 | total_flops += embedding_params 171 | for el in arch_list: 172 | if el == "r": 173 | # mamba layer 174 | total_flops += mamba_flops 175 | total_mamba_flops += mamba_flops 176 | elif el == "m": 177 | total_flops += mamba2_flops 178 | total_mamba2_flops += mamba2_flops 179 | elif el == "a": 180 | total_flops += attention_flops 181 | total_attention_flops += attention_flops 182 | elif el.isnumeric(): 183 | total_flops += ffn_flops 184 | total_ffn_flops += ffn_flops 185 | elif el == "g": 186 | # zamba shared layer 187 | original_hidden_size = args.hidden_size 188 | args.hidden_size = original_hidden_size * 2 189 | shared_attention_flops = compute_attention_flops(args) 190 | shared_ffn_flops = compute_ffn_flops(args) 191 | total_flops += shared_attention_flops 192 | total_attention_flops += shared_attention_flops 193 | total_flops += shared_ffn_flops 194 | total_ffn_flops = shared_ffn_flops 195 | args.hidden_size = original_hidden_size 196 | # final downprojector matrix 197 | total_flops += 4 * args.batch_size * args.sequence_length * args.hidden_size * args.hidden_size 198 | else: 199 | raise ValueError("Invalid layer string: " + str(el) " not recognized.") 200 | 201 | total_flops *= iter_factor 202 | 203 | print(f'Calculating number of FLOPs with training configuration: {vars(args)}\n') 204 | print(f'Total Mamba FLOPs: {convert_flops(total_mamba_flops)}') 205 | print(f'Total Mamba2 FLOPs: {convert_flops(total_mamba2_flops)}') 206 | print(f'Total Attention FLOPs: {convert_flops(total_attention_flops)}') 207 | print(f'Total FFN FLOPs: {convert_flops(total_ffn_flops)}') 208 | print(f'Embedding FLOPs: {convert_flops(embedding_flops)}') 209 | print(f'Total FLOPs for the Model: {convert_flops(total_flops)}') 210 | if args.tokens is not None: 211 | total_flops_through_training = int(total_flops * (args.tokens // args.batch_size)) 212 | print(f'Total FLOPs through training: {convert_flops(total_flops_through_training)}') 213 | 214 | if __name__ == "__main__": 215 | print('\nExample: python calc_mamba_moe_flops.py -num-mamba-layers 12 -hs 768 --num-experts 8 --num-moe-layers 12 -s 2048 --tokens 300e9') 216 | 217 | args = config_parser().parse_args() 218 | calc_flops(args) 219 | -------------------------------------------------------------------------------- /calc/calc_mamba_params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | 4 | # Helper function to pretty-print message sizes 5 | def convert_params(params): 6 | if params == 0: 7 | return "0" 8 | size_name = ("", "K", "M", "B", "T", "P", "E", "Z", "Y") 9 | i = int(math.floor(math.log(params, 1000))) 10 | p = math.pow(1000, i) 11 | s = round(params / p, 2) 12 | return "%s %s" % (s, size_name[i]) 13 | 14 | def config_parser(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--vocab-size", "-v", 17 | type=int, 18 | default=50277, 19 | help='Size of the vocab') 20 | parser.add_argument("--d-model", "-dm", 21 | type=int, 22 | default=768, 23 | help='Embedding dimension') 24 | parser.add_argument("--d-state", "-ds", 25 | type=int, 26 | default=16, 27 | help='Hidded state dimension') 28 | parser.add_argument("--d-conv", "-dc", 29 | type=int, 30 | default=4, 31 | help='conv1d kernel size') 32 | parser.add_argument("--mamba-ngroups", "-dc", 33 | type=int, 34 | default=1, 35 | help='Number of Mamba groups') 36 | parser.add_argument("--mamba-headdim", "-dc", 37 | type=int, 38 | default=64, 39 | help='Mamba2 head dimension') 40 | parser.add_argument("--expand", "-ex", 41 | type=int, 42 | default=2, 43 | help='Inner state expansion factor') 44 | parser.add_argument("--dt-rank", "-dr", 45 | type=int, 46 | default=-1, 47 | help='Rank of the delta. Default is -1, which means auto') 48 | parser.add_argument("--num-layers", "-l", 49 | type=int, 50 | default=24, 51 | help='Total number of sequential layers used in model (both Mamba and MoE') 52 | parser.add_argument("--num-experts", "-e", 53 | type=int, 54 | default=0, 55 | help="Number of experts used in model") 56 | parser.add_argument("--ffn-expansion-factor", '-fe', 57 | type=float, 58 | default=4, 59 | help="Expansion factor of the ffn hidden dimension") 60 | parser.add_argument("--expert-interval", "-ei", 61 | type=int, 62 | default=1, 63 | help="Every N blocks to put an MoE block") 64 | parser.add_argument("--parallel-moe", "-p", 65 | type = bool, 66 | default = False, 67 | help = "Run the MoE MLPs in parallel with the mamba blocks like gptj") 68 | parser.add_argument("--swiglu", 69 | action="store_true", 70 | help='Use swiglu MLP. If set, ffn-hidden-size needs to be specified and is defined as the inner dimension of each of the three MLP weights.') 71 | parser.add_argument("--ffn-hidden-size", 72 | type=int, 73 | default=0, 74 | help="Hidden dimension of the MLP") 75 | parser.add_argument("--mamba_moe_layers", type = str, default = "") 76 | parser.add_argument("--use-global-mem", type=bool, default = False) 77 | parser.add_argument("--global-memory-projection-interval", type=int, default = 2) 78 | 79 | return parser 80 | 81 | def compute_mamba_block_params(args): 82 | d_inner = args.d_model * args.expand 83 | dt_rank = math.ceil(args.d_model / 16) if args.dt_rank < 1 else args.dt_rank 84 | 85 | 86 | mamba_block_params = (d_inner * args.d_model) # W_x 87 | mamba_block_params += (d_inner * args.d_model) # W_z 88 | mamba_block_params += (args.d_conv * d_inner) + d_inner # conv1d 89 | mamba_block_params += (args.d_state * d_inner) # W_B 90 | mamba_block_params += (args.d_state * d_inner) # W_C 91 | mamba_block_params += 2 * (dt_rank * d_inner) + d_inner # W_dt 92 | mamba_block_params += (d_inner * args.d_state) # W_A 93 | mamba_block_params += (d_inner) # D 94 | mamba_block_params += (args.d_model * d_inner) # W_y 95 | mamba_block_params += 2 * args.d_model # LayerNorm 96 | return mamba_block_params, dt_rank 97 | 98 | def compute_mamba2_block_params(args): 99 | d_inner = args.d_model * args.expand 100 | n_heads = d_inner / args.mamba_headdim 101 | d_in_proj = (2 * d_inner) + (2 * args.mamba_ngroups * args.d_state) + n_heads 102 | mamba2_block_params = args.d_model * d_in_proj # W_in 103 | mamba2_block_params += 3 * n_heads # A, dt, D 104 | mamba2_block_params += (d_inner + (2 * args.mamba_ngroups * args.d_state)) * args.d_conv # conv weight 105 | mamba2_block_params += d_inner + (2 * args.mamba_ngroups * args.d_state) # conv bias 106 | mamba2_block_params += d_inner # layernorm 107 | mamba2_block_params += d_inner * args.d_model # W_out 108 | return mamba2_block_params 109 | 110 | # calculates the params of a model given their hparams 111 | def calc_params(args): 112 | if args.swiglu: 113 | assert args.ffn_hidden_size > 0, "If args.swiglu=True, ffn-hidden-size needs to be specified." 114 | # Embedding unembedding weights are tied 115 | embedding_params = args.d_model * args.vocab_size 116 | attention_block_params = 4 * args.d_model * args.d_model 117 | 118 | mamba_block_params, dt_rank = compute_mamba_block_params(args) 119 | mamba2_block_params = compute_mamba2_block_params(args) 120 | 121 | 122 | 123 | ffn_dim = args.d_model * args.ffn_expansion_factor 124 | ffn_block_params = 2 * ffn_dim * args.d_model 125 | 126 | if args.mamba_moe_layers == "": 127 | if args.num_experts == 0: 128 | # pure mamba 129 | total_ffn_params = 0 130 | ffn_block_params = 0 131 | total_expert_params = 0 132 | total_params = args.num_layers * mamba_block_params + embedding_params 133 | mamba_block_params = total_params 134 | 135 | else: 136 | if not args.parallel_moe: 137 | mamba_block_params = int(round((args.num_layers * mamba_block_params) * (1 - (1/args.expert_interval)))) 138 | else: 139 | 140 | mamba_block_params = int(round((args.num_layers * mamba_block_params))) 141 | 142 | if args.swiglu: 143 | ffn_block_params = 3 * args.ffn_hidden_size * args.d_model 144 | if not args.parallel_moe: 145 | total_ffn_params = (args.num_layers // args.expert_interval) * ffn_block_params 146 | else: 147 | total_ffn_params = args.num_layers * ffn_block_params 148 | total_expert_params = total_ffn_params * args.num_experts 149 | total_params = mamba_block_params + total_expert_params + embedding_params 150 | forward_pass_params = mamba_block_params + total_ffn_params + embedding_params 151 | else: 152 | arch_list = args.mamba_moe_layers.split(" ") 153 | len_list = len(arch_list) 154 | assert len_list == args.num_layers, "Length of mamba moe list is not the same as the total number of layers" 155 | total_params = 0 156 | total_params += embedding_params 157 | forward_pass_params = 0 158 | forward_pass_params += embedding_params 159 | total_attention_params = 0 160 | total_mamba_params = 0 161 | total_ffn_params = 0 162 | for el in arch_list: 163 | if el == 'r': 164 | total_params += mamba_block_params 165 | forward_pass_params += mamba_block_params 166 | total_mamba_params += mamba_block_params 167 | elif el == 'm': 168 | total_params += mamba2_block_params 169 | forward_pass_params += mamba2_block_params 170 | total_mamba_params += mamba2_block_params 171 | elif el == 'a': 172 | total_params += attention_block_params 173 | forward_pass_params += attention_block_params 174 | total_attention_params += attention_block_params 175 | elif el.isnumeric(): 176 | num_experts = int(el) 177 | total_params += num_experts * ffn_block_params 178 | forward_pass_params += ffn_block_params 179 | total_ffn_params += num_experts * ffn_block_params 180 | else: 181 | raise ValueError("Invalid layers string") 182 | 183 | if args.use_global_mem: 184 | # we add a transformer layer and an MLP layer to the model plus we add a d^2 linear layer to each layer 185 | global_mem_params = 0 186 | global_mem_params += 3 * (2 * args.d_model)**2 + 2 * args.d_model**2 # qkv act on 2d_model, output proj maps 2d_model -> d_model 187 | global_mem_params += 12 * args.d_model # the 12 is because we use swiglu without the 2/3 resizing of the ffn hidden dimension 188 | global_mem_projection_params = (args.num_layers // args.global_memory_projection_interval) * args.d_model * args.d_model 189 | total_params += global_mem_params 190 | total_params += global_mem_projection_params 191 | 192 | 193 | 194 | if args.mamba_moe_layers == "": 195 | print(f'Calculating number of parameters with training configuration: {vars(args)}\n') 196 | print(f'dt_rank: {convert_params(dt_rank)}') 197 | print(f'Embedding parameters: {convert_params(embedding_params)}') 198 | print(f'Single Mamba block params: {convert_params(mamba_block_params)}') 199 | print(f'FFN block params: {convert_params(ffn_block_params)}') 200 | print(f'Total Mamba Params: {convert_params(mamba_block_params)}') 201 | print(f"Total FFN params: {convert_params(total_expert_params)}") 202 | 203 | else: 204 | print(f'Calculating number of parameters with training configuration: {vars(args)}\n') 205 | print(f'Embedding parameters: {convert_params(embedding_params)}') 206 | print(f'Total Mamba Params: {convert_params(total_mamba_params)}') 207 | print(f'Total Attention Params: {convert_params(total_attention_params)}') 208 | print(f"Total FFN params: {convert_params(total_ffn_params)}") 209 | 210 | if args.use_global_mem: 211 | print(f'Global Memory Parameters: {convert_params(global_mem_params)}') 212 | print(f'Global Memory Projection Params: {convert_params(global_mem_projection_params)}') 213 | 214 | print(f'Total params: {convert_params(total_params)}') 215 | print("Aspect Ratio: ", args.d_model / args.num_layers) 216 | if args.num_experts > 0: 217 | print(f'Forward pass params: {convert_params(forward_pass_params)}') 218 | 219 | 220 | 221 | if __name__ == "__main__": 222 | args = config_parser().parse_args() 223 | calc_params(args) 224 | -------------------------------------------------------------------------------- /calc/data/convert_into_jsonl_partitions.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import argparse 3 | import datasets 4 | import os 5 | import more_itertools 6 | 7 | import logging 8 | logging.basicConfig(format='%(asctime)s: %(message)s', level=logging.INFO) 9 | 10 | 11 | def save_shard(data, path, total, indices): 12 | for idx in indices: 13 | shard = data.shard(num_shards=total, index=idx, contiguous=True) 14 | save_path = os.path.join(path, f"partition_{idx}.jsonl") 15 | shard.to_json(save_path) 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--hf-path", type=str, required=True, help="Path of HF dataset") 21 | parser.add_argument("--hf-dir", type=str, default=None, help="Dir in HF dataset") 22 | parser.add_argument("--save-path", type=str, required=True, help="Folder to save partitioned JSONL dataset to") 23 | parser.add_argument("--num-proc", type=int, default=1, help="Number of processes for HF processing") 24 | parser.add_argument("--num-partitions", type=int, default=1, help="Number of partitions to split the dataset into") 25 | args = parser.parse_args() 26 | 27 | logging.info("Loading the dataset") 28 | ds = datasets.load_dataset( 29 | path=args.hf_path, 30 | data_dir=args.hf_dir, 31 | num_proc=args.num_proc, 32 | split="train", 33 | trust_remote_code=True 34 | ) 35 | 36 | logging.info("Saving JSONL partitions") 37 | n_proc = min(args.num_proc, args.num_partitions) 38 | inds_distr = more_itertools.distribute(n_proc, range(args.num_partitions)) 39 | processes = [] 40 | for process_inds in inds_distr: 41 | p = mp.Process(target=save_shard, args=(ds, args.save_path, args.num_partitions, process_inds)) 42 | processes.append(p) 43 | p.start() 44 | for p in processes: 45 | p.join() 46 | 47 | logging.info("Done!") 48 | -------------------------------------------------------------------------------- /calc/data/tokenize_and_count.py: -------------------------------------------------------------------------------- 1 | """ 2 | Token Calculation Script 3 | 4 | This script tokenizes text data from a Hugging Face dataset, calculates the total number of tokens, 5 | and optionally saves the tokenized dataset. 6 | 7 | It uses the Hugging Face Transformers library for tokenization and the Datasets library for data handling. 8 | """ 9 | 10 | from typing import Dict, List 11 | from collections import defaultdict 12 | from transformers import AutoTokenizer 13 | import argparse 14 | import datasets 15 | 16 | import logging 17 | logging.basicConfig(format='%(asctime)s: %(message)s', level=logging.INFO) 18 | 19 | def tokenize( 20 | batch, 21 | tokenizer, 22 | key: str = "text", 23 | ) -> Dict[str, List]: 24 | """ 25 | Tokenize a batch of texts using the provided tokenizer. 26 | 27 | Args: 28 | batch: A dictionary containing the batch of data. 29 | tokenizer: The tokenizer to use for encoding the text. 30 | key: The key in the batch dictionary that contains the text to tokenize. 31 | 32 | Returns: 33 | A dictionary with the tokenized texts and their token counts. 34 | """ 35 | texts = batch[key] 36 | features = defaultdict(list) 37 | for text in texts: 38 | tokenized_text = tokenizer.encode(text) 39 | features[f"tokenized_{key}"].append(tokenized_text) 40 | features["n_tokens"].append(len(tokenized_text)) 41 | return features 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser(description="Tokenize text data and calculate token count.") 45 | parser.add_argument("--hf-path", type=str, required=True, help="Path of HF dataset") 46 | parser.add_argument("--hf-dir", type=str, default=None, help="Dir in HF dataset") 47 | parser.add_argument("--hf-tokenizer", type=str, required=True, help="Path of HF tokenizer") 48 | parser.add_argument("--key", type=str, default='text', help="Name of the column that contains text to tokenize") 49 | parser.add_argument("--save-path", type=str, help="Folder to save processed HF dataset to") 50 | parser.add_argument("--num-proc", type=int, default=1, help="Number of processes for HF processing") 51 | args = parser.parse_args() 52 | 53 | logging.info("Loading the dataset") 54 | ds = datasets.load_dataset(path=args.hf_path, data_dir=args.hf_dir) 55 | 56 | logging.info("Loading the tokenizer") 57 | tokenizer = AutoTokenizer.from_pretrained(args.hf_tokenizer) 58 | 59 | logging.info("Tokenizing the dataset") 60 | ds_tok = ds.map( 61 | lambda batch: tokenize(batch, tokenizer, key=args.key), 62 | batched=True, 63 | num_proc=args.num_proc, 64 | ) 65 | 66 | logging.info("Computing total number of tokens") 67 | n_tok = sum(ds_tok["n_tokens"]) 68 | logging.info(f"Total number of tokens: {n_tok}") 69 | 70 | if args.save_path: 71 | logging.info("Saving tokenized dataset") 72 | ds_tok.save_to_disk(args.save_path) 73 | -------------------------------------------------------------------------------- /imgs/annealing-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zyphra/zcookbook/293c813f9e3e428c044f2aefc801641997657683/imgs/annealing-example.png -------------------------------------------------------------------------------- /imgs/mamba-moe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zyphra/zcookbook/293c813f9e3e428c044f2aefc801641997657683/imgs/mamba-moe.png -------------------------------------------------------------------------------- /imgs/mamba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zyphra/zcookbook/293c813f9e3e428c044f2aefc801641997657683/imgs/mamba.png -------------------------------------------------------------------------------- /imgs/transformer-moe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zyphra/zcookbook/293c813f9e3e428c044f2aefc801641997657683/imgs/transformer-moe.png -------------------------------------------------------------------------------- /imgs/transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zyphra/zcookbook/293c813f9e3e428c044f2aefc801641997657683/imgs/transformer.png -------------------------------------------------------------------------------- /imgs/zamba-7b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zyphra/zcookbook/293c813f9e3e428c044f2aefc801641997657683/imgs/zamba-7b.png -------------------------------------------------------------------------------- /imgs/zamba2-1p2b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zyphra/zcookbook/293c813f9e3e428c044f2aefc801641997657683/imgs/zamba2-1p2b.png -------------------------------------------------------------------------------- /imgs/zamba2-2p7b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zyphra/zcookbook/293c813f9e3e428c044f2aefc801641997657683/imgs/zamba2-2p7b.png -------------------------------------------------------------------------------- /imgs/zcookbook.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zyphra/zcookbook/293c813f9e3e428c044f2aefc801641997657683/imgs/zcookbook.jpg --------------------------------------------------------------------------------