├── src ├── __init__.py ├── example.py ├── complex │ ├── util.py │ ├── test_retnet.py │ ├── retnet.py │ ├── test_retention.py │ └── retention.py ├── retnet.py ├── xpos_relative_position.py ├── tests.py └── retention.py ├── CONTRIBUTING.md ├── license.md └── readme.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import retnet 5 | 6 | if __name__ == "__main__": 7 | # verify model size for hyperparameters in paper 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | # 1.3B model 11 | layers = 24 12 | hidden_dim = 2048 13 | ffn_size = 4096 14 | heads = 16 15 | 16 | retnet = retnet.RetNet(layers, hidden_dim, ffn_size, heads, double_v_dim=True).to(device) 17 | print("1.3B model:",sum(p.numel() for p in retnet.parameters() if p.requires_grad)) 18 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | All contributions are welcome. 3 | 4 | Please read below for information on creating good issues and pull requests. 5 | ## Issues 6 | If you want to draw attention to an error, bug, or unexpected behaviour, please apply the `bug` label to your issue. Provide steps for reproducing the error and any relevent platform information. 7 | ## Pull requests 8 | For bug fixes, provide for verifying that the bug has been resolved and link to the corresponding issue if one exists. 9 | 10 | For new features, ensure code is consistent with the style in the rest of the repository. 11 | -------------------------------------------------------------------------------- /license.md: -------------------------------------------------------------------------------- 1 | Copyright 2023 Jamie Stirling 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # RetNet 2 | An implementation of [Retentive Network: A Successor to Transformer 3 | for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf) in PyTorch. 4 | 5 | ## About this repository 6 | This is a minimal, pure pytorch implementation of RetNet. RetNet paper: [Retentive Network: A Successor to Transformer 7 | for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf). 8 | 9 | The contributors(s) to this repository are not authors of the original paper. All credit for the idea and formulation of RetNet goes to the original authors. 10 | 11 | The purpose of this repository is to aid scientific and technological understanding and advancement. The code prioritizes correctness and readability over optimization. 12 | 13 | ## Features implemented 14 | * Single-scale and MultiScale retention: 15 | - parallel paradigm 16 | - recurrent paradigm 17 | - chunkwise paradigm 18 | * Multi-layer retentive network with FFN and LayerNorm 19 | - parallel paradigm 20 | - recurrent paradigm 21 | - chunkwise paradigm 22 | * Causal language model (CLM) built on top of the the retentive network 23 | 24 | ## Usage and Examples: 25 | * See scripts prefixed with `test_` for examples of basic usage 26 | 27 | ## Positional Encodings 28 | The main implementation in `src/` uses [Microsoft's xPos](https://github.com/microsoft/torchscale/blob/main/torchscale/component/xpos_relative_position.py) for positional encoding. 29 | 30 | The implementation in `src/complex` uses complex values to encode position, which requires parameter and data throughput types to be `torch.ComplexFloat` (64-bit). This has some limitations due to there not yet being torch support for half-precision complex types. It also requires twice the amount of memory as real-valued data at 32-bit precision. 31 | 32 | ## Contributions 33 | All contributions are welcome. Please see [issues](https://github.com/Jamie-Stirling/RetNet/issues) for an idea of what needs doing. 34 | 35 | If you would like to contribute to this project, please fork it and submit a pull request for review. 36 | 37 | ## References 38 | ``` 39 | @misc{sun2023retentive, 40 | title={Retentive Network: A Successor to Transformer for Large Language Models}, 41 | author={Yutao Sun and Li Dong and Shaohan Huang and Shuming Ma and Yuqing Xia and Jilong Xue and Jianyong Wang and Furu Wei}, 42 | year={2023}, 43 | eprint={2307.08621}, 44 | archivePrefix={arXiv}, 45 | primaryClass={cs.CL} 46 | } 47 | ``` 48 | -------------------------------------------------------------------------------- /src/complex/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | class ComplexGroupNorm(nn.Module): 6 | def __init__(self, num_groups, num_channels, eps=1e-5): 7 | super(ComplexGroupNorm, self).__init__() 8 | self.num_groups = num_groups 9 | self.num_channels = num_channels 10 | self.eps = eps 11 | self.weight = nn.Parameter(torch.ones(num_channels, dtype=torch.float32)) 12 | self.bias = nn.Parameter(torch.zeros(num_channels, dtype=torch.float32)) 13 | 14 | def forward(self, X): 15 | """ 16 | X: (batch_size, sequence_length, hidden_size) 17 | X is assumed to be complex 18 | """ 19 | X = X.reshape(-1, self.num_groups, self.num_channels // self.num_groups) 20 | mean = X.mean(dim=2, keepdim=True) 21 | var = X.var(dim=2, keepdim=True) 22 | X = (X - mean) / torch.sqrt(var + self.eps) 23 | X = X.reshape(-1, self.num_channels) 24 | X = X * self.weight + self.bias 25 | 26 | return X 27 | 28 | class ComplexLayerNorm(nn.Module): 29 | def __init__(self, num_channels, eps=1e-5): 30 | super(ComplexLayerNorm, self).__init__() 31 | self.num_channels = num_channels 32 | self.eps = eps 33 | self.weight = nn.Parameter(torch.ones(num_channels, dtype=torch.float32)) 34 | self.bias = nn.Parameter(torch.zeros(num_channels, dtype=torch.float32)) 35 | 36 | def forward(self, X): 37 | """ 38 | X: unknown shape ending in hidden_size 39 | we treat the last dimension as the hidden_size 40 | """ 41 | X_shape = X.shape 42 | X = X.reshape(-1, X_shape[-1]) 43 | mean = X.mean(dim=1, keepdim=True) 44 | var = X.abs().var(dim=1, keepdim=True) 45 | X = (X - mean) / torch.sqrt(var + self.eps) 46 | X = X * self.weight + self.bias 47 | X = X.reshape(X_shape) 48 | return X 49 | 50 | 51 | class ComplexFFN(nn.Module): 52 | """ 53 | 2 linear layers with no bias 54 | """ 55 | def __init__(self, hidden_size, ffn_size): 56 | super(ComplexFFN, self).__init__() 57 | self.W1 = nn.Parameter(torch.randn(hidden_size, ffn_size, dtype=torch.float32) / math.sqrt(hidden_size)) 58 | self.W2 = nn.Parameter(torch.randn(ffn_size, hidden_size, dtype=torch.float32) / math.sqrt(ffn_size)) 59 | self.gelu = lambda x: 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 60 | 61 | def forward(self, X): 62 | """ 63 | X: (batch_size, sequence_length, hidden_size) 64 | X is assumed to be complex 65 | """ 66 | # reshaping 67 | X = X @ self.W1.to(X) 68 | X = self.gelu(X) 69 | X = X @ self.W2.to(X) 70 | 71 | return X 72 | -------------------------------------------------------------------------------- /src/retnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from retention import MultiScaleRetention 5 | 6 | class RetNet(nn.Module): 7 | def __init__(self, layers, hidden_dim, ffn_size, heads, double_v_dim=False): 8 | super(RetNet, self).__init__() 9 | self.layers = layers 10 | self.hidden_dim = hidden_dim 11 | self.ffn_size = ffn_size 12 | self.heads = heads 13 | self.v_dim = hidden_dim * 2 if double_v_dim else hidden_dim 14 | 15 | self.retentions = nn.ModuleList([ 16 | MultiScaleRetention(hidden_dim, heads, double_v_dim) 17 | for _ in range(layers) 18 | ]) 19 | self.ffns = nn.ModuleList([ 20 | nn.Sequential( 21 | nn.Linear(hidden_dim, ffn_size), 22 | nn.GELU(), 23 | nn.Linear(ffn_size, hidden_dim) 24 | ) 25 | for _ in range(layers) 26 | ]) 27 | self.layer_norms_1 = nn.ModuleList([ 28 | nn.LayerNorm(hidden_dim) 29 | for _ in range(layers) 30 | ]) 31 | self.layer_norms_2 = nn.ModuleList([ 32 | nn.LayerNorm(hidden_dim) 33 | for _ in range(layers) 34 | ]) 35 | 36 | def forward(self, X): 37 | """ 38 | X: (batch_size, sequence_length, hidden_size) 39 | """ 40 | for i in range(self.layers): 41 | Y = self.retentions[i](self.layer_norms_1[i](X)) + X 42 | 43 | X = self.ffns[i](self.layer_norms_2[i](Y)) + Y 44 | 45 | return X 46 | 47 | def forward_recurrent(self, x_n, s_n_1s, n): 48 | """ 49 | X: (batch_size, sequence_length, hidden_size) 50 | s_n_1s: list of lists of tensors of shape (batch_size, hidden_size // heads, hidden_size // heads) 51 | 52 | """ 53 | s_ns = [] 54 | for i in range(self.layers): 55 | # list index out of range 56 | o_n, s_n = self.retentions[i].forward_recurrent(self.layer_norms_1[i](x_n), s_n_1s[i], n) 57 | y_n = o_n + x_n 58 | s_ns.append(s_n) 59 | x_n = self.ffns[i](self.layer_norms_2[i](y_n)) + y_n 60 | 61 | return x_n, s_ns 62 | 63 | def forward_chunkwise(self, x_i, r_i_1s, i): 64 | """ 65 | X: (batch_size, sequence_length, hidden_size) 66 | r_i_1s: list of lists of tensors of shape (batch_size, hidden_size // heads, hidden_size // heads) 67 | 68 | """ 69 | r_is = [] 70 | for j in range(self.layers): 71 | o_i, r_i = self.retentions[j].forward_chunkwise(self.layer_norms_1[j](x_i), r_i_1s[j], i) 72 | y_i = o_i + x_i 73 | r_is.append(r_i) 74 | x_i = self.ffns[j](self.layer_norms_2[j](y_i)) + y_i 75 | 76 | return x_i, r_is 77 | -------------------------------------------------------------------------------- /src/xpos_relative_position.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Microsoft 2 | # Licensed under The MIT License (https://github.com/microsoft/torchscale/blob/main/LICENSE) 3 | import torch 4 | import torch.nn as nn 5 | 6 | def fixed_pos_embedding(x): 7 | seq_len, dim = x.shape 8 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim)) 9 | sinusoid_inp = ( 10 | torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x) 11 | ) 12 | return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) 13 | 14 | def rotate_every_two(x): 15 | x1 = x[:, :, ::2] 16 | x2 = x[:, :, 1::2] 17 | x = torch.stack((-x2, x1), dim=-1) 18 | if x.shape[-1]%2 == 1: 19 | # fill last dim with zero if hidden_size is odd 20 | x2 = torch.concat((x2, torch.zeros_like(x2[:, :, :1])), dim=-1) 21 | return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ 22 | 23 | def duplicate_interleave(m): 24 | """ 25 | A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. 26 | """ 27 | dim0 = m.shape[0] 28 | m = m.view(-1, 1) # flatten the matrix 29 | m = m.repeat(1, 2) # repeat all elements into the 2nd dimension 30 | m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy 31 | return m 32 | 33 | def apply_rotary_pos_emb(x, sin, cos, scale=1): 34 | sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos)) 35 | # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) 36 | return (x * cos[:, :x.shape[-1]]) + (rotate_every_two(x) * sin)[:, :, :x.shape[-1]] 37 | 38 | 39 | class XPOS(nn.Module): 40 | def __init__( 41 | self, head_dim, scale_base=512 42 | ): 43 | super().__init__() 44 | self.head_dim = head_dim 45 | self.scale_base = scale_base 46 | self.register_buffer( 47 | "scale", (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim) 48 | ) 49 | 50 | def forward(self, x, offset=0, downscale=False): 51 | length = x.shape[1] 52 | min_pos = 0 53 | max_pos = length + offset + min_pos 54 | scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None] 55 | sin, cos = fixed_pos_embedding(scale) 56 | 57 | if scale.shape[0] > length: 58 | scale = scale[-length:] 59 | sin = sin[-length:] 60 | cos = cos[-length:] 61 | 62 | if downscale: 63 | scale = 1 / scale 64 | 65 | x = apply_rotary_pos_emb(x, sin, cos, scale) 66 | return x 67 | 68 | def forward_reverse(self, x, offset=0, downscale=False): 69 | length = x.shape[1] 70 | min_pos = -(length + offset) // 2 71 | max_pos = length + offset + min_pos 72 | scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None] 73 | sin, cos = fixed_pos_embedding(scale) 74 | 75 | if scale.shape[0] > length: 76 | scale = scale[-length:] 77 | sin = sin[-length:] 78 | cos = cos[-length:] 79 | 80 | if downscale: 81 | scale = 1 / scale 82 | 83 | x = apply_rotary_pos_emb(x, -sin, cos, scale) 84 | return x 85 | 86 | # test 87 | if __name__ == "__main__": 88 | x = torch.eye(4).unsqueeze(0) 89 | xpos = XPOS(4) 90 | x_rot = xpos(x) 91 | # apply reverse 92 | x_rot_rev = xpos.forward(x) 93 | 94 | print(x_rot @ x_rot_rev.transpose(-1, -2)) -------------------------------------------------------------------------------- /src/complex/test_retnet.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from retnet import RetNet, RetNetCLM 4 | 5 | class TestRetNet(unittest.TestCase): 6 | 7 | def test_paradigms_equivalent(self): 8 | batch_size = 2 9 | layers = 2 10 | hidden_dim = 8 11 | heads = 4 12 | sequence_length = 4 13 | ffn_size = 16 14 | 15 | X = torch.rand(batch_size, sequence_length, hidden_dim) 16 | 17 | retnet = RetNet(layers, hidden_dim, ffn_size, heads) 18 | Y_parallel = retnet(X) 19 | 20 | s_n_1s = [ 21 | [ 22 | torch.zeros(hidden_dim // heads, hidden_dim // heads, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) 23 | for _ in range(heads) 24 | ] for _ in range(layers) 25 | ] 26 | 27 | Y_recurrent = [] 28 | for i in range(sequence_length): 29 | Y, s_ns = retnet.forward_recurrent(X[:, i, :], s_n_1s, i+1) 30 | Y_recurrent.append(Y) 31 | s_n_1s = s_ns 32 | 33 | Y_recurrent = torch.stack(Y_recurrent, dim=1) 34 | 35 | print((Y_parallel - Y_recurrent).abs().max()) 36 | 37 | self.assertTrue((Y_parallel - Y_recurrent).abs().max() < 1e-4) 38 | 39 | def test_clm(self): 40 | batch_size = 2 41 | layers = 2 42 | hidden_dim = 16 43 | heads = 4 44 | sequence_length = 6 45 | ffn_size = 32 46 | vocab_size = 10 47 | 48 | X = torch.randint(0, vocab_size, (batch_size, sequence_length)) 49 | 50 | retnet = RetNetCLM(layers, hidden_dim, ffn_size, heads, vocab_size) 51 | Y_parallel = retnet(X) 52 | 53 | s_n_1s = [ 54 | [ 55 | torch.zeros(hidden_dim // heads, hidden_dim // heads, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) 56 | for _ in range(heads) 57 | ] for _ in range(layers) 58 | ] 59 | 60 | Y_recurrent = [] 61 | for i in range(sequence_length): 62 | Y, s_ns = retnet.forward_recurrent(X[:, i], s_n_1s, i+1) 63 | Y_recurrent.append(Y) 64 | s_n_1s = s_ns 65 | 66 | Y_recurrent = torch.stack(Y_recurrent, dim=1) 67 | 68 | # test sample 69 | Y_sample = retnet.sample(X, 5) 70 | 71 | self.assertTrue(Y_sample.shape == (batch_size, 5)) 72 | 73 | self.assertTrue((Y_parallel - Y_recurrent).abs().max() < 1e-4) 74 | 75 | def test_training(self): 76 | batch_size = 2 77 | layers = 3 78 | hidden_dim = 16 79 | heads = 4 80 | sequence_length = 6 81 | ffn_size = 32 82 | vocab_size = 10 83 | bos_idx = 0 84 | 85 | data = torch.randint(0, vocab_size, (batch_size, sequence_length - 1)) 86 | X = torch.cat([torch.ones(batch_size, 1).long() * bos_idx, data[:,:-1]], dim=1) 87 | Y = data 88 | 89 | # verify we can overfit autoregressive model 90 | model = RetNetCLM(layers, hidden_dim, ffn_size, heads, vocab_size) 91 | 92 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) 93 | criterion = torch.nn.CrossEntropyLoss() 94 | initial_loss = criterion(model(X).reshape(-1, 10), Y.reshape(-1)) 95 | for i in range(10): 96 | optimizer.zero_grad() 97 | output = model(X) 98 | loss = criterion(output.reshape(-1, 10), Y.reshape(-1)) 99 | loss.backward() 100 | optimizer.step() 101 | self.assertTrue((loss < initial_loss).item()) 102 | unittest.main() -------------------------------------------------------------------------------- /src/complex/retnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from retention import MultiScaleRetention 5 | from util import ComplexFFN, ComplexGroupNorm, ComplexLayerNorm 6 | 7 | class RetNet(nn.Module): 8 | def __init__(self, layers, hidden_dim, ffn_size, heads): 9 | super(RetNet, self).__init__() 10 | self.layers = layers 11 | self.hidden_dim = hidden_dim 12 | self.ffn_size = ffn_size 13 | self.heads = heads 14 | 15 | self.retentions = nn.ModuleList([ 16 | MultiScaleRetention(hidden_dim, heads) 17 | for _ in range(layers) 18 | ]) 19 | self.ffns = nn.ModuleList([ 20 | ComplexFFN(hidden_dim, ffn_size) 21 | for _ in range(layers) 22 | ]) 23 | self.layer_norm = ComplexLayerNorm(hidden_dim) 24 | 25 | def forward(self, X): 26 | """ 27 | X: (batch_size, sequence_length, hidden_size) 28 | """ 29 | for i in range(self.layers): 30 | Y = self.retentions[i](self.layer_norm(X)) + X 31 | X = self.ffns[i](self.layer_norm(Y)) + Y 32 | 33 | return X 34 | 35 | def forward_recurrent(self, x_n, s_n_1s, n): 36 | """ 37 | X: (batch_size, sequence_length, hidden_size) 38 | s_n_1s: list of lists of tensors of shape (batch_size, hidden_size // heads, hidden_size // heads) 39 | 40 | """ 41 | s_ns = [] 42 | for i in range(self.layers): 43 | o_n, s_n = self.retentions[i].forward_recurrent(self.layer_norm(x_n), s_n_1s[i], n) 44 | y_n = o_n + x_n 45 | s_ns.append(s_n) 46 | x_n = self.ffns[i](self.layer_norm(y_n)) + y_n 47 | 48 | return x_n, s_ns 49 | 50 | class RetNetCLM(nn.Module): 51 | def __init__(self, layers, hidden_dim, ffn_size, heads, vocab_size): 52 | """ 53 | NOTE: softmax not included! 54 | """ 55 | super(RetNetCLM, self).__init__() 56 | self.layers = layers 57 | self.hidden_dim = hidden_dim 58 | self.ffn_size = ffn_size 59 | self.heads = heads 60 | self.vocab_size = vocab_size 61 | 62 | self.retnet = RetNet(layers, hidden_dim, ffn_size, heads) 63 | self.embed = nn.Embedding(vocab_size, hidden_dim) 64 | self.proj = nn.Parameter(torch.randn(hidden_dim, vocab_size, dtype=torch.float32) / hidden_dim) 65 | 66 | def forward(self, input_ids): 67 | """ 68 | input_ids: (batch_size, sequence_length) 69 | """ 70 | X = self.embed(input_ids) 71 | X = self.retnet(X) 72 | X = X @ self.proj.to(X.dtype) 73 | 74 | return X.real 75 | 76 | def forward_recurrent(self, input_ids, s_n_1s, n): 77 | """ 78 | input_ids: (batch_size) 79 | s_n_1s: list of lists of tensors of shape (batch_size, hidden_size // heads, hidden_size // heads) 80 | """ 81 | X = self.embed(input_ids) 82 | X, s_ns = self.retnet.forward_recurrent(X, s_n_1s, n) 83 | X = X @ self.proj.to(X.dtype) 84 | 85 | return X.real, s_ns 86 | 87 | def sample(self, input_ids, sample_length, temperature=1.0): 88 | """ 89 | input_ids: (batch_size, sequence_length) 90 | s_n_1s: list of lists of tensors of shape (batch_size, hidden_size // heads, hidden_size // heads) 91 | """ 92 | s_n_1s = [ 93 | [ 94 | torch.zeros(self.hidden_dim // self.heads, self.hidden_dim // self.heads, dtype=torch.complex64).unsqueeze(0).repeat(input_ids.shape[0], 1, 1) 95 | for _ in range(self.heads) 96 | ] for _ in range(self.layers) 97 | ] 98 | for i in range(input_ids.shape[1]): 99 | X, s_n_1s = self.forward_recurrent(input_ids[:, i], s_n_1s, i+1) 100 | 101 | # get softmax of x (real part only) 102 | X = X.real / temperature 103 | X = torch.softmax(X, dim=-1) 104 | X = torch.multinomial(X, num_samples=1) 105 | next_char = X[:, -1] 106 | output_ids = [] 107 | # now start sampling! 108 | for i in range(sample_length): 109 | X, s_n_1s = self.forward_recurrent(next_char, s_n_1s, i+1) 110 | X = X.real / temperature 111 | X = torch.softmax(X, dim=-1) 112 | X = torch.multinomial(X, num_samples=1) 113 | next_char = X[:, -1] 114 | output_ids.append(next_char) 115 | 116 | output_ids = torch.stack(output_ids, dim=1) 117 | 118 | return output_ids -------------------------------------------------------------------------------- /src/complex/test_retention.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from retention import SimpleRetention, MultiScaleRetention 4 | 5 | class TestSimpleRetention(unittest.TestCase): 6 | def test_simple_retention_parallel(self): 7 | batch_size = 4 8 | hidden_size = 8 9 | sequence_length = 16 10 | gamma = 0.9 11 | 12 | X = torch.rand(batch_size, sequence_length, hidden_size) 13 | retention = SimpleRetention(hidden_size, gamma) 14 | 15 | Y = retention(X) 16 | self.assertEqual(Y.shape, (batch_size, sequence_length, hidden_size)) 17 | 18 | def test_simple_retention_recurrent(self): 19 | batch_size = 4 20 | hidden_size = 8 21 | sequence_length = 16 22 | gamma = 0.9 23 | 24 | X = torch.rand(batch_size, sequence_length, hidden_size) 25 | retention = SimpleRetention(hidden_size, gamma) 26 | 27 | s_n_1 = torch.zeros(hidden_size, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) 28 | Y = [] 29 | for i in range(sequence_length): 30 | y_n, s_n = retention.forward_recurrent(X[:, i, :], s_n_1, i+1) 31 | Y.append(y_n) 32 | s_n_1 = s_n 33 | Y = torch.stack(Y, dim=1) 34 | self.assertEqual(Y.shape, (batch_size, sequence_length, hidden_size)) 35 | 36 | def test_paradigms_identical(self): 37 | """ 38 | check that the parallel and recurrent paradigms have identical outputs 39 | """ 40 | batch_size = 1 41 | hidden_size = 8 42 | sequence_length = 4 43 | gamma = 0.90 44 | 45 | X = torch.rand(batch_size, sequence_length, hidden_size) 46 | retention = SimpleRetention(hidden_size, gamma) 47 | 48 | Y_parallel = retention(X) 49 | 50 | s_n_1 = torch.zeros(hidden_size, hidden_size, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) 51 | Y_recurrent = [] 52 | for i in range(sequence_length): 53 | y_n, s_n = retention.forward_recurrent(X[:, i, :], s_n_1, i+1) 54 | Y_recurrent.append(y_n) 55 | s_n_1 = s_n 56 | Y_recurrent = torch.stack(Y_recurrent, dim=1) 57 | 58 | self.assertTrue(torch.allclose(Y_parallel, Y_recurrent)) 59 | 60 | class TestMultiScaleRetention(unittest.TestCase): 61 | def test_multiscale_retention_parallel(self): 62 | batch_size = 4 63 | sequence_length = 5 64 | hidden_size = 32 65 | heads = 4 66 | retention = MultiScaleRetention(hidden_size, heads) 67 | 68 | X = torch.rand(batch_size, sequence_length, hidden_size) 69 | Y = retention(X) 70 | self.assertEqual(Y.shape, (batch_size, sequence_length, hidden_size)) 71 | 72 | def test_multiscale_retention_recurrent(self): 73 | batch_size = 4 74 | sequence_length = 5 75 | hidden_size = 32 76 | heads = 4 77 | retention = MultiScaleRetention(hidden_size, heads) 78 | 79 | X = torch.rand(batch_size, sequence_length, hidden_size) 80 | s_n_1s = [ 81 | torch.zeros(hidden_size // heads, hidden_size // heads, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) 82 | for _ in range(heads) 83 | ] 84 | Y = [] 85 | for i in range(sequence_length): 86 | y_n, s_ns = retention.forward_recurrent(X[:, i, :], s_n_1s, i) 87 | Y.append(y_n) 88 | s_n_1s = s_ns 89 | Y = torch.stack(Y, dim=1) 90 | self.assertEqual(Y.shape, (batch_size, sequence_length, hidden_size)) 91 | 92 | def test_multiscale_paradigms_identical(self): 93 | """ 94 | check that the parallel and recurrent paradigms have identical outputs 95 | """ 96 | batch_size = 2 97 | hidden_size = 36 98 | sequence_length = 5 99 | heads = 3 100 | 101 | X = torch.rand(batch_size, sequence_length, hidden_size) 102 | retention = MultiScaleRetention(hidden_size, heads) 103 | 104 | Y_parallel = retention(X) 105 | 106 | s_n_1s = [ 107 | torch.zeros(hidden_size // heads, hidden_size // heads, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) 108 | for _ in range(heads) 109 | ] 110 | Y_recurrent = [] 111 | for i in range(sequence_length): 112 | y_n, s_ns = retention.forward_recurrent(X[:, i, :], s_n_1s, i) 113 | Y_recurrent.append(y_n) 114 | s_n_1s = s_ns 115 | Y_recurrent = torch.stack(Y_recurrent, dim=1) 116 | 117 | self.assertTrue(torch.allclose(Y_parallel, Y_recurrent)) 118 | 119 | unittest.main() -------------------------------------------------------------------------------- /src/tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from retention import SimpleRetention, MultiScaleRetention 6 | from retnet import RetNet 7 | 8 | class TestRetention(unittest.TestCase): 9 | 10 | def test_simple(self): 11 | """ 12 | verify that the three implementations of SimpleRetention are identical 13 | """ 14 | batch_size = 4 15 | sequence_length = 12 16 | hidden_size = 6 17 | chunk_size = 4 18 | 19 | gamma = 0.9 20 | 21 | X = torch.rand(batch_size, sequence_length, hidden_size) 22 | sr = SimpleRetention(hidden_size, gamma, double_v_dim=True) 23 | 24 | Y_parallel = sr(X) 25 | 26 | s_n_1 = torch.zeros(hidden_size, sr.v_dim).unsqueeze(0).repeat(batch_size, 1, 1) 27 | Y_recurrent = [] 28 | for i in range(sequence_length): 29 | y_n, s_n = sr.forward_recurrent(X[:, i:i+1, :], s_n_1, i) 30 | Y_recurrent.append(y_n) 31 | s_n_1 = s_n 32 | 33 | Y_recurrent = torch.concat(Y_recurrent, dim=1) 34 | 35 | r_n_1 = torch.zeros(hidden_size, sr.v_dim).unsqueeze(0).repeat(batch_size, 1, 1) 36 | Y_chunkwise = [] 37 | for i in range(sequence_length // chunk_size): 38 | y_i, r_i = sr.forward_chunkwise(X[:, i*chunk_size:(i+1)*chunk_size, :], r_n_1, i) 39 | Y_chunkwise.append(y_i) 40 | r_n_1 = r_i 41 | 42 | 43 | Y_chunkwise = torch.concat(Y_chunkwise, dim=1) 44 | 45 | 46 | assert torch.allclose(Y_parallel, Y_recurrent, atol=1e-5) 47 | assert torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5) 48 | 49 | 50 | def test_multiscale(self): 51 | """ 52 | verify that the three implementations of MultiScaleRetention are identical 53 | """ 54 | batch_size = 2 55 | hidden_size = 6 56 | sequence_length = 12 57 | heads = 3 58 | chunk_size = 2 59 | 60 | X = torch.rand(batch_size, sequence_length, hidden_size) 61 | retention = MultiScaleRetention(hidden_size, heads, double_v_dim=False) 62 | # print total number of parameters 63 | print("Default v_dim:",sum(p.numel() for p in retention.parameters() if p.requires_grad)) 64 | 65 | retention = MultiScaleRetention(hidden_size, heads, double_v_dim=True) 66 | print("Double v_dim:",sum(p.numel() for p in retention.parameters() if p.requires_grad)) 67 | 68 | Y_parallel = retention(X) 69 | 70 | s_n_1s = [ 71 | torch.zeros(hidden_size // heads, retention.v_dim // heads).unsqueeze(0).repeat(batch_size, 1, 1) 72 | for _ in range(heads) 73 | ] 74 | Y_recurrent = [] 75 | for i in range(sequence_length): 76 | y_n, s_ns = retention.forward_recurrent(X[:, i:i+1, :], s_n_1s, i) 77 | Y_recurrent.append(y_n) 78 | s_n_1s = s_ns 79 | 80 | Y_recurrent = torch.concat(Y_recurrent, dim=1) 81 | 82 | r_n_1s = [ 83 | torch.zeros(hidden_size // heads, retention.v_dim // heads).unsqueeze(0).repeat(batch_size, 1, 1) 84 | for _ in range(heads) 85 | ] 86 | Y_chunkwise = [] 87 | for i in range(sequence_length // chunk_size): 88 | y_i, r_i = retention.forward_chunkwise(X[:, i*chunk_size:(i+1)*chunk_size, :], r_n_1s, i) 89 | Y_chunkwise.append(y_i) 90 | r_n_1s = r_i 91 | 92 | Y_chunkwise = torch.concat(Y_chunkwise, dim=1) 93 | 94 | self.assertTrue(torch.allclose(Y_parallel, Y_recurrent, atol=1e-5)) 95 | self.assertTrue(torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5)) # fails 96 | 97 | class TestRetNet(unittest.TestCase): 98 | 99 | def test_retnet(self): 100 | """ 101 | verify that the three implementations of RetNet are identical 102 | """ 103 | batch_size = 2 104 | hidden_size = 36 105 | sequence_length = 5 106 | heads = 3 107 | layers = 4 108 | ffn_size = 128 109 | 110 | X = torch.rand(batch_size, sequence_length, hidden_size) 111 | retnet = RetNet(layers, hidden_size, ffn_size, heads, double_v_dim=False) 112 | # print total number of parameters 113 | print("Default v_dim:",sum(p.numel() for p in retnet.parameters() if p.requires_grad)) 114 | 115 | retnet = RetNet(layers, hidden_size, ffn_size, heads, double_v_dim=True) 116 | print("Double v_dim:",sum(p.numel() for p in retnet.parameters() if p.requires_grad)) 117 | 118 | Y_parallel = retnet(X) 119 | 120 | s_n_1s = [ 121 | [ 122 | torch.zeros(hidden_size // heads, retnet.v_dim // heads).unsqueeze(0).repeat(batch_size, 1, 1) 123 | for _ in range(heads) 124 | ] 125 | for _ in range(layers) 126 | ] 127 | Y_recurrent = [] 128 | for i in range(sequence_length): 129 | y_n, s_ns = retnet.forward_recurrent(X[:, i:i+1, :], s_n_1s, i) 130 | Y_recurrent.append(y_n) 131 | s_n_1s = s_ns 132 | 133 | Y_recurrent = torch.concat(Y_recurrent, dim=1) 134 | 135 | r_n_1s = [ 136 | [ 137 | torch.zeros(hidden_size // heads, retnet.v_dim // heads).unsqueeze(0).repeat(batch_size, 1, 1) 138 | for _ in range(heads) 139 | ] 140 | for _ in range(layers) 141 | ] 142 | Y_chunkwise = [] 143 | for i in range(sequence_length): 144 | y_i, r_i = retnet.forward_chunkwise(X[:, i:i+1, :], r_n_1s, i) 145 | Y_chunkwise.append(y_i) 146 | r_n_1s = r_i 147 | 148 | Y_chunkwise = torch.concat(Y_chunkwise, dim=1) 149 | 150 | self.assertTrue(torch.allclose(Y_parallel, Y_recurrent, atol=1e-5)) 151 | self.assertTrue(torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5)) 152 | 153 | if __name__ == "__main__": 154 | unittest.main() 155 | -------------------------------------------------------------------------------- /src/complex/retention.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from util import ComplexGroupNorm 7 | 8 | class SimpleRetention(nn.Module): 9 | def __init__(self, hidden_size, gamma, precision="single"): 10 | """ 11 | Simple retention mechanism based on the paper 12 | "Retentive Network: A Successor to Transformer for Large Language Models"[https://arxiv.org/pdf/2307.08621.pdf] 13 | """ 14 | super(SimpleRetention, self).__init__() 15 | 16 | if precision == "half": 17 | raise NotImplementedError("batchmm does not support half precision complex yet.") 18 | self.complex_type = torch.complex32 19 | self.real_type = torch.float16 20 | elif precision == "single": 21 | self.complex_type = torch.complex64 22 | self.real_type = torch.float32 23 | 24 | self.precision = precision 25 | self.hidden_size = hidden_size 26 | self.gamma = gamma 27 | 28 | self.i = torch.complex(torch.tensor(0.0), torch.tensor(1.0)) 29 | 30 | self.W_Q = nn.Parameter(torch.randn(hidden_size, hidden_size, dtype=self.real_type) / hidden_size) 31 | self.W_K = nn.Parameter(torch.randn(hidden_size, hidden_size, dtype=self.real_type) / hidden_size) 32 | self.W_V = nn.Parameter(torch.randn(hidden_size, hidden_size, dtype=self.real_type) / hidden_size) 33 | 34 | 35 | self.theta = torch.randn(hidden_size) / hidden_size 36 | self.theta = nn.Parameter(self.theta) 37 | 38 | 39 | 40 | def forward(self, X): 41 | """ 42 | Parallel (default) representation of the retention mechanism. 43 | X: (batch_size, sequence_length, hidden_size) 44 | """ 45 | sequence_length = X.shape[1] 46 | D = self._get_D(sequence_length).to(X.device) 47 | 48 | if X.dtype != self.complex_type: 49 | X = torch.complex(X, torch.zeros_like(X)).to(self.complex_type) 50 | 51 | i = self.i.to(X.device) 52 | ns = torch.arange(1, sequence_length + 1, dtype=self.real_type, device=X.device) 53 | ns = torch.complex(ns, torch.zeros_like(ns)).to(self.complex_type) 54 | Theta = [] 55 | 56 | for n in ns: 57 | Theta.append(torch.exp(i * n * self.theta)) 58 | 59 | Theta = torch.stack(Theta, dim=0) 60 | 61 | Theta_bar = Theta.conj() 62 | 63 | Q = (X @ self.W_Q.to(self.complex_type)) * Theta.unsqueeze(0) 64 | K = (X @ self.W_K.to(self.complex_type)) * Theta_bar.unsqueeze(0) 65 | V = X @ self.W_V.to(self.complex_type) 66 | att = (Q @ K.permute(0, 2, 1)) * D.unsqueeze(0) 67 | 68 | return att @ V 69 | 70 | def forward_recurrent(self, x_n, s_n_1, n): 71 | """ 72 | Recurrent representation of the retention mechanism. 73 | x_n: (batch_size, hidden_size) 74 | s_n_1: (batch_size, hidden_size) 75 | """ 76 | if x_n.dtype != self.complex_type: 77 | x_n = torch.complex(x_n, torch.zeros_like(x_n)).to(self.complex_type) 78 | 79 | n = torch.tensor(n, dtype=self.complex_type, device=x_n.device) 80 | 81 | Theta = torch.exp(self.i * n * self.theta) 82 | Theta_bar = Theta.conj() 83 | 84 | Q = (x_n @ self.W_Q.to(self.complex_type)) * Theta 85 | K = (x_n @ self.W_K.to(self.complex_type)) * Theta_bar 86 | V = x_n @ self.W_V.to(self.complex_type) 87 | 88 | # K: (batch_size, hidden_size) 89 | # V: (batch_size, hidden_size) 90 | # s_n_1: (batch_size, hidden_size, hidden_size) 91 | # s_n = gamma * s_n_1 + K^T @ V 92 | 93 | s_n = self.gamma * s_n_1 + K.unsqueeze(2) @ V.unsqueeze(1) 94 | 95 | return (Q.unsqueeze(1) @ s_n).squeeze(1), s_n 96 | 97 | def _get_D(self, sequence_length): 98 | n = torch.arange(sequence_length).unsqueeze(1) 99 | m = torch.arange(sequence_length).unsqueeze(0) 100 | 101 | # Broadcast self.gamma ** (n - m) with appropriate masking to set values where n < m to 0 102 | D = (self.gamma ** (n - m)) * (n >= m).float() #this results in some NaN when n is much larger than m 103 | # fill the NaN with 0 104 | D[D != D] = 0 105 | 106 | return D 107 | 108 | class MultiScaleRetention(nn.Module): 109 | def __init__(self, hidden_size, heads, precision="single"): 110 | """ 111 | Multi-scale retention mechanism based on the paper 112 | "Retentive Network: A Successor to Transformer for Large Language Models"[https://arxiv.org/pdf/2307.08621.pdf] 113 | """ 114 | super(MultiScaleRetention, self).__init__() 115 | self.hidden_size = hidden_size 116 | self.heads = heads 117 | self.precision = precision 118 | assert hidden_size % heads == 0, "hidden_size must be divisible by heads" 119 | self.head_size = hidden_size // heads 120 | 121 | if precision == "half": 122 | raise NotImplementedError("batchmm does not support half precision complex yet.") 123 | self.complex_type = torch.complex32 124 | self.real_type = torch.float16 125 | elif precision == "single": 126 | self.complex_type = torch.complex64 127 | self.real_type = torch.float32 128 | 129 | self.gammas = (1 - torch.exp(torch.linspace(math.log(1/32), math.log(1/512), heads, dtype=self.real_type))).detach().cpu().tolist() 130 | 131 | self.swish = lambda x: x * torch.sigmoid(x) 132 | self.W_G = nn.Parameter(torch.randn(hidden_size, hidden_size, dtype=self.complex_type) / hidden_size) 133 | self.W_O = nn.Parameter(torch.randn(hidden_size, hidden_size, dtype=self.complex_type) / hidden_size) 134 | self.group_norm = ComplexGroupNorm(heads, hidden_size) 135 | 136 | self.retentions = nn.ModuleList([ 137 | SimpleRetention(self.head_size, gamma) for gamma in self.gammas 138 | ]) 139 | 140 | def forward(self, X): 141 | """ 142 | parallel representation of the multi-scale retention mechanism 143 | """ 144 | if X.dtype != self.complex_type: 145 | X = torch.complex(X, torch.zeros_like(X)).to(self.complex_type) 146 | 147 | # apply each individual retention mechanism to a slice of X 148 | Y = [] 149 | for i in range(self.heads): 150 | Y.append(self.retentions[i](X[:, :, i*self.head_size:(i+1)*self.head_size])) 151 | 152 | Y = torch.cat(Y, dim=2) 153 | Y = self.group_norm(Y.reshape(-1, self.hidden_size)).reshape(X.shape) 154 | 155 | return (self.swish(X @ self.W_G.to(self.complex_type)) * Y) @ self.W_O.to(self.complex_type) 156 | 157 | def forward_recurrent(self, x_n, s_n_1s, n): 158 | """ 159 | recurrent representation of the multi-scale retention mechanism 160 | """ 161 | if x_n.dtype != self.complex_type: 162 | x_n = torch.complex(x_n, torch.zeros_like(x_n)).to(self.complex_type) 163 | n = torch.tensor(n, dtype=self.complex_type, device=x_n.device) 164 | 165 | # apply each individual retention mechanism to a slice of X 166 | Y = [] 167 | s_ns = [] 168 | for i in range(self.heads): 169 | y, s_n = self.retentions[i].forward_recurrent( 170 | x_n[:, i*self.head_size:(i+1)*self.head_size], s_n_1s[i], n 171 | ) 172 | Y.append(y) 173 | s_ns.append(s_n) 174 | 175 | Y = torch.cat(Y, dim=1) 176 | Y = self.group_norm(Y) 177 | return (self.swish(x_n @ self.W_G.to(self.complex_type)) * Y) @ self.W_O.to(self.complex_type), s_ns 178 | -------------------------------------------------------------------------------- /src/retention.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from xpos_relative_position import XPOS 7 | 8 | class SimpleRetention(nn.Module): 9 | def __init__(self, hidden_size, gamma, head_size=None, double_v_dim=False): 10 | """ 11 | Simple retention mechanism based on the paper 12 | "Retentive Network: A Successor to Transformer for Large Language Models"[https://arxiv.org/pdf/2307.08621.pdf] 13 | """ 14 | super(SimpleRetention, self).__init__() 15 | 16 | self.hidden_size = hidden_size 17 | if head_size is None: 18 | head_size = hidden_size 19 | self.head_size = head_size 20 | 21 | self.v_dim = head_size * 2 if double_v_dim else head_size 22 | self.gamma = gamma 23 | 24 | self.W_Q = nn.Parameter(torch.randn(hidden_size, head_size) / hidden_size) 25 | self.W_K = nn.Parameter(torch.randn(hidden_size, head_size) / hidden_size) 26 | self.W_V = nn.Parameter(torch.randn(hidden_size, self.v_dim) / hidden_size) 27 | 28 | self.xpos = XPOS(head_size) 29 | 30 | def forward(self, X): 31 | """ 32 | Parallel (default) representation of the retention mechanism. 33 | X: (batch_size, sequence_length, hidden_size) 34 | """ 35 | sequence_length = X.shape[1] 36 | D = self._get_D(sequence_length).to(self.W_Q.device) 37 | 38 | Q = (X @ self.W_Q) 39 | K = (X @ self.W_K) 40 | 41 | Q = self.xpos(Q) 42 | K = self.xpos(K, downscale=True) 43 | 44 | V = X @ self.W_V 45 | ret = (Q @ K.permute(0, 2, 1)) * D.unsqueeze(0) 46 | 47 | return ret @ V 48 | 49 | def forward_recurrent(self, x_n, s_n_1, n): 50 | """ 51 | Recurrent representation of the retention mechanism. 52 | x_n: (batch_size, 1, hidden_size) 53 | s_n_1: (batch_size, hidden_size, v_dim) 54 | """ 55 | 56 | Q = (x_n @ self.W_Q) 57 | K = (x_n @ self.W_K) 58 | 59 | Q = self.xpos(Q, n+1) 60 | K = self.xpos(K, n+1, downscale=True) 61 | 62 | V = x_n @ self.W_V 63 | 64 | # K: (batch_size, 1, hidden_size) 65 | # V: (batch_size, 1, v_dim) 66 | # s_n = gamma * s_n_1 + K^T @ V 67 | 68 | s_n = self.gamma * s_n_1 + (K.transpose(-1, -2) @ V) 69 | 70 | return (Q @ s_n), s_n 71 | 72 | def forward_chunkwise(self, x_i, r_i_1, i): 73 | """ 74 | Chunkwise representation of the retention mechanism. 75 | x_i: (batch_size, chunk_size, hidden_size) 76 | r_i_1: (batch_size, hidden_size, v_dim) 77 | """ 78 | batch, chunk_size, _ = x_i.shape 79 | D = self._get_D(chunk_size) 80 | 81 | Q = (x_i @ self.W_Q) 82 | K = (x_i @ self.W_K) 83 | 84 | Q = self.xpos(Q, i * chunk_size) 85 | K = self.xpos(K, i * chunk_size, downscale=True) 86 | 87 | V = x_i @ self.W_V 88 | 89 | r_i =(K.transpose(-1, -2) @ (V * D[-1].view(1, chunk_size, 1))) + (self.gamma ** chunk_size) * r_i_1 90 | 91 | inner_chunk = ((Q @ K.transpose(-1, -2)) * D.unsqueeze(0)) @ V 92 | 93 | #e[i,j] = gamma ** (i+1) 94 | e = torch.zeros(batch, chunk_size, 1) 95 | 96 | for _i in range(chunk_size): 97 | e[:, _i, :] = self.gamma ** (_i + 1) 98 | 99 | cross_chunk = (Q @ r_i_1) * e 100 | 101 | return inner_chunk + cross_chunk, r_i 102 | 103 | def _get_D(self, sequence_length): 104 | n = torch.arange(sequence_length).unsqueeze(1) 105 | m = torch.arange(sequence_length).unsqueeze(0) 106 | 107 | # Broadcast self.gamma ** (n - m) with appropriate masking to set values where n < m to 0 108 | D = (self.gamma ** (n - m)) * (n >= m).float() #this results in some NaN when n is much larger than m 109 | # fill the NaN with 0 110 | D[D != D] = 0 111 | 112 | return D 113 | 114 | 115 | 116 | class MultiScaleRetention(nn.Module): 117 | def __init__(self, hidden_size, heads, double_v_dim=False): 118 | """ 119 | Multi-scale retention mechanism based on the paper 120 | "Retentive Network: A Successor to Transformer for Large Language Models"[https://arxiv.org/pdf/2307.08621.pdf] 121 | """ 122 | super(MultiScaleRetention, self).__init__() 123 | self.hidden_size = hidden_size 124 | self.v_dim = hidden_size * 2 if double_v_dim else hidden_size 125 | self.heads = heads 126 | assert hidden_size % heads == 0, "hidden_size must be divisible by heads" 127 | self.head_size = hidden_size // heads 128 | self.head_v_dim = hidden_size * 2 if double_v_dim else hidden_size 129 | 130 | self.gammas = (1 - torch.exp(torch.linspace(math.log(1/32), math.log(1/512), heads))).detach().cpu().tolist() 131 | 132 | self.swish = lambda x: x * torch.sigmoid(x) 133 | self.W_G = nn.Parameter(torch.randn(hidden_size, self.v_dim) / hidden_size) 134 | self.W_O = nn.Parameter(torch.randn(self.v_dim, hidden_size) / hidden_size) 135 | self.group_norm = nn.GroupNorm(heads, self.v_dim) 136 | 137 | self.retentions = nn.ModuleList([ 138 | SimpleRetention(self.hidden_size, gamma, self.head_size, double_v_dim) for gamma in self.gammas 139 | ]) 140 | 141 | def forward(self, X): 142 | """ 143 | parallel representation of the multi-scale retention mechanism 144 | """ 145 | 146 | # apply each individual retention mechanism to X 147 | Y = [] 148 | for i in range(self.heads): 149 | Y.append(self.retentions[i](X)) 150 | 151 | Y = torch.cat(Y, dim=2) 152 | Y_shape = Y.shape 153 | Y = self.group_norm(Y.reshape(-1, self.v_dim)).reshape(Y_shape) 154 | 155 | return (self.swish(X @ self.W_G) * Y) @ self.W_O 156 | 157 | def forward_recurrent(self, x_n, s_n_1s, n): 158 | """ 159 | recurrent representation of the multi-scale retention mechanism 160 | x_n: (batch_size, 1, hidden_size) 161 | s_n_1s: (batch_size, heads, head_size, head_size) 162 | 163 | """ 164 | 165 | # apply each individual retention mechanism to a slice of X 166 | Y = [] 167 | s_ns = [] 168 | for i in range(self.heads): 169 | y, s_n = self.retentions[i].forward_recurrent( 170 | x_n[:, :, :], s_n_1s[i], n 171 | ) 172 | Y.append(y) 173 | s_ns.append(s_n) 174 | 175 | Y = torch.cat(Y, dim=2) 176 | Y_shape = Y.shape 177 | Y = self.group_norm(Y.reshape(-1, self.v_dim)).reshape(Y_shape) 178 | 179 | return (self.swish(x_n @ self.W_G) * Y) @ self.W_O, s_ns 180 | 181 | def forward_chunkwise(self, x_i, r_i_1s, i): 182 | """ 183 | chunkwise representation of the multi-scale retention mechanism 184 | x_i: (batch_size, chunk_size, hidden_size) 185 | r_i_1s: (batch_size, heads, head_size, head_size) 186 | """ 187 | batch, chunk_size, _ = x_i.shape 188 | 189 | # apply each individual retention mechanism to a slice of X 190 | Y = [] 191 | r_is = [] 192 | for j in range(self.heads): 193 | y, r_i = self.retentions[j].forward_chunkwise( 194 | x_i[:, :, :], r_i_1s[j], i 195 | ) 196 | Y.append(y) 197 | r_is.append(r_i) 198 | 199 | 200 | Y = torch.cat(Y, dim=2) 201 | Y_shape = Y.shape 202 | Y = self.group_norm(Y.reshape(-1, self.v_dim)).reshape(Y_shape) 203 | 204 | return (self.swish(x_i @ self.W_G) * Y) @ self.W_O, r_is 205 | --------------------------------------------------------------------------------