├── .gitignore ├── LICENSE ├── MANIFEST.in ├── examples ├── __init__.py └── modinfiniformer.py ├── infini_transformer ├── __init__.py ├── activations.py ├── compressive_memory.py ├── positional_embeddings.py └── transformer.py ├── pyproject.toml ├── readme.md ├── requirements.txt └── tests ├── __init__.py └── test_transformer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Python files 2 | *.pyc 3 | __pycache__/ 4 | 5 | # Virtual environment 6 | venv/ 7 | .venv/ 8 | 9 | # IDE and editor files 10 | .vscode/ 11 | .idea/ 12 | *.sublime-project 13 | *.sublime-workspace 14 | 15 | # Jupyter Notebook checkpoints 16 | .ipynb_checkpoints/ 17 | 18 | # Python package distribution files 19 | dist/ 20 | build/ 21 | *.egg-info/ 22 | 23 | # Local development files 24 | .env 25 | .env.* 26 | .python-version 27 | 28 | # macOS files 29 | .DS_Store 30 | 31 | # Logs and temporary files 32 | *.log 33 | *.swp 34 | *.bak -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ryan Taylor 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include pyproject.toml 4 | include requirements.txt 5 | 6 | recursive-include infini_transformer *.py 7 | recursive-include examples *.py 8 | recursive-include tests *.py -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | # examples/__init__.py 2 | 3 | # This file is intentionally left empty. 4 | # It is required to make the 'examples' directory a Python package. -------------------------------------------------------------------------------- /examples/modinfiniformer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | from torch.utils.data import DataLoader 6 | 7 | from infini_transformer import MoDInfiniTransformer, InfiniTransformer, YaRNEmbeddings 8 | 9 | 10 | class NextTokenModel(nn.Module): 11 | """ 12 | An Infini-Transformer based model for next token prediction. 13 | """ 14 | def __init__( 15 | self, 16 | vocab_size: int, 17 | embedding_dim: int, 18 | num_layers: int, 19 | dim_hidden: int, 20 | dim_key: int, 21 | dim_value: int, 22 | num_heads: int, 23 | segment_len: int, 24 | sampling_factor: int, 25 | update="linear", 26 | causal: bool = False, 27 | init_state_learnable: bool = False, 28 | dropout: float = 0.0, 29 | ): 30 | """ 31 | Initialize the module. 32 | Parameters: 33 | vocab_size (int): Size of the vocabulary. 34 | embedding_dim (int): Dimensionality of the embedding space. 35 | num_layers (int): Number of Infini-transformer layers. 36 | dim_hidden (int): Hidden dimension for the MLP. 37 | dim_key (int): Key dimension for the CompressiveMemory. 38 | dim_value (int): Value dimension for the CompressiveMemory. 39 | num_heads (int): Number of attention heads for the CompressiveMemory. 40 | segment_len (int): Segment length for the CompressiveMemory. 41 | sampling_factor (int): Reciprocal of the sampling rate for the Mixture-of-Depths mechanism. 42 | update (str, optional): Type of memory update rule to use for the CompressiveMemory ("linear" or "delta"). Defaults to "linear". 43 | causal (bool, optional): Whether to use causal attention masking for the CompressiveMemory. Defaults to False 44 | init_state_learnable (bool, optional): Whether the initial state of the CompressiveMemory should be learnable. Defaults to False. 45 | dropout (float, optional): Dropout rate for the MLP. Defaults to 0.0. 46 | """ 47 | super(NextTokenModel, self).__init__() 48 | 49 | self.embedding = nn.Embedding(vocab_size, embedding_dim) 50 | 51 | transformers = [] 52 | for layer_n in range(num_layers): 53 | if layer_n % 2 == 0: 54 | transformers.append( 55 | MoDInfiniTransformer( 56 | dim_input=embedding_dim, 57 | dim_hidden=dim_hidden, 58 | dim_key=dim_key, 59 | dim_value=dim_value, 60 | num_heads=num_heads, 61 | segment_len=segment_len, 62 | sampling_factor=sampling_factor, 63 | update=update, 64 | causal=causal, 65 | position_embedder=YaRNEmbeddings( 66 | dim=dim_key, 67 | seq_len=segment_len, 68 | context_len=32768, 69 | context_len_ext=65536, 70 | dim_embedding_pct=0.5, 71 | base=10000, 72 | alpha=1, 73 | beta=32, 74 | length_scale=None 75 | ), 76 | init_state_learnable=init_state_learnable, 77 | dropout=dropout 78 | ) 79 | ) 80 | else: 81 | transformers.append( 82 | InfiniTransformer( 83 | dim_input=embedding_dim, 84 | dim_hidden=dim_hidden, 85 | dim_key=dim_key, 86 | dim_value=dim_value, 87 | num_heads=num_heads, 88 | segment_len=segment_len, 89 | update=update, 90 | causal=causal, 91 | position_embedder=YaRNEmbeddings( 92 | dim=dim_key, 93 | seq_len=segment_len, 94 | context_len=32768, 95 | context_len_ext=65536, 96 | dim_embedding_pct=0.5, 97 | base=10000, 98 | alpha=1, 99 | beta=32, 100 | length_scale=None 101 | ), 102 | init_state_learnable=init_state_learnable, 103 | dropout=dropout 104 | ) 105 | ) 106 | self.transformers = nn.ModuleList(transformers) 107 | 108 | self.proj_final = nn.Linear(embedding_dim, vocab_size) 109 | self.softmax = nn.Softmax(dim=-1) 110 | 111 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor], List[Optional[torch.Tensor]]]: 112 | """ 113 | Forward pass of the model for computing the next token probabilities and other outputs if applicable. 114 | 115 | Parameters: 116 | x (torch.Tensor): Input tensor containing token indices. 117 | 118 | Returns: 119 | Tuple containing the probability distribution over the next tokens, actual modification tokens (if any), and predicted modification tokens (if any). 120 | """ 121 | mod_token_actuals = [] 122 | mod_token_preds = [] 123 | 124 | x = self.embedding(x) 125 | for ix, transformer in enumerate(self.transformers): 126 | if ix % 2 == 0: 127 | x, mod_token_actual, mod_token_pred = transformer(x) 128 | mod_token_actuals.append(mod_token_actual) 129 | mod_token_preds.append(mod_token_pred) 130 | else: 131 | x = transformer(x) 132 | 133 | x = self.proj_final(x) 134 | 135 | next_token_probs = self.softmax(x) 136 | 137 | return next_token_probs, mod_token_actuals, mod_token_preds 138 | 139 | 140 | def train_model( 141 | model: NextTokenModel, 142 | dataloader_train: DataLoader, 143 | dataloader_val: DataLoader, 144 | epochs: int, 145 | device: str 146 | ): 147 | # Switch model to training mode and move to the specified device 148 | model = model.train() 149 | model = model.to(device=device) 150 | 151 | # Optimizer 152 | optimizer = torch.optim.AdamW( 153 | model.parameters(), 154 | lr=1e-4, 155 | betas=(0.9, 0.95), 156 | weight_decay=0.01 157 | ) 158 | 159 | # Learning rate scheduler 160 | lr_schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) 161 | 162 | # Main loss function 163 | loss_fn_main = nn.CrossEntropyLoss().to(device) 164 | # Auxiliary loss function 165 | loss_fn_sampling = nn.functional.binary_cross_entropy_with_logits 166 | 167 | for epoch in range(epochs): 168 | for ix, batch in enumerate(dataloader_train): 169 | # Move data batch to specified device 170 | batch = batch.to(device) 171 | # Target labels for training 172 | target = batch[:, 1:].clone() 173 | # Generate predictions and auxiliary outputs from model 174 | preds, mod_actuals, mod_preds = model(batch) 175 | 176 | # Calculate the main loss (Cross-Entropy) 177 | loss_main = loss_fn_main(input=preds[:, :-1, :], target=target) 178 | 179 | # Calculate auxiliary loss 180 | loss_aux = torch.tensor(0.0, device=device) 181 | for mod_actual, mod_pred in zip(mod_actuals, mod_preds): 182 | loss_aux += loss_fn_sampling(input=mod_pred, target=mod_actual) 183 | 184 | # Total loss is the sum of main and auxiliary losses 185 | loss = loss_main + loss_aux 186 | 187 | # Clear gradients, perform backpropagation, and update the model parameters 188 | optimizer.zero_grad() 189 | loss.backward() 190 | optimizer.step() 191 | 192 | print( 193 | f'Epoch: {epoch + 1}/epochs ({ix + 1}/{len(dataloader_train)}) | Training Loss: {loss_main.detach().cpu().item():.6f}\r', 194 | end="" 195 | ) 196 | 197 | # Update the learning rate schedule after each epoch 198 | lr_schedule.step() 199 | 200 | # Validation phase 201 | with torch.no_grad(): 202 | total_loss = 0.0 203 | num_obs = 0 204 | 205 | for batch in dataloader_val: 206 | batch = batch.to(device) 207 | target = batch[:, 1:].clone() 208 | 209 | preds, _, _ = model(batch) 210 | 211 | total_loss += loss_fn_main(input=preds[:, :-1, :], target=target).detach().cpu().item() * batch.size(0) 212 | num_obs += batch.size(0) 213 | 214 | # Calculate the average validation loss 215 | val_loss = total_loss / num_obs 216 | 217 | print( 218 | f'\nEpoch: {epoch + 1}/{epochs} | Validation Loss -- (CE): {val_loss:.6f}' 219 | ) 220 | 221 | return model 222 | -------------------------------------------------------------------------------- /infini_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # infini_transformer/__init__.py 2 | 3 | from .compressive_memory import CompressiveMemory 4 | from .positional_embeddings import RoPEEmbeddings, YaRNEmbeddings 5 | from .transformer import InfiniTransformer, MoDInfiniTransformer 6 | 7 | 8 | __all__ = [ 9 | "InfiniTransformer", 10 | "MoDInfiniTransformer", 11 | "CompressiveMemory", 12 | "RoPEEmbeddings", 13 | "YaRNEmbeddings" 14 | ] -------------------------------------------------------------------------------- /infini_transformer/activations.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | 6 | class Swish(torch.nn.Module): 7 | """Swish activation module""" 8 | def __init__(self, beta: Optional[float] = None): 9 | """Initialize the module. 10 | 11 | Args: 12 | beta (Optional[float], optional): Shape parameter. If None, it's a learnable parameter. Defaults to None. 13 | """ 14 | super(Swish, self).__init__() 15 | # If beta is None, make it a learnable parameter 16 | if beta is None: 17 | self.beta = torch.nn.Parameter(torch.ones(1)) 18 | # Otherwise, set it to a fixed constant 19 | else: 20 | self.beta = beta 21 | 22 | def forward(self, x: torch.Tensor) -> torch.Tensor: 23 | """Forward pass. 24 | 25 | Args: 26 | x (torch.Tensor): Input tensor. 27 | 28 | Returns: 29 | torch.Tensor: Activation tensor. 30 | """ 31 | return x * torch.nn.functional.sigmoid(self.beta * x) 32 | 33 | class SwiGLU(torch.nn.Module): 34 | """SwiGLU activation module.""" 35 | def __init__(self, dim: int): 36 | """Initialize the module. 37 | 38 | Args: 39 | dim (int): Dimension of the input tensor. 40 | """ 41 | super(SwiGLU, self).__init__() 42 | self.swish = Swish() 43 | self.W = torch.nn.Parameter(torch.randn(dim, dim) / dim ** 0.5) 44 | self.V = torch.nn.Parameter(torch.randn(dim, dim) / dim ** 0.5) 45 | self.b = torch.nn.Parameter(torch.zeros(dim)) 46 | self.c = torch.nn.Parameter(torch.zeros(dim)) 47 | 48 | def forward(self, x: torch.Tensor) -> torch.Tensor: 49 | """Forward pass. 50 | 51 | Args: 52 | x (torch.Tensor): Input tensor. 53 | 54 | Returns: 55 | torch.Tensor: Activation tensor. 56 | """ 57 | return self.swish(x @ self.W + self.b) * (x @ self.V + self.c) 58 | 59 | 60 | class GEGLU(torch.nn.Module): 61 | """GEGLU activation module.""" 62 | def __init__(self, dim: int): 63 | """Initialize the module. 64 | 65 | Args: 66 | dim (int): Dimension of the input tensor. 67 | """ 68 | super(GEGLU, self).__init__() 69 | self.W = torch.nn.Parameter(torch.randn(dim, dim) / dim ** 0.5) 70 | self.V = torch.nn.Parameter(torch.randn(dim, dim) / dim ** 0.5) 71 | self.b = torch.nn.Parameter(torch.zeros(dim)) 72 | self.c = torch.nn.Parameter(torch.zeros(dim)) 73 | 74 | def forward(self, x: torch.Tensor) -> torch.Tensor: 75 | """Forward pass. 76 | 77 | Args: 78 | x (torch.Tensor): Input tensor. 79 | 80 | Returns: 81 | torch.Tensor: Activation tensor. 82 | """ 83 | return torch.nn.functional.gelu(x @ self.W + self.b) * (x @ self.V + self.c) 84 | 85 | class FFNGLU(torch.nn.Module): 86 | """FFN GLU activation module.""" 87 | def __init__(self, dim: int): 88 | """Initialize the module. 89 | 90 | Args: 91 | dim (int): Dimension of the input tensor. 92 | """ 93 | super(FFNGLU, self).__init__() 94 | inner_dim = math.ceil(dim * 2 / 3) 95 | self.W1 = torch.nn.Parameter(torch.randn(dim, inner_dim) / inner_dim ** 0.5) 96 | self.W2 = torch.nn.Parameter(torch.randn(inner_dim, dim) / dim ** 0.5) 97 | self.V = torch.nn.Parameter(torch.randn(dim, inner_dim) / inner_dim ** 0.5) 98 | 99 | def forward(self, x: torch.Tensor) -> torch.Tensor: 100 | """Forward pass. 101 | 102 | Args: 103 | x (torch.Tensor): Input tensor. 104 | 105 | Returns: 106 | torch.Tensor: Activation tensor. 107 | """ 108 | return (torch.nn.functional.sigmoid(x @ self.W1) * (x @ self.V)) @ self.W2 109 | 110 | class FFNGEGLU(torch.nn.Module): 111 | """FFN GELU activation module.""" 112 | def __init__(self, dim: int): 113 | """Initialize the module. 114 | 115 | Args: 116 | dim (int): Dimension of the input tensor. 117 | """ 118 | super(FFNGEGLU, self).__init__() 119 | inner_dim = math.ceil(dim * 2 / 3) 120 | self.W1 = torch.nn.Parameter(torch.randn(dim, inner_dim) / inner_dim ** 0.5) 121 | self.W2 = torch.nn.Parameter(torch.randn(inner_dim, dim) / dim ** 0.5) 122 | self.V = torch.nn.Parameter(torch.randn(dim, inner_dim) / inner_dim ** 0.5) 123 | 124 | def forward(self, x: torch.Tensor) -> torch.Tensor: 125 | """Forward pass. 126 | 127 | Args: 128 | x (torch.Tensor): Input tensor. 129 | 130 | Returns: 131 | torch.Tensor: Activation tensor. 132 | """ 133 | return (torch.nn.functional.gelu(x @ self.W1) * (x @ self.V)) @ self.W2 134 | 135 | class FFNSwiGLU(torch.nn.Module): 136 | """FFN SwiGLU activation module.""" 137 | def __init__(self, dim: int): 138 | """Initialize the module. 139 | 140 | Args: 141 | dim (int): Dimension of the input tensor. 142 | """ 143 | super(FFNSwiGLU, self).__init__() 144 | inner_dim = math.ceil(dim * 2 / 3) 145 | self.swish = Swish(beta=1) 146 | self.W1 = torch.nn.Parameter(torch.randn(dim, inner_dim) / inner_dim ** 0.5) 147 | self.W2 = torch.nn.Parameter(torch.randn(inner_dim, dim) / dim ** 0.5) 148 | self.V = torch.nn.Parameter(torch.randn(dim, inner_dim) / inner_dim ** 0.5) 149 | 150 | def forward(self, x: torch.Tensor) -> torch.Tensor: 151 | """Forward pass. 152 | 153 | Args: 154 | x (torch.Tensor): Input tensor. 155 | 156 | Returns: 157 | torch.Tensor: Activation tensor. 158 | """ 159 | return (self.swish(x @ self.W1) * (x @ self.V)) @ self.W2 160 | 161 | class Abs(torch.nn.Module): 162 | """Absolute value activation module.""" 163 | def __init__(self): 164 | """Initialize the module.""" 165 | super(Abs, self).__init__() 166 | 167 | def forward(self, x: torch.Tensor) -> torch.Tensor: 168 | """Forward pass. 169 | 170 | Args: 171 | x (torch.Tensor): Input tensor. 172 | 173 | Returns: 174 | torch.Tensor: Activation tensor. 175 | """ 176 | return torch.abs(x) 177 | 178 | # Importable container for available activations 179 | ACTIVATIONS = { 180 | "relu": torch.nn.ReLU, 181 | "gelu": torch.nn.GELU, 182 | "swish": Swish, 183 | "swiglu": SwiGLU, 184 | "geglu": GEGLU, 185 | "ffnglu": FFNGLU, 186 | "ffngeglu": FFNGEGLU, 187 | "ffnswiglu": FFNSwiGLU, 188 | "abs": Abs 189 | } 190 | -------------------------------------------------------------------------------- /infini_transformer/compressive_memory.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional, Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from .positional_embeddings import PositionEmbeddings 7 | 8 | class CompressiveMemory(nn.Module): 9 | """Implements the Compressive Transformer memory module as described in "Leave No Context Behind: 10 | Efficient Infinite Context Transformers with Infini-attention" by Munkhdalai et al. 11 | (https://arxiv.org/abs/2404.07143)""" 12 | 13 | def __init__( 14 | self, 15 | dim_input: int, 16 | dim_key: int, 17 | dim_value: int, 18 | num_heads: int, 19 | segment_len: int, 20 | sampling_factor: Optional[int] = None, 21 | update: str = "linear", 22 | causal: bool = False, 23 | position_embedder: Optional[PositionEmbeddings] = None, 24 | init_state_learnable: bool = False 25 | ): 26 | """Initialize module. 27 | 28 | Args: 29 | dim_input (int): Input dimension. 30 | dim_key (int): Key dimension. 31 | dim_value (int): Value dimension. 32 | num_heads (int): Number of attention heads. 33 | segment_len (int): Segment length (must be a factor of the input sequence length). 34 | sampling_factor (Optional[int], optional): Reciprocal of the sampling rate for the Mixture-of-Depths mechanism. Defaults to None. 35 | update (str, optional): Type of memory update rule to use ("linear" or "delta"). Defaults to "linear". 36 | causal (bool, optional): Whether to use causal attention masking. Defaults to False. 37 | position_embedder (Optional[PositionEmbeddings], optional): Position embedding module. Defaults to None. 38 | init_state_learnable (bool, optional): Whether the initial memory and normalization are learnable. Defaults to False. 39 | """ 40 | super(CompressiveMemory, self).__init__() 41 | 42 | # Record input parameters 43 | self.num_heads = num_heads 44 | self.segment_len = segment_len 45 | self.sampling_factor = sampling_factor 46 | 47 | self.dim_input = dim_input 48 | self.dim_key = dim_key 49 | self.dim_value = dim_value 50 | 51 | self.update = update 52 | self.causal = causal 53 | 54 | self.position_embedder = position_embedder 55 | 56 | # Projections for stacked SDP attention 57 | self.proj_k = nn.Linear(dim_input, num_heads * dim_key, bias=False) 58 | self.proj_v = nn.Linear(dim_input, num_heads * dim_value, bias=False) 59 | self.proj_q = nn.Linear(dim_input, num_heads * dim_key, bias=False) 60 | 61 | # Initialize betas for weighted average of dot-product and memory-based attention 62 | self.betas = nn.Parameter(torch.randn(1, num_heads, 1, dim_value)) 63 | 64 | # Projection for output 65 | self.proj_out = nn.Linear(num_heads * dim_value, dim_input, bias=False) 66 | 67 | # If init_state_learnable is set, create parameters for the initial memory matrix 68 | # and normalization vector; otherwise, set them to None 69 | if init_state_learnable: 70 | self.init_mem = nn.Parameter(torch.randn(1, self.num_heads, self.dim_key, self.dim_value)) 71 | self.init_z = nn.Parameter(torch.ones(1, self.num_heads, self.dim_key, 1)) 72 | else: 73 | self.init_mem = None 74 | self.init_z = None 75 | 76 | def forward(self, x: torch.Tensor, sample_mask: Optional[torch.Tensor] = None) -> torch.Tensor: 77 | """ 78 | Applies Compressive Memory Attention to the input tensor. 79 | 80 | Args: 81 | x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim_input). 82 | sample_mask (Optional[torch.Tensor], optional): Mask tensor of shape (batch_size, seq_len) used to sample the input sequence. Defaults to None. 83 | Returns: 84 | torch.Tensor: Output tensor of shape (batch_size, seq_len, dim_input). 85 | """ 86 | batch_size, seq_len, _ = x.shape 87 | 88 | num_segments, rem = divmod(seq_len, self.segment_len) 89 | num_segments += 1 if rem > 0 else 0 90 | 91 | out = [] 92 | 93 | # Initialize mem and normalization 94 | if self.init_mem is not None and self.init_z is not None: 95 | mem = self.init_mem 96 | z = self.init_z 97 | else: 98 | # !!! Initialization was never specified in the paper, so this is an educated guess 99 | mem = torch.zeros(1, self.num_heads, self.dim_key, self.dim_value) 100 | z = torch.ones(batch_size, self.num_heads, self.dim_key, 1) / self.dim_key 101 | 102 | # Project the input tensor to get the key, value, and query tensors 103 | k_full = self.proj_k(x).unsqueeze(1).view( 104 | (batch_size, self.num_heads, x.size(1), self.dim_key)) 105 | v_full = self.proj_v(x).unsqueeze(1).view( 106 | (batch_size, self.num_heads, x.size(1), self.dim_value)) 107 | q_full = self.proj_q(x).unsqueeze(1).view( 108 | (batch_size, self.num_heads, x.size(1), self.dim_key)) 109 | 110 | for ix in range(num_segments): 111 | ix_lo = ix * self.segment_len 112 | ix_hi = min(ix_lo + self.segment_len, x.size(1)) 113 | seg_len = ix_hi - ix_lo 114 | 115 | # Extract segment from key, value and query tensors 116 | k = k_full[:, :, ix_lo:ix_hi, :] 117 | v = v_full[:, :, ix_lo:ix_hi, :] 118 | q = q_full[:, :, ix_lo:ix_hi, :] 119 | 120 | # If sample_mask was given, extract segment from it 121 | if sample_mask is not None: 122 | if self.sampling_factor is None: 123 | raise ValueError("sampling_factor must be specified if sample_mask is provided") 124 | ix_lo_seg = ix * self.segment_len * self.sampling_factor 125 | ix_hi_seg = min(ix_lo_seg + self.segment_len * self.sampling_factor, sample_mask.size(1)) 126 | sample_mask_seg = sample_mask[:, ix_lo_seg:ix_hi_seg] 127 | else: 128 | sample_mask_seg = None 129 | 130 | # If position embedder is specified, add positional embeddings to q and k 131 | if self.position_embedder is not None: 132 | if sample_mask is None: 133 | k_pos = self.position_embedder(k, total_seq_len=seq_len, offset=ix_lo) 134 | q_pos = self.position_embedder(q, total_seq_len=seq_len, offset=ix_lo) 135 | else: 136 | k_pos = self.position_embedder(k, total_seq_len=seq_len, offset=ix_lo_seg, select_mask=sample_mask_seg) 137 | q_pos = self.position_embedder(q, total_seq_len=seq_len, offset=ix_lo_seg, select_mask=sample_mask_seg) 138 | 139 | # Pre-calculate sigma(q) for updating memory and calculating attention 140 | # The calculation is described on page 4 of the paper under the subsection 141 | # "Memory retrieval" 142 | # shape: (batch_size, num_heads, segment_len, dim_key) 143 | sigma_q = (nn.functional.elu(q) + 1.0) 144 | 145 | # Apply SDP attention, as part of equation (2) of the paper 146 | if self.position_embedder is not None: 147 | scores = q_pos @ k_pos.transpose(-2, -1) / self.dim_key ** 0.5 148 | else: 149 | scores = q @ k.transpose(-2, -1) / self.dim_key ** 0.5 150 | 151 | # If causal mask specified, calculate and apply it 152 | if self.causal: 153 | mask = torch.tril(torch.ones((seg_len, seg_len), dtype=torch.bool), diagonal=0) 154 | mask = mask.unsqueeze(0).unsqueeze(0).repeat((batch_size, self.num_heads, 1, 1)) 155 | scores.masked_fill_(torch.logical_not(mask), float('-inf')) 156 | 157 | # Calculate SDP attention, completing equation (2) of the paper 158 | att_dot = nn.functional.softmax(scores, dim=-1) @ v 159 | 160 | # Calculate normalized linear attention 161 | # The calculation is described in equation (3) of the paper 162 | # shape: (batch_size, num_heads, segment_len, dim_value) 163 | att_mem = (sigma_q @ mem) / (sigma_q @ z) 164 | 165 | # Apply mem update 166 | # The update rules are described in equations (4) and (5) of the paper 167 | sigma_k = nn.functional.elu(k) + 1.0 168 | if self.update == "linear": 169 | mem = mem + sigma_k.transpose(-2, -1) @ v 170 | elif self.update == "delta": 171 | mem = mem + \ 172 | sigma_k.transpose(-2, -1) @ (v - (sigma_k @ mem) / (sigma_k @ z)) 173 | 174 | # Apply normalization term update 175 | # The calculation is described in equation (4) of the paper 176 | z = z + (nn.functional.elu(k) + 1.0).sum(dim=-2, keepdim=True).transpose(-2, -1) 177 | 178 | # Calculate weighted average of dot-product and memory-based attention 179 | # The calculation is described in equation (6) of the paper 180 | att = nn.functional.sigmoid( 181 | self.betas) * att_mem + (1 - nn.functional.sigmoid(self.betas)) * att_dot 182 | att = att.view((batch_size, seg_len, 183 | self.num_heads * self.dim_value)) 184 | 185 | # Append output to buffer 186 | # The calculation is described in equation (7) of the paper 187 | out.append(self.proj_out(att)) 188 | 189 | # Return concatenated full sequence from buffer 190 | out = torch.concat(out, dim=1) 191 | 192 | return out 193 | 194 | 195 | def test_compressive_memory( 196 | short_seq_len: bool = False, 197 | even_seq_len: bool = True, 198 | causal_masking: bool = False, 199 | update: str = "linear" 200 | ) -> None: 201 | # Set example module parameters 202 | dim_input = 512 203 | dim_key = 64 204 | dim_value = 64 205 | num_heads = 8 206 | segment_len = 32 207 | causal = causal_masking 208 | 209 | # Set dummy input dimensions 210 | batch_size = 4 211 | 212 | # Handle sequence length based on test case 213 | if short_seq_len: 214 | seq_len = 16 215 | else: 216 | if even_seq_len: 217 | seq_len = 128 218 | else: 219 | seq_len = 144 220 | 221 | # Initialize module 222 | model = CompressiveMemory( 223 | dim_input, dim_key, dim_value, num_heads, segment_len, update, causal) 224 | 225 | # Generate random input 226 | batch = torch.randn(batch_size, seq_len, dim_input) 227 | 228 | # Apply the CompressiveMemory module 229 | model(batch) 230 | 231 | 232 | if __name__ == "__main__": 233 | # Test all cases with short sequence lengths 234 | print("Testing with short sequence lengths:") 235 | 236 | short_seq_len = True 237 | # In this case even_seq_len doesn't matter -- arbitrarily setting it to True 238 | even_seq_len = True 239 | 240 | for causal_masking in [True, False]: 241 | for update in ["linear", "delta"]: 242 | print(f" Testing with causal_masking={causal_masking} and update={update}") 243 | test_compressive_memory( 244 | short_seq_len=short_seq_len, 245 | even_seq_len=even_seq_len, 246 | causal_masking=causal_masking, 247 | update=update 248 | ) 249 | 250 | # Test all cases with short sequence lengths 251 | print("Testing with non-short sequence lengths:") 252 | 253 | short_seq_len = False 254 | 255 | for even_seq_len in [True, False]: 256 | for causal_masking in [True, False]: 257 | for update in ["linear", "delta"]: 258 | print(f" Testing with even_seq_len={even_seq_len}, causal_masking={causal_masking} and update={update}") 259 | test_compressive_memory( 260 | short_seq_len=short_seq_len, 261 | even_seq_len=even_seq_len, 262 | causal_masking=causal_masking, 263 | update=update 264 | ) -------------------------------------------------------------------------------- /infini_transformer/positional_embeddings.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | 9 | class PositionEmbeddings(nn.Module): 10 | """Base class for different types of positional embeddings.""" 11 | def __init__(self): 12 | super(PositionEmbeddings, self).__init__() 13 | 14 | def forward(x: torch.Tensor) -> torch.Tensor: 15 | raise NotImplementedError("Subclasses must implement the forward method.") 16 | 17 | 18 | class RoPEEmbeddings(PositionEmbeddings): 19 | """Implements rotary positional embeddings (RoPE) as described in the paper: 20 | "RoFormer: Enhanced Transformer with Rotary Position Embedding" by Su et al. 21 | (https://arxiv.org/abs/2104.09864). 22 | 23 | Modifications have been made to make it compatible with both Infini-Attention 24 | and Mixture-of-Experts.""" 25 | def __init__(self, dim: int, seq_len: int, dim_embedding_pct: float = 0.5, base: int = 10000): 26 | """Instantiate the module. 27 | 28 | Args: 29 | dim (int): Key/Value dimension of the attention layer. 30 | seq_len (int): Maximum sequence length. 31 | dim_embedding_pct (float): Percentage of the total embedding dimension to use for the positional embeddings. Must be within the interval (0, 1]. Defaults to 0.5. 32 | base (int, optional): Base used for calculating thetas. Defaults to 10000. 33 | """ 34 | super(RoPEEmbeddings, self).__init__() 35 | 36 | # Record input parameters 37 | self.dim = dim 38 | self.effective_dim = int(dim * dim_embedding_pct) 39 | self.seq_len = seq_len 40 | self.dim_embedding_pct = dim_embedding_pct 41 | self.base = base 42 | self.last_offset = 0 43 | 44 | # Initialize theta matrix 45 | self._calculate_thetas() 46 | 47 | # Initialize sin component indices for input tensor 48 | # Indices for rearranging the input follow the pattern [1, 0, 3, 2, 5, 4, ...] 49 | # Indices that need to be negated in calculating the positional embeddings are [0, 2, 4, ...] 50 | self.ixs_sin = torch.empty(self.effective_dim, dtype=torch.long) 51 | self.ixs_sin_neg = 2 * torch.arange(self.effective_dim // 2) 52 | self.ixs_sin[self.ixs_sin_neg] = self.ixs_sin_neg + 1 53 | self.ixs_sin[self.ixs_sin_neg + 1] = self.ixs_sin_neg 54 | 55 | def _calculate_thetas(self, offset: int = 0, select_mask: Optional[torch.Tensor] = None) -> None: 56 | """Calculate the cosine and sine component matrices for the rotary positional embeddings. 57 | Uses multidimensional extension of theta as defined in Sec 3.2.2 as well as equation (34) 58 | from the RoFormer paper 59 | 60 | Args: 61 | offset (int, optional): Position offset for Infini-Former compatibility. Defaults to 0. 62 | select_mask (Optional[torch.Tensor], optional): Mask to select a subset of the positional embeddings for Mixture-of-Depths compatibility. Defaults to None. 63 | """ 64 | if select_mask is None: 65 | # Calculate matrix of angles: thetas[i,j] = base^(-2 * ceil(i/2)) * (j + offset) 66 | thetas = torch.repeat_interleave( 67 | (self.base ** (-2. * torch.arange(1, self.effective_dim//2 + 1))).unsqueeze(-1).repeat((1, self.seq_len)), 68 | repeats=2, 69 | dim=0 70 | ) 71 | # Multiply by index positions, then transpose to get correct shape 72 | thetas *= torch.arange(1 + offset, self.seq_len + 1 + offset).unsqueeze(0) 73 | self.thetas = thetas.transpose(0, 1).unsqueeze(0).unsqueeze(0) 74 | else: 75 | # (n_obs, select_seq_len) 76 | select_ixs = 1 + offset + torch.argwhere(select_mask)[:, 1].view((select_mask.size(0), -1)) 77 | # (n_obs, select_seq_len, effective_dim) 78 | select_ixs = select_ixs.unsqueeze(-1).repeat((1, 1, self.effective_dim)) 79 | # (effective_dim, select_seq_len) 80 | thetas = torch.repeat_interleave( 81 | (self.base ** (-2. * torch.arange(1, self.effective_dim//2 + 1))).unsqueeze(-1).repeat((1, select_ixs.size(1))), 82 | repeats=2, 83 | dim=0 84 | ) 85 | # (n_obs, select_seq_len, effective_dim) 86 | thetas = thetas.transpose(0, 1).unsqueeze(0).repeat((select_mask.size(0), 1, 1)) 87 | thetas *= select_ixs 88 | self.thetas = thetas.unsqueeze(1) 89 | 90 | def forward(self, x: torch.Tensor, total_seq_len: int = 0, offset: int = 0, select_mask: Optional[torch.Tensor] = None) -> torch.Tensor: 91 | """Applies rotary positional embeddings to the input tensor. Uses a multidimensional 92 | extension of equation (34) of the RoFormer paper. 93 | 94 | Args: 95 | x (torch.Tensor): Input tensor of shape (batch_size, num_heads, seq_len, dim). 96 | total_seq_len (int, optional): Unused input for YaRN compatibility. Defaults to 0. 97 | offset (int, optional): Position offset for Infini-Former compatibility. Defaults to 0. 98 | select_mask (Optional[torch.Tensor], optional): Mask to select a subset of the positional embeddings for Mixture-of-Depths compatibility. Defaults to None. 99 | 100 | Returns: 101 | torch.Tensor: Transformed input tensor with rotary positional embeddings applied. 102 | """ 103 | if offset != self.last_offset: 104 | self._calculate_thetas(offset=offset, select_mask=select_mask) 105 | self.last_offset = offset 106 | cos_sin_recalculated = True 107 | else: 108 | cos_sin_recalculated = False 109 | 110 | if self.dim_embedding_pct < 1.0: 111 | x_pos = x[..., :self.effective_dim] 112 | x_pass = x[..., self.effective_dim:] 113 | else: 114 | x_pos = x 115 | 116 | # If no selection mask is specified, add embeddings as usual 117 | if select_mask is None: 118 | # If the sequence length is less than the maximum sequence length, perform calculations 119 | # with truncated cos_component and sin_component, along the sequence axis 120 | if x.size(2) < self.seq_len: 121 | x_cos = self.thetas.cos()[:, :, :x_pos.size(2), :].repeat(x_pos.size(0), x_pos.size(1), 1, 1) * x_pos 122 | x_sin = x_pos[..., self.ixs_sin] 123 | x_sin[..., self.ixs_sin_neg] = -x_sin[...,self.ixs_sin_neg] 124 | x_sin *= self.thetas.sin()[:, :, :x_pos.size(2), :].repeat(x_pos.size(0), x_pos.size(1), 1, 1) 125 | # Otherwise, perform calculations with the full cos_component and sin_component 126 | else: 127 | x_cos = self.thetas.cos().repeat(x_pos.size(0), x_pos.size(1), 1, 1) * x_pos 128 | x_sin = x_pos[..., self.ixs_sin] 129 | x_sin[..., self.ixs_sin_neg] = -x_sin[...,self.ixs_sin_neg] 130 | x_sin *= self.thetas.sin().repeat(x_pos.size(0), x_pos.size(1), 1, 1) 131 | # If a selection mask is specified, incorporate it into the positional embeddings 132 | else: 133 | if not cos_sin_recalculated: 134 | self._calculate_thetas(offset=offset, select_mask=select_mask) 135 | self.last_offset = offset 136 | x_cos = self.thetas.cos().repeat(1, x_pos.size(1), 1, 1) * x_pos 137 | x_sin = x_pos[..., self.ixs_sin] 138 | x_sin[..., self.ixs_sin_neg] = -x_sin[...,self.ixs_sin_neg] 139 | x_sin *= self.thetas.sin().repeat(1, x_pos.size(1), 1, 1) 140 | 141 | # If the sequence length is less than the maximum sequence length, concatenate positionally embedded 142 | # entries with original entries, otherwise return the positionally embedded entries 143 | if self.dim_embedding_pct < 1.0: 144 | out = torch.cat([x_cos + x_sin, x_pass], dim=-1) 145 | else: 146 | out = x_cos + x_sin 147 | 148 | return out 149 | 150 | class YaRNEmbeddings(PositionEmbeddings): 151 | """Implements Yet Another RoPE ExtensioN (YaRN) as described in the paper: 152 | "YaRN: Efficient Context Window Extension of Large Language Models" by Peng et al. 153 | (https://arxiv.org/abs/2309.00071). 154 | 155 | Modifications have been made to make it compatible with both Infini-Attention 156 | and Mixture-of-Experts.""" 157 | def __init__( 158 | self, 159 | dim: int, 160 | seq_len: int, 161 | context_len: int, 162 | context_len_ext: int, 163 | dim_embedding_pct: float = 0.5, 164 | base: int = 10000, 165 | alpha: int = 1, 166 | beta: int = 32, 167 | length_scale: Optional[float] = None 168 | ): 169 | """Instantiate the module. 170 | 171 | Args: 172 | dim (int): Key/Value dimension of the attention layer. 173 | seq_len (int): Maximum sequence length. 174 | context_len (int): Length of the context window. 175 | context_len_ext (int): Extended length of the context window. 176 | dim_embedding_pct (float): Percentage of the total embedding dimension to use for the positional embeddings. Must be within the interval (0, 1]. Defaults to 0.5. 177 | base (int, optional): Base used for calculating thetas. Defaults to 10000. 178 | alpha (int, optional): Interpolation minimum for dynamic scaling. Defaults to 1. 179 | beta (int, optional): Interpolation maximum for dynamic scaling. Defaults to 32. 180 | len_scale (Optional[float], optional): Length scale for attention calculation. Defaults to None. 181 | """ 182 | super(YaRNEmbeddings, self).__init__() 183 | 184 | # Record input parameters 185 | self.dim = dim 186 | self.effective_dim = int(dim * dim_embedding_pct) 187 | self.seq_len = seq_len 188 | self.context_len = context_len 189 | self.context_len_ext = context_len_ext 190 | self.dim_embedding_pct = dim_embedding_pct 191 | self.base = base 192 | self.alpha = alpha 193 | self.beta = beta 194 | self.length_scale = length_scale 195 | 196 | self.last_offset = -1 197 | 198 | # Initialize sin component indices for input tensor 199 | # Indices for rearranging the input follow the pattern [1, 0, 3, 2, 5, 4, ...] 200 | # Indices that need to be negated in calculating the positional embeddings are [0, 2, 4, ...] 201 | self.ixs_sin = torch.empty(self.effective_dim, dtype=torch.long) 202 | self.ixs_sin_neg = 2 * torch.arange(self.effective_dim // 2) 203 | self.ixs_sin[self.ixs_sin_neg] = self.ixs_sin_neg + 1 204 | self.ixs_sin[self.ixs_sin_neg + 1] = self.ixs_sin_neg 205 | 206 | def _scale_factor(self, seq_len: int) -> float: 207 | """Calculate the scale factor for the given sequence length from section 3.3 in the paper. 208 | 209 | Args: 210 | seq_len (int): The sequence length to calculate the scale factor for. 211 | 212 | Returns: 213 | float: The scale factor for the given sequence length. 214 | """ 215 | return max(1., seq_len / self.context_len) 216 | 217 | def _base_ext(self, seq_len: int) -> float: 218 | """Calculate the extended base from equation (16) in the paper. 219 | 220 | Args: 221 | seq_len (int): The sequence length to calculate the extended base for. 222 | 223 | Returns: 224 | float: The extended base for the given sequence length. 225 | """ 226 | return self.base * (self._scale_factor(seq_len) ** (self.dim / (self.dim - 2))) 227 | 228 | def _wavelength_d(self, d: torch.Tensor) -> torch.Tensor: 229 | """Calculate the wavelength for the given dimension index tensor from equation (13) in the paper. 230 | 231 | Args: 232 | d (torch.Tensor): Tensor of dimension indices to calculate the wavelengths for. 233 | 234 | Returns: 235 | torch.Tensor: The wavelengths of the given dimension index tensor. 236 | """ 237 | return 2. * math.pi * self.base ** (2 * d / self.dim) 238 | 239 | def _wavelength_theta(self, theta: torch.Tensor) -> torch.Tensor: 240 | """Calculate the wavelength of the given theta tensor from equation (13) in the paper. 241 | 242 | Args: 243 | theta (torch.Tensor): The theta tensor to calculate the wavelengths for. 244 | 245 | Returns: 246 | torch.Tensor: The wavelengths of the given theta tensor. 247 | """ 248 | return 2. * math.pi / theta 249 | 250 | def _wavelength_context_ratio(self, wavelength: torch.Tensor) -> torch.Tensor: 251 | """Calculate the wavelength ratio for the context extension from equation (17) in the paper. 252 | 253 | Args: 254 | wavelength (torch.Tensor): The wavelengths to calculate the ratio for. 255 | 256 | Returns: 257 | torch.Tensor: The wavelength ratio for the context extension. 258 | """ 259 | return self.context_len / wavelength 260 | 261 | def _ramp(self, ratio: torch.Tensor) -> torch.Tensor: 262 | """Calculate the ramp function from equation (18) in the paper. 263 | 264 | Args: 265 | ratio (torch.Tensor): The wavelength ratio to calculate the ramp function for. 266 | 267 | Returns: 268 | torch.Tensor: The ramp function values for the given wavelength ratio. 269 | """ 270 | out = torch.zeros_like(ratio, device=ratio.device, dtype=torch.float) 271 | interp_mask = torch.logical_and(ratio >= self.alpha, ratio <= self.beta) 272 | one_mask = ratio > self.beta 273 | out[interp_mask] = (ratio[interp_mask] - self.alpha) / (self.beta - self.alpha) 274 | out[one_mask] = 1. 275 | return out 276 | 277 | def _calculate_thetas(self, total_seq_len: int, offset: int = 0, select_mask: Optional[torch.Tensor] = None) -> None: 278 | """Calculate the cosine and sine component matrices for the rotary positional embeddings. 279 | Uses multidimensional extension of theta as defined in Sec 3.2.2 as well as equation (34) 280 | from the RoFormer paper 281 | 282 | Args: 283 | total_seq_len (int): The total sequence length to calculate the thetas for. 284 | offset (int, optional): Position offset for Infini-Former compatibility. Defaults to 0. 285 | select_mask (Optional[torch.Tensor], optional): Mask to select a subset of the positional embeddings for Mixture-of-Depths compatibility. Defaults to None. 286 | """ 287 | if select_mask is None: 288 | # Calculate matrix of angles 289 | # Shape: (effective_dim, seq_len) 290 | thetas = torch.repeat_interleave( 291 | (self.base ** (-2. * torch.arange(1, self.effective_dim//2 + 1) / self.dim)).unsqueeze(-1).repeat((1, self.seq_len)), 292 | repeats=2, 293 | dim=0 294 | ) 295 | ramp = self._ramp(self._wavelength_context_ratio(self._wavelength_theta(thetas))) 296 | scale = self._scale_factor(total_seq_len) 297 | length_scale = 0.1 * math.log(scale) + 1. if self.length_scale is None else self.length_scale 298 | thetas = ((1. - ramp) * thetas / scale + ramp * thetas) * length_scale 299 | # Multiply by index positions, then transpose to get correct shape 300 | thetas *= torch.arange(1 + offset, self.seq_len + 1 + offset).unsqueeze(0) 301 | self.thetas = thetas.transpose(0, 1).unsqueeze(0).unsqueeze(0) 302 | else: 303 | # (n_obs, select_seq_len) 304 | select_ixs = 1 + offset + torch.argwhere(select_mask)[:, 1].view((select_mask.size(0), -1)) 305 | # (n_obs, select_seq_len, effective_dim) 306 | select_ixs = select_ixs.unsqueeze(-1).repeat((1, 1, self.effective_dim)) 307 | # (effective_dim, select_seq_len) 308 | thetas = torch.repeat_interleave( 309 | (self.base ** (-2. * torch.arange(1, self.effective_dim//2 + 1))).unsqueeze(-1).repeat((1, select_ixs.size(1))), 310 | repeats=2, 311 | dim=0 312 | ) 313 | ramp = self._ramp(self._wavelength_context_ratio(self._wavelength_theta(thetas))) 314 | scale = self._scale_factor(total_seq_len) 315 | length_scale = 0.1 * math.log(scale) + 1. if self.length_scale is None else self.length_scale 316 | thetas = ((1. - ramp) * thetas / scale + ramp * thetas) * length_scale 317 | # (n_obs, select_seq_len, effective_dim) 318 | thetas = thetas.transpose(0, 1).unsqueeze(0).repeat((select_mask.size(0), 1, 1)) 319 | thetas *= select_ixs 320 | self.thetas = thetas.unsqueeze(1) 321 | 322 | def forward(self, x: torch.Tensor, total_seq_len: int, offset: int = 0, select_mask: Optional[torch.Tensor] = None) -> torch.Tensor: 323 | """Applies rotary positional embeddings to the input tensor. Uses a multidimensional 324 | extension of equation (34) of the RoFormer paper. 325 | 326 | Args: 327 | x (torch.Tensor): Input tensor of shape (batch_size, num_heads, seq_len, dim). 328 | total_seq_len (int): The total sequence length of the input. 329 | offset (int, optional): Position offset for Infini-Former compatibility. Defaults to 0. 330 | select_mask (Optional[torch.Tensor], optional): Mask to select a subset of the positional embeddings for Mixture-of-Depths compatibility. Defaults to None. 331 | 332 | Returns: 333 | torch.Tensor: Transformed input tensor with rotary positional embeddings applied. 334 | """ 335 | if offset != self.last_offset: 336 | self._calculate_thetas(total_seq_len=total_seq_len, offset=offset, select_mask=select_mask) 337 | self.last_offset = offset 338 | cos_sin_recalculated = True 339 | else: 340 | cos_sin_recalculated = False 341 | 342 | if self.dim_embedding_pct < 1.0: 343 | x_pos = x[..., :self.effective_dim] 344 | x_pass = x[..., self.effective_dim:] 345 | else: 346 | x_pos = x 347 | 348 | # If no selection mask is specified, add embeddings as usual 349 | if select_mask is None: 350 | # If the sequence length is less than the maximum sequence length, perform calculations 351 | # with truncated cos_component and sin_component, along the sequence axis 352 | if x.size(2) < self.seq_len: 353 | x_cos = self.thetas.cos()[:, :, :x_pos.size(2), :].repeat(x_pos.size(0), x_pos.size(1), 1, 1) * x_pos 354 | x_sin = x_pos[..., self.ixs_sin] 355 | x_sin[..., self.ixs_sin_neg] = -x_sin[...,self.ixs_sin_neg] 356 | x_sin *= self.thetas.sin()[:, :, :x_pos.size(2), :].repeat(x_pos.size(0), x_pos.size(1), 1, 1) 357 | # Otherwise, perform calculations with the full cos_component and sin_component 358 | else: 359 | x_cos = self.thetas.cos().repeat(x_pos.size(0), x_pos.size(1), 1, 1) * x_pos 360 | x_sin = x_pos[..., self.ixs_sin] 361 | x_sin[..., self.ixs_sin_neg] = -x_sin[...,self.ixs_sin_neg] 362 | x_sin *= self.thetas.sin().repeat(x_pos.size(0), x_pos.size(1), 1, 1) 363 | # If a selection mask is specified, incorporate it into the positional embeddings 364 | else: 365 | if not cos_sin_recalculated: 366 | self._calculate_thetas(total_seq_len=total_seq_len, offset=offset, select_mask=select_mask) 367 | self.last_offset = offset 368 | x_cos = self.thetas.cos().repeat(1, x_pos.size(1), 1, 1) * x_pos 369 | x_sin = x_pos[..., self.ixs_sin] 370 | x_sin[..., self.ixs_sin_neg] = -x_sin[...,self.ixs_sin_neg] 371 | x_sin *= self.thetas.sin().repeat(1, x_pos.size(1), 1, 1) 372 | 373 | # If the sequence length is less than the maximum sequence length, concatenate positionally embedded 374 | # entries with original entries, otherwise return the positionally embedded entries 375 | if self.dim_embedding_pct < 1.0: 376 | out = torch.cat([x_cos + x_sin, x_pass], dim=-1) 377 | else: 378 | out = x_cos + x_sin 379 | 380 | return out -------------------------------------------------------------------------------- /infini_transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from .activations import ACTIVATIONS 8 | from .compressive_memory import CompressiveMemory 9 | from .positional_embeddings import PositionEmbeddings 10 | 11 | 12 | class InfiniTransformer(nn.Module): 13 | """Transformer layer with compressive memory.""" 14 | 15 | def __init__( 16 | self, 17 | dim_input: int, 18 | dim_hidden: int, 19 | dim_key: int, 20 | dim_value: int, 21 | num_heads: int, 22 | activation: str, 23 | segment_len: int, 24 | update: str = "linear", 25 | causal: bool = False, 26 | position_embedder: Optional[PositionEmbeddings] = None, 27 | init_state_learnable: bool = False, 28 | dropout: float = 0.0, 29 | **kwargs 30 | ): 31 | """Initializes the module. 32 | 33 | Args: 34 | dim_input (int): Input dimension. 35 | dim_hidden (int): Hidden dimension for the MLP. 36 | dim_key (int): Key dimension for the CompressiveMemory. 37 | dim_value (int): Value dimension for the CompressiveMemory. 38 | num_heads (int): Number of attention heads for the CompressiveMemory. 39 | activation (str): Activation function to use for the MLP. Must be a key in the ACTIVATIONS dictionary. 40 | segment_len (int): Segment length for the CompressiveMemory. 41 | update (str, optional): Type of memory update rule to use for the CompressiveMemory ("linear" or "delta"). Defaults to "linear". 42 | causal (bool, optional): Whether to use causal attention masking for the CompressiveMemory. Defaults to False. 43 | position_embedder (Optional[PositionEmbeddings], optional): Position embedding module for the CompressiveMemory. Defaults to None. 44 | init_state_learnable (bool, optional): Whether the initial state of the CompressiveMemory should be learnable. Defaults to False. 45 | dropout (float, optional): Dropout rate for the MLP. Defaults to 0.0. 46 | """ 47 | super(InfiniTransformer, self).__init__() 48 | 49 | # If sampling_factor passed to kwargs, use it, otherwise set to None 50 | sampling_factor = kwargs.get("sampling_factor", None) 51 | 52 | # Multi-head attention 53 | self.attn = CompressiveMemory( 54 | dim_input=dim_input, 55 | dim_key=dim_key, 56 | dim_value=dim_value, 57 | num_heads=num_heads, 58 | segment_len=segment_len, 59 | sampling_factor=sampling_factor, 60 | update=update, 61 | causal=causal, 62 | position_embedder=position_embedder, 63 | init_state_learnable=init_state_learnable) 64 | # MLP 65 | if activation not in ACTIVATIONS: 66 | raise ValueError(f"Invalid activation function: {activation}") 67 | if activation in ["swiglu", "geglu", "ffnglu", "ffngeglu", "ffnswiglu"]: 68 | act = ACTIVATIONS[activation](dim_hidden) 69 | else: 70 | act = ACTIVATIONS[activation]() 71 | self.mlp = nn.Sequential( 72 | nn.Linear(dim_input, dim_hidden), 73 | nn.Dropout(dropout), 74 | act, 75 | nn.Linear(dim_hidden, dim_input), 76 | nn.Dropout(dropout) 77 | ) 78 | self.layer_norm = nn.LayerNorm(dim_input) 79 | 80 | def forward(self, x: torch.Tensor) -> torch.Tensor: 81 | """Forward pass. 82 | 83 | Args: 84 | x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim_input). 85 | 86 | Returns: 87 | torch.Tensor: Output tensor of shape (batch_size, seq_len, dim_input). 88 | """ 89 | # Apply multi-head attention, followed by MLP and layer normalization with residual connection. 90 | x_ = self.attn(x) 91 | x_ = self.mlp(x_) 92 | 93 | return self.layer_norm(x_ + x) 94 | 95 | 96 | class MoDInfiniTransformer(InfiniTransformer): 97 | """Mixture-of-Depths Infini-Transformer Layer.""" 98 | 99 | def __init__( 100 | self, 101 | dim_input: int, 102 | dim_hidden: int, 103 | dim_key: int, 104 | dim_value: int, 105 | num_heads: int, 106 | activation: str, 107 | segment_len: int, 108 | sampling_factor: int, 109 | update="linear", 110 | causal: bool = False, 111 | position_embedder: Optional[PositionEmbeddings] = None, 112 | init_state_learnable: bool = False, 113 | dropout: float = 0.0 114 | ): 115 | """Instantiate module. 116 | 117 | Args: 118 | dim_input (int): Input dimension. 119 | dim_hidden (int): Hidden dimension for the MLP. 120 | dim_key (int): Key dimension for the CompressiveMemory. 121 | dim_value (int): Value dimension for the CompressiveMemory. 122 | num_heads (int): Number of attention heads for the CompressiveMemory. 123 | activation (str): Activation function to use for the MLP. Must be a key in the ACTIVATIONS dictionary. 124 | segment_len (int): Segment length for the CompressiveMemory. 125 | sampling_factor (int): Reciprocal of the sampling rate for the Mixture-of-Depths mechanism. 126 | update (str, optional): Type of memory update rule to use for the CompressiveMemory ("linear" or "delta"). Defaults to "linear". 127 | causal (bool, optional): Whether to use causal attention masking for the CompressiveMemory. Defaults to False. 128 | position_embedder (Optional[PositionEmbeddings], optional): Position embedding module for the CompressiveMemory. Defaults to None. 129 | init_state_learnable (bool, optional): Whether the initial state of the CompressiveMemory should be learnable. Defaults to False. 130 | dropout (float, optional): Dropout rate for the MLP. Defaults to 0.0. 131 | """ 132 | # Initialize ordinary InfiniTransformer, but with segment length reduced by sampling_factor 133 | super(MoDInfiniTransformer, self).__init__( 134 | dim_input=dim_input, 135 | dim_hidden=dim_hidden, 136 | dim_key=dim_key, 137 | dim_value=dim_value, 138 | num_heads=num_heads, 139 | activation=activation, 140 | segment_len=math.ceil(segment_len / sampling_factor), 141 | update=update, 142 | causal=causal, 143 | position_embedder=position_embedder, 144 | init_state_learnable=init_state_learnable, 145 | dropout=dropout, 146 | sampling_factor=sampling_factor 147 | ) 148 | 149 | # Record additional init arguments for forward pass 150 | self.segment_len = math.ceil(segment_len / sampling_factor) 151 | self.full_segment_len = segment_len 152 | self.sampling_factor = sampling_factor 153 | self.dim_input = dim_input 154 | 155 | # Projection for tensor of logits when sampling 156 | self.proj_sampling = nn.Linear(dim_input, 1) 157 | 158 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: 159 | """Forward pass wrapper -- used to check at inference time whether to handle each observation individually. 160 | 161 | Args: 162 | x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim_input). 163 | 164 | Returns: 165 | torch.Tensor: Output tensor of shape (batch_size, seq_len, dim_input). 166 | torch.Tensor: Token selection mask of shape (batch_size * seq_len, 1) or None. 167 | torch.Tensor: Predicted token selection scores of shape (batch_size * seq_len, 1) or None. 168 | """ 169 | if self.training: 170 | return self.forward_(x) 171 | else: 172 | # !!! TEMPORARY: Each sample may have a different sequence length, resulting in a ragged array 173 | # !!! the current fix is to process each sample individually and concatenate the results 174 | 175 | out = [] 176 | 177 | # Loop through samples and produce output for each 178 | for ix in range(x.size(0)): 179 | sample_out, _, _ = self.forward_(x[ix:ix+1,...]) 180 | out.append(sample_out) 181 | 182 | # Concatenate results 183 | out = torch.cat(out, dim=0) 184 | 185 | return out, None, None 186 | 187 | def forward_(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: 188 | """Forward pass. 189 | 190 | Args: 191 | x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim_input). 192 | 193 | Returns: 194 | torch.Tensor: Output tensor of shape (batch_size, seq_len, dim_input). 195 | torch.Tensor: Token selection mask of shape (batch_size * seq_len, 1). 196 | torch.Tensor: Predicted token selection scores of shape (batch_size * seq_len, 1) or None. 197 | """ 198 | # Calculate number of total segments, samples 199 | batch_size, seq_len, _ = x.shape 200 | num_segments, rem = divmod(seq_len, self.full_segment_len) 201 | num_segments += 1 if rem > 0 else 0 202 | 203 | # Initialize list of token sample masks 204 | sample_masks = [] 205 | 206 | # Use linear embedding for sample scores 207 | sample_scores = self.proj_sampling(x).squeeze(-1) 208 | 209 | # For each segment, sample the tokens with the highest scores 210 | for seg_num in range(num_segments): 211 | # Compute segment indices 212 | ix_lo = seg_num * self.full_segment_len 213 | ix_hi = ix_lo + self.full_segment_len 214 | 215 | if self.training: 216 | # During training, take the top-k tokens by score 217 | # Argsort by sample scores to get indices of tokens to keep 218 | sort_ixs = torch.argsort( 219 | sample_scores[:, ix_lo:ix_hi], dim=1, descending=True) 220 | 221 | # Convert token indices to a binary mask 222 | sample_mask_seg = torch.zeros_like( 223 | sample_scores[:, ix_lo:ix_hi], device=x.device) 224 | sample_mask_seg.scatter_( 225 | dim=1, index=sort_ixs[:, :self.segment_len], value=1.0) 226 | else: 227 | # During inference, take the tokens with score greater than zero 228 | sample_mask_seg = (sample_scores[:, ix_lo:ix_hi] > 0.0).float() 229 | 230 | sample_masks.append(sample_mask_seg) 231 | 232 | # Combine segment masks into a single mask 233 | sample_mask = torch.cat(sample_masks, dim=1).bool() 234 | 235 | # Extract subsequcne of input tensor based on sample mask 236 | sample_shape = (batch_size, self.segment_len * num_segments, self.dim_input) 237 | x_ = x[sample_mask.unsqueeze(-1).repeat((1, 1, self.dim_input))].view(sample_shape) 238 | 239 | # Apply multi-head attention to sample, followed by MLP 240 | x_ = self.attn(x_, sample_mask=sample_mask) 241 | x_ = self.mlp(x_) 242 | 243 | # Add result of attended tokens to the result (equivalent to making the result 244 | # for non-attended tokens zero) 245 | x[sample_mask.unsqueeze(-1).repeat((1, 1, self.dim_input))] += x_.view(-1) 246 | 247 | # Pad the output tensor to the original sequence length 248 | padding_mask = torch.arange(x.size(1), device=x.device)[None, :] < sample_mask.view(batch_size, -1).sum(dim=1)[:, None] 249 | x = x * padding_mask.unsqueeze(-1) 250 | 251 | # Flatten sample scores and concatenation of top-k masks for auxiliary training task 252 | sample_scores = sample_scores.view((-1, 1)) 253 | sample_mask = sample_mask.view((-1, 1)).float() 254 | 255 | return x, sample_mask, sample_scores 256 | 257 | 258 | def demo_mod_infini_transformer(): 259 | """ 260 | Demonstrates the usage of the MoDInfiniTransformer class. 261 | """ 262 | # Define the model parameters 263 | dim_input = 512 264 | dim_hidden = 2048 265 | dim_key = 64 266 | dim_value = 64 267 | num_heads = 8 268 | activation = "ffngeglu" 269 | segment_len = 2048 270 | sampling_factor = 8 271 | update = "linear" 272 | dropout = 0.1 273 | position_embedder = None 274 | 275 | # Define batch dimensions 276 | seq_len = 4096 277 | batch_size = 2 278 | 279 | # Create the MoDInfiniTransformer layer 280 | layer = MoDInfiniTransformer( 281 | dim_input=dim_input, 282 | dim_hidden=dim_hidden, 283 | dim_key=dim_key, 284 | dim_value=dim_value, 285 | num_heads=num_heads, 286 | activation=activation, 287 | segment_len=segment_len, 288 | sampling_factor=sampling_factor, 289 | update=update, 290 | dropout=dropout, 291 | position_embedder=position_embedder 292 | ) 293 | 294 | # Generate dummy batch 295 | x = torch.randn(batch_size, seq_len, dim_input) 296 | 297 | # Test outputs for the case where the net is training 298 | layer.train() 299 | x_att, sample_mask, sample_scores_pred = layer(x) 300 | 301 | # Test output for the case where the net is not training 302 | layer.eval() 303 | x_att, sample_mask, sample_scores_pred = layer(x) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Guide (user-friendly): 2 | # https://packaging.python.org/en/latest/guides/writing-pyproject-toml/ 3 | 4 | # Specification (technical, formal): 5 | # https://packaging.python.org/en/latest/specifications/pyproject-toml/ 6 | 7 | # Choosing a build backend: 8 | [build-system] 9 | requires = ["setuptools"] # REQUIRED if [build-system] table is used 10 | build-backend = "setuptools.build_meta" # If not defined, then legacy behavior can happen. 11 | 12 | [project] 13 | name = "infini-transformer" 14 | version = "0.2.7" 15 | description = "Infini-Transformer is a powerful and versatile transformer model designed for a wide range of natural language processing tasks." 16 | readme = "readme.md" 17 | requires-python = ">=3.8" 18 | license = {file = "LICENSE"} 19 | keywords = ["deep learning", "LLM", "transformer"] 20 | authors = [ 21 | {name = "Ryan Taylor", email = "ryan@beta-reduce.net" } 22 | ] 23 | maintainers = [ 24 | {name = "Ryan Taylor", email = "ryan@beta-reduce.net" } 25 | ] 26 | 27 | # Classifiers help users find your project by categorizing it. 28 | classifiers = [ 29 | # How mature is this project? Common values are 30 | # 3 - Alpha, 4 - Beta, 5 - Production/Stable 31 | "Development Status :: 3 - Alpha", 32 | 33 | # Indicate who your project is intended for 34 | "Intended Audience :: Developers", 35 | 36 | # Pick your license as you wish 37 | "License :: OSI Approved :: MIT License", 38 | "Programming Language :: Python :: 3 :: Only", 39 | ] 40 | 41 | dependencies = [ 42 | "torch>=2.0.0" 43 | ] 44 | 45 | [project.urls] 46 | "Homepage" = "https://github.com/dingo-actual/infini-transformer" 47 | "Bug Reports" = "https://github.com/dingo-actual/infini-transformer/issues" 48 | "Publication" = "https://arxiv.org/abs/2404.07143" 49 | "Source" = "https://github.com/dingo-actual/infini-transformer/" -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Infini-Transformer 2 | 3 | - [Infini-Transformer](#infini-transformer) 4 | - [Overview](#overview) 5 | - [Features](#features) 6 | - [Directory structure](#directory-structure) 7 | - [Getting Started](#getting-started) 8 | - [Usage](#usage) 9 | - [`CompressiveMemory`](#compressivememory) 10 | - [`InfiniTransformer`](#infinitransformer) 11 | - [`MoDInfiniTransformer`](#modinfinitransformer) 12 | - [`RoPEEmbeddings`](#ropeembeddings) 13 | - [`YaRNEmbeddings`](#yarnembeddings) 14 | - [Example Usage](#example-usage) 15 | - [License](#license) 16 | - [Acknowledgments](#acknowledgments) 17 | 18 | ## Overview 19 | 20 | Infini-Transformer ([https://arxiv.org/abs/2404.07143](https://arxiv.org/abs/2404.07143)) is a powerful and versatile transformer model designed for a wide range of natural language processing tasks. It leverages state-of-the-art techniques and architectures to achieve exceptional performance and scalability to infinite context lengths. 21 | 22 | ## Features 23 | 24 | - Scalable architecture for handling long sequences 25 | - Large-scale pre-training on diverse datasets 26 | - Support for multiple downstream tasks, including text classification, question answering, and language generation 27 | - Efficient fine-tuning for task-specific adaptation 28 | - Includes a Mixture-of-Depths ([https://arxiv.org/abs/2404.02258](https://arxiv.org/abs/2404.02258)) transformer layer that incorporates Infini-Attention 29 | - Implementation of RoPE ([https://arxiv.org/abs/2104.09864](https://arxiv.org/abs/2104.09864)) that conforms to Infini-Attention's and Mixture-of-Depth's memory-efficient designs 30 | - Implementation of YaRN ([https://arxiv.org/abs/2309.00071](https://arxiv.org/abs/2309.00071)) that conforms to Infini-Attention's and Mixture-of-Depth's memory-efficient designs 31 | 32 | ## Directory structure 33 | 34 | ```default 35 | infini-transformer/ 36 | │ 37 | ├── infini_transformer/ 38 | │ ├── __init__.py 39 | │ ├── transformer.py 40 | │ ├── compressive_memory.py 41 | │ ├── positional_embedder.py 42 | │ └── activations.py 43 | │ 44 | ├── examples/ 45 | │ ├── __init__.py 46 | │ └── modinfiniformer.py 47 | │ 48 | ├── tests/ 49 | │ ├── __init__.py 50 | │ └── test_transformer.py 51 | │ 52 | ├── LICENSE 53 | ├── README.md 54 | ├── requirements.txt 55 | ├── MANIFEST.in 56 | └── pyproject.toml 57 | ``` 58 | 59 | ## Getting Started 60 | 61 | To get started with Infini-Transformer, you can clone the repository and install it from source: 62 | 63 | ```bash 64 | git clone https://github.com/dingo-actual/infini-transformer.git 65 | cd infini-transformer 66 | pip install -e . 67 | ``` 68 | 69 | ## Usage 70 | 71 | ### `CompressiveMemory` 72 | 73 | The `CompressiveMemory` module is a key component of the Infini-Transformer architecture. It is designed to handle long sequences efficiently by compressing and storing the input tokens in a memory matrix and normalization vector. This allows the model to maintain a large context window while keeping the memory usage bounded. 74 | 75 | It performs a variant of multi-head self-attention with a recurrent update step by dividing the input tensor along its sequence dimension (which is assumed to be dimension 1). It begins by performing learned linear projections of the input into key, query and value tensors, from which it extracts segments for each recurrent step. 76 | 77 | At each recurrent step, it calculates a learned linear combination of linear attention (which uses the memory and normalization matrices) and SDP attention. It then updates the memory matrix and normalization vector using the current step's key and value matrices, along with the current memory matrix and normalization vector. Before output, the combined attention tensor is stacked along all heads, then projected back to the input dimension. 78 | 79 | The outputs from each recurrent step are concatenated along the sequence dimension (dimension 1) to produce the final output tensor. 80 | 81 | The update for the memory matrix has two variants: linear and delta. 82 | 83 | The linear update rule is: 84 | $$M_t = M_{t-1} + \bigl(\textrm{ELU}(K_{t-1}\bigr) + 1)^TV_{t-1}$$ 85 | 86 | The delta update rule is: 87 | $$M_t = M_{t-1} + \bigl(\textrm{ELU}(K_{t-1}) + 1\bigr)^T \biggl( V_{t-1} - \frac{(\textrm{ELU}(K_{t-1}) + 1)M_{t-1}}{(\textrm{ELU}(K_{t-1}) + 1)z_{t-1}}\biggr)$$ 88 | 89 | Where $M_i$ is the memory matrix and $z_i$ is the normalization vector at step $i$. The $K$ and $V$ matrices are subscripted to indicate the recurrent steps they correspond to. 90 | 91 | Computations are stacked along the embedding dimension whenever possible to make use of multi-head attention in an efficient manner. 92 | 93 | The `CompressiveMemory` module takes the following arguments: 94 | 95 | - `dim_input`: The input dimension of the tensors. 96 | - `dim_key`: The dimension of the key tensor and query tensors. 97 | - `dim_value`: The dimension of the value tensor. 98 | - `num_heads`: The number of attention heads. 99 | - `segment_len`: The length of each segment in the recurrent attention computation. 100 | - `sampling_factor`: The sampling factor used if using Mixture-of-Depths (use None if not using Mixture-of-Depths). (Default is None.) 101 | - `update`: The type of update to use for the memory matrix. Can be "linear" or "delta". (Default is "linear".) 102 | - `causal`: Whether to use causal attention in SDP calculations (where each position can only attend to previous positions). (Default is False.) 103 | - `positional_embedder`: An optional `PositionEmbeddings` object: `RoPEEmbeddings` or `YaRNEmbeddings` (Default is None.) 104 | - `init_state_learnable`: Whether the initial memory state and normalization vector are learnable parameters. (Default is False.) 105 | 106 | Example usage of the `CompressiveMemory` module is as follows: 107 | 108 | ```python 109 | import torch 110 | 111 | from infini_transformer.compressive_memory import CompressiveMemory 112 | 113 | 114 | cm = CompressiveMemory( 115 | dim_input=768, 116 | dim_key=64, 117 | dim_value=64, 118 | num_heads=8, 119 | segment_len=2048, 120 | sampling_factor=None, 121 | update="linear", 122 | causal=True, 123 | positional_embedder="rope", 124 | init_state_learnable=False 125 | ) 126 | 127 | batch = torch.randn( 128 | 2, # batch size 129 | 65536, # sequence length 130 | 768 # input dimension 131 | ) 132 | 133 | output = cm(batch) 134 | ``` 135 | 136 | During training, no special handling of the output is required. 137 | 138 | ### `InfiniTransformer` 139 | 140 | The `InfiniTransformer` class implements a variation on the original transformer the utilizes `CompressiveMemory` in place of standard self-attention. This allows the model to efficiently handle long sequences by compressing and storing the input tokens in a memory matrix and normalization vector. It makes use of the `CompressiveMemory` module to perform a variant of multi-head self-attention with a recurrent update step. 141 | 142 | The primary difference between `InfiniTransformer` and an ordinary transformer is the replacement of `CompressiveMemory` for the standard multi-head self-attention mechanism. 143 | 144 | The `InfiniTransformer` module takes the following arguments: 145 | 146 | - `dim_input`: The input dimension of the tensors. 147 | - `dim_hidden`: The hidden dimension of the MLP applied after multi-head self-attention. 148 | - `dim_key`: The dimension of the key tensor and query tensors. 149 | - `dim_value`: The dimension of the value tensor. 150 | - `num_heads`: The number of attention heads. 151 | - `activation`: The nonlinear activation function to apply in the MLP. The following activations are supported: 152 | 153 | - `"relu"`: ReLU activation 154 | - `"abs"`: Absolute value activation 155 | - `"gelu"`: Gaussian Error Linear Unit (GELU) activation 156 | - `"swish"`: Swish activation 157 | - `"swiglu"`: SwiGLU activation 158 | - `"geglu"`: Gated Gaussian Error Linear Unit (GeGELU) activation 159 | - `"ffnglu"`: Feed-Forward Network with Gated Linear Unit (FFNGLU) activation 160 | - `"ffngeglu"`: Feed-Forward Network with Gated Gaussian Error Linear Unit (FFNGeGLU) activation 161 | - `"ffnswiglu"`: Feed-Forward Network with Swish Gated Linear Unit (FFNSwiGLU) activation 162 | 163 | - `segment_len`: The length of each segment in the recurrent attention computation. 164 | - `update`: The type of update to use for the memory matrix. Can be "linear" or "delta". (Default is "linear".) 165 | - `causal`: Whether to use causal attention in SDP calculations (where each position can only attend to previous positions). (Default is False.) 166 | - `positional_embedder`: An optional `PositionEmbeddings` object: `RoPEEmbeddings` or `YaRNEmbeddings` (Default is None.) 167 | - `init_state_learnable`: Whether the initial memory state and normalization vector are learnable parameters. (Default is False.) 168 | - `dropout`: The dropout rate to apply in the MLP. (Default is 0.0.) 169 | 170 | Example usage of the `InfiniTransformer` module is as follows: 171 | 172 | ```python 173 | import torch 174 | 175 | from infini_transformer import InfiniTransformer 176 | 177 | 178 | tfm = InfiniTransformer( 179 | dim_input=768, 180 | dim_hidden=2048, 181 | dim_key=64, 182 | dim_value=64, 183 | num_heads=8, 184 | activation="ffngeglu", 185 | segment_len=2048, 186 | update="delta", 187 | causal=True, 188 | positional_embedder=None, 189 | init_state_learnable=False, 190 | dropout=0.1 191 | ) 192 | 193 | batch = torch.randn( 194 | 2, # batch size 195 | 65536, # sequence length 196 | 768 # input dimension 197 | ) 198 | 199 | output = tfm(batch) 200 | ``` 201 | 202 | During training, no special handling of the output is required. 203 | 204 | ### `MoDInfiniTransformer` 205 | 206 | The `MoDInfiniTransformer` module extends the `InfiniTransformer` module to incorporate Mixture-of-Depths (Raposo, et. al; [https://arxiv.org/abs/2404.02258](https://arxiv.org/abs/2404.02258)). A `MoDInfiniTransformer` block takes a learned linear projection of its input to a single dimension, and uses the tokens with the top-k highest values for the operations performed by `InfiniTransformer`, adding all remaining tokens to the residual connection. This allows the model to focus its capacity on the most important parts of the input sequence, reducing overall computation and memory requirements even further than `InfiniTransformer` alone. 207 | 208 | The top-k selection would ordinarily cause segments within the recurrent loop to have different lengths. We avoid this by dividing the selection evenly amongst all segments. 209 | 210 | Due to the non-causal nature of the top-k selection, at inference time the scores produced during projection to 1 dimension are taken to be logits for independent binary classifiers. As such, we train the model with an additional term added to the loss for each `ModInfiniFormer` layer, which is the binary cross-entropy loss between the logits and the top-k tokens selected during training. 211 | 212 | As such, the output from `ModInfiniTransformer` is a tuple consisting of three tensors: 213 | 214 | - The usual output tensor which matches the dimensions of the input tensor 215 | - A tensor of shape `(batch_size * sequence_length, 1)`, which represents a binary mask of top-k tokens selected during training. This will be the target for our additional binary cross-entropy loss. 216 | - A tensor of shape `(batch_size * sequence_length, 1)` of logits corresponding to the binary mask above. This represents the scores used to select the top-k tokens and is considered the prediction for the additional binary cross-entropy loss. 217 | 218 | At inference time, the second and third elements of the tuple can be safely ignored, as all token selection logic is handled within the `MoDInfiniTransformer` module itself. 219 | 220 | > **IMPORTANT NOTE**: The binary-classifier-based token selection mechanism for inference has no guarantee of selecting the same number of tokens for each element in a batch. If left unchecked, this would result in a ragged array, which is currently unsupported by PyTorch. The current solution in place is to force the batch size to 1 and concatenate forward passes over single observations. We are aware this is sub-optimal and hope to address it in the near future. 221 | 222 | The `MoDInfiniTransformer` module takes the following arguments: 223 | 224 | - `dim_input`: The input dimension of the tensors. 225 | - `dim_hidden`: The hidden dimension of the MLP applied after multi-head self-attention. 226 | - `dim_key`: The dimension of the key tensor and query tensors. 227 | - `dim_value`: The dimension of the value tensor. 228 | - `num_heads`: The number of attention heads. 229 | - `activation`: The nonlinear activation function to apply in the MLP. The following activations are supported: 230 | 231 | - `"relu"`: ReLU activation 232 | - `"abs"`: Absolute value activation 233 | - `"gelu"`: Gaussian Error Linear Unit (GELU) activation 234 | - `"swish"`: Swish activation 235 | - `"swiglu"`: SwiGLU activation 236 | - `"geglu"`: Gated Gaussian Error Linear Unit (GeGELU) activation 237 | - `"ffnglu"`: Feed-Forward Network with Gated Linear Unit (FFNGLU) activation 238 | - `"ffngeglu"`: Feed-Forward Network with Gated Gaussian Error Linear Unit (FFNGeGLU) activation 239 | - `"ffnswiglu"`: Feed-Forward Network with Swish Gated Linear Unit (FFNSwiGLU) activation 240 | 241 | - `segment_len`: The length of each segment in the recurrent attention computation. 242 | - `sampling_factor`: A numeric value in the interval (1, `segment_len`) that determines the number of tokens to select from each segment during the top-k selection. A larger value of `sampling_factor` results in fewer tokens being selected. 243 | - `update`: The type of update to use for the memory matrix. Can be "linear" or "delta". (Default is "linear".) 244 | - `causal`: Whether to use causal attention in SDP calculations (where each position can only attend to previous positions). (Default is False.) 245 | - `positional_embedder`: An optional `PositionEmbeddings` object: `RoPEEmbeddings` or `YaRNEmbeddings` (Default is None.) 246 | - `init_state_learnable`: Whether the initial memory state and normalization vector are learnable parameters. (Default is False.) 247 | - `dropout`: The dropout rate to apply in the MLP. (Default is 0.0.) 248 | 249 | Example usage of the `InfiniTransformer` module is as follows: 250 | 251 | ```python 252 | import torch 253 | 254 | from infini_transformer import MoDInfiniTransformer 255 | 256 | 257 | tfm = MoDInfiniTransformer( 258 | dim_input=768, 259 | dim_hidden=2048, 260 | dim_key=64, 261 | dim_value=64, 262 | num_heads=8, 263 | activation="ffngeglu", 264 | segment_len=2048, 265 | sampling_factor=8, 266 | update="delta", 267 | causal=True, 268 | init_state_learnable=False, 269 | positional_embedder=None, 270 | dropout=0.1 271 | ) 272 | 273 | batch = torch.randn( 274 | 2, # batch size 275 | 65536, # sequence length 276 | 768 # input dimension 277 | ) 278 | 279 | output, select_target, select_pred = tfm(batch) 280 | ``` 281 | 282 | During training, we must account for the additional outputs from `MoDInfiniFormer` so we can use them for the binary cross-entropy loss. See [infini_transformer/example/modinfiniformer.py](infini_transformer/example/modinfiniformer.py) for an example of how to incorporate the additional outputs into both the overall model output and the training loop. 283 | 284 | ### `RoPEEmbeddings` 285 | 286 | The `RoPEEmbeddings` module applies RoPE from the paper, "RoFormer: Enhanced Transformer with Rotary Position Embedding" by Su, et. al. ([https://arxiv.org/abs/2104.09864](https://arxiv.org/abs/2104.09864)). Once instantiated, it can be passed to the `InfiniTransformer` or `MoDInfiniTransformer` modules as the `positional_embedder` parameter, which passes it through to `CompressiveMemory`, where the position-aware embeddings are applied to the key and query tensors. 287 | 288 | The `RoPEEmbeddings` module takes the following arguments: 289 | 290 | - `dim`: The dimension of the key/value tensors. 291 | - `seq_len`: The maximum length of the input sequence to `CompressiveMemory` (this must match `CompressiveMemory`'s `segment_len` parameter). 292 | - `dim_embeddings_pct`: The proportion of the key/value tensor dimension to use for the position-aware embeddings. For example, if `dim` is 64 and `dim_embeddings_pct` is 0.5, then 32 dimensions will be used for the position-aware embeddings. (Default is 0.5.) 293 | - `base`: The base value to use for the position embedding angles. (Default is 10000.) 294 | 295 | Example usage of the `RoPEEmbeddings` module is as follows: 296 | 297 | ```python 298 | import torch 299 | 300 | from infini_transformer import InfiniTransformer 301 | from infini_transformer import RoPEEmbeddings 302 | 303 | embedder = RoPEEmbeddings( 304 | dim=64, # must match dim_key parameter in InfiniTransformer 305 | seq_len=2048, # must match segment_len parameter in InfiniTransformer 306 | dim_embeddings_pct=0.5, 307 | base=10000 308 | ) 309 | 310 | tfm = InfiniTransformer( 311 | dim_input=768, 312 | dim_hidden=2048, 313 | dim_key=64, # must match dim parameter in RoPEEmbeddings 314 | dim_value=64, 315 | num_heads=8, 316 | activation="ffngeglu", 317 | segment_len=2048, # must match seq_len parameter in RoPEEmbeddings 318 | update="delta", 319 | causal=True, 320 | positional_embedder=embedder, 321 | init_state_learnable=False, 322 | dropout=0.1 323 | ) 324 | 325 | batch = torch.randn( 326 | 2, # batch size 327 | 65536, # sequence length 328 | 768 # input dimension 329 | ) 330 | 331 | output = tfm(batch) 332 | ``` 333 | 334 | ### `YaRNEmbeddings` 335 | 336 | The `YaRNEmbeddings` module applies YaRN from the paper, "YaRN: Efficient Context Window Extension of Large Language Models" by Peng, et. al. ([https://arxiv.org/abs/2309.00071](https://arxiv.org/abs/2309.00071)). Once instantiated, it can be passed to the `InfiniTransformer` or `MoDInfiniTransformer` modules as the `positional_embedder` parameter, which passes it through to `CompressiveMemory`, where the position-aware embeddings are applied to the key and query tensors. 337 | 338 | The `YaRNEmbeddings` module takes the following arguments: 339 | 340 | - `dim`: The dimension of the key/value tensors. 341 | - `seq_len`: The maximum length of the input sequence to `CompressiveMemory` (this must match `CompressiveMemory`'s `segment_len` parameter). 342 | - `context_len`: Context length used during training. 343 | - `context_len_ext`: Context length to extend to. 344 | - `dim_embeddings_pct`: The proportion of the key/value tensor dimension to use for the position-aware embeddings. For example, if `dim` is 64 and `dim_embeddings_pct` is 0.5, then 32 dimensions will be used for the position-aware embeddings. (Default is 0.5.) 345 | - `base`: The base value to use for the position embedding angles. (Default is 10000.) 346 | - `alpha`: Interpolation minimum for dynamic scaling. (Default is 1.) 347 | - `beta`: Interpolation minimum for dynamic scaling. (Default is 32.) 348 | - `len_scale`: Length scale for attention calculation. Defaults to None (automatically calculated). 349 | 350 | Example usage of the `YaRNEmbeddings` module is as follows: 351 | 352 | ```python 353 | import torch 354 | 355 | from infini_transformer import InfiniTransformer 356 | from infini_transformer import YaRNEmbeddings 357 | 358 | embedder = YaRNEmbeddings( 359 | dim=64, # must match dim_key in InfiniTransformer 360 | seq_len=2048, # must match segment_len parameter in InfiniTransformer 361 | context_len=32768, 362 | context_len_ext=65536, 363 | dim_embeddings_pct=0.5, 364 | base=10000, 365 | alpha=1, 366 | beta=32, 367 | len_scale=None 368 | ) 369 | 370 | tfm = InfiniTransformer( 371 | dim_input=768, 372 | dim_hidden=2048, 373 | dim_key=64, # must match dim in YaRNEmbeddings 374 | dim_value=64, 375 | num_heads=8, 376 | activation="ffngeglu", 377 | segment_len=2048, # must match seq_len parameter in YaRNEmbeddings 378 | update="delta", 379 | causal=True, 380 | positional_embedder=embedder, 381 | init_state_learnable=False, 382 | dropout=0.1 383 | ) 384 | 385 | batch = torch.randn( 386 | 2, # batch size 387 | 65536, # sequence length 388 | 768 # input dimension 389 | ) 390 | 391 | output = tfm(batch) 392 | ``` 393 | 394 | ### Example Usage 395 | 396 | Please see [infini_transformer/example/modinfiniformer.py](infini_transformer/example/modinfiniformer.py) for an example of a model and training routine using the `MoDInfiniTransformer` module. 397 | 398 | More examples will be forthcoming. 399 | 400 | ## License 401 | 402 | This project is licensed under the [MIT License](LICENSE). 403 | 404 | ## Acknowledgments 405 | 406 | We would like to thank the researchers and developers whose work has inspired and contributed to the development of Infini-Transformer and Mixture-of-Depths Transformer. 407 | 408 | Also, we'd like to give special thanks to all the contributors, collaborators and people who have given feedback. Your efforts have made what was a rough outline of an implementation into something actually usable. 409 | 410 | If you have any questions or need further assistance, please feel free to reach out to me at [ryan@beta-reduce.net](ryan@beta-reduce.net). 411 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0.0 -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # tests/__init__.py 2 | 3 | # This file is intentionally left empty. 4 | # It is required to make the 'tests' directory a Python package. -------------------------------------------------------------------------------- /tests/test_transformer.py: -------------------------------------------------------------------------------- 1 | # tests/test_transformer.py 2 | 3 | import torch 4 | from infini_transformer import InfiniTransformer, MoDInfiniTransformer 5 | from infini_transformer import YaRNEmbeddings 6 | 7 | def test_infini_transformer(): 8 | dim_input = 512 9 | dim_hidden = 2048 10 | dim_key = 64 11 | dim_value = 64 12 | num_heads = 8 13 | activation = "ffngeglu" 14 | segment_len = 2048 15 | update = "delta" 16 | causal = True 17 | init_state_learnable = True 18 | dropout = 0.1 19 | 20 | positional_embedder = YaRNEmbeddings( 21 | dim=dim_key, 22 | seq_len=segment_len, 23 | context_len=32000, 24 | context_len_ext=64000, 25 | dim_embedding_pct=0.5, 26 | base=10000, 27 | alpha=1, 28 | beta=32, 29 | length_scale=None 30 | ) 31 | 32 | layer = InfiniTransformer( 33 | dim_input=dim_input, 34 | dim_hidden=dim_hidden, 35 | dim_key=dim_key, 36 | dim_value=dim_value, 37 | num_heads=num_heads, 38 | activation=activation, 39 | segment_len=segment_len, 40 | update=update, 41 | causal=causal, 42 | positional_embedder=positional_embedder, 43 | init_state_learnable=init_state_learnable, 44 | dropout=dropout 45 | ) 46 | 47 | batch_size = 2 48 | seq_len = 4096 49 | x = torch.randn(batch_size, seq_len, dim_input) 50 | 51 | layer.eval() # Set the layer to evaluation mode 52 | x_att = layer(x) 53 | 54 | assert x_att.shape == (batch_size, seq_len, dim_input) 55 | 56 | def test_mod_infini_transformer(): 57 | dim_input = 768 58 | dim_hidden = 3072 59 | dim_key = 96 60 | dim_value = 96 61 | num_heads = 12 62 | activation = "gelu" 63 | segment_len = 1024 64 | sampling_factor = 8 65 | update = "delta" 66 | causal = True 67 | init_state_learnable = True 68 | dropout = 0.2 69 | 70 | positional_embedder = YaRNEmbeddings( 71 | dim=dim_key, 72 | seq_len=segment_len, 73 | context_len=32000, 74 | context_len_ext=64000, 75 | dim_embedding_pct=0.5, 76 | base=10000, 77 | alpha=1, 78 | beta=32, 79 | length_scale=None 80 | ) 81 | 82 | layer = MoDInfiniTransformer( 83 | dim_input=dim_input, 84 | dim_hidden=dim_hidden, 85 | dim_key=dim_key, 86 | dim_value=dim_value, 87 | num_heads=num_heads, 88 | activation=activation, 89 | segment_len=segment_len, 90 | sampling_factor=sampling_factor, 91 | update=update, 92 | causal=causal, 93 | position_embedder=positional_embedder, 94 | init_state_learnable=init_state_learnable, 95 | dropout=dropout 96 | ) 97 | 98 | batch_size = 4 99 | seq_len = 2048 100 | x = torch.randn(batch_size, seq_len, dim_input) 101 | 102 | layer.train() # Set the layer to training mode 103 | x_att, sample_mask, sample_scores_pred = layer(x) 104 | 105 | assert x_att.shape == (batch_size, seq_len, dim_input) 106 | assert sample_mask.shape == (batch_size * seq_len, 1) 107 | assert sample_scores_pred.shape == (batch_size * seq_len, 1) 108 | 109 | 110 | if __name__=="__main__": 111 | test_infini_transformer() 112 | test_mod_infini_transformer() --------------------------------------------------------------------------------