├── requirements.txt ├── README.md └── main.py /requirements.txt: -------------------------------------------------------------------------------- 1 | rotary-embedding-torch 2 | datasets 3 | tokenizers -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Retentive Network 2 | Pytorch implementation of [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf) 3 | 4 | ## References 5 | ``` 6 | @misc{sun2023retentive, 7 | title={Retentive Network: A Successor to Transformer for Large Language Models}, 8 | author={Yutao Sun and Li Dong and Shaohan Huang and Shuming Ma and Yuqing Xia and Jilong Xue and Jianyong Wang and Furu Wei}, 9 | year={2023}, 10 | eprint={2307.08621}, 11 | archivePrefix={arXiv}, 12 | primaryClass={cs.CL} 13 | } 14 | ``` -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import pytorch_lightning as pl 3 | from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.utils.data.dataloader import DataLoader 9 | 10 | from einops import rearrange, repeat 11 | from rotary_embedding_torch import RotaryEmbedding 12 | 13 | from datasets import load_dataset 14 | 15 | from typing import Any, Tuple, Optional 16 | 17 | from itertools import chain 18 | 19 | import argparse 20 | 21 | from tokenizers import Tokenizer, models, decoders, pre_tokenizers, trainers 22 | 23 | class MSR(nn.Module): 24 | def __init__(self, 25 | hidden_dim: int, 26 | head_dim: int, 27 | ): 28 | super().__init__() 29 | 30 | self.n_heads = n_heads = hidden_dim // head_dim 31 | self.head_dim = head_dim 32 | 33 | self.gn = nn.GroupNorm(n_heads, n_heads) 34 | self.act = nn.Mish() 35 | 36 | self.register_buffer('gamma', 1 - torch.pow(2.0, -5.0 - torch.arange(0, n_heads))) 37 | 38 | self.wg = nn.Linear(hidden_dim, hidden_dim, bias=False) 39 | self.wo = nn.Linear(hidden_dim, hidden_dim, bias=False) 40 | 41 | self.wqkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False) 42 | self.pos_emb = RotaryEmbedding(head_dim // 2, use_xpos=True) 43 | 44 | def forward(self, x: torch.Tensor, hidden_state: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 45 | """ 46 | x = input vector of shape [bs, seqlen, hidden dim] 47 | hidden_state = recurrent state from previous step. Initialized by layer itself is null and module is in eval state 48 | """ 49 | bs, seqlen, _ = x.shape 50 | 51 | qkv: torch.Tensor = self.wqkv(x) # [bs, seqlen, 3xhidden dim] 52 | qkv = rearrange(qkv, 'B S (H D) -> B H S D', H=self.n_heads) 53 | q, k, v = qkv.chunk(3, dim=-1) # 3x[bs, heads, seqlen, head dim] 54 | 55 | # Eq. (5) 56 | # Apply xPos embedding on Q/K 57 | q, k = self.pos_emb.rotate_queries_and_keys(q, k) 58 | 59 | if self.training: 60 | # Eq. (5) 61 | # Dnm = pow(gamma, n-m) if n>=m else 0 62 | # TODO: implement Retention Score normalization 63 | nm_index = torch.arange(1, seqlen+1, device=x.device) 64 | nm = repeat(nm_index, 'W -> W H', H=seqlen) - repeat(nm_index, 'W -> H W', H=seqlen) 65 | decay_mask = torch.pow(self.gamma.view(-1, 1, 1), nm) * (nm >= 0).int() 66 | 67 | ret: torch.Tensor = q @ k.transpose(-1, -2) 68 | ret = ret * decay_mask 69 | ret = ret @ v 70 | ret = self.gn(ret) 71 | else: 72 | # Eq. (6) 73 | if hidden_state is None: 74 | hidden_state = torch.zeros(bs, self.n_heads, q.shape[-1], v.shape[-1]) 75 | 76 | hidden_state = self.gamma.view(1, -1, 1, 1) * hidden_state + k.transpose(-1, -2) @ v 77 | ret = q @ hidden_state 78 | 79 | # Eq. (8) 80 | y = rearrange(ret, 'B H S D -> B S (H D)') 81 | y = self.act(self.wg(x)) * y 82 | y = self.wo(y) 83 | 84 | return y, hidden_state 85 | 86 | class FFN(nn.Module): 87 | def __init__(self, hidden_dim: int,): 88 | super().__init__() 89 | self.w1 = nn.Linear(hidden_dim, hidden_dim) 90 | self.w2 = nn.Linear(hidden_dim, hidden_dim) 91 | self.act = nn.GELU() 92 | 93 | def forward(self, x: torch.Tensor) -> torch.Tensor: 94 | return self.w2(self.act(self.w1(x))) 95 | 96 | class RetNetBlock(nn.Module): 97 | def __init__(self, hidden_dim: int, 98 | head_dim: int,): 99 | super().__init__() 100 | 101 | self.ln1 = nn.LayerNorm(hidden_dim) 102 | self.msr = MSR(hidden_dim=hidden_dim, head_dim=head_dim) 103 | self.ln2 = nn.LayerNorm(hidden_dim) 104 | self.ffn = FFN(hidden_dim) 105 | 106 | def forward(self, x: torch.Tensor, hidden_state: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: 107 | # Eq. (9) 108 | y, hidden_state = self.msr(self.ln1(x), hidden_state) 109 | y = y + x 110 | return self.ffn(self.ln2(y)) + y, hidden_state 111 | 112 | class RetentionNetwork(nn.Module): 113 | def __init__(self, 114 | vocab_size: int, 115 | hidden_dim: int, 116 | n_layers: int, 117 | head_dim: int = 256, 118 | ): 119 | super().__init__() 120 | self.emb = nn.Embedding(vocab_size, hidden_dim) 121 | self.layers = nn.ModuleList([RetNetBlock(hidden_dim=hidden_dim, head_dim=head_dim) for i in range(n_layers)]) 122 | 123 | def forward(self, x: torch.LongTensor) -> torch.Tensor: 124 | x = self.emb(x) 125 | hidden_state: torch.Tensor = None 126 | 127 | for layer in self.layers: 128 | x, hidden_state = layer(x, hidden_state) 129 | 130 | return x 131 | 132 | class RetNetClassification(pl.LightningModule): 133 | def __init__(self, 134 | vocab_size: int, 135 | hidden_dim: int, 136 | n_layers: int, 137 | head_dim: int = 256, 138 | max_seqlen: int = 512, 139 | num_classes: int = 2, 140 | *, 141 | lr = 0.001, 142 | betas = (0.9, 0.98), 143 | weight_decay = 0.05 144 | ): 145 | super().__init__() 146 | 147 | self.hidden_dim = hidden_dim 148 | self.max_seqlen = max_seqlen 149 | 150 | self.retnet = RetentionNetwork(vocab_size=vocab_size, hidden_dim=hidden_dim, n_layers=n_layers, head_dim=head_dim) 151 | self.head = nn.Linear(max_seqlen * hidden_dim, num_classes) 152 | 153 | self.lr = lr 154 | self.betas = betas 155 | self.weight_decay = weight_decay 156 | 157 | self.loss = nn.CrossEntropyLoss() 158 | 159 | def forward(self, x: torch.LongTensor) -> torch.Tensor: 160 | y = self.retnet(x) 161 | y = self.head(y.view(-1, self.max_seqlen * self.hidden_dim)) 162 | y = F.softmax(y, dim=1) 163 | return y 164 | 165 | def configure_optimizers(self): 166 | optim = torch.optim.AdamW(chain(self.retnet.parameters(), self.head.parameters()), 167 | lr=self.lr, 168 | betas=self.betas, 169 | weight_decay=self.weight_decay 170 | ) 171 | return optim 172 | 173 | def training_step(self, batch, batch_idx): 174 | x, labels = batch 175 | y = self.forward(x) 176 | 177 | loss = self.loss(y, labels) 178 | self.log_dict({'train_loss' : loss}) 179 | 180 | return loss 181 | 182 | def test_step(self, batch, batch_idx): 183 | x, labels = batch 184 | y = self.forward(x) 185 | 186 | loss = self.loss(y, labels) 187 | 188 | self.log_dict({'test_loss' : loss}) 189 | 190 | class IMDBDataModule(pl.LightningDataModule): 191 | def __init__(self, 192 | tokenizer: Tokenizer, 193 | train_batch_size: int = 32, 194 | test_batch_size: int = 4, 195 | predict_batch_size: int = 4, 196 | max_len: int = 512, 197 | ): 198 | super().__init__() 199 | self.name = 'imdb' 200 | self.train_batch_size = train_batch_size 201 | self.test_batch_size = test_batch_size 202 | self.predict_batch_size = predict_batch_size 203 | self.max_len = 512 204 | 205 | self.tokenizer = tokenizer 206 | 207 | def prepare_data(self): 208 | load_dataset(self.name) 209 | 210 | def train_tokenizer(self, trainer: trainers.Trainer): 211 | for split, dataset in load_dataset(self.name).items(): 212 | def batch_iterator(batch_size=1000): 213 | for i in range(0, len(dataset), batch_size): 214 | yield dataset[i:i+batch_size]["text"] 215 | 216 | print(f'Train tokenizer for split {split}') 217 | 218 | self.tokenizer.train_from_iterator(batch_iterator(), trainer, length=len(dataset)) 219 | 220 | def tokenize(self, item): 221 | return { 222 | 'tokens' : [o.ids for o in self.tokenizer.encode_batch(item['text'])] 223 | } 224 | 225 | def setup(self, stage: str): 226 | if stage == 'fit': 227 | self.train_dataset = load_dataset(self.name, split='train') 228 | self.train_dataset.set_format(type='torch') 229 | self.train_dataset = self.train_dataset.map(lambda e: self.tokenize(e), batched=True) 230 | 231 | elif stage == 'test': 232 | self.test_dataset = load_dataset(self.name, split='test') 233 | self.test_dataset.set_format(type='torch') 234 | self.test_dataset = self.test_dataset.map(lambda e: self.tokenize(e), batched=True) 235 | 236 | elif stage == 'predict': 237 | self.predict_dataset = load_dataset(self.name, split='unsupervised') 238 | self.predict_dataset.set_format(type='torch') 239 | self.predict_dataset = self.predict_dataset.map(lambda e: self.tokenize(e), batched=True) 240 | 241 | def collate(self, batch): 242 | bs = len(batch) 243 | batched_ids = torch.zeros(bs, self.max_len, dtype=torch.long) 244 | batched_labels = torch.zeros(bs, dtype=torch.long) 245 | 246 | for id, item in enumerate(batch): 247 | tokens, labels = item['tokens'], item['label'] 248 | l = min(tokens.shape[0], self.max_len) 249 | batched_ids[id, :l] = tokens[:l] 250 | batched_labels[id] = labels 251 | return batched_ids, batched_labels 252 | 253 | def train_dataloader(self): 254 | return DataLoader(self.train_dataset, batch_size=self.train_batch_size, collate_fn=lambda e: self.collate(e)) 255 | 256 | def test_dataloader(self): 257 | return DataLoader(self.test_dataloader, batch_size=self.test_batch_size, collate_fn=lambda e: self.collate(e)) 258 | 259 | def predict_dataloader(self): 260 | return DataLoader(self.predict_dataset, batch_sampler=self.predict_batch_size, collate_fn=lambda e: self.collate(e)) 261 | 262 | if __name__=='__main__': 263 | 264 | torch.set_float32_matmul_precision('high') 265 | 266 | try: 267 | tokenizer = Tokenizer.from_file('imdb.json') 268 | tokenizer_trainer = None 269 | except: 270 | tokenizer = Tokenizer(models.WordPiece(unk_token="[UNK]")) 271 | 272 | tokenizer.decoder = decoders.WordPiece() 273 | tokenizer.pre_tokenizer = pre_tokenizers.Whitespace() 274 | tokenizer_trainer = trainers.WordPieceTrainer(vocab_size=3000, 275 | special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] 276 | ) 277 | 278 | datamodule = IMDBDataModule(tokenizer=tokenizer) 279 | datamodule.prepare_data() 280 | 281 | if tokenizer_trainer is not None: 282 | datamodule.train_tokenizer(tokenizer_trainer) 283 | tokenizer.save('imdb.json') 284 | 285 | model = RetNetClassification(vocab_size=tokenizer.get_vocab_size(), hidden_dim=256, n_layers=3, head_dim=32, max_seqlen=512) 286 | # model = MSR(256, 32) 287 | # model = model.eval() 288 | # x = torch.zeros(1, 512, dtype=torch.long) 289 | # y = model(x) 290 | # print(x.shape, y.shape) 291 | 292 | trainer = pl.Trainer(max_epochs=10) 293 | 294 | trainer.fit(model, datamodule) 295 | --------------------------------------------------------------------------------