├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── data
├── README.md
└── enwik8.gz
├── gated_state_spaces_pytorch
├── __init__.py
├── autoregressive_wrapper.py
├── dsconv.py
├── gss.py
└── mhesa.py
├── gss.png
├── setup.py
└── train.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Gated State Spaces - Pytorch
4 |
5 | Implementation of Gated State Spaces, from the paper Long Range Language Modeling via Gated State Spaces, in Pytorch. In particular, it will contain the hybrid version containing local self attention with the long-range GSS.
6 |
7 | It will also contain a few more settings to compare state spaces to a sequence-wise GLU depthwise conv, and even simpler, a parameterized exponential moving average along the sequence dimension. So we get to the bottom of whether state spaces are worth it, or whether it is really all about the `O(L log(L))` FFT convolution trick. Results will be shared in the readme.
8 |
9 | I will also pit the GSS module against the Path-X challenge and see how well it does.
10 |
11 | Update: This paper has beat S4 on LRA using multi-headed EMA + single head attention.
12 |
13 | ## Install
14 |
15 | ```bash
16 | $ pip install gated-state-spaces-pytorch
17 | ```
18 |
19 | ## Usage
20 |
21 | ```python
22 | import torch
23 | from gated_state_spaces_pytorch import GSS
24 |
25 | gss = GSS(
26 | dim = 512, # dimension
27 | dim_expansion_factor = 4, # hidden dimension (expansion factor x dim) = 2048
28 | dss_kernel_N = 512,
29 | dss_kernel_H = 256
30 | )
31 |
32 | x = torch.randn(1, 65536, 512)
33 |
34 | out = gss(x) # (1, 65536, 512)
35 | ```
36 |
37 | Gated state spaces language model
38 |
39 | ```python
40 | import torch
41 | from gated_state_spaces_pytorch import GatedStateSpacesLM
42 |
43 | gss_lm = GatedStateSpacesLM(
44 | num_tokens = 20000,
45 | depth = 12,
46 | dim = 512,
47 | dim_expansion_factor = 4,
48 | dss_kernel_N = 512,
49 | dss_kernel_H = 256
50 | )
51 |
52 | ids = torch.randint(0, 20000, (1, 1024))
53 |
54 | logits = gss_lm(ids) # (1, 1024, 20000)
55 | ```
56 |
57 | ## Todo
58 |
59 | - [x] enwik8
60 | - [x] gss lm class
61 | - [x] add dsconv + learned ema
62 | - [ ] add attention.
63 |
64 | ## Citations
65 |
66 | ```bibtex
67 | @inproceedings{Mehta2022LongRL,
68 | title = {Long Range Language Modeling via Gated State Spaces},
69 | author = {Harsh Mehta and Ankit Gupta and Ashok Cutkosky and Behnam Neyshabur},
70 | year = {2022}
71 | }
72 | ```
73 |
74 | ```bibtex
75 | @misc{woo2022etsformer,
76 | title = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting},
77 | author = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
78 | year = {2022},
79 | eprint = {2202.01381},
80 | archivePrefix = {arXiv},
81 | primaryClass = {cs.LG}
82 | }
83 | ```
84 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Data source
2 |
3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
--------------------------------------------------------------------------------
/data/enwik8.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/gated-state-spaces-pytorch/d484cf5a17382e50749d561a401f727234e10a06/data/enwik8.gz
--------------------------------------------------------------------------------
/gated_state_spaces_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from gated_state_spaces_pytorch.gss import GSS, GatedStateSpacesLM
2 | from gated_state_spaces_pytorch.dsconv import GatedDsConv, GatedDsConvLM
3 | from gated_state_spaces_pytorch.mhesa import GatedExponentialSmoothingLM, GatedMHESA
4 |
--------------------------------------------------------------------------------
/gated_state_spaces_pytorch/autoregressive_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from einops import rearrange
4 | from torch import nn
5 |
6 | # helper function
7 |
8 |
9 | def exists(val):
10 | return val is not None
11 |
12 |
13 | def eval_decorator(fn):
14 | def inner(model, *args, **kwargs):
15 | was_training = model.training
16 | model.eval()
17 | out = fn(model, *args, **kwargs)
18 | model.train(was_training)
19 | return out
20 |
21 | return inner
22 |
23 |
24 | # top k filtering
25 |
26 |
27 | def top_k(logits, thres=0.9):
28 | k = int((1 - thres) * logits.shape[-1])
29 | val, ind = torch.topk(logits, k)
30 | probs = torch.full_like(logits, float('-inf'))
31 | probs.scatter_(1, ind, val)
32 | return probs
33 |
34 |
35 | class AutoregressiveWrapper(nn.Module):
36 | def __init__(self, net, pad_value=0, max_seq_len=4096):
37 | super().__init__()
38 | self.max_seq_len = max_seq_len
39 | self.pad_value = pad_value
40 | self.net = net
41 |
42 | @torch.no_grad()
43 | @eval_decorator
44 | def generate(
45 | self,
46 | start_tokens,
47 | seq_len,
48 | eos_token=None,
49 | temperature=1.0,
50 | filter_thres=0.9,
51 | **kwargs
52 | ):
53 | b, n, device = *start_tokens.shape, start_tokens.device
54 |
55 | out = start_tokens
56 |
57 | for _ in range(seq_len):
58 | logits = self.net(
59 | out[:, -self.max_seq_len:],
60 | **kwargs
61 | )[:, -1]
62 |
63 | filtered_logits = top_k(logits, thres = filter_thres)
64 | probs = F.softmax(filtered_logits / temperature, dim=-1)
65 |
66 | sample = torch.multinomial(probs, 1)
67 | out = torch.cat((out, sample), dim=-1)
68 |
69 | if exists(eos_token):
70 | is_eos_token = out == eos_token
71 |
72 | if is_eos_token.any(dim=-1).all():
73 | # mask out everything after the eos tokens
74 | shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
75 | mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
76 | out = out.masked_fill(mask, self.pad_value)
77 | break
78 |
79 | return out[:, n:]
80 |
81 | def forward(self, x, **kwargs):
82 | inp, labels = x[:, :-1], x[:, 1:]
83 | return self.net(inp, labels = labels, **kwargs)
84 |
--------------------------------------------------------------------------------
/gated_state_spaces_pytorch/dsconv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn, einsum
4 | from torch.fft import rfft, irfft
5 |
6 | from einops import rearrange
7 | from scipy.fftpack import next_fast_len
8 |
9 | # functions
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 | def append_dims(x, num_dims):
15 | if num_dims <= 0:
16 | return x
17 | return x.view(*x.shape, *((1,) * num_dims))
18 |
19 | def conv1d_fft(x, weights, dim = -2, weight_dim = -1):
20 | # O(N log(N)) 1d convolution using some fourier trick
21 |
22 | assert weight_dim >= dim
23 |
24 | N = x.shape[dim]
25 | M = weights.shape[weight_dim]
26 |
27 | fast_len = next_fast_len(N + M - 1)
28 |
29 | f_x = torch.fft.rfft(x, n = fast_len, dim = dim)
30 | f_weight = torch.fft.rfft(weights, n = fast_len, dim = weight_dim)
31 |
32 | f_v_weight = f_x * append_dims(f_weight.conj(), weight_dim - dim)
33 | out = torch.fft.irfft(f_v_weight, fast_len, dim = dim)
34 | out = out.roll(-1, dims = (dim,))
35 |
36 | indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device)
37 | out = out.index_select(dim, indices)
38 | return out
39 |
40 | # classes
41 |
42 | class EfficientDsConv(nn.Module):
43 | def __init__(
44 | self,
45 | *,
46 | dim,
47 | heads
48 | ):
49 | super().__init__()
50 | assert (dim % heads) == 0
51 |
52 | self.heads = heads
53 | self.norm = nn.LayerNorm(dim)
54 |
55 | self.to_weight = nn.Linear(dim, heads, bias = False)
56 |
57 | # params D
58 |
59 | self.param_D = nn.Parameter(torch.randn(dim))
60 |
61 | def forward(self, x):
62 | """
63 | einstein notation:
64 | b - batch
65 | h - heads (or groups)
66 | l - sequence length
67 | d - dimension
68 | """
69 |
70 | device, seq_len = x.device, x.shape[1]
71 | u = self.norm(x)
72 |
73 | # learned weighted residual
74 |
75 | residual = u * self.param_D
76 |
77 | # dsconv kernel depends on sequence length
78 |
79 | K = self.to_weight(x)
80 | K = torch.flip(K, dims = (1,))
81 |
82 | # conv1d fft O(nlog(n))
83 |
84 | u = rearrange(u, '... (h d) -> ... h d', h = self.heads)
85 |
86 | out = conv1d_fft(u, K, dim = -3, weight_dim = -2)
87 |
88 | out = rearrange(out, '... h d -> ... (h d)')
89 |
90 | return out + residual
91 |
92 | class GatedDsConv(nn.Module):
93 | """ Pseudocode 3.2 """
94 | """ except state spaces replaced with regular learned convolution kernel """
95 |
96 | def __init__(
97 | self,
98 | *,
99 | dim,
100 | heads = 8,
101 | dim_dsconv = 512,
102 | dim_expansion_factor = 4,
103 | reverse_seq = False
104 | ):
105 | super().__init__()
106 | assert (dim_dsconv % heads) == 0
107 | self.reverse_seq = reverse_seq
108 |
109 | self.norm = nn.LayerNorm(dim)
110 |
111 | dim_hidden = int(dim_expansion_factor * dim)
112 | self.to_u = nn.Sequential(nn.Linear(dim, dim_hidden, bias = False), nn.GELU())
113 | self.to_v = nn.Sequential(nn.Linear(dim, dim_dsconv, bias = False), nn.GELU())
114 |
115 | self.dsconv = EfficientDsConv(dim = dim_dsconv, heads = heads)
116 |
117 | self.to_gate = nn.Linear(dim_dsconv, dim_hidden, bias = False)
118 | self.to_out = nn.Linear(dim_hidden, dim)
119 |
120 | def forward(self, x):
121 | if self.reverse_seq:
122 | x = torch.flip(x, dims = (1,))
123 |
124 | residual, x = x.clone(), self.norm(x)
125 |
126 | u = self.to_u(x)
127 | v = self.to_v(x)
128 |
129 | v = self.dsconv(v)
130 |
131 | uc = self.to_gate(v)
132 | out = self.to_out(uc * u)
133 |
134 | out = out + residual
135 |
136 | if self.reverse_seq:
137 | out = torch.flip(out, dims = (1,))
138 |
139 | return out
140 |
141 | # Gated Dsconv LM
142 |
143 | class GatedDsConvLM(nn.Module):
144 | def __init__(
145 | self,
146 | *,
147 | num_tokens,
148 | dim,
149 | depth,
150 | heads = 8,
151 | dim_dsconv = 512,
152 | max_seq_len = 2048,
153 | dim_expansion_factor = 4,
154 | ):
155 | super().__init__()
156 | self.token_emb = nn.Embedding(num_tokens, dim)
157 | self.max_seq_len = max_seq_len
158 |
159 | self.layers = nn.ModuleList([])
160 | for _ in range(depth):
161 | self.layers.append(
162 | GatedDsConv(
163 | dim = dim,
164 | heads = heads,
165 | dim_dsconv = dim_dsconv,
166 | dim_expansion_factor = dim_expansion_factor
167 | )
168 | )
169 |
170 | self.to_logits = nn.Linear(dim, num_tokens, bias = False)
171 |
172 | def forward(self, x, labels = None):
173 | assert x.shape[1] <= self.max_seq_len
174 |
175 | x = self.token_emb(x)
176 |
177 | for dsconv in self.layers:
178 | x = dsconv(x)
179 |
180 | logits = self.to_logits(x)
181 |
182 | if not exists(labels):
183 | return logits
184 |
185 | logits = rearrange(logits, 'b n c -> b c n')
186 | return F.cross_entropy(logits, labels)
187 |
--------------------------------------------------------------------------------
/gated_state_spaces_pytorch/gss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn, einsum
4 | from torch.fft import rfft, irfft
5 |
6 | from einops import rearrange
7 |
8 | # functions
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 | # classes
14 |
15 | class DSS(nn.Module):
16 | def __init__(
17 | self,
18 | *,
19 | dim,
20 | kernel_N = 512,
21 | dss_kernel_lambda_imag_exp = True
22 | ):
23 | super().__init__()
24 | self.norm = nn.LayerNorm(dim)
25 |
26 | # Lambda
27 |
28 | self.Lambda_real = nn.Parameter(torch.randn(kernel_N))
29 | self.Lambda_imag = nn.Parameter(torch.randn(kernel_N))
30 |
31 | # C
32 |
33 | self.C_real = nn.Parameter(torch.randn(dim, kernel_N))
34 | self.C_imag = nn.Parameter(torch.randn(dim, kernel_N))
35 |
36 | # params D
37 |
38 | self.param_D = nn.Parameter(torch.randn(dim))
39 |
40 | # whether to exponentiate lambda imag @albertfgu says it is not accurate to s4 original designs (but it is present in the pseudocode)
41 |
42 | self.dss_kernel_lambda_imag_exp = dss_kernel_lambda_imag_exp
43 |
44 | def forward(self, x):
45 | """
46 | einstein notation:
47 | b - batch
48 | l - sequence length
49 | d - dimension
50 | """
51 |
52 | device, seq_len = x.device, x.shape[1]
53 | u = self.norm(x)
54 |
55 | # learned weighted residual
56 |
57 | residual = u * self.param_D
58 |
59 | # derive simple dss kernel
60 |
61 | Lambda_imag = self.Lambda_imag.exp() if self.dss_kernel_lambda_imag_exp else self.Lambda_imag
62 |
63 | Lambda = -self.Lambda_real.exp() + 1j * Lambda_imag
64 | C = self.C_real + 1j * self.C_imag
65 |
66 | arange = torch.arange(seq_len, device = device)
67 |
68 | S = (rearrange(Lambda, 'n -> n 1') * rearrange(arange, 'l -> 1 l')).exp()
69 | C = C * (Lambda.exp() - 1) / Lambda
70 |
71 | K = einsum('h n, n l -> l h', C, S).real
72 |
73 | # conv1d fft O(nlog(n))
74 |
75 | u_f = rfft(u, n = seq_len * 2, dim = -2)
76 | K_f = rfft(K, n = seq_len * 2, dim = -2)
77 |
78 | y = irfft(u_f * K_f, seq_len * 2, dim = -2)[..., :seq_len, :]
79 |
80 | return y + residual
81 |
82 | class GSS(nn.Module):
83 | """ Pseudocode 3.2 """
84 |
85 | def __init__(
86 | self,
87 | *,
88 | dim,
89 | dim_expansion_factor = 4,
90 | dss_kernel_N = 512,
91 | dss_kernel_H = 256,
92 | reverse_seq = False,
93 | dss_kernel_lambda_imag_exp = True
94 | ):
95 | super().__init__()
96 | self.reverse_seq = reverse_seq
97 | self.norm = nn.LayerNorm(dim)
98 |
99 | dim_hidden = int(dim_expansion_factor * dim)
100 | self.to_u = nn.Sequential(nn.Linear(dim, dim_hidden, bias = False), nn.GELU())
101 | self.to_v = nn.Sequential(nn.Linear(dim, dss_kernel_H, bias = False), nn.GELU())
102 |
103 | self.dss = DSS(dim = dss_kernel_H, kernel_N = dss_kernel_N, dss_kernel_lambda_imag_exp = dss_kernel_lambda_imag_exp)
104 |
105 | self.to_gate = nn.Linear(dss_kernel_H, dim_hidden, bias = False)
106 | self.to_out = nn.Linear(dim_hidden, dim)
107 |
108 | def forward(self, x):
109 | if self.reverse_seq:
110 | x = torch.flip(x, dims = (1,))
111 |
112 | residual, x = x.clone(), self.norm(x)
113 |
114 | u = self.to_u(x)
115 | v = self.to_v(x)
116 |
117 | v = self.dss(v)
118 |
119 | uc = self.to_gate(v)
120 | out = self.to_out(uc * u)
121 |
122 | out = out + residual
123 |
124 | if self.reverse_seq:
125 | out = torch.flip(out, dims = (1,))
126 |
127 | return out
128 |
129 | # Gated State Spaces LM
130 |
131 | class GatedStateSpacesLM(nn.Module):
132 | def __init__(
133 | self,
134 | *,
135 | num_tokens,
136 | dim,
137 | depth,
138 | dim_expansion_factor = 4,
139 | dss_kernel_N = 512,
140 | dss_kernel_H = 256,
141 | dss_kernel_lambda_imag_exp = True
142 | ):
143 | super().__init__()
144 | self.token_emb = nn.Embedding(num_tokens, dim)
145 |
146 | self.layers = nn.ModuleList([])
147 | for _ in range(depth):
148 | self.layers.append(
149 | GSS(
150 | dim = dim,
151 | dss_kernel_H = dss_kernel_H,
152 | dss_kernel_N = dss_kernel_N,
153 | dim_expansion_factor = dim_expansion_factor,
154 | dss_kernel_lambda_imag_exp = dss_kernel_lambda_imag_exp
155 | )
156 | )
157 |
158 | self.to_logits = nn.Linear(dim, num_tokens, bias = False)
159 |
160 | def forward(self, x, labels = None):
161 | x = self.token_emb(x)
162 |
163 | for gss in self.layers:
164 | x = gss(x)
165 |
166 | logits = self.to_logits(x)
167 |
168 | if not exists(labels):
169 | return logits
170 |
171 | logits = rearrange(logits, 'b n c -> b c n')
172 | return F.cross_entropy(logits, labels)
173 |
--------------------------------------------------------------------------------
/gated_state_spaces_pytorch/mhesa.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn, einsum
4 | from torch.fft import rfft, irfft
5 |
6 | from einops import rearrange
7 | from scipy.fftpack import next_fast_len
8 |
9 | # functions
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 | def append_dims(x, num_dims):
15 | if num_dims <= 0:
16 | return x
17 | return x.view(*x.shape, *((1,) * num_dims))
18 |
19 | def conv1d_fft(x, weights, dim = -2, weight_dim = -1):
20 | # O(N log(N)) 1d convolution using some fourier trick
21 |
22 | assert weight_dim >= dim
23 |
24 | N = x.shape[dim]
25 | M = weights.shape[weight_dim]
26 |
27 | fast_len = next_fast_len(N + M - 1)
28 |
29 | f_x = torch.fft.rfft(x, n = fast_len, dim = dim)
30 | f_weight = torch.fft.rfft(weights, n = fast_len, dim = weight_dim)
31 |
32 | f_v_weight = f_x * append_dims(f_weight.conj(), weight_dim - dim)
33 | out = torch.fft.irfft(f_v_weight, fast_len, dim = dim)
34 | out = out.roll(-1, dims = (dim,))
35 |
36 | indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device)
37 | out = out.index_select(dim, indices)
38 | return out
39 |
40 | # classes
41 |
42 | class MHESA(nn.Module):
43 | """ used for time-series in ETSFormer https://arxiv.org/abs/2202.01381 """
44 |
45 | def __init__(
46 | self,
47 | *,
48 | dim,
49 | heads,
50 | reverse_seq = False
51 | ):
52 | super().__init__()
53 | assert (dim % heads) == 0
54 | self.reverse_seq = reverse_seq
55 |
56 | self.heads = heads
57 | self.norm = nn.LayerNorm(dim)
58 |
59 | self.alphas = nn.Parameter(torch.randn(heads))
60 | self.dampen_factors = nn.Parameter(torch.randn(heads))
61 |
62 | # params D
63 |
64 | self.param_D = nn.Parameter(torch.randn(dim))
65 |
66 | def forward(self, x):
67 | """
68 | einstein notation:
69 | b - batch
70 | h - heads
71 | l - sequence length
72 | d - dimension
73 | """
74 |
75 | if self.reverse_seq:
76 | x = torch.flip(x, dims = (1,))
77 |
78 | device, seq_len = x.device, x.shape[1]
79 | u = self.norm(x)
80 |
81 | # learned weighted residual
82 |
83 | residual = u * self.param_D
84 |
85 | # weights derived from alphas (learned exponential smoothing decay rate)
86 |
87 | alphas = self.alphas.sigmoid()
88 | dampen_factors = self.dampen_factors.sigmoid()
89 |
90 | reversed_powers = torch.arange(seq_len - 1, -1, -1, device = device)
91 | K = alphas * (((1 - alphas) * dampen_factors) ** rearrange(reversed_powers, '... l -> ... l 1'))
92 |
93 | # conv1d fft O(nlog(n))
94 |
95 | u = rearrange(u, '... (h d) -> ... h d', h = self.heads)
96 |
97 | out = conv1d_fft(u, K, dim = -3, weight_dim = -2)
98 |
99 | out = rearrange(out, '... h d -> ... (h d)')
100 |
101 | out = out + residual
102 |
103 | if self.reverse_seq:
104 | out = torch.flip(out, dims = (1,))
105 |
106 | return out
107 |
108 | class GatedMHESA(nn.Module):
109 | """ Pseudocode 3.2 """
110 | """ except state spaces replaced with multi-head exponential smoothing with learned alpha """
111 | """ used for time-series in ETSFormer https://arxiv.org/abs/2202.01381 """
112 |
113 | def __init__(
114 | self,
115 | *,
116 | dim,
117 | heads = 8,
118 | dim_mhesa = 512,
119 | dim_expansion_factor = 4,
120 | ):
121 | super().__init__()
122 | assert (dim_mhesa % heads) == 0
123 |
124 | self.norm = nn.LayerNorm(dim)
125 |
126 | dim_hidden = int(dim_expansion_factor * dim)
127 | self.to_u = nn.Sequential(nn.Linear(dim, dim_hidden, bias = False), nn.GELU())
128 | self.to_v = nn.Sequential(nn.Linear(dim, dim_mhesa, bias = False), nn.GELU())
129 |
130 | self.mhesa = MHESA(dim = dim_mhesa, heads = heads)
131 |
132 | self.to_gate = nn.Linear(dim_mhesa, dim_hidden, bias = False)
133 | self.to_out = nn.Linear(dim_hidden, dim)
134 |
135 | def forward(self, x):
136 | residual, x = x.clone(), self.norm(x)
137 |
138 | u = self.to_u(x)
139 | v = self.to_v(x)
140 |
141 | v = self.mhesa(v)
142 |
143 | uc = self.to_gate(v)
144 | out = self.to_out(uc * u)
145 |
146 | return out + residual
147 |
148 | # Gated Dsconv LM
149 |
150 | class GatedExponentialSmoothingLM(nn.Module):
151 | def __init__(
152 | self,
153 | *,
154 | num_tokens,
155 | dim,
156 | depth,
157 | heads = 8,
158 | dim_mhesa = 512,
159 | dim_expansion_factor = 4,
160 | ):
161 | super().__init__()
162 | self.token_emb = nn.Embedding(num_tokens, dim)
163 |
164 | self.layers = nn.ModuleList([])
165 | for _ in range(depth):
166 | self.layers.append(
167 | GatedMHESA(
168 | dim = dim,
169 | heads = heads,
170 | dim_mhesa = dim_mhesa,
171 | dim_expansion_factor = dim_expansion_factor
172 | )
173 | )
174 |
175 | self.to_logits = nn.Linear(dim, num_tokens, bias = False)
176 |
177 | def forward(self, x, labels = None):
178 | x = self.token_emb(x)
179 |
180 | for mhesa in self.layers:
181 | x = mhesa(x)
182 |
183 | logits = self.to_logits(x)
184 |
185 | if not exists(labels):
186 | return logits
187 |
188 | logits = rearrange(logits, 'b n c -> b c n')
189 | return F.cross_entropy(logits, labels)
190 |
--------------------------------------------------------------------------------
/gss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/gated-state-spaces-pytorch/d484cf5a17382e50749d561a401f727234e10a06/gss.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'gated-state-spaces-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.1.0',
7 | license='MIT',
8 | description = 'Gated State Spaces - GSS - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/gated-state-spaces-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'state spaces',
17 | 'long context'
18 | ],
19 | install_requires=[
20 | 'einops>=0.4',
21 | 'scipy',
22 | 'torch>=1.6',
23 | ],
24 | classifiers=[
25 | 'Development Status :: 4 - Beta',
26 | 'Intended Audience :: Developers',
27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
28 | 'License :: OSI Approved :: MIT License',
29 | 'Programming Language :: Python :: 3.6',
30 | ],
31 | )
32 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import random
3 |
4 | import numpy as np
5 | import torch
6 | import torch.optim as optim
7 | import tqdm
8 | from torch.nn import functional as F
9 | from torch.utils.data import DataLoader, Dataset
10 |
11 | from gated_state_spaces_pytorch import GatedStateSpacesLM
12 | from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper
13 |
14 | # constants
15 |
16 | NUM_BATCHES = int(1e5)
17 | BATCH_SIZE = 4
18 | GRADIENT_ACCUMULATE_EVERY = 4
19 | LEARNING_RATE = 2e-4
20 | VALIDATE_EVERY = 100
21 | GENERATE_EVERY = 500
22 | GENERATE_LENGTH = 1024
23 | SEQ_LEN = 4096
24 |
25 | # helpers
26 |
27 |
28 | def cycle(loader):
29 | while True:
30 | for data in loader:
31 | yield data
32 |
33 |
34 | def decode_token(token):
35 | return str(chr(max(32, token)))
36 |
37 |
38 | def decode_tokens(tokens):
39 | return "".join(list(map(decode_token, tokens)))
40 |
41 |
42 | # instantiate GPT-like decoder model
43 |
44 | model = GatedStateSpacesLM(
45 | num_tokens = 256,
46 | dim = 512,
47 | depth = 8
48 | )
49 |
50 | model = AutoregressiveWrapper(model)
51 | model.cuda()
52 |
53 | # prepare enwik8 data
54 |
55 | with gzip.open("./data/enwik8.gz") as file:
56 | X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
57 | trX, vaX = np.split(X, [int(90e6)])
58 | data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
59 |
60 |
61 | class TextSamplerDataset(Dataset):
62 | def __init__(self, data, seq_len):
63 | super().__init__()
64 | self.data = data
65 | self.seq_len = seq_len
66 |
67 | def __getitem__(self, index):
68 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
69 | full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
70 | return full_seq.cuda()
71 |
72 | def __len__(self):
73 | return self.data.size(0) // self.seq_len
74 |
75 |
76 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
77 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
78 | train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
79 | val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
80 |
81 | # optimizer
82 |
83 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
84 |
85 | # training
86 |
87 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
88 | model.train()
89 |
90 | for __ in range(GRADIENT_ACCUMULATE_EVERY):
91 | loss = model(next(train_loader))
92 | loss.backward()
93 |
94 | print(f"training loss: {loss.item()}")
95 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
96 | optim.step()
97 | optim.zero_grad()
98 |
99 | if i % VALIDATE_EVERY == 0:
100 | model.eval()
101 | with torch.no_grad():
102 | loss = model(next(val_loader))
103 | print(f"validation loss: {loss.item()}")
104 |
105 | if i % GENERATE_EVERY == 0:
106 | model.eval()
107 | inp = random.choice(val_dataset)[:-1]
108 | prime = decode_tokens(inp)
109 | print(f"%s \n\n %s", (prime, "*" * 100))
110 |
111 | sample = model.generate(inp[None, ...], GENERATE_LENGTH)
112 | output_str = decode_tokens(sample[0])
113 | print(output_str)
114 |
--------------------------------------------------------------------------------