├── assets ├── .DS_Store ├── memory.jpg ├── overview.png ├── rodimus.jpg ├── scaling.jpg ├── benchmark.jpg ├── CodeFuse-logo.jpg ├── needlebench.jpg └── rodimus-plus-coder-chat-evaluation.png ├── .gitignore ├── LEGAL.md ├── examples ├── generation_script.py └── chat_script.py ├── modules ├── utils.py ├── mlp.py ├── cache.py ├── rodimus_flow.py ├── rodimus_attention.py └── chat_format.py ├── configuration_rodimus.py ├── ops ├── swiglu.py ├── apply_rotary.py ├── layernorm_gated.py ├── rotary.py └── layernorm.py ├── tokenization_rodimus_fast.py ├── LICENSE ├── README.md └── modeling_rodimus.py /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/rodimus/main/assets/.DS_Store -------------------------------------------------------------------------------- /assets/memory.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/rodimus/main/assets/memory.jpg -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/rodimus/main/assets/overview.png -------------------------------------------------------------------------------- /assets/rodimus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/rodimus/main/assets/rodimus.jpg -------------------------------------------------------------------------------- /assets/scaling.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/rodimus/main/assets/scaling.jpg -------------------------------------------------------------------------------- /assets/benchmark.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/rodimus/main/assets/benchmark.jpg -------------------------------------------------------------------------------- /assets/CodeFuse-logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/rodimus/main/assets/CodeFuse-logo.jpg -------------------------------------------------------------------------------- /assets/needlebench.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/rodimus/main/assets/needlebench.jpg -------------------------------------------------------------------------------- /assets/rodimus-plus-coder-chat-evaluation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/rodimus/main/assets/rodimus-plus-coder-chat-evaluation.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # intermediate files 2 | build 3 | dist 4 | *.egg-info 5 | 6 | # pycache 7 | __pycache__ 8 | 9 | # vscode 10 | .vscode 11 | 12 | # macOS 13 | .DS_Store 14 | 15 | # ais 16 | .theia -------------------------------------------------------------------------------- /LEGAL.md: -------------------------------------------------------------------------------- 1 | Legal Disclaimer 2 | 3 | Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail. 4 | 5 | 法律免责声明 6 | 7 | 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。 -------------------------------------------------------------------------------- /examples/generation_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from modeling_rodimus import RodimusForCausalLM 4 | from tokenization_rodimus_fast import RodimusTokenizer 5 | 6 | # load model 7 | ckpt_dir = "model_path" 8 | tokenizer = RodimusTokenizer.from_pretrained(ckpt_dir) 9 | model = RodimusForCausalLM.from_pretrained( 10 | ckpt_dir, 11 | torch_dtype=torch.float16, 12 | device_map="cuda" 13 | ).eval() 14 | 15 | # inference 16 | input_prompt = "你好!你是谁?" 17 | model_inputs = tokenizer(input_prompt, return_tensors="pt").to(model.device) 18 | outputs = model.generate(**model_inputs, max_length=32) 19 | response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] 20 | 21 | print(response) 22 | -------------------------------------------------------------------------------- /examples/chat_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from modeling_rodimus import RodimusForCausalLM 4 | from tokenization_rodimus_fast import RodimusTokenizer 5 | 6 | # load model 7 | ckpt_dir = "model_path" 8 | tokenizer = RodimusTokenizer.from_pretrained(ckpt_dir) 9 | model = RodimusForCausalLM.from_pretrained( 10 | ckpt_dir, 11 | torch_dtype=torch.float16, 12 | device_map="cuda" 13 | ).eval() 14 | 15 | # inference 16 | input_prompt = "简单介绍一下大型语言模型。" 17 | messages = [ 18 | {"role": "HUMAN", "content": input_prompt} 19 | ] 20 | 21 | text = tokenizer.apply_chat_template( 22 | messages, 23 | system='You are Rodimus$+$, created by AntGroup. You are a helpful assistant.', 24 | tokenize=False, 25 | ) 26 | print(text) 27 | model_inputs = tokenizer(text, return_tensors="pt").to(model.device) 28 | outputs = model.generate(**model_inputs, max_length=2048) 29 | response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] 30 | 31 | print(response) 32 | -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from einops import rearrange, repeat 9 | 10 | 11 | def align_multiple(value, multiple_size=8): 12 | if value % multiple_size != 0: 13 | value += multiple_size - (value % multiple_size) 14 | return value 15 | 16 | 17 | def safe_eval_number(s): 18 | if s is None: 19 | return s 20 | try: 21 | return int(s) 22 | except ValueError: 23 | try: 24 | return float(s) 25 | except ValueError: 26 | return s 27 | 28 | 29 | def autocast_to_2B(x): 30 | if x.dtype not in {torch.float16, torch.bfloat16}: 31 | return x.to(dtype=torch.bfloat16) 32 | else: 33 | return x 34 | 35 | 36 | def xavier_uniform_(weight, gain=2 ** -2.5): 37 | nn.init.xavier_uniform_(weight, gain=2 ** -2.5) 38 | weight._no_reinit = True 39 | 40 | 41 | def reset_parameters_(linear_module): 42 | linear_module.reset_parameters() 43 | linear_module._is_hf_initialized = True 44 | -------------------------------------------------------------------------------- /modules/mlp.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from einops import rearrange, repeat 9 | 10 | from modules.utils import align_multiple 11 | from ops.swiglu import swiglu_linear 12 | 13 | 14 | class GLU(nn.Module): 15 | def __init__( 16 | self, 17 | dim, 18 | expand_ratio, 19 | dropout=0., 20 | activation_dropout=0., 21 | use_fast_path=True, 22 | ): 23 | super().__init__() 24 | self.dim = dim 25 | self.ffn_dim = align_multiple(int(dim * expand_ratio), 8) 26 | self.dropout = dropout 27 | self.activation_dropout = activation_dropout 28 | self.use_fast_path = use_fast_path 29 | 30 | if self.use_fast_path: 31 | assert swiglu_linear is not None 32 | 33 | self.fc = nn.Linear(self.dim, self.ffn_dim * 2, bias=False) 34 | self.out_proj = nn.Linear(self.ffn_dim, self.dim, bias=False) 35 | 36 | self.dropout_module = nn.Dropout(self.dropout) 37 | self.activation_dropout_module = nn.Dropout(self.activation_dropout) 38 | 39 | def forward(self, x): 40 | """ 41 | x: (B L D) 42 | """ 43 | x, g = self.fc(x).chunk(2, -1) 44 | if self.use_fast_path and self.activation_dropout == 0.: 45 | y = swiglu_linear(g, x, self.out_proj.weight, self.out_proj.bias) 46 | else: 47 | x_g = F.silu(g) * x 48 | x_g = self.activation_dropout_module(x_g) 49 | y = self.out_proj(x_g) 50 | 51 | y = self.dropout_module(y) 52 | 53 | return y 54 | -------------------------------------------------------------------------------- /configuration_rodimus.py: -------------------------------------------------------------------------------- 1 | from transformers.configuration_utils import PretrainedConfig 2 | 3 | DEFAULT_MIXER_CFG = { 4 | 5 | } 6 | DEFAULT_ATTN_CFG = { 7 | 8 | } 9 | 10 | 11 | class RodimusConfig(PretrainedConfig): 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | model_type = "rodimus" 14 | 15 | def __init__( 16 | self, 17 | block_type="rodimus", 18 | d_model=2048, 19 | n_layer=24, 20 | vocab_size=50277, 21 | norm_epsilon=1e-5, 22 | initializer_range=0.02, 23 | rescale_prenorm_residual=True, 24 | residual_in_fp32=True, 25 | use_fast_path=True, 26 | use_fused_cross_entropy=False, 27 | use_fused_swiglu=True, 28 | dropout=0., 29 | activation_dropout=0., 30 | attention_dropout=0., 31 | mixer_cfg=DEFAULT_MIXER_CFG, 32 | attn_cfg=DEFAULT_ATTN_CFG, 33 | max_position_embeddings=2048, 34 | pad_token_id=None, 35 | bos_token_id=0, 36 | eos_token_id=0, 37 | tie_word_embeddings=False, 38 | output_attentions=False, 39 | output_hidden_states=False, 40 | use_cache=True, 41 | use_scale_embedding=False, 42 | use_norm_embedding=False, 43 | no_weight_decay_on_bias=True, 44 | no_weight_decay_on_norm=True, 45 | no_weight_decay_on_embedding=True, 46 | **kwargs, 47 | ): 48 | assert block_type in ["rodimus", "rodimus_plus"] 49 | self.block_type = block_type 50 | 51 | self.d_model = d_model 52 | self.n_layer = n_layer 53 | self.vocab_size = vocab_size 54 | self.max_position_embeddings = max_position_embeddings 55 | self.mixer_cfg = mixer_cfg 56 | self.attn_cfg = attn_cfg 57 | 58 | self.norm_epsilon = norm_epsilon 59 | self.initializer_range = initializer_range 60 | self.rescale_prenorm_residual = rescale_prenorm_residual 61 | self.residual_in_fp32 = residual_in_fp32 62 | self.use_fast_path = use_fast_path 63 | self.use_fused_cross_entropy = use_fused_cross_entropy 64 | self.use_fused_swiglu = use_fused_swiglu 65 | 66 | self.dropout = dropout 67 | self.activation_dropout = activation_dropout 68 | self.attention_dropout = attention_dropout 69 | 70 | self.output_attentions = output_attentions 71 | self.output_hidden_states = output_hidden_states 72 | self.use_cache = use_cache 73 | 74 | self.use_scale_embedding = use_scale_embedding 75 | self.use_norm_embedding = use_norm_embedding 76 | 77 | self.no_weight_decay_on_bias = no_weight_decay_on_bias 78 | self.no_weight_decay_on_norm = no_weight_decay_on_norm 79 | self.no_weight_decay_on_embedding = no_weight_decay_on_embedding 80 | 81 | super().__init__( 82 | bos_token_id=bos_token_id, 83 | eos_token_id=eos_token_id, 84 | pad_token_id=pad_token_id, 85 | tie_word_embeddings=tie_word_embeddings, 86 | **kwargs, 87 | ) 88 | -------------------------------------------------------------------------------- /ops/swiglu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch.cuda.amp import custom_bwd, custom_fwd 5 | 6 | swiglu_fwd_codestring = """ 7 | template T swiglu_fwd(T x, T y) { 8 | return float(x) * float(y) / (1.0f + ::exp(-float(x))); 9 | } 10 | """ 11 | swiglu_bwd_codestring = """ 12 | template T swiglu_bwd(T x, T y, T g, T& dx, T& dy) { 13 | float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); 14 | dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); 15 | dy = float(x) * x_sigmoid * float(g); 16 | } 17 | """ 18 | 19 | swiglu_bwd_with_output_codestring = """ 20 | template T swiglu_bwd_with_output(T x, T y, T g, T& dx, T& dy, T& z) { 21 | float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); 22 | float x_swish = float(x) * x_sigmoid; 23 | dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); 24 | dy = x_swish * float(g); 25 | z = x_swish * float(y); 26 | } 27 | """ 28 | 29 | swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring) 30 | swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2) 31 | swiglu_bwd_with_output = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_with_output_codestring, num_outputs=3) 32 | 33 | 34 | class SwiGLUFunction(torch.autograd.Function): 35 | r""" 36 | Swish-Gated Linear Unit (SwiGLU) function. 37 | 38 | .. math:: 39 | \text{SwiGLU}(x, y) = swish(x) * y = \frac{x}{1 + \exp(-x)} * y 40 | """ 41 | 42 | @staticmethod 43 | def forward(ctx, x, y): 44 | ctx.save_for_backward(x, y) 45 | return swiglu_fwd(x, y) 46 | 47 | @staticmethod 48 | def backward(ctx, dout): 49 | x, y = ctx.saved_tensors 50 | return swiglu_bwd(x, y, dout) 51 | 52 | 53 | class SwiGLULinearFunction(torch.autograd.Function): 54 | r""" 55 | Swish-Gated Linear Unit (SwiGLU) function followed by a linear transformation. 56 | 57 | .. math:: 58 | \text{SwiGLULinear}(x, y, W, b) = (swish(x) * y) W + b 59 | 60 | This simple wrap discards the intermediate results of SwiGLU(x, y) to save memory. 61 | """ 62 | 63 | @staticmethod 64 | @custom_fwd 65 | def forward(ctx, x, y, weight, bias): 66 | z = swiglu_fwd(x, y) 67 | out = F.linear(z.to(weight.dtype), weight, bias) 68 | # We don't store z, will be recomputed in the backward pass to save memory 69 | ctx.save_for_backward(x, y, weight) 70 | ctx.linear_bias_is_none = bias is None 71 | return out 72 | 73 | @staticmethod 74 | @custom_bwd 75 | def backward(ctx, dout, *args): 76 | x, y, weight = ctx.saved_tensors 77 | dout = dout.reshape(-1, dout.shape[-1]) 78 | dz = F.linear(dout, weight.t()).view_as(x) 79 | dx, dy, z = swiglu_bwd_with_output(x, y, dz) 80 | dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1])) 81 | dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) 82 | return dx, dy, dlinear_weight, dlinear_bias 83 | 84 | 85 | swiglu = SwiGLUFunction.apply 86 | swiglu_linear = SwiGLULinearFunction.apply 87 | -------------------------------------------------------------------------------- /tokenization_rodimus_fast.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import warnings 5 | from shutil import copyfile 6 | from typing import Any, Dict, List, Optional, Tuple, Union 7 | 8 | import numpy as np 9 | import torch 10 | from transformers.file_utils import to_py_obj 11 | from transformers.tokenization_utils_base import ( 12 | AddedToken, 13 | BatchEncoding, 14 | EncodedInput, 15 | PreTokenizedInput, 16 | TextInput, 17 | TruncationStrategy, 18 | ) 19 | from transformers.utils import PaddingStrategy, TensorType, logging 20 | from transformers import AutoTokenizer, PreTrainedTokenizerFast 21 | 22 | from modules.chat_format import Chat 23 | 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | 28 | class RodimusTokenizer(PreTrainedTokenizerFast): 29 | slow_tokenizer_class = None 30 | padding_side = "left" 31 | model_input_names = ["input_ids", "attention_mask"] 32 | slow_tokenizer_class = None 33 | 34 | SPECIAL_TOKENS_ATTRIBUTES = [ 35 | "bos_token", 36 | "eos_token", 37 | "unk_token", 38 | "sep_token", 39 | "pad_token", 40 | "cls_token", 41 | "mask_token", 42 | "gmask_token", 43 | "additional_special_tokens", 44 | ] 45 | 46 | def __init__( 47 | self, 48 | vocab_file=None, 49 | merges_file=None, 50 | tokenizer_file=None, 51 | clean_up_tokenization_spaces=False, 52 | bos_token="<|startoftext|>", 53 | eos_token="<|endoftext|>", 54 | cls_token="[CLS]", 55 | pad_token="<|endoftext|>", 56 | gmask_token="[gMASK]", 57 | add_bos_token=False, 58 | add_eos_token=False, 59 | **kwargs, 60 | ): 61 | self._gmask_token = ( 62 | AddedToken(gmask_token, lstrip=False, 63 | rstrip=False, normalized=False) 64 | if isinstance(gmask_token, str) 65 | else gmask_token 66 | ) 67 | 68 | self._sop_token = ( 69 | AddedToken(bos_token, lstrip=False, rstrip=False, normalized=False) 70 | if isinstance(bos_token, str) 71 | else bos_token 72 | ) 73 | 74 | self._eop_token = ( 75 | AddedToken(eos_token, lstrip=False, rstrip=False, normalized=False) 76 | if isinstance(eos_token, str) 77 | else eos_token 78 | ) 79 | 80 | super().__init__( 81 | vocab_file=vocab_file, 82 | merges_file=merges_file, 83 | tokenizer_file=tokenizer_file, 84 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 85 | bos_token=bos_token, 86 | eos_token=eos_token, 87 | cls_token=cls_token, 88 | pad_token=eos_token, 89 | gmask_token=gmask_token, 90 | add_bos_token=add_bos_token, 91 | add_eos_token=add_eos_token, 92 | **kwargs, 93 | ) 94 | 95 | self.check_special_tokens() 96 | 97 | def check_special_tokens(self): 98 | ''' 99 | eos_token, cls_token, mask_token 100 | special tokens should init, check special token is not None 101 | ''' 102 | for name, special_token in zip( 103 | ['eos', 'bos', 'cls', 'gmask'], 104 | [self.eos_token, self.bos_token, self.cls_token, self.gmask_token], 105 | ): 106 | assert special_token is not None, f'should init special token [{name}] in tokenizer_config.json' 107 | 108 | @property 109 | def gmask_token(self) -> Optional[str]: 110 | if self._gmask_token is None: 111 | if self.verbose: 112 | logger.error("Using gmask_token, but it is not set yet.") 113 | return None 114 | return str(self._gmask_token) 115 | 116 | @gmask_token.setter 117 | def gmask_token(self, value): 118 | if not isinstance(value, (str, AddedToken)) and value is not None: 119 | raise ValueError( 120 | "Cannot set a non-string value as the gmask token") 121 | self._gmask_token = value 122 | 123 | @property 124 | def gmask_token_id(self) -> Optional[int]: 125 | if self._gmask_token is None: 126 | return None 127 | return self.convert_tokens_to_ids(self.gmask_token) 128 | 129 | @property 130 | def sop_token(self) -> Optional[str]: 131 | if self._sop_token is None: 132 | if self.verbose: 133 | logger.error("Using sop_token, but it is not set yet.") 134 | return None 135 | return str(self._sop_token) 136 | 137 | @sop_token.setter 138 | def sop_token(self, value): 139 | if not isinstance(value, (str, AddedToken)) and value is not None: 140 | raise ValueError("Cannot set a non-string value as the sop token") 141 | self._sop_token = value 142 | 143 | @property 144 | def sop_token_id(self) -> Optional[int]: 145 | if self._sop_token is None: 146 | return None 147 | return self.convert_tokens_to_ids(self.sop_token) 148 | 149 | @property 150 | def eop_token(self) -> Optional[str]: 151 | if self._eop_token is None: 152 | if self.verbose: 153 | logger.error("Using eop_token, but it is not set yet.") 154 | return None 155 | return str(self._eop_token) 156 | 157 | @eop_token.setter 158 | def eop_token(self, value): 159 | if not isinstance(value, (str, AddedToken)) and value is not None: 160 | raise ValueError("Cannot set a non-string value as the eop token") 161 | self._eop_token = value 162 | 163 | @property 164 | def eop_token_id(self) -> Optional[int]: 165 | if self._eop_token is None: 166 | return None 167 | return self.convert_tokens_to_ids(self.eop_token) 168 | 169 | @property 170 | def vocab_size(self): 171 | return len(self.get_vocab()) 172 | 173 | def apply_chat_template( 174 | self, 175 | conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], 176 | system: str = None, 177 | tokenize=False, 178 | padding: bool = False, 179 | truncation: bool = False, 180 | max_length: Optional[int] = None, 181 | return_tensors: Optional[Union[str, TensorType]] = None, 182 | return_dict: bool = False, 183 | **kwargs, 184 | ): 185 | chat_format = kwargs.get('chat_format', 'rodimus_chat') 186 | 187 | is_batched = False 188 | 189 | if isinstance(conversation, List) and ( 190 | isinstance(conversation[0], (list, tuple) 191 | ) or "messages" in conversation[0] 192 | ): 193 | conversations = conversation 194 | is_batched = True 195 | 196 | if not is_batched: 197 | conversations = [conversation] 198 | 199 | rendered = [] 200 | for chat in conversations: 201 | if "messages" not in chat: 202 | # Indicates it's a Conversation object 203 | chat = {'messages': chat} 204 | if system: 205 | chat['system_message'] = system 206 | rendered_chat = Chat.from_json(chat, name=chat_format).prompt_str 207 | rendered.append(rendered_chat) 208 | 209 | if not is_batched: 210 | rendered = rendered[0] 211 | 212 | if tokenize: 213 | out = self( 214 | rendered, 215 | padding=padding, 216 | truncation=truncation, 217 | max_length=max_length, 218 | add_special_tokens=False, 219 | return_tensors=return_tensors, 220 | ) 221 | if return_dict: 222 | return out 223 | else: 224 | return out["input_ids"] 225 | else: 226 | return rendered 227 | -------------------------------------------------------------------------------- /ops/apply_rotary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | from typing import Optional, Union 4 | 5 | import torch 6 | 7 | import triton 8 | import triton.language as tl 9 | 10 | 11 | @triton.jit 12 | def rotary_kernel( 13 | OUT, # Pointers to matrices 14 | X, 15 | COS, 16 | SIN, 17 | CU_SEQLENS, 18 | SEQLEN_OFFSETS, # this could be int or a pointer 19 | # Matrix dimensions 20 | seqlen, 21 | rotary_dim, 22 | seqlen_ro, 23 | # strides 24 | stride_out_batch, 25 | stride_out_seqlen, 26 | stride_out_nheads, 27 | stride_out_headdim, 28 | stride_x_batch, 29 | stride_x_seqlen, 30 | stride_x_nheads, 31 | stride_x_headdim, 32 | # Meta-parameters 33 | BLOCK_K: tl.constexpr, 34 | IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, 35 | IS_VARLEN: tl.constexpr, 36 | INTERLEAVED: tl.constexpr, 37 | CONJUGATE: tl.constexpr, 38 | BLOCK_M: tl.constexpr, 39 | ): 40 | pid_m = tl.program_id(axis=0) 41 | pid_batch = tl.program_id(axis=1) 42 | pid_head = tl.program_id(axis=2) 43 | rotary_dim_half = rotary_dim // 2 44 | 45 | if not IS_VARLEN: 46 | X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads 47 | OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads 48 | else: 49 | start_idx = tl.load(CU_SEQLENS + pid_batch) 50 | seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx 51 | X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads 52 | OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads 53 | 54 | if pid_m * BLOCK_M >= seqlen: 55 | return 56 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 57 | if not IS_SEQLEN_OFFSETS_TENSOR: 58 | rm_cs = rm + SEQLEN_OFFSETS 59 | else: 60 | rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) 61 | rk = tl.arange(0, BLOCK_K) 62 | rk_half = tl.arange(0, BLOCK_K // 2) 63 | 64 | if not INTERLEAVED: 65 | # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT 66 | X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) 67 | COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) 68 | SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) 69 | cos = tl.load( 70 | COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0 71 | ).to(tl.float32) 72 | sin = tl.load( 73 | SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0 74 | ).to(tl.float32) 75 | x0 = tl.load( 76 | X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0 77 | ).to(tl.float32) 78 | x1 = tl.load( 79 | X + rotary_dim_half * stride_x_headdim, 80 | mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), 81 | other=0.0, 82 | ).to(tl.float32) 83 | if CONJUGATE: 84 | sin = -sin 85 | o0 = x0 * cos - x1 * sin 86 | o1 = x0 * sin + x1 * cos 87 | # write back result 88 | OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) 89 | tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) 90 | tl.store( 91 | OUT + rotary_dim_half * stride_out_headdim, 92 | o1, 93 | mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), 94 | ) 95 | else: 96 | # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow. 97 | # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...]. 98 | # Loading x0 will be fast but x1 will be slow. 99 | # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...]. 100 | # Then we do the calculation and use tl.where to pick put the right outputs for the even 101 | # and for the odd indices. 102 | rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... 103 | rk_repeat = tl.arange(0, BLOCK_K) // 2 104 | X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) 105 | X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim) 106 | COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) 107 | SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) 108 | cos = tl.load( 109 | COS, 110 | mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), 111 | other=1.0, 112 | ).to(tl.float32) 113 | sin = tl.load( 114 | SIN, 115 | mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), 116 | other=0.0, 117 | ).to(tl.float32) 118 | x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to( 119 | tl.float32 120 | ) 121 | x1 = tl.load( 122 | X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0 123 | ).to(tl.float32) 124 | if CONJUGATE: 125 | sin = -sin 126 | x0_cos = x0 * cos 127 | x1_sin = x1 * sin 128 | out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) 129 | OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) 130 | tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) 131 | 132 | 133 | def apply_rotary( 134 | x: torch.Tensor, 135 | cos: torch.Tensor, 136 | sin: torch.Tensor, 137 | seqlen_offsets: Union[int, torch.Tensor] = 0, 138 | cu_seqlens: Optional[torch.Tensor] = None, 139 | max_seqlen: Optional[int] = None, 140 | interleaved=False, 141 | inplace=False, 142 | conjugate=False, 143 | ) -> torch.Tensor: 144 | """ 145 | Arguments: 146 | x: (batch, seqlen, nheads, headdim) if cu_seqlens is None 147 | else (total_seqlen, nheads, headdim). 148 | cos: (seqlen_ro, rotary_dim / 2) 149 | sin: (seqlen_ro, rotary_dim / 2) 150 | seqlen_offsets: integer or integer tensor of size (batch,) 151 | cu_seqlens: (batch + 1,) or None 152 | max_seqlen: int 153 | Returns: 154 | y: (batch, seqlen, nheads, headdim) 155 | """ 156 | is_varlen = cu_seqlens is not None 157 | if not is_varlen: 158 | batch, seqlen, nheads, headdim = x.shape 159 | else: 160 | assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" 161 | total_seqlen, nheads, headdim = x.shape 162 | batch_p_1 = cu_seqlens.shape[0] 163 | batch = batch_p_1 - 1 164 | seqlen = max_seqlen 165 | seqlen_ro, rotary_dim = cos.shape 166 | assert sin.shape == cos.shape 167 | rotary_dim *= 2 168 | assert rotary_dim <= headdim, "rotary_dim must be <= headdim" 169 | assert headdim <= 256, "Only support headdim <= 256" 170 | assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" 171 | 172 | assert ( 173 | cos.dtype == sin.dtype 174 | ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" 175 | assert ( 176 | x.dtype == cos.dtype 177 | ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" 178 | 179 | cos, sin = cos.contiguous(), sin.contiguous() 180 | if isinstance(seqlen_offsets, torch.Tensor): 181 | assert seqlen_offsets.shape == (batch,) 182 | assert seqlen_offsets.dtype in [torch.int32, torch.int64] 183 | seqlen_offsets = seqlen_offsets.contiguous() 184 | else: 185 | assert seqlen_offsets + seqlen <= seqlen_ro 186 | 187 | output = torch.empty_like(x) if not inplace else x 188 | if rotary_dim < headdim and not inplace: 189 | output[..., rotary_dim:].copy_(x[..., rotary_dim:]) 190 | 191 | BLOCK_K = ( 192 | 32 193 | if rotary_dim <= 32 194 | else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) 195 | ) 196 | grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa 197 | BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) 198 | 199 | # Need this, otherwise Triton tries to launch from cuda:0 and we get 200 | # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) 201 | with torch.cuda.device(x.device.index): 202 | rotary_kernel[grid]( 203 | output, # data ptrs 204 | x, 205 | cos, 206 | sin, 207 | cu_seqlens, 208 | seqlen_offsets, 209 | seqlen, # shapes 210 | rotary_dim, 211 | seqlen_ro, 212 | output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 213 | output.stride(-3), # seqlen_stride or total_seqlen_stride 214 | output.stride(-2), # nheads_stride 215 | output.stride(-1), # headdim_stride 216 | x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 217 | x.stride(-3), # seqlen stride or total_seqlen_stride 218 | x.stride(-2), # nheads stride 219 | x.stride(-1), # headdim stride 220 | BLOCK_K, 221 | isinstance(seqlen_offsets, torch.Tensor), 222 | is_varlen, 223 | interleaved, 224 | conjugate, 225 | BLOCK_M, 226 | ) 227 | return output 228 | -------------------------------------------------------------------------------- /modules/cache.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any, Dict, List, Optional, Tuple 6 | 7 | import torch 8 | from transformers.cache_utils import Cache 9 | 10 | 11 | class HybridCache(Cache): 12 | def __init__( 13 | self, 14 | seen_tokens: int = 0, 15 | has_ssm: bool = True, 16 | has_attn: bool = False, 17 | ) -> None: 18 | assert has_attn or has_ssm, "set `has_attn=True` or `has_ssm=True`" 19 | 20 | self._seen_tokens = seen_tokens 21 | 22 | self.has_ssm = has_ssm 23 | self.has_attn = has_attn 24 | self.num_ssm_states = 2 if self.has_ssm else 0 25 | 26 | self.conv_states: List[torch.Tensor] = [] 27 | self.ssm_states: List[torch.Tensor] = [] 28 | self.key_caches: List[torch.Tensor] = [] 29 | self.value_caches: List[torch.Tensor] = [] 30 | 31 | def __getitem__(self, layer_idx: int) -> torch.Tensor: 32 | if layer_idx < len(self): 33 | states = () 34 | if self.has_ssm: 35 | states += (self.conv_states[layer_idx], self.ssm_states[layer_idx]) 36 | if self.has_attn: 37 | states += (self.key_caches[layer_idx], self.value_caches[layer_idx]) 38 | return states 39 | else: 40 | raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") 41 | 42 | def get_ssm_states(self, layer_idx: int) -> torch.Tensor: 43 | assert self.has_ssm 44 | return self[layer_idx][:self.num_ssm_states] 45 | 46 | def get_attn_states(self, layer_idx: int) -> torch.Tensor: 47 | assert self.has_attn 48 | return self[layer_idx][self.num_ssm_states:] 49 | 50 | def __iter__(self): 51 | for layer_idx in range(len(self)): 52 | yield self[layer_idx] 53 | 54 | def __len__(self): 55 | if self.has_ssm: 56 | return len(self.conv_states) 57 | else: 58 | return len(self.key_caches) 59 | 60 | def update( 61 | self, 62 | layer_idx: int, 63 | conv_state: torch.Tensor = None, 64 | ssm_state: torch.Tensor = None, 65 | key_cache: torch.Tensor = None, 66 | value_cache: torch.Tensor = None, 67 | offset: Optional[int, torch.Tensor] = 1, 68 | skip_copy: Optional[bool] = False, 69 | ) -> Tuple[torch.Tensor]: 70 | is_end = True 71 | if self.has_ssm: 72 | if conv_state is not None: 73 | if len(self.conv_states) <= layer_idx: 74 | self.conv_states.append(conv_state) 75 | is_end = False 76 | elif not skip_copy: 77 | self.conv_states[layer_idx].copy_(conv_state) 78 | 79 | if ssm_state is not None: 80 | if len(self.ssm_states) <= layer_idx: 81 | self.ssm_states.append(ssm_state) 82 | is_end = False 83 | elif not skip_copy: 84 | self.ssm_states[layer_idx].copy_(ssm_state) 85 | 86 | if self.has_attn: 87 | if key_cache is not None: 88 | if len(self.key_caches) <= layer_idx: 89 | self.key_caches.append(key_cache) 90 | is_end = False 91 | elif not skip_copy: 92 | # cache_seq_len = min(key_cache.shape[1], self.key_caches[layer_idx].shape[1]) 93 | # # self.key_caches[layer_idx] = torch.roll( 94 | # # self.key_caches[layer_idx], shifts=-cache_seq_len, dims=1) # b l h d 95 | # # self.key_caches[layer_idx][:, -cache_seq_len:, :, :].copy_(key_cache[:, -cache_seq_len:, :, :]) 96 | # self.key_caches[layer_idx] = torch.cat((self.key_caches[layer_idx][:, cache_seq_len:, :, :], key_cache[:, -cache_seq_len:, :, :]), dim=1) 97 | 98 | if key_cache.shape[0] > self.key_caches[layer_idx].shape[1]: 99 | max_cache_len = self.key_caches[layer_idx].shape[1] 100 | k_out = key_cache[:, -max_cache_len:, :, :] 101 | self.key_caches[layer_idx] += k_out 102 | else: 103 | k_out = self.key_caches[layer_idx] 104 | 105 | max_cache_len = self.key_caches[layer_idx].shape[1] 106 | input_num_tokens = key_cache.shape[1] 107 | 108 | slicing = torch.ones(max_cache_len, dtype=torch.long, device=key_cache.device).cumsum(0) 109 | cache_position = torch.arange(self._seen_tokens, self._seen_tokens + input_num_tokens, device=key_cache.device).clamp(0, max_cache_len - 1) 110 | to_shift = cache_position >= max_cache_len - 1 111 | indices = (slicing + to_shift[-1].int() - 1) % max_cache_len 112 | 113 | k_out = k_out[:, indices, :, :] 114 | k_out.index_copy_(1, cache_position, key_cache) 115 | 116 | self.key_caches[layer_idx].zero_() 117 | self.key_caches[layer_idx] += k_out 118 | 119 | if value_cache is not None: 120 | if len(self.value_caches) <= layer_idx: 121 | self.value_caches.append(value_cache) 122 | is_end = False 123 | elif not skip_copy: 124 | # cache_seq_len = min(value_cache.shape[1], self.value_caches[layer_idx].shape[1]) 125 | # # self.value_caches[layer_idx] = torch.roll( 126 | # # self.value_caches[layer_idx], shifts=-cache_seq_len, dims=1) # b l h d 127 | # # self.value_caches[layer_idx][:, -cache_seq_len:, :, :].copy_(value_cache[:, -cache_seq_len:, :, :]) 128 | # self.value_caches[layer_idx] = torch.cat((self.value_caches[layer_idx][:, cache_seq_len:, :, :], value_cache[:, -cache_seq_len:, :, :]), dim=1) 129 | 130 | if value_cache.shape[0] > self.value_caches[layer_idx].shape[1]: 131 | max_cache_len = self.value_caches[layer_idx].shape[1] 132 | v_out = value_cache[:, -max_cache_len:, :, :] 133 | self.value_caches[layer_idx] += v_out 134 | else: 135 | v_out = self.value_caches[layer_idx] 136 | 137 | max_cache_len = self.value_caches[layer_idx].shape[1] 138 | input_num_tokens = value_cache.shape[1] 139 | 140 | slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_cache.device).cumsum(0) 141 | cache_position = torch.arange(self._seen_tokens, self._seen_tokens + input_num_tokens, device=value_cache.device).clamp(0, max_cache_len - 1) 142 | to_shift = cache_position >= max_cache_len - 1 143 | indices = (slicing + to_shift[-1].int() - 1) % max_cache_len 144 | 145 | v_out = v_out[:, indices, :, :] 146 | v_out.index_copy_(1, cache_position, value_cache) 147 | 148 | self.value_caches[layer_idx].zero_() 149 | self.value_caches[layer_idx] += v_out 150 | 151 | # update the number of seen tokens once we achieve the last layer 152 | if layer_idx == len(self) - 1 and is_end: 153 | if self.has_attn and self.has_ssm: 154 | if value_cache is not None: # update `offset` once 155 | self._seen_tokens += offset 156 | else: 157 | self._seen_tokens += offset 158 | 159 | return self[layer_idx] 160 | 161 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: 162 | """Returns the sequence length of the cached states. A layer index can be optionally passed.""" 163 | if len(self) <= layer_idx: 164 | return 0 165 | return self._seen_tokens 166 | 167 | def get_max_length(self) -> Optional[int]: 168 | """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" 169 | return None 170 | 171 | def reorder_cache(self, beam_idx: torch.LongTensor): 172 | """Reorders the cache for beam search, given the selected beam indices.""" 173 | for layer_idx in range(len(self)): 174 | if self.has_ssm: 175 | device = self.conv_states[layer_idx].device 176 | self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) 177 | self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) 178 | if self.has_attn: 179 | device = self.key_caches[layer_idx].device 180 | self.key_caches[layer_idx] = self.key_caches[layer_idx].index_select(0, beam_idx.to(device)) 181 | self.value_caches[layer_idx] = self.value_caches[layer_idx].index_select(0, beam_idx.to(device)) 182 | 183 | def to_legacy_cache(self) -> Tuple[torch.Tensor]: 184 | legacy_cache = [] 185 | for layer_idx in range(len(self)): 186 | layer_cache = () 187 | if self.has_ssm: 188 | layer_cache += (self.conv_states[layer_idx], self.ssm_states[layer_idx]) 189 | if self.has_attn: 190 | layer_cache += (self.key_caches[layer_idx], self.value_caches[layer_idx]) 191 | legacy_cache.append(layer_cache) 192 | return legacy_cache 193 | 194 | @classmethod 195 | def from_legacy_cache( 196 | cls, 197 | past_key_values: Optional[Tuple[torch.Tensor]] = None, 198 | seen_tokens: int = 0, 199 | has_ssm: bool = True, 200 | has_attn: bool = False, 201 | **kwargs, 202 | ) -> HybridCache: 203 | """Converts a cache in the legacy cache format into an equivalent `HybridCache`.""" 204 | cache = cls(seen_tokens, has_ssm=has_ssm, has_attn=has_attn, **kwargs) 205 | 206 | if past_key_values is not None: 207 | for layer_idx in range(len(past_key_values)): 208 | unpack_states = {} 209 | if cache.has_ssm: 210 | assert len(past_key_values[layer_idx]) >= 2 211 | conv_state, ssm_state = past_key_values[layer_idx][:cache.num_ssm_states] 212 | unpack_states["conv_state"] = conv_state 213 | unpack_states["ssm_state"] = ssm_state 214 | if cache.has_attn: 215 | assert len(past_key_values[layer_idx]) >= 4 216 | key_cache, value_cache = past_key_values[layer_idx][cache.num_ssm_states:] 217 | unpack_states["key_cache"] = key_cache 218 | unpack_states["value_cache"] = value_cache 219 | cache.update( 220 | layer_idx, 221 | **unpack_states, 222 | ) 223 | 224 | return cache -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /modules/rodimus_flow.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from einops import rearrange, repeat 9 | from functools import partial 10 | 11 | from transformers.cache_utils import Cache 12 | from transformers.utils import is_torchdynamo_compiling 13 | 14 | from modules.utils import ( 15 | align_multiple, 16 | safe_eval_number, 17 | xavier_uniform_, 18 | reset_parameters_, 19 | ) 20 | 21 | try: 22 | from causal_conv1d import ( 23 | causal_conv1d_fn, 24 | causal_conv1d_update, 25 | ) 26 | except: 27 | causal_conv1d_update = None 28 | causal_conv1d_fn = None 29 | 30 | from fla.ops.gla import ( 31 | fused_chunk_gla, 32 | chunk_gla, 33 | fused_recurrent_gla 34 | ) 35 | from ops.layernorm import RMSNorm 36 | from ops.layernorm_gated import RMSNorm as RMSNormGated 37 | 38 | 39 | def _unsqueeze(x): 40 | return x.unsqueeze(1) 41 | 42 | 43 | def _squeeze(x): 44 | return x.squeeze(1) 45 | 46 | 47 | class ShortConv(nn.Module): 48 | def __init__( 49 | self, 50 | dim, 51 | d_conv=4, 52 | act="silu", # silu or None 53 | use_fast_path=True, 54 | causal=True, 55 | ): 56 | super().__init__() 57 | self.dim = dim 58 | self.d_conv = d_conv 59 | self.use_fast_path = use_fast_path 60 | self.causal = causal 61 | 62 | if self.use_fast_path: 63 | assert causal_conv1d_fn is not None 64 | 65 | self.act = act 66 | self.conv1d = nn.Conv1d( 67 | in_channels=self.dim, 68 | out_channels=self.dim, 69 | bias=True, 70 | kernel_size=self.d_conv, 71 | groups=self.dim, 72 | padding=self.d_conv - 1, 73 | ) 74 | 75 | if not self.causal: 76 | self.reverse_conv1d = nn.Conv1d( 77 | in_channels=self.dim, 78 | out_channels=self.dim, 79 | bias=True, 80 | kernel_size=self.d_conv, 81 | groups=self.dim, 82 | padding=self.d_conv - 1, 83 | ) 84 | 85 | self._init_weights() 86 | 87 | def _init_weights(self,): 88 | # self.conv1d.reset_parameters() 89 | self.conv1d._is_hf_initialized = True 90 | if not self.causal: 91 | # self.reverse_conv1d.reset_parameters() 92 | self.reverse_conv1d.zero_() 93 | self.reverse_conv1d._is_hf_initialized = True 94 | 95 | def allocate_inference_cache( 96 | self, 97 | batch_size, 98 | ): 99 | param = next(self.parameters()) 100 | conv_state = param.new_zeros(batch_size, self.dim, self.d_conv) 101 | return conv_state 102 | 103 | def forward( 104 | self, 105 | x: torch.Tensor, 106 | mask: Optional[torch.Tensor] = None, 107 | cache: Optional[torch.Tensor] = None, 108 | ): 109 | if mask is not None: 110 | x = x.masked_fill(~mask.unsqueeze(-1), 0.) 111 | 112 | seq_len = x.size(1) 113 | if cache is not None and seq_len == 1: 114 | return self.step(x, cache) 115 | 116 | re_x = rearrange(x, "b l d -> b d l") 117 | 118 | # Update state (B D W) 119 | if cache is not None: 120 | cache.copy_(F.pad(re_x, (self.d_conv - re_x.shape[-1], 0))) 121 | 122 | if self.use_fast_path: 123 | re_weight = rearrange(self.conv1d.weight, "d 1 w -> d w") 124 | x = causal_conv1d_fn( 125 | x=re_x, 126 | weight=re_weight, 127 | bias=self.conv1d.bias, 128 | activation=self.act if self.causal else None, 129 | ) 130 | 131 | if not self.causal: 132 | re_reverse_weight = rearrange(self.reverse_conv1d.weight, "d 1 w -> d w") 133 | x = x + causal_conv1d_fn( 134 | x=re_x.flip(-1), 135 | weight=re_reverse_weight, 136 | bias=self.reverse_conv1d.bias, 137 | activation=None, 138 | ) 139 | if self.act is not None: 140 | x = F.silu(x) 141 | else: 142 | x = self.conv1d(re_x)[..., :seq_len] 143 | if self.act is not None and self.causal: 144 | x = F.silu(x) 145 | 146 | if not self.causal: 147 | x = x + self.reverse_conv1d(re_x.flip(-1))[..., :seq_len] 148 | if self.act is not None: 149 | x = F.silu(x) 150 | 151 | x = rearrange(x, "b d l -> b l d") 152 | return x 153 | 154 | def step( 155 | self, 156 | x: torch.Tensor, 157 | cache: torch.Tensor 158 | ): 159 | assert x.shape[1] == 1, "Only support decoding with 1 token at a time for now" 160 | 161 | x = x.squeeze(1) 162 | if self.use_fast_path: 163 | re_weight = rearrange(self.conv1d.weight, "d 1 w -> d w") 164 | x = causal_conv1d_update( 165 | x=x, 166 | conv_state=cache, 167 | weight=re_weight, 168 | bias=self.conv1d.bias, 169 | activation=self.act, 170 | ) 171 | else: 172 | dtype = x.dtype 173 | cache.copy_(torch.roll(cache, shifts=-1, dims=-1)) 174 | cache[:, :, -1] = x 175 | x = torch.sum(cache * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) 176 | if self.conv1d.bias is not None: 177 | x = x + self.conv1d.bias 178 | if self.act is not None: 179 | x = F.silu(x) 180 | return x.unsqueeze(1) 181 | 182 | 183 | class RodimusFlowInner(nn.Module): 184 | def __init__( 185 | self, 186 | d_inner, 187 | d_conv=4, 188 | mem_size=64, 189 | input_gate_low_rank="auto", 190 | mode="fused_chunk", 191 | norm_epsilon=1e-5, 192 | post_norm_epsilon=None, 193 | normalize_epsilon=None, 194 | residual_in_fp32=True, 195 | use_fast_path=True, 196 | layer_idx=None, 197 | causal=True, 198 | ): 199 | super().__init__() 200 | self.d_conv = d_conv 201 | self.d_inner = d_inner 202 | self.mem_size = mem_size 203 | self.input_gate_low_rank = input_gate_low_rank 204 | 205 | self.residual_in_fp32 = residual_in_fp32 206 | self.use_fast_path = use_fast_path 207 | self.mode = mode 208 | self.norm_epsilon = norm_epsilon 209 | self.post_norm_epsilon = post_norm_epsilon if post_norm_epsilon is not None else norm_epsilon 210 | self.normalize_epsilon = normalize_epsilon if normalize_epsilon is not None else 1e-12 211 | self.layer_idx = layer_idx 212 | 213 | self.scale = 1 / math.sqrt(self.mem_size) 214 | 215 | self.short_conv = ShortConv( 216 | self.d_inner, 217 | d_conv=self.d_conv, 218 | use_fast_path=self.use_fast_path, 219 | causal=True, 220 | ) 221 | self.residual_weight = nn.Parameter(torch.ones( 222 | (self.d_inner, ), dtype=torch.float32 if self.residual_in_fp32 else None), requires_grad=True) 223 | self.residual_weight._no_weight_decay = True 224 | 225 | self.in_proj = nn.Linear(self.d_inner, self.mem_size * 2, bias=False) 226 | 227 | self.ch_gate_proj = nn.Sequential(nn.Linear(self.d_inner, self.input_gate_low_rank, bias=False),) 228 | self.ch_gate_proj.append(nn.Linear(self.input_gate_low_rank, self.d_inner, bias=True)) 229 | self.ch_gate_proj.append(nn.Sigmoid()) 230 | 231 | self.mem_gate_proj = nn.Linear(self.d_inner, self.mem_size * 2, bias=True) 232 | 233 | self._init_weights() 234 | 235 | def allocate_inference_cache( 236 | self, 237 | batch_size, 238 | ): 239 | param = next(self.parameters()) 240 | conv_state = self.short_conv.allocate_inference_cache(batch_size) 241 | ssm_state = param.new_zeros(batch_size, 1, self.mem_size, self.d_inner,) 242 | 243 | if not is_torchdynamo_compiling(): 244 | idx = self.layer_idx 245 | 246 | self.register_buffer(f"conv_state_{idx}", conv_state) 247 | conv_state = getattr(self, f"conv_state_{idx}") 248 | torch._dynamo.mark_static_address(conv_state) 249 | 250 | self.register_buffer(f"ssm_state_{idx}", ssm_state) 251 | ssm_state = getattr(self, f"ssm_state_{idx}") 252 | torch._dynamo.mark_static_address(ssm_state) 253 | 254 | return conv_state, ssm_state 255 | 256 | @torch.no_grad() 257 | def _init_weights(self): 258 | xavier_uniform_(self.in_proj.weight) 259 | 260 | sigmoid_bias_max = 0.999 261 | sigmoid_bias_min = 0.9 262 | init_floor = 1e-4 263 | 264 | xavier_uniform_(self.ch_gate_proj[0].weight) 265 | xavier_uniform_(self.ch_gate_proj[1].weight) 266 | 267 | bias = [] 268 | max_ = 1 - sigmoid_bias_min 269 | min_ = 1 - sigmoid_bias_max 270 | rt_bias = torch.exp( 271 | torch.rand(self.mem_size) * (math.log(max_) - math.log(min_)) 272 | + math.log(min_) 273 | ).clamp(min=init_floor) 274 | rt_bias = rt_bias + torch.log(-torch.expm1(-rt_bias)) 275 | 276 | bias.append(rt_bias) 277 | 278 | tau_bias = torch.empty((self.mem_size,)).uniform_(1/16, 0.9) 279 | tau_bias = torch.logit(tau_bias.float()).to(tau_bias.dtype) 280 | bias.append(tau_bias) 281 | 282 | xavier_uniform_(self.mem_gate_proj.weight) 283 | 284 | if self.mem_gate_proj.bias.shape[0] > 0: 285 | bias = torch.cat([b.to(device=self.mem_gate_proj.bias.device) for b in bias], dim=0) 286 | self.mem_gate_proj.bias.copy_(bias) 287 | self.mem_gate_proj.bias._no_reinit = True 288 | else: 289 | import warnings 290 | warnings.warn('mem_gate_proj.bias cannot be initialized using the meta context. Please note that when loading a pre-trained model') 291 | 292 | def forward( 293 | self, 294 | hidden_states: torch.Tensor, 295 | attention_mask: Optional[torch.Tensor] = None, 296 | past_key_values: Optional[Cache] = None, 297 | use_cache: Optional[bool] = False, 298 | output_attentions: Optional[bool] = False, 299 | **kwargs 300 | ): 301 | mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode 302 | 303 | if past_key_values is not None: 304 | last_state = past_key_values.get_ssm_states(self.layer_idx) 305 | attention_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None 306 | else: 307 | last_state = None 308 | 309 | if last_state is not None: 310 | conv_state, ssm_state = last_state 311 | else: 312 | conv_state, ssm_state = None, None 313 | 314 | if ( 315 | ssm_state is not None 316 | and hidden_states.dtype != ssm_state.dtype 317 | and mode == "chunk" 318 | ): 319 | mode = "fused_chunk" 320 | 321 | shift_hidden_states = self.short_conv( 322 | hidden_states, 323 | mask=attention_mask, 324 | cache=conv_state, 325 | ) 326 | residual = shift_hidden_states.float() if self.residual_in_fp32 else shift_hidden_states 327 | 328 | r, k = self.in_proj(shift_hidden_states).chunk(2, -1) 329 | 330 | u = self.ch_gate_proj(hidden_states) * hidden_states 331 | 332 | if attention_mask is not None: 333 | u = u.masked_fill(~attention_mask.unsqueeze(-1), 0.) 334 | 335 | mem_gates = F.linear(shift_hidden_states, self.mem_gate_proj.weight) + self.mem_gate_proj.bias.float() 336 | select_gate, tau_gate = mem_gates.chunk(2, -1) 337 | 338 | select_gate = F.softplus(select_gate) 339 | it_gate = select_gate 340 | rt_gate_log = -select_gate 341 | 342 | tau_gate = F.sigmoid(tau_gate) 343 | it_gate = it_gate ** tau_gate 344 | rt_gate_log = rt_gate_log * tau_gate 345 | 346 | k = F.normalize(k.float(), dim=-1, eps=self.normalize_epsilon) 347 | 348 | r, k, u, rt_gate_log = map(_unsqueeze, (r, k.float() * it_gate, u, rt_gate_log)) 349 | 350 | if mode == 'fused_recurrent': 351 | o, ssm_state = fused_recurrent_gla(r, k, u, rt_gate_log, scale=self.scale, 352 | initial_state=ssm_state, output_final_state=use_cache) 353 | elif mode == 'fused_chunk': 354 | o, ssm_state = fused_chunk_gla(r, k, u, rt_gate_log, scale=self.scale, 355 | initial_state=ssm_state, output_final_state=use_cache) 356 | elif mode == 'chunk': 357 | r, k, rt_gate_log = map(lambda x: x.to(u.dtype), (r, k, rt_gate_log)) 358 | o, ssm_state = chunk_gla(r, k, u, rt_gate_log, scale=self.scale, 359 | initial_state=ssm_state, output_final_state=use_cache) 360 | 361 | if past_key_values is not None: 362 | past_key_values.update( 363 | self.layer_idx, 364 | conv_state=conv_state, 365 | ssm_state=ssm_state, 366 | offset=u.shape[-2], 367 | ) 368 | o = (_squeeze(o) + residual * self.residual_weight).to(o.dtype) # TODO: fused 369 | 370 | return o, past_key_values 371 | 372 | 373 | class RodimusFlow(nn.Module): 374 | def __init__( 375 | self, 376 | dim, 377 | d_conv=4, 378 | expand_ratio=2, 379 | mem_size=64, 380 | dropout=0., 381 | activation_dropout=0., 382 | input_gate_low_rank="auto", 383 | mode="fused_chunk", 384 | norm_epsilon=1e-5, 385 | post_norm_epsilon=None, 386 | normalize_epsilon=None, 387 | residual_in_fp32=True, 388 | use_fast_path=True, 389 | layer_idx=None, 390 | causal=True, 391 | ): 392 | super().__init__() 393 | input_gate_low_rank = safe_eval_number(input_gate_low_rank) 394 | 395 | self.dim = dim 396 | self.d_conv = d_conv 397 | self.d_inner = align_multiple(int(dim * expand_ratio), 8) 398 | self.mem_size = mem_size 399 | self.dropout = dropout 400 | self.activation_dropout = activation_dropout 401 | self.input_gate_low_rank = max(self.dim // 64, 16) if input_gate_low_rank == "auto" else input_gate_low_rank 402 | 403 | self.residual_in_fp32 = residual_in_fp32 404 | self.use_fast_path = use_fast_path 405 | self.mode = mode 406 | self.norm_epsilon = norm_epsilon 407 | self.post_norm_epsilon = post_norm_epsilon if post_norm_epsilon is not None else norm_epsilon 408 | self.normalize_epsilon = normalize_epsilon if normalize_epsilon is not None else 1e-12 409 | self.layer_idx = layer_idx 410 | self.causal = causal 411 | 412 | self.act_norm = RMSNormGated(self.d_inner, eps=self.norm_epsilon, norm_before_gate=False) 413 | 414 | self.fc = nn.Linear(self.dim, self.d_inner * 2, bias=False) 415 | self.out_proj = nn.Linear(self.d_inner, self.dim, bias=False) 416 | 417 | inner_cls = partial( 418 | RodimusFlowInner, 419 | d_inner=self.d_inner, 420 | d_conv=self.d_conv, 421 | mem_size=self.mem_size, 422 | input_gate_low_rank=self.input_gate_low_rank, 423 | mode=self.mode, 424 | norm_epsilon=self.norm_epsilon, 425 | post_norm_epsilon=self.post_norm_epsilon, 426 | normalize_epsilon=self.normalize_epsilon, 427 | residual_in_fp32=True, 428 | use_fast_path=True, 429 | layer_idx=layer_idx, 430 | ) 431 | 432 | self.inner_mixer = inner_cls() 433 | if not self.causal: 434 | self.reverse_inner_mixer = inner_cls() 435 | 436 | self.dropout_module = nn.Dropout(self.dropout) 437 | self.activation_dropout_module = nn.Dropout(self.activation_dropout) 438 | 439 | def allocate_inference_cache( 440 | self, 441 | batch_size, 442 | ): 443 | assert not hasattr(self, "reverse_inner_mixer") 444 | return self.inner_mixer.allocate_inference_cache(batch_size) 445 | 446 | def forward( 447 | self, 448 | hidden_states: torch.Tensor, 449 | attention_mask: Optional[torch.Tensor] = None, 450 | past_key_values: Optional[Cache] = None, 451 | use_cache: Optional[bool] = False, 452 | output_attentions: Optional[bool] = False, 453 | **kwargs 454 | ): 455 | x, g = self.fc(hidden_states).chunk(2, -1) 456 | 457 | o, past_key_values = self.inner_mixer( 458 | hidden_states=x, 459 | attention_mask=attention_mask, 460 | past_key_values=past_key_values, 461 | use_cache=use_cache, 462 | output_attentions=output_attentions, 463 | ) 464 | 465 | if not self.causal: 466 | assert past_key_values is None 467 | assert use_cache is False 468 | reverse_o, _ = self.reverse_inner_mixer( 469 | hidden_states=x.flip(-2), 470 | attention_mask=attention_mask.flip(-1) if attention_mask is not None else None, 471 | past_key_values=None, 472 | use_cache=False, 473 | output_attentions=output_attentions, 474 | ) 475 | o = o + reverse_o 476 | 477 | x_g = self.act_norm(o, g) 478 | x_g = self.activation_dropout_module(x_g) 479 | y = self.out_proj(x_g) 480 | 481 | y = self.dropout_module(y) 482 | 483 | return y, past_key_values 484 | -------------------------------------------------------------------------------- /modules/rodimus_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.cuda import amp 8 | 9 | from einops import rearrange, repeat 10 | 11 | from transformers.cache_utils import Cache 12 | from transformers.utils import is_torchdynamo_compiling 13 | 14 | import fla 15 | from ops.layernorm import RMSNorm 16 | from modules.utils import ( 17 | autocast_to_2B, 18 | safe_eval_number, 19 | align_multiple 20 | ) 21 | from ops.rotary import RotaryEmbedding 22 | 23 | USE_FLASH_ATTN = True 24 | try: 25 | from flash_attn import ( 26 | flash_attn_kvpacked_func, 27 | flash_attn_qkvpacked_func, 28 | flash_attn_varlen_kvpacked_func, 29 | flash_attn_varlen_qkvpacked_func, 30 | flash_attn_func, 31 | flash_attn_varlen_func, 32 | ) 33 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 34 | except: 35 | USE_FLASH_ATTN = False 36 | 37 | flash_attn_with_kvcache = None # TODO 38 | if USE_FLASH_ATTN: 39 | try: 40 | from flash_attn import flash_attn_with_kvcache 41 | except ImportError: 42 | flash_attn_with_kvcache = None 43 | 44 | 45 | def _get_unpad_data(attention_mask): 46 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 47 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 48 | max_seqlen_in_batch = seqlens_in_batch.max().item() 49 | cu_seqlens = F.pad(torch.cumsum( 50 | seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 51 | return ( 52 | indices, 53 | cu_seqlens, 54 | max_seqlen_in_batch, 55 | ) 56 | 57 | 58 | class FlashAttention(nn.Module): 59 | def __init__( 60 | self, 61 | causal=True, 62 | softmax_scale=None, 63 | attention_dropout=0.0, 64 | window_size=(-1, -1), 65 | ): 66 | super().__init__() 67 | assert USE_FLASH_ATTN, "FlashAttention is not installed" 68 | self.causal = causal 69 | self.softmax_scale = softmax_scale 70 | self.drop = nn.Dropout(attention_dropout) 71 | self.window_size = window_size 72 | 73 | def __repr__(self): 74 | return f"{self.drop}\n\twindow_size: {self.window_size}, causal: {self.causal}, softmax_scale: {self.softmax_scale}" 75 | 76 | @amp.autocast(False) 77 | def forward( 78 | self, 79 | qkv, 80 | kv=None, 81 | v=None, 82 | cu_seqlens=None, 83 | max_seqlen=None, 84 | cu_seqlens_k=None, 85 | max_seqlen_k=None, 86 | ): 87 | dtype = qkv.dtype 88 | unpadded = cu_seqlens is not None 89 | 90 | if kv is None: 91 | assert v is None 92 | assert qkv.dtype in [torch.float16, torch.bfloat16] 93 | 94 | if unpadded: 95 | assert cu_seqlens.dtype == torch.int32 96 | assert max_seqlen is not None 97 | assert isinstance(max_seqlen, int) 98 | return flash_attn_varlen_qkvpacked_func( 99 | qkv, 100 | cu_seqlens, 101 | max_seqlen, 102 | self.drop.p if self.training else 0.0, 103 | softmax_scale=self.softmax_scale, 104 | causal=self.causal, 105 | # [i - window_size[0], i + window_size[1]] 106 | window_size=self.window_size, 107 | ).to(dtype) 108 | else: 109 | return flash_attn_qkvpacked_func( 110 | qkv, 111 | self.drop.p if self.training else 0.0, 112 | softmax_scale=self.softmax_scale, 113 | causal=self.causal, 114 | window_size=self.window_size, 115 | ).to(dtype) 116 | 117 | else: 118 | assert qkv.dtype in [torch.float16, torch.bfloat16] 119 | assert kv.dtype in [torch.float16, torch.bfloat16] 120 | if v is None: 121 | q = qkv 122 | if unpadded: 123 | assert cu_seqlens.dtype == torch.int32 124 | assert max_seqlen is not None 125 | assert isinstance(max_seqlen, int) 126 | assert cu_seqlens_k is not None 127 | assert cu_seqlens_k.dtype == torch.int32 128 | assert max_seqlen_k is not None 129 | assert isinstance(max_seqlen, int) 130 | return flash_attn_varlen_kvpacked_func( 131 | q, 132 | kv, 133 | cu_seqlens, 134 | cu_seqlens_k, 135 | max_seqlen, 136 | max_seqlen_k, 137 | self.drop.p if self.training else 0.0, 138 | softmax_scale=self.softmax_scale, 139 | causal=self.causal, 140 | window_size=self.window_size, 141 | ).to(dtype) 142 | else: 143 | batch_size, seqlen_q = q.shape[0], q.shape[1] 144 | seqlen_k = kv.shape[1] 145 | assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] 146 | return flash_attn_kvpacked_func( 147 | q, 148 | kv, 149 | self.drop.p if self.training else 0.0, 150 | causal=self.causal, 151 | softmax_scale=self.softmax_scale, 152 | window_size=self.window_size, 153 | ).to(dtype) 154 | else: 155 | assert v.dtype in [torch.float16, torch.bfloat16] 156 | 157 | q = qkv 158 | k = kv 159 | if unpadded: 160 | assert cu_seqlens.dtype == torch.int32 161 | assert max_seqlen is not None 162 | assert isinstance(max_seqlen, int) 163 | assert cu_seqlens_k is not None 164 | assert cu_seqlens_k.dtype == torch.int32 165 | assert max_seqlen_k is not None 166 | assert isinstance(max_seqlen, int) 167 | return flash_attn_varlen_func( 168 | q, 169 | k, 170 | v, 171 | cu_seqlens, 172 | cu_seqlens_k, 173 | max_seqlen, 174 | max_seqlen_k, 175 | self.drop.p if self.training else 0.0, 176 | softmax_scale=self.softmax_scale, 177 | causal=self.causal, 178 | window_size=self.window_size, 179 | ).to(dtype) 180 | else: 181 | batch_size, seqlen_q = q.shape[0], q.shape[1] 182 | seqlen_k = kv.shape[1] 183 | assert k.shape[3] == q.shape[3] 184 | return flash_attn_func( 185 | q, 186 | k, 187 | v, 188 | self.drop.p if self.training else 0.0, 189 | causal=self.causal, 190 | softmax_scale=self.softmax_scale, 191 | window_size=self.window_size, 192 | ).to(dtype) 193 | 194 | 195 | class SlideWindowSharedKeyAttention(nn.Module): 196 | def __init__( 197 | self, 198 | dim, 199 | num_heads, 200 | num_heads_k=None, 201 | num_heads_v=None, 202 | window_size=None, 203 | softmax_scale=None, 204 | causal=True, 205 | layer_idx=None, 206 | rotary_emb_dim=-1, 207 | rotary_emb_base=10000.0, 208 | rotary_emb_scale_base=None, 209 | rotary_emb_interleaved=False, 210 | dropout=0., 211 | activation_dropout=0., 212 | attention_dropout=0., 213 | max_position_embeddings=None, 214 | ): 215 | super().__init__() 216 | self.dim = dim 217 | self.num_heads = num_heads 218 | self.num_heads_k = num_heads_k 219 | self.num_heads_v = num_heads_v 220 | 221 | self.causal = causal 222 | self.layer_idx = layer_idx 223 | 224 | assert USE_FLASH_ATTN, "pip install flash_attn" 225 | if window_size is not None and window_size > 0: 226 | self.window_size = ( 227 | window_size // 2, 0) if self.causal else (window_size // 2, window_size // 2) 228 | else: 229 | self.window_size = (-1, 0) if self.causal else (-1, -1) 230 | 231 | self.head_dim = self.dim // self.num_heads 232 | assert self.head_dim * self.num_heads == self.dim 233 | 234 | assert self.num_heads % self.num_heads_k == 0 235 | assert self.num_heads % self.num_heads_v == 0 236 | 237 | self.rotary_emb_dim = rotary_emb_dim 238 | self.rotary_emb_base = rotary_emb_base 239 | self.rotary_emb_scale_base = rotary_emb_scale_base 240 | self.rotary_emb_interleaved = rotary_emb_interleaved 241 | 242 | self.dropout = dropout 243 | self.activation_dropout = activation_dropout 244 | self.attention_dropout = attention_dropout 245 | self.max_position_embeddings = max_position_embeddings 246 | 247 | self.dropout_module = nn.Dropout(self.dropout) 248 | self.activation_dropout_module = nn.Dropout(self.activation_dropout) 249 | 250 | if self.rotary_emb_dim < 0: 251 | self.rotary_emb_dim = self.head_dim 252 | 253 | if self.rotary_emb_dim > 0: # 0 -> nope 254 | self.rotary_emb = RotaryEmbedding( 255 | self.rotary_emb_dim, 256 | base=rotary_emb_base, 257 | scale_base=rotary_emb_scale_base, 258 | interleaved=rotary_emb_interleaved, 259 | ) 260 | # self.rotary_emb = None 261 | else: 262 | self.rotary_emb = None 263 | 264 | scale = None 265 | self.softmax_scale = softmax_scale 266 | if self.softmax_scale is not None: 267 | if self.softmax_scale == "norm": 268 | scale = 1. 269 | self.register_buffer("s", torch.arange( 270 | 1., 16., step=self.num_heads).unsqueeze(-1)) 271 | 272 | self.inner_attn = FlashAttention( 273 | causal=self.causal, 274 | softmax_scale=scale, 275 | attention_dropout=self.attention_dropout, 276 | window_size=self.window_size, 277 | ) 278 | 279 | self.q_proj = nn.Linear( 280 | self.dim, self.head_dim * self.num_heads, bias=False) 281 | self.k_proj = nn.Linear( 282 | self.dim, self.head_dim * self.num_heads_k, bias=False) 283 | self.v_proj = nn.Linear( 284 | self.dim, self.head_dim * self.num_heads_v, bias=False) 285 | 286 | self.out_proj = nn.Linear(self.dim, self.dim, bias=False) 287 | 288 | def allocate_inference_cache( 289 | self, 290 | batch_size, 291 | ): 292 | param = next(self.parameters()) 293 | key_caches = param.new_zeros( 294 | (batch_size, self.window_size[0], self.num_heads_k, self.head_dim)) 295 | value_caches = param.new_zeros( 296 | (batch_size, self.window_size[0], self.num_heads_v, self.head_dim)) 297 | 298 | if not is_torchdynamo_compiling(): 299 | idx = self.layer_idx 300 | 301 | self.register_buffer(f"key_cache_{idx}", key_caches) 302 | key_caches = getattr(self, f"key_cache_{idx}") 303 | torch._dynamo.mark_static_address(key_caches) 304 | 305 | self.register_buffer(f"value_cache_{idx}", value_caches) 306 | value_caches = getattr(self, f"value_cache_{idx}") 307 | torch._dynamo.mark_static_address(value_caches) 308 | 309 | return key_caches, value_caches 310 | 311 | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): 312 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( 313 | attention_mask) 314 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 315 | 316 | key_layer = index_first_axis( 317 | rearrange(key_layer, "b l ... -> (b l) ..."), indices_k 318 | ) 319 | value_layer = index_first_axis( 320 | rearrange(value_layer, "b l ... -> (b l) ..."), indices_k 321 | ) 322 | if query_length == kv_seq_len: 323 | query_layer = index_first_axis( 324 | rearrange(query_layer, "b l ... -> (b l) ..."), indices_k 325 | ) 326 | cu_seqlens_q = cu_seqlens_k 327 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 328 | indices_q = indices_k 329 | elif query_length == 1: 330 | max_seqlen_in_batch_q = 1 331 | cu_seqlens_q = torch.arange( 332 | batch_size + 1, dtype=torch.int32, device=query_layer.device 333 | ) 334 | indices_q = cu_seqlens_q[:-1] 335 | query_layer = query_layer.squeeze(1) 336 | else: 337 | raise NotImplementedError("Not implemented `cross_attention`") 338 | 339 | return ( 340 | query_layer, 341 | key_layer, 342 | value_layer, 343 | indices_q, 344 | (cu_seqlens_q, cu_seqlens_k), 345 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 346 | ) 347 | 348 | def forward( 349 | self, 350 | hidden_states: torch.Tensor, 351 | attention_mask: Optional[torch.Tensor] = None, 352 | # pack cache of key and value, with other params 353 | past_key_values: Optional[Cache] = None, 354 | use_cache: Optional[bool] = False, 355 | output_attentions: Optional[bool] = False, 356 | ): 357 | # print(attention_mask, attention_mask.all() if attention_mask is not None else None) 358 | 359 | batch_size, seq_len, _ = hidden_states.size() 360 | 361 | q = self.q_proj(hidden_states) 362 | k = self.k_proj(hidden_states) 363 | v = self.v_proj(hidden_states) 364 | 365 | q, k, v = map(lambda x: rearrange( 366 | x, "... (h d) -> ... h d", d=self.head_dim), (q, k, v)) 367 | 368 | if self.softmax_scale is not None: 369 | if self.softmax_scale == "norm": 370 | q = F.normalize(q, dim=-1) * self.s 371 | k = F.normalize(k, dim=-1) 372 | 373 | if past_key_values is not None: 374 | assert self.causal 375 | # assert not self.training, "inference" 376 | seqlen_offset = past_key_values.get_seq_length(self.layer_idx) 377 | cached_seqlen = seqlen_offset + seq_len 378 | rotary_max_seqlen = self.max_position_embeddings if self.max_position_embeddings is not None else 0 379 | if rotary_max_seqlen < cached_seqlen: 380 | rotary_max_seqlen = cached_seqlen 381 | else: 382 | seqlen_offset = 0 383 | rotary_max_seqlen = self.max_position_embeddings if self.max_position_embeddings else 0 384 | if rotary_max_seqlen < seq_len: 385 | rotary_max_seqlen = seq_len 386 | 387 | if self.rotary_emb is not None: 388 | # TODO 389 | q, k = self.rotary_emb( 390 | q, k, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen 391 | ) 392 | q, k, v = map(lambda x: autocast_to_2B(x), (q, k, v)) 393 | 394 | if past_key_values is not None: 395 | past_key_values.update( 396 | self.layer_idx, 397 | key_cache=k, 398 | value_cache=v, 399 | offset=seq_len, 400 | ) 401 | if seqlen_offset > 0: 402 | assert seq_len == 1, "during inference, length of query should be equal to 1" 403 | key_caches, value_cahces = past_key_values.get_attn_states( 404 | self.layer_idx) 405 | # k = key_caches[:, -cached_seqlen:, :, :] 406 | # v = value_cahces[:, -cached_seqlen:, :, :] 407 | k = key_caches[:, :cached_seqlen, :, :] 408 | v = value_cahces[:, :cached_seqlen, :, :] 409 | attention_mask = attention_mask[:, -k.shape[1] 410 | :] if attention_mask is not None else None 411 | 412 | if self.num_heads_k != self.num_heads: 413 | k = repeat(k, "... h d -> ... (n h) d", 414 | n=self.num_heads // self.num_heads_k) 415 | if self.num_heads_v != self.num_heads: 416 | v = repeat(v, "... h d -> ... (n h) d", 417 | n=self.num_heads // self.num_heads_v) 418 | 419 | if attention_mask is not None: 420 | q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( 421 | q, k, v, attention_mask, seq_len, 422 | ) 423 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 424 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 425 | else: 426 | cu_seqlens_q, cu_seqlens_k = None, None 427 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = None, None 428 | 429 | context = self.inner_attn( 430 | q, k, v, 431 | cu_seqlens=cu_seqlens_q, 432 | cu_seqlens_k=cu_seqlens_k, 433 | max_seqlen=max_seqlen_in_batch_q, 434 | max_seqlen_k=max_seqlen_in_batch_k 435 | ) 436 | if attention_mask is not None: 437 | context = pad_input(context, indices_q, batch_size, seq_len) 438 | 439 | x_g = rearrange(context, "... h d -> ... (h d)") 440 | x_g = self.activation_dropout_module(x_g) 441 | out = self.out_proj(x_g.to(self.out_proj.weight.dtype)) 442 | 443 | out = self.dropout_module(out) 444 | return out, past_key_values 445 | -------------------------------------------------------------------------------- /ops/layernorm_gated.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Tri Dao. 2 | # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html 3 | # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. 4 | # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. 5 | # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. 6 | 7 | import math 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | import triton 13 | import triton.language as tl 14 | 15 | from einops import rearrange 16 | 17 | 18 | def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True): 19 | dtype = x.dtype 20 | N = x.shape[-1] 21 | weight = weight.float() 22 | bias = bias.float() if bias is not None else None 23 | if upcast: 24 | x = x.float() 25 | z = z.float() if z is not None else z 26 | if z is not None and not norm_before_gate: 27 | x = x * F.silu(z) 28 | if group_size is None: 29 | rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) 30 | out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) 31 | else: 32 | x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) 33 | rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps) 34 | out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight 35 | if bias is not None: 36 | out = out + bias 37 | if z is not None and norm_before_gate: 38 | out *= F.silu(z) 39 | return out.to(dtype) 40 | 41 | 42 | @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) 43 | @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) 44 | @triton.jit 45 | def _layer_norm_fwd_1pass_kernel( 46 | X, # pointer to the input 47 | Y, # pointer to the output 48 | W, # pointer to the weights 49 | B, # pointer to the biases 50 | Z, # pointer to the other branch 51 | Mean, # pointer to the mean 52 | Rstd, # pointer to the 1/std 53 | stride_x_row, # how much to increase the pointer when moving by 1 row 54 | stride_y_row, 55 | stride_z_row, 56 | M, # number of rows in X 57 | N, # number of columns in X 58 | eps, # epsilon to avoid division by zero 59 | BLOCK_N: tl.constexpr, 60 | HAS_BIAS: tl.constexpr, 61 | HAS_Z: tl.constexpr, 62 | NORM_BEFORE_GATE: tl.constexpr, 63 | IS_RMS_NORM: tl.constexpr, 64 | ): 65 | # Map the program id to the row of X and Y it should compute. 66 | row = tl.program_id(0) 67 | group = tl.program_id(1) 68 | X += row * stride_x_row + group * N 69 | Y += row * stride_y_row + group * N 70 | if HAS_Z: 71 | Z += row * stride_z_row + group * N 72 | if not IS_RMS_NORM: 73 | Mean += group * M 74 | Rstd += group * M 75 | W += group * N 76 | if HAS_BIAS: 77 | B += group * N 78 | # Compute mean and variance 79 | cols = tl.arange(0, BLOCK_N) 80 | x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) 81 | if HAS_Z and not NORM_BEFORE_GATE: 82 | z = tl.load(Z + cols, mask=cols < N).to(tl.float32) 83 | x *= z * tl.sigmoid(z) 84 | if not IS_RMS_NORM: 85 | mean = tl.sum(x, axis=0) / N 86 | tl.store(Mean + row, mean) 87 | xbar = tl.where(cols < N, x - mean, 0.) 88 | var = tl.sum(xbar * xbar, axis=0) / N 89 | else: 90 | xbar = tl.where(cols < N, x, 0.) 91 | var = tl.sum(xbar * xbar, axis=0) / N 92 | rstd = 1 / tl.sqrt(var + eps) 93 | tl.store(Rstd + row, rstd) 94 | # Normalize and apply linear transformation 95 | mask = cols < N 96 | w = tl.load(W + cols, mask=mask).to(tl.float32) 97 | if HAS_BIAS: 98 | b = tl.load(B + cols, mask=mask).to(tl.float32) 99 | x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd 100 | y = x_hat * w + b if HAS_BIAS else x_hat * w 101 | if HAS_Z and NORM_BEFORE_GATE: 102 | z = tl.load(Z + cols, mask=mask).to(tl.float32) 103 | y *= z * tl.sigmoid(z) 104 | # Write output 105 | tl.store(Y + cols, y, mask=mask) 106 | 107 | 108 | def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False): 109 | M, N = x.shape 110 | if group_size is None: 111 | group_size = N 112 | assert N % group_size == 0 113 | ngroups = N // group_size 114 | assert x.stride(-1) == 1 115 | if z is not None: 116 | assert z.stride(-1) == 1 117 | assert z.shape == (M, N) 118 | assert weight.shape == (N,) 119 | assert weight.stride(-1) == 1 120 | if bias is not None: 121 | assert bias.stride(-1) == 1 122 | assert bias.shape == (N,) 123 | # allocate output 124 | if out is not None: 125 | assert out.shape == x.shape 126 | else: 127 | out = torch.empty_like(x) 128 | assert out.stride(-1) == 1 129 | mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None 130 | rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) 131 | # Less than 64KB per feature: enqueue fused kernel 132 | MAX_FUSED_SIZE = 65536 // x.element_size() 133 | BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) 134 | if group_size > BLOCK_N: 135 | raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") 136 | # heuristics for number of warps 137 | num_warps = min(max(BLOCK_N // 256, 1), 8) 138 | grid = (M, ngroups) 139 | with torch.cuda.device(x.device.index): 140 | _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd, 141 | x.stride(0), out.stride(0), z.stride(0) if z is not None else 0, 142 | M, group_size, eps, 143 | BLOCK_N=BLOCK_N, 144 | NORM_BEFORE_GATE=norm_before_gate, 145 | IS_RMS_NORM=is_rms_norm, 146 | num_warps=num_warps) 147 | return out, mean, rstd 148 | 149 | 150 | 151 | @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) 152 | @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) 153 | @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) 154 | @triton.jit 155 | def _layer_norm_bwd_kernel( 156 | X, # pointer to the input 157 | W, # pointer to the weights 158 | B, # pointer to the biases 159 | Z, # pointer to the other branch 160 | Y, # pointer to the output to be recomputed 161 | DY, # pointer to the output gradient 162 | DX, # pointer to the input gradient 163 | DW, # pointer to the partial sum of weights gradient 164 | DB, # pointer to the partial sum of biases gradient 165 | DZ, # pointer to the other branch 166 | Mean, # pointer to the mean 167 | Rstd, # pointer to the 1/std 168 | stride_x_row, # how much to increase the pointer when moving by 1 row 169 | stride_z_row, 170 | stride_y_row, 171 | stride_dy_row, 172 | stride_dx_row, 173 | stride_dz_row, 174 | stride_dw_row, 175 | stride_db_row, 176 | M, # number of rows in X 177 | N, # number of columns in X 178 | eps, # epsilon to avoid division by zero 179 | rows_per_program, 180 | NORM_BEFORE_GATE: tl.constexpr, 181 | IS_RMS_NORM: tl.constexpr, 182 | HAS_BIAS: tl.constexpr, 183 | HAS_Z: tl.constexpr, 184 | RECOMPUTE_OUTPUT: tl.constexpr, 185 | BLOCK_N: tl.constexpr, 186 | ): 187 | # Map the program id to the elements of X, DX, and DY it should compute. 188 | row_block_id = tl.program_id(0) 189 | group = tl.program_id(1) 190 | row_start = row_block_id * rows_per_program 191 | cols = tl.arange(0, BLOCK_N) 192 | mask = cols < N 193 | X += row_start * stride_x_row + group * N 194 | if HAS_Z: 195 | Z += row_start * stride_z_row + group * N 196 | DZ += row_start * stride_dz_row + group * N 197 | DY += row_start * stride_dy_row + group * N 198 | DX += row_start * stride_dx_row + group * N 199 | if RECOMPUTE_OUTPUT: 200 | Y += row_start * stride_y_row + group * N 201 | if not IS_RMS_NORM: 202 | Mean += group * M 203 | Rstd += group * M 204 | W += group * N 205 | w = tl.load(W + cols, mask=mask).to(tl.float32) 206 | if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS: 207 | B += group * N 208 | b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32) 209 | dw = tl.zeros((BLOCK_N,), dtype=tl.float32) 210 | if HAS_BIAS: 211 | db = tl.zeros((BLOCK_N,), dtype=tl.float32) 212 | row_end = min((row_block_id + 1) * rows_per_program, M) 213 | for row in range(row_start, row_end): 214 | # Load data to SRAM 215 | x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) 216 | dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) 217 | if not IS_RMS_NORM: 218 | mean = tl.load(Mean + row) 219 | if HAS_Z and not NORM_BEFORE_GATE: 220 | z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32) 221 | x_og = x 222 | x = x_og * z * tl.sigmoid(z) 223 | rstd = tl.load(Rstd + row) 224 | # Compute dx 225 | xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd 226 | xhat = tl.where(mask, xhat, 0.) 227 | if HAS_Z and NORM_BEFORE_GATE: 228 | z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32) 229 | z_sigmoid = tl.sigmoid(z) 230 | y = xhat * w + b if HAS_BIAS else xhat * w 231 | if RECOMPUTE_OUTPUT: 232 | tl.store(Y + cols, y * z * z_sigmoid, mask=mask) 233 | dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid)) 234 | tl.store(DZ + cols, dz, mask=mask) 235 | dy *= z * z_sigmoid 236 | else: 237 | if RECOMPUTE_OUTPUT: 238 | y = xhat * w + b if HAS_BIAS else xhat * w 239 | tl.store(Y + cols, y, mask=mask) 240 | wdy = w * dy 241 | c1 = tl.sum(xhat * wdy, axis=0) / N 242 | if not IS_RMS_NORM: 243 | c2 = tl.sum(wdy, axis=0) / N 244 | dx = (wdy - (xhat * c1 + c2)) * rstd 245 | else: 246 | dx = (wdy - xhat * c1) * rstd 247 | dw += dy * xhat 248 | if HAS_BIAS: 249 | db += dy 250 | if HAS_Z and not NORM_BEFORE_GATE: 251 | z_sigmoid = tl.sigmoid(z) 252 | dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid)) 253 | tl.store(DZ + cols, dz, mask=mask) 254 | dx *= z * z_sigmoid 255 | # Write dx 256 | tl.store(DX + cols, dx, mask=mask) 257 | 258 | X += stride_x_row 259 | if HAS_Z: 260 | Z += stride_z_row 261 | DZ += stride_dz_row 262 | if RECOMPUTE_OUTPUT: 263 | Y += stride_y_row 264 | DY += stride_dy_row 265 | DX += stride_dx_row 266 | tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask) 267 | if HAS_BIAS: 268 | tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask) 269 | 270 | 271 | def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None, 272 | norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None): 273 | M, N = x.shape 274 | if group_size is None: 275 | group_size = N 276 | assert N % group_size == 0 277 | ngroups = N // group_size 278 | assert x.stride(-1) == 1 279 | assert dy.stride(-1) == 1 280 | assert dy.shape == (M, N) 281 | if z is not None: 282 | assert z.stride(-1) == 1 283 | assert z.shape == (M, N) 284 | assert weight.shape == (N,) 285 | assert weight.stride(-1) == 1 286 | if bias is not None: 287 | assert bias.stride(-1) == 1 288 | assert bias.shape == (N,) 289 | # allocate output 290 | dx = torch.empty_like(x) 291 | if dz is not None: 292 | assert z is not None 293 | assert dz.shape == z.shape 294 | assert dz.stride(-1) == 1 295 | else: 296 | dz = torch.empty_like(z) if z is not None else None 297 | if recompute_output: 298 | if out is None: 299 | out = torch.empty_like(x) 300 | assert out.shape == x.shape 301 | 302 | # Less than 64KB per feature: enqueue fused kernel 303 | MAX_FUSED_SIZE = 65536 // x.element_size() 304 | BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) 305 | if group_size > BLOCK_N: 306 | raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") 307 | # heuristics for number of warps 308 | num_warps = min(max(BLOCK_N // 256, 1), 8) 309 | sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count 310 | # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs 311 | # would limit the occupancy. 312 | nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups) 313 | _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device) 314 | _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None 315 | rows_per_program = math.ceil(M / nrow_groups) 316 | grid = (nrow_groups, ngroups) 317 | with torch.cuda.device(x.device.index): 318 | _layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None, 319 | dy, dx, _dw, _db, dz, mean, rstd, 320 | x.stride(0), 321 | z.stride(0) if z is not None else 0, 322 | 0 if not recompute_output else out.stride(0), 323 | dy.stride(0), dx.stride(0), 324 | dz.stride(0) if dz is not None else 0, 325 | _dw.stride(0), 326 | _db.stride(0) if _db is not None else 0, 327 | M, group_size, eps, 328 | rows_per_program, 329 | BLOCK_N=BLOCK_N, 330 | NORM_BEFORE_GATE=norm_before_gate, 331 | IS_RMS_NORM=is_rms_norm, 332 | num_warps=num_warps) 333 | dw = _dw.sum(0).to(weight.dtype) 334 | db = _db.sum(0).to(bias.dtype) if bias is not None else None 335 | return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out) 336 | 337 | 338 | class LayerNormFn(torch.autograd.Function): 339 | 340 | @staticmethod 341 | def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, 342 | is_rms_norm=False): 343 | """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) 344 | """ 345 | 346 | x_shape_og = x.shape 347 | # reshape input data into 2D tensor 348 | x = x.reshape(-1, x.shape[-1]) 349 | if x.stride(-1) != 1: 350 | x = x.contiguous() 351 | if z is not None: 352 | assert z.shape == x_shape_og 353 | z = z.reshape(-1, z.shape[-1]) 354 | if z.stride(-1) != 1: 355 | z = z.contiguous() 356 | weight = weight.contiguous() 357 | if bias is not None: 358 | bias = bias.contiguous() 359 | y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm) 360 | ctx.save_for_backward(x, weight, bias, mean, rstd, z) 361 | ctx.x_shape_og = x_shape_og 362 | ctx.eps = eps 363 | ctx.group_size = group_size 364 | ctx.norm_before_gate = norm_before_gate 365 | ctx.is_rms_norm = is_rms_norm 366 | return y.reshape(x_shape_og) 367 | 368 | @staticmethod 369 | def backward(ctx, dy): 370 | x, weight, bias, mean, rstd, z = ctx.saved_tensors 371 | dy = dy.reshape(-1, dy.shape[-1]) 372 | if dy.stride(-1) != 1: 373 | dy = dy.contiguous() 374 | assert dy.shape == x.shape 375 | dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size, 376 | ctx.norm_before_gate, ctx.is_rms_norm) 377 | return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None 378 | 379 | 380 | def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False): 381 | return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm) 382 | 383 | 384 | def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True): 385 | return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True) 386 | 387 | 388 | class LayerNorm(torch.nn.Module): 389 | 390 | def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None): 391 | """If group_size is not None, we do GroupNorm with each group having group_size elements. 392 | group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). 393 | """ 394 | 395 | factory_kwargs = {"device": device, "dtype": dtype} 396 | super().__init__() 397 | self.eps = eps 398 | self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) 399 | self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) 400 | self.group_size = group_size 401 | self.norm_before_gate = norm_before_gate 402 | self.reset_parameters() 403 | 404 | def reset_parameters(self): 405 | torch.nn.init.ones_(self.weight) 406 | torch.nn.init.zeros_(self.bias) 407 | 408 | def forward(self, x, z=None): 409 | """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) 410 | """ 411 | return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps, 412 | norm_before_gate=self.norm_before_gate) 413 | 414 | 415 | class RMSNorm(torch.nn.Module): 416 | 417 | def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None): 418 | """If group_size is not None, we do GroupNorm with each group having group_size elements. 419 | group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). 420 | """ 421 | factory_kwargs = {"device": device, "dtype": dtype} 422 | super().__init__() 423 | self.eps = eps 424 | self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) 425 | self.register_parameter("bias", None) 426 | self.group_size = group_size 427 | self.norm_before_gate = norm_before_gate 428 | self.reset_parameters() 429 | 430 | def reset_parameters(self): 431 | torch.nn.init.ones_(self.weight) 432 | 433 | def forward(self, x, z=None): 434 | """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) 435 | """ 436 | return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size, 437 | norm_before_gate=self.norm_before_gate) 438 | -------------------------------------------------------------------------------- /ops/rotary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import math 4 | from typing import Optional, Tuple, Union 5 | 6 | import torch 7 | from einops import rearrange, repeat 8 | from .apply_rotary import apply_rotary 9 | 10 | 11 | def rotate_half(x, interleaved=False): 12 | if not interleaved: 13 | x1, x2 = x.chunk(2, dim=-1) 14 | return torch.cat((-x2, x1), dim=-1) 15 | else: 16 | x1, x2 = x[..., ::2], x[..., 1::2] 17 | return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) 18 | 19 | 20 | def apply_rotary_emb_torch(x, cos, sin, interleaved=False): 21 | """ 22 | x: (batch_size, seqlen, nheads, headdim) 23 | cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) 24 | """ 25 | ro_dim = cos.shape[-1] * 2 26 | assert ro_dim <= x.shape[-1] 27 | cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") 28 | sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") 29 | return torch.cat( 30 | [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], 31 | dim=-1, 32 | ) 33 | 34 | 35 | class ApplyRotaryEmb(torch.autograd.Function): 36 | @staticmethod 37 | def forward( 38 | ctx, 39 | x, 40 | cos, 41 | sin, 42 | interleaved=False, 43 | inplace=False, 44 | seqlen_offsets: Union[int, torch.Tensor] = 0, 45 | cu_seqlens: Optional[torch.Tensor] = None, 46 | max_seqlen: Optional[int] = None, 47 | ): 48 | out = apply_rotary( 49 | x, 50 | cos, 51 | sin, 52 | seqlen_offsets=seqlen_offsets, 53 | cu_seqlens=cu_seqlens, 54 | max_seqlen=max_seqlen, 55 | interleaved=interleaved, 56 | inplace=inplace, 57 | ) 58 | if isinstance(seqlen_offsets, int): 59 | ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward 60 | ctx.seqlen_offsets = seqlen_offsets 61 | else: 62 | ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) 63 | ctx.seqlen_offsets = None 64 | ctx.interleaved = interleaved 65 | ctx.inplace = inplace 66 | ctx.max_seqlen = max_seqlen 67 | return out if not inplace else x 68 | 69 | @staticmethod 70 | def backward(ctx, do): 71 | seqlen_offsets = ctx.seqlen_offsets 72 | if seqlen_offsets is None: 73 | cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors 74 | else: 75 | cos, sin, cu_seqlens = ctx.saved_tensors 76 | # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with 77 | # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. 78 | if not ctx.interleaved and not ctx.inplace: 79 | do = do.clone() 80 | dx = apply_rotary( 81 | do, 82 | cos, 83 | sin, 84 | seqlen_offsets=seqlen_offsets, 85 | cu_seqlens=cu_seqlens, 86 | max_seqlen=ctx.max_seqlen, 87 | interleaved=ctx.interleaved, 88 | inplace=ctx.inplace, 89 | conjugate=True, 90 | ) 91 | return dx, None, None, None, None, None, None, None 92 | 93 | 94 | def apply_rotary_emb( 95 | x, 96 | cos, 97 | sin, 98 | interleaved=False, 99 | inplace=False, 100 | seqlen_offsets: Union[int, torch.Tensor] = 0, 101 | cu_seqlens: Optional[torch.Tensor] = None, 102 | max_seqlen: Optional[int] = None, 103 | ): 104 | """ 105 | Arguments: 106 | x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None 107 | else (total_seqlen, nheads, headdim) 108 | cos, sin: (seqlen_rotary, rotary_dim / 2) 109 | interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead 110 | of 1st half and 2nd half (GPT-NeoX style). 111 | inplace: if True, apply rotary embedding in-place. 112 | seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. 113 | Most commonly used in inference when we have KV cache. 114 | cu_seqlens: (batch + 1,) or None 115 | max_seqlen: int 116 | Return: 117 | out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None 118 | else (total_seqlen, nheads, headdim) 119 | rotary_dim must be <= headdim 120 | Apply rotary embedding to the first rotary_dim of x. 121 | """ 122 | return ApplyRotaryEmb.apply( 123 | x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen 124 | ) 125 | 126 | 127 | # For backward compatibility 128 | apply_rotary_emb_func = apply_rotary_emb 129 | 130 | 131 | class ApplyRotaryEmbQKV_(torch.autograd.Function): 132 | @staticmethod 133 | def forward( 134 | ctx, 135 | qkv, 136 | cos, 137 | sin, 138 | cos_k=None, 139 | sin_k=None, 140 | interleaved=False, 141 | seqlen_offsets: Union[int, torch.Tensor] = 0, 142 | ): 143 | batch, seqlen, three, nheads, headdim = qkv.shape 144 | assert three == 3 145 | if cos_k is None and sin_k is None and qkv.is_contiguous(): 146 | # Call 1 kernel instead of 2 kernels 147 | # We need qkv to be contiguous so that when we reshape to combine (3, nheads) 148 | # dimensions, we get the same tensor 149 | # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") 150 | qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) 151 | apply_rotary( 152 | qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True 153 | ) 154 | else: 155 | cos_k = cos if cos_k is None else cos_k 156 | sin_k = sin if sin_k is None else sin_k 157 | q, k = qkv[:, :, 0], qkv[:, :, 1] 158 | apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True) 159 | apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True) 160 | ctx.save_for_backward(cos, sin, cos_k, sin_k) 161 | if isinstance(seqlen_offsets, int): 162 | ctx.save_for_backward(cos, sin, cos_k, sin_k) 163 | ctx.seqlen_offsets = seqlen_offsets 164 | else: 165 | ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets) 166 | ctx.seqlen_offsets = None 167 | ctx.interleaved = interleaved 168 | return qkv 169 | 170 | @staticmethod 171 | def backward(ctx, dqkv): 172 | seqlen_offsets = ctx.seqlen_offsets 173 | if seqlen_offsets is None: 174 | cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors 175 | else: 176 | cos, sin, cos_k, sin_k = ctx.saved_tensors 177 | if cos_k is None and sin_k is None and dqkv.is_contiguous(): 178 | # Call 1 kernel instead of 2 kernels 179 | # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) 180 | # dimensions, we get the same tensor 181 | dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d") 182 | apply_rotary( 183 | dqk, 184 | cos, 185 | sin, 186 | seqlen_offsets=seqlen_offsets, 187 | interleaved=ctx.interleaved, 188 | inplace=True, 189 | conjugate=True, 190 | ) 191 | else: 192 | cos_k = cos if cos_k is None else cos_k 193 | sin_k = sin if sin_k is None else sin_k 194 | dq, dk = dqkv[:, :, 0], dqkv[:, :, 1] 195 | apply_rotary( 196 | dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True 197 | ) 198 | apply_rotary( 199 | dk, 200 | cos_k, 201 | sin_k, 202 | seqlen_offsets, 203 | interleaved=ctx.interleaved, 204 | inplace=True, 205 | conjugate=True, 206 | ) 207 | return dqkv, None, None, None, None, None, None 208 | 209 | 210 | def apply_rotary_emb_qkv_( 211 | qkv, 212 | cos, 213 | sin, 214 | cos_k=None, 215 | sin_k=None, 216 | interleaved=False, 217 | seqlen_offsets: Union[int, torch.Tensor] = 0, 218 | ): 219 | """ 220 | Arguments: 221 | qkv: (batch_size, seqlen, 3, nheads, headdim) 222 | cos, sin: (seqlen, rotary_dim / 2) 223 | cos_k, sin_k: (seqlen, rotary_dim / 2), optional 224 | interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 225 | 1st half and 2nd half (GPT-NeoX style). 226 | seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. 227 | Most commonly used in inference when we have KV cache. 228 | Return: 229 | qkv: (batch_size, seqlen, 3, nheads, headdim) 230 | rotary_dim must be <= headdim 231 | Apply rotary embedding *inplace* to the first rotary_dim of Q and K. 232 | """ 233 | return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets) 234 | 235 | 236 | class ApplyRotaryEmbKV_(torch.autograd.Function): 237 | @staticmethod 238 | def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0): 239 | batch, seqlen, two, nheads, headdim = kv.shape 240 | assert two == 2 241 | k = kv[:, :, 0] 242 | apply_rotary( 243 | k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True 244 | ) 245 | if isinstance(seqlen_offsets, int): 246 | ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward 247 | ctx.seqlen_offsets = seqlen_offsets 248 | else: 249 | ctx.save_for_backward(cos, sin, seqlen_offsets) 250 | ctx.seqlen_offsets = None 251 | ctx.interleaved = interleaved 252 | return kv 253 | 254 | @staticmethod 255 | def backward(ctx, dkv): 256 | seqlen_offsets = ctx.seqlen_offsets 257 | if seqlen_offsets is None: 258 | cos, sin, seqlen_offsets = ctx.saved_tensors 259 | else: 260 | cos, sin = ctx.saved_tensors 261 | apply_rotary( 262 | dkv[:, :, 0], 263 | cos, 264 | sin, 265 | seqlen_offsets=seqlen_offsets, 266 | interleaved=ctx.interleaved, 267 | inplace=True, 268 | conjugate=True, 269 | ) 270 | return dkv, None, None, None, None 271 | 272 | 273 | apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply 274 | 275 | 276 | def apply_rotary_emb_kv_( 277 | kv, 278 | cos, 279 | sin, 280 | interleaved=False, 281 | seqlen_offsets: Union[int, torch.Tensor] = 0, 282 | ): 283 | """ 284 | Arguments: 285 | kv: (batch_size, seqlen, 2, nheads, headdim) 286 | cos, sin: (seqlen, rotary_dim / 2) 287 | interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 288 | 1st half and 2nd half (GPT-NeoX style). 289 | seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. 290 | Most commonly used in inference when we have KV cache. 291 | Return: 292 | kv: (batch_size, seqlen, 2, nheads, headdim) 293 | rotary_dim must be <= headdim 294 | Apply rotary embedding *inplace* to the first rotary_dim of K. 295 | """ 296 | return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets) 297 | 298 | 299 | class RotaryEmbedding(torch.nn.Module): 300 | """ 301 | The rotary position embeddings from RoFormer_ (Su et. al). 302 | A crucial insight from the method is that the query and keys are 303 | transformed by rotation matrices which depend on the relative positions. 304 | 305 | Other implementations are available in the Rotary Transformer repo_ and in 306 | GPT-NeoX_, GPT-NeoX was an inspiration 307 | 308 | .. _RoFormer: https://arxiv.org/abs/2104.09864 309 | .. _repo: https://github.com/ZhuiyiTechnology/roformer 310 | .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox 311 | 312 | If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). 313 | A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 314 | Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py 315 | """ 316 | 317 | def __init__( 318 | self, 319 | dim: int, 320 | base=10000.0, 321 | interleaved=False, 322 | scale_base=None, 323 | pos_idx_in_fp32=True, 324 | device=None, 325 | ): 326 | """ 327 | interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead 328 | of 1st half and 2nd half (GPT-NeoX style). 329 | pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, 330 | otherwise they might be in lower precision. 331 | This option was added because previously (before 2023-07-02), when we construct 332 | the position indices, we use the dtype of self.inv_freq. In most cases this would 333 | be fp32, but if the model is trained in pure bf16 (not mixed precision), then 334 | self.inv_freq would be bf16, and the position indices are also in bf16. 335 | Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the 336 | embeddings for some positions will coincide. 337 | To maintain compatibility with models previously trained in pure bf16, 338 | we add this option. 339 | """ 340 | super().__init__() 341 | self.dim = dim 342 | self.base = float(base) 343 | self.pos_idx_in_fp32 = pos_idx_in_fp32 344 | # Generate and save the inverse frequency buffer (non trainable) 345 | inv_freq = self._compute_inv_freq(device) 346 | self.register_buffer("inv_freq", inv_freq, persistent=False) 347 | self.interleaved = interleaved 348 | self.scale_base = scale_base 349 | scale = ( 350 | (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) 351 | if scale_base is not None 352 | else None 353 | ) 354 | self.register_buffer("scale", scale, persistent=False) 355 | 356 | self._seq_len_cached = 0 357 | self._cos_cached = None 358 | self._sin_cached = None 359 | self._cos_k_cached = None 360 | self._sin_k_cached = None 361 | 362 | def _compute_inv_freq(self, device=None): 363 | return 1.0 / ( 364 | self.base 365 | ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) 366 | ) 367 | 368 | def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): 369 | # Reset the tables if the sequence length has changed, 370 | # if we're on a new device (possibly due to tracing for instance), 371 | # or if we're switching from inference mode to training 372 | if ( 373 | seqlen > self._seq_len_cached 374 | or self._cos_cached is None 375 | or self._cos_cached.device != device 376 | or self._cos_cached.dtype != dtype 377 | or (self.training and self._cos_cached.is_inference()) 378 | ): 379 | self._seq_len_cached = seqlen 380 | # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 381 | # And the output of arange can be quite large, so bf16 would lose a lot of precision. 382 | # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. 383 | if self.pos_idx_in_fp32: 384 | t = torch.arange(seqlen, device=device, dtype=torch.float32) 385 | # We want fp32 here as well since inv_freq will be multiplied with t, and the output 386 | # will be large. Having it in bf16 will lose a lot of precision and cause the 387 | # cos & sin output to change significantly. 388 | # We want to recompute self.inv_freq if it was not loaded in fp32 389 | if self.inv_freq.dtype != torch.float32: 390 | inv_freq = self._compute_inv_freq(device=device) 391 | else: 392 | inv_freq = self.inv_freq 393 | else: 394 | t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) 395 | inv_freq = self.inv_freq 396 | # Don't do einsum, it converts fp32 to fp16 under AMP 397 | # freqs = torch.einsum("i,j->ij", t, self.inv_freq) 398 | freqs = torch.outer(t, inv_freq) 399 | if self.scale is None: 400 | self._cos_cached = torch.cos(freqs).to(dtype) 401 | self._sin_cached = torch.sin(freqs).to(dtype) 402 | else: 403 | power = ( 404 | torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) 405 | - seqlen // 2 406 | ) / self.scale_base 407 | scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") 408 | # We want the multiplication by scale to happen in fp32 409 | self._cos_cached = (torch.cos(freqs) * scale).to(dtype) 410 | self._sin_cached = (torch.sin(freqs) * scale).to(dtype) 411 | self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) 412 | self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) 413 | 414 | def forward( 415 | self, 416 | qkv: torch.Tensor, 417 | kv: Optional[torch.Tensor] = None, 418 | seqlen_offset: Union[int, torch.Tensor] = 0, 419 | max_seqlen: Optional[int] = None, 420 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 421 | """ 422 | qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, 423 | else it's just q of shape (batch, seqlen, nheads, headdim) 424 | kv: (batch, seqlen, 2, nheads, headdim) 425 | seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. 426 | Most commonly used in inference when we have KV cache. 427 | If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one 428 | should pass in max_seqlen, which will update the cos / sin cache up to that length. 429 | Apply rotary embedding *inplace* to qkv and / or kv. 430 | """ 431 | seqlen = qkv.shape[1] 432 | if max_seqlen is not None: 433 | self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) 434 | elif isinstance(seqlen_offset, int): 435 | self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) 436 | if kv is None: 437 | return apply_rotary_emb_qkv_( 438 | qkv, 439 | self._cos_cached, 440 | self._sin_cached, 441 | interleaved=self.interleaved, 442 | seqlen_offsets=seqlen_offset, 443 | ) 444 | else: 445 | q = qkv 446 | q = apply_rotary_emb_func( 447 | q, 448 | self._cos_cached, 449 | self._sin_cached, 450 | interleaved=self.interleaved, 451 | inplace=True, 452 | seqlen_offsets=seqlen_offset, 453 | ) 454 | 455 | if kv.ndim > 4: 456 | kv = apply_rotary_emb_kv_( 457 | kv, 458 | self._cos_cached, 459 | self._sin_cached, 460 | interleaved=self.interleaved, 461 | seqlen_offsets=seqlen_offset, 462 | ) 463 | else: # only k 464 | kv = apply_rotary_emb_func( 465 | kv, 466 | self._cos_cached, 467 | self._sin_cached, 468 | interleaved=self.interleaved, 469 | inplace=True, 470 | seqlen_offsets=seqlen_offset, 471 | ) 472 | return q, kv 473 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 |

Rodimus*: Breaking the Accuracy-Efficiency Trade-Off with Efficient Attentions 6 |

7 |
If you like our project, please give us a star ⭐ on GitHub for the latest update.
8 | 9 |
10 | 11 | 12 | [![hf](https://img.shields.io/badge/🤗-Hugging%20Face-blue.svg)](https://huggingface.co/) 13 | [![ModelScope](https://img.shields.io/badge/🤖-ModelScope-3771C8.svg)](https://modelscope.cn) 14 | [![ICLR](https://img.shields.io/badge/ICLR-2025-orange?logo=iclryear)](https://openreview.net/forum?id=IIVYiJ1ggK) 15 | [![License](https://img.shields.io/badge/Code%20License-Apache2.0-yellow)](https://choosealicense.com/licenses/apache-2.0/) 16 |
17 | 18 | ## Overview 19 | 20 | We propose Rodimus*, including Rodimus and Rodimus+, which tries to break the accuracy-efficency trade-off existing in Vanilla tranformers by introducing several innovative features. 21 | 22 | **Rodimus:** 23 | * Linear attention-based, purely recurrent model. 24 | * Incorporates Data-Dependent Tempered Selection (DDTS) for semantic compression. 25 | * Reduced memory usage. 26 | 27 | **Rodimus+:** 28 | * Hybrid model combining Rodimus with Sliding Window Shared-Key Attention (SW-SKA). 29 | * Enhances semantic, token, and head compression. 30 | 31 |
32 | 33 |
34 | 35 | **Rodimus+-Coder:** 36 | * We train and open-source the lightweight Rodimus+-Coder model, available in 1.6B and 4B sizes, achieving performance surpassing SOTA models of similar sizes. 37 | 38 |
39 | 40 |
41 | 42 | ## Highlights 43 | 44 | * **Constant memory footprint but better language modeling performance.** 45 |
46 | 47 |
48 | 49 | * **Better scaling performance than Transformer.** 50 |
51 | 52 |
53 | 54 | * **A real lite model, without memory complexity O(T) in KV cache.** 55 | 56 | ## Pretrained Checkpoints 57 | 58 | ### Benchmark Checkpoints 59 | 60 | > This checkpoints completed training before submitting the paper, used to reproduce the benchmarks in the paper. 61 | > 62 | > If you want to use the more practical model, we strongly recommand you to download the checkpionts in **Rodimus+-Coder**. 63 | 64 |
65 | 66 | | **Model (2024/10/01)** | **#Total Params** | **Training Tokens** | **Context Length** | **Download** | 67 | | :----------------: | :---------------: | :----------------: | :----------: | :-------------: | 68 | | Rodimus-1.4B-Base | 1.4B | 500B | 2K | [🤗 HuggingFace](https://huggingface.co/codefuse-admin/rodimus_1B4_base_20241001)
[🤖 ModelScope](https://www.modelscope.cn/models/codefuse-ai/rodimus_1B4_base_20241001) | 69 | | Rodimus+-1.6B-Base | 1.6B | 1T | 2K | [🤗 HuggingFace](https://huggingface.co/codefuse-ai/rodimus_plus_1B6_base_20241001)
[🤖 ModelScope](https://www.modelscope.cn/models/codefuse-ai/rodimus_plus_1B6_base_20241001) | 70 | | Rodimus+-Coder-1.6B-Base-20241001 | 1.6B | 2.5T | 4K | [🤗 HuggingFace](https://huggingface.co/codefuse-ai/rodimus_plus_coder_1B6_base_20241001)
[🤖 ModelScope](https://www.modelscope.cn/models/codefuse-ai/rodimus_plus_coder_1B6_base_20241001) | 71 | 72 |
73 | 74 | The `Rodimus+-Coder-1.6B-Base-20241001` is the model enhanced by multi-stage training with math and code datasets in the paper. 75 | 76 | ### Rodimus+-Coder Checkpoints 77 | 78 | You can download the following table to see the various parameters for your use case. If you are located in mainland China, we also provide the model on modelscope.cn to speed up the download process. 79 | 80 |
81 | 82 | | **Model** | **#Total Params** | **Training Tokens** | **Context Length** | **Download** | 83 | | :----------------: | :---------------: | :----------------: | :----------------: | :----------: | 84 | | Rodimus+-Coder-1.6B-Base | 1.6B | 8.2T | 4K | [🤗 HuggingFace](https://huggingface.co/codefuse-ai/Rodimus-Plus-Coder-1.6B-Base)
[🤖 ModelScope](https://modelscope.cn/models/codefuse-ai/Rodimus-Plus-Coder-1.6B-Base) | 85 | | Rodimus+-Coder-1.6B-Chat | 1.6B | - | 4K | [🤗 HuggingFace](https://huggingface.co/codefuse-ai/Rodimus-Plus-Coder-1.6B-Chat)
[🤖 ModelScope](https://modelscope.cn/models/codefuse-ai/Rodimus-Plus-Coder-1.6B-Chat) | 86 | | Rodimus+-Coder-4B-Base | 4B | 8.2T | 4K | [🤗 HuggingFace](https://huggingface.co/codefuse-ai/Rodimus-Plus-Coder-4B-Base)
[🤖 ModelScope](https://modelscope.cn/models/codefuse-ai/Rodimus-Plus-Coder-4B-Base) | 87 | | Rodimus+-Coder-4B-Chat | 4B | - | 4K | [🤗 HuggingFace](https://huggingface.co/codefuse-ai/Rodimus-Plus-Coder-4B-Chat)
[🤖 ModelScope](https://modelscope.cn/models/codefuse-ai/Rodimus-Plus-Coder-4B-Chat) | 88 | 89 |
90 | 91 | ## Rodimus+-Coder Evaluation 92 | 93 | We re-evaluate the metrics of the Qwen series models, and the metrics of other series models are quoted from the original paper. For detailed evaluation code, please refer to the evaluation method of Ling-Coder-Lite in [CodeFuse-Evaluation](https://github.com/codefuse-ai/codefuse-evaluation). 94 | 95 | ### Rodimus+-Coder-Base 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 |
DatasetsQwen2.5-Coder-1.5BRodimus+-Coder-1.6B-BaseGemma2-2B-PTQwen2.5-Coder-3BRodimus+-Coder-4B-BaseGemma3-4B-PTQwen2.5-Coder-7B
Coding Tasks
HumanEval41.551.219.551.860.436.060.4
HumanEval+34.845.1-40.952.4-50.6
MBPP57.251.231.062.664.646.070.0
MBPP+66.162.2-65.971.4-70.1
BCBCOMPLETION21.617.9-26.230.8-30.4
MultiPL-E46.152.5-49.460.7-56.9
CRUXEval38.545.1-44.656.4-56.8
Coding Avg.43.746.5-48.856.7-56.4
General Tasks
C-EVAL55.256.7-65.370.2-69.1
CMMLU54.552.3-65.468.3-72.7
MMLU55.551.152.263.362.659.670.5
BBH21.846.842.432.561.950.967.3
General Avg.46.851.7-56.665.8-69.9
Mathematics Tasks
GSM8K60.468.725.072.178.538.483.4
MATH23.729.016.431.937.024.242.2
Math Avg.41.948.920.752.057.831.362.8
Overall
Overall44.448.4-51.759.6-61.6
291 | 292 | ### Rodimus+-Coder-Chat 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 549 | 550 | 551 | 552 | 553 | 554 | 555 | 556 | 557 | 558 | 559 | 560 | 561 | 562 | 563 | 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 |
DatasetsQwen2.5-Coder-1.5B-InstructRodimus+-Coder-1.6B-ChatGemma2-2B-ITQwen2.5-Coder-InstructPhi-4-Mini-3.8BRodimus+-Coder-4B-ChatGemma3-4B-ITQwen2.5-Coder-7B-Instruct
Coding Tasks
HumanEval64.676.820.179.974.486.671.387.2
HumanEval+63.473.8-80.568.382.9-82.3
MBPP51.059.036.659.265.368.063.275.8
MBPP+53.066.4-61.963.868.5-75.1
LCB(24.08-24.11)4.010.9-13.0-13.9-22.8
BCBINSTRUCT10.821.5-21.733.826.6-30.6
HumanEval-Mul50.857.3-67.4-70.6-76.1
MBPP-Mul43.452.4-53.4-59.6-61.4
MBXP-EN55.875.5-76.0-87.3-87.7
MBXP-CN48.875.0-68.7-84.3-83.5
CRUXEval28.655.0-51.6-63.2-69.3
HumanEvalFix38.952.6-55.5-68.8-69.3
Spider61.271.4-71.842.273.5-82.0
Coding Avg.44.257.5-58.5-65.7-69.5
General Tasks
C-EVAL51.550.8-62.0-61.6-66.4
CMMLU45.250.5-60.1-62.0-64.9
MMLU52.049.356.161.767.357.558.166.1
BBH24.258.741.457.370.463.772.259.1
General Avg.43.252.3-60.3-61.2-64.1
Mathematics Tasks
GSM8K54.468.562.673.588.679.289.279.5
MATH38.133.527.244.164.044.175.660.8
Math Avg.46.251.044.958.868.861.782.470.1
Overall
Overall44.255.8-58.9-64.3-68.4
572 | 573 | ## Quick Starts 574 | 575 | ### Installation 576 | 577 | 1. The latest version of `transformers` is recommended (at least 4.42.0). 578 | 2. We evaluate our models with `python=3.8` and `torch==2.1.2`. 579 | 3. If you use Rodimus, you need to install `flash-linear-attention`, `causal_conv1d` and `triton>=2.2.0`. If you use Rodimus+, you need to further install `flash-attention`. 580 | 581 | ### Examples 582 | 583 | In `examples/generation_script.py`, we show a code snippet to show you how to use the model to generate: 584 | 585 | ```python 586 | import os 587 | import torch 588 | from modeling_rodimus import RodimusForCausalLM 589 | from tokenization_rodimus_fast import RodimusTokenizer 590 | 591 | # load model 592 | ckpt_dir = "model_path" 593 | tokenizer = RodimusTokenizer.from_pretrained(ckpt_dir) 594 | model = RodimusForCausalLM.from_pretrained( 595 | ckpt_dir, 596 | torch_dtype=torch.float16, 597 | device_map="cuda" 598 | ).eval() 599 | 600 | # inference 601 | input_prompt = "你好!你是谁?" 602 | model_inputs = tokenizer(input_prompt, return_tensors="pt").to(model.device) 603 | outputs = model.generate(**model_inputs, max_length=32) 604 | response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] 605 | 606 | print(response) 607 | ``` 608 | 609 | In `examples/chat_script.py`, we further show how to chat with Rodimus+: 610 | 611 | ```python 612 | import os 613 | import torch 614 | from modeling_rodimus import RodimusForCausalLM 615 | from tokenization_rodimus_fast import RodimusTokenizer 616 | 617 | # load model 618 | ckpt_dir = "model_path" 619 | tokenizer = RodimusTokenizer.from_pretrained(ckpt_dir) 620 | model = RodimusForCausalLM.from_pretrained( 621 | ckpt_dir, 622 | torch_dtype=torch.float16, 623 | device_map="cuda" 624 | ).eval() 625 | 626 | # inference 627 | input_prompt = "简单介绍一下大型语言模型。" 628 | messages = [ 629 | {"role": "HUMAN", "content": input_prompt} 630 | ] 631 | 632 | text = tokenizer.apply_chat_template( 633 | messages, 634 | system='You are Rodimus$+$, created by AntGroup. You are a helpful assistant.', 635 | tokenize=False, 636 | ) 637 | print(text) 638 | model_inputs = tokenizer(text, return_tensors="pt").to(model.device) 639 | outputs = model.generate(**model_inputs, max_length=2048) 640 | response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] 641 | 642 | print(response) 643 | ``` 644 | 645 | ## Citation 646 | 647 | If you find our work helpful, feel free to give us a cite. 648 | 649 | ``` 650 | @inproceedings{ 651 | he2025rodimus, 652 | title={Rodimus*: Breaking the Accuracy-Efficiency Trade-Off with Efficient Attentions}, 653 | author={Zhihao He and Hang Yu and Zi Gong and Shizhan Liu and Jianguo Li and Weiyao Lin}, 654 | booktitle={The Thirteenth International Conference on Learning Representations}, 655 | year={2025}, 656 | url={https://openreview.net/forum?id=IIVYiJ1ggK} 657 | } 658 | ``` 659 | -------------------------------------------------------------------------------- /modules/chat_format.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import dataclasses 3 | import logging 4 | import re 5 | import uuid 6 | from copy import deepcopy 7 | from enum import IntEnum, auto 8 | from typing import Dict, List, Optional, Tuple 9 | from transformers.utils import TensorType, logging 10 | 11 | logger = logging.get_logger(__name__) 12 | 13 | 14 | class PromptStyle(IntEnum): 15 | '''Prompt styles.''' 16 | RODIMUS_CHAT = auto() 17 | CHATML = auto() 18 | LLAMA2 = auto() 19 | CHATGLM = auto() 20 | CHATGLM3 = auto() 21 | BAICHUAN2 = auto() 22 | 23 | 24 | @dataclasses.dataclass 25 | class Chat: 26 | 27 | id: str = None 28 | name: Optional[str] = None 29 | prompt_style: Optional[PromptStyle] = None 30 | 31 | system_template: str = 'SYSTEM{}' 32 | system_message: str = '' 33 | 34 | role_human: str = 'HUMAN' 35 | role_assistant: str = 'ASSISTANT' 36 | role_observation: str = 'OBSERVATION' 37 | role_template: str = '{}' 38 | 39 | turn_start: str = '' 40 | human_end: str = '' 41 | assistant_start: str = '' 42 | assistant_end: str = '' 43 | assistant_end_ids: Optional[List[int]] = None 44 | general_role_end: str = '' 45 | 46 | tool_template = '{}' 47 | code_template = '{}' 48 | arithemetic_templte = '{}' 49 | image_template = '{}' 50 | 51 | messages: List[Tuple[str, str]] = () 52 | 53 | offset: int = 0 54 | 55 | source: Optional[str] = None 56 | lang: Optional[str] = None 57 | topic: Optional[str] = None 58 | 59 | origin_json: Optional[dict] = None 60 | 61 | @classmethod 62 | def from_json( 63 | cls, 64 | input: dict, 65 | name: Optional[str] = None, 66 | prompt_style: Optional[PromptStyle] = None, 67 | ): 68 | _id = input.get('id') 69 | if name: 70 | _name = name 71 | else: 72 | _name = input.get('name') 73 | source = input.get('source') 74 | lang = input.get('lang') 75 | topic = input.get('topic') 76 | kwargs = {} 77 | if 'system_template' in input: 78 | kwargs['system_template'] = input['system_template'] 79 | if 'system_message' in input: 80 | kwargs['system_message'] = input['system_message'] 81 | 82 | chat = cls( 83 | id=_id, 84 | name=_name, 85 | prompt_style=prompt_style, 86 | source=source, 87 | lang=lang, 88 | topic=topic, 89 | origin_json=deepcopy(input), 90 | **kwargs, 91 | ) 92 | if 'messages' in input: 93 | for msg in input['messages']: 94 | if msg['role'] == 'HUMAN': 95 | role = chat.role_human 96 | elif msg['role'] == 'OBSERVATION': 97 | role = chat.role_observation 98 | elif msg['role'] == 'ASSISTANT': 99 | role = chat.role_assistant 100 | else: 101 | raise ValueError(f'不支持数据集中的 role: {msg["role"]}') 102 | 103 | chat.append_message(role, msg['content']) 104 | 105 | elif 'turns' in input: 106 | for turn in input['turns']: 107 | if 'HUMAN' in turn: 108 | content = turn['HUMAN'] 109 | chat.append_message(chat.role_human, content) 110 | if 'OBSERVATION' in turn: 111 | content = turn['OBSERVATION'] 112 | chat.append_message(chat.role_observation, content) 113 | if 'ASSISTANT' in turn: 114 | content = turn['ASSISTANT'] 115 | chat.append_message(chat.role_assistant, content) 116 | 117 | return chat 118 | 119 | @classmethod 120 | def from_pack( 121 | cls, 122 | packs: Dict[str, List[str]], 123 | name: str, 124 | prompt_style: Optional[PromptStyle] = None, 125 | ) -> list: 126 | chat = cls(name=name, prompt_style=prompt_style) 127 | packs = cls._format_packs(packs) 128 | 129 | sys_pattern = re.compile( 130 | chat.system_template.format(r'(.*?)'), re.DOTALL) 131 | turn_pattern = re.compile(chat.turn_start.format(r'(\d+)'), re.DOTALL) 132 | human_pattern = re.compile(chat.role_template.format( 133 | chat.role_human).strip(), re.DOTALL) 134 | observe_pattern = re.compile(chat.role_template.format( 135 | chat.role_observation).strip(), re.DOTALL) 136 | assistant_pattern = re.compile(chat.role_template.format( 137 | chat.role_assistant).strip(), re.DOTALL) 138 | 139 | chats = [] 140 | for input, output in zip(packs['input'], packs['output']): 141 | # system message 142 | sys_match = sys_pattern.search(input) 143 | if sys_match and sys_match.group(0): 144 | # system 指令只在首轮, 新增 chat 对象 145 | if len(chat.messages) > 0: 146 | chats.append(chat) 147 | chat = cls(name=name, prompt_style=prompt_style) 148 | 149 | input = input[sys_match.end():] 150 | chat.system_message = sys_match.group(1) 151 | 152 | # turn start 153 | turn_match = turn_pattern.search(input) 154 | if turn_match and turn_match.group(0): 155 | if name == 'chatglm2': 156 | round_start = 1 157 | else: 158 | round_start = 0 159 | 160 | if all( 161 | [ 162 | len(turn_match.groups()) > 0, 163 | int(turn_match.group(1)) == round_start, 164 | len(chat.messages) > 0, 165 | ] 166 | ): 167 | chats.append(chat) 168 | chat = cls(name=name, prompt_style=prompt_style) 169 | 170 | input = input[turn_match.end():] 171 | 172 | human_iter = human_pattern.finditer(input) 173 | observe_iter = observe_pattern.finditer(input) 174 | assistant_iter = assistant_pattern.finditer(input) 175 | human_match = next(human_iter, None) 176 | observe_match = next(observe_iter, None) 177 | assistant_match = next(assistant_iter, None) 178 | 179 | if not human_match and not observe_match: 180 | chat.append_message(chat.role_human, input) 181 | 182 | while human_match or observe_match: 183 | next_human_match = next(human_iter, None) 184 | next_observe_match = next(observe_iter, None) 185 | input = cls._append_human_observation( 186 | chat, 187 | input, 188 | human_match=human_match, 189 | next_human_match=next_human_match, 190 | observe_match=observe_match, 191 | next_observe_match=next_observe_match, 192 | assistant_match=assistant_match, 193 | ) 194 | 195 | human_match = next_human_match 196 | observe_match = next_observe_match 197 | next_human_match = next(human_iter, None) 198 | next_observe_match = next(observe_iter, None) 199 | 200 | if output: 201 | chat.append_message(chat.role_assistant, output) 202 | 203 | if chat.messages: 204 | chats.append(chat) 205 | 206 | return chats 207 | 208 | @classmethod 209 | def _append_human_observation( 210 | cls, 211 | chat, 212 | input: str, 213 | human_match: Optional[re.Match] = None, 214 | next_human_match: Optional[re.Match] = None, 215 | observe_match: Optional[re.Match] = None, 216 | next_observe_match: Optional[re.Match] = None, 217 | assistant_match: Optional[re.Match] = None, 218 | ) -> str: 219 | if observe_match: 220 | if observe_match.span()[0] > observe_match.span()[0]: 221 | human_str = input[observe_match.span()[1]: observe_match.span()[0]] 222 | observe_str = input[observe_match.span()[1]: assistant_match.span()[0]] 223 | chat.append_message(chat.role_human, human_str.strip()) 224 | input_end = observe_match.span()[1] 225 | if observe_match.span()[0] < next_human_match.span()[0]: 226 | chat.append_message( 227 | chat.role_observation, observe_str.strip()) 228 | input_end = assistant_match.span()[1] 229 | else: 230 | human_str = input[observe_match.span()[1]: assistant_match.span()[0]] 231 | observe_str = input[observe_match.span()[1]: observe_match.span()[0]] 232 | chat.append_message(chat.role_observation, observe_str.strip()) 233 | input_end = observe_match.span()[1] 234 | if observe_match.span()[0] < next_observe_match.span()[0]: 235 | chat.append_message(chat.role_human, human_str.strip()) 236 | input_end = assistant_match.span()[1] 237 | else: 238 | if assistant_match: 239 | human_str = input[human_match.span( 240 | )[1]: assistant_match.span()[0]] 241 | input_end = assistant_match.span()[1] 242 | else: 243 | human_str = input[human_match.span()[1]:] 244 | input_end = len(input) 245 | chat.append_message(chat.role_human, human_str.strip()) 246 | 247 | return input[input_end:] 248 | 249 | @classmethod 250 | def from_inout( 251 | cls, 252 | sample: Dict[str, str], 253 | name: str, 254 | prompt_style: Optional[PromptStyle] = None, 255 | ): 256 | chat = cls(name=name, prompt_style=prompt_style) 257 | input = sample['input'] 258 | output = sample['output'] 259 | 260 | sys_pattern = re.compile( 261 | chat.system_template.format(r'(.*?)'), re.DOTALL) 262 | turn_pattern = re.compile(chat.turn_start.format(r'(\d+)'), re.DOTALL) 263 | human_pattern = re.compile(chat.role_template.format( 264 | chat.role_human).strip(), re.DOTALL) 265 | observe_pattern = re.compile(chat.role_template.format( 266 | chat.role_observation).strip(), re.DOTALL) 267 | assistant_pattern = re.compile(chat.role_template.format( 268 | chat.role_assistant).strip(), re.DOTALL) 269 | 270 | input = turn_pattern.sub('', input) 271 | 272 | sys_match = sys_pattern.search(input) 273 | if sys_match and sys_match.group(0): 274 | input = input[sys_match.end():] 275 | chat.system_message = sys_match.group(1) 276 | 277 | human_iter = human_pattern.finditer(input) 278 | observe_iter = observe_pattern.finditer(input) 279 | assistant_iter = assistant_pattern.finditer(input) 280 | human_match = next(human_iter, None) 281 | observe_match = next(observe_iter, None) 282 | assistant_match = next(assistant_iter, None) 283 | next_human_match = next(human_iter, None) 284 | next_observe_match = next(observe_iter, None) 285 | 286 | while any( 287 | [ 288 | human_match, 289 | observe_match, 290 | assistant_match, 291 | ] 292 | ): 293 | while any( 294 | [ 295 | human_match and human_match.span( 296 | )[0] < assistant_match.span()[0], 297 | observe_match and observe_match.span( 298 | )[0] < assistant_match.span()[0], 299 | next_human_match and next_human_match.span()[0] < assistant_match.span()[ 300 | 0], 301 | next_observe_match and next_observe_match.span()[0] < assistant_match.span()[ 302 | 0], 303 | ] 304 | ): 305 | if not input: 306 | break 307 | 308 | cls._append_human_observation( 309 | chat, 310 | input, 311 | human_match=human_match, 312 | next_human_match=next_human_match, 313 | observe_match=observe_match, 314 | next_observe_match=next_observe_match, 315 | assistant_match=assistant_match, 316 | ) 317 | 318 | human_match = next_human_match 319 | observe_match = next_observe_match 320 | next_human_match = next(human_iter, None) 321 | next_observe_match = next(observe_iter, None) 322 | 323 | if assistant_match and assistant_match.span(): 324 | if observe_match: 325 | if observe_match.span() and observe_match.span()[0] < human_match.span()[0]: 326 | assistant_str = input[assistant_match.span()[1]: observe_match.span()[ 327 | 0]] 328 | elif human_match: 329 | if human_match.span(): 330 | assistant_str = input[assistant_match.span()[1]: human_match.span()[ 331 | 0]] 332 | else: 333 | assistant_str = input[assistant_match.span()[1]:] 334 | 335 | if assistant_str: 336 | chat.append_message(chat.role_assistant, assistant_str) 337 | 338 | assistant_match = next(assistant_iter, None) 339 | 340 | if output: 341 | chat.append_message(chat.role_assistant, output) 342 | 343 | return chat 344 | 345 | def __hash__(self): 346 | return hash(self.id) 347 | 348 | def __post_init__(self): 349 | self.id = str(uuid.uuid4()) 350 | if not self.messages: 351 | self.messages = [] 352 | 353 | if not self.name and not self.prompt_style: 354 | raise ValueError 355 | 356 | if not self.name and self.prompt_style == PromptStyle.RODIMUS_CHAT: 357 | logger.info( 358 | "The input parameter of the Chat object does not have `name`. By default, the `name` is RODIMUS_CHAT', format: \n" 359 | f'role_human: {self.role_human}\n' 360 | f'role_assistant: {self.role_assistant}\n' 361 | f'role_observation: {self.role_observation}\n' 362 | f'role_template: {self.role_template}\n' 363 | f'turn_start: {self.turn_start}\n' 364 | f'human_end: {self.human_end}\n' 365 | f'assistant_start: {self.assistant_start}\n' 366 | f'assistant_end: {self.assistant_end}\n' 367 | f'assistant_end_ids: {self.assistant_end_ids}\n' 368 | f'general_role_end: {self.general_role_end}\n' 369 | f'tool_template: {self.tool_template}\n' 370 | f'code_template: {self.code_template}\n' 371 | f'arithemetic_templte: {self.arithemetic_templte}\n' 372 | f'image_template: {self.image_template}\n' 373 | f'\n入参 `name` 支持: ``' 374 | ) 375 | return 376 | 377 | if self.name in ['chatglm1', 'chatglm2'] or self.prompt_style == PromptStyle.CHATGLM: 378 | self.prompt_style = PromptStyle.CHATGLM 379 | self.role_template = '{}' 380 | self.role_human = '问:' 381 | self.role_assistant = '答:' 382 | self.turn_start = '[Round {}]\n' 383 | if self.name == 'chatglm1': 384 | self.general_role_end = '\n' 385 | else: 386 | self.general_role_end = '\n\n' 387 | 388 | elif self.name == 'chatglm3' or self.prompt_style == PromptStyle.CHATGLM3: 389 | self.prompt_style = PromptStyle.CHATGLM3 390 | self.system_template = '<|system|>\n {}' 391 | self.role_human = '<|user|>\n ' 392 | self.role_assistant = '<|assistant|>\n ' 393 | self.role_template = '{}' 394 | 395 | elif self.name == 'llama2' or self.prompt_style == PromptStyle.LLAMA2: 396 | self.prompt_style = PromptStyle.LLAMA2 397 | self.role_template = '{}' 398 | self.system_template = '[INST] <>\n{}\n<>\n\n' 399 | self.role_human = '[INST] ' 400 | self.role_assistant = '[/INST] ' 401 | self.human_end = ' ' 402 | self.assistant_end = ' ' 403 | 404 | elif self.name == 'qwen': 405 | self.prompt_style = PromptStyle.CHATML 406 | self.role_template = '{}' 407 | self.system_template = '<|im_start|>system\n{}' 408 | if not self.system_message: 409 | self.system_message = 'You are a helpful assistant.' 410 | self.role_human = '<|im_start|>user\n' 411 | self.role_assistant = '<|im_start|>assistant\n' 412 | self.general_role_end = '<|im_end|>\n' 413 | 414 | elif self.name == 'baichuan': 415 | self.prompt_style = PromptStyle.BAICHUAN2 416 | self.role_template = '{}' 417 | self.system_template = '{}' 418 | self.role_human = '' 419 | self.role_assistant = '' 420 | 421 | if not self.system_template: 422 | self.system_template = '{}' 423 | 424 | def readable_messages(self) -> str: 425 | pass 426 | 427 | @property 428 | def prompt_str(self) -> str: 429 | return f'{self.prompt_inout["input"]}{self.prompt_inout["output"]}' 430 | 431 | @classmethod 432 | def _format_packs(cls, packs: Dict[str, List[str]]) -> Dict[str, List[str]]: 433 | _packs = copy.deepcopy(packs) 434 | if len(_packs['input']) - 1 == len(_packs['output']): 435 | _packs['output'].append('') 436 | 437 | if len(_packs['input']) != len(_packs['output']): 438 | print(packs) 439 | raise ValueError( 440 | '输入 input 和 output 数量不匹配, ' 441 | f'input num: {len(packs["input"])}, ' 442 | f'output num: {len(packs["output"])}' 443 | ) 444 | 445 | return _packs 446 | 447 | @property 448 | def prompt_inout(self) -> Dict[str, str]: 449 | packs = self._format_packs(self.prompt_pack) 450 | 451 | prompt_input = ''.join([f'{x}{y}' for x, y in zip( 452 | packs['input'][:-1], packs['output'][:-1])]) 453 | prompt_input += packs['input'][-1] 454 | prompt_output = packs['output'][-1] 455 | 456 | return { 457 | 'input': prompt_input, 458 | 'output': prompt_output, 459 | } 460 | 461 | @property 462 | def prompt_pack(self) -> Dict[str, List[str]]: 463 | inputs = [] 464 | outputs = [] 465 | 466 | system_prompt = '' 467 | if self.system_message: 468 | system_prompt = self.system_template.format(self.system_message) 469 | 470 | if system_prompt: 471 | ret = system_prompt + self.general_role_end 472 | else: 473 | ret = '' 474 | 475 | if self.name in ['chatglm2']: 476 | round_start = 1 477 | else: 478 | round_start = 0 479 | 480 | for i, (role, message) in enumerate(self.messages): 481 | if self.name in ['chatglm1', 'chatglm2']: 482 | if i % 2 == 0: 483 | ret += self.turn_start.format(i // 2 + round_start) 484 | 485 | role_end = self.general_role_end 486 | if role == self.role_assistant and self.assistant_end: 487 | role_end = self.assistant_end 488 | elif self.human_end: 489 | role_end = self.human_end 490 | 491 | ret += self.role_template.format(role) + message + role_end 492 | 493 | if role == self.role_assistant: 494 | if not message: 495 | outputs.append('') 496 | else: 497 | outputs.append(message + role_end) 498 | # input 需要连接 assistant role 499 | inputs[-1] += ret[: -len(message + role_end)] 500 | elif all( 501 | [ 502 | role == self.role_observation, 503 | len(self.messages) > 1, 504 | self.messages[i - 1][0] != self.role_assistant, 505 | ] 506 | ): 507 | continue 508 | else: 509 | inputs.append(ret) 510 | ret = '' 511 | 512 | if i == len(self.messages) - 1 and role != self.role_assistant: 513 | inputs[-1] += self.role_template.format( 514 | self.role_assistant).strip() 515 | 516 | return { 517 | 'input': inputs, 518 | 'output': outputs, 519 | } 520 | 521 | @property 522 | def turns_num(self) -> int: 523 | return sum([1 if msg[0] == self.role_human else 0 for msg in self.messages]) 524 | 525 | def to_json(self) -> dict: 526 | turns = [] 527 | messages = [] 528 | turn = {} 529 | for msg in self.messages: 530 | if msg[0] == self.role_assistant: 531 | messages.append({'role': 'ASSISTANT', 'content': msg[1]}) 532 | turn['ASSISTANT'] = msg[1] 533 | turns.append(turn) 534 | turn = {} 535 | 536 | if msg[0] == self.role_human: 537 | messages.append({'role': 'HUMAN', 'content': msg[1]}) 538 | turn['HUMAN'] = msg[1] 539 | 540 | if msg[0] == self.role_observation: 541 | messages.append({'role': 'OBSERVATION', 'content': msg[1]}) 542 | turn['OBSERVATION'] = msg[1] 543 | 544 | if self.messages[-1][0] == self.role_human: 545 | messages.append({'role': 'ASSISTANT', 'content': ''}) 546 | turn['ASSISTANT'] = '' 547 | turns.append(turn) 548 | 549 | result = self.origin_json or {} 550 | result.update( 551 | { 552 | 'id': self.id, 553 | 'name': self.name, 554 | 'source': self.source, 555 | 'lang': self.lang, 556 | 'topic': self.topic, 557 | 'system_template': self.system_template, 558 | 'system_message': self.system_message, 559 | 'turns': turns, 560 | 'messages': messages, 561 | } 562 | ) 563 | 564 | return result 565 | 566 | def set_system_message(self, system_message: str): 567 | self.system_message = system_message 568 | 569 | def append_message(self, role: str, message: str): 570 | if not message: 571 | message = '' 572 | self.messages.append([role, message]) 573 | 574 | def to_openai_api_messages(self) -> List[dict]: 575 | ret = [{'role': 'system', 'content': self.system_message}] 576 | 577 | for i, (_, msg) in enumerate(self.messages[self.offset:]): 578 | if i % 2 == 0: 579 | ret.append({'role': 'user', 'content': msg}) 580 | else: 581 | if msg is not None: 582 | ret.append({'role': 'assistant', 'content': msg}) 583 | return ret 584 | 585 | def copy(self): 586 | return copy.deepcopy(self) 587 | -------------------------------------------------------------------------------- /ops/layernorm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | # Implement residual + layer_norm / rms_norm. 3 | 4 | # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html 5 | # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. 6 | # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. 7 | # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. 8 | 9 | import math 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.cuda.amp import custom_fwd, custom_bwd 14 | 15 | import triton 16 | import triton.language as tl 17 | 18 | 19 | def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): 20 | dtype = x.dtype 21 | if upcast: 22 | weight = weight.float() 23 | bias = bias.float() if bias is not None else None 24 | if upcast: 25 | x = x.float() 26 | residual = residual.float() if residual is not None else residual 27 | if residual is not None: 28 | x = (x + residual).to(x.dtype) 29 | out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( 30 | dtype 31 | ) 32 | return out if not prenorm else (out, x) 33 | 34 | 35 | def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): 36 | dtype = x.dtype 37 | if upcast: 38 | weight = weight.float() 39 | bias = bias.float() if bias is not None else None 40 | if upcast: 41 | x = x.float() 42 | residual = residual.float() if residual is not None else residual 43 | if residual is not None: 44 | x = (x + residual).to(x.dtype) 45 | rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) 46 | out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) 47 | out = out.to(dtype) 48 | return out if not prenorm else (out, x) 49 | 50 | 51 | @triton.autotune( 52 | configs=[ 53 | triton.Config({}, num_warps=1), 54 | triton.Config({}, num_warps=2), 55 | triton.Config({}, num_warps=4), 56 | triton.Config({}, num_warps=8), 57 | triton.Config({}, num_warps=16), 58 | triton.Config({}, num_warps=32), 59 | ], 60 | key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], 61 | ) 62 | # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) 63 | # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) 64 | @triton.jit 65 | def _layer_norm_fwd_1pass_kernel( 66 | X, # pointer to the input 67 | Y, # pointer to the output 68 | W, # pointer to the weights 69 | B, # pointer to the biases 70 | RESIDUAL, # pointer to the residual 71 | RESIDUAL_OUT, # pointer to the residual 72 | Mean, # pointer to the mean 73 | Rstd, # pointer to the 1/std 74 | stride_x_row, # how much to increase the pointer when moving by 1 row 75 | stride_y_row, 76 | stride_res_row, 77 | stride_res_out_row, 78 | N, # number of columns in X 79 | eps, # epsilon to avoid division by zero 80 | IS_RMS_NORM: tl.constexpr, 81 | BLOCK_N: tl.constexpr, 82 | HAS_RESIDUAL: tl.constexpr, 83 | STORE_RESIDUAL_OUT: tl.constexpr, 84 | HAS_BIAS: tl.constexpr, 85 | ): 86 | # Map the program id to the row of X and Y it should compute. 87 | row = tl.program_id(0) 88 | X += row * stride_x_row 89 | Y += row * stride_y_row 90 | if HAS_RESIDUAL: 91 | RESIDUAL += row * stride_res_row 92 | if STORE_RESIDUAL_OUT: 93 | RESIDUAL_OUT += row * stride_res_out_row 94 | # Compute mean and variance 95 | cols = tl.arange(0, BLOCK_N) 96 | x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) 97 | if HAS_RESIDUAL: 98 | residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) 99 | x += residual 100 | if STORE_RESIDUAL_OUT: 101 | tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) 102 | if not IS_RMS_NORM: 103 | mean = tl.sum(x, axis=0) / N 104 | tl.store(Mean + row, mean) 105 | xbar = tl.where(cols < N, x - mean, 0.0) 106 | var = tl.sum(xbar * xbar, axis=0) / N 107 | else: 108 | xbar = tl.where(cols < N, x, 0.0) 109 | var = tl.sum(xbar * xbar, axis=0) / N 110 | rstd = 1 / tl.sqrt(var + eps) 111 | tl.store(Rstd + row, rstd) 112 | # Normalize and apply linear transformation 113 | mask = cols < N 114 | w = tl.load(W + cols, mask=mask).to(tl.float32) 115 | if HAS_BIAS: 116 | b = tl.load(B + cols, mask=mask).to(tl.float32) 117 | x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd 118 | y = x_hat * w + b if HAS_BIAS else x_hat * w 119 | # Write output 120 | tl.store(Y + cols, y, mask=mask) 121 | 122 | 123 | def _layer_norm_fwd( 124 | x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False 125 | ): 126 | if residual is not None: 127 | residual_dtype = residual.dtype 128 | M, N = x.shape 129 | assert x.stride(-1) == 1 130 | if residual is not None: 131 | assert residual.stride(-1) == 1 132 | assert residual.shape == (M, N) 133 | assert weight.shape == (N,) 134 | assert weight.stride(-1) == 1 135 | if bias is not None: 136 | assert bias.stride(-1) == 1 137 | assert bias.shape == (N,) 138 | # allocate output 139 | y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) 140 | assert y.stride(-1) == 1 141 | if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): 142 | residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) 143 | assert residual_out.stride(-1) == 1 144 | else: 145 | residual_out = None 146 | mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None 147 | rstd = torch.empty((M,), dtype=torch.float32, device=x.device) 148 | # Less than 64KB per feature: enqueue fused kernel 149 | MAX_FUSED_SIZE = 65536 // x.element_size() 150 | BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) 151 | if N > BLOCK_N: 152 | raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") 153 | # heuristics for number of warps 154 | with torch.cuda.device(x.device.index): 155 | _layer_norm_fwd_1pass_kernel[(M,)]( 156 | x, 157 | y, 158 | weight, 159 | bias, 160 | residual, 161 | residual_out, 162 | mean, 163 | rstd, 164 | x.stride(0), 165 | y.stride(0), 166 | residual.stride(0) if residual is not None else 0, 167 | residual_out.stride(0) if residual_out is not None else 0, 168 | N, 169 | eps, 170 | is_rms_norm, 171 | BLOCK_N, 172 | residual is not None, 173 | residual_out is not None, 174 | bias is not None, 175 | ) 176 | # residual_out is None if residual is None and residual_dtype == input_dtype 177 | return y, mean, rstd, residual_out if residual_out is not None else x 178 | 179 | 180 | @triton.autotune( 181 | configs=[ 182 | triton.Config({}, num_warps=1), 183 | triton.Config({}, num_warps=2), 184 | triton.Config({}, num_warps=4), 185 | triton.Config({}, num_warps=8), 186 | triton.Config({}, num_warps=16), 187 | triton.Config({}, num_warps=32), 188 | ], 189 | key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], 190 | ) 191 | # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) 192 | # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) 193 | # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) 194 | @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) 195 | @triton.jit 196 | def _layer_norm_bwd_kernel( 197 | X, # pointer to the input 198 | W, # pointer to the weights 199 | B, # pointer to the biases 200 | Y, # pointer to the output to be recomputed 201 | DY, # pointer to the output gradient 202 | DX, # pointer to the input gradient 203 | DW, # pointer to the partial sum of weights gradient 204 | DB, # pointer to the partial sum of biases gradient 205 | DRESIDUAL, 206 | DRESIDUAL_IN, 207 | Mean, # pointer to the mean 208 | Rstd, # pointer to the 1/std 209 | stride_x_row, # how much to increase the pointer when moving by 1 row 210 | stride_y_row, 211 | stride_dy_row, 212 | stride_dx_row, 213 | stride_dres_row, 214 | stride_dres_in_row, 215 | M, # number of rows in X 216 | N, # number of columns in X 217 | eps, # epsilon to avoid division by zero 218 | rows_per_program, 219 | IS_RMS_NORM: tl.constexpr, 220 | BLOCK_N: tl.constexpr, 221 | HAS_DRESIDUAL: tl.constexpr, 222 | STORE_DRESIDUAL: tl.constexpr, 223 | HAS_BIAS: tl.constexpr, 224 | RECOMPUTE_OUTPUT: tl.constexpr, 225 | ): 226 | # Map the program id to the elements of X, DX, and DY it should compute. 227 | row_block_id = tl.program_id(0) 228 | row_start = row_block_id * rows_per_program 229 | cols = tl.arange(0, BLOCK_N) 230 | mask = cols < N 231 | X += row_start * stride_x_row 232 | if HAS_DRESIDUAL: 233 | DRESIDUAL += row_start * stride_dres_row 234 | if STORE_DRESIDUAL: 235 | DRESIDUAL_IN += row_start * stride_dres_in_row 236 | DY += row_start * stride_dy_row 237 | DX += row_start * stride_dx_row 238 | if RECOMPUTE_OUTPUT: 239 | Y += row_start * stride_y_row 240 | w = tl.load(W + cols, mask=mask).to(tl.float32) 241 | if RECOMPUTE_OUTPUT and HAS_BIAS: 242 | b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) 243 | dw = tl.zeros((BLOCK_N,), dtype=tl.float32) 244 | if HAS_BIAS: 245 | db = tl.zeros((BLOCK_N,), dtype=tl.float32) 246 | row_end = min((row_block_id + 1) * rows_per_program, M) 247 | for row in range(row_start, row_end): 248 | # Load data to SRAM 249 | x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) 250 | dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) 251 | if not IS_RMS_NORM: 252 | mean = tl.load(Mean + row) 253 | rstd = tl.load(Rstd + row) 254 | # Compute dx 255 | xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd 256 | xhat = tl.where(mask, xhat, 0.0) 257 | if RECOMPUTE_OUTPUT: 258 | y = xhat * w + b if HAS_BIAS else xhat * w 259 | tl.store(Y + cols, y, mask=mask) 260 | wdy = w * dy 261 | dw += dy * xhat 262 | if HAS_BIAS: 263 | db += dy 264 | if not IS_RMS_NORM: 265 | c1 = tl.sum(xhat * wdy, axis=0) / N 266 | c2 = tl.sum(wdy, axis=0) / N 267 | dx = (wdy - (xhat * c1 + c2)) * rstd 268 | else: 269 | c1 = tl.sum(xhat * wdy, axis=0) / N 270 | dx = (wdy - xhat * c1) * rstd 271 | if HAS_DRESIDUAL: 272 | dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) 273 | dx += dres 274 | # Write dx 275 | if STORE_DRESIDUAL: 276 | tl.store(DRESIDUAL_IN + cols, dx, mask=mask) 277 | tl.store(DX + cols, dx, mask=mask) 278 | 279 | X += stride_x_row 280 | if HAS_DRESIDUAL: 281 | DRESIDUAL += stride_dres_row 282 | if STORE_DRESIDUAL: 283 | DRESIDUAL_IN += stride_dres_in_row 284 | if RECOMPUTE_OUTPUT: 285 | Y += stride_y_row 286 | DY += stride_dy_row 287 | DX += stride_dx_row 288 | tl.store(DW + row_block_id * N + cols, dw, mask=mask) 289 | if HAS_BIAS: 290 | tl.store(DB + row_block_id * N + cols, db, mask=mask) 291 | 292 | 293 | def _layer_norm_bwd( 294 | dy, 295 | x, 296 | weight, 297 | bias, 298 | eps, 299 | mean, 300 | rstd, 301 | dresidual=None, 302 | has_residual=False, 303 | is_rms_norm=False, 304 | x_dtype=None, 305 | recompute_output=False, 306 | ): 307 | M, N = x.shape 308 | assert x.stride(-1) == 1 309 | assert dy.stride(-1) == 1 310 | assert dy.shape == (M, N) 311 | if dresidual is not None: 312 | assert dresidual.stride(-1) == 1 313 | assert dresidual.shape == (M, N) 314 | assert weight.shape == (N,) 315 | assert weight.stride(-1) == 1 316 | if bias is not None: 317 | assert bias.stride(-1) == 1 318 | assert bias.shape == (N,) 319 | # allocate output 320 | dx = ( 321 | torch.empty_like(x) 322 | if x_dtype is None 323 | else torch.empty(M, N, dtype=x_dtype, device=x.device) 324 | ) 325 | dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None 326 | y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None 327 | 328 | # Less than 64KB per feature: enqueue fused kernel 329 | MAX_FUSED_SIZE = 65536 // x.element_size() 330 | BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) 331 | if N > BLOCK_N: 332 | raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") 333 | sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count 334 | _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) 335 | _db = ( 336 | torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) 337 | if bias is not None 338 | else None 339 | ) 340 | rows_per_program = math.ceil(M / sm_count) 341 | grid = (sm_count,) 342 | with torch.cuda.device(x.device.index): 343 | _layer_norm_bwd_kernel[grid]( 344 | x, 345 | weight, 346 | bias, 347 | y, 348 | dy, 349 | dx, 350 | _dw, 351 | _db, 352 | dresidual, 353 | dresidual_in, 354 | mean, 355 | rstd, 356 | x.stride(0), 357 | 0 if not recompute_output else y.stride(0), 358 | dy.stride(0), 359 | dx.stride(0), 360 | dresidual.stride(0) if dresidual is not None else 0, 361 | dresidual_in.stride(0) if dresidual_in is not None else 0, 362 | M, 363 | N, 364 | eps, 365 | rows_per_program, 366 | is_rms_norm, 367 | BLOCK_N, 368 | dresidual is not None, 369 | dresidual_in is not None, 370 | bias is not None, 371 | ) 372 | dw = _dw.sum(0).to(weight.dtype) 373 | db = _db.sum(0).to(bias.dtype) if bias is not None else None 374 | # Don't need to compute dresidual_in separately in this case 375 | if has_residual and dx.dtype == x.dtype: 376 | dresidual_in = dx 377 | return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) 378 | 379 | 380 | class LayerNormFn(torch.autograd.Function): 381 | @staticmethod 382 | def forward( 383 | ctx, 384 | x, 385 | weight, 386 | bias, 387 | residual=None, 388 | eps=1e-6, 389 | prenorm=False, 390 | residual_in_fp32=False, 391 | is_rms_norm=False, 392 | ): 393 | x_shape_og = x.shape 394 | # reshape input data into 2D tensor 395 | x = x.reshape(-1, x.shape[-1]) 396 | if x.stride(-1) != 1: 397 | x = x.contiguous() 398 | if residual is not None: 399 | assert residual.shape == x_shape_og 400 | residual = residual.reshape(-1, residual.shape[-1]) 401 | if residual.stride(-1) != 1: 402 | residual = residual.contiguous() 403 | weight = weight.contiguous() 404 | if bias is not None: 405 | bias = bias.contiguous() 406 | residual_dtype = ( 407 | residual.dtype 408 | if residual is not None 409 | else (torch.float32 if residual_in_fp32 else None) 410 | ) 411 | y, mean, rstd, residual_out = _layer_norm_fwd( 412 | x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm 413 | ) 414 | ctx.save_for_backward(residual_out, weight, bias, mean, rstd) 415 | ctx.x_shape_og = x_shape_og 416 | ctx.eps = eps 417 | ctx.is_rms_norm = is_rms_norm 418 | ctx.has_residual = residual is not None 419 | ctx.prenorm = prenorm 420 | ctx.x_dtype = x.dtype 421 | y = y.reshape(x_shape_og) 422 | return y if not prenorm else (y, residual_out.reshape(x_shape_og)) 423 | 424 | @staticmethod 425 | def backward(ctx, dy, *args): 426 | x, weight, bias, mean, rstd = ctx.saved_tensors 427 | dy = dy.reshape(-1, dy.shape[-1]) 428 | if dy.stride(-1) != 1: 429 | dy = dy.contiguous() 430 | assert dy.shape == x.shape 431 | if ctx.prenorm: 432 | dresidual = args[0] 433 | dresidual = dresidual.reshape(-1, dresidual.shape[-1]) 434 | if dresidual.stride(-1) != 1: 435 | dresidual = dresidual.contiguous() 436 | assert dresidual.shape == x.shape 437 | else: 438 | dresidual = None 439 | dx, dw, db, dresidual_in = _layer_norm_bwd( 440 | dy, 441 | x, 442 | weight, 443 | bias, 444 | ctx.eps, 445 | mean, 446 | rstd, 447 | dresidual, 448 | ctx.has_residual, 449 | ctx.is_rms_norm, 450 | x_dtype=ctx.x_dtype, 451 | ) 452 | return ( 453 | dx.reshape(ctx.x_shape_og), 454 | dw, 455 | db, 456 | dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, 457 | None, 458 | None, 459 | None, 460 | None, 461 | ) 462 | 463 | 464 | def layer_norm_fn( 465 | x, 466 | weight, 467 | bias, 468 | residual=None, 469 | eps=1e-6, 470 | prenorm=False, 471 | residual_in_fp32=False, 472 | is_rms_norm=False, 473 | ): 474 | return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm) 475 | 476 | 477 | def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6): 478 | return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True) 479 | 480 | 481 | class RMSNorm(torch.nn.Module): 482 | def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): 483 | factory_kwargs = {"device": device, "dtype": dtype} 484 | super().__init__() 485 | self.eps = eps 486 | self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) 487 | self.register_parameter("bias", None) 488 | self.reset_parameters() 489 | 490 | def reset_parameters(self): 491 | torch.nn.init.ones_(self.weight) 492 | 493 | def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): 494 | return rms_norm_fn( 495 | x, 496 | self.weight, 497 | self.bias, 498 | residual=residual, 499 | eps=self.eps, 500 | prenorm=prenorm, 501 | residual_in_fp32=residual_in_fp32, 502 | ) 503 | 504 | 505 | class LayerNormLinearFn(torch.autograd.Function): 506 | @staticmethod 507 | @custom_fwd 508 | def forward( 509 | ctx, 510 | x, 511 | norm_weight, 512 | norm_bias, 513 | linear_weight, 514 | linear_bias, 515 | residual=None, 516 | eps=1e-6, 517 | prenorm=False, 518 | residual_in_fp32=False, 519 | is_rms_norm=False, 520 | ): 521 | x_shape_og = x.shape 522 | # reshape input data into 2D tensor 523 | x = x.reshape(-1, x.shape[-1]) 524 | if x.stride(-1) != 1: 525 | x = x.contiguous() 526 | if residual is not None: 527 | assert residual.shape == x_shape_og 528 | residual = residual.reshape(-1, residual.shape[-1]) 529 | if residual.stride(-1) != 1: 530 | residual = residual.contiguous() 531 | norm_weight = norm_weight.contiguous() 532 | if norm_bias is not None: 533 | norm_bias = norm_bias.contiguous() 534 | residual_dtype = ( 535 | residual.dtype 536 | if residual is not None 537 | else (torch.float32 if residual_in_fp32 else None) 538 | ) 539 | y, mean, rstd, residual_out = _layer_norm_fwd( 540 | x, 541 | norm_weight, 542 | norm_bias, 543 | eps, 544 | residual, 545 | out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), 546 | residual_dtype=residual_dtype, 547 | is_rms_norm=is_rms_norm, 548 | ) 549 | y = y.reshape(x_shape_og) 550 | dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype 551 | linear_weight = linear_weight.to(dtype) 552 | linear_bias = linear_bias.to(dtype) if linear_bias is not None else None 553 | out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) 554 | # We don't store y, will be recomputed in the backward pass to save memory 555 | ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) 556 | ctx.x_shape_og = x_shape_og 557 | ctx.eps = eps 558 | ctx.is_rms_norm = is_rms_norm 559 | ctx.has_residual = residual is not None 560 | ctx.prenorm = prenorm 561 | ctx.x_dtype = x.dtype 562 | ctx.linear_bias_is_none = linear_bias is None 563 | return out if not prenorm else (out, residual_out.reshape(x_shape_og)) 564 | 565 | @staticmethod 566 | @custom_bwd 567 | def backward(ctx, dout, *args): 568 | x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors 569 | dout = dout.reshape(-1, dout.shape[-1]) 570 | dy = F.linear(dout, linear_weight.t()) 571 | dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) 572 | if dy.stride(-1) != 1: 573 | dy = dy.contiguous() 574 | assert dy.shape == x.shape 575 | if ctx.prenorm: 576 | dresidual = args[0] 577 | dresidual = dresidual.reshape(-1, dresidual.shape[-1]) 578 | if dresidual.stride(-1) != 1: 579 | dresidual = dresidual.contiguous() 580 | assert dresidual.shape == x.shape 581 | else: 582 | dresidual = None 583 | dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( 584 | dy, 585 | x, 586 | norm_weight, 587 | norm_bias, 588 | ctx.eps, 589 | mean, 590 | rstd, 591 | dresidual, 592 | ctx.has_residual, 593 | ctx.is_rms_norm, 594 | x_dtype=ctx.x_dtype, 595 | recompute_output=True, 596 | ) 597 | dlinear_weight = torch.einsum("bo,bi->oi", dout, y) 598 | return ( 599 | dx.reshape(ctx.x_shape_og), 600 | dnorm_weight, 601 | dnorm_bias, 602 | dlinear_weight, 603 | dlinear_bias, 604 | dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, 605 | None, 606 | None, 607 | None, 608 | None, 609 | ) 610 | 611 | 612 | def layer_norm_linear_fn( 613 | x, 614 | norm_weight, 615 | norm_bias, 616 | linear_weight, 617 | linear_bias, 618 | residual=None, 619 | eps=1e-6, 620 | prenorm=False, 621 | residual_in_fp32=False, 622 | is_rms_norm=False, 623 | ): 624 | return LayerNormLinearFn.apply( 625 | x, 626 | norm_weight, 627 | norm_bias, 628 | linear_weight, 629 | linear_bias, 630 | residual, 631 | eps, 632 | prenorm, 633 | residual_in_fp32, 634 | is_rms_norm, 635 | ) 636 | -------------------------------------------------------------------------------- /modeling_rodimus.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from typing import List, Optional, Tuple, Union 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from einops import rearrange, repeat 11 | 12 | from transformers.modeling_utils import PreTrainedModel 13 | from transformers.modeling_outputs import ( 14 | BaseModelOutputWithPast, 15 | CausalLMOutputWithPast, 16 | BaseModelOutput, 17 | MaskedLMOutput 18 | ) 19 | from transformers.cache_utils import Cache 20 | 21 | from modules.cache import HybridCache 22 | from modules.rodimus_flow import RodimusFlow 23 | from modules.rodimus_attention import SlideWindowSharedKeyAttention 24 | from modules.mlp import GLU 25 | 26 | from ops.layernorm import RMSNorm 27 | 28 | from configuration_rodimus import RodimusConfig 29 | 30 | import logging 31 | logger = logging.getLogger(__name__) 32 | 33 | try: 34 | from fla.modules import FusedCrossEntropyLoss 35 | except ImportError: 36 | FusedCrossEntropyLoss = None 37 | 38 | 39 | def _apply_no_weight_decay_on_norm(module): 40 | from ops.layernorm_gated import RMSNorm as RMSNormWithGate 41 | from ops.layernorm_gated import LayerNorm as LayerNormWithGate 42 | 43 | if isinstance(module, RMSNorm) or isinstance(module, RMSNormWithGate): 44 | module.weight._no_weight_decay = True 45 | elif isinstance(module, nn.LayerNorm) or isinstance(module, LayerNormWithGate): 46 | module.weight._no_weight_decay = True 47 | module.bias._no_weight_decay = True 48 | 49 | 50 | def _apply_no_weight_decay_on_embedding(module, lm_head_param=None): 51 | if isinstance(module, nn.Embedding): 52 | if lm_head_param is not None: 53 | if lm_head_param.weight != module.weight: 54 | module.weight._no_weight_decay = True 55 | else: 56 | logger.warning_once( 57 | "Unable to find the lm_head, forcibly set embedding's weight decay to 0.0") 58 | module.weight._no_weight_decay = True 59 | 60 | 61 | def _set_no_weight_decay( 62 | module: nn.Module, 63 | no_weight_decay_on_bias=True, 64 | no_weight_decay_on_norm=True, 65 | no_weight_decay_on_embedding=False, 66 | ): 67 | if no_weight_decay_on_bias: 68 | for n, p in module.named_parameters(): 69 | if n.endswith("bias") and p is not None: 70 | p._no_weight_decay = True 71 | 72 | if no_weight_decay_on_norm: 73 | module.apply(_apply_no_weight_decay_on_norm) 74 | 75 | if no_weight_decay_on_embedding: 76 | lm_head_param = None 77 | for n, p in module.named_parameters(): 78 | if n.endswith("lm_head"): 79 | lm_head_param = p 80 | break 81 | module.apply(partial(_apply_no_weight_decay_on_embedding, 82 | lm_head_param=lm_head_param)) 83 | 84 | 85 | def _init_weights( 86 | module: nn.Module, 87 | initializer_range: float = 0.02, 88 | rescale_prenorm_residual: bool = True, 89 | num_residuals_per_layer: int = 1, 90 | n_layer: int = 1, 91 | ): 92 | if isinstance(module, nn.Linear): 93 | if not getattr(module.weight, "_no_reinit", False): 94 | nn.init.normal_(module.weight, mean=0.0, std=initializer_range) 95 | if module.bias is not None: 96 | if not getattr(module.bias, "_no_reinit", False): 97 | nn.init.zeros_(module.bias) 98 | 99 | elif isinstance(module, nn.Embedding): 100 | if not getattr(module.weight, "_no_reinit", False): 101 | nn.init.normal_(module.weight, mean=0.0, std=initializer_range) 102 | if module.padding_idx is not None: 103 | module.weight.data[module.padding_idx].zero_() 104 | 105 | if rescale_prenorm_residual: 106 | for name, p in module.named_parameters(): 107 | if name in ["out_proj.weight"]: 108 | with torch.no_grad(): 109 | p /= math.sqrt(num_residuals_per_layer * n_layer) 110 | 111 | 112 | class RodimusTrainedModel(PreTrainedModel): 113 | config_class = RodimusConfig 114 | supports_gradient_checkpointing = True 115 | _no_split_modules = ["RodimusBlock"] 116 | 117 | def __init__(self, *inputs, **kwargs): 118 | super().__init__(*inputs, **kwargs) 119 | if self.config.block_type == "rodimus": 120 | self.num_residuals_per_layer = 1 121 | elif self.config.block_type == "rodimus_plus": 122 | self.num_residuals_per_layer = 3 123 | else: 124 | raise NotImplementedError() 125 | 126 | def _init_weights( 127 | self, 128 | module: nn.Module, 129 | ): 130 | _init_weights( 131 | module, 132 | initializer_range=self.config.initializer_range, 133 | rescale_prenorm_residual=self.config.rescale_prenorm_residual, 134 | num_residuals_per_layer=self.num_residuals_per_layer, 135 | n_layer=self.config.n_layer 136 | ) 137 | 138 | 139 | class RodimusBlock(nn.Module): 140 | def __init__( 141 | self, 142 | block_type, 143 | d_model, 144 | max_position_embeddings=None, 145 | mixer_cfg={}, 146 | attn_cfg={}, 147 | norm_epsilon=1e-5, 148 | residual_in_fp32=True, 149 | use_fast_path=True, 150 | use_fused_swiglu=True, 151 | layer_idx=None, 152 | causal=True, 153 | dropout=0., 154 | activation_dropout=0., 155 | attention_dropout=0., 156 | ): 157 | super().__init__() 158 | self.block_type = block_type 159 | self.d_model = d_model 160 | self.norm_epsilon = norm_epsilon 161 | self.residual_in_fp32 = residual_in_fp32 162 | self.use_fast_path = use_fast_path 163 | self.use_fused_swiglu = use_fused_swiglu 164 | self.causal = causal 165 | 166 | attn_cfg = attn_cfg.copy() 167 | mixer_cfg = mixer_cfg.copy() 168 | 169 | self.mixer_norm = RMSNorm(self.d_model, eps=self.norm_epsilon) 170 | self.mixer = RodimusFlow( 171 | d_model, layer_idx=layer_idx, **mixer_cfg, 172 | use_fast_path=use_fast_path, residual_in_fp32=residual_in_fp32, 173 | causal=self.causal, 174 | dropout=dropout, 175 | activation_dropout=activation_dropout, 176 | norm_epsilon=self.norm_epsilon, 177 | ) 178 | 179 | if self.block_type == "rodimus_plus": 180 | attn_cfg["num_heads"] = d_model // 128 if "num_heads" not in attn_cfg or attn_cfg["num_heads"] is None else attn_cfg["num_heads"] 181 | ffn_expand_ratio = attn_cfg.pop("ffn_expand_ratio", 4/3) 182 | 183 | self.attn_norm = RMSNorm(self.d_model, eps=self.norm_epsilon) 184 | self.attn = SlideWindowSharedKeyAttention( 185 | dim=d_model, 186 | **attn_cfg, 187 | layer_idx=layer_idx, 188 | causal=self.causal, 189 | dropout=dropout, 190 | activation_dropout=activation_dropout, 191 | attention_dropout=attention_dropout, 192 | max_position_embeddings=max_position_embeddings, 193 | ) 194 | 195 | self.ffn_norm = RMSNorm(self.d_model, eps=self.norm_epsilon) 196 | self.ffn = GLU( 197 | d_model, ffn_expand_ratio, 198 | use_fast_path=use_fused_swiglu, 199 | dropout=dropout, 200 | activation_dropout=activation_dropout, 201 | ) 202 | 203 | def forward( 204 | self, 205 | hidden_states: torch.Tensor, 206 | residual: Optional[torch.Tensor] = None, 207 | attention_mask: Optional[torch.Tensor] = None, 208 | past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, 209 | use_cache: Optional[bool] = False, 210 | output_attentions: Optional[bool] = False, 211 | **kwargs, 212 | ): 213 | hidden_states, residual = self.mixer_norm( 214 | hidden_states, 215 | residual=residual, 216 | prenorm=True, 217 | residual_in_fp32=self.residual_in_fp32 218 | ) 219 | hidden_states, past_key_values = self.mixer( 220 | hidden_states=hidden_states, 221 | attention_mask=attention_mask, 222 | past_key_values=past_key_values, 223 | use_cache=use_cache, 224 | output_attentions=output_attentions, 225 | ) 226 | 227 | if self.block_type == "rodimus_plus": 228 | hidden_states, residual = self.attn_norm( 229 | hidden_states, 230 | residual=residual, 231 | prenorm=True, 232 | residual_in_fp32=self.residual_in_fp32 233 | ) 234 | 235 | hidden_states, past_key_values = self.attn( 236 | hidden_states=hidden_states, 237 | attention_mask=attention_mask, 238 | past_key_values=past_key_values, 239 | use_cache=use_cache, 240 | output_attentions=output_attentions, 241 | ) 242 | 243 | hidden_states = self.ffn_norm( 244 | hidden_states, 245 | residual=residual, 246 | prenorm=False, 247 | residual_in_fp32=self.residual_in_fp32 248 | ) 249 | hidden_states = self.ffn(hidden_states) 250 | 251 | return hidden_states, residual, past_key_values 252 | 253 | 254 | class RodimusModel(RodimusTrainedModel): 255 | def __init__( 256 | self, 257 | config: RodimusConfig, 258 | causal=True, 259 | ): 260 | super().__init__(config) 261 | self.config = config 262 | self.d_model = config.d_model 263 | self.n_layer = config.n_layer 264 | self.vocab_size = config.vocab_size 265 | self.padding_idx = config.pad_token_id 266 | self.norm_epsilon = config.norm_epsilon 267 | self.residual_in_fp32 = config.residual_in_fp32 268 | self.use_fast_path = config.use_fast_path 269 | self.use_fused_swiglu = config.use_fused_swiglu 270 | self.causal = causal 271 | self.max_position_embeddings = config.max_position_embeddings 272 | 273 | self.RodimusConfig = config.block_type 274 | 275 | self.embeddings = nn.Embedding( 276 | self.vocab_size, self.d_model, padding_idx=self.padding_idx) 277 | 278 | if self.config.use_scale_embedding: 279 | mem_size = self.config.mixer_cfg['mem_size'] if 'mem_size' in self.config.mixer_cfg else 64 280 | self.embed_scale = math.sqrt(mem_size) 281 | else: 282 | self.embed_scale = 1. 283 | 284 | if self.config.use_norm_embedding: 285 | self.embed_norm = RMSNorm(self.d_model, eps=self.norm_epsilon) 286 | else: 287 | self.embed_norm = None 288 | 289 | self.layers = nn.ModuleList([]) 290 | 291 | for i in range(self.n_layer): 292 | block = RodimusBlock( 293 | self.config.block_type, 294 | self.d_model, 295 | layer_idx=i, 296 | max_position_embeddings=self.max_position_embeddings, 297 | mixer_cfg=self.config.mixer_cfg, 298 | attn_cfg=self.config.attn_cfg, 299 | norm_epsilon=self.norm_epsilon, 300 | residual_in_fp32=self.residual_in_fp32, 301 | use_fast_path=self.use_fast_path, 302 | use_fused_swiglu=self.use_fused_swiglu, 303 | causal=self.causal, 304 | dropout=self.config.dropout, 305 | activation_dropout=self.config.activation_dropout, 306 | attention_dropout=self.config.attention_dropout, 307 | ) 308 | 309 | self.layers.append(block) 310 | 311 | self.norm_f = RMSNorm(self.d_model, eps=self.norm_epsilon) 312 | 313 | self.has_ssm = hasattr(self.layers[0], "mixer") 314 | self.has_attn = hasattr(self.layers[0], "attn") 315 | assert self.has_ssm or self.has_attn 316 | 317 | _set_no_weight_decay( 318 | self, 319 | no_weight_decay_on_bias=self.config.no_weight_decay_on_bias, 320 | no_weight_decay_on_norm=self.config.no_weight_decay_on_norm, 321 | no_weight_decay_on_embedding=False, # do this at `RodimusForCausalLM` 322 | ) 323 | 324 | self.gradient_checkpointing = False 325 | 326 | self.post_init() 327 | 328 | def get_input_embeddings(self): 329 | return self.embeddings 330 | 331 | def set_input_embeddings(self, value): 332 | self.embeddings = value 333 | 334 | def forward( 335 | self, 336 | input_ids: Optional[torch.LongTensor] = None, 337 | attention_mask: Optional[torch.Tensor] = None, # noqa 338 | inputs_embeds: Optional[torch.FloatTensor] = None, 339 | past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, 340 | use_cache: Optional[bool] = None, 341 | output_attentions: Optional[bool] = None, 342 | output_hidden_states: Optional[bool] = None, 343 | return_dict: Optional[bool] = None 344 | ): 345 | if output_attentions: 346 | warnings.warn( 347 | "`Model` does not `output_attentions` now, setting it to `False`.") 348 | output_attentions = False 349 | 350 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 351 | output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 352 | use_cache = use_cache if use_cache is not None else ( 353 | self.config.use_cache if not self.training else False) 354 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 355 | if attention_mask is not None: 356 | attention_mask = attention_mask.to(torch.bool) 357 | if attention_mask.dtype == torch.bool: 358 | attention_mask = attention_mask if False in attention_mask else None 359 | else: 360 | attention_mask = attention_mask if 0.0 in attention_mask else None 361 | else: 362 | attention_mask = None 363 | 364 | if input_ids is not None and inputs_embeds is not None: 365 | raise ValueError( 366 | "You cannot specify both input_ids and inputs_embeds at the same time") 367 | 368 | if inputs_embeds is None: 369 | inputs_embeds = self.embeddings(input_ids) 370 | inputs_embeds *= self.embed_scale 371 | 372 | if self.embed_norm is not None: 373 | inputs_embeds = self.embed_norm(inputs_embeds) 374 | 375 | hidden_states = inputs_embeds 376 | 377 | if self.gradient_checkpointing and self.training: 378 | if use_cache: 379 | logger.warning_once( 380 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 381 | ) 382 | use_cache = False 383 | 384 | if use_cache: 385 | if past_key_values is None: # init states 386 | past_key_values = [] 387 | for layer in self.layers: 388 | cache = () 389 | if self.has_ssm: 390 | cache += layer.mixer.allocate_inference_cache( 391 | hidden_states.size(0)) 392 | if self.has_attn: 393 | cache += layer.attn.allocate_inference_cache( 394 | hidden_states.size(0)) 395 | past_key_values.append(cache) 396 | if not isinstance(past_key_values, HybridCache): 397 | past_key_values = HybridCache.from_legacy_cache( 398 | past_key_values=past_key_values, 399 | seen_tokens=0, 400 | has_ssm=self.has_ssm, 401 | has_attn=self.has_attn, 402 | ) 403 | else: 404 | past_key_values = None 405 | 406 | all_hidden_states = () if output_hidden_states else None 407 | all_attns = () if output_attentions else None 408 | residual = None 409 | for layer in self.layers: 410 | if output_hidden_states: 411 | all_hidden_states += (hidden_states, ) 412 | 413 | if self.gradient_checkpointing and self.training: 414 | hidden_states, residual, past_key_values = self._gradient_checkpointing_func( 415 | layer.__call__, 416 | hidden_states, 417 | residual, 418 | attention_mask, 419 | past_key_values, 420 | use_cache, 421 | output_attentions, 422 | ) 423 | else: 424 | hidden_states, residual, past_key_values = layer( 425 | hidden_states=hidden_states, 426 | residual=residual, 427 | attention_mask=attention_mask, 428 | past_key_values=past_key_values, 429 | use_cache=use_cache, 430 | output_attentions=output_attentions, 431 | ) 432 | 433 | hidden_states = self.norm_f( 434 | hidden_states, 435 | residual=residual, 436 | prenorm=False, 437 | residual_in_fp32=self.residual_in_fp32 438 | ) 439 | 440 | if output_hidden_states: 441 | all_hidden_states += (hidden_states,) 442 | 443 | next_cache = None 444 | if use_cache: 445 | next_cache = past_key_values.to_legacy_cache() 446 | if not return_dict: 447 | return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None) 448 | 449 | if self.causal: 450 | return BaseModelOutputWithPast( 451 | last_hidden_state=hidden_states, 452 | past_key_values=next_cache, 453 | hidden_states=all_hidden_states, 454 | attentions=all_attns 455 | ) 456 | else: 457 | return BaseModelOutput( 458 | last_hidden_state=hidden_states, 459 | hidden_states=all_hidden_states, 460 | attentions=all_attns 461 | ) 462 | 463 | 464 | class RodimusForCausalLM(RodimusTrainedModel): 465 | _tied_weights_keys = ["lm_head.weight"] 466 | 467 | def __init__( 468 | self, 469 | config: RodimusConfig 470 | ): 471 | super().__init__(config) 472 | self.config = config 473 | 474 | self.model = RodimusModel(config, causal=True) 475 | self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) 476 | 477 | _set_no_weight_decay( 478 | self, 479 | no_weight_decay_on_bias=False, 480 | no_weight_decay_on_norm=False, 481 | no_weight_decay_on_embedding=( 482 | not self.config.tie_word_embeddings) and self.config.no_weight_decay_on_embedding, 483 | ) 484 | 485 | self.post_init() 486 | 487 | def get_input_embeddings(self): 488 | return self.model.embeddings 489 | 490 | def set_input_embeddings(self, value): 491 | self.model.embeddings = value 492 | 493 | def get_output_embeddings(self): 494 | return self.lm_head 495 | 496 | def set_output_embeddings(self, new_embeddings): 497 | self.lm_head = new_embeddings 498 | 499 | def generate(self, *args, **kwargs): 500 | try: 501 | return super().generate(*args, **kwargs) 502 | except AttributeError as exception: 503 | if 'past_key_values' in str(exception): 504 | raise AttributeError( 505 | f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " 506 | f"which is not supported for {self.__class__.__name__}. " 507 | f"Try another generation strategy instead. " 508 | f"For the available generation strategies, check this doc: " 509 | f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" 510 | ) 511 | else: 512 | raise exception 513 | 514 | def prepare_inputs_for_generation( 515 | self, 516 | input_ids: torch.LongTensor = None, 517 | past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, 518 | attention_mask: Optional[torch.Tensor] = None, 519 | inputs_embeds: Optional[torch.Tensor] = None, 520 | **kwargs 521 | ): 522 | # only last token for `inputs_ids` if the `past_key_values` is passed along. 523 | if past_key_values is not None: 524 | if not isinstance(past_key_values, HybridCache): 525 | past_key_values = HybridCache.from_legacy_cache( 526 | past_key_values=past_key_values, 527 | seen_tokens=input_ids.shape[1] - 1, 528 | has_ssm=self.model.has_ssm, 529 | has_attn=self.model.has_attn, 530 | ) 531 | # input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:] 532 | input_ids = input_ids[:, -1:] 533 | 534 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 535 | if inputs_embeds is not None and past_key_values is None: 536 | model_inputs = {'inputs_embeds': inputs_embeds} 537 | else: 538 | # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise 539 | # recompiles graphs as the stride of the inputs is a guard. 540 | # Ref: https://github.com/huggingface/transformers/pull/29114 541 | # TODO: use `next_tokens` directly instead. 542 | model_inputs = {'input_ids': input_ids.contiguous()} 543 | 544 | model_inputs.update({ 545 | 'past_key_values': past_key_values, 546 | 'use_cache': kwargs.get('use_cache'), 547 | 'attention_mask': attention_mask, 548 | }) 549 | return model_inputs 550 | 551 | def forward( 552 | self, 553 | input_ids: torch.LongTensor = None, 554 | attention_mask: Optional[torch.Tensor] = None, 555 | inputs_embeds: Optional[torch.Tensor] = None, 556 | past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, 557 | labels: Optional[torch.LongTensor] = None, 558 | use_cache: Optional[bool] = None, 559 | output_attentions: Optional[bool] = None, 560 | output_hidden_states: Optional[bool] = None, 561 | return_dict: Optional[bool] = None, 562 | **kwargs, 563 | ) -> Union[Tuple, CausalLMOutputWithPast]: 564 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 565 | output_hidden_states = ( 566 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 567 | ) 568 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 569 | 570 | outputs = self.model( 571 | input_ids=input_ids, 572 | attention_mask=attention_mask, 573 | inputs_embeds=inputs_embeds, 574 | past_key_values=past_key_values, 575 | use_cache=use_cache, 576 | output_attentions=output_attentions, 577 | output_hidden_states=output_hidden_states, 578 | return_dict=return_dict 579 | ) 580 | hidden_states = outputs[0] 581 | 582 | if self.lm_head is not None: 583 | logits = self.lm_head(hidden_states) 584 | else: 585 | logits = hidden_states 586 | logits = logits.float() 587 | 588 | loss = None 589 | if labels is not None and logits is not None: 590 | # Shift so that tokens < n predict n 591 | shift_logits = logits[..., :-1, :].contiguous() 592 | shift_labels = labels[..., 1:].contiguous() 593 | 594 | if FusedCrossEntropyLoss is not None and self.config.use_fused_cross_entropy: 595 | loss_fct = FusedCrossEntropyLoss(inplace_backward=True) 596 | else: 597 | loss_fct = nn.CrossEntropyLoss() 598 | 599 | # Flatten the tokens 600 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 601 | shift_labels = shift_labels.view(-1) 602 | # Enable model parallelism 603 | shift_labels = shift_labels.to(shift_logits.device) 604 | loss = loss_fct(shift_logits, shift_labels) 605 | 606 | if not return_dict: 607 | output = (logits,) + outputs[1:] 608 | return (loss,) + output if loss is not None else output 609 | 610 | # output 611 | return CausalLMOutputWithPast( 612 | loss=loss, 613 | logits=logits, 614 | past_key_values=outputs.past_key_values, 615 | hidden_states=outputs.hidden_states, 616 | attentions=outputs.attentions, 617 | ) 618 | 619 | --------------------------------------------------------------------------------