├── .gitignore ├── LICENSE ├── README.md └── ttt.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 test-time-training 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning to (Learn at Test Time): RNNs with Expressive Hidden States 2 | 3 | [**Paper**](https://arxiv.org/abs/2407.04620) 4 | | [**JAX Codebase**](https://github.com/test-time-training/ttt-lm-jax) 5 | | [**Setup**](#environment-setup) 6 | | [**Quick Start**](#quick-start) 7 | | [**Inference Benchmark**](https://github.com/test-time-training/ttt-lm-kernels) 8 | 9 | This is the official PyTorch model implementation of [Learning to (Learn at Test Time): RNNs with Expressive Hidden States](https://arxiv.org/abs/2407.04620). 10 | We **do not recommend training** with this codebase, because it is written in pure PyTorch without any systems optimization, so training will be slow, especially when the per-device batch size is small. 11 | 12 | 13 | For training code, or to replicate results from our paper, please view our [JAX codebase](https://github.com/test-time-training/ttt-lm-jax). For inference kernels, or to replicate speed benchmarks from our paper, please view our [kernel implementations](https://github.com/test-time-training/ttt-lm-kernels). 14 | 15 | ## Abstract 16 | 17 | Self-attention performs well in long context but has quadratic complexity. Existing RNN layers 18 | have linear complexity, but their performance in long context is limited by the expressive power 19 | of their hidden state. We propose a new class of sequence modeling layers with linear complexity 20 | and an expressive hidden state. The key idea is to make the hidden state a machine learning 21 | model itself, and the update rule a step of self-supervised learning. 22 | 23 | Since the hidden state is updated by training even on test sequences, our layers are called **Test-Time Training (TTT) layers**. 24 | We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model 25 | and a two-layer MLP respectively. 26 | 27 | ## Environment Setup 28 | 29 | ```bash 30 | pip install "transformers[torch]" 31 | ``` 32 | 33 | ## Quick Start 34 | 35 | Our implementation is based on Huggingface Transformers. You can use the following code to load the model and generate text. 36 | 37 | ```python 38 | from transformers import AutoTokenizer 39 | from ttt import TTTForCausalLM, TTTConfig, TTT_STANDARD_CONFIGS 40 | 41 | # Initializing a TTT ttt-1b style configuration 42 | # configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following 43 | configuration = TTTConfig() 44 | 45 | # Initializing a model from the ttt-1b style configuration 46 | model = TTTForCausalLM(configuration) 47 | model.eval() 48 | 49 | # Accessing the model configuration 50 | configuration = model.config 51 | 52 | # Tokenizer 53 | tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf') 54 | 55 | # Prefill 56 | input_ids = tokenizer("Greeting from TTT!", return_tensors="pt").input_ids 57 | logits = model(input_ids=input_ids) 58 | print(logits) 59 | 60 | # Decoding 61 | out_ids = model.generate(input_ids=input_ids, max_length=50) 62 | out_str = tokenizer.batch_decode(out_ids, skip_special_tokens=True) 63 | print(out_str) 64 | ``` 65 | 66 | **Note: This is a naive implementation of TTT layers for tutorial purposes.** This model can be trained using Huggingface Accelerate, or custom training loops. We have released our faster inference kernel and its speed benchmark [here](https://github.com/test-time-training/ttt-lm-kernels). 67 | -------------------------------------------------------------------------------- /ttt.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, Optional, Tuple, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.utils.checkpoint 8 | from torch import nn 9 | from torch.nn import CrossEntropyLoss 10 | from torch.utils._pytree import tree_map 11 | 12 | from transformers import PretrainedConfig 13 | from transformers.activations import ACT2FN 14 | from transformers.modeling_outputs import ( 15 | BaseModelOutputWithPast, 16 | CausalLMOutputWithPast, 17 | ) 18 | from transformers.modeling_utils import PreTrainedModel 19 | from transformers.utils import ModelOutput, logging 20 | from transformers.utils.import_utils import is_causal_conv1d_available 21 | 22 | if is_causal_conv1d_available(): 23 | from causal_conv1d import causal_conv1d_fn, causal_conv1d_update 24 | else: 25 | causal_conv1d_update, causal_conv1d_fn = None, None 26 | 27 | 28 | logger = logging.get_logger(__name__) 29 | 30 | TTT_STANDARD_CONFIGS = { 31 | "125m": { 32 | "hidden_size": 768, 33 | "intermediate_size": 2048, 34 | "num_hidden_layers": 12, 35 | "num_attention_heads": 12, 36 | }, 37 | "350m": { 38 | "hidden_size": 1024, 39 | "intermediate_size": 2736, 40 | "num_hidden_layers": 24, 41 | "num_attention_heads": 16, 42 | }, 43 | "760m": { 44 | "hidden_size": 1536, 45 | "intermediate_size": 4096, 46 | "num_hidden_layers": 24, 47 | "num_attention_heads": 16, 48 | }, 49 | "1b": { 50 | "hidden_size": 2048, 51 | "intermediate_size": 5504, 52 | "num_hidden_layers": 24, 53 | "num_attention_heads": 32, 54 | }, 55 | } 56 | 57 | 58 | class TTTConfig(PretrainedConfig): 59 | r""" 60 | This is the configuration class to store the configuration of a [`TTTModel`]. It is used to instantiate an TTT 61 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 62 | defaults will yield a similar configuration to that of the TTT-1B. 63 | 64 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 65 | documentation from [`PretrainedConfig`] for more information. 66 | 67 | 68 | Args: 69 | vocab_size (`int`, *optional*, defaults to 32000): 70 | Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the 71 | `inputs_ids` passed when calling [`LlamaModel`] 72 | hidden_size (`int`, *optional*, defaults to 4096): 73 | Dimension of the hidden representations. 74 | intermediate_size (`int`, *optional*, defaults to 11008): 75 | Dimension of the MLP representations. 76 | num_hidden_layers (`int`, *optional*, defaults to 32): 77 | Number of hidden layers in the Transformer decoder. 78 | num_attention_heads (`int`, *optional*, defaults to 32): 79 | Number of attention heads for each attention layer in the Transformer decoder. 80 | num_key_value_heads (`int`, *optional*): 81 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 82 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 83 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 84 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 85 | by meanpooling all the original heads within that group. For more details checkout [this 86 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 87 | `num_attention_heads`. 88 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 89 | The non-linear activation function (function or string) in the decoder. 90 | max_position_embeddings (`int`, *optional*, defaults to 2048): 91 | The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, 92 | Llama 2 up to 4096, CodeLlama up to 16384. 93 | initializer_range (`float`, *optional*, defaults to 0.02): 94 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 95 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 96 | The epsilon used by the rms normalization layers. 97 | use_cache (`bool`, *optional*, defaults to `True`): 98 | Whether or not the model should return the last key/values attentions (not used by all models). Only 99 | relevant if `config.is_decoder=True`. 100 | pad_token_id (`int`, *optional*): 101 | Padding token id. 102 | bos_token_id (`int`, *optional*, defaults to 1): 103 | Beginning of stream token id. 104 | eos_token_id (`int`, *optional*, defaults to 2): 105 | End of stream token id. 106 | pretraining_tp (`int`, *optional*, defaults to 1): 107 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 108 | document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is 109 | necessary to ensure exact reproducibility of the pretraining results. Please refer to [this 110 | issue](https://github.com/pytorch/pytorch/issues/76232). 111 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 112 | Whether to tie weight embeddings 113 | rope_theta (`float`, *optional*, defaults to 10000.0): 114 | The base period of the RoPE embeddings. 115 | rope_scaling (`Dict`, *optional*): 116 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 117 | strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is 118 | `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 119 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how 120 | these scaling strategies behave: 121 | https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an 122 | experimental feature, subject to breaking API changes in future versions. 123 | use_gate (`bool`, *optional*, defaults to `False`): whether use gating in Mamba backbone 124 | share_qk (`bool`, *optional*, defaults to `False`): whether share Q/K projection matrix 125 | ttt_layer_type (`str`, *optional*, defaults to `"linear"`): ttt block type, "linear" or "mlp", stands for TTT-Linear and TTT-MLP 126 | ttt_base_lr (`float`, *optional*, defaults to 1.0): base learning rate for TTT learner 127 | pre_conv (`bool`, *optional*, defaults to `False`): whether use conv before TTT 128 | conv_kernel (`int`, *optional*, defaults to 4): kernel size of the conv layer 129 | scan_checkpoint_group_size (`int`, *optional*, defaults to 0): 130 | gradient checkpoint group size on seq dimension, 0 means no checkpointing. 131 | In JAX implementation, we set it 4, which means we group 4 mini-batches together in 1 gradient checkpointg to save memory. 132 | 133 | 134 | ```python 135 | >>> from . import TTTModel, TTTConfig 136 | 137 | >>> # Initializing a TTT ttt-1b style configuration 138 | >>> configuration = TTTConfig() 139 | 140 | >>> # Initializing a model from the ttt-1b style configuration 141 | >>> model = TTTModel(configuration) 142 | 143 | >>> # Accessing the model configuration 144 | >>> configuration = model.config 145 | ```""" 146 | 147 | model_type = "ttt" 148 | 149 | def __init__( 150 | self, 151 | vocab_size=32000, 152 | hidden_size=2048, 153 | intermediate_size=5504, 154 | num_hidden_layers=24, 155 | num_attention_heads=32, 156 | hidden_act="silu", 157 | max_position_embeddings=2048, 158 | initializer_range=0.02, 159 | rms_norm_eps=1e-6, 160 | use_cache=False, 161 | pad_token_id=None, 162 | bos_token_id=1, 163 | eos_token_id=2, 164 | pretraining_tp=1, 165 | tie_word_embeddings=True, 166 | rope_theta=10000.0, 167 | use_gate=False, 168 | share_qk=False, 169 | ttt_layer_type="linear", 170 | ttt_base_lr=1.0, 171 | mini_batch_size=16, 172 | pre_conv=False, 173 | conv_kernel=4, 174 | scan_checkpoint_group_size=0, 175 | **kwargs, 176 | ): 177 | self.vocab_size = vocab_size 178 | self.max_position_embeddings = max_position_embeddings 179 | self.hidden_size = hidden_size 180 | self.intermediate_size = intermediate_size 181 | self.num_hidden_layers = num_hidden_layers 182 | self.num_attention_heads = num_attention_heads 183 | 184 | self.hidden_act = hidden_act 185 | self.initializer_range = initializer_range 186 | self.rms_norm_eps = rms_norm_eps 187 | self.pretraining_tp = pretraining_tp 188 | self.use_cache = use_cache 189 | self.rope_theta = rope_theta 190 | 191 | self.use_gate = use_gate 192 | self.share_qk = share_qk 193 | self.ttt_layer_type = ttt_layer_type 194 | self.ttt_base_lr = ttt_base_lr 195 | self.mini_batch_size = mini_batch_size 196 | 197 | self.pre_conv = pre_conv 198 | self.conv_kernel = conv_kernel 199 | self.scan_checkpoint_group_size = scan_checkpoint_group_size 200 | 201 | super().__init__( 202 | pad_token_id=pad_token_id, 203 | bos_token_id=bos_token_id, 204 | eos_token_id=eos_token_id, 205 | tie_word_embeddings=tie_word_embeddings, 206 | **kwargs, 207 | ) 208 | 209 | 210 | ######################## 211 | ### Backbone Modules ### 212 | ######################## 213 | 214 | 215 | def rotate_half(x): 216 | """Rotates half the hidden dims of the input.""" 217 | x1 = x[..., : x.shape[-1] // 2] 218 | x2 = x[..., x.shape[-1] // 2 :] 219 | return torch.cat((-x2, x1), dim=-1) 220 | 221 | 222 | def permute_qk(q, k): 223 | # NOTE: EasyLM and transformers use different method to compute rotary emebdding 224 | # we manually reorder the dim here to match our JAX implementation 225 | # which may not be optimal for speed 226 | # reference: https://github.com/young-geng/EasyLM/blob/981a2ed9630f44258a94b6f44dff2b7bd203ae8d/EasyLM/models/llama/convert_hf_to_easylm.py#L33 227 | bsz, num_head, seq_len, head_dim = q.shape 228 | q = q.reshape(bsz, num_head, seq_len, head_dim // 2, 2).transpose(3, 4).reshape(bsz, num_head, seq_len, head_dim) 229 | k = k.reshape(bsz, num_head, seq_len, head_dim // 2, 2).transpose(3, 4).reshape(bsz, num_head, seq_len, head_dim) 230 | 231 | return q, k 232 | 233 | 234 | def undo_permute_qk(q, k): 235 | # NOTE: EasyLM and transformers use different method to compute rotary emebdding 236 | # we manually undo the reorder the dim here to match our JAX implementation 237 | # which may not be optimal for speed 238 | # reference: https://github.com/young-geng/EasyLM/blob/981a2ed9630f44258a94b6f44dff2b7bd203ae8d/EasyLM/models/llama/convert_hf_to_easylm.py#L33 239 | bsz, num_head, seq_len, head_dim = q.shape 240 | q = q.reshape(bsz, num_head, seq_len, 2, head_dim // 2).transpose(3, 4).reshape(bsz, num_head, seq_len, head_dim) 241 | k = k.reshape(bsz, num_head, seq_len, 2, head_dim // 2).transpose(3, 4).reshape(bsz, num_head, seq_len, head_dim) 242 | 243 | return q, k 244 | 245 | 246 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): 247 | """Applies Rotary Position Embedding to the query and key tensors. 248 | 249 | Args: 250 | q (`torch.Tensor`): The query tensor. 251 | k (`torch.Tensor`): The key tensor. 252 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 253 | sin (`torch.Tensor`): The sine part of the rotary embedding. 254 | position_ids (`torch.Tensor`, *optional*): 255 | Deprecated and unused. 256 | unsqueeze_dim (`int`, *optional*, defaults to 1): 257 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 258 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 259 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 260 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 261 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 262 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 263 | Returns: 264 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 265 | """ 266 | cos = cos.unsqueeze(unsqueeze_dim) 267 | sin = sin.unsqueeze(unsqueeze_dim) 268 | q_embed = (q * cos) + (rotate_half(q) * sin) 269 | k_embed = (k * cos) + (rotate_half(k) * sin) 270 | return q_embed, k_embed 271 | 272 | 273 | class RMSNorm(nn.Module): 274 | def __init__(self, hidden_size, eps=1e-6): 275 | super().__init__() 276 | self.weight = nn.Parameter(torch.ones(hidden_size)) 277 | self.variance_epsilon = eps 278 | 279 | def forward(self, hidden_states): 280 | input_dtype = hidden_states.dtype 281 | hidden_states = hidden_states.to(torch.float32) 282 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 283 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 284 | return self.weight * hidden_states.to(input_dtype) 285 | 286 | 287 | class SwiGluMLP(nn.Module): 288 | def __init__(self, config): 289 | super().__init__() 290 | self.config = config 291 | self.hidden_size = config.hidden_size 292 | self.intermediate_size = config.intermediate_size 293 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 294 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 295 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 296 | self.act_fn = ACT2FN[config.hidden_act] 297 | 298 | def forward(self, x): 299 | if self.config.pretraining_tp > 1: 300 | slice = self.intermediate_size // self.config.pretraining_tp 301 | gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) 302 | up_proj_slices = self.up_proj.weight.split(slice, dim=0) 303 | down_proj_slices = self.down_proj.weight.split(slice, dim=1) 304 | 305 | gate_proj = torch.cat( 306 | [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], 307 | dim=-1, 308 | ) 309 | up_proj = torch.cat( 310 | [F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], 311 | dim=-1, 312 | ) 313 | 314 | intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) 315 | down_proj = [ 316 | F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) 317 | ] 318 | down_proj = sum(down_proj) 319 | else: 320 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 321 | 322 | return down_proj 323 | 324 | 325 | class RotaryEmbedding(nn.Module): 326 | def __init__( 327 | self, 328 | dim, 329 | max_position_embeddings=16, 330 | base=10000, 331 | device=None, 332 | scaling_factor=1.0, 333 | ): 334 | super().__init__() 335 | self.scaling_factor = scaling_factor 336 | self.dim = dim 337 | self.max_position_embeddings = max_position_embeddings 338 | self.base = base 339 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) 340 | self.register_buffer("inv_freq", inv_freq, persistent=False) 341 | 342 | @torch.no_grad() 343 | def forward(self, x, position_ids): 344 | # x: [bs, num_attention_heads, seq_len, head_size] 345 | inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) 346 | position_ids_expanded = position_ids[:, None, :].float() 347 | # Force float32 since bfloat16 loses precision on long contexts 348 | # See https://github.com/huggingface/transformers/pull/29285 349 | device_type = x.device.type 350 | device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" 351 | with torch.autocast(device_type=device_type, enabled=False): 352 | freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) 353 | emb = torch.cat((freqs, freqs), dim=-1) 354 | cos = emb.cos() 355 | sin = emb.sin() 356 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 357 | 358 | 359 | class Conv(nn.Module): 360 | def __init__(self, config, layer_idx): 361 | super().__init__() 362 | self.config = config 363 | self.layer_idx = layer_idx 364 | 365 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 366 | self.conv = nn.Conv1d( 367 | config.hidden_size, 368 | config.hidden_size, 369 | bias=True, 370 | kernel_size=config.conv_kernel, 371 | groups=config.hidden_size, 372 | padding=config.conv_kernel - 1, 373 | ) 374 | 375 | def __call__(self, hidden_states, cache_params=None): 376 | seq_len = hidden_states.shape[1] 377 | hidden_states = self.norm(hidden_states) 378 | # [B, C, L] 379 | hidden_states = hidden_states.transpose(1, 2) 380 | 381 | if causal_conv1d_fn is None: 382 | if cache_params is not None: 383 | if cache_params.seqlen_offset > 0: 384 | conv_state = cache_params.conv_states_dic["pre_conv"][self.layer_idx] 385 | conv_state = torch.roll(conv_state, shifts=-1, dims=-1) 386 | conv_state[:, :, -1] = hidden_states[:, :, 0] 387 | cache_params.conv_states_dic["pre_conv"][self.layer_idx].copy_(conv_state) 388 | hidden_states = torch.sum(conv_state * self.conv.weight[:, 0, :], dim=-1) 389 | hidden_states += self.conv.bias 390 | hidden_states = hidden_states.unsqueeze(-1) 391 | else: 392 | conv_state = nn.functional.pad( 393 | hidden_states, 394 | (self.config.conv_kernel - hidden_states.shape[-1], 0), 395 | ) 396 | cache_params.conv_states_dic["pre_conv"][self.layer_idx].copy_(conv_state) 397 | hidden_states = self.conv(hidden_states)[..., :seq_len] 398 | else: 399 | hidden_states = self.conv(hidden_states)[..., :seq_len] 400 | else: 401 | conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2)) 402 | if cache_params is not None and cache_params.seqlen_offset > 0: 403 | hidden_states = causal_conv1d_update( 404 | hidden_states.squeeze(-1), 405 | cache_params.conv_states_dic["pre_conv"][self.layer_idx], 406 | conv_weights, 407 | self.conv.bias, 408 | None, 409 | ) 410 | hidden_states = hidden_states.unsqueeze(-1) 411 | else: 412 | if cache_params is not None: 413 | conv_states = nn.functional.pad( 414 | hidden_states, 415 | (self.config.conv_kernel - hidden_states.shape[-1], 0), 416 | ) 417 | cache_params.conv_states_dic["pre_conv"][self.layer_idx].copy_(conv_states) 418 | hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv.bias, activation=None) 419 | 420 | # [B, L, C] 421 | hidden_states = hidden_states.transpose(1, 2) 422 | 423 | return hidden_states 424 | 425 | 426 | ######################### 427 | ### TTT Layer Modules ### 428 | ######################### 429 | 430 | 431 | def scan(f, init, xs, out, checkpoint_group=0): 432 | """Minic jax.lax.scan function.""" 433 | carry = init 434 | if isinstance(xs, dict): 435 | num_items = len(next(iter(xs.values()))) 436 | else: 437 | num_items = len(xs[0]) 438 | 439 | def scan_fn(carry, i_start, i_end): 440 | for i in range(i_start, i_end): 441 | if isinstance(xs, dict): 442 | x = {key: tensor[i] for key, tensor in xs.items()} 443 | else: 444 | x = [x[i] for x in xs] 445 | carry, y = f(carry, x) 446 | out[i] = y 447 | return carry 448 | 449 | if checkpoint_group > 0: 450 | ckpt_every_n = num_items // checkpoint_group 451 | for k in range(0, num_items, ckpt_every_n): 452 | carry = torch.utils.checkpoint.checkpoint( 453 | scan_fn, carry, k, min(k + ckpt_every_n, num_items), use_reentrant=False 454 | ) 455 | else: 456 | carry = scan_fn(carry, 0, num_items) 457 | 458 | return carry, out 459 | 460 | 461 | def ln_fwd(x, gamma, beta, eps=1e-6): 462 | "Batch forward for LayerNorm." 463 | 464 | # Mean and variance computation 465 | mu = x.mean(dim=-1, keepdim=True) 466 | var = x.var(dim=-1, keepdim=True, unbiased=False) 467 | 468 | # Normalization 469 | std = torch.sqrt(var + eps) 470 | x_hat = (x - mu) / std 471 | 472 | # Scale and shift 473 | y = gamma * x_hat + beta 474 | 475 | return y 476 | 477 | 478 | def ln_fused_l2_bwd(x, l2_target, gamma, beta, eps=1e-6): 479 | "Batch backward for LayerNorm fused with L2 loss." 480 | D = x.shape[-1] 481 | 482 | # Mean and variance computation 483 | mu = x.mean(dim=-1, keepdim=True) 484 | var = x.var(dim=-1, keepdim=True, unbiased=False) 485 | 486 | # Normalization 487 | std = torch.sqrt(var + eps) 488 | x_hat = (x - mu) / std 489 | 490 | # Scale and shift 491 | y = gamma * x_hat + beta 492 | 493 | grad_output = y - l2_target 494 | grad_x_hat = grad_output * gamma 495 | z = ( 496 | (1.0 / D) 497 | * ( 498 | D * grad_x_hat 499 | - grad_x_hat.sum(dim=-1, keepdim=True) 500 | - x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True) 501 | ) 502 | / std 503 | ) 504 | 505 | return z 506 | 507 | 508 | # Modified from https://github.com/NVIDIA/Megatron-LM/blob/e33c8f78a35765d5aa37475a144da60e8a2349d1/megatron/core/fusions/fused_bias_gelu.py#L26 509 | def gelu_bwd(x): 510 | tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) 511 | ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) 512 | return ff 513 | 514 | 515 | class TTTCache: 516 | """ 517 | TTTCache is a data structure that holds the last hidden states and gradients for the TTT layer. 518 | 519 | Arguments: 520 | model: TTTModel 521 | batch_size: int 522 | 523 | Attributes: 524 | seqlen_offset: int 525 | mini_batch_size: int 526 | params_dict: Dict[str, Dict[int, torch.Tensor]] *_states, *_grad -> # layer_idx -> [batch_size, ...] 527 | conv_states_dic: Dict[str, Dict[int, torch.Tensor]] *_states -> # layer_idx -> [batch_size, ...] 528 | 529 | """ 530 | 531 | def __init__(self, model, batch_size: int): 532 | config = model.config 533 | self.seqlen_offset = 0 534 | self.mini_batch_size = config.mini_batch_size 535 | 536 | self.ttt_params_dict = defaultdict(dict) 537 | if "linear" in config.ttt_layer_type: 538 | self.ttt_param_names = ["W1", "b1"] 539 | elif "mlp" in config.ttt_layer_type: 540 | self.ttt_param_names = ["W1", "b1", "W2", "b2"] 541 | else: 542 | raise ValueError(f"TTT Layer Type {config.ttt_layer_type} not supported yet") 543 | 544 | self.conv_states_dic = defaultdict(dict) 545 | logger.info(f"Creating cache of size: {batch_size}") 546 | for layer_idx in range(config.num_hidden_layers): 547 | for name in self.ttt_param_names: 548 | weight = getattr(model.layers[layer_idx].seq_modeling_block, name) 549 | tiled_weight = torch.tile(weight.unsqueeze(0), (batch_size,) + (1,) * weight.dim()).to(model.device) 550 | self.ttt_params_dict[f"{name}_states"][layer_idx] = tiled_weight 551 | # for decoding, we need to store the gradients as well 552 | self.ttt_params_dict[f"{name}_grad"][layer_idx] = torch.zeros_like(tiled_weight) 553 | 554 | if config.pre_conv: 555 | self.conv_states_dic["pre_conv"][layer_idx] = torch.zeros( 556 | batch_size, 557 | config.hidden_size, 558 | config.conv_kernel, 559 | device=model.device, 560 | ) 561 | if config.share_qk: 562 | self.conv_states_dic["ttt_conv_q"][layer_idx] = torch.zeros( 563 | batch_size, 564 | config.hidden_size, 565 | config.conv_kernel, 566 | device=model.device, 567 | ) 568 | self.conv_states_dic["ttt_conv_k"][layer_idx] = torch.zeros( 569 | batch_size, 570 | config.hidden_size, 571 | config.conv_kernel, 572 | device=model.device, 573 | ) 574 | 575 | def update(self, py_tree, layer_idx, seq_len): 576 | if seq_len % self.mini_batch_size == 0: 577 | # copy last mini-batch states, clear gradients 578 | for name in self.ttt_param_names: 579 | self.ttt_params_dict[f"{name}_states"][layer_idx].copy_(py_tree[f"{name}_states"]) 580 | self.ttt_params_dict[f"{name}_grad"][layer_idx].zero_() 581 | elif seq_len < self.mini_batch_size: 582 | if seq_len != 1 and self.seqlen_offset > 0 and self.seqlen_offset % self.mini_batch_size != 0: 583 | raise ValueError("fractional update not supported yet.") 584 | if (seq_len + self.seqlen_offset) % self.mini_batch_size == 0: 585 | # copy last mini-batch states, clear gradients 586 | for name in self.ttt_param_names: 587 | self.ttt_params_dict[f"{name}_states"][layer_idx].copy_(py_tree[f"{name}_states"]) 588 | self.ttt_params_dict[f"{name}_grad"][layer_idx].zero_() 589 | else: 590 | # copy gradients for the next update 591 | for name in self.ttt_param_names: 592 | self.ttt_params_dict[f"{name}_grad"][layer_idx].copy_(py_tree[f"{name}_grad"]) 593 | else: 594 | raise ValueError(f"seq_len {seq_len} is a partial update not supported yet") 595 | 596 | def ttt_params_to_dict(self, layer_idx): 597 | return {name: self.ttt_params_dict[name][layer_idx] for name in self.ttt_params_dict} 598 | 599 | 600 | class TTTBase(nn.Module): 601 | def __init__(self, config: TTTConfig, layer_idx: Optional[int] = None): 602 | super().__init__() 603 | self.config = config 604 | self.layer_idx = layer_idx 605 | if layer_idx is None: 606 | logger.warning_once( 607 | f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " 608 | "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " 609 | "when creating this class." 610 | ) 611 | 612 | self.width = config.hidden_size 613 | self.hidden_size = config.hidden_size 614 | self.num_heads = config.num_attention_heads 615 | self.head_dim = self.width // self.num_heads 616 | self.mini_batch_size = config.mini_batch_size 617 | 618 | # token_idx is a scale factor that scale the summation in Eqn. 4 619 | token_idx = 1.0 / torch.arange(1, self.mini_batch_size + 1) 620 | self.register_buffer("token_idx", token_idx, persistent=False) 621 | # make the scale factor learnable 622 | self.learnable_token_idx = nn.Parameter(torch.zeros((self.mini_batch_size,))) 623 | 624 | self.share_qk = config.share_qk 625 | self.conv_kernel = config.conv_kernel 626 | self._init_qkvo_proj() 627 | self._init_rope() 628 | # Learnable eta in Sec. 2.7 629 | self._init_ttt_lr_gate() 630 | self._init_ttt_ln() 631 | 632 | # use gating as in Mamba backbone 633 | self.use_gate = config.use_gate 634 | if self.use_gate: 635 | self.g_proj = nn.Linear(self.width, self.width, bias=False) 636 | 637 | self.post_norm = nn.LayerNorm(self.width, eps=1e-6) 638 | 639 | def _init_qkvo_proj(self): 640 | self.q_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False) 641 | # we share Q/K projection when using Mamba backbone 642 | if not self.share_qk: 643 | self.k_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False) 644 | self.v_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False) 645 | self.o_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False) 646 | 647 | # after share Q/K projection, we use different conv layers for Q and K 648 | if self.share_qk: 649 | self.conv_q = nn.Conv1d( 650 | self.hidden_size, 651 | self.hidden_size, 652 | bias=True, 653 | kernel_size=self.conv_kernel, 654 | groups=self.hidden_size, 655 | padding=self.conv_kernel - 1, 656 | ) 657 | self.conv_k = nn.Conv1d( 658 | self.hidden_size, 659 | self.hidden_size, 660 | bias=True, 661 | kernel_size=self.conv_kernel, 662 | groups=self.hidden_size, 663 | padding=self.conv_kernel - 1, 664 | ) 665 | 666 | def _init_rope(self): 667 | self.rope_theta = self.config.rope_theta 668 | self.rotary_emb = RotaryEmbedding( 669 | self.head_dim, 670 | max_position_embeddings=self.mini_batch_size, 671 | base=self.rope_theta, 672 | ) 673 | 674 | def _init_ttt_lr_gate(self): 675 | # [width, 1] 676 | linear_weight_data = nn.Linear(self.width, 1, bias=True).weight.data 677 | # prepending head dim -> [num_heads, width, 1] 678 | self.learnable_ttt_lr_weight = nn.Parameter( 679 | torch.stack( 680 | [torch.normal(0, 0.02, size=linear_weight_data.shape) for _ in range(self.num_heads)], 681 | dim=0, 682 | ) 683 | ) 684 | linear_bias_data = nn.Linear(self.width, 1, bias=True).bias.data 685 | # init bias to 0 following original JAX impl. 686 | # [num_heads, 1] 687 | self.learnable_ttt_lr_bias = nn.Parameter( 688 | torch.stack( 689 | [torch.zeros_like(linear_bias_data) for _ in range(self.num_heads)], 690 | dim=0, 691 | ) 692 | ) 693 | 694 | def _init_ttt_ln(self): 695 | ln_weight_data = nn.LayerNorm(self.head_dim).weight.data 696 | # prepending head dim -> [num_heads, width] 697 | self.ttt_norm_weight = nn.Parameter(torch.tile(ln_weight_data.unsqueeze(0), (self.num_heads, 1))) 698 | ln_bias_data = nn.LayerNorm(self.head_dim).bias.data 699 | self.ttt_norm_bias = nn.Parameter(torch.tile(ln_bias_data.unsqueeze(0), (self.num_heads, 1))) 700 | 701 | def get_qkv_projections(self, hidden_states, cache_params: Optional[TTTCache] = None): 702 | if self.share_qk: 703 | xq, XV = self.q_proj(hidden_states), self.v_proj(hidden_states) 704 | seq_len = xq.shape[1] 705 | xq = xq.transpose(1, 2) 706 | if causal_conv1d_fn is None: 707 | if cache_params is not None: 708 | if cache_params.seqlen_offset > 0: 709 | conv_q_state = cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx] 710 | conv_q_state = torch.roll(conv_q_state, shifts=-1, dims=-1) 711 | conv_q_state[:, :, -1] = xq[:, :, 0] 712 | cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx].copy_(conv_q_state) 713 | XQ = torch.sum(conv_q_state * self.conv_q.weight[:, 0, :], dim=-1) 714 | XQ += self.conv_q.bias 715 | XQ = XQ.unsqueeze(-1) 716 | 717 | conv_k_state = cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx] 718 | conv_k_state = torch.roll(conv_k_state, shifts=-1, dims=-1) 719 | conv_k_state[:, :, -1] = xq[:, :, 0] 720 | cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx].copy_(conv_k_state) 721 | XK = torch.sum(conv_k_state * self.conv_k.weight[:, 0, :], dim=-1) 722 | XK += self.conv_k.bias 723 | XK = XK.unsqueeze(-1) 724 | else: 725 | conv_q_state = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0)) 726 | cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx].copy_(conv_q_state) 727 | XQ = self.conv_q(xq)[..., :seq_len] 728 | conv_k_state = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0)) 729 | cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx].copy_(conv_k_state) 730 | XK = self.conv_k(xq)[..., :seq_len] 731 | else: 732 | XQ = self.conv_q(xq)[..., :seq_len] 733 | XK = self.conv_k(xq)[..., :seq_len] 734 | else: 735 | conv_q_weights = self.conv_q.weight.view(self.conv_q.weight.size(0), self.conv_q.weight.size(2)) 736 | conv_k_weights = self.conv_k.weight.view(self.conv_k.weight.size(0), self.conv_k.weight.size(2)) 737 | if cache_params is not None and cache_params.seqlen_offset > 0: 738 | XQ = causal_conv1d_update( 739 | xq.squeeze(-1), 740 | cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx], 741 | conv_q_weights, 742 | self.conv_q.bias, 743 | None, 744 | ) 745 | XQ = XQ.unsqueeze(-1) 746 | XK = causal_conv1d_update( 747 | xq.squeeze(-1), 748 | cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx], 749 | conv_k_weights, 750 | self.conv_k.bias, 751 | None, 752 | ) 753 | XK = XK.unsqueeze(-1) 754 | else: 755 | if cache_params is not None: 756 | conv_q_states = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0)) 757 | cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx].copy_(conv_q_states) 758 | conv_k_states = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0)) 759 | cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx].copy_(conv_k_states) 760 | XQ = causal_conv1d_fn(xq, conv_q_weights, self.conv_q.bias, activation=None) 761 | XK = causal_conv1d_fn(xq, conv_k_weights, self.conv_k.bias, activation=None) 762 | 763 | XQ = XQ.transpose(1, 2) 764 | XK = XK.transpose(1, 2) 765 | else: 766 | XQ, XK, XV = ( 767 | self.q_proj(hidden_states), 768 | self.k_proj(hidden_states), 769 | self.v_proj(hidden_states), 770 | ) 771 | return XQ, XK, XV 772 | 773 | def _split_heads(self, hidden_states): 774 | return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) 775 | 776 | def get_eta(self, X, mini_batch_step_offset, mini_batch_size): 777 | # [B, num_heads, num_mini_batch, mini_batch_size, 1] 778 | ttt_lr = torch.einsum("bnkc,hdc->bhnkd", X, self.learnable_ttt_lr_weight) + self.learnable_ttt_lr_bias.reshape( 779 | 1, -1, 1, 1, 1 780 | ) 781 | ttt_lr = F.sigmoid(ttt_lr) 782 | 783 | # [B, num_heads, num_mini_batch, 1, mini_batch_size] 784 | ttt_lr = ttt_lr.permute(0, 1, 2, 4, 3) 785 | ttt_lr_eta = self.config.ttt_base_lr * ttt_lr / self.head_dim 786 | 787 | # [B, L] 788 | token_idx = self.token_idx + self.learnable_token_idx 789 | token_idx = token_idx[mini_batch_step_offset : mini_batch_step_offset + mini_batch_size] 790 | 791 | # token idx should be greast than 0 792 | token_idx = torch.clamp_min(token_idx, 0.0) 793 | 794 | # NOTE: token_eta is a scale factor that applies to each token in the mini-batch 795 | # [B, num_heads, num_mini_batch, mini_batch_size, 1] 796 | token_eta = torch.broadcast_to( 797 | token_idx.reshape(1, 1, 1, mini_batch_size, 1), 798 | (X.shape[0], self.num_heads, X.shape[1], mini_batch_size, 1), 799 | ) 800 | 801 | return token_eta, ttt_lr_eta 802 | 803 | def apply_gate(self, hidden_states, ttt_output): 804 | y = self.g_proj(hidden_states) 805 | # use 'tanh' approximation for matching JAX impl. 806 | y = F.gelu(y, approximate="tanh") 807 | output = y * ttt_output 808 | return output 809 | 810 | def get_ttt_inputs(self, inputs, mini_batch_size, cache_params): 811 | XQ = inputs["XQ"] 812 | XK = inputs["XK"] 813 | XV = inputs["XV"] 814 | X = inputs["X"] 815 | B, L, C = X.shape 816 | num_mini_batch = L // mini_batch_size 817 | # [B ,num_mini_batch, mini_batch_size, C] 818 | X = X.reshape(B, num_mini_batch, mini_batch_size, self.width) 819 | 820 | XQ = XQ.reshape(B, self.num_heads, L // mini_batch_size, mini_batch_size, self.head_dim) 821 | XK = XK.reshape(B, self.num_heads, L // mini_batch_size, mini_batch_size, self.head_dim) 822 | XV = XV.reshape(B, self.num_heads, L // mini_batch_size, mini_batch_size, self.head_dim) 823 | 824 | if cache_params is not None: 825 | mini_batch_step_offset = cache_params.seqlen_offset % self.mini_batch_size 826 | else: 827 | mini_batch_step_offset = 0 828 | token_eta, ttt_lr_eta = self.get_eta(X, mini_batch_step_offset, mini_batch_size) 829 | eta = token_eta * ttt_lr_eta 830 | # decouple token_coeff and ilr_coeff for decoding 831 | inputs = { 832 | "XQ": XQ, 833 | "XK": XK, 834 | "XV": XV, 835 | "eta": eta, 836 | "token_eta": token_eta, 837 | "ttt_lr_eta": ttt_lr_eta, 838 | } 839 | return inputs 840 | 841 | def ttt( 842 | self, 843 | inputs, 844 | mini_batch_size, 845 | last_mini_batch_params_dict, 846 | cache_params: Optional[TTTCache] = None, 847 | ): 848 | raise NotImplementedError("ttt method must be implemented in TTTBase subclasses.") 849 | 850 | def forward( 851 | self, 852 | hidden_states: torch.Tensor, 853 | attention_mask: Optional[torch.Tensor] = None, 854 | position_ids: Optional[torch.LongTensor] = None, 855 | cache_params: Optional[TTTCache] = None, 856 | ): 857 | B, L = hidden_states.shape[:2] 858 | reminder_len = L % self.mini_batch_size 859 | num_mini_batch = L // self.mini_batch_size 860 | last_mini_batch_params_dict = None 861 | 862 | XQ, XK, XV = self.get_qkv_projections(hidden_states, cache_params=cache_params) 863 | 864 | # [B, L, C] -> [B, L, num_heads, head_dim] -> [B, num_heads, L, head_dim] 865 | XQ = XQ.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2) 866 | XK = XK.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2) 867 | XV = XV.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2) 868 | 869 | cos, sin = self.rotary_emb(XV, position_ids % self.mini_batch_size) 870 | 871 | # permute_qk and undo_permute_qk is just for aligning pytorch with jax pre-training 872 | XQ, XK = permute_qk(XQ, XK) 873 | XQ, XK = apply_rotary_pos_emb(XQ, XK, cos, sin) 874 | XQ, XK = undo_permute_qk(XQ, XK) 875 | 876 | output_hidden_states = [] 877 | # when input sequence length is not a multiple of mini_batch_size 878 | # we need to compute them seperately, when computing the reminder, 879 | # we will need the last_mini_batch_params_dict to continue TTT learning 880 | if num_mini_batch > 0: 881 | inputs = { 882 | "XQ": XQ[:, :, : num_mini_batch * self.mini_batch_size], 883 | "XK": XK[:, :, : num_mini_batch * self.mini_batch_size], 884 | "XV": XV[:, :, : num_mini_batch * self.mini_batch_size], 885 | "X": hidden_states[:, : num_mini_batch * self.mini_batch_size], 886 | } 887 | output_mod, last_mini_batch_params_dict = self.ttt( 888 | self.get_ttt_inputs(inputs, self.mini_batch_size, cache_params), 889 | mini_batch_size=self.mini_batch_size, 890 | last_mini_batch_params_dict=last_mini_batch_params_dict, 891 | cache_params=cache_params, 892 | ) 893 | output_hidden_states.append(output_mod) 894 | if reminder_len > 0: 895 | inputs = { 896 | "XQ": XQ[:, :, -reminder_len:], 897 | "XK": XK[:, :, -reminder_len:], 898 | "XV": XV[:, :, -reminder_len:], 899 | "X": hidden_states[:, -reminder_len:], 900 | } 901 | output_reminder, _ = self.ttt( 902 | self.get_ttt_inputs(inputs, reminder_len, cache_params), 903 | mini_batch_size=reminder_len, 904 | last_mini_batch_params_dict=last_mini_batch_params_dict, 905 | cache_params=cache_params, 906 | ) 907 | output_hidden_states.append(output_reminder) 908 | 909 | output_hidden_states = torch.cat(output_hidden_states, dim=1) 910 | output_hidden_states = self.post_norm(output_hidden_states) 911 | if self.use_gate: 912 | output_hidden_states = self.apply_gate(hidden_states, output_hidden_states) 913 | output_hidden_states = self.o_proj(output_hidden_states) 914 | 915 | return output_hidden_states 916 | 917 | 918 | class TTTLinear(TTTBase): 919 | def __init__(self, config: TTTConfig, layer_idx: Optional[int] = None): 920 | super().__init__(config, layer_idx) 921 | # TTT model initialization for TTT-Linear 922 | self.W1 = nn.Parameter(torch.normal(0, 0.02, size=(self.num_heads, self.head_dim, self.head_dim))) 923 | self.b1 = nn.Parameter(torch.zeros(self.num_heads, 1, self.head_dim)) 924 | 925 | def ttt( 926 | self, 927 | inputs, 928 | mini_batch_size, 929 | last_mini_batch_params_dict, 930 | cache_params: Optional[TTTCache] = None, 931 | ): 932 | if mini_batch_size is None: 933 | mini_batch_size = self.mini_batch_size 934 | 935 | # in this case, we are decoding 936 | if last_mini_batch_params_dict is None and cache_params is not None: 937 | last_mini_batch_params_dict = cache_params.ttt_params_to_dict(self.layer_idx) 938 | 939 | # [B, num_heads, num_mini_batch, mini_batch_size, head_dim] 940 | B = inputs["XV"].shape[0] 941 | num_mini_batch = inputs["XV"].shape[2] 942 | L = inputs["XV"].shape[2] * inputs["XV"].shape[3] 943 | device = inputs["XV"].device 944 | dtype = inputs["XV"].dtype 945 | 946 | # NOTE: 947 | # for prefilling, we will always use dual form for faster computation 948 | # we need to use primal form if mini_batch_size is not a multiple of self.mini_batch_size 949 | # since we need store the gradient for the next mini-batch computation 950 | use_dual_form = cache_params is None or mini_batch_size % self.mini_batch_size == 0 951 | 952 | def compute_mini_batch(params_dict, inputs): 953 | # [B, nh, f, f], nh=num_heads, f=head_dim 954 | W1_init = params_dict["W1_states"] 955 | # [B, nh, 1, f] 956 | b1_init = params_dict["b1_states"] 957 | 958 | # [B,nh,K,f], K=mini_batch_size 959 | XQ_mini_batch = inputs["XQ"] 960 | XV_mini_batch = inputs["XV"] 961 | XK_mini_batch = inputs["XK"] 962 | # [B, nh, K, 1] 963 | eta_mini_batch = inputs["eta"] 964 | token_eta_mini_batch = inputs["token_eta"] 965 | ttt_lr_eta_mini_batch = inputs["ttt_lr_eta"] 966 | 967 | X1 = XK_mini_batch 968 | # [B,nh,K,f] @ [B,nh,f,f] -> [B,nh,K,f] 969 | Z1 = X1 @ W1_init + b1_init 970 | reconstruction_target = XV_mini_batch - XK_mini_batch 971 | 972 | ln_weight = self.ttt_norm_weight.reshape(self.num_heads, 1, self.head_dim) 973 | ln_bias = self.ttt_norm_bias.reshape(self.num_heads, 1, self.head_dim) 974 | # [B,nh,K,f] 975 | grad_l_wrt_Z1 = ln_fused_l2_bwd(Z1, reconstruction_target, ln_weight, ln_bias) 976 | 977 | if use_dual_form: 978 | # [B,nh,K,K] 979 | Attn1 = torch.tril(XQ_mini_batch @ X1.transpose(-2, -1)) 980 | # [B,nh,1,f] - [B,nh,K,K] @ [B,nh,K,f] -> [B,nh,K,f] 981 | b1_bar = b1_init - torch.tril(eta_mini_batch) @ grad_l_wrt_Z1 982 | # [B,nh,K,f] @ [B,nh,f,f] - ([B,nh,K,1] * [B,nh,K,K]) @ [B,nh,K,f] + [B,nh,K,f] 983 | Z1_bar = XQ_mini_batch @ W1_init - (eta_mini_batch * Attn1) @ grad_l_wrt_Z1 + b1_bar 984 | 985 | last_eta_mini_batch = eta_mini_batch[:, :, -1, :, None] 986 | # [B,nh,f,f] - [B,nh,f,K] @ [B,nh,K,f] 987 | W1_last = W1_init - (last_eta_mini_batch * X1).transpose(-1, -2) @ grad_l_wrt_Z1 988 | # [B,nh,1,f] 989 | b1_last = b1_init - torch.sum(last_eta_mini_batch * grad_l_wrt_Z1, dim=-2, keepdim=True) 990 | grad_W1_last = torch.zeros_like(W1_last) 991 | grad_b1_last = torch.zeros_like(b1_last) 992 | else: 993 | ttt_lr_eta_mini_batch = torch.broadcast_to( 994 | ttt_lr_eta_mini_batch, 995 | ( 996 | *ttt_lr_eta_mini_batch.shape[:2], 997 | mini_batch_size, 998 | mini_batch_size, 999 | ), 1000 | ) 1001 | 1002 | # [B, nh, K, f, f] 1003 | grad_W1 = torch.einsum("bhki,bhkj->bhkij", X1, grad_l_wrt_Z1) 1004 | grad_W1 = torch.einsum("bhnk,bhkij->bhnij", torch.tril(ttt_lr_eta_mini_batch), grad_W1) 1005 | grad_W1 = grad_W1 + params_dict["W1_grad"].unsqueeze(2) 1006 | # [B, nh, K, f] 1007 | grad_b1 = torch.einsum("bhnk,bhki->bhni", torch.tril(ttt_lr_eta_mini_batch), grad_l_wrt_Z1) 1008 | grad_b1 = grad_b1 + params_dict["b1_grad"] 1009 | 1010 | W1_bar = W1_init.unsqueeze(2) - grad_W1 * token_eta_mini_batch.unsqueeze(-1) 1011 | b1_bar = b1_init - grad_b1 * token_eta_mini_batch 1012 | 1013 | # [B, nh, K, 1, f] @ [B, nh, K, f, f] 1014 | Z1_bar = (XQ_mini_batch.unsqueeze(3) @ W1_bar).squeeze(3) + b1_bar 1015 | 1016 | W1_last = W1_bar[:, :, -1] 1017 | b1_last = b1_bar[:, :, -1:] 1018 | grad_W1_last = grad_W1[:, :, -1] 1019 | grad_b1_last = grad_b1[:, :, -1:] 1020 | 1021 | Z1_bar = ln_fwd(Z1_bar, ln_weight, ln_bias) 1022 | 1023 | XQW_mini_batch = XQ_mini_batch + Z1_bar 1024 | 1025 | last_param_dict = { 1026 | "W1_states": W1_last, 1027 | "b1_states": b1_last, 1028 | "W1_grad": grad_W1_last, 1029 | "b1_grad": grad_b1_last, 1030 | } 1031 | return last_param_dict, XQW_mini_batch 1032 | 1033 | if last_mini_batch_params_dict is not None: 1034 | init_params_dict = last_mini_batch_params_dict 1035 | else: 1036 | init_params_dict = { 1037 | "W1_states": torch.tile(self.W1.unsqueeze(0), dims=(B, 1, 1, 1)), 1038 | "b1_states": torch.tile(self.b1.unsqueeze(0), dims=(B, 1, 1, 1)), 1039 | } 1040 | init_params_dict.update(W1_grad=torch.zeros_like(init_params_dict["W1_states"])) 1041 | init_params_dict.update(b1_grad=torch.zeros_like(init_params_dict["b1_states"])) 1042 | 1043 | # [B,num_heads, num_mini_batch, mini_batch_size, f] -> [num_mini_batch, B, num_heads, mini_batch_size, f] 1044 | inputs = tree_map(lambda x: x.permute(2, 0, 1, 3, 4), inputs) 1045 | 1046 | # allocate output tensor 1047 | XQW_batch = torch.empty( 1048 | (num_mini_batch, B, self.num_heads, mini_batch_size, self.head_dim), 1049 | device=device, 1050 | dtype=dtype, 1051 | ) 1052 | # XQW_batch: [num_mini_batch, B, num_heads, mini_batch_size, head_dim] 1053 | batch_params_dict, XQW_batch = scan( 1054 | compute_mini_batch, 1055 | init_params_dict, 1056 | inputs, 1057 | XQW_batch, 1058 | self.config.scan_checkpoint_group_size if self.training else 0, 1059 | ) 1060 | 1061 | # [B, num_heads, L, C] 1062 | if cache_params is not None: 1063 | cache_params.update(batch_params_dict, self.layer_idx, L) 1064 | 1065 | # [num_mini_batch, B, num_heads, mini_batch_size, head_dim] -> [B, num_mini_batch, mini_batch_size, num_heads, head_dim] 1066 | XQW_batch = XQW_batch.permute(1, 0, 3, 2, 4) 1067 | # [B, L, C] 1068 | XQW_batch = XQW_batch.reshape(B, L, self.width) 1069 | return XQW_batch, batch_params_dict 1070 | 1071 | 1072 | class TTTMLP(TTTBase): 1073 | def __init__(self, config: TTTConfig, layer_idx: Optional[int] = None): 1074 | super().__init__(config, layer_idx) 1075 | # TTT model initialization for TTT-MLP 1076 | self.W1 = nn.Parameter(torch.normal(0, 0.02, size=(self.num_heads, self.head_dim, 4 * self.head_dim))) 1077 | self.b1 = nn.Parameter(torch.zeros(self.num_heads, 1, 4 * self.head_dim)) 1078 | self.W2 = nn.Parameter(torch.normal(0, 0.02, size=(self.num_heads, 4 * self.head_dim, self.head_dim))) 1079 | self.b2 = nn.Parameter(torch.zeros(self.num_heads, 1, self.head_dim)) 1080 | 1081 | def ttt( 1082 | self, 1083 | inputs, 1084 | mini_batch_size, 1085 | last_mini_batch_params_dict, 1086 | cache_params: Optional[TTTCache] = None, 1087 | ): 1088 | if mini_batch_size is None: 1089 | mini_batch_size = self.mini_batch_size 1090 | 1091 | # in this case, we are decoding 1092 | if last_mini_batch_params_dict is None and cache_params is not None: 1093 | last_mini_batch_params_dict = cache_params.ttt_params_to_dict(self.layer_idx) 1094 | 1095 | # [B, num_heads, num_mini_batch, mini_batch_size, head_dim] 1096 | B = inputs["XV"].shape[0] 1097 | num_mini_batch = inputs["XV"].shape[2] 1098 | L = inputs["XV"].shape[2] * inputs["XV"].shape[3] 1099 | device = inputs["XV"].device 1100 | dtype = inputs["XV"].dtype 1101 | # NOTE: 1102 | # for prefilling, we will always use dual form for faster computation 1103 | # we need to use primal form if mini_batch_size is not a multiple of self.mini_batch_size 1104 | # since we need store the gradient for the next mini-batch computation 1105 | use_dual_form = cache_params is None or mini_batch_size % self.mini_batch_size == 0 1106 | 1107 | def compute_mini_batch(params_dict, inputs): 1108 | # [B, nh, f, 4f] 1109 | W1_init = params_dict["W1_states"] 1110 | # [B, nh, 1, 4f] 1111 | b1_init = params_dict["b1_states"] 1112 | # [B, nh, 4f, f] 1113 | W2_init = params_dict["W2_states"] 1114 | # [B, nh, 1, f] 1115 | b2_init = params_dict["b2_states"] 1116 | 1117 | # [B,nh,K,f] 1118 | XQ_mini_batch = inputs["XQ"] 1119 | XV_mini_batch = inputs["XV"] 1120 | XK_mini_batch = inputs["XK"] 1121 | # [B,nh,K,1] 1122 | eta_mini_batch = inputs["eta"] 1123 | token_eta_mini_batch = inputs["token_eta"] 1124 | ttt_lr_eta_mini_batch = inputs["ttt_lr_eta"] 1125 | 1126 | X1 = XK_mini_batch 1127 | # [B,nh,K,f] @ [B,nh,f,4f] -> [B,nh,K,4f] 1128 | Z1 = X1 @ W1_init + b1_init 1129 | X2 = F.gelu(Z1, approximate="tanh") 1130 | # [B,nh,K,4f] @ [B,nh,4f,f] -> [B,nh,K,f] 1131 | Z2 = X2 @ W2_init + b2_init 1132 | reconstruction_target = XV_mini_batch - XK_mini_batch 1133 | 1134 | ln_weight = self.ttt_norm_weight.reshape(self.num_heads, 1, self.head_dim) 1135 | ln_bias = self.ttt_norm_bias.reshape(self.num_heads, 1, self.head_dim) 1136 | # [B, nh, K, f] 1137 | grad_l_wrt_Z2 = ln_fused_l2_bwd(Z2, reconstruction_target, ln_weight, ln_bias) 1138 | # [B, nh, K, 4f] 1139 | grad_l_wrt_Z1 = grad_l_wrt_Z2 @ W2_init.transpose(-2, -1) * gelu_bwd(Z1) 1140 | 1141 | if use_dual_form: 1142 | Attn1 = torch.tril(XQ_mini_batch @ X1.transpose(-2, -1)) # [B,nh,K,K] 1143 | # [B,nh,1,f] - [B,nh,K,K] @ [B,nh,K,4f] -> [B,nh,K,4f] 1144 | b1_bar = b1_init - torch.tril(eta_mini_batch) @ grad_l_wrt_Z1 1145 | # [B,nh,K,f] @ [B,nh,f,4f] - ([B,nh,K,1] * [B,nh,K,K]) @ [B,nh,K,4f] + [B,nh,K,4f] 1146 | Z1_bar = XQ_mini_batch @ W1_init - (eta_mini_batch * Attn1) @ grad_l_wrt_Z1 + b1_bar 1147 | X2_bar = F.gelu(Z1_bar, approximate="tanh") 1148 | 1149 | # [B,nh,K,K] 1150 | Attn2 = torch.tril(X2_bar @ X2.transpose(-2, -1)) 1151 | # [B,nh,1,f] - [B,nh,K,1] * [B,nh,K,f] -> [B,nh,K,f] 1152 | b2_bar = b2_init - torch.tril(eta_mini_batch) @ grad_l_wrt_Z2 1153 | # [B,nh,K,f] @ [1,nh,4f,f] - ([B,nh,K,1] * [B,nh,K,K]) @ [B,nh,K,f] + [B,nh,K,f] 1154 | Z2_bar = X2_bar @ W2_init - (eta_mini_batch * Attn2) @ grad_l_wrt_Z2 + b2_bar 1155 | 1156 | last_eta_mini_batch = eta_mini_batch[:, :, -1, :, None] 1157 | # [B,nh,f,4f] - [B,nh,f,K] @ [B,nh,K,4f] 1158 | W1_last = W1_init - (last_eta_mini_batch * X1).transpose(-1, -2) @ grad_l_wrt_Z1 1159 | # [B,nh,1,4f] 1160 | b1_last = b1_init - torch.sum(last_eta_mini_batch * grad_l_wrt_Z1, dim=-2, keepdim=True) 1161 | # [B,nh,4f,f] - [B,nh,4f,K] @ [B,nh,K,f] 1162 | W2_last = W2_init - (last_eta_mini_batch * X2).transpose(-1, -2) @ grad_l_wrt_Z2 1163 | # [B,nh,1,f] 1164 | b2_last = b2_init - torch.sum(last_eta_mini_batch * grad_l_wrt_Z2, dim=-2, keepdim=True) 1165 | grad_W1_last = torch.zeros_like(W1_last) 1166 | grad_b1_last = torch.zeros_like(b1_last) 1167 | grad_W2_last = torch.zeros_like(W2_last) 1168 | grad_b2_last = torch.zeros_like(b2_last) 1169 | 1170 | else: 1171 | ttt_lr_eta_mini_batch = torch.broadcast_to( 1172 | ttt_lr_eta_mini_batch, 1173 | ( 1174 | *ttt_lr_eta_mini_batch.shape[:2], 1175 | mini_batch_size, 1176 | mini_batch_size, 1177 | ), 1178 | ) 1179 | 1180 | # [B, nh, K, 4f, f] 1181 | grad_W2 = torch.einsum("bhki,bhkj->bhkij", X2, grad_l_wrt_Z2) 1182 | grad_W2 = torch.einsum("bhnk,bhkij->bhnij", torch.tril(ttt_lr_eta_mini_batch), grad_W2) 1183 | grad_W2 = grad_W2 + params_dict["W2_grad"].unsqueeze(2) 1184 | # [B, nh, K, f] 1185 | grad_b2 = torch.einsum("bhnk,bhki->bhni", torch.tril(ttt_lr_eta_mini_batch), grad_l_wrt_Z2) 1186 | grad_b2 = grad_b2 + params_dict["b2_grad"] 1187 | 1188 | # [B, nh, K, f, 4f] 1189 | grad_W1 = torch.einsum("bhki,bhkj->bhkij", X1, grad_l_wrt_Z1) 1190 | grad_W1 = torch.einsum("bhnk,bhkij->bhnij", torch.tril(ttt_lr_eta_mini_batch), grad_W1) 1191 | grad_W1 = grad_W1 + params_dict["W1_grad"].unsqueeze(2) 1192 | # [B, nh, K, 4f] 1193 | grad_b1 = torch.einsum("bhnk,bhki->bhni", torch.tril(ttt_lr_eta_mini_batch), grad_l_wrt_Z1) 1194 | grad_b1 = grad_b1 + params_dict["b1_grad"] 1195 | 1196 | W1_bar = W1_init.unsqueeze(2) - grad_W1 * token_eta_mini_batch.unsqueeze(-1) 1197 | b1_bar = b1_init - grad_b1 * token_eta_mini_batch 1198 | W2_bar = W2_init.unsqueeze(2) - grad_W2 * token_eta_mini_batch.unsqueeze(-1) 1199 | b2_bar = b2_init - grad_b2 * token_eta_mini_batch 1200 | 1201 | # [B, nh, K, 1, f] @ [B, nh, K, f, 4f] -> [B, nh, K, 4f] 1202 | Z1_bar = (XQ_mini_batch.unsqueeze(3) @ W1_bar).squeeze(3) + b1_bar 1203 | X2_bar = F.gelu(Z1_bar, approximate="tanh") 1204 | Z2_bar = (X2_bar.unsqueeze(3) @ W2_bar).squeeze(3) + b2_bar 1205 | 1206 | W1_last = W1_bar[:, :, -1] 1207 | b1_last = b1_bar[:, :, -1:] 1208 | W2_last = W2_bar[:, :, -1] 1209 | b2_last = b2_bar[:, :, -1:] 1210 | grad_W1_last = grad_W1[:, :, -1] 1211 | grad_b1_last = grad_b1[:, :, -1:] 1212 | grad_W2_last = grad_W2[:, :, -1] 1213 | grad_b2_last = grad_b2[:, :, -1:] 1214 | 1215 | Z2_bar = ln_fwd(Z2_bar, ln_weight, ln_bias) 1216 | 1217 | XQW_mini_batch = XQ_mini_batch + Z2_bar 1218 | 1219 | last_param_dict = { 1220 | "W1_states": W1_last, 1221 | "b1_states": b1_last, 1222 | "W2_states": W2_last, 1223 | "b2_states": b2_last, 1224 | "W1_grad": grad_W1_last, 1225 | "b1_grad": grad_b1_last, 1226 | "W2_grad": grad_W2_last, 1227 | "b2_grad": grad_b2_last, 1228 | } 1229 | return last_param_dict, XQW_mini_batch 1230 | 1231 | if last_mini_batch_params_dict is not None: 1232 | init_params_dict = last_mini_batch_params_dict 1233 | else: 1234 | init_params_dict = { 1235 | "W1_states": torch.tile(self.W1.unsqueeze(0), dims=(B, 1, 1, 1)), 1236 | "b1_states": torch.tile(self.b1.unsqueeze(0), dims=(B, 1, 1, 1)), 1237 | "W2_states": torch.tile(self.W2.unsqueeze(0), dims=(B, 1, 1, 1)), 1238 | "b2_states": torch.tile(self.b2.unsqueeze(0), dims=(B, 1, 1, 1)), 1239 | } 1240 | init_params_dict.update(W1_grad=torch.zeros_like(init_params_dict["W1_states"])) 1241 | init_params_dict.update(b1_grad=torch.zeros_like(init_params_dict["b1_states"])) 1242 | init_params_dict.update(W2_grad=torch.zeros_like(init_params_dict["W2_states"])) 1243 | init_params_dict.update(b2_grad=torch.zeros_like(init_params_dict["b2_states"])) 1244 | inputs = tree_map(lambda x: x.permute(2, 0, 1, 3, 4), inputs) # [B,nh,NC,CS,f] -> [NC,B,nh,CS,f] 1245 | # allocate output tensor 1246 | XQW_batch = torch.empty( 1247 | (num_mini_batch, B, self.num_heads, mini_batch_size, self.head_dim), 1248 | device=device, 1249 | dtype=dtype, 1250 | ) 1251 | # XQW_batch: [num_mini_batch, B, num_heads, mini_batch_size, head_dim] 1252 | batch_params_dict, XQW_batch = scan( 1253 | compute_mini_batch, 1254 | init_params_dict, 1255 | inputs, 1256 | XQW_batch, 1257 | self.config.scan_checkpoint_group_size if self.training else 0, 1258 | ) 1259 | 1260 | # [B, num_heads, L, C] 1261 | if cache_params is not None: 1262 | cache_params.update(batch_params_dict, self.layer_idx, L) 1263 | 1264 | # [num_mini_batch, B, num_heads, mini_batch_size, head_dim] -> [B, num_mini_batch, mini_batch_size, num_heads, head_dim] 1265 | XQW_batch = XQW_batch.permute(1, 0, 3, 2, 4) 1266 | # [B, L, C] 1267 | XQW_batch = XQW_batch.reshape(B, L, self.width) 1268 | return XQW_batch, batch_params_dict 1269 | 1270 | 1271 | ################################ 1272 | ### E2E Architecture Modules ### 1273 | ################################ 1274 | 1275 | 1276 | class Block(nn.Module): 1277 | def __init__(self, config: TTTConfig, layer_idx: int): 1278 | super().__init__() 1279 | self.hidden_size = config.hidden_size 1280 | self.pre_conv = config.pre_conv 1281 | 1282 | if config.ttt_layer_type == "linear": 1283 | ttt_layer = TTTLinear 1284 | elif config.ttt_layer_type == "mlp": 1285 | ttt_layer = TTTMLP 1286 | else: 1287 | raise ValueError(f"Invalid ttt_layer_type: {config.ttt_layer_type}") 1288 | 1289 | self.seq_modeling_block = ttt_layer(config=config, layer_idx=layer_idx) 1290 | 1291 | self.mlp = SwiGluMLP(config) 1292 | if self.pre_conv: 1293 | self.conv = Conv(config, layer_idx) 1294 | 1295 | self.seq_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 1296 | self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 1297 | self.layer_idx = layer_idx 1298 | 1299 | def forward( 1300 | self, 1301 | hidden_states: torch.Tensor, 1302 | attention_mask: Optional[torch.Tensor] = None, 1303 | position_ids: Optional[torch.LongTensor] = None, 1304 | cache_params: Optional[TTTCache] = None, 1305 | ): 1306 | if self.pre_conv: 1307 | residual = hidden_states 1308 | hidden_states = self.conv(hidden_states, cache_params=cache_params) 1309 | hidden_states = residual + hidden_states 1310 | 1311 | residual = hidden_states 1312 | 1313 | hidden_states = self.seq_norm(hidden_states) 1314 | 1315 | # TTT Layer 1316 | hidden_states = self.seq_modeling_block( 1317 | hidden_states=hidden_states, 1318 | attention_mask=attention_mask, 1319 | position_ids=position_ids, 1320 | cache_params=cache_params, 1321 | ) 1322 | hidden_states = residual + hidden_states 1323 | 1324 | # Feed-Forward-Network 1325 | residual = hidden_states 1326 | hidden_states = self.ffn_norm(hidden_states) 1327 | hidden_states = self.mlp(hidden_states) 1328 | hidden_states = residual + hidden_states 1329 | 1330 | return hidden_states 1331 | 1332 | 1333 | class TTTPreTrainedModel(PreTrainedModel): 1334 | config_class = TTTConfig 1335 | base_model_prefix = "model" 1336 | supports_gradient_checkpointing = True 1337 | _no_split_modules = ["Block"] 1338 | 1339 | def _init_weights(self, module): 1340 | std = self.config.initializer_range 1341 | if isinstance(module, nn.Linear): 1342 | module.weight.data.normal_(mean=0.0, std=std) 1343 | if module.bias is not None: 1344 | module.bias.data.zero_() 1345 | elif isinstance(module, nn.Embedding): 1346 | module.weight.data.normal_(mean=0.0, std=std) 1347 | if module.padding_idx is not None: 1348 | module.weight.data[module.padding_idx].zero_() 1349 | 1350 | 1351 | @dataclass 1352 | class TTTOutput(ModelOutput): 1353 | """ 1354 | Class for the TTT model outputs. 1355 | 1356 | Args: 1357 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 1358 | Sequence of hidden-states at the output of the last layer of the model. 1359 | cache_params (`TTTCache`): 1360 | The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to 1361 | avoid providing the old `input_ids`. 1362 | """ 1363 | 1364 | last_hidden_state: Optional[torch.FloatTensor] = None 1365 | cache_params: Optional[TTTCache] = None 1366 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 1367 | 1368 | 1369 | @dataclass 1370 | class TTTCausalLMOutput(ModelOutput): 1371 | """ 1372 | Base class for causal language model (or autoregressive) outputs. 1373 | 1374 | Args: 1375 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): 1376 | Language modeling loss (for next-token prediction). 1377 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): 1378 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 1379 | cache_params (`TTTCache`): 1380 | The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to 1381 | avoid providing the old `input_ids`. 1382 | """ 1383 | 1384 | loss: Optional[torch.FloatTensor] = None 1385 | logits: Optional[torch.FloatTensor] = None 1386 | cache_params: Optional[TTTCache] = None 1387 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 1388 | 1389 | 1390 | class TTTModel(TTTPreTrainedModel): 1391 | """ 1392 | Decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Block`] 1393 | 1394 | Args: 1395 | config: TTTConfig 1396 | """ 1397 | 1398 | def __init__(self, config: TTTConfig): 1399 | super().__init__(config) 1400 | self.padding_idx = config.pad_token_id 1401 | self.vocab_size = config.vocab_size 1402 | 1403 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 1404 | self.layers = nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) 1405 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 1406 | self.gradient_checkpointing = False 1407 | 1408 | # Initialize weights and apply final processing 1409 | self.post_init() 1410 | 1411 | def get_input_embeddings(self): 1412 | return self.embed_tokens 1413 | 1414 | def set_input_embeddings(self, value): 1415 | self.embed_tokens = value 1416 | 1417 | def forward( 1418 | self, 1419 | input_ids: torch.LongTensor = None, 1420 | attention_mask: Optional[torch.Tensor] = None, 1421 | position_ids: Optional[torch.LongTensor] = None, 1422 | inputs_embeds: Optional[torch.FloatTensor] = None, 1423 | cache_params: Optional[TTTCache] = None, 1424 | output_hidden_states: Optional[bool] = None, 1425 | return_dict: Optional[bool] = None, 1426 | use_cache: Optional[bool] = None, 1427 | ) -> Union[Tuple, BaseModelOutputWithPast]: 1428 | output_hidden_states = ( 1429 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1430 | ) 1431 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1432 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1433 | 1434 | if (input_ids is None) ^ (inputs_embeds is not None): 1435 | raise ValueError( 1436 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" 1437 | ) 1438 | 1439 | if self.gradient_checkpointing and self.training and use_cache: 1440 | logger.warning_once( 1441 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." 1442 | ) 1443 | use_cache = False 1444 | 1445 | if inputs_embeds is None: 1446 | inputs_embeds = self.embed_tokens(input_ids) 1447 | 1448 | if cache_params is None and use_cache: 1449 | cache_params = TTTCache(self, inputs_embeds.size(0)) 1450 | 1451 | seqlen_offset = 0 1452 | if cache_params is not None: 1453 | seqlen_offset = cache_params.seqlen_offset 1454 | position_ids = torch.arange( 1455 | seqlen_offset, 1456 | seqlen_offset + inputs_embeds.shape[1], 1457 | dtype=torch.long, 1458 | device=inputs_embeds.device, 1459 | ).unsqueeze(0) 1460 | 1461 | hidden_states = inputs_embeds 1462 | 1463 | if attention_mask is None: 1464 | attention_mask = torch.ones_like(input_ids) 1465 | 1466 | # decoder layers 1467 | all_hidden_states = () if output_hidden_states else None 1468 | 1469 | for decoder_layer in self.layers: 1470 | if self.gradient_checkpointing and self.training: 1471 | hidden_states = self._gradient_checkpointing_func( 1472 | decoder_layer.__call__, 1473 | hidden_states, 1474 | attention_mask, 1475 | position_ids, 1476 | cache_params, 1477 | ) 1478 | else: 1479 | hidden_states = decoder_layer( 1480 | hidden_states, 1481 | attention_mask=attention_mask, 1482 | position_ids=position_ids, 1483 | cache_params=cache_params, 1484 | ) 1485 | 1486 | if output_hidden_states: 1487 | all_hidden_states = all_hidden_states + (hidden_states,) 1488 | 1489 | if use_cache: 1490 | cache_params.seqlen_offset += inputs_embeds.shape[1] 1491 | 1492 | hidden_states = self.norm(hidden_states) 1493 | 1494 | # add hidden states from the last decoder layer 1495 | if output_hidden_states: 1496 | all_hidden_states += (hidden_states,) 1497 | 1498 | if not return_dict: 1499 | return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) 1500 | 1501 | return TTTOutput( 1502 | last_hidden_state=hidden_states, 1503 | cache_params=cache_params if use_cache else None, 1504 | hidden_states=all_hidden_states, 1505 | ) 1506 | 1507 | 1508 | class TTTForCausalLM(TTTPreTrainedModel): 1509 | _tied_weights_keys = ["lm_head.weight"] 1510 | 1511 | def __init__(self, config): 1512 | super().__init__(config) 1513 | self.model = TTTModel(config) 1514 | self.vocab_size = config.vocab_size 1515 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1516 | 1517 | # Initialize weights and apply final processing 1518 | self.post_init() 1519 | 1520 | def get_input_embeddings(self): 1521 | return self.model.embed_tokens 1522 | 1523 | def set_input_embeddings(self, value): 1524 | self.model.embed_tokens = value 1525 | 1526 | def get_output_embeddings(self): 1527 | return self.lm_head 1528 | 1529 | def set_output_embeddings(self, new_embeddings): 1530 | self.lm_head = new_embeddings 1531 | 1532 | def set_decoder(self, decoder): 1533 | self.model = decoder 1534 | 1535 | def get_decoder(self): 1536 | return self.model 1537 | 1538 | def _update_model_kwargs_for_generation( 1539 | self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs 1540 | ) -> Dict[str, Any]: 1541 | model_kwargs["cache_params"] = outputs.get("cache_params", None) 1542 | # update attention mask 1543 | if "attention_mask" in model_kwargs: 1544 | attention_mask = model_kwargs["attention_mask"] 1545 | model_kwargs["attention_mask"] = torch.cat( 1546 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], 1547 | dim=-1, 1548 | ) 1549 | return model_kwargs 1550 | 1551 | def prepare_inputs_for_generation( 1552 | self, 1553 | input_ids, 1554 | attention_mask=None, 1555 | cache_params: Optional[TTTCache] = None, 1556 | inputs_embeds=None, 1557 | **kwargs, 1558 | ): 1559 | # only last token for inputs_ids if the state is passed along. 1560 | if cache_params is not None: 1561 | input_ids = input_ids[:, -1].unsqueeze(-1) 1562 | attention_mask = attention_mask[:, -1].unsqueeze(-1) if attention_mask is not None else None 1563 | 1564 | if inputs_embeds is not None and cache_params is None: 1565 | model_inputs = {"inputs_embeds": inputs_embeds} 1566 | else: 1567 | model_inputs = {"input_ids": input_ids} 1568 | 1569 | model_inputs.update( 1570 | { 1571 | "cache_params": cache_params, 1572 | "use_cache": kwargs.get("use_cache"), 1573 | "attention_mask": attention_mask, 1574 | } 1575 | ) 1576 | 1577 | return model_inputs 1578 | 1579 | def forward( 1580 | self, 1581 | input_ids: torch.LongTensor = None, 1582 | attention_mask: Optional[torch.Tensor] = None, 1583 | position_ids: Optional[torch.LongTensor] = None, 1584 | inputs_embeds: Optional[torch.FloatTensor] = None, 1585 | cache_params: Optional[TTTCache] = None, 1586 | labels: Optional[torch.LongTensor] = None, 1587 | output_hidden_states: Optional[bool] = None, 1588 | return_dict: Optional[bool] = None, 1589 | use_cache: Optional[bool] = None, 1590 | *, 1591 | output_attentions: Optional[bool] = None, 1592 | ) -> Union[Tuple, CausalLMOutputWithPast]: 1593 | """ 1594 | Args: 1595 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1596 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1597 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1598 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1599 | """ 1600 | output_hidden_states = ( 1601 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1602 | ) 1603 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1604 | assert not output_attentions, "output_attentions is not available in TTTForCausalLM" 1605 | 1606 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1607 | outputs = self.model( 1608 | input_ids=input_ids, 1609 | attention_mask=attention_mask, 1610 | position_ids=position_ids, 1611 | cache_params=cache_params, 1612 | inputs_embeds=inputs_embeds, 1613 | output_hidden_states=output_hidden_states, 1614 | return_dict=return_dict, 1615 | use_cache=use_cache, 1616 | ) 1617 | 1618 | hidden_states = outputs[0] 1619 | if self.config.pretraining_tp > 1: 1620 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) 1621 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] 1622 | logits = torch.cat(logits, dim=-1) 1623 | else: 1624 | logits = self.lm_head(hidden_states) 1625 | logits = logits.float() 1626 | 1627 | loss = None 1628 | if labels is not None: 1629 | # Shift so that tokens < n predict n 1630 | shift_logits = logits[..., :-1, :].contiguous() 1631 | shift_labels = labels[..., 1:].contiguous() 1632 | # Flatten the tokens 1633 | loss_fct = CrossEntropyLoss() 1634 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 1635 | shift_labels = shift_labels.view(-1) 1636 | # Enable model parallelism 1637 | shift_labels = shift_labels.to(shift_logits.device) 1638 | loss = loss_fct(shift_logits, shift_labels) 1639 | 1640 | if not return_dict: 1641 | output = (logits,) + outputs[1:] 1642 | return (loss,) + output if loss is not None else output 1643 | 1644 | return TTTCausalLMOutput( 1645 | loss=loss, 1646 | logits=logits, 1647 | cache_params=outputs.cache_params, 1648 | hidden_states=outputs.hidden_states, 1649 | ) 1650 | --------------------------------------------------------------------------------