├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── examples ├── AirPassengers.csv └── AirPassengers.ipynb ├── pyproject.toml └── src └── torchxlstm ├── __init__.py └── xlstm.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package to PyPI when a Release is Created 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | pypi-publish: 9 | name: Publish release to PyPI 10 | runs-on: ubuntu-latest 11 | environment: 12 | name: release 13 | url: https://pypi.org/p/torchxlstm 14 | permissions: 15 | id-token: write 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: "3.x" 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | python -m pip install --upgrade build 26 | pip install setuptools wheel 27 | - name: Build package 28 | run: | 29 | python -m build 30 | - name: Publish package distributions to PyPI 31 | uses: pypa/gh-action-pypi-publish@release/v1 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Akaash Dash 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # xLSTM 2 | 3 | A pure pytorch implementation of the [XLSTM paper](https://arxiv.org/abs/2405.04517). 4 | 5 | ## TODO 6 | - Ensure correct with paper 7 | - Create some usage examples 8 | - Implement paralleization 9 | - CUDA? 10 | - Allow for different initializations according to: https://pytorch.org/docs/stable/nn.init.html 11 | - Allow for flattening of x for greater shape conformity 12 | - Allow for batching 13 | - Add tests 14 | - https://github.com/catid/audio_prediction/tree/master 15 | - other classic RNN/LSTM tasks 16 | 17 | ## References 18 | - https://arxiv.org/abs/2405.04517 19 | - https://discuss.pytorch.org/t/causal-convolution/3456/3 20 | - https://pytorch.org/docs/stable/generated/torch.nn.RNN.html 21 | - https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html -------------------------------------------------------------------------------- /examples/AirPassengers.csv: -------------------------------------------------------------------------------- 1 | Month,#Passengers 2 | 1949-01,112 3 | 1949-02,118 4 | 1949-03,132 5 | 1949-04,129 6 | 1949-05,121 7 | 1949-06,135 8 | 1949-07,148 9 | 1949-08,148 10 | 1949-09,136 11 | 1949-10,119 12 | 1949-11,104 13 | 1949-12,118 14 | 1950-01,115 15 | 1950-02,126 16 | 1950-03,141 17 | 1950-04,135 18 | 1950-05,125 19 | 1950-06,149 20 | 1950-07,170 21 | 1950-08,170 22 | 1950-09,158 23 | 1950-10,133 24 | 1950-11,114 25 | 1950-12,140 26 | 1951-01,145 27 | 1951-02,150 28 | 1951-03,178 29 | 1951-04,163 30 | 1951-05,172 31 | 1951-06,178 32 | 1951-07,199 33 | 1951-08,199 34 | 1951-09,184 35 | 1951-10,162 36 | 1951-11,146 37 | 1951-12,166 38 | 1952-01,171 39 | 1952-02,180 40 | 1952-03,193 41 | 1952-04,181 42 | 1952-05,183 43 | 1952-06,218 44 | 1952-07,230 45 | 1952-08,242 46 | 1952-09,209 47 | 1952-10,191 48 | 1952-11,172 49 | 1952-12,194 50 | 1953-01,196 51 | 1953-02,196 52 | 1953-03,236 53 | 1953-04,235 54 | 1953-05,229 55 | 1953-06,243 56 | 1953-07,264 57 | 1953-08,272 58 | 1953-09,237 59 | 1953-10,211 60 | 1953-11,180 61 | 1953-12,201 62 | 1954-01,204 63 | 1954-02,188 64 | 1954-03,235 65 | 1954-04,227 66 | 1954-05,234 67 | 1954-06,264 68 | 1954-07,302 69 | 1954-08,293 70 | 1954-09,259 71 | 1954-10,229 72 | 1954-11,203 73 | 1954-12,229 74 | 1955-01,242 75 | 1955-02,233 76 | 1955-03,267 77 | 1955-04,269 78 | 1955-05,270 79 | 1955-06,315 80 | 1955-07,364 81 | 1955-08,347 82 | 1955-09,312 83 | 1955-10,274 84 | 1955-11,237 85 | 1955-12,278 86 | 1956-01,284 87 | 1956-02,277 88 | 1956-03,317 89 | 1956-04,313 90 | 1956-05,318 91 | 1956-06,374 92 | 1956-07,413 93 | 1956-08,405 94 | 1956-09,355 95 | 1956-10,306 96 | 1956-11,271 97 | 1956-12,306 98 | 1957-01,315 99 | 1957-02,301 100 | 1957-03,356 101 | 1957-04,348 102 | 1957-05,355 103 | 1957-06,422 104 | 1957-07,465 105 | 1957-08,467 106 | 1957-09,404 107 | 1957-10,347 108 | 1957-11,305 109 | 1957-12,336 110 | 1958-01,340 111 | 1958-02,318 112 | 1958-03,362 113 | 1958-04,348 114 | 1958-05,363 115 | 1958-06,435 116 | 1958-07,491 117 | 1958-08,505 118 | 1958-09,404 119 | 1958-10,359 120 | 1958-11,310 121 | 1958-12,337 122 | 1959-01,360 123 | 1959-02,342 124 | 1959-03,406 125 | 1959-04,396 126 | 1959-05,420 127 | 1959-06,472 128 | 1959-07,548 129 | 1959-08,559 130 | 1959-09,463 131 | 1959-10,407 132 | 1959-11,362 133 | 1959-12,405 134 | 1960-01,417 135 | 1960-02,391 136 | 1960-03,419 137 | 1960-04,461 138 | 1960-05,472 139 | 1960-06,535 140 | 1960-07,622 141 | 1960-08,606 142 | 1960-09,508 143 | 1960-10,461 144 | 1960-11,390 145 | 1960-12,432 146 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "torchxlstm" 7 | version = "0.0.3" 8 | description = "A pure pytorch implementation of xLSTM." 9 | readme = "README.md" 10 | requires-python = ">=3" 11 | license = {file = "LICENSE.txt"} 12 | keywords = ["xlstm", "lstm", "torch", "pytorch"] 13 | authors = [ 14 | {name = "Akaash Dash", email = "akaash.dash@gmail.com" } 15 | ] 16 | maintainers = [ 17 | {name = "Akaash Dash", email = "akaash.dash@gmail.com" } 18 | ] 19 | classifiers = [ 20 | "Programming Language :: Python :: 3", 21 | "License :: OSI Approved :: MIT License", 22 | "Operating System :: OS Independent", 23 | ] 24 | dependencies = [ 25 | "torch" 26 | ] 27 | 28 | [project.urls] 29 | "Homepage" = "https://github.com/akaashdash/xlstm" 30 | "Source" = "https://github.com/akaashdash/xlstm/" -------------------------------------------------------------------------------- /src/torchxlstm/__init__.py: -------------------------------------------------------------------------------- 1 | from .xlstm import * -------------------------------------------------------------------------------- /src/torchxlstm/xlstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class CausalConv1D(nn.Module): 6 | def __init__(self, in_channels, out_channels, kernel_size, dilation=1, **kwargs): 7 | super(CausalConv1D, self).__init__() 8 | self.padding = (kernel_size - 1) * dilation 9 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=self.padding, dilation=dilation, **kwargs) 10 | 11 | def forward(self, x): 12 | x = self.conv(x) 13 | return x[:, :, :-self.padding] 14 | 15 | class BlockDiagonal(nn.Module): 16 | def __init__(self, in_features, out_features, num_blocks): 17 | super(BlockDiagonal, self).__init__() 18 | self.in_features = in_features 19 | self.out_features = out_features 20 | self.num_blocks = num_blocks 21 | 22 | assert out_features % num_blocks == 0 23 | 24 | block_out_features = out_features // num_blocks 25 | 26 | self.blocks = nn.ModuleList([ 27 | nn.Linear(in_features, block_out_features) 28 | for _ in range(num_blocks) 29 | ]) 30 | 31 | def forward(self, x): 32 | x = [block(x) for block in self.blocks] 33 | x = torch.cat(x, dim=-1) 34 | return x 35 | 36 | class sLSTMBlock(nn.Module): 37 | def __init__(self, input_size, head_size, num_heads, proj_factor=4/3): 38 | super(sLSTMBlock, self).__init__() 39 | self.input_size = input_size 40 | self.head_size = head_size 41 | self.hidden_size = head_size * num_heads 42 | self.num_heads = num_heads 43 | self.proj_factor = proj_factor 44 | 45 | assert proj_factor > 0 46 | 47 | self.layer_norm = nn.LayerNorm(input_size) 48 | self.causal_conv = CausalConv1D(1, 1, 4) 49 | 50 | self.Wz = BlockDiagonal(input_size, self.hidden_size, num_heads) 51 | self.Wi = BlockDiagonal(input_size, self.hidden_size, num_heads) 52 | self.Wf = BlockDiagonal(input_size, self.hidden_size, num_heads) 53 | self.Wo = BlockDiagonal(input_size, self.hidden_size, num_heads) 54 | 55 | self.Rz = BlockDiagonal(self.hidden_size, self.hidden_size, num_heads) 56 | self.Ri = BlockDiagonal(self.hidden_size, self.hidden_size, num_heads) 57 | self.Rf = BlockDiagonal(self.hidden_size, self.hidden_size, num_heads) 58 | self.Ro = BlockDiagonal(self.hidden_size, self.hidden_size, num_heads) 59 | 60 | self.group_norm = nn.GroupNorm(num_heads, self.hidden_size) 61 | 62 | self.up_proj_left = nn.Linear(self.hidden_size, int(self.hidden_size * proj_factor)) 63 | self.up_proj_right = nn.Linear(self.hidden_size, int(self.hidden_size * proj_factor)) 64 | self.down_proj = nn.Linear(int(self.hidden_size * proj_factor), input_size) 65 | 66 | def forward(self, x, prev_state): 67 | assert x.size(-1) == self.input_size 68 | h_prev, c_prev, n_prev, m_prev = prev_state 69 | 70 | h_prev = h_prev.to(x.device) 71 | c_prev = c_prev.to(x.device) 72 | n_prev = n_prev.to(x.device) 73 | m_prev = m_prev.to(x.device) 74 | 75 | x_norm = self.layer_norm(x) 76 | x_conv = F.silu(self.causal_conv(x_norm.unsqueeze(1)).squeeze(1)) 77 | 78 | z = torch.tanh(self.Wz(x_norm) + self.Rz(h_prev)) 79 | o = torch.sigmoid(self.Wo(x_norm) + self.Ro(h_prev)) 80 | i_tilde = self.Wi(x_conv) + self.Ri(h_prev) 81 | f_tilde = self.Wf(x_conv) + self.Rf(h_prev) 82 | 83 | m_t = torch.max(f_tilde + m_prev, i_tilde) 84 | i = torch.exp(i_tilde - m_t) 85 | f = torch.exp(f_tilde + m_prev - m_t) 86 | 87 | c_t = f * c_prev + i * z 88 | n_t = f * n_prev + i 89 | h_t = o * c_t / n_t 90 | 91 | output = h_t 92 | output_norm = self.group_norm(output) 93 | output_left = self.up_proj_left(output_norm) 94 | output_right = self.up_proj_right(output_norm) 95 | output_gated = F.gelu(output_right) 96 | output = output_left * output_gated 97 | output = self.down_proj(output) 98 | final_output = output + x 99 | 100 | return final_output, (h_t, c_t, n_t, m_t) 101 | 102 | class sLSTM(nn.Module): 103 | # TODO: Add bias, dropout, bidirectional 104 | def __init__(self, input_size, head_size, num_heads, num_layers=1, batch_first=False, proj_factor=4/3): 105 | super(sLSTM, self).__init__() 106 | self.input_size = input_size 107 | self.head_size = head_size 108 | self.hidden_size = head_size * num_heads 109 | self.num_heads = num_heads 110 | self.num_layers = num_layers 111 | self.batch_first = batch_first 112 | self.proj_factor_slstm = proj_factor 113 | 114 | self.layers = nn.ModuleList([sLSTMBlock(input_size, head_size, num_heads, proj_factor) for _ in range(num_layers)]) 115 | 116 | def forward(self, x, state=None): 117 | assert x.ndim == 3 118 | if self.batch_first: x = x.transpose(0, 1) 119 | seq_len, batch_size, _ = x.size() 120 | 121 | if state is not None: 122 | state = torch.stack(list(state)).to(x.device) 123 | assert state.ndim == 4 124 | num_hidden, state_num_layers, state_batch_size, state_input_size = state.size() 125 | assert num_hidden == 4 126 | assert state_num_layers == self.num_layers 127 | assert state_batch_size == batch_size 128 | assert state_input_size == self.input_size 129 | state = state.transpose(0, 1) 130 | else: 131 | state = torch.zeros(self.num_layers, 4, batch_size, self.hidden_size, device=x.device) 132 | 133 | output = [] 134 | for t in range(seq_len): 135 | x_t = x[t] 136 | for layer in range(self.num_layers): 137 | x_t, state_tuple = self.layers[layer](x_t, tuple(state[layer].clone())) 138 | state[layer] = torch.stack(list(state_tuple)) 139 | output.append(x_t) 140 | 141 | output = torch.stack(output) 142 | if self.batch_first: 143 | output = output.transpose(0, 1) 144 | state = tuple(state.transpose(0, 1)) 145 | return output, state 146 | 147 | class mLSTMBlock(nn.Module): 148 | def __init__(self, input_size, head_size, num_heads, proj_factor=2): 149 | super(mLSTMBlock, self).__init__() 150 | self.input_size = input_size 151 | self.head_size = head_size 152 | self.hidden_size = head_size * num_heads 153 | self.num_heads = num_heads 154 | self.proj_factor = proj_factor 155 | 156 | assert proj_factor > 0 157 | 158 | self.layer_norm = nn.LayerNorm(input_size) 159 | self.up_proj_left = nn.Linear(input_size, int(input_size * proj_factor)) 160 | self.up_proj_right = nn.Linear(input_size, self.hidden_size) 161 | self.down_proj = nn.Linear(self.hidden_size, input_size) 162 | 163 | self.causal_conv = CausalConv1D(1, 1, 4) 164 | self.skip_connection = nn.Linear(int(input_size * proj_factor), self.hidden_size) 165 | 166 | self.Wq = BlockDiagonal(int(input_size * proj_factor), self.hidden_size, num_heads) 167 | self.Wk = BlockDiagonal(int(input_size * proj_factor), self.hidden_size, num_heads) 168 | self.Wv = BlockDiagonal(int(input_size * proj_factor), self.hidden_size, num_heads) 169 | self.Wi = nn.Linear(int(input_size * proj_factor), self.hidden_size) 170 | self.Wf = nn.Linear(int(input_size * proj_factor), self.hidden_size) 171 | self.Wo = nn.Linear(int(input_size * proj_factor), self.hidden_size) 172 | 173 | self.group_norm = nn.GroupNorm(num_heads, self.hidden_size) 174 | 175 | def forward(self, x, prev_state): 176 | h_prev, c_prev, n_prev, m_prev = prev_state 177 | 178 | h_prev = h_prev.to(x.device) 179 | c_prev = c_prev.to(x.device) 180 | n_prev = n_prev.to(x.device) 181 | m_prev = m_prev.to(x.device) 182 | 183 | assert x.size(-1) == self.input_size 184 | x_norm = self.layer_norm(x) 185 | x_up_left = self.up_proj_left(x_norm) 186 | x_up_right = self.up_proj_right(x_norm) 187 | 188 | x_conv = F.silu(self.causal_conv(x_up_left.unsqueeze(1)).squeeze(1)) 189 | x_skip = self.skip_connection(x_conv) 190 | 191 | q = self.Wq(x_conv) 192 | k = self.Wk(x_conv) / (self.head_size ** 0.5) 193 | v = self.Wv(x_up_left) 194 | 195 | i_tilde = self.Wi(x_conv) 196 | f_tilde = self.Wf(x_conv) 197 | o = torch.sigmoid(self.Wo(x_up_left)) 198 | 199 | m_t = torch.max(f_tilde + m_prev, i_tilde) 200 | i = torch.exp(i_tilde - m_t) 201 | f = torch.exp(f_tilde + m_prev - m_t) 202 | 203 | c_t = f * c_prev + i * (v * k) # v @ k.T 204 | n_t = f * n_prev + i * k 205 | h_t = o * (c_t * q) / torch.max(torch.abs(n_t.T @ q), 1)[0] # o * (c @ q) / max{|n.T @ q|, 1} 206 | 207 | output = h_t 208 | output_norm = self.group_norm(output) 209 | output = output_norm + x_skip 210 | output = output * F.silu(x_up_right) 211 | output = self.down_proj(output) 212 | final_output = output + x 213 | 214 | return final_output, (h_t, c_t, n_t, m_t) 215 | 216 | class mLSTM(nn.Module): 217 | # TODO: Add bias, dropout, bidirectional 218 | def __init__(self, input_size, head_size, num_heads, num_layers=1, batch_first=False, proj_factor=2): 219 | super(mLSTM, self).__init__() 220 | self.input_size = input_size 221 | self.head_size = head_size 222 | self.hidden_size = head_size * num_heads 223 | self.num_heads = num_heads 224 | self.num_layers = num_layers 225 | self.batch_first = batch_first 226 | self.proj_factor_slstm = proj_factor 227 | 228 | self.layers = nn.ModuleList([mLSTMBlock(input_size, head_size, num_heads, proj_factor) for _ in range(num_layers)]) 229 | 230 | def forward(self, x, state=None): 231 | assert x.ndim == 3 232 | if self.batch_first: x = x.transpose(0, 1) 233 | seq_len, batch_size, _ = x.size() 234 | 235 | if state is not None: 236 | state = torch.stack(list(state)).to(x.device) 237 | assert state.ndim == 4 238 | num_hidden, state_num_layers, state_batch_size, state_input_size = state.size() 239 | assert num_hidden == 4 240 | assert state_num_layers == self.num_layers 241 | assert state_batch_size == batch_size 242 | assert state_input_size == self.input_size 243 | state = state.transpose(0, 1) 244 | else: 245 | state = torch.zeros(self.num_layers, 4, batch_size, self.hidden_size, device=x.device) 246 | 247 | output = [] 248 | for t in range(seq_len): 249 | x_t = x[t] 250 | for layer in range(self.num_layers): 251 | x_t, state_tuple = self.layers[layer](x_t, tuple(state[layer].clone())) 252 | state[layer] = torch.stack(list(state_tuple)) 253 | output.append(x_t) 254 | 255 | output = torch.stack(output) 256 | if self.batch_first: 257 | output = output.transpose(0, 1) 258 | state = tuple(state.transpose(0, 1)) 259 | return output, state 260 | 261 | class xLSTM(nn.Module): 262 | # TODO: Add bias, dropout, bidirectional 263 | def __init__(self, input_size, head_size, num_heads, layers, batch_first=False, proj_factor_slstm=4/3, proj_factor_mlstm=2): 264 | super(xLSTM, self).__init__() 265 | self.input_size = input_size 266 | self.head_size = head_size 267 | self.hidden_size = head_size * num_heads 268 | self.num_heads = num_heads 269 | self.layers = layers 270 | self.num_layers = len(layers) 271 | self.batch_first = batch_first 272 | self.proj_factor_slstm = proj_factor_slstm 273 | self.proj_factor_mlstm = proj_factor_mlstm 274 | 275 | self.layers = nn.ModuleList() 276 | for layer_type in layers: 277 | if layer_type == 's': 278 | layer = sLSTMBlock(input_size, head_size, num_heads, proj_factor_slstm) 279 | elif layer_type == 'm': 280 | layer = mLSTMBlock(input_size, head_size, num_heads, proj_factor_mlstm) 281 | else: 282 | raise ValueError(f"Invalid layer type: {layer_type}. Choose 's' for sLSTM or 'm' for mLSTM.") 283 | self.layers.append(layer) 284 | 285 | def forward(self, x, state=None): 286 | assert x.ndim == 3 287 | if self.batch_first: x = x.transpose(0, 1) 288 | seq_len, batch_size, _ = x.size() 289 | 290 | if state is not None: 291 | state = torch.stack(list(state)).to(x.device) 292 | assert state.ndim == 4 293 | num_hidden, state_num_layers, state_batch_size, state_input_size = state.size() 294 | assert num_hidden == 4 295 | assert state_num_layers == self.num_layers 296 | assert state_batch_size == batch_size 297 | assert state_input_size == self.input_size 298 | state = state.transpose(0, 1) 299 | else: 300 | state = torch.zeros(self.num_layers, 4, batch_size, self.hidden_size, device=x.device) 301 | 302 | output = [] 303 | for t in range(seq_len): 304 | x_t = x[t] 305 | for layer in range(self.num_layers): 306 | x_t, state_tuple = self.layers[layer](x_t, tuple(state[layer].clone())) 307 | state[layer] = torch.stack(list(state_tuple)) 308 | output.append(x_t) 309 | 310 | output = torch.stack(output) 311 | if self.batch_first: 312 | output = output.transpose(0, 1) 313 | state = tuple(state.transpose(0, 1)) 314 | return output, state --------------------------------------------------------------------------------