├── LICENSE ├── README.md └── src ├── configuration_mamba.py ├── modeling_mamba.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mamba-hf 2 | mamba-hf 3 | 4 | Implementation of the Mamba SSM with hf_integration. 5 | 6 | # Usage: 7 | To use the **mamba-hf**, follow these steps: 8 | 9 | 1. Clone the repository to your local machine. 10 | 11 | ```bash 12 | git clone https://github.com/LegallyCoder/mamba-hf 13 | ``` 14 | 2. Open a terminal or command prompt and navigate to the script's directory. 15 | ```bash 16 | cd src 17 | ``` 18 | 19 | 3. Install the required packages using this command: 20 | 21 | ```bash 22 | pip3 install -r requirements.txt 23 | ``` 24 | 25 | 4. Open new python file at the script's directory. 26 | ```python 27 | from modeling_mamba import MambaForCausalLM 28 | from transformers import AutoTokenizer 29 | 30 | model = MambaForCausalLM.from_pretrained('Q-bert/Mamba-130M') 31 | tokenizer = AutoTokenizer.from_pretrained('Q-bert/Mamba-130M') 32 | 33 | text = "Hi" 34 | 35 | input_ids = tokenizer.encode(text, return_tensors="pt") 36 | 37 | output = model.generate(input_ids, max_length=20, num_beams=5, no_repeat_ngram_size=2) 38 | 39 | generated_text = tokenizer.decode(output[0], skip_special_tokens=True) 40 | 41 | print(generated_text) 42 | 43 | ``` 44 | > Hi, I'm looking for a new job. I've been working at a company for about a year now. 45 | 46 | # For more: 47 | You can look at here 48 | [Mamba Models Collection](https://huggingface.co/collections/Q-bert/mamba-65869481595e25821853d20d) 49 | ## References and Credits: 50 | 51 | The Mamba architecture was introduced in [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by [Albert Gu](https://twitter.com/_albertgu?lang=en) and [Tri Dao](https://twitter.com/tri_dao?ref_src=twsrc%5Egoogle%7Ctwcamp%5Eserp%7Ctwgr%5Eauthor). 52 | 53 | Thank for the simple implementation (https://github.com/johnma2006/mamba-minimal) 54 | 55 | The official implementation is here: https://github.com/state-spaces/mamba/tree/main 56 | -------------------------------------------------------------------------------- /src/configuration_mamba.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional , Union 3 | 4 | from transformers import PretrainedConfig 5 | class MambaConfig(PretrainedConfig): 6 | model_type = "mamba" 7 | def __init__( 8 | self, 9 | vocab_size=50277, 10 | d_state=16, 11 | d_model=2560, 12 | d_conv=4, 13 | expand=2, 14 | conv_bias=True, 15 | bias=False, 16 | n_layer=64, 17 | dt_rank: Union[int, str] = "auto", 18 | pad_vocab_size_multiple=8, 19 | initializer_range=0.02, 20 | **kwargs, 21 | ): 22 | self.vocab_size = vocab_size 23 | self.n_layer= n_layer 24 | self.conv_bias = conv_bias 25 | self.expand = expand 26 | self.pad_vocab_size_multiple = pad_vocab_size_multiple 27 | self.d_conv = d_conv 28 | self.d_model = d_model 29 | self.d_state = d_state 30 | self.d_inner = int(self.expand * self.d_model) 31 | self.dt_rank = dt_rank 32 | self.initializer_range = initializer_range 33 | self.bias = bias 34 | 35 | if self.dt_rank == 'auto': 36 | self.dt_rank = math.ceil(self.d_model / 16) 37 | 38 | if self.vocab_size % self.pad_vocab_size_multiple != 0: 39 | self.vocab_size += (self.pad_vocab_size_multiple 40 | - self.vocab_size % self.pad_vocab_size_multiple) 41 | super().__init__( 42 | **kwargs, 43 | ) -------------------------------------------------------------------------------- /src/modeling_mamba.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from configuration_mamba import MambaConfig 4 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 5 | from transformers.modeling_utils import PreTrainedModel 6 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast 7 | import math 8 | import json 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from dataclasses import dataclass 13 | from einops import rearrange, repeat, einsum 14 | from typing import Optional , Union ,Tuple 15 | 16 | # Dear contributors of the https://github.com/johnma2006/mamba-minimal/tree/master repository, special thanks to Albert Gu and Tri Dao for their articles. (https://arxiv.org/abs/2312.00752) 17 | 18 | 19 | class MambaRMSNorm(nn.Module): 20 | def __init__(self, 21 | d_model: int, 22 | eps: float = 1e-5): 23 | super().__init__() 24 | self.eps = eps 25 | self.weight = nn.Parameter(torch.ones(d_model)) 26 | def forward(self, x): 27 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight 28 | return output 29 | 30 | 31 | class MambaBlock(nn.Module): 32 | def __init__(self, config: MambaConfig): 33 | """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].""" 34 | super().__init__() 35 | self.config = config 36 | 37 | self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=config.bias) 38 | 39 | self.conv1d = nn.Conv1d( 40 | in_channels=config.d_inner, 41 | out_channels=config.d_inner, 42 | bias=config.conv_bias, 43 | kernel_size=config.d_conv, 44 | groups=config.d_inner, 45 | padding=config.d_conv - 1, 46 | ) 47 | 48 | # x_proj takes in `x` and outputs the input-specific Δ, B, C 49 | self.x_proj = nn.Linear(config.d_inner, config.dt_rank + config.d_state * 2, bias=False) 50 | 51 | # dt_proj projects Δ from dt_rank to d_in 52 | self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True) 53 | 54 | A = repeat(torch.arange(1, config.d_state + 1), 'n -> d n', d=config.d_inner) 55 | self.A_log = nn.Parameter(torch.log(A)) 56 | self.D = nn.Parameter(torch.ones(config.d_inner)) 57 | self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias) 58 | self.norm = MambaRMSNorm(config.d_model) 59 | 60 | def forward(self, x): 61 | """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1]. 62 | 63 | Args: 64 | x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...) 65 | 66 | Returns: 67 | output: shape (b, l, d) 68 | 69 | Official Implementation: 70 | class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119 71 | mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 72 | 73 | """ 74 | 75 | (b, l, d) = x.shape 76 | x_copy = x # There was a separate class for residual, I deleted that part and added it here. 77 | x = self.norm(x) 78 | x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in) 79 | (x, res) = x_and_res.split(split_size=[self.config.d_inner, self.config.d_inner], dim=-1) 80 | 81 | x = rearrange(x, 'b l d_in -> b d_in l') 82 | x = self.conv1d(x)[:, :, :l] 83 | x = rearrange(x, 'b d_in l -> b l d_in') 84 | 85 | x = F.silu(x) 86 | 87 | y = self.ssm(x) 88 | 89 | y = y * F.silu(res) 90 | 91 | output = self.out_proj(y) + x_copy 92 | 93 | return output 94 | 95 | 96 | def ssm(self, x): 97 | """Runs the SSM. See: 98 | - Algorithm 2 in Section 3.2 in the Mamba paper [1] 99 | - run_SSM(A, B, C, u) in The Annotated S4 [2] 100 | 101 | Args: 102 | x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) 103 | 104 | Returns: 105 | output: shape (b, l, d_in) 106 | 107 | Official Implementation: 108 | mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 109 | 110 | """ 111 | (d_in, n) = self.A_log.shape 112 | 113 | # Compute ∆ A B C D, the state space parameters. 114 | # A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) 115 | # ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, 116 | # and is why Mamba is called **selective** state spaces) 117 | 118 | A = -torch.exp(self.A_log.float()) # shape (d_in, n) 119 | D = self.D.float() 120 | 121 | x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n) 122 | 123 | (delta, B, C) = x_dbl.split(split_size=[self.config.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n) 124 | delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in) 125 | 126 | y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2] 127 | 128 | return y 129 | 130 | 131 | def selective_scan(self, u, delta, A, B, C, D): 132 | """Does selective scan algorithm. See: 133 | - Section 2 State Space Models in the Mamba paper [1] 134 | - Algorithm 2 in Section 3.2 in the Mamba paper [1] 135 | - run_SSM(A, B, C, u) in The Annotated S4 [2] 136 | 137 | This is the classic discrete state space formula: 138 | x(t + 1) = Ax(t) + Bu(t) 139 | y(t) = Cx(t) + Du(t) 140 | except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t). 141 | 142 | Args: 143 | u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) 144 | delta: shape (b, l, d_in) 145 | A: shape (d_in, n) 146 | B: shape (b, l, n) 147 | C: shape (b, l, n) 148 | D: shape (d_in,) 149 | 150 | Returns: 151 | output: shape (b, l, d_in) 152 | 153 | Official Implementation: 154 | selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86 155 | Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly. 156 | 157 | """ 158 | (b, l, d_in) = u.shape 159 | n = A.shape[1] 160 | 161 | # Discretize continuous parameters (A, B) 162 | # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1]) 163 | # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: 164 | # "A is the more important term and the performance doesn't change much with the simplication on B" 165 | deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) 166 | deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') 167 | 168 | # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) 169 | x = torch.zeros((b, d_in, n), device=deltaA.device) 170 | ys = [] 171 | for i in range(l): 172 | x = deltaA[:, :, i] * x + deltaB_u[:, :, i] 173 | y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') 174 | ys.append(y) 175 | y = torch.stack(ys, dim=1) # shape (b, l, d_in) 176 | 177 | y = y + u * D 178 | 179 | return y 180 | 181 | class MambaPreTrainedModel(PreTrainedModel): 182 | config_class = MambaConfig 183 | base_model_prefix = "model" 184 | supports_gradient_checkpointing = True 185 | _no_split_modules = ["MambaBlock"] 186 | 187 | def _init_weights(self, module): 188 | std = 0.02 189 | if isinstance(module, (nn.Linear, nn.Conv1d)): 190 | module.weight.data.normal_(mean=0.0, std=std) 191 | if module.bias is not None: 192 | module.bias.data.zero_() 193 | elif isinstance(module, nn.Embedding): 194 | module.weight.data.normal_(mean=0.0, std=std) 195 | if module.padding_idx is not None: 196 | module.weight.data[module.padding_idx].zero_() 197 | 198 | class MambaModel(MambaPreTrainedModel): 199 | def __init__(self, config: MambaConfig): 200 | """Full Mamba model. 201 | Mamba model decoder consisting of *config.n_layer* layers. Each layer is a [`MambaBlock`] 202 | 203 | Args: 204 | config: MambaConfig 205 | """ 206 | super().__init__(config) 207 | self.config = config 208 | 209 | self.embedding = nn.Embedding(config.vocab_size, config.d_model) 210 | self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)]) 211 | self.norm_f = MambaRMSNorm(config.d_model) 212 | 213 | self.gradient_checkpointing = False 214 | self.post_init() 215 | 216 | def get_input_embeddings(self): 217 | return self.embedding 218 | 219 | def set_input_embeddings(self, value): 220 | self.embedding = value 221 | 222 | def forward(self, 223 | input_ids: torch.LongTensor = None, 224 | return_dict: Optional[bool] = None, 225 | )-> Union[Tuple, BaseModelOutputWithPast]: 226 | x = self.embedding(input_ids) 227 | all_hidden_states = list() 228 | for layer in self.layers: 229 | x = layer(x) 230 | all_hidden_states.append(x) 231 | 232 | hidden_states = self.norm_f(x) 233 | 234 | return BaseModelOutputWithPast( 235 | last_hidden_state=hidden_states, 236 | hidden_states=all_hidden_states, 237 | ) 238 | class MambaForCausalLM(MambaPreTrainedModel): 239 | _tied_weights_keys = ["lm_head.weight"] 240 | 241 | def __init__(self, config): 242 | super().__init__(config) 243 | self.model = MambaModel(config) 244 | self.vocab_size = config.vocab_size 245 | self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) 246 | self.lm_head.weight = self.model.embedding.weight 247 | self.post_init() 248 | 249 | def get_input_embeddings(self): 250 | return self.model.embedding 251 | 252 | def set_input_embeddings(self, value): 253 | self.model.embedding = value 254 | 255 | def get_output_embeddings(self): 256 | return self.lm_head 257 | 258 | def set_output_embeddings(self, new_embeddings): 259 | self.lm_head = new_embeddings 260 | 261 | def set_decoder(self, decoder): 262 | self.model = decoder 263 | 264 | def get_decoder(self): 265 | return self.model 266 | 267 | def forward(self, 268 | input_ids: torch.LongTensor = None, 269 | labels: Optional[torch.LongTensor] = None, 270 | output_attentions: Optional[bool] = None, 271 | output_hidden_states: Optional[bool] = None, 272 | return_dict: Optional[bool] = None, 273 | )-> Union[Tuple, CausalLMOutputWithPast]: 274 | outputs = self.model( 275 | input_ids=input_ids, 276 | return_dict=return_dict, 277 | ) 278 | hidden_states = outputs[0] 279 | logits = self.lm_head(hidden_states) 280 | logits = logits.float() 281 | loss = None 282 | if labels is not None: 283 | shift_logits = logits[..., :-1, :].contiguous() 284 | shift_labels = labels[..., 1:].contiguous() 285 | loss_fct = CrossEntropyLoss() 286 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 287 | shift_labels = shift_labels.view(-1) 288 | 289 | shift_labels = shift_labels.to(shift_logits.device) 290 | loss = loss_fct(shift_logits, shift_labels) 291 | 292 | if not return_dict: 293 | output = (logits,) + outputs[1:] 294 | return (loss,) + output if loss is not None else output 295 | 296 | return CausalLMOutputWithPast( 297 | loss=loss, 298 | logits=logits, 299 | hidden_states=outputs.hidden_states, 300 | ) 301 | 302 | def prepare_inputs_for_generation( 303 | self, input_ids, **kwargs 304 | ): 305 | model_inputs = {"input_ids": input_ids} 306 | return model_inputs 307 | 308 | 309 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | typing 3 | math 4 | torch 5 | einops 6 | --------------------------------------------------------------------------------