├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── data ├── README.md └── enwik8.gz ├── gated_state_spaces_pytorch ├── __init__.py ├── autoregressive_wrapper.py ├── dsconv.py ├── gss.py └── mhesa.py ├── gss.png ├── setup.py └── train.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Gated State Spaces - Pytorch 4 | 5 | Implementation of Gated State Spaces, from the paper Long Range Language Modeling via Gated State Spaces, in Pytorch. In particular, it will contain the hybrid version containing local self attention with the long-range GSS. 6 | 7 | It will also contain a few more settings to compare state spaces to a sequence-wise GLU depthwise conv, and even simpler, a parameterized exponential moving average along the sequence dimension. So we get to the bottom of whether state spaces are worth it, or whether it is really all about the `O(L log(L))` FFT convolution trick. Results will be shared in the readme. 8 | 9 | I will also pit the GSS module against the Path-X challenge and see how well it does. 10 | 11 | Update: This paper has beat S4 on LRA using multi-headed EMA + single head attention. 12 | 13 | ## Install 14 | 15 | ```bash 16 | $ pip install gated-state-spaces-pytorch 17 | ``` 18 | 19 | ## Usage 20 | 21 | ```python 22 | import torch 23 | from gated_state_spaces_pytorch import GSS 24 | 25 | gss = GSS( 26 | dim = 512, # dimension 27 | dim_expansion_factor = 4, # hidden dimension (expansion factor x dim) = 2048 28 | dss_kernel_N = 512, 29 | dss_kernel_H = 256 30 | ) 31 | 32 | x = torch.randn(1, 65536, 512) 33 | 34 | out = gss(x) # (1, 65536, 512) 35 | ``` 36 | 37 | Gated state spaces language model 38 | 39 | ```python 40 | import torch 41 | from gated_state_spaces_pytorch import GatedStateSpacesLM 42 | 43 | gss_lm = GatedStateSpacesLM( 44 | num_tokens = 20000, 45 | depth = 12, 46 | dim = 512, 47 | dim_expansion_factor = 4, 48 | dss_kernel_N = 512, 49 | dss_kernel_H = 256 50 | ) 51 | 52 | ids = torch.randint(0, 20000, (1, 1024)) 53 | 54 | logits = gss_lm(ids) # (1, 1024, 20000) 55 | ``` 56 | 57 | ## Todo 58 | 59 | - [x] enwik8 60 | - [x] gss lm class 61 | - [x] add dsconv + learned ema 62 | - [ ] add attention. 63 | 64 | ## Citations 65 | 66 | ```bibtex 67 | @inproceedings{Mehta2022LongRL, 68 | title = {Long Range Language Modeling via Gated State Spaces}, 69 | author = {Harsh Mehta and Ankit Gupta and Ashok Cutkosky and Behnam Neyshabur}, 70 | year = {2022} 71 | } 72 | ``` 73 | 74 | ```bibtex 75 | @misc{woo2022etsformer, 76 | title = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting}, 77 | author = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi}, 78 | year = {2022}, 79 | eprint = {2202.01381}, 80 | archivePrefix = {arXiv}, 81 | primaryClass = {cs.LG} 82 | } 83 | ``` 84 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data source 2 | 3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/ -------------------------------------------------------------------------------- /data/enwik8.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/gated-state-spaces-pytorch/d484cf5a17382e50749d561a401f727234e10a06/data/enwik8.gz -------------------------------------------------------------------------------- /gated_state_spaces_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from gated_state_spaces_pytorch.gss import GSS, GatedStateSpacesLM 2 | from gated_state_spaces_pytorch.dsconv import GatedDsConv, GatedDsConvLM 3 | from gated_state_spaces_pytorch.mhesa import GatedExponentialSmoothingLM, GatedMHESA 4 | -------------------------------------------------------------------------------- /gated_state_spaces_pytorch/autoregressive_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | # helper function 7 | 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | 13 | def eval_decorator(fn): 14 | def inner(model, *args, **kwargs): 15 | was_training = model.training 16 | model.eval() 17 | out = fn(model, *args, **kwargs) 18 | model.train(was_training) 19 | return out 20 | 21 | return inner 22 | 23 | 24 | # top k filtering 25 | 26 | 27 | def top_k(logits, thres=0.9): 28 | k = int((1 - thres) * logits.shape[-1]) 29 | val, ind = torch.topk(logits, k) 30 | probs = torch.full_like(logits, float('-inf')) 31 | probs.scatter_(1, ind, val) 32 | return probs 33 | 34 | 35 | class AutoregressiveWrapper(nn.Module): 36 | def __init__(self, net, pad_value=0, max_seq_len=4096): 37 | super().__init__() 38 | self.max_seq_len = max_seq_len 39 | self.pad_value = pad_value 40 | self.net = net 41 | 42 | @torch.no_grad() 43 | @eval_decorator 44 | def generate( 45 | self, 46 | start_tokens, 47 | seq_len, 48 | eos_token=None, 49 | temperature=1.0, 50 | filter_thres=0.9, 51 | **kwargs 52 | ): 53 | b, n, device = *start_tokens.shape, start_tokens.device 54 | 55 | out = start_tokens 56 | 57 | for _ in range(seq_len): 58 | logits = self.net( 59 | out[:, -self.max_seq_len:], 60 | **kwargs 61 | )[:, -1] 62 | 63 | filtered_logits = top_k(logits, thres = filter_thres) 64 | probs = F.softmax(filtered_logits / temperature, dim=-1) 65 | 66 | sample = torch.multinomial(probs, 1) 67 | out = torch.cat((out, sample), dim=-1) 68 | 69 | if exists(eos_token): 70 | is_eos_token = out == eos_token 71 | 72 | if is_eos_token.any(dim=-1).all(): 73 | # mask out everything after the eos tokens 74 | shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) 75 | mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1 76 | out = out.masked_fill(mask, self.pad_value) 77 | break 78 | 79 | return out[:, n:] 80 | 81 | def forward(self, x, **kwargs): 82 | inp, labels = x[:, :-1], x[:, 1:] 83 | return self.net(inp, labels = labels, **kwargs) 84 | -------------------------------------------------------------------------------- /gated_state_spaces_pytorch/dsconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn, einsum 4 | from torch.fft import rfft, irfft 5 | 6 | from einops import rearrange 7 | from scipy.fftpack import next_fast_len 8 | 9 | # functions 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def append_dims(x, num_dims): 15 | if num_dims <= 0: 16 | return x 17 | return x.view(*x.shape, *((1,) * num_dims)) 18 | 19 | def conv1d_fft(x, weights, dim = -2, weight_dim = -1): 20 | # O(N log(N)) 1d convolution using some fourier trick 21 | 22 | assert weight_dim >= dim 23 | 24 | N = x.shape[dim] 25 | M = weights.shape[weight_dim] 26 | 27 | fast_len = next_fast_len(N + M - 1) 28 | 29 | f_x = torch.fft.rfft(x, n = fast_len, dim = dim) 30 | f_weight = torch.fft.rfft(weights, n = fast_len, dim = weight_dim) 31 | 32 | f_v_weight = f_x * append_dims(f_weight.conj(), weight_dim - dim) 33 | out = torch.fft.irfft(f_v_weight, fast_len, dim = dim) 34 | out = out.roll(-1, dims = (dim,)) 35 | 36 | indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device) 37 | out = out.index_select(dim, indices) 38 | return out 39 | 40 | # classes 41 | 42 | class EfficientDsConv(nn.Module): 43 | def __init__( 44 | self, 45 | *, 46 | dim, 47 | heads 48 | ): 49 | super().__init__() 50 | assert (dim % heads) == 0 51 | 52 | self.heads = heads 53 | self.norm = nn.LayerNorm(dim) 54 | 55 | self.to_weight = nn.Linear(dim, heads, bias = False) 56 | 57 | # params D 58 | 59 | self.param_D = nn.Parameter(torch.randn(dim)) 60 | 61 | def forward(self, x): 62 | """ 63 | einstein notation: 64 | b - batch 65 | h - heads (or groups) 66 | l - sequence length 67 | d - dimension 68 | """ 69 | 70 | device, seq_len = x.device, x.shape[1] 71 | u = self.norm(x) 72 | 73 | # learned weighted residual 74 | 75 | residual = u * self.param_D 76 | 77 | # dsconv kernel depends on sequence length 78 | 79 | K = self.to_weight(x) 80 | K = torch.flip(K, dims = (1,)) 81 | 82 | # conv1d fft O(nlog(n)) 83 | 84 | u = rearrange(u, '... (h d) -> ... h d', h = self.heads) 85 | 86 | out = conv1d_fft(u, K, dim = -3, weight_dim = -2) 87 | 88 | out = rearrange(out, '... h d -> ... (h d)') 89 | 90 | return out + residual 91 | 92 | class GatedDsConv(nn.Module): 93 | """ Pseudocode 3.2 """ 94 | """ except state spaces replaced with regular learned convolution kernel """ 95 | 96 | def __init__( 97 | self, 98 | *, 99 | dim, 100 | heads = 8, 101 | dim_dsconv = 512, 102 | dim_expansion_factor = 4, 103 | reverse_seq = False 104 | ): 105 | super().__init__() 106 | assert (dim_dsconv % heads) == 0 107 | self.reverse_seq = reverse_seq 108 | 109 | self.norm = nn.LayerNorm(dim) 110 | 111 | dim_hidden = int(dim_expansion_factor * dim) 112 | self.to_u = nn.Sequential(nn.Linear(dim, dim_hidden, bias = False), nn.GELU()) 113 | self.to_v = nn.Sequential(nn.Linear(dim, dim_dsconv, bias = False), nn.GELU()) 114 | 115 | self.dsconv = EfficientDsConv(dim = dim_dsconv, heads = heads) 116 | 117 | self.to_gate = nn.Linear(dim_dsconv, dim_hidden, bias = False) 118 | self.to_out = nn.Linear(dim_hidden, dim) 119 | 120 | def forward(self, x): 121 | if self.reverse_seq: 122 | x = torch.flip(x, dims = (1,)) 123 | 124 | residual, x = x.clone(), self.norm(x) 125 | 126 | u = self.to_u(x) 127 | v = self.to_v(x) 128 | 129 | v = self.dsconv(v) 130 | 131 | uc = self.to_gate(v) 132 | out = self.to_out(uc * u) 133 | 134 | out = out + residual 135 | 136 | if self.reverse_seq: 137 | out = torch.flip(out, dims = (1,)) 138 | 139 | return out 140 | 141 | # Gated Dsconv LM 142 | 143 | class GatedDsConvLM(nn.Module): 144 | def __init__( 145 | self, 146 | *, 147 | num_tokens, 148 | dim, 149 | depth, 150 | heads = 8, 151 | dim_dsconv = 512, 152 | max_seq_len = 2048, 153 | dim_expansion_factor = 4, 154 | ): 155 | super().__init__() 156 | self.token_emb = nn.Embedding(num_tokens, dim) 157 | self.max_seq_len = max_seq_len 158 | 159 | self.layers = nn.ModuleList([]) 160 | for _ in range(depth): 161 | self.layers.append( 162 | GatedDsConv( 163 | dim = dim, 164 | heads = heads, 165 | dim_dsconv = dim_dsconv, 166 | dim_expansion_factor = dim_expansion_factor 167 | ) 168 | ) 169 | 170 | self.to_logits = nn.Linear(dim, num_tokens, bias = False) 171 | 172 | def forward(self, x, labels = None): 173 | assert x.shape[1] <= self.max_seq_len 174 | 175 | x = self.token_emb(x) 176 | 177 | for dsconv in self.layers: 178 | x = dsconv(x) 179 | 180 | logits = self.to_logits(x) 181 | 182 | if not exists(labels): 183 | return logits 184 | 185 | logits = rearrange(logits, 'b n c -> b c n') 186 | return F.cross_entropy(logits, labels) 187 | -------------------------------------------------------------------------------- /gated_state_spaces_pytorch/gss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn, einsum 4 | from torch.fft import rfft, irfft 5 | 6 | from einops import rearrange 7 | 8 | # functions 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | # classes 14 | 15 | class DSS(nn.Module): 16 | def __init__( 17 | self, 18 | *, 19 | dim, 20 | kernel_N = 512, 21 | dss_kernel_lambda_imag_exp = True 22 | ): 23 | super().__init__() 24 | self.norm = nn.LayerNorm(dim) 25 | 26 | # Lambda 27 | 28 | self.Lambda_real = nn.Parameter(torch.randn(kernel_N)) 29 | self.Lambda_imag = nn.Parameter(torch.randn(kernel_N)) 30 | 31 | # C 32 | 33 | self.C_real = nn.Parameter(torch.randn(dim, kernel_N)) 34 | self.C_imag = nn.Parameter(torch.randn(dim, kernel_N)) 35 | 36 | # params D 37 | 38 | self.param_D = nn.Parameter(torch.randn(dim)) 39 | 40 | # whether to exponentiate lambda imag @albertfgu says it is not accurate to s4 original designs (but it is present in the pseudocode) 41 | 42 | self.dss_kernel_lambda_imag_exp = dss_kernel_lambda_imag_exp 43 | 44 | def forward(self, x): 45 | """ 46 | einstein notation: 47 | b - batch 48 | l - sequence length 49 | d - dimension 50 | """ 51 | 52 | device, seq_len = x.device, x.shape[1] 53 | u = self.norm(x) 54 | 55 | # learned weighted residual 56 | 57 | residual = u * self.param_D 58 | 59 | # derive simple dss kernel 60 | 61 | Lambda_imag = self.Lambda_imag.exp() if self.dss_kernel_lambda_imag_exp else self.Lambda_imag 62 | 63 | Lambda = -self.Lambda_real.exp() + 1j * Lambda_imag 64 | C = self.C_real + 1j * self.C_imag 65 | 66 | arange = torch.arange(seq_len, device = device) 67 | 68 | S = (rearrange(Lambda, 'n -> n 1') * rearrange(arange, 'l -> 1 l')).exp() 69 | C = C * (Lambda.exp() - 1) / Lambda 70 | 71 | K = einsum('h n, n l -> l h', C, S).real 72 | 73 | # conv1d fft O(nlog(n)) 74 | 75 | u_f = rfft(u, n = seq_len * 2, dim = -2) 76 | K_f = rfft(K, n = seq_len * 2, dim = -2) 77 | 78 | y = irfft(u_f * K_f, seq_len * 2, dim = -2)[..., :seq_len, :] 79 | 80 | return y + residual 81 | 82 | class GSS(nn.Module): 83 | """ Pseudocode 3.2 """ 84 | 85 | def __init__( 86 | self, 87 | *, 88 | dim, 89 | dim_expansion_factor = 4, 90 | dss_kernel_N = 512, 91 | dss_kernel_H = 256, 92 | reverse_seq = False, 93 | dss_kernel_lambda_imag_exp = True 94 | ): 95 | super().__init__() 96 | self.reverse_seq = reverse_seq 97 | self.norm = nn.LayerNorm(dim) 98 | 99 | dim_hidden = int(dim_expansion_factor * dim) 100 | self.to_u = nn.Sequential(nn.Linear(dim, dim_hidden, bias = False), nn.GELU()) 101 | self.to_v = nn.Sequential(nn.Linear(dim, dss_kernel_H, bias = False), nn.GELU()) 102 | 103 | self.dss = DSS(dim = dss_kernel_H, kernel_N = dss_kernel_N, dss_kernel_lambda_imag_exp = dss_kernel_lambda_imag_exp) 104 | 105 | self.to_gate = nn.Linear(dss_kernel_H, dim_hidden, bias = False) 106 | self.to_out = nn.Linear(dim_hidden, dim) 107 | 108 | def forward(self, x): 109 | if self.reverse_seq: 110 | x = torch.flip(x, dims = (1,)) 111 | 112 | residual, x = x.clone(), self.norm(x) 113 | 114 | u = self.to_u(x) 115 | v = self.to_v(x) 116 | 117 | v = self.dss(v) 118 | 119 | uc = self.to_gate(v) 120 | out = self.to_out(uc * u) 121 | 122 | out = out + residual 123 | 124 | if self.reverse_seq: 125 | out = torch.flip(out, dims = (1,)) 126 | 127 | return out 128 | 129 | # Gated State Spaces LM 130 | 131 | class GatedStateSpacesLM(nn.Module): 132 | def __init__( 133 | self, 134 | *, 135 | num_tokens, 136 | dim, 137 | depth, 138 | dim_expansion_factor = 4, 139 | dss_kernel_N = 512, 140 | dss_kernel_H = 256, 141 | dss_kernel_lambda_imag_exp = True 142 | ): 143 | super().__init__() 144 | self.token_emb = nn.Embedding(num_tokens, dim) 145 | 146 | self.layers = nn.ModuleList([]) 147 | for _ in range(depth): 148 | self.layers.append( 149 | GSS( 150 | dim = dim, 151 | dss_kernel_H = dss_kernel_H, 152 | dss_kernel_N = dss_kernel_N, 153 | dim_expansion_factor = dim_expansion_factor, 154 | dss_kernel_lambda_imag_exp = dss_kernel_lambda_imag_exp 155 | ) 156 | ) 157 | 158 | self.to_logits = nn.Linear(dim, num_tokens, bias = False) 159 | 160 | def forward(self, x, labels = None): 161 | x = self.token_emb(x) 162 | 163 | for gss in self.layers: 164 | x = gss(x) 165 | 166 | logits = self.to_logits(x) 167 | 168 | if not exists(labels): 169 | return logits 170 | 171 | logits = rearrange(logits, 'b n c -> b c n') 172 | return F.cross_entropy(logits, labels) 173 | -------------------------------------------------------------------------------- /gated_state_spaces_pytorch/mhesa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn, einsum 4 | from torch.fft import rfft, irfft 5 | 6 | from einops import rearrange 7 | from scipy.fftpack import next_fast_len 8 | 9 | # functions 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def append_dims(x, num_dims): 15 | if num_dims <= 0: 16 | return x 17 | return x.view(*x.shape, *((1,) * num_dims)) 18 | 19 | def conv1d_fft(x, weights, dim = -2, weight_dim = -1): 20 | # O(N log(N)) 1d convolution using some fourier trick 21 | 22 | assert weight_dim >= dim 23 | 24 | N = x.shape[dim] 25 | M = weights.shape[weight_dim] 26 | 27 | fast_len = next_fast_len(N + M - 1) 28 | 29 | f_x = torch.fft.rfft(x, n = fast_len, dim = dim) 30 | f_weight = torch.fft.rfft(weights, n = fast_len, dim = weight_dim) 31 | 32 | f_v_weight = f_x * append_dims(f_weight.conj(), weight_dim - dim) 33 | out = torch.fft.irfft(f_v_weight, fast_len, dim = dim) 34 | out = out.roll(-1, dims = (dim,)) 35 | 36 | indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device) 37 | out = out.index_select(dim, indices) 38 | return out 39 | 40 | # classes 41 | 42 | class MHESA(nn.Module): 43 | """ used for time-series in ETSFormer https://arxiv.org/abs/2202.01381 """ 44 | 45 | def __init__( 46 | self, 47 | *, 48 | dim, 49 | heads, 50 | reverse_seq = False 51 | ): 52 | super().__init__() 53 | assert (dim % heads) == 0 54 | self.reverse_seq = reverse_seq 55 | 56 | self.heads = heads 57 | self.norm = nn.LayerNorm(dim) 58 | 59 | self.alphas = nn.Parameter(torch.randn(heads)) 60 | self.dampen_factors = nn.Parameter(torch.randn(heads)) 61 | 62 | # params D 63 | 64 | self.param_D = nn.Parameter(torch.randn(dim)) 65 | 66 | def forward(self, x): 67 | """ 68 | einstein notation: 69 | b - batch 70 | h - heads 71 | l - sequence length 72 | d - dimension 73 | """ 74 | 75 | if self.reverse_seq: 76 | x = torch.flip(x, dims = (1,)) 77 | 78 | device, seq_len = x.device, x.shape[1] 79 | u = self.norm(x) 80 | 81 | # learned weighted residual 82 | 83 | residual = u * self.param_D 84 | 85 | # weights derived from alphas (learned exponential smoothing decay rate) 86 | 87 | alphas = self.alphas.sigmoid() 88 | dampen_factors = self.dampen_factors.sigmoid() 89 | 90 | reversed_powers = torch.arange(seq_len - 1, -1, -1, device = device) 91 | K = alphas * (((1 - alphas) * dampen_factors) ** rearrange(reversed_powers, '... l -> ... l 1')) 92 | 93 | # conv1d fft O(nlog(n)) 94 | 95 | u = rearrange(u, '... (h d) -> ... h d', h = self.heads) 96 | 97 | out = conv1d_fft(u, K, dim = -3, weight_dim = -2) 98 | 99 | out = rearrange(out, '... h d -> ... (h d)') 100 | 101 | out = out + residual 102 | 103 | if self.reverse_seq: 104 | out = torch.flip(out, dims = (1,)) 105 | 106 | return out 107 | 108 | class GatedMHESA(nn.Module): 109 | """ Pseudocode 3.2 """ 110 | """ except state spaces replaced with multi-head exponential smoothing with learned alpha """ 111 | """ used for time-series in ETSFormer https://arxiv.org/abs/2202.01381 """ 112 | 113 | def __init__( 114 | self, 115 | *, 116 | dim, 117 | heads = 8, 118 | dim_mhesa = 512, 119 | dim_expansion_factor = 4, 120 | ): 121 | super().__init__() 122 | assert (dim_mhesa % heads) == 0 123 | 124 | self.norm = nn.LayerNorm(dim) 125 | 126 | dim_hidden = int(dim_expansion_factor * dim) 127 | self.to_u = nn.Sequential(nn.Linear(dim, dim_hidden, bias = False), nn.GELU()) 128 | self.to_v = nn.Sequential(nn.Linear(dim, dim_mhesa, bias = False), nn.GELU()) 129 | 130 | self.mhesa = MHESA(dim = dim_mhesa, heads = heads) 131 | 132 | self.to_gate = nn.Linear(dim_mhesa, dim_hidden, bias = False) 133 | self.to_out = nn.Linear(dim_hidden, dim) 134 | 135 | def forward(self, x): 136 | residual, x = x.clone(), self.norm(x) 137 | 138 | u = self.to_u(x) 139 | v = self.to_v(x) 140 | 141 | v = self.mhesa(v) 142 | 143 | uc = self.to_gate(v) 144 | out = self.to_out(uc * u) 145 | 146 | return out + residual 147 | 148 | # Gated Dsconv LM 149 | 150 | class GatedExponentialSmoothingLM(nn.Module): 151 | def __init__( 152 | self, 153 | *, 154 | num_tokens, 155 | dim, 156 | depth, 157 | heads = 8, 158 | dim_mhesa = 512, 159 | dim_expansion_factor = 4, 160 | ): 161 | super().__init__() 162 | self.token_emb = nn.Embedding(num_tokens, dim) 163 | 164 | self.layers = nn.ModuleList([]) 165 | for _ in range(depth): 166 | self.layers.append( 167 | GatedMHESA( 168 | dim = dim, 169 | heads = heads, 170 | dim_mhesa = dim_mhesa, 171 | dim_expansion_factor = dim_expansion_factor 172 | ) 173 | ) 174 | 175 | self.to_logits = nn.Linear(dim, num_tokens, bias = False) 176 | 177 | def forward(self, x, labels = None): 178 | x = self.token_emb(x) 179 | 180 | for mhesa in self.layers: 181 | x = mhesa(x) 182 | 183 | logits = self.to_logits(x) 184 | 185 | if not exists(labels): 186 | return logits 187 | 188 | logits = rearrange(logits, 'b n c -> b c n') 189 | return F.cross_entropy(logits, labels) 190 | -------------------------------------------------------------------------------- /gss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/gated-state-spaces-pytorch/d484cf5a17382e50749d561a401f727234e10a06/gss.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'gated-state-spaces-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.1.0', 7 | license='MIT', 8 | description = 'Gated State Spaces - GSS - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/gated-state-spaces-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'state spaces', 17 | 'long context' 18 | ], 19 | install_requires=[ 20 | 'einops>=0.4', 21 | 'scipy', 22 | 'torch>=1.6', 23 | ], 24 | classifiers=[ 25 | 'Development Status :: 4 - Beta', 26 | 'Intended Audience :: Developers', 27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 28 | 'License :: OSI Approved :: MIT License', 29 | 'Programming Language :: Python :: 3.6', 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | import tqdm 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from gated_state_spaces_pytorch import GatedStateSpacesLM 12 | from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper 13 | 14 | # constants 15 | 16 | NUM_BATCHES = int(1e5) 17 | BATCH_SIZE = 4 18 | GRADIENT_ACCUMULATE_EVERY = 4 19 | LEARNING_RATE = 2e-4 20 | VALIDATE_EVERY = 100 21 | GENERATE_EVERY = 500 22 | GENERATE_LENGTH = 1024 23 | SEQ_LEN = 4096 24 | 25 | # helpers 26 | 27 | 28 | def cycle(loader): 29 | while True: 30 | for data in loader: 31 | yield data 32 | 33 | 34 | def decode_token(token): 35 | return str(chr(max(32, token))) 36 | 37 | 38 | def decode_tokens(tokens): 39 | return "".join(list(map(decode_token, tokens))) 40 | 41 | 42 | # instantiate GPT-like decoder model 43 | 44 | model = GatedStateSpacesLM( 45 | num_tokens = 256, 46 | dim = 512, 47 | depth = 8 48 | ) 49 | 50 | model = AutoregressiveWrapper(model) 51 | model.cuda() 52 | 53 | # prepare enwik8 data 54 | 55 | with gzip.open("./data/enwik8.gz") as file: 56 | X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) 57 | trX, vaX = np.split(X, [int(90e6)]) 58 | data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX) 59 | 60 | 61 | class TextSamplerDataset(Dataset): 62 | def __init__(self, data, seq_len): 63 | super().__init__() 64 | self.data = data 65 | self.seq_len = seq_len 66 | 67 | def __getitem__(self, index): 68 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) 69 | full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long() 70 | return full_seq.cuda() 71 | 72 | def __len__(self): 73 | return self.data.size(0) // self.seq_len 74 | 75 | 76 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN) 77 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN) 78 | train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE)) 79 | val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE)) 80 | 81 | # optimizer 82 | 83 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 84 | 85 | # training 86 | 87 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): 88 | model.train() 89 | 90 | for __ in range(GRADIENT_ACCUMULATE_EVERY): 91 | loss = model(next(train_loader)) 92 | loss.backward() 93 | 94 | print(f"training loss: {loss.item()}") 95 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 96 | optim.step() 97 | optim.zero_grad() 98 | 99 | if i % VALIDATE_EVERY == 0: 100 | model.eval() 101 | with torch.no_grad(): 102 | loss = model(next(val_loader)) 103 | print(f"validation loss: {loss.item()}") 104 | 105 | if i % GENERATE_EVERY == 0: 106 | model.eval() 107 | inp = random.choice(val_dataset)[:-1] 108 | prime = decode_tokens(inp) 109 | print(f"%s \n\n %s", (prime, "*" * 100)) 110 | 111 | sample = model.generate(inp[None, ...], GENERATE_LENGTH) 112 | output_str = decode_tokens(sample[0]) 113 | print(output_str) 114 | --------------------------------------------------------------------------------