├── .gitignore ├── LICENSE ├── README.md ├── demo.ipynb ├── mamba2.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .venv/ 2 | __pycache__/ 3 | .ipynb_checkpoints/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2024 Thomas Ip 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mamba2-minimal 2 | 3 | A minimal, single-file implementation of the Mamba-2 model in PyTorch. 4 | 5 | ![Mamba-2](https://github.com/state-spaces/mamba/blob/f9dbb4fdb2705d71282e0db184d177c6375623f0/assets/ssd_algorithm.png) 6 | > **Transformers are SSMs: Generalized Models and Efficient Algorithms**\ 7 | > **Through Structured State Space Duality**\ 8 | > Tri Dao*, Albert Gu*\ 9 | > Paper: https://arxiv.org/abs/2405.21060 10 | 11 | Mamba is a new class of foundation models, most notable for _not_ being based on the Transformer architecture. Instead it is in the family of State Space Models (SSMs) that maps a sequence through a hidden state in the fashion of RNNs. This approach enables linear scaling in computation and memory with respect to sequence length during training (unlike transformer's quadratic complexity), as well as constant time per step during inference. Mamba-2 builds upon Mamba-1 by imposing additional constraints on certain SSM parameters, allowing it to have much larger state dimensions and significantly improved training speed. 12 | 13 | This implementation is device agnostic and have been tested to work on the CPU and MPS (Metal Performance Shaders) backends. The model's output logits follow the same distribution as the reference implementation but are not numerically equivalent. 14 | 15 | ## Usage 16 | 17 | Install dependencies (`torch`, `einops` and `transformers`): 18 | 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | **See [demo.ipynb](./demo.ipynb) for using Mamba-2 as part of an end-to-end language model with pretrained weights for text generation.** 24 | 25 | The core Mamba-2 model can be used as follows: 26 | 27 | ```py 28 | import torch 29 | 30 | from mamba2 import Mamba2, Mamba2Config 31 | 32 | config = Mamba2Config(d_model=768) 33 | model = Mamba2(config) 34 | 35 | x = torch.randn(2, 64, 768) # (batch, seqlen, d_model) 36 | y = model(x) # same shape as x 37 | ``` 38 | 39 | ## TODOs 40 | 41 | - [x] Constant time (wrt sequence length) autoregressive inference 42 | - [ ] Remove dependency on `einops` (depends on whether resulting code is still readable) 43 | 44 | ## Credits 45 | 46 | * [Albert Gu], [Tri Dao] - authors of the Mamba-2 architecture 47 | * [John Ma] - author of [johnma2006/mamba-minimal], who inspired this repo 48 | 49 | ## Resources 50 | 51 | Some resources to understand Mamba and SSMs. 52 | 53 | * [Mamba-1/2 reference implementation] 54 | * [Mamba-1 paper] 55 | * [Mamba-2 paper] 56 | * [The Annotated S4] (literate programming for the S4 model) 57 | * [Mamba-2 blog post] 58 | 59 | [Albert Gu]: https://github.com/albertfgu 60 | [Tri Dao]: https://github.com/tridao 61 | [John Ma]: https://github.com/johnma2006 62 | [johnma2006/mamba-minimal]: https://github.com/johnma2006/mamba-minimal 63 | [Mamba-1 paper]: https://arxiv.org/abs/2312.00752 64 | [Mamba-2 paper]: https://arxiv.org/abs/2405.21060 65 | [The Annotated S4]: https://srush.github.io/annotated-s4/ 66 | [Mamba-2 blog post]: https://tridao.me/blog/2024/mamba2-part1-model/ 67 | [Mamba-1/2 reference implementation]: https://github.com/state-spaces/mamba 68 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "64b6c94f-3e79-46b7-b116-7027966777f8", 6 | "metadata": {}, 7 | "source": [ 8 | "# Mamba-2 Language Model demo" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "fa052505-d91c-4e87-8daa-2b00ad8cc881", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "id": "a74ee1cb-b4b2-46a8-98a4-1dd845c1e5ab", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import time\n", 30 | "\n", 31 | "import torch\n", 32 | "from transformers import AutoTokenizer\n", 33 | "\n", 34 | "from mamba2 import Mamba2LMHeadModel\n", 35 | "\n", 36 | "if torch.cuda.is_available():\n", 37 | " device = torch.device('cuda')\n", 38 | "elif torch.backends.mps.is_available():\n", 39 | " device = torch.device('mps')\n", 40 | "else:\n", 41 | " device = torch.device('cpu')" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "id": "a59ab109-2cbe-4f7a-b5ce-a58b8860f98c", 47 | "metadata": {}, 48 | "source": [ 49 | "Official pretrained models on [huggingface](https://huggingface.co/state-spaces):\n", 50 | "* `state-spaces/mamba2-130m`\n", 51 | "* `state-spaces/mamba2-370m`\n", 52 | "* `state-spaces/mamba2-780m`\n", 53 | "* `state-spaces/mamba2-1.3b`\n", 54 | "* `state-spaces/mamba2-2.7b`\n", 55 | "\n", 56 | "Choose a model depending on available system RAM (for CPU or system with unified memory) or VRAM.\n", 57 | "\n", 58 | "Note that these are base models without fine-tuning for downstream tasks such as chat or instruction following." 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 3, 64 | "id": "b6569ffd-993f-4d5b-9094-902801fe6c14", 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "name": "stderr", 69 | "output_type": "stream", 70 | "text": [ 71 | "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" 72 | ] 73 | } 74 | ], 75 | "source": [ 76 | "model = Mamba2LMHeadModel.from_pretrained(\"state-spaces/mamba2-1.3b\", device=device)\n", 77 | "tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-neox-20b\")\n", 78 | "tokenizer.pad_token_id = tokenizer.eos_token_id" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 4, 84 | "id": "bb837263-8a1f-40bf-a9b1-fce72225a674", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "generation_config = dict(\n", 89 | " max_new_length=200,\n", 90 | " temperature=1.0,\n", 91 | " top_k=30,\n", 92 | " top_p=1.0,\n", 93 | ")" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 5, 99 | "id": "87006a5d-7992-4026-9b40-36cbc3ebf8dc", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "def generate(prompt: str, seed: int = 0, show_perf: bool = True):\n", 104 | " \"\"\"Generate streaming completion\"\"\"\n", 105 | " torch.manual_seed(seed)\n", 106 | "\n", 107 | " input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)[0]\n", 108 | " print(prompt, end=\"\")\n", 109 | "\n", 110 | " start = time.process_time()\n", 111 | " n_generated = 0\n", 112 | " for i, (token_id, _hidden_state) in enumerate(model.generate(input_ids, **generation_config)):\n", 113 | " token = tokenizer.decode([token_id])\n", 114 | " if i == 0:\n", 115 | " now = time.process_time()\n", 116 | " prompt_eval_elapsed, start = now - start, now\n", 117 | " else:\n", 118 | " n_generated += 1\n", 119 | " print(token, end=\"\", flush=True)\n", 120 | " if show_perf:\n", 121 | " elapsed = time.process_time() - start\n", 122 | " print('\\n\\n---')\n", 123 | " print(f'Prompt eval | tokens: {input_ids.shape[0]} | elapsed: {prompt_eval_elapsed:.2f}s | tok/s: {input_ids.shape[0] / prompt_eval_elapsed:.2f}')\n", 124 | " print(f'Generation | tokens: {n_generated} | elapsed: {elapsed:.2f}s | tok/s: {n_generated / elapsed:.2f}')" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 6, 130 | "id": "4b926b16-2883-4eef-9459-3718498409e6", 131 | "metadata": {}, 132 | "outputs": [ 133 | { 134 | "name": "stdout", 135 | "output_type": "stream", 136 | "text": [ 137 | "Mamba is a new state space model architecture that enables the modeling of discrete events in humanoid robots with simple and intuitive syntax.\n", 138 | "\n", 139 | "The Mamba state model is based on the state space model architecture of the state machine.\n", 140 | "Mamba enables fast and intuitive specification of the state transitions, without requiring any experience with formal modeling.\n", 141 | "The states are described on a per-event basis and they are not tied to an explicit representation of the robot world or\n", 142 | "the physics of the physical robot.\n", 143 | "\n", 144 | "Mamba is a free and open-source state space model software.\n", 145 | "\n", 146 | "For information on Mamba, visit mamba-robots.org\n", 147 | "\n", 148 | "What is a state machine?\n", 149 | "\n", 150 | "State machine modeling was pioneered by J.R. Walker and his colleagues in the early 1960s at MIT, who showed that\n", 151 | "continuous-time systems can be well represented by a simple discrete state machine. They also used this idea to build\n", 152 | "the first model of the humanoid robotic system known as Quoogle. Over the\n", 153 | "\n", 154 | "---\n", 155 | "Prompt eval | tokens: 9 | elapsed: 1.11s | tok/s: 8.08\n", 156 | "Generation | tokens: 199 | elapsed: 12.94s | tok/s: 15.38\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "generate(\"Mamba is a new state space model architecture\")" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 7, 167 | "id": "608ccece-9a11-47bc-bafd-7b47fc6383c7", 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "name": "stdout", 172 | "output_type": "stream", 173 | "text": [ 174 | "The meaning of life is death. But there is always a possibility that people may believe the opposite, as many have in various parts of the world, such as in India. The idea of God being the one who decides everything and life meaning is not decided by our thoughts, but by events. Life is not a fairytale and even if death is the only real possibility that people do not think of.\n", 175 | "\n", 176 | "India is the birthplace of Hinduism and the country has a history of several ancient civilizations. But what has remained unknown is the fact that Hinduism was not a religious system to worship in the past. It was more of a system of beliefs to live a better life. The most important point that can be ascertained is that life is all about choice and free will. The one who chooses the path, chooses the future that will be his.\n", 177 | "\n", 178 | "The Hindu way of life has been influenced by ancient Hindu traditions and beliefs. While the major tenets remain the same, the practices and rituals have\n", 179 | "\n", 180 | "---\n", 181 | "Prompt eval | tokens: 5 | elapsed: 0.33s | tok/s: 14.95\n", 182 | "Generation | tokens: 199 | elapsed: 12.53s | tok/s: 15.88\n" 183 | ] 184 | } 185 | ], 186 | "source": [ 187 | "generate(\"The meaning of life is\")" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 8, 193 | "id": "f0bc8a2b-b2bc-4d30-bf4c-213baec7441a", 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "name": "stdout", 198 | "output_type": "stream", 199 | "text": [ 200 | "CUDA is Nvidia's biggest moat on graphics hardware, and it's one that the gaming and PC markets have both been fighting to maintain for the last decade. However, Nvidia's Pascal architecture is on the horizon. And that could be a big opportunity for AMD.\n", 201 | "\n", 202 | "\n", 203 | "When Nvidia first released its Turing architecture at GTC back in February it was only on the cards; we only got an early taste of it. And so it has taken AMD quite some time to start taking a look at all that Nvidia-Turing design. However, AMD still has plenty of time to get to Nvidia before it's too late, and it needs to keep its eye on the Pascal architecture.\n", 204 | "\n", 205 | "AMD is also working on a new GPU called Vega which is due to go live this year; the first Vega GPU is rumored to be a reimagined Polaris architecture that features an enhanced memory hierarchy which could significantly speed up the graphics pipeline. If Vega is anything like Polaris,\n", 206 | "\n", 207 | "---\n", 208 | "Prompt eval | tokens: 9 | elapsed: 0.57s | tok/s: 15.67\n", 209 | "Generation | tokens: 199 | elapsed: 12.67s | tok/s: 15.71\n" 210 | ] 211 | } 212 | ], 213 | "source": [ 214 | "generate(\"CUDA is Nvidia's biggest moat\")" 215 | ] 216 | } 217 | ], 218 | "metadata": { 219 | "kernelspec": { 220 | "display_name": "Python 3 (ipykernel)", 221 | "language": "python", 222 | "name": "python3" 223 | }, 224 | "language_info": { 225 | "codemirror_mode": { 226 | "name": "ipython", 227 | "version": 3 228 | }, 229 | "file_extension": ".py", 230 | "mimetype": "text/x-python", 231 | "name": "python", 232 | "nbconvert_exporter": "python", 233 | "pygments_lexer": "ipython3", 234 | "version": "3.12.2" 235 | } 236 | }, 237 | "nbformat": 4, 238 | "nbformat_minor": 5 239 | } 240 | -------------------------------------------------------------------------------- /mamba2.py: -------------------------------------------------------------------------------- 1 | """ 2 | mamba2-minimal 3 | ============== 4 | 5 | A minimal, single-file implementation of the Mamba-2 model in PyTorch. 6 | 7 | > **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality** 8 | > Authors: Tri Dao, Albert Gu 9 | > Paper: https://arxiv.org/abs/2405.21060 10 | """ 11 | 12 | import json 13 | from dataclasses import dataclass 14 | from typing import Iterable, NamedTuple, TypeAlias, cast 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from einops import rearrange, repeat 19 | from torch import LongTensor, Tensor, nn 20 | 21 | Device: TypeAlias = str | torch.device | None 22 | 23 | 24 | @dataclass 25 | class Mamba2Config: 26 | d_model: int # model dimension (D) 27 | n_layer: int = 24 # number of Mamba-2 layers in the language model 28 | d_state: int = 128 # state dimension (N) 29 | d_conv: int = 4 # convolution kernel size 30 | expand: int = 2 # expansion factor (E) 31 | headdim: int = 64 # head dimension (P) 32 | chunk_size: int = 64 # matrix partition size (Q) 33 | vocab_size: int = 50277 34 | pad_vocab_size_multiple: int = 16 35 | 36 | def __post_init__(self): 37 | self.d_inner = self.expand * self.d_model 38 | assert self.d_inner % self.headdim == 0 39 | self.nheads = self.d_inner // self.headdim 40 | if self.vocab_size % self.pad_vocab_size_multiple != 0: 41 | self.vocab_size += ( 42 | self.pad_vocab_size_multiple 43 | - self.vocab_size % self.pad_vocab_size_multiple 44 | ) 45 | 46 | 47 | class InferenceCache(NamedTuple): 48 | conv_state: Tensor # (batch, d_inner + 2 * d_state, d_conv) 49 | ssm_state: Tensor # (batch, nheads, headdim, d_state) 50 | 51 | @staticmethod 52 | def alloc(batch_size: int, args: Mamba2Config, device: Device = None): 53 | return InferenceCache( 54 | torch.zeros( 55 | batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device 56 | ), 57 | torch.zeros( 58 | batch_size, args.nheads, args.headdim, args.d_state, device=device 59 | ), 60 | ) 61 | 62 | 63 | class Mamba2LMHeadModel(nn.Module): 64 | def __init__(self, args: Mamba2Config, device: Device = None): 65 | super().__init__() 66 | self.args = args 67 | self.device = device 68 | 69 | self.backbone = nn.ModuleDict( 70 | dict( 71 | embedding=nn.Embedding(args.vocab_size, args.d_model, device=device), 72 | layers=nn.ModuleList( 73 | [ 74 | nn.ModuleDict( 75 | dict( 76 | mixer=Mamba2(args, device=device), 77 | norm=RMSNorm(args.d_model, device=device), 78 | ) 79 | ) 80 | for _ in range(args.n_layer) 81 | ] 82 | ), 83 | norm_f=RMSNorm(args.d_model, device=device), 84 | ) 85 | ) 86 | self.lm_head = nn.Linear( 87 | args.d_model, args.vocab_size, bias=False, device=device 88 | ) 89 | self.lm_head.weight = self.backbone.embedding.weight 90 | 91 | @staticmethod 92 | def from_pretrained(huggingface_model_id: str, device: Device = None): 93 | from transformers.utils import CONFIG_NAME, WEIGHTS_NAME 94 | from transformers.utils.hub import cached_file 95 | 96 | config_path = cached_file(huggingface_model_id, CONFIG_NAME) 97 | assert config_path, "Failed to get huggingface config file" 98 | state_dict_path = cached_file(huggingface_model_id, WEIGHTS_NAME) 99 | assert state_dict_path, "Failed to get huggingface state dict file" 100 | 101 | config = json.load(open(config_path)) 102 | args = Mamba2Config( 103 | d_model=config["d_model"], 104 | n_layer=config["n_layer"], 105 | vocab_size=config["vocab_size"], 106 | pad_vocab_size_multiple=config["pad_vocab_size_multiple"], 107 | ) 108 | 109 | map_location = "cpu" if device is None else device 110 | state_dict = torch.load( 111 | state_dict_path, weights_only=True, map_location=map_location, mmap=True 112 | ) 113 | model = Mamba2LMHeadModel(args, device=device) 114 | model.load_state_dict(state_dict) 115 | model.eval() 116 | return model 117 | 118 | def forward( 119 | self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None 120 | ) -> tuple[LongTensor, list[InferenceCache]]: 121 | """ 122 | Arguments 123 | input_ids: (batch, seqlen) tokens from `EleutherAI/gpt-neox-20b` tokenizer 124 | h: hidden states for inference step. If present the constant-time 125 | (wrt sequence length) inference path will be taken, input_ids 126 | should have shape (batch, 1) containing the next batch of prompt 127 | token. 128 | 129 | Return (logits, h) 130 | logits: (batch, seqlen, vocab_size) 131 | h: updated inference cache after processing `input_ids` 132 | """ 133 | seqlen = input_ids.shape[1] 134 | 135 | if h is None: 136 | h = [None for _ in range(self.args.n_layer)] 137 | 138 | x = self.backbone.embedding(input_ids) 139 | for i, layer in enumerate(self.backbone.layers): 140 | y, h[i] = layer.mixer(layer.norm(x), h[i]) 141 | x = y + x 142 | 143 | x = self.backbone.norm_f(x) 144 | logits = self.lm_head(x) 145 | return logits[:, :seqlen], cast(list[InferenceCache], h) 146 | 147 | def generate( 148 | self, 149 | input_ids: LongTensor, 150 | max_new_length: int = 20, 151 | temperature: float = 1.0, 152 | top_k: int = 50, 153 | top_p: float = 1.0, 154 | eos_token_id: int = 0, 155 | ) -> Iterable[tuple[int, list[InferenceCache]]]: 156 | prefix, tokens = input_ids[:-1], input_ids[-1:].unsqueeze(0) 157 | 158 | # Process prompt 159 | # The input sequence to forward (non-inference path) must have length multiple that of chunk_size. 160 | # We split out excess tokens so that n_chunked tokens can be processed by one forward call and 161 | # process the rest in multiple inference steps. 162 | n_chunked = (prefix.shape[0] // self.args.chunk_size) * self.args.chunk_size 163 | if n_chunked > 0: 164 | _, h = self(prefix[:n_chunked].unsqueeze(0), None) 165 | else: 166 | h = [ 167 | InferenceCache.alloc(1, self.args, device=self.device) 168 | for _ in range(self.args.n_layer) 169 | ] 170 | for i in range(n_chunked, prefix.shape[0]): 171 | _, h = self(prefix[i : i + 1].unsqueeze(0), h) 172 | 173 | # Generate 174 | for _ in range(max_new_length): 175 | with torch.no_grad(): 176 | out, h = self(tokens, h) 177 | logits = out[0, -1] 178 | if temperature != 1.0: 179 | logits = logits / temperature 180 | if top_k > 0: 181 | indices_to_remove = logits < torch.topk(logits, k=top_k)[0][-1] 182 | logits[indices_to_remove] = -torch.inf 183 | if top_p < 1.0: 184 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 185 | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 186 | sorted_indices_to_remove = cum_probs > 0.5 187 | sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() 188 | sorted_indices_to_remove[0] = False 189 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 190 | logits[indices_to_remove] = -torch.inf 191 | probs = F.softmax(logits, dim=-1) 192 | next_token = torch.multinomial(probs, num_samples=1) 193 | if next_token.item() == eos_token_id: 194 | return 195 | tokens = next_token.unsqueeze(0) 196 | yield cast(int, next_token.item()), h 197 | 198 | 199 | class Mamba2(nn.Module): 200 | def __init__(self, args: Mamba2Config, device: Device = None): 201 | super().__init__() 202 | self.args = args 203 | self.device = device 204 | 205 | # Order: (z, x, B, C, dt) 206 | d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads 207 | self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device) 208 | 209 | conv_dim = args.d_inner + 2 * args.d_state 210 | self.conv1d = nn.Conv1d( 211 | in_channels=conv_dim, 212 | out_channels=conv_dim, 213 | kernel_size=args.d_conv, 214 | groups=conv_dim, 215 | padding=args.d_conv - 1, 216 | device=device, 217 | ) 218 | 219 | self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device)) 220 | self.A_log = nn.Parameter(torch.empty(args.nheads, device=device)) 221 | self.D = nn.Parameter(torch.empty(args.nheads, device=device)) 222 | self.norm = RMSNorm(args.d_inner, device=device) 223 | self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device) 224 | 225 | def forward(self, u: Tensor, h: InferenceCache | None = None): 226 | """ 227 | Arguments 228 | u: (batch, seqlen, d_model) input. seqlen should be a multiple of chunk_size. 229 | h: hidden states for inference step. Initialized to 0s if not present. 230 | 231 | Return (y, h) 232 | y: (batch, seqlen, d_model) output 233 | h: updated inference cache after processing `u` 234 | """ 235 | if h: 236 | return self.step(u, h) 237 | 238 | A = -torch.exp(self.A_log) # (nheads,) 239 | zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj) 240 | z, xBC, dt = torch.split( 241 | zxbcdt, 242 | [ 243 | self.args.d_inner, 244 | self.args.d_inner + 2 * self.args.d_state, 245 | self.args.nheads, 246 | ], 247 | dim=-1, 248 | ) 249 | dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads) 250 | 251 | # Pad or truncate xBC seqlen to d_conv 252 | conv_state = F.pad( 253 | rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0) 254 | ) 255 | 256 | xBC = silu( 257 | self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :] 258 | ) # (batch, seqlen, d_inner + 2 * d_state)) 259 | x, B, C = torch.split( 260 | xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1 261 | ) 262 | x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim) 263 | y, ssm_state = ssd( 264 | x * dt.unsqueeze(-1), 265 | A * dt, 266 | rearrange(B, "b l n -> b l 1 n"), 267 | rearrange(C, "b l n -> b l 1 n"), 268 | self.args.chunk_size, 269 | device=self.device, 270 | ) 271 | y = y + x * self.D.unsqueeze(-1) 272 | y = rearrange(y, "b l h p -> b l (h p)") 273 | y = self.norm(y, z) 274 | y = self.out_proj(y) 275 | 276 | h = InferenceCache(conv_state, ssm_state) 277 | return y, h 278 | 279 | def step(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]: 280 | """Take a single inference step for the current input and hidden state 281 | 282 | Unlike attention-based models, RNN-based models (eg Mamba) does not need 283 | to look back at all the past tokens to generate a new token. Instead a 284 | hidden state (initialized to 0s initially) is updated for each input and 285 | passed to the next inference step. This means that the total inference 286 | time is linear with respect to the sequence length instead of quadratic 287 | in attention's case. 288 | 289 | Arguments 290 | u: (batch, 1, d_model) 291 | h: initial/running hidden state 292 | 293 | Return (y, h) 294 | y: (batch, 1, d_model) 295 | h: updated hidden state 296 | """ 297 | assert u.shape[1] == 1, "Only one token can be decoded per inference step" 298 | 299 | zxbcdt = self.in_proj(u.squeeze(1)) # (batch, d_in_proj) 300 | z, xBC, dt = torch.split( 301 | zxbcdt, 302 | [ 303 | self.args.d_inner, 304 | self.args.d_inner + 2 * self.args.d_state, 305 | self.args.nheads, 306 | ], 307 | dim=-1, 308 | ) 309 | 310 | # Advance convolution input 311 | h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1)) 312 | h.conv_state[:, :, -1] = xBC 313 | # Convolution step 314 | xBC = torch.sum( 315 | h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1 316 | ) 317 | xBC += self.conv1d.bias 318 | xBC = silu(xBC) 319 | 320 | x, B, C = torch.split( 321 | xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1 322 | ) 323 | A = -torch.exp(self.A_log) # (nheads,) 324 | 325 | # SSM step 326 | dt = F.softplus(dt + self.dt_bias) # (batch, nheads) 327 | dA = torch.exp(dt * A) # (batch, nheads) 328 | x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim) 329 | dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x) 330 | h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) 331 | y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C) 332 | y = y + rearrange(self.D, "h -> h 1") * x 333 | y = rearrange(y, "b h p -> b (h p)") 334 | y = self.norm(y, z) 335 | y = self.out_proj(y) 336 | 337 | return y.unsqueeze(1), h 338 | 339 | 340 | def segsum(x: Tensor, device: Device = None) -> Tensor: 341 | """Stable segment sum calculation. 342 | 343 | `exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM. 344 | 345 | Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32 346 | """ 347 | T = x.size(-1) 348 | x = repeat(x, "... d -> ... d e", e=T) 349 | mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1) 350 | x = x.masked_fill(~mask, 0) 351 | x_segsum = torch.cumsum(x, dim=-2) 352 | mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0) 353 | x_segsum = x_segsum.masked_fill(~mask, -torch.inf) 354 | return x_segsum 355 | 356 | 357 | def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None): 358 | """Structed State Space Duality (SSD) - the core of Mamba-2 359 | 360 | This is almost the exact same minimal SSD code from the blog post. 361 | 362 | Arguments 363 | x: (batch, seqlen, n_heads, d_head) 364 | A: (batch, seqlen, n_heads) 365 | B: (batch, seqlen, n_heads, d_state) 366 | C: (batch, seqlen, n_heads, d_state) 367 | 368 | Return 369 | y: (batch, seqlen, n_heads, d_head) 370 | 371 | Source 372 | 1. https://tridao.me/blog/2024/mamba2-part3-algorithm/ 373 | 2. https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L34-L78 374 | """ 375 | assert x.shape[1] % chunk_size == 0 376 | 377 | # Rearrange into chunks 378 | # Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel) 379 | # This is not implemented and left as an exercise for the reader 😜 380 | x, A, B, C = [ 381 | rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C) 382 | ] 383 | 384 | A = rearrange(A, "b c l h -> b h c l") 385 | A_cumsum = torch.cumsum(A, dim=-1) 386 | 387 | # 1. Compute the output for each intra-chunk (diagonal blocks) 388 | L = torch.exp(segsum(A, device=device)) 389 | Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x) 390 | 391 | # 2. Compute the state for each intra-chunk 392 | # (right term of low-rank factorization of off-diagonal blocks; B terms) 393 | decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) 394 | states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x) 395 | 396 | # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries 397 | # (middle term of factorization of off-diag blocks; A terms) 398 | if initial_states is None: 399 | initial_states = torch.zeros_like(states[:, :1]) 400 | states = torch.cat([initial_states, states], dim=1) 401 | decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device)) 402 | new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states) 403 | states, final_state = new_states[:, :-1], new_states[:, -1] 404 | 405 | # 4. Compute state -> output conversion per chunk 406 | # (left term of low-rank factorization of off-diagonal blocks; C terms) 407 | state_decay_out = torch.exp(A_cumsum) 408 | Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out) 409 | 410 | # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) 411 | Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") 412 | 413 | return Y, final_state 414 | 415 | 416 | class RMSNorm(nn.Module): 417 | def __init__(self, d: int, eps: float = 1e-5, device: Device = None): 418 | """Gated Root Mean Square Layer Normalization 419 | 420 | Paper: https://arxiv.org/abs/1910.07467 421 | """ 422 | super().__init__() 423 | self.eps = eps 424 | self.weight = nn.Parameter(torch.ones(d, device=device)) 425 | 426 | def forward(self, x, z=None): 427 | if z is not None: 428 | x = x * silu(z) 429 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight 430 | 431 | 432 | def silu(x): 433 | """Applies the Sigmoid Linear Unit (SiLU), element-wise. 434 | 435 | Define this manually since torch's version doesn't seem to work on MPS. 436 | """ 437 | return x * F.sigmoid(x) 438 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers 3 | einops 4 | --------------------------------------------------------------------------------