├── .gitignore ├── README.md ├── config └── llm.yaml ├── requirements.txt ├── res └── loss.jpeg ├── run.py ├── setup.py ├── test ├── test_mlstm.py ├── test_slstm.py ├── test_stories.py └── test_xlstm.py └── xlstm ├── __init__.py ├── data.py ├── llm.py ├── lstm.py ├── stories.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.vs* 2 | *.pytest* 3 | *.pyc 4 | *.ipynb 5 | 6 | *egg* 7 | *log* 8 | .local 9 | test/.* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # xLSTM in Easy Pytorch 2 | 3 | This repo contains the _unofficial_ implementation of `xLSTM` model as introduced in [Beck et al. (2024)](https://arxiv.org/abs/2405.04517). This repo is developed mainly for didactic purposes to spell out the details of a modern `Long-Short Term Memory` with competitive performances against modern `Transformers` or `State-Space` models (e.g. `Mamba`). 4 | 5 | Just for fun, this repo tries to implement a basic LLM (see `📂 xlstm.llm`) using [Lightning](https://lightning.ai/docs/pytorch/stable/) so that training on multi-gpu (should) be just one variable away. 6 | 7 | # Results 8 | 9 | Just for fun I set up to train a small `xLSTM` LLM model on the cute `TinyStories` dataset and logged its progress as it learned (I always find it amusing to read the incoherent first attempts and was actually surprised by how quickly it got the general structure). Here what I get for the highly original `Once upon a time` prompt: 10 | 11 | **At initialization** 12 | 13 | ```text 14 | Once upon a timeboro wit carryingabellaastered Greens intestinal Pil su128 configure Patentrowing SeventhNohs implies Burger ® Cities lowacommTYelligimilationbender Manual authored Comprehensivelow fightingrinasq intercourse377 gradientafe bluntlyaroo coats Witchhiba Jeff Flags ambassadors iT deleted Deals reassCruzka...(you get the idea) 15 | ``` 16 | 17 | **After 320 steps** 18 | 19 | ```text 20 | Once upon a time. She and took them. He is and they with something. She asked, a big dog on the park. Lily went to the park, ''That wanted it is not she is 21 | verv hanov into the around's mom man was a lot him to the "Thank 22 | he couldn't sad and. He is a time. "What and not to go be careful. She was that the little girl, I will. Then it?''' Tom things. He took it they saw a bia." 23 | ``` 24 | 25 | ![Validation Loss](res/loss.jpeg) 26 | 27 | **After 20K steps** 28 | 29 | ```text 30 | Once upon a time. Jack and ran across the hill. When she always a bit embarrassed and felt so much to play!" And they couldn't know what you should always made of the park." One day she wanted to help make some new friends." 31 | "The boy was so happy to a time. 32 | "Lily's help. He was very sorry, there. Then, and it looked at how he saw the ball. When she was happy and had so excited to buy the ground. He used to fly was very happy and daddy was so excited and the car. Timmy went to go home." 33 | ``` 34 | 35 | # Usage 36 | 37 | The `xlstm` module exposes both the `sLSTM` (scalar-LSTM) and the `mLSTM` (matrix-LSTM) modules. Both expect their input to have shape `(batch_size, d_input)` as they consume an input sequence sequentially. They output the model current (projected) hidden state `h_t` (which is considered the module output and has the same shape as the input, see Figure 9 in the Appendix of [Beck et al. (2024)](https://arxiv.org/abs/2405.04517)), plus their updated hidden variables (a tuple of tensors). 38 | 39 | ```python 40 | from xlstm import sLSTM 41 | from itertools import pairwise 42 | 43 | seq_len = 32 44 | batch_size = 4 45 | 46 | inp_dim = 16 47 | head_dim = 8 48 | head_num = 4 49 | 50 | # Create a mock up input sequence 51 | seq = torch.randn(seq_len, batch_size, inp_dim) 52 | 53 | lstm = sLSTM( 54 | inp_dim, # Input sequence dimension 55 | head_dim, # Dimension of each head 56 | head_num, # Number of heads 57 | p_factor=4/3, # Tunable expansion factor 58 | ) 59 | 60 | # Initialize the hidden states 61 | hid = lstm.init_hidden(batch_size) 62 | 63 | criterion = ... # Pick some loss function, i.e. MSE 64 | 65 | # Iterate through the sequence length 66 | loss = 0 67 | for prev, succ in pairwise(seq): 68 | # Get the model prediction plus the updated hidden states 69 | pred, hid = lstm(prev, hid) 70 | 71 | # Target is the next sequence token 72 | loss += criterion(pred, succ) 73 | 74 | # Compute gradients 75 | loss.backward() 76 | ``` 77 | 78 | This repo also provides an implementation of an `xLSTM` LLM (which is simply a stack of `sLSTM`s and `mLSTM` plus a prediction head) built using `Pytorch Lightning` which unlocks easy training on multi-gpus. To use it one can simply run the following example: 79 | 80 | ```python 81 | from lightning import Trainer 82 | from transformers import AutoTokenizer 83 | 84 | from xlstm import xLSTM 85 | from xlstm.stories import TinyStoriesLightning 86 | 87 | config = ... # path to YAML configuration file 88 | 89 | # Load an off-the-shelf tokenizer from HF 90 | tokenizer = AutoTokenizer.from_pretrained('openai-community/gpt2') 91 | 92 | # Load the Mamba model from a config file 93 | model = xLSTM.from_config(config, key='llm') 94 | 95 | # Load the dataset 96 | dataset = TinyStoriesLightning.from_config( 97 | config, 98 | tokenizer, 99 | key='dataset' 100 | ) 101 | 102 | trainer = Trainer( 103 | max_epochs = 500, 104 | accelerator = 'gpu', 105 | devices = 4, # Piece of cake multi-gpu support! 106 | strategy = 'ddp_find_unused_parameters_false', 107 | ) 108 | 109 | # Train the model 110 | trainer.fit(model, dataset) 111 | ``` 112 | 113 | Alternatively, one can also run the training script `run.py` directly which leverages the `LightningCLI` API which offers great flexibility for customization. The script expects a configuration file path (see example configuration file in `📂 config/llm.yaml`) and accepts all the Trainer arguments (and more! See [LightningCLI](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.LightningCLI.html#lightning.pytorch.cli.LightningCLI) for reference). 114 | 115 | ```bash 116 | python run.py fit --config config/llm.yaml 117 | ``` 118 | 119 | A cool feature of `xLSTM` current implementation is the lazy (batched-) inference implemented via a generator. One can thus print tokens on screen as they are streamed by the model, no need to wait for the whole inference to finish! A mock-up script would look like the following. 120 | 121 | ```python 122 | from xlstm import xLSTM 123 | from transformers import AutoTokenizer 124 | 125 | # Get an off-the-shelf tokenizer 126 | tokenizer = AutoTokenizer.from_pretrained('openai-community/gpt2') 127 | 128 | tokenizer.add_special_tokens({'pad_token': '<|pad|>'}) 129 | 130 | # Parameters for the LLM 131 | vocab_size = tokenizer.vocab_size + 1 132 | num_layers = 8 133 | signature = (7, 1) 134 | inp_dim = 16 135 | head_dim = 8 136 | head_num = 4 137 | ker_size = 4 138 | p_factor = (2, 4/3) 139 | 140 | model = xLSTM( 141 | vocab_size = vocab_size, 142 | num_layers = self.num_layers, 143 | signature = self.signature, 144 | inp_dim= self.inp_dim, 145 | head_dim= self.head_dim, 146 | head_num= self.head_num, 147 | p_factor= self.p_factor, 148 | ker_size = self.ker_size, 149 | ) 150 | 151 | # Parameters for the inference 152 | token_lim = 16 153 | use_top_k = 50 154 | temperature = 0.7 155 | 156 | # Generate text 157 | stream = model.generate( 158 | # We can provide more than one prompt! 159 | prompt=[ 160 | 'Once upon a time', 161 | 'In a galaxy far far away', 162 | ], 163 | tokenizer=tokenizer, 164 | token_lim=token_lim, 165 | use_top_k=use_top_k, 166 | temperature=temperature, 167 | ) 168 | 169 | for token in stream: 170 | # Each token is a dictionary indexed by the 171 | # batch-id and contains the produced string 172 | # as value, so we can print the first batch as: 173 | print(token[0], end='') 174 | ``` 175 | 176 | # Roadmap 177 | 178 | - [x] Put all the essential pieces together (i.e. `sLSTM` & `mLSTM`) 179 | - [x] Add implementation for a full `xLSTM` 180 | - [x] Add functioning training script (Lightning) 181 | - [x] Show some results 182 | 183 | # Requirements 184 | 185 | Code was tested with Python 3.11+. To install the required dependencies simply run `pip install -r requirements.txt`. 186 | 187 | ``` 188 | torch==2.3.0 189 | PyYAML==6.0.1 190 | einops==0.8.0 191 | lightning==2.2.4 192 | setuptools==69.5.1 193 | transformers==4.40.2 194 | ``` 195 | 196 | # Citations 197 | 198 | ```bibtex 199 | @article{beck2024xlstm, 200 | title={xLSTM: Extended Long Short-Term Memory}, 201 | author={Beck, Maximilian and P{\"o}ppel, Korbinian and Spanring, Markus and Auer, Andreas and Prudnikova, Oleksandra and Kopp, Michael and Klambauer, G{\"u}nter and Brandstetter, Johannes and Hochreiter, Sepp}, 202 | journal={arXiv preprint arXiv:2405.04517}, 203 | year={2024} 204 | } 205 | ``` 206 | -------------------------------------------------------------------------------- /config/llm.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 31415 2 | 3 | model: 4 | vocab_size: 50257 5 | num_layers: 8 6 | signature: [7, 1] 7 | inp_dim: 64 8 | head_dim: 16 9 | head_num: 4 10 | p_factor: [2, 1.33333] # (2, 4/3) 11 | ker_size: 4 12 | tokenizer: 13 | class_path: xlstm.utils.TokenizerWrapper 14 | init_args: 15 | pretrained_model_name_or_path: openai-community/gpt2 16 | special_tokens: 17 | pad_token: <|pad|> 18 | 19 | # Parameters relevant for the model inference 20 | inference_kw: 21 | prompt: [Once upon a time, In a galaxy far far away] 22 | token_lim: 128 23 | use_top_k: 50 24 | temperature: 1. 25 | 26 | optimizer: 27 | class_path: torch.optim.AdamW 28 | init_args: 29 | lr: 1e-4 30 | weight_decay: 0.01 31 | 32 | data: 33 | root: /path/to/data/root 34 | read_chunk: 4096 35 | batch_size: 64 36 | num_workers: 4 37 | tokenizer: 38 | class_path: xlstm.utils.TokenizerWrapper 39 | init_args: 40 | pretrained_model_name_or_path: openai-community/gpt2 41 | special_tokens: 42 | pad_token: <|pad|> 43 | 44 | trainer: 45 | max_epochs: 40 46 | accelerator: gpu 47 | devices: 1 48 | strategy: ddp_find_unused_parameters_false 49 | precision: 16-mixed 50 | log_every_n_steps: 1 51 | callbacks: 52 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 53 | init_args: 54 | monitor: val_loss 55 | save_last: true 56 | logger: 57 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 58 | init_args: 59 | save_dir: /path/to/save/dir 60 | name: llm-xlstm 61 | version: null 62 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.3.0 2 | PyYAML==6.0.1 3 | einops==0.8.0 4 | lightning==2.2.4 5 | setuptools==69.5.1 6 | transformers==4.40.2 -------------------------------------------------------------------------------- /res/loss.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myscience/x-lstm/b1635c99cd6105e96b94fdb5433b287cb1594cbf/res/loss.jpeg -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.cli import LightningCLI 2 | 3 | from xlstm import xLSTM 4 | from xlstm.stories import TinyStoriesLightning 5 | 6 | def cli_main(): 7 | ''' 8 | Main function for the training script. 9 | ''' 10 | 11 | # That's all it takes for LightningCLI to work! 12 | # No need to call .fit() or .test() or anything like that. 13 | cli = LightningCLI( 14 | xLSTM, 15 | TinyStoriesLightning, 16 | ) 17 | 18 | if __name__ == '__main__': 19 | cli_main() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='xlstm', 5 | version='0.1', 6 | packages=find_packages(), 7 | ) -------------------------------------------------------------------------------- /test/test_mlstm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from xlstm import mLSTM 5 | 6 | class TestMLSTM(unittest.TestCase): 7 | def setUp(self): 8 | self.inp_dim = 10 9 | self.head_dim = 8 10 | self.head_num = 4 11 | self.hid_dim = self.head_num * self.head_dim 12 | 13 | self.batch_size = 5 14 | 15 | # Create an instance of mLSTM 16 | self.model = mLSTM(self.inp_dim, self.head_num, self.head_dim) 17 | self.input = torch.randn(self.batch_size, self.inp_dim) 18 | 19 | self.hid_0 = self.model.init_hidden(self.batch_size) 20 | 21 | def test_forward(self): 22 | 23 | # Forward pass 24 | output, next_hid = self.model(self.input, self.hid_0) 25 | 26 | # Check if the output shape is correct 27 | self.assertEqual(output.shape, (self.batch_size, self.inp_dim)) 28 | 29 | self.assertEqual(next_hid[0].shape, (self.batch_size, self.head_num, self.head_dim, self.head_dim)) 30 | self.assertEqual(next_hid[1].shape, (self.batch_size, self.head_num, self.head_dim)) 31 | self.assertEqual(next_hid[2].shape, (self.batch_size, self.head_num)) 32 | 33 | def test_backward(self): 34 | criterion = torch.nn.MSELoss() 35 | 36 | # Forward pass 37 | target = torch.randn(self.batch_size, self.inp_dim) 38 | output, next_hid = self.model(self.input, self.hid_0) 39 | 40 | # Define target tensor 41 | target = torch.randn(self.batch_size, self.inp_dim) 42 | 43 | # Compute loss & backward pass 44 | loss = criterion(output, target) 45 | loss.backward() 46 | 47 | # Check if gradients are computed for all parameters 48 | for param in self.model.parameters(): 49 | self.assertIsNotNone(param.grad) 50 | 51 | if __name__ == '__main__': 52 | unittest.main() -------------------------------------------------------------------------------- /test/test_slstm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from xlstm import sLSTM 4 | 5 | class TestSLSTM(unittest.TestCase): 6 | def setUp(self): 7 | self.inp_dim = 10 8 | self.head_dim = 8 9 | self.head_num = 4 10 | self.hid_dim = self.head_num * self.head_dim 11 | 12 | self.batch_size = 5 13 | 14 | self.model = sLSTM(self.inp_dim, self.head_dim, self.head_num) 15 | self.input = torch.randn(self.batch_size, self.inp_dim) 16 | 17 | self.hid_0 = self.model.init_hidden(self.batch_size) 18 | 19 | def test_output_shape(self): 20 | output, _ = self.model(self.input, self.hid_0) 21 | 22 | self.assertEqual(output.shape, (self.batch_size, self.inp_dim)) 23 | 24 | def test_hidden_shape(self): 25 | hid = self.model.init_hidden(self.batch_size) 26 | self.assertEqual(len(hid), 4) 27 | 28 | self.assertEqual(hid[0].shape, (self.batch_size, self.hid_dim,)) 29 | self.assertEqual(hid[1].shape, (self.batch_size, self.hid_dim,)) 30 | self.assertEqual(hid[2].shape, (self.batch_size, self.hid_dim,)) 31 | self.assertEqual(hid[3].shape, (self.batch_size, self.hid_dim,)) 32 | 33 | def test_forward_no_conv(self): 34 | output, _ = self.model(self.input, self.hid_0) 35 | self.assertEqual(output.shape, (self.batch_size, self.inp_dim)) 36 | 37 | def test_forward_with_conv(self): 38 | output, _ = self.model(self.input, self.hid_0, use_conv=True) 39 | self.assertEqual(output.shape, (self.batch_size, self.inp_dim)) 40 | 41 | def test_backward(self): 42 | criterion = torch.nn.MSELoss() 43 | 44 | target = torch.randn(self.batch_size, self.inp_dim) 45 | output, _ = self.model(self.input, self.hid_0) 46 | 47 | loss = criterion(output, target) 48 | loss.backward() 49 | 50 | # Check if gradients are computed for all parameters 51 | # with the possible exception of the causal conv 52 | for name, param in self.model.named_parameters(): 53 | if 'causal_conv' in name: continue 54 | self.assertIsNotNone(param.grad) 55 | 56 | if __name__ == '__main__': 57 | unittest.main() -------------------------------------------------------------------------------- /test/test_stories.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import yaml 4 | from os import path 5 | from unittest.mock import Mock 6 | from torch import Tensor 7 | from transformers import AutoTokenizer 8 | 9 | from xlstm.stories import TinyStories, TinyStoriesLightning 10 | from xlstm.utils import TokenizerWrapper 11 | 12 | # Loading `local_settings.json` for custom local settings 13 | test_folder = path.dirname(path.abspath(__file__)) 14 | local_settings = path.join(test_folder, '.local.yaml') 15 | 16 | with open(local_settings, 'r') as f: 17 | local_settings = yaml.safe_load(f) 18 | 19 | class TestTinyStories(unittest.TestCase): 20 | 21 | def setUp(self): 22 | self.tokenizer = AutoTokenizer.from_pretrained('openai-community/gpt2') 23 | 24 | self.tokenizer.add_special_tokens({'pad_token': '<|pad|>'}) 25 | 26 | 27 | def test_len(self): 28 | self.dataset = TinyStories( 29 | root=local_settings['tiny_stories_root'], 30 | tokenizer=self.tokenizer, 31 | max_length=256, 32 | data_split='train', 33 | read_chunk=4096 34 | ) 35 | 36 | self.assertEqual(len(self.dataset), 0) # Replace 0 with the expected length 37 | 38 | def test_iter(self): 39 | iterator = iter(self.dataset) 40 | inputs, labels = next(iterator) 41 | self.assertIsInstance(inputs, Tensor) 42 | self.assertIsInstance(labels, Tensor) 43 | # Add more assertions to validate the data returned by the iterator 44 | 45 | def tearDown(self): 46 | pass 47 | 48 | class TestTinyStoriesLightning(unittest.TestCase): 49 | 50 | def setUp(self): 51 | self.seq_len = 256 52 | self.batch_size = 16 53 | self.num_workers = 2 54 | 55 | wrapper = TokenizerWrapper( 56 | pretrained_model_name_or_path='openai-community/gpt2', 57 | special_tokens={'pad_token': '<|pad|>'} 58 | ) 59 | 60 | self.module = TinyStoriesLightning( 61 | tokenizer=wrapper, 62 | root=local_settings['tiny_stories_root'], 63 | max_length=self.seq_len, 64 | read_chunk=1024, 65 | batch_size=self.batch_size, 66 | num_workers=self.num_workers, 67 | ) 68 | 69 | self.module.setup('fit') 70 | 71 | def test_setup_fit(self): 72 | self.assertIsInstance(self.module.train_dataset, TinyStories) 73 | self.assertIsInstance(self.module.valid_dataset, TinyStories) 74 | 75 | # Add more assertions to validate the setup for the 'fit' stage 76 | # NOTE: We discard the first batch because it appears to be inconsistent 77 | # for multi-workers setup probably due to strange synchronizations 78 | trainset = iter(self.module.train_dataloader()) 79 | _ = next(trainset) 80 | batch = next(trainset) 81 | 82 | prev, post = batch 83 | 84 | self.assertEqual(prev.shape, (self.batch_size, self.seq_len)) 85 | self.assertEqual(post.shape, (self.batch_size, self.seq_len)) 86 | 87 | # Add more assertions to validate the setup for the 'fit' stage 88 | valset = iter(self.module.val_dataloader()) 89 | _ = next(valset) 90 | batch = next(valset) 91 | 92 | prev, post = batch 93 | 94 | self.assertEqual(prev.shape, (self.batch_size, self.seq_len)) 95 | self.assertEqual(post.shape, (self.batch_size, self.seq_len)) 96 | 97 | def test_setup_test(self): 98 | self.module.setup('test') 99 | 100 | self.assertIsInstance(self.module.test__dataset, TinyStories) 101 | 102 | batch = next(iter(self.module.test_dataloader())) 103 | 104 | prev, post = batch 105 | 106 | self.assertEqual(prev.shape, (self.batch_size, self.seq_len)) 107 | self.assertEqual(post.shape, (self.batch_size, self.seq_len)) 108 | 109 | def tearDown(self): 110 | pass 111 | 112 | if __name__ == '__main__': 113 | unittest.main() -------------------------------------------------------------------------------- /test/test_xlstm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import yaml 4 | import torch 5 | from os import path 6 | 7 | from transformers import AutoTokenizer 8 | 9 | from xlstm import xLSTM, mLSTM, sLSTM 10 | from xlstm.stories import TinyStoriesLightning 11 | from xlstm.utils import default_iterdata_worker_init 12 | 13 | # Loading `local_settings.json` for custom local settings 14 | test_folder = path.dirname(path.abspath(__file__)) 15 | local_settings = path.join(test_folder, '.local.yaml') 16 | 17 | class TestXLSTM(unittest.TestCase): 18 | def setUp(self): 19 | self.num_layers = 8 20 | self.signature = (7, 1) 21 | self.inp_dim = 16 22 | self.head_dim = 8 23 | self.head_num = 4 24 | self.ker_size = 4 25 | self.p_factor = (2, 4/3) 26 | 27 | self.seq_len = 32 28 | self.batch_size = 4 29 | self.vocab_size = 24 30 | 31 | # Mockup input for example purposes 32 | self.seq = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len)) 33 | 34 | def test_llm_init(self): 35 | LLM = xLSTM( 36 | vocab_size = self.vocab_size, 37 | num_layers = self.num_layers, 38 | signature = self.signature, 39 | inp_dim= self.inp_dim, 40 | head_dim= self.head_dim, 41 | head_num= self.head_num, 42 | p_factor= self.p_factor, 43 | ker_size = self.ker_size, 44 | ) 45 | 46 | m_num, s_num = self.signature 47 | which = [True] * m_num + [False] * s_num 48 | 49 | for layer, kind in zip(LLM.llm, which): 50 | self.assertIsInstance(layer, mLSTM if kind else sLSTM) 51 | 52 | def test_llm_forward(self): 53 | 54 | xlstm = xLSTM( 55 | vocab_size = self.vocab_size, 56 | num_layers = self.num_layers, 57 | signature = self.signature, 58 | inp_dim= self.inp_dim, 59 | head_dim= self.head_dim, 60 | head_num= self.head_num, 61 | p_factor= self.p_factor, 62 | ker_size = self.ker_size, 63 | ) 64 | 65 | 66 | # Compute the output using the xLSTM architecture 67 | out, _ = xlstm.forward(self.seq, batch_first=True) 68 | 69 | self.assertEqual(out.shape, (self.batch_size, self.seq_len, self.vocab_size)) 70 | 71 | def test_llm_dataloader(self): 72 | 73 | # Get the local path to tiny stories 74 | with open(local_settings, 'r') as f: 75 | root = yaml.safe_load(f)['tiny_stories_path'] 76 | 77 | # Get an off-the-shelf tokenizer 78 | tokenizer = AutoTokenizer.from_pretrained('openai-community/gpt2') 79 | 80 | vocab_size = tokenizer.vocab_size 81 | 82 | xlstm = xLSTM( 83 | vocab_size = vocab_size, 84 | num_layers = self.num_layers, 85 | signature = self.signature, 86 | inp_dim= self.inp_dim, 87 | head_dim= self.head_dim, 88 | head_num= self.head_num, 89 | p_factor= self.p_factor, 90 | ker_size = self.ker_size, 91 | ) 92 | 93 | loader = TinyStoriesLightning( 94 | root, 95 | tokenizer, 96 | max_length=self.seq_len, 97 | batch_size=self.batch_size, 98 | worker_init_fn=default_iterdata_worker_init, 99 | ) 100 | 101 | loader.setup(stage='fit') 102 | batch = next(iter(loader.train_dataloader())) 103 | 104 | prev, post = batch 105 | 106 | logits, _ = xlstm(prev) 107 | 108 | loss = xlstm.compute_loss(prev, post) 109 | 110 | self.assertTrue((loss >= 0).all()) 111 | self.assertEqual(logits.shape, (*prev.shape, vocab_size)) 112 | 113 | def test_llm_generate(self): 114 | # Get an off-the-shelf tokenizer 115 | tokenizer = AutoTokenizer.from_pretrained('openai-community/gpt2') 116 | 117 | tokenizer.add_special_tokens({'pad_token': '<|pad|>'}) 118 | 119 | vocab_size = tokenizer.vocab_size + 1 120 | token_lim = 16 121 | 122 | model = xLSTM( 123 | vocab_size = vocab_size, 124 | num_layers = self.num_layers, 125 | signature = self.signature, 126 | inp_dim= self.inp_dim, 127 | head_dim= self.head_dim, 128 | head_num= self.head_num, 129 | p_factor= self.p_factor, 130 | ker_size = self.ker_size, 131 | ) 132 | 133 | # Generate text 134 | gen = model.generate( 135 | prompt=[ 136 | 'Once upon a time', 137 | 'In a galaxy far far away', 138 | ], 139 | tokenizer=tokenizer, 140 | token_lim=token_lim, 141 | ) 142 | 143 | for tok in gen: 144 | print(tok[0], end='') 145 | 146 | self.assertTrue(True) 147 | 148 | if __name__ == '__main__': 149 | unittest.main() -------------------------------------------------------------------------------- /xlstm/__init__.py: -------------------------------------------------------------------------------- 1 | from xlstm.llm import xLSTM 2 | from xlstm.lstm import sLSTM, mLSTM -------------------------------------------------------------------------------- /xlstm/data.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from abc import abstractmethod 3 | 4 | from torch import Tensor 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data import IterableDataset 7 | 8 | from lightning import LightningDataModule 9 | 10 | from typing import Callable 11 | 12 | from .utils import default 13 | from .utils import default_iterdata_worker_init 14 | 15 | class LightningDataset(LightningDataModule): 16 | ''' 17 | Abstract Lightning Data Module that represents a dataset we 18 | can train a Lightning module on. 19 | ''' 20 | 21 | @classmethod 22 | def from_config(cls, conf_path : str, *args, key : str = 'dataset') -> 'LightningDataset': 23 | ''' 24 | Construct a Lightning DataModule from a configuration file. 25 | ''' 26 | 27 | with open(conf_path, 'r') as f: 28 | conf = yaml.safe_load(f) 29 | 30 | data_conf = conf[key] 31 | 32 | return cls( 33 | *args, 34 | **data_conf, 35 | ) 36 | 37 | def __init__( 38 | self, 39 | *args, 40 | batch_size : int = 16, 41 | num_workers : int = 0, 42 | train_shuffle : bool | None = None, 43 | val_shuffle : bool | None = None, 44 | val_batch_size : None | int = None, 45 | worker_init_fn : None | Callable = None, 46 | collate_fn : None | Callable = None, 47 | train_sampler : None | Callable = None, 48 | val_sampler : None | Callable = None, 49 | test_sampler : None | Callable = None, 50 | ) -> None: 51 | super().__init__() 52 | 53 | self.train_dataset = None 54 | self.valid_dataset = None 55 | self.test__dataset = None 56 | 57 | val_batch_size = default(val_batch_size, batch_size) 58 | 59 | self.num_workers = num_workers 60 | self.batch_size = batch_size 61 | self.train_shuffle = train_shuffle 62 | self.val_shuffle = val_shuffle 63 | self.train_sampler = train_sampler 64 | self.valid_sampler = val_sampler 65 | self.test__sampler = test_sampler 66 | self.collate_fn = collate_fn 67 | self.worker_init_fn = worker_init_fn 68 | self.val_batch_size = val_batch_size 69 | 70 | @abstractmethod 71 | def setup(self, stage: str) -> None: 72 | msg = \ 73 | ''' 74 | This is an abstract datamodule class. You should use one of 75 | the concrete subclasses that represents an actual dataset. 76 | ''' 77 | 78 | raise NotImplementedError(msg) 79 | 80 | def train_dataloader(self) -> DataLoader: 81 | if isinstance(self.train_dataset, IterableDataset): 82 | worker_init_fn = default(self.worker_init_fn, default_iterdata_worker_init) 83 | 84 | return DataLoader( 85 | self.train_dataset, # type: ignore 86 | sampler = self.train_sampler, # type: ignore 87 | batch_size = self.batch_size, 88 | shuffle = self.train_shuffle, 89 | collate_fn = self.collate_fn, 90 | num_workers = self.num_workers, 91 | worker_init_fn = worker_init_fn, 92 | ) 93 | 94 | def val_dataloader(self) -> DataLoader: 95 | if isinstance(self.train_dataset, IterableDataset): 96 | worker_init_fn = default(self.worker_init_fn, default_iterdata_worker_init) 97 | 98 | return DataLoader( 99 | self.valid_dataset, # type: ignore 100 | sampler = self.valid_sampler, # type: ignore 101 | batch_size = self.val_batch_size, 102 | shuffle = self.val_shuffle, 103 | collate_fn = self.collate_fn, 104 | num_workers = self.num_workers, 105 | worker_init_fn = worker_init_fn, 106 | ) 107 | 108 | def test_dataloader(self) -> DataLoader: 109 | if isinstance(self.train_dataset, IterableDataset): 110 | worker_init_fn = default(self.worker_init_fn, default_iterdata_worker_init) 111 | 112 | return DataLoader( 113 | self.test__dataset, # type: ignore 114 | sampler = self.test__sampler, # type: ignore 115 | batch_size = self.val_batch_size, 116 | shuffle = self.val_shuffle, 117 | collate_fn = self.collate_fn, 118 | num_workers = self.num_workers, 119 | worker_init_fn = worker_init_fn, 120 | ) -------------------------------------------------------------------------------- /xlstm/llm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from warnings import warn 4 | from lightning import LightningModule 5 | 6 | from torch import Tensor 7 | from torch.optim import AdamW 8 | from torch.optim import Optimizer 9 | from torch.nn .functional import softmax 10 | from torch.nn.functional import cross_entropy 11 | 12 | from transformers import PreTrainedTokenizerBase 13 | from typing import Any, Dict, Generator, List, Tuple, Callable, Iterable 14 | 15 | from itertools import repeat 16 | from einops import rearrange 17 | 18 | 19 | from .lstm import sLSTM 20 | from .lstm import mLSTM 21 | from .utils import Hidden 22 | from .utils import default 23 | from .utils import TokenizerWrapper 24 | 25 | OptimizerCallable = Callable[[Iterable], Optimizer] 26 | 27 | class xLSTM(LightningModule): 28 | '''The extended Long Short Term Memory (xLSTM) module as 29 | originally introduced in Beck et al. (2024)] see: 30 | (https://arxiv.org/abs/2405.04517). 31 | 32 | This model stacks sLSTM and mLSTM modules with residual 33 | connections and offers superior memory and performance 34 | compared to the standard LSTM model, achieving competitive 35 | or better performance and scaling than Transformer models 36 | or State-Space models. 37 | ''' 38 | 39 | def __init__( 40 | self, 41 | vocab_size : int, 42 | num_layers : int, 43 | signature : Tuple[int, int], 44 | inp_dim : int, 45 | head_dim : int, 46 | head_num : int, 47 | p_factor : Tuple[float, float] = (2, 4/3), 48 | ker_size : int = 4, 49 | optimizer : OptimizerCallable = AdamW, 50 | tokenizer: TokenizerWrapper | None = None, 51 | inference_kw: Dict[str, Any] = {} 52 | ) -> None: 53 | '''Initialize the LLM model. 54 | 55 | Args: 56 | vocab_size (int): The size of the vocabulary. 57 | num_layers (int): The number of layers in the LLM model. 58 | signature (Tuple[int, int]): The signature of the LLM model, 59 | which represents the ration of the mLSTM-to-sLSTM blocks. 60 | inp_dim (int): The dimension of the input tokens. 61 | head_dim (int): The dimension of each attention head. 62 | head_num (int): The number of attention heads. 63 | p_factor (Tuple[float, float], optional): The expansion factor 64 | for the MLP projection in the m|s-LSTM blocks. Defaults to (2, 4/3). 65 | ker_size (int, optional): The kernel size for the causal convolutional layers. 66 | Defaults to 4. 67 | 68 | kwargs: Additional keyword arguments used at inference time (see relevant 69 | arguments of the generate method). 70 | ''' 71 | super().__init__() 72 | 73 | self.optimizer = optimizer 74 | self.inference_kw = inference_kw 75 | self.tokenizer = None if tokenizer is None else tokenizer.get_tokenizer() 76 | 77 | num_embeddings = vocab_size if tokenizer is None else\ 78 | self.tokenizer.vocab_size + len(self.tokenizer.added_tokens_decoder) 79 | 80 | if num_embeddings != vocab_size: 81 | warn('Tokenizer detected. Using tokenizer vocabulary size. Vocabulary size will be ignored.') 82 | 83 | # Needed embedding layer for mapping input tokens to the network 84 | self.embedding = nn.Embedding( 85 | num_embeddings=num_embeddings, 86 | embedding_dim=inp_dim, 87 | ) 88 | 89 | m_factor, s_factor = p_factor 90 | 91 | mlstm_par = { 92 | 'inp_dim' : inp_dim, 93 | 'head_dim' : head_dim, 94 | 'head_num' : head_num, 95 | 'p_factor' : m_factor, 96 | 'ker_size' : ker_size, 97 | } 98 | 99 | slstm_par = { 100 | 'inp_dim' : inp_dim, 101 | 'head_dim' : head_dim, 102 | 'head_num' : head_num, 103 | 'p_factor' : s_factor, 104 | 'ker_size' : ker_size, 105 | } 106 | 107 | m_num, s_num = signature 108 | which = [True] * m_num + [False] * s_num 109 | 110 | self.llm : List[mLSTM | sLSTM] = nn.ModuleList([ 111 | mLSTM(**mlstm_par) if v else sLSTM(**slstm_par) 112 | for w in repeat(which, num_layers) for v in w 113 | ]) 114 | 115 | # Prediction head to map the output of the xLSTM model to the vocabulary 116 | self.head = nn.Linear(inp_dim, vocab_size, bias=False) 117 | 118 | self.save_hyperparameters() 119 | 120 | def forward( 121 | self, 122 | tok: Tensor, 123 | hid: Hidden | None = None, 124 | batch_first : bool = False, 125 | ) -> Tuple[Tensor, Hidden]: 126 | '''Forward pass of the xLSTM model. 127 | 128 | Args: 129 | tok (Tensor): Input tensor representing the sequence tokens. 130 | Expected shape: (batch, seq_len) if batch_first=True, 131 | else (seq_len, batch). 132 | hid (Hidden, optional): Cache object for storing intermediate hidden 133 | values of the m|s-LSTM blocks of the model. If None, the hidden 134 | states are initialized by the models. Defaults to None. 135 | 136 | Returns: 137 | Tuple[Tensor, Hidden]: Returns tensor of predicted logits of shape 138 | (batch, seq_len, vocab_size) if batch_first=True or of shape 139 | (seq_len, batch, vocab_size) if batch_first=False, and the 140 | updated hidden model states. 141 | ''' 142 | 143 | tok : Tensor = torch.atleast_2d(tok) 144 | seq : Tensor = self.embedding(tok) 145 | 146 | if batch_first: seq = rearrange(seq, 'b s i -> s b i') 147 | if hid is None: hid = [l.init_hidden(seq.shape[1]) for l in self.llm] 148 | 149 | # Pass the sequence through the mLSTM and sLSTM blocks 150 | out = [] 151 | for inp in seq: 152 | # Compute model output and update the hidden states 153 | for i, lstm in enumerate(self.llm): 154 | inp, hid[i] = lstm(inp, hid[i]) 155 | 156 | out.append(inp) 157 | 158 | out = torch.stack(out, dim=1 if batch_first else 0) 159 | out = self.head(out) 160 | 161 | return out, hid 162 | 163 | @torch.no_grad() 164 | def generate( 165 | self, 166 | prompt : str | List[str], 167 | token_lim : int = 300, 168 | use_top_k : int = 50, 169 | tokenizer : PreTrainedTokenizerBase | None = None, 170 | temperature : float = 1.0, 171 | ) -> Generator[Dict[int, str], None, None]: 172 | # Set model in evaluation model for inference 173 | self.eval() 174 | 175 | tokenizer = default(tokenizer, self.tokenizer) 176 | if tokenizer is None: raise ValueError('Tokenizer not available.') 177 | 178 | if isinstance(prompt, str): 179 | prompt = [prompt] 180 | 181 | # Encode the prompt using the tokenizer 182 | inp = tokenizer( 183 | prompt, 184 | return_tensors='pt', 185 | padding=True, 186 | truncation=True, 187 | ).input_ids.to(self.device) 188 | 189 | batch_size, inp_len = inp.shape 190 | vocab_size = tokenizer.vocab_size + len(tokenizer.added_tokens_decoder) # type: ignore 191 | 192 | # Consume the prompt to get the hidden states 193 | logits, hid = self(inp, batch_first=True) 194 | 195 | # Start generating the output sequence until either the maximum 196 | # token limit is reach or the model generates the<|endoftext|> token 197 | num_tokes = 0 198 | out, pred = [inp], inp[:, -1] 199 | pidx = torch.arange(batch_size, device=self.device) 200 | 201 | yield {int(pid) : tokenizer.decode(raw, skip_special_tokens=True) for pid, raw in zip(pidx, inp)} 202 | 203 | while num_tokes < token_lim and len(pred): 204 | logits, hid = self(pred, hid) 205 | 206 | # Get the token with the highest probability by zeroing out 207 | # the probability of the lowest probability tokens 208 | prob = softmax(logits[-1] / temperature, dim=-1) 209 | idxs = prob.topk(k=vocab_size - use_top_k, largest=False, sorted=False).indices 210 | prob.scatter_(dim=-1, index=idxs, src=torch.zeros_like(prob)) 211 | prob /= prob.sum(dim=-1, keepdim=True) 212 | 213 | # Sample the next token from the distribution modelled by the llm 214 | pred = torch.multinomial(prob, num_samples=1, replacement=True).squeeze() 215 | 216 | # Append the token to the input sequence 217 | out.append(pred) 218 | 219 | num_tokes += 1 220 | 221 | # Drop from the batch every prediction that reached the <|endoftext|> token 222 | mask = pred != tokenizer.eos_token_id 223 | 224 | pred = pred[mask] 225 | pidx = pidx[mask] 226 | hid = [[val[mask] for val in layer] for layer in hid] 227 | 228 | # Yield the decoded tokens 229 | yield {int(pid) : tokenizer.decode(raw, skip_special_tokens=True) for pid, raw in zip(pidx, pred)} 230 | 231 | self.train() 232 | 233 | def compute_loss(self, prev : Tensor, post : Tensor) -> Tensor: 234 | '''Compute the cross-entropy loss between the predicted (logits) and 235 | the actual next token, for all tokens in the batch. 236 | 237 | Args: 238 | prev (Tensor): The tensor containing the previous tokens. 239 | Expected shape: (batch, seq_len). 240 | post (Tensor): The tensor containing the next tokens, i.e. 241 | the targets. Expected shape: (batch, seq_len). 242 | 243 | Returns: 244 | Tensor: The computed loss between the predicted tokens and the target tokens. 245 | ''' 246 | # Compute model predictions (logits) for the next tokens based 247 | # on the previous tokens 248 | pred, _ = self(prev) 249 | 250 | pred = rearrange(pred, 'b s v -> (b s) v') 251 | post = rearrange(post, 'b s -> (b s)') 252 | 253 | # Compute the loss using the cross entropy loss 254 | loss = cross_entropy(pred, post) 255 | 256 | return loss 257 | 258 | def training_step(self, batch : Tuple[Tensor, Tensor], batch_idx : int) -> Tensor: 259 | prev_tok, next_tok = batch 260 | 261 | loss = self.compute_loss(prev_tok, next_tok) 262 | 263 | self.log_dict( 264 | {'train_loss' : loss}, 265 | logger=True, 266 | on_step=True, 267 | sync_dist=True 268 | ) 269 | 270 | return loss 271 | 272 | def validation_step(self, batch : Tensor, batch_idx : int) -> Tensor: 273 | prev_tok, next_tok = batch 274 | 275 | loss = self.compute_loss(prev_tok, next_tok) 276 | 277 | self.log_dict( 278 | {'val_loss' : loss}, 279 | logger=True, 280 | on_step=True, 281 | sync_dist=True 282 | ) 283 | 284 | return loss 285 | 286 | def on_validation_end(self) -> None: 287 | # No need to generate text if the tokenizer is not available 288 | if self.tokenizer is None: return 289 | 290 | inference_kw = { 291 | 'prompt' : 'Once upon a time', 292 | 'tokenizer' : self.tokenizer, 293 | **self.inference_kw 294 | } 295 | 296 | # Generate the model output on the given prompt 297 | output = list( # List needed to consume the generator 298 | self.generate( 299 | **inference_kw 300 | ) 301 | ) 302 | 303 | # Assemble the outputs based on the batch id 304 | pids = list(output[0].keys()) 305 | output = {pid : ''.join([out[pid] for out in output]) for pid in pids} 306 | 307 | for pid, text in output.items(): 308 | self.logger.experiment.add_text( 309 | f'Prompt ID:{pid}', 310 | text, 311 | global_step=self.global_step, 312 | ) 313 | 314 | def configure_optimizers(self) -> Optimizer: 315 | optim = self.optimizer( 316 | self.parameters(), 317 | ) 318 | 319 | return optim 320 | -------------------------------------------------------------------------------- /xlstm/lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from math import sqrt 5 | from torch import exp 6 | from torch import tanh 7 | from torch import sigmoid 8 | from einops import einsum, rearrange 9 | 10 | from torch import Tensor 11 | from typing import Tuple 12 | from torch.nn.functional import silu 13 | from torch.nn.functional import gelu 14 | 15 | from .utils import enlarge_as 16 | from .utils import BlockLinear 17 | from .utils import CausalConv1d 18 | 19 | class sLSTM(nn.Module): 20 | '''The scalar-Long Short Term Memory (sLSTM) module as 21 | originally introduced in Beck et al. (2024)] see: 22 | (https://arxiv.org/abs/2405.04517). 23 | 24 | This model is a variant of the standard LSTM model and 25 | offers two major improvements: 26 | - Exponential gating with appropriate state normalization 27 | to avoid overflows induced by the exponential function. 28 | - A new memory mixing within heads but not across heads. 29 | ''' 30 | 31 | def __init__( 32 | self, 33 | inp_dim : int, 34 | head_dim : int, 35 | head_num : int, 36 | ker_size : int = 4, 37 | p_factor : float = 4/3, 38 | ) -> None: 39 | super().__init__() 40 | 41 | self.inp_dim = inp_dim 42 | self.head_dim = head_dim 43 | self.head_num = head_num 44 | 45 | self.inp_norm = nn.LayerNorm(inp_dim) 46 | self.hid_norm = nn.GroupNorm(head_num, head_dim * head_num) 47 | 48 | self.causal_conv = CausalConv1d(1, 1, kernel_size=ker_size) 49 | 50 | self.W_z = nn.Linear(inp_dim, head_num * head_dim) 51 | self.W_i = nn.Linear(inp_dim, head_num * head_dim) 52 | self.W_o = nn.Linear(inp_dim, head_num * head_dim) 53 | self.W_f = nn.Linear(inp_dim, head_num * head_dim) 54 | 55 | self.R_z = BlockLinear([(head_dim, head_dim)] * head_num) 56 | self.R_i = BlockLinear([(head_dim, head_dim)] * head_num) 57 | self.R_o = BlockLinear([(head_dim, head_dim)] * head_num) 58 | self.R_f = BlockLinear([(head_dim, head_dim)] * head_num) 59 | 60 | # NOTE: The factor of two in the output dimension of the up_proj 61 | # is due to the fact that the output needs to branch into two 62 | # separate outputs to account for the the gated GeLU connection. 63 | # See Fig. 9 in the paper. 64 | proj_dim = int(p_factor * head_num * head_dim) 65 | self.up_proj = nn.Linear(head_num * head_dim, 2 * proj_dim) 66 | self.down_proj = nn.Linear(proj_dim, inp_dim) 67 | 68 | @property 69 | def device(self) -> str: 70 | '''Get the device of the model. 71 | 72 | Returns: 73 | str: The device of the model. 74 | ''' 75 | return next(self.parameters()).device 76 | 77 | def init_hidden(self, bs : int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: 78 | '''Initialize the hidden state of the sLSTM model. 79 | 80 | Args: 81 | batch_size (int): The batch size of the input sequence. 82 | 83 | Returns: 84 | Tuple[Tensor, Tensor, Tensor, Tensor]: The hidden state tuple containing the cell state, 85 | normalizer state, hidden state, and stabilizer state. 86 | ''' 87 | 88 | n_0 = torch.ones (bs, self.head_num * self.head_dim, device=self.device) 89 | c_0 = torch.zeros(bs, self.head_num * self.head_dim, device=self.device) 90 | h_0 = torch.zeros(bs, self.head_num * self.head_dim, device=self.device) 91 | m_0 = torch.zeros(bs, self.head_num * self.head_dim, device=self.device) 92 | 93 | return c_0, n_0, h_0, m_0 94 | 95 | def forward( 96 | self, 97 | seq: Tensor, 98 | hid: Tuple[Tensor, Tensor, Tensor, Tensor], 99 | use_conv : bool = False, 100 | ) -> Tuple[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor]]: 101 | '''Forward pass of the sLSTM model. 102 | 103 | Args: 104 | seq (Tensor): The input sequence tensor of shape (batch_size, input_dim). 105 | hid (Tuple[Tensor, Tensor, Tensor, Tensor]): The hidden state tuple containing the cell state, 106 | normalizer state, hidden state, and stabilizer state. 107 | 108 | Returns: 109 | Tuple[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor]]: The output tensor with the residual 110 | connection and the newly updated hidden state tuple. 111 | ''' 112 | 113 | b, d = seq.shape 114 | 115 | # Separate the hidden (previous) state into the cell state, 116 | # the normalizer state, the hidden state, and the stabilizer state. 117 | c_tm1, n_tm1, h_tm1, m_tm1 = hid 118 | 119 | x_t : Tensor = self.inp_norm(seq) 120 | 121 | # Optional causal convolution block for the input 122 | # and forget gates. See Fig. 9 in the paper. 123 | if use_conv: 124 | # FIXME: The causal conv branch is broken. 125 | x_c = self.causal_conv(x_t) 126 | x_c = silu(x_c).squeeze() 127 | else: 128 | x_c = x_t 129 | 130 | # Project the input to the different heads for all 131 | # the gates. 132 | # NOTE: For input (i) and forget (f) inputs we use 133 | # the output of the causal conv. See Fig. 9 in the paper. 134 | i_t: Tensor = self.W_i(x_c) + self.R_i(h_tm1) 135 | f_t: Tensor = self.W_f(x_c) + self.R_f(h_tm1) 136 | z_t: Tensor = self.W_z(x_t) + self.R_z(h_tm1) 137 | o_t: Tensor = self.W_o(x_t) + self.R_o(h_tm1) 138 | 139 | # Compute the gated outputs for the newly computed inputs 140 | m_t = torch.max(f_t + m_tm1, i_t) 141 | 142 | i_t = exp(i_t - m_t) # Eq. (16) in ref. paper | or Eq. (38) in supp. mat. 143 | f_t = exp(f_t - m_t + m_tm1) # Eq. (17) in ref. paper | or Eq. (39) in supp. mat. 144 | 145 | z_t = tanh(z_t) # Eq. (11) in ref. paper 146 | o_t = sigmoid(o_t) # Eq. (14) in ref. paper 147 | 148 | # Update the internal states of the model 149 | c_t = f_t * c_tm1 + i_t * z_t # Eq. (8) in ref. paper 150 | n_t = f_t * n_tm1 + i_t # Eq. (9) in ref. paper 151 | h_t = o_t * (c_t / n_t) # Eq. (10) in ref. paper 152 | 153 | # Compute the output of the LSTM block 154 | out = self.hid_norm(h_t) 155 | 156 | # Perform up-and-down projection of the output with 157 | # projection factor 4/3. See Fig. (9) in supp. mat. 158 | out1, out2 = self.up_proj(out).chunk(2, dim=-1) 159 | 160 | out = out1 + gelu(out2) 161 | out = self.down_proj(out) 162 | 163 | # Return output with the residual connection and the 164 | # newly updated hidden state. 165 | return out + seq, (c_t, n_t, h_t, m_t) 166 | 167 | class mLSTM(nn.Module): 168 | '''The matrix-Long Short Term Memory (mLSTM) module as 169 | originally introduced in Beck et al. (2024)] see: 170 | (https://arxiv.org/abs/2405.04517). 171 | 172 | This model is a variant of the standard LSTM model and 173 | offers superior memory due to its storing values in a 174 | matrix instead of a scalar. It is fully parallelizable 175 | and updates internal memory with the covariance rule. 176 | ''' 177 | 178 | def __init__( 179 | self, 180 | inp_dim : int, 181 | head_num : int, 182 | head_dim : int, 183 | p_factor : int = 2, 184 | ker_size : int = 4, 185 | ) -> None: 186 | super().__init__() 187 | 188 | self.inp_dim = inp_dim 189 | self.head_num = head_num 190 | self.head_dim = head_dim 191 | 192 | hid_dim = head_num * head_dim 193 | 194 | self.inp_norm = nn.LayerNorm(inp_dim) 195 | self.hid_norm = nn.GroupNorm(head_num, hid_dim) 196 | 197 | # NOTE: The factor of two in the output dimension of the up_proj 198 | # is due to the fact that the output needs to branch into two 199 | self.up_l_proj = nn.Linear(inp_dim, int(p_factor * inp_dim)) 200 | self.up_r_proj = nn.Linear(inp_dim, hid_dim) 201 | self.down_proj = nn.Linear(hid_dim, inp_dim) 202 | 203 | self.causal_conv = CausalConv1d(1, 1, kernel_size=ker_size) 204 | 205 | self.skip = nn.Conv1d(int(p_factor * inp_dim), hid_dim, kernel_size=1, bias=False) 206 | 207 | self.W_i = nn.Linear(int(p_factor * inp_dim), head_num) 208 | self.W_f = nn.Linear(int(p_factor * inp_dim), head_num) 209 | self.W_o = nn.Linear(int(p_factor * inp_dim), hid_dim) 210 | 211 | self.W_q = nn.Linear(int(p_factor * inp_dim), hid_dim) 212 | self.W_k = nn.Linear(int(p_factor * inp_dim), hid_dim) 213 | self.W_v = nn.Linear(int(p_factor * inp_dim), hid_dim) 214 | 215 | @property 216 | def device(self) -> str: 217 | '''Get the device of the model. 218 | 219 | Returns: 220 | str: The device of the model. 221 | ''' 222 | return next(self.parameters()).device 223 | 224 | def init_hidden(self, bs : int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: 225 | '''Initialize the hidden state of the sLSTM model. 226 | 227 | Args: 228 | batch_size (int): The batch size of the input sequence. 229 | 230 | Returns: 231 | Tuple[Tensor, Tensor, Tensor, Tensor]: The hidden state tuple containing the cell state, 232 | normalizer state, hidden state, and stabilizer state. 233 | ''' 234 | 235 | c_0 = torch.zeros(bs, self.head_num, self.head_dim, self.head_dim, device=self.device) 236 | n_0 = torch.ones (bs, self.head_num, self.head_dim , device=self.device) 237 | m_0 = torch.zeros(bs, self.head_num , device=self.device) 238 | 239 | return c_0, n_0, m_0 240 | 241 | def forward( 242 | self, 243 | seq: Tensor, 244 | hid: Tuple[Tensor, Tensor], 245 | ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 246 | '''_summary_ 247 | 248 | Args: 249 | seq (Tensor): _description_ 250 | hid (Tuple[Tensor, Tensor]): _description_ 251 | 252 | Returns: 253 | Tuple[Tensor, Tuple[Tensor, Tensor]]: _description_ 254 | ''' 255 | 256 | # Separate the hidden (previous) state into the cell state, 257 | # the normalizer state, the hidden state, and the stabilizer state. 258 | c_tm1, n_tm1, m_tm1 = hid 259 | 260 | x_n : Tensor = self.inp_norm(seq) # shape: b i 261 | 262 | x_t = self.up_l_proj(x_n) # shape: b (i * p_factor) 263 | r_t = self.up_r_proj(x_n) # shape: b (h d) 264 | 265 | # Compute the causal convolutional input (to be 266 | # used for the query and key gates) 267 | x_c = self.causal_conv(x_t) # shape: b 1 (i * p_factor) 268 | x_c = rearrange(silu(x_c), 'b ... -> b (...)') # shape: b (i * p_factor) 269 | 270 | q_t = rearrange(self.W_q(x_c), 'b (h d) -> b h d', h=self.head_num) 271 | k_t = rearrange(self.W_k(x_c), 'b (h d) -> b h d', h=self.head_num) / sqrt(self.head_dim) 272 | v_t = rearrange(self.W_v(x_t), 'b (h d) -> b h d', h=self.head_num) 273 | 274 | i_t: Tensor = self.W_i(x_c) # shape: b h 275 | f_t: Tensor = self.W_f(x_c) # shape: b h 276 | o_t: Tensor = self.W_o(x_t) # shape: b (h d) 277 | 278 | # Compute the gated outputs for the newly computed inputs 279 | m_t = torch.max(f_t + m_tm1, i_t) 280 | 281 | i_t = exp(i_t - m_t) # Eq. (25) in ref. paper 282 | f_t = exp(f_t - m_t + m_tm1) # Eq. (26) in ref. paper 283 | o_t = sigmoid(o_t) # Eq. (27) in ref. paper 284 | 285 | # Update the internal states of the model 286 | c_t = enlarge_as(f_t, c_tm1) * c_tm1 + enlarge_as(i_t, c_tm1) * einsum(v_t, k_t, 'b h d, b h p -> b h d p') 287 | n_t = enlarge_as(f_t, n_tm1) * n_tm1 + enlarge_as(i_t, k_t) * k_t 288 | h_t = o_t * rearrange( 289 | einsum(c_t, q_t, 'b h d p, b h p -> b h d') / 290 | einsum(n_t, q_t, 'b h d, b h d -> b h').clamp(min=1).unsqueeze(-1), 291 | 'b h d -> b (h d)' 292 | ) # Eq. (21) in ref. paper 293 | 294 | x_c = rearrange(x_c, 'b i -> b i 1') 295 | out = self.hid_norm(h_t) + self.skip(x_c).squeeze() # shape: b (h d) 296 | out = out * silu(r_t) # shape: b (h d) 297 | out = self.down_proj(out) # shape: h i 298 | 299 | # Return output with the residual connection and the 300 | # newly updated hidden state. 301 | return out + seq, (c_t, n_t, m_t) -------------------------------------------------------------------------------- /xlstm/stories.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from os import path 3 | from torch import Tensor 4 | from torch.utils.data import IterableDataset 5 | 6 | from typing import Literal, Tuple, Generator 7 | from transformers import PreTrainedTokenizerBase 8 | 9 | from .data import LightningDataset 10 | from .utils import TokenizerWrapper 11 | 12 | class TinyStories(IterableDataset): 13 | 14 | def __init__( 15 | self, 16 | root : str, 17 | tokenizer : PreTrainedTokenizerBase, 18 | max_length : int = 256, 19 | data_split : Literal['train', 'valid', 'test'] = 'train', 20 | read_chunk : int = 4096, 21 | ) -> None: 22 | super().__init__() 23 | 24 | text_path = path.join(root, f'{data_split}.txt') 25 | 26 | with open(text_path, 'r', encoding='utf-8') as f: 27 | # Move the file pointer to the end of the file 28 | f.seek(0, 2) 29 | 30 | # Get the current position of the file pointer, which is the file size 31 | self.file_size = f.tell() 32 | 33 | self.tokenizer = tokenizer 34 | self.read_chunk = read_chunk 35 | self.max_length = max_length + 1 36 | 37 | self.tokens = [] 38 | self.stream = open(text_path, 'r', encoding='utf-8') 39 | 40 | self._start = 0 41 | self._end = len(self) 42 | 43 | def __len__(self) -> int: 44 | return self.file_size 45 | 46 | def __del__(self) -> None: 47 | if hasattr(self, 'stream'): self.stream.close() 48 | 49 | def __iter__(self) -> Generator[Tuple[Tensor, Tensor], None, None]: 50 | 51 | self.stream.seek(self._start) 52 | 53 | while self.stream.tell() < self._end: 54 | while len(self.tokens) < self.max_length: 55 | self.tokens.extend( 56 | self.tokenizer.encode( 57 | self.stream.read(self.read_chunk) 58 | ) 59 | ) 60 | 61 | tokens, self.tokens = self.tokens[:self.max_length], self.tokens[self.max_length:] 62 | 63 | prev = torch.tensor(tokens[:-1]) 64 | post = torch.tensor(tokens[+1:]) 65 | 66 | yield prev, post 67 | 68 | class TinyStoriesLightning(LightningDataset): 69 | '''Lightning Dataset class for the Tiny Stories dataset. The Tiny 70 | Stories dataset is a small dataset of short stories, each consisting 71 | of a few sentences. The dataset is used for training a language model. 72 | ''' 73 | 74 | def __init__( 75 | self, 76 | tokenizer : TokenizerWrapper, 77 | root : str = './', 78 | max_length : int = 256, 79 | read_chunk : int = 1024, 80 | **kwargs, 81 | ) -> None: 82 | super().__init__(**kwargs) 83 | 84 | self.root = root 85 | self.tokenizer = tokenizer.get_tokenizer() 86 | self.max_length = max_length 87 | self.read_chunk = read_chunk 88 | 89 | # NOTE: We ignore the tokenizer key to avoid having 90 | # a repetition with the LightningModule 91 | self.save_hyperparameters(ignore=['tokenizer']) 92 | 93 | def setup(self, stage: str) -> None: 94 | 95 | match stage: 96 | case 'fit': 97 | self.train_dataset = TinyStories( 98 | root=self.root, 99 | tokenizer=self.tokenizer, 100 | max_length=self.max_length, 101 | data_split='train', 102 | read_chunk=self.read_chunk 103 | ) 104 | self.valid_dataset = TinyStories( 105 | root=self.root, 106 | tokenizer=self.tokenizer, 107 | max_length=self.max_length, 108 | data_split='valid', 109 | read_chunk=self.read_chunk 110 | ) 111 | case 'test': 112 | self.test__dataset = TinyStories( 113 | root=self.root, 114 | tokenizer=self.tokenizer, 115 | max_length=self.max_length, 116 | data_split='test', 117 | read_chunk=self.read_chunk 118 | ) 119 | case _: 120 | raise ValueError(f'Invalid stage: {stage}') -------------------------------------------------------------------------------- /xlstm/utils.py: -------------------------------------------------------------------------------- 1 | from tokenizers import AddedToken 2 | import torch 3 | import torch.nn as nn 4 | 5 | from einops import rearrange 6 | 7 | from torch import Tensor 8 | from torch.utils.data import get_worker_info 9 | 10 | from transformers import AutoTokenizer 11 | from transformers import PreTrainedTokenizerBase 12 | 13 | from typing import Dict, List, Tuple, TypeVar 14 | 15 | T = TypeVar('T') 16 | D = TypeVar('D') 17 | 18 | Hidden = List[Tuple[Tensor, ...]] 19 | 20 | def exists(var : T | None) -> bool: 21 | return var is not None 22 | 23 | def default(var : T | None, val : D) -> T | D: 24 | return var if exists(var) else val 25 | 26 | def enlarge_as(src : Tensor, other : Tensor) -> Tensor: 27 | ''' 28 | Add sufficient number of singleton dimensions 29 | to tensor a **to the right** so to match the 30 | shape of tensor b. NOTE that simple broadcasting 31 | works in the opposite direction. 32 | ''' 33 | return rearrange(src, f'... -> ...{" 1" * (other.dim() - src.dim())}').contiguous() 34 | 35 | def default_iterdata_worker_init(worker_id : int) -> None: 36 | torch.manual_seed(torch.initial_seed() + worker_id) 37 | worker_info = get_worker_info() 38 | 39 | if worker_info is None: return 40 | 41 | dataset = worker_info.dataset 42 | glob_start = dataset._start # type: ignore 43 | glob_end = dataset._end # type: ignore 44 | 45 | per_worker = int((glob_end - glob_start) / worker_info.num_workers) 46 | worker_id = worker_info.id 47 | 48 | dataset._start = glob_start + worker_id * per_worker # type: ignore 49 | dataset._end = min(dataset._start + per_worker, glob_end) # type: ignore 50 | 51 | class CausalConv1d(nn.Conv1d): 52 | def __init__( 53 | self, 54 | in_channels, 55 | out_channels, 56 | kernel_size, 57 | stride=1, 58 | dilation=1, 59 | groups=1, 60 | bias=True 61 | ): 62 | self._padding = (kernel_size - 1) * dilation 63 | 64 | super(CausalConv1d, self).__init__( 65 | in_channels, 66 | out_channels, 67 | kernel_size=kernel_size, 68 | stride=stride, 69 | padding=self._padding, 70 | dilation=dilation, 71 | groups=groups, 72 | bias=bias) 73 | 74 | def forward(self, inp : Tensor) -> Tensor: 75 | # Handle the case where input has only two dimensions 76 | # we expect them to have semantics (batch, channels), 77 | # so we add the missing dimension manually 78 | if inp.dim() == 2: inp = rearrange(inp, 'b i -> b 1 i') 79 | 80 | result = super(CausalConv1d, self).forward(inp) 81 | if self._padding != 0: return result[..., :-self._padding] 82 | return result 83 | 84 | class BlockLinear(nn.Module): 85 | def __init__( 86 | self, 87 | block_dims : List[int | List[int]], 88 | bias : bool = False, 89 | ): 90 | super(BlockLinear, self).__init__() 91 | 92 | self._blocks = nn.ParameterList([ 93 | nn.Parameter(torch.randn(size, requires_grad=True)) 94 | for size in block_dims 95 | ]) 96 | 97 | self._bias = nn.Parameter(torch.zeros(sum(block_dims))) if bias else None 98 | 99 | def forward(self, inp : Tensor) -> Tensor: 100 | # Assemble the blocks into a block-diagonal matrix 101 | full = torch.block_diag(*self._blocks) 102 | 103 | out = torch.matmul(inp, full) 104 | 105 | if self._bias is not None: 106 | out = out + self._bias 107 | 108 | return out 109 | 110 | class TokenizerWrapper: 111 | ''' 112 | A wrapper class for tokenizers. 113 | 114 | This class provides a convenient way to initialize and access tokenizers for various pretrained models. 115 | 116 | Args: 117 | pretrained_model_name_or_path (str): The name or path of the pretrained model. 118 | 119 | Attributes: 120 | tokenizer (PreTrainedTokenizerBase): The tokenizer object. 121 | 122 | Methods: 123 | get_tokenizer: Returns the tokenizer object. 124 | 125 | Example: 126 | >>> tokenizer = TokenizerWrapper('bert-base-uncased') 127 | >>> tokenizer.get_tokenizer() 128 | 129 | ''' 130 | 131 | def __init__( 132 | self, 133 | pretrained_model_name_or_path: str, 134 | special_tokens: Dict[str, str | AddedToken] = {}, 135 | ): 136 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) 137 | self.tokenizer.add_special_tokens(special_tokens) 138 | 139 | 140 | def get_tokenizer(self) -> PreTrainedTokenizerBase: 141 | ''' 142 | Returns the tokenizer object. 143 | 144 | Returns: 145 | PreTrainedTokenizerBase: The tokenizer object. 146 | ''' 147 | return self.tokenizer 148 | --------------------------------------------------------------------------------