├── LICENSE ├── README.md ├── demo.ipynb └── model.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Tri Dao, Albert Gu 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## mamba-minimal 2 | 3 | Simple, minimal implementation of Mamba in one file of PyTorch. 4 | 5 | Featuring: 6 | * Equivalent numerical output as official implementation for both forward and backward pass 7 | * Simplified, readable, annotated code 8 | 9 | Does NOT include: 10 | * Speed. The official implementation is heavily optimized, and these optimizations are core contributions of the Mamba paper. I kept most implementations simple for readability. 11 | * Proper parameter initialization (though this could be added without sacrificing readability) 12 | 13 | ## Demo 14 | 15 | See [demo.ipynb](demo.ipynb) for examples of prompt completions. 16 | 17 | ```python 18 | from model import Mamba 19 | from transformers import AutoTokenizer 20 | 21 | model = Mamba.from_pretrained('state-spaces/mamba-370m') 22 | tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') 23 | 24 | generate(model, tokenizer, 'Mamba is the') 25 | ``` 26 | > Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite) 27 | 28 | 150 meters... 🫢 scary! 29 | 30 | ## References 31 | 32 | The Mamba architecture was introduced in [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by [Albert Gu](https://twitter.com/_albertgu?lang=en) and [Tri Dao](https://twitter.com/tri_dao?ref_src=twsrc%5Egoogle%7Ctwcamp%5Eserp%7Ctwgr%5Eauthor). 33 | 34 | The official implementation is here: https://github.com/state-spaces/mamba/tree/main 35 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "531467a2-5160-4073-a990-0d81d574b014", 6 | "metadata": {}, 7 | "source": [ 8 | "## (1) Load model" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 12, 14 | "id": "d9337043-4e7a-4b20-9d89-6c6257245334", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from model import Mamba, ModelArgs\n", 19 | "from transformers import AutoTokenizer\n", 20 | "\n", 21 | "# One of:\n", 22 | "# 'state-spaces/mamba-2.8b-slimpj'\n", 23 | "# 'state-spaces/mamba-2.8b'\n", 24 | "# 'state-spaces/mamba-1.4b'\n", 25 | "# 'state-spaces/mamba-790m'\n", 26 | "# 'state-spaces/mamba-370m'\n", 27 | "# 'state-spaces/mamba-130m'\n", 28 | "pretrained_model_name = 'state-spaces/mamba-370m'\n", 29 | "\n", 30 | "model = Mamba.from_pretrained(pretrained_model_name)\n", 31 | "tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "id": "0b2efb17-37ad-472b-b029-9567acf17629", 37 | "metadata": {}, 38 | "source": [ 39 | "## (2) Generate Text" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "id": "c4b2d62d-0d95-4a3f-bd98-aa37e3f26b39", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "import torch\n", 50 | "import torch.nn.functional as F\n", 51 | "\n", 52 | "\n", 53 | "def generate(model,\n", 54 | " tokenizer,\n", 55 | " prompt: str,\n", 56 | " n_tokens_to_gen: int = 50,\n", 57 | " sample: bool = True,\n", 58 | " top_k: int = 40):\n", 59 | " model.eval()\n", 60 | " \n", 61 | " input_ids = tokenizer(prompt, return_tensors='pt').input_ids\n", 62 | " \n", 63 | " for token_n in range(n_tokens_to_gen):\n", 64 | " with torch.no_grad():\n", 65 | " indices_to_input = input_ids\n", 66 | " next_token_logits = model(indices_to_input)[:, -1]\n", 67 | " \n", 68 | " probs = F.softmax(next_token_logits, dim=-1)\n", 69 | " (batch, vocab_size) = probs.shape\n", 70 | " \n", 71 | " if top_k is not None:\n", 72 | " (values, indices) = torch.topk(probs, k=top_k)\n", 73 | " probs[probs < values[:, -1, None]] = 0\n", 74 | " probs = probs / probs.sum(axis=1, keepdims=True)\n", 75 | " \n", 76 | " if sample:\n", 77 | " next_indices = torch.multinomial(probs, num_samples=1)\n", 78 | " else:\n", 79 | " next_indices = torch.argmax(probs, dim=-1)[:, None]\n", 80 | " \n", 81 | " input_ids = torch.cat([input_ids, next_indices], dim=1)\n", 82 | "\n", 83 | " output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]\n", 84 | " \n", 85 | " return output_completions" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 10, 91 | "id": "ee877143-2042-4579-8042-a96db6200517", 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "name": "stdout", 96 | "output_type": "stream", 97 | "text": [ 98 | "Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "print(generate(model, tokenizer, 'Mamba is the'))" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 9, 109 | "id": "65d70549-597f-49ca-9185-2184d2576f7d", 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "John: Hi!\n", 117 | "Sally: Hey!\n", 118 | "John: So, when's the wedding?\n", 119 | "Sally: We haven't decided.\n", 120 | "John: It's in September.\n", 121 | "Sally: Yeah, we were thinking July or\n", 122 | "August.\n", 123 | "John: I'm not too\n" 124 | ] 125 | } 126 | ], 127 | "source": [ 128 | "print(generate(model, tokenizer, 'John: Hi!\\nSally:'))" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 8, 134 | "id": "6d419fc9-066b-4818-812c-2f1952528bc6", 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "name": "stdout", 139 | "output_type": "stream", 140 | "text": [ 141 | "The meaning of life is \n", 142 | "just this: It is the best you can do.\n", 143 | "\n", 144 | "--K.J.\n", 145 | "\n", 146 | "And finally: How to handle your emotions. \n", 147 | "\n", 148 | "<|endoftext|>Q:\n", 149 | "\n", 150 | "Error creating an EntityManager instance in JavaEE 7\n", 151 | "\n", 152 | "This is\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "print(generate(model, tokenizer, 'The meaning of life is '))" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 11, 163 | "id": "2b189e6e-6a96-4770-88cf-7c5de22cb321", 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "def reverse_string(text, result):\n", 171 | " # find the position of the start of the string.\n", 172 | " start = text.index(text[0:-1])\n", 173 | " # find the position where the string begins changing.\n", 174 | " end = text.index\n" 175 | ] 176 | } 177 | ], 178 | "source": [ 179 | "print(generate(model, tokenizer, 'def reverse_string('))" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "id": "be3afb51-5093-4c64-ac3f-43c2e6b20b10", 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "id": "6531acc0-b18f-472a-8e99-cee64dd51cd8", 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "id": "d0efe197-891a-4ab8-8cea-413d1fb1acda", 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "id": "2e99509b-df7b-4bac-b6a2-669f601ec1c8", 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [] 213 | } 214 | ], 215 | "metadata": { 216 | "kernelspec": { 217 | "display_name": "Python 3 (ipykernel)", 218 | "language": "python", 219 | "name": "python3" 220 | }, 221 | "language_info": { 222 | "codemirror_mode": { 223 | "name": "ipython", 224 | "version": 3 225 | }, 226 | "file_extension": ".py", 227 | "mimetype": "text/x-python", 228 | "name": "python", 229 | "nbconvert_exporter": "python", 230 | "pygments_lexer": "ipython3", 231 | "version": "3.11.5" 232 | } 233 | }, 234 | "nbformat": 4, 235 | "nbformat_minor": 5 236 | } 237 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """Simple, minimal implementation of Mamba in one file of PyTorch. 2 | 3 | Suggest reading the following before/while reading the code: 4 | [1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao) 5 | https://arxiv.org/abs/2312.00752 6 | [2] The Annotated S4 (Sasha Rush and Sidd Karamcheti) 7 | https://srush.github.io/annotated-s4 8 | 9 | Glossary: 10 | b: batch size (`B` in Mamba paper [1] Algorithm 2) 11 | l: sequence length (`L` in [1] Algorithm 2) 12 | d or d_model: hidden dim 13 | n or d_state: latent state dim (`N` in [1] Algorithm 2) 14 | expand: expansion factor (`E` in [1] Section 3.4) 15 | d_in or d_inner: d * expand (`D` in [1] Algorithm 2) 16 | A, B, C, D: state space parameters (See any state space representation formula) 17 | (B, C are input-dependent (aka selective, a key innovation in Mamba); A, D are not) 18 | Δ or delta: input-dependent step size 19 | dt_rank: rank of Δ (See [1] Section 3.6 "Parameterization of ∆") 20 | 21 | """ 22 | from __future__ import annotations 23 | import math 24 | import json 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | from dataclasses import dataclass 29 | from einops import rearrange, repeat, einsum 30 | 31 | 32 | @dataclass 33 | class ModelArgs: 34 | d_model: int 35 | n_layer: int 36 | vocab_size: int 37 | d_state: int = 16 38 | expand: int = 2 39 | dt_rank: Union[int, str] = 'auto' 40 | d_conv: int = 4 41 | pad_vocab_size_multiple: int = 8 42 | conv_bias: bool = True 43 | bias: bool = False 44 | 45 | def __post_init__(self): 46 | self.d_inner = int(self.expand * self.d_model) 47 | 48 | if self.dt_rank == 'auto': 49 | self.dt_rank = math.ceil(self.d_model / 16) 50 | 51 | if self.vocab_size % self.pad_vocab_size_multiple != 0: 52 | self.vocab_size += (self.pad_vocab_size_multiple 53 | - self.vocab_size % self.pad_vocab_size_multiple) 54 | 55 | 56 | class Mamba(nn.Module): 57 | def __init__(self, args: ModelArgs): 58 | """Full Mamba model.""" 59 | super().__init__() 60 | self.args = args 61 | 62 | self.embedding = nn.Embedding(args.vocab_size, args.d_model) 63 | self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)]) 64 | self.norm_f = RMSNorm(args.d_model) 65 | 66 | self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) 67 | self.lm_head.weight = self.embedding.weight # Tie output projection to embedding weights. 68 | # See "Weight Tying" paper 69 | 70 | 71 | def forward(self, input_ids): 72 | """ 73 | Args: 74 | input_ids (long tensor): shape (b, l) (See Glossary at top for definitions of b, l, d_in, n...) 75 | 76 | Returns: 77 | logits: shape (b, l, vocab_size) 78 | 79 | Official Implementation: 80 | class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173 81 | 82 | """ 83 | x = self.embedding(input_ids) 84 | 85 | for layer in self.layers: 86 | x = layer(x) 87 | 88 | x = self.norm_f(x) 89 | logits = self.lm_head(x) 90 | 91 | return logits 92 | 93 | 94 | @staticmethod 95 | def from_pretrained(pretrained_model_name: str): 96 | """Load pretrained weights from HuggingFace into model. 97 | 98 | Args: 99 | pretrained_model_name: One of 100 | * 'state-spaces/mamba-2.8b-slimpj' 101 | * 'state-spaces/mamba-2.8b' 102 | * 'state-spaces/mamba-1.4b' 103 | * 'state-spaces/mamba-790m' 104 | * 'state-spaces/mamba-370m' 105 | * 'state-spaces/mamba-130m' 106 | 107 | Returns: 108 | model: Mamba model with weights loaded 109 | 110 | """ 111 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 112 | from transformers.utils.hub import cached_file 113 | 114 | def load_config_hf(model_name): 115 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, 116 | _raise_exceptions_for_missing_entries=False) 117 | return json.load(open(resolved_archive_file)) 118 | 119 | 120 | def load_state_dict_hf(model_name, device=None, dtype=None): 121 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, 122 | _raise_exceptions_for_missing_entries=False) 123 | return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True) 124 | 125 | config_data = load_config_hf(pretrained_model_name) 126 | args = ModelArgs( 127 | d_model=config_data['d_model'], 128 | n_layer=config_data['n_layer'], 129 | vocab_size=config_data['vocab_size'] 130 | ) 131 | model = Mamba(args) 132 | 133 | state_dict = load_state_dict_hf(pretrained_model_name) 134 | new_state_dict = {} 135 | for key in state_dict: 136 | new_key = key.replace('backbone.', '') 137 | new_state_dict[new_key] = state_dict[key] 138 | model.load_state_dict(new_state_dict) 139 | 140 | return model 141 | 142 | 143 | class ResidualBlock(nn.Module): 144 | def __init__(self, args: ModelArgs): 145 | """Simple block wrapping Mamba block with normalization and residual connection.""" 146 | super().__init__() 147 | self.args = args 148 | self.mixer = MambaBlock(args) 149 | self.norm = RMSNorm(args.d_model) 150 | 151 | 152 | def forward(self, x): 153 | """ 154 | Args: 155 | x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...) 156 | 157 | Returns: 158 | output: shape (b, l, d) 159 | 160 | Official Implementation: 161 | Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297 162 | 163 | Note: the official repo chains residual blocks that look like 164 | [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ... 165 | where the first Add is a no-op. This is purely for performance reasons as this 166 | allows them to fuse the Add->Norm. 167 | 168 | We instead implement our blocks as the more familiar, simpler, and numerically equivalent 169 | [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> .... 170 | 171 | """ 172 | output = self.mixer(self.norm(x)) + x 173 | 174 | return output 175 | 176 | 177 | class MambaBlock(nn.Module): 178 | def __init__(self, args: ModelArgs): 179 | """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].""" 180 | super().__init__() 181 | self.args = args 182 | 183 | self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias) 184 | 185 | self.conv1d = nn.Conv1d( 186 | in_channels=args.d_inner, 187 | out_channels=args.d_inner, 188 | bias=args.conv_bias, 189 | kernel_size=args.d_conv, 190 | groups=args.d_inner, 191 | padding=args.d_conv - 1, 192 | ) 193 | 194 | # x_proj takes in `x` and outputs the input-specific Δ, B, C 195 | self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False) 196 | 197 | # dt_proj projects Δ from dt_rank to d_in 198 | self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True) 199 | 200 | A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner) 201 | self.A_log = nn.Parameter(torch.log(A)) 202 | self.D = nn.Parameter(torch.ones(args.d_inner)) 203 | self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias) 204 | 205 | 206 | def forward(self, x): 207 | """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1]. 208 | 209 | Args: 210 | x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...) 211 | 212 | Returns: 213 | output: shape (b, l, d) 214 | 215 | Official Implementation: 216 | class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119 217 | mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 218 | 219 | """ 220 | (b, l, d) = x.shape 221 | 222 | x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in) 223 | (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1) 224 | 225 | x = rearrange(x, 'b l d_in -> b d_in l') 226 | x = self.conv1d(x)[:, :, :l] 227 | x = rearrange(x, 'b d_in l -> b l d_in') 228 | 229 | x = F.silu(x) 230 | 231 | y = self.ssm(x) 232 | 233 | y = y * F.silu(res) 234 | 235 | output = self.out_proj(y) 236 | 237 | return output 238 | 239 | 240 | def ssm(self, x): 241 | """Runs the SSM. See: 242 | - Algorithm 2 in Section 3.2 in the Mamba paper [1] 243 | - run_SSM(A, B, C, u) in The Annotated S4 [2] 244 | 245 | Args: 246 | x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) 247 | 248 | Returns: 249 | output: shape (b, l, d_in) 250 | 251 | Official Implementation: 252 | mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 253 | 254 | """ 255 | (d_in, n) = self.A_log.shape 256 | 257 | # Compute ∆ A B C D, the state space parameters. 258 | # A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) 259 | # ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, 260 | # and is why Mamba is called **selective** state spaces) 261 | 262 | A = -torch.exp(self.A_log.float()) # shape (d_in, n) 263 | D = self.D.float() 264 | 265 | x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n) 266 | 267 | (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n) 268 | delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in) 269 | 270 | y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2] 271 | 272 | return y 273 | 274 | 275 | def selective_scan(self, u, delta, A, B, C, D): 276 | """Does selective scan algorithm. See: 277 | - Section 2 State Space Models in the Mamba paper [1] 278 | - Algorithm 2 in Section 3.2 in the Mamba paper [1] 279 | - run_SSM(A, B, C, u) in The Annotated S4 [2] 280 | 281 | This is the classic discrete state space formula: 282 | x(t + 1) = Ax(t) + Bu(t) 283 | y(t) = Cx(t) + Du(t) 284 | except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t). 285 | 286 | Args: 287 | u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) 288 | delta: shape (b, l, d_in) 289 | A: shape (d_in, n) 290 | B: shape (b, l, n) 291 | C: shape (b, l, n) 292 | D: shape (d_in,) 293 | 294 | Returns: 295 | output: shape (b, l, d_in) 296 | 297 | Official Implementation: 298 | selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86 299 | Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly. 300 | 301 | """ 302 | (b, l, d_in) = u.shape 303 | n = A.shape[1] 304 | 305 | # Discretize continuous parameters (A, B) 306 | # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1]) 307 | # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: 308 | # "A is the more important term and the performance doesn't change much with the simplification on B" 309 | deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n')) 310 | deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n') 311 | 312 | # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) 313 | # Note that the below is sequential, while the official implementation does a much faster parallel scan that 314 | # is additionally hardware-aware (like FlashAttention). 315 | x = torch.zeros((b, d_in, n), device=deltaA.device) 316 | ys = [] 317 | for i in range(l): 318 | x = deltaA[:, i] * x + deltaB_u[:, i] 319 | y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') 320 | ys.append(y) 321 | y = torch.stack(ys, dim=1) # shape (b, l, d_in) 322 | 323 | y = y + u * D 324 | 325 | return y 326 | 327 | 328 | class RMSNorm(nn.Module): 329 | def __init__(self, 330 | d_model: int, 331 | eps: float = 1e-5): 332 | super().__init__() 333 | self.eps = eps 334 | self.weight = nn.Parameter(torch.ones(d_model)) 335 | 336 | 337 | def forward(self, x): 338 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight 339 | 340 | return output 341 | 342 | --------------------------------------------------------------------------------