├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── data
├── README.md
└── enwik8.gz
├── mega.png
├── mega_pytorch
├── __init__.py
├── autoregressive_wrapper.py
└── mega_pytorch.py
├── setup.py
└── train.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Mega - Moving Average Equipped Gated Attention - Pytorch
4 |
5 | Implementation of the Mega layer, the Single-head Attention with Multi-headed EMA layer that exists in the architecture that currently holds SOTA on Long Range Arena, beating S4 on Pathfinder-X and all the other tasks save for audio.
6 |
7 | ## Install
8 |
9 | ```bash
10 | $ pip install mega-pytorch
11 | ```
12 |
13 | ## Usage
14 |
15 | The Mega Layer with combination of attention and learned EMA
16 |
17 | ```python
18 | import torch
19 | from mega_pytorch import MegaLayer
20 |
21 | layer = MegaLayer(
22 | dim = 128, # model dimensions
23 | ema_heads = 16, # number of EMA heads
24 | attn_dim_qk = 64, # dimension of queries / keys in attention
25 | attn_dim_value = 256, # dimension of values in attention
26 | laplacian_attn_fn = False, # whether to use softmax (false) or laplacian attention activation fn (true)
27 | )
28 |
29 | x = torch.randn(1, 1024, 128) # (batch, seq, dim)
30 |
31 | out = layer(x) # (1, 1024, 128)
32 | ```
33 |
34 | Full Mega (with layernorm for now)
35 |
36 | ```python
37 | import torch
38 | from mega_pytorch import Mega
39 |
40 | mega = Mega(
41 | num_tokens = 256, # number of tokens
42 | dim = 128, # model dimensions
43 | depth = 6, # depth
44 | ema_heads = 16, # number of EMA heads
45 | attn_dim_qk = 64, # dimension of queries / keys in attention
46 | attn_dim_value = 256, # dimensino of values in attention
47 | laplacian_attn_fn = True, # whether to use softmax (false) or laplacian attention activation fn (true)
48 | )
49 |
50 | x = torch.randint(0, 256, (1, 1024))
51 |
52 | logits = mega(x) # (1, 1024, 256)
53 | ```
54 |
55 | ## Todo
56 |
57 | - [ ] add dynamic positional bias for best length extrapolation arch
58 |
59 | ## Citations
60 |
61 | ```bibtex
62 | @inproceedings{Ma2022MegaMA,
63 | title = {Mega: Moving Average Equipped Gated Attention},
64 | author = {Xuezhe Ma and Chunting Zhou and Xiang Kong and Junxian He and Liangke Gui and Graham Neubig and Jonathan May and Luke Zettlemoyer},
65 | year = {2022}
66 | }
67 | ```
68 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Data source
2 |
3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
--------------------------------------------------------------------------------
/data/enwik8.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/Mega-pytorch/dc765fd5313bf02419a473dfc819b2ab33e046e8/data/enwik8.gz
--------------------------------------------------------------------------------
/mega.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/Mega-pytorch/dc765fd5313bf02419a473dfc819b2ab33e046e8/mega.png
--------------------------------------------------------------------------------
/mega_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from mega_pytorch.mega_pytorch import MegaLayer, Mega, MultiHeadedEMA
2 |
--------------------------------------------------------------------------------
/mega_pytorch/autoregressive_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | from einops import rearrange
6 |
7 | # helper function
8 |
9 | def exists(val):
10 | return val is not None
11 |
12 | def eval_decorator(fn):
13 | def inner(model, *args, **kwargs):
14 | was_training = model.training
15 | model.eval()
16 | out = fn(model, *args, **kwargs)
17 | model.train(was_training)
18 | return out
19 | return inner
20 |
21 | # top k filtering
22 |
23 | def top_k(logits, thres = 0.9):
24 | k = int((1 - thres) * logits.shape[-1])
25 | val, ind = torch.topk(logits, k)
26 | probs = torch.full_like(logits, float('-inf'))
27 | probs.scatter_(1, ind, val)
28 | return probs
29 |
30 | class AutoregressiveWrapper(nn.Module):
31 | def __init__(self, net, pad_value = 0):
32 | super().__init__()
33 | self.pad_value = pad_value
34 | self.net = net
35 |
36 | @torch.no_grad()
37 | @eval_decorator
38 | def generate(self, start_tokens, seq_len, temperature = 1., filter_thres = 0.9, **kwargs):
39 | b, t, device = *start_tokens.shape, start_tokens.device
40 |
41 | out = start_tokens
42 |
43 | for _ in range(seq_len):
44 | logits = self.net(out, **kwargs)[:, -1, :]
45 |
46 | filtered_logits = top_k(logits, thres = filter_thres)
47 | probs = F.softmax(filtered_logits / temperature, dim=-1)
48 |
49 | sample = torch.multinomial(probs, 1)
50 |
51 | out = torch.cat((out, sample), dim=-1)
52 |
53 | out = out[:, t:]
54 | return out
55 |
56 | def forward(self, x, **kwargs):
57 | x_inp, x_labels = x[:, :-1], x[:, 1:]
58 | logits = self.net(x_inp, **kwargs)
59 | return F.cross_entropy(rearrange(logits, 'b c n -> b n c'), x_labels)
60 |
--------------------------------------------------------------------------------
/mega_pytorch/mega_pytorch.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn, einsum
7 | from torch.fft import rfft, irfft
8 |
9 | from einops import rearrange
10 | from einops.layers.torch import Rearrange
11 |
12 | from scipy.fftpack import next_fast_len
13 |
14 | # functions
15 |
16 | def exists(val):
17 | return val is not None
18 |
19 | def identity(t, *args, **kwargs):
20 | return t
21 |
22 | def default(val, d):
23 | return val if exists(val) else d
24 |
25 | def append_dims(x, num_dims):
26 | if num_dims <= 0:
27 | return x
28 | return x.view(*x.shape, *((1,) * num_dims))
29 |
30 | def conv1d_fft(x, weights, dim = -2, weight_dim = -1):
31 | # O(N log(N)) 1d convolution using some fourier trick
32 |
33 | assert weight_dim >= dim
34 |
35 | N = x.shape[dim]
36 | M = weights.shape[weight_dim]
37 |
38 | fast_len = next_fast_len(N + M - 1)
39 |
40 | f_x = rfft(x, n = fast_len, dim = dim)
41 | f_weight = rfft(weights, n = fast_len, dim = weight_dim)
42 |
43 | f_v_weight = f_x * append_dims(f_weight.conj(), weight_dim - dim)
44 | out = irfft(f_v_weight, fast_len, dim = dim)
45 | out = out.roll(-1, dims = (dim,))
46 |
47 | indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device)
48 | out = out.index_select(dim, indices)
49 | return out
50 |
51 | # positional bias for single-headed attention
52 |
53 | class T5RelativePositionBias(nn.Module):
54 | def __init__(
55 | self,
56 | scale,
57 | causal = False,
58 | num_buckets = 32,
59 | max_distance = 128
60 | ):
61 | super().__init__()
62 | self.scale = scale
63 | self.causal = causal
64 | self.num_buckets = num_buckets
65 | self.max_distance = max_distance
66 | self.relative_attention_bias = nn.Embedding(num_buckets, 1)
67 |
68 | @staticmethod
69 | def _relative_position_bucket(
70 | relative_position,
71 | causal = True,
72 | num_buckets = 32,
73 | max_distance = 128
74 | ):
75 | ret = 0
76 | n = -relative_position
77 | if not causal:
78 | num_buckets //= 2
79 | ret += (n < 0).long() * num_buckets
80 | n = torch.abs(n)
81 | else:
82 | n = torch.max(n, torch.zeros_like(n))
83 |
84 | max_exact = num_buckets // 2
85 | is_small = n < max_exact
86 |
87 | val_if_large = max_exact + (
88 | torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
89 | ).long()
90 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
91 |
92 | ret += torch.where(is_small, n, val_if_large)
93 | return ret
94 |
95 | def forward(self, x):
96 | i, j, device = *x.shape[-2:], x.device
97 | q_pos = torch.arange(i, dtype = torch.long, device = device)
98 | k_pos = torch.arange(j, dtype = torch.long, device = device)
99 | rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
100 | rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
101 | values = self.relative_attention_bias(rp_bucket)
102 | bias = rearrange(values, 'i j 1 -> i j')
103 | return bias * self.scale
104 |
105 | # classes
106 |
107 | class LaplacianAttnFn(nn.Module):
108 | def forward(self, x):
109 | mu = math.sqrt(0.5)
110 | std = math.sqrt((4 * math.pi) ** -1)
111 | return (1 + torch.special.erf((x - mu) / (std * math.sqrt(2)))) * 0.5
112 |
113 | class OffsetScale(nn.Module):
114 | def __init__(self, dim, heads = 1):
115 | super().__init__()
116 | self.gamma = nn.Parameter(torch.ones(heads, dim))
117 | self.beta = nn.Parameter(torch.zeros(heads, dim))
118 | nn.init.normal_(self.gamma, std = 0.02)
119 |
120 | def forward(self, x):
121 | out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
122 | return out.unbind(dim = -2)
123 |
124 | class SingleHeadedAttention(nn.Module):
125 | def __init__(
126 | self,
127 | *,
128 | dim,
129 | dim_qk,
130 | dim_value,
131 | causal = False,
132 | laplacian_attn_fn = False
133 | ):
134 | super().__init__()
135 | self.causal = causal
136 | self.laplacian_attn_fn = laplacian_attn_fn
137 |
138 | self.attn_fn = partial(F.softmax, dim = -1) if not laplacian_attn_fn else LaplacianAttnFn()
139 |
140 | self.rel_pos_bias = T5RelativePositionBias(causal = causal, scale = dim_qk ** 0.5)
141 |
142 | self.to_qk = nn.Sequential(
143 | nn.Linear(dim, dim_qk),
144 | nn.SiLU()
145 | )
146 |
147 | self.offsetscale = OffsetScale(dim_qk, heads = 2)
148 |
149 | self.to_v = nn.Sequential(
150 | nn.Linear(dim, dim_value),
151 | nn.SiLU()
152 | )
153 |
154 | def forward(self, x, v_input = None):
155 | seq_len, dim, device, dtype = *x.shape[-2:], x.device, x.dtype
156 |
157 | v_input = default(v_input, x)
158 |
159 | qk, v = self.to_qk(x), self.to_v(v_input)
160 | q, k = self.offsetscale(qk)
161 |
162 | scale = (seq_len ** -1) if self.laplacian_attn_fn else (dim ** -0.5)
163 |
164 | sim = einsum('b i d, b j d -> b i j', q, k) * scale
165 |
166 | sim = sim + self.rel_pos_bias(sim)
167 |
168 | if self.causal:
169 | causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).triu(1)
170 |
171 | if self.causal and not self.laplacian_attn_fn:
172 | # is softmax attention and using large negative value pre-softmax
173 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
174 |
175 | attn = self.attn_fn(sim)
176 |
177 | if self.causal and self.laplacian_attn_fn:
178 | # if using laplacian attention function, zero out upper triangular with 0s
179 | attn = attn.masked_fill(causal_mask, 0.)
180 |
181 | return einsum('b i j, b j d -> b i d', attn, v)
182 |
183 | class MultiHeadedEMA(nn.Module):
184 | def __init__(
185 | self,
186 | *,
187 | dim,
188 | heads,
189 | bidirectional = False,
190 | norm_mhesa_heads = False
191 | ):
192 | super().__init__()
193 | self.bidirectional = bidirectional
194 |
195 | self.expansion = nn.Parameter(torch.randn(heads * (2 if bidirectional else 1), dim))
196 | self.reduction = nn.Parameter(torch.randn(heads * (2 if bidirectional else 1), dim))
197 |
198 | # learned alpha and dampening factors
199 |
200 | self.alphas = nn.Parameter(torch.randn(heads))
201 | self.dampen_factors = nn.Parameter(torch.randn(heads))
202 |
203 | if bidirectional:
204 | self.reverse_alphas = nn.Parameter(torch.randn(heads))
205 | self.reverse_dampen_factors = nn.Parameter(torch.randn(heads))
206 |
207 | self.heads = heads
208 |
209 | self.norm_heads = nn.Identity()
210 |
211 | if norm_mhesa_heads:
212 | # https://arxiv.org/abs/2210.06423 - retnet used sub-ln with some success as groupnorm
213 |
214 | self.norm_heads = nn.Sequential(
215 | Rearrange('b n h d -> b (h d) n'),
216 | nn.GroupNorm(heads, dim * heads),
217 | Rearrange('b (h d) n -> b n h d', h = heads)
218 | )
219 |
220 | def forward(self, x):
221 | device, seq_len = x.device, x.shape[1]
222 |
223 | # project in and split heads
224 |
225 | x = einsum('... d, h d -> ... h d', x, self.expansion)
226 |
227 | if self.bidirectional:
228 | x, x_reversed = x.chunk(2, dim = -2)
229 | x_reversed = torch.flip(x_reversed, dims = (1,))
230 |
231 | # weights derived from alphas (learned exponential smoothing decay rate)
232 |
233 | def apply_learned_ema_with_damping(x, alphas, dampen_factors):
234 | alphas = alphas.sigmoid()
235 | dampen_factors = dampen_factors.sigmoid()
236 |
237 | reversed_powers = torch.arange(seq_len - 1, -1, -1, device = device)
238 | K = alphas * (((1 - alphas) * dampen_factors) ** rearrange(reversed_powers, '... l -> ... l 1'))
239 |
240 | # conv1d fft O(nlog(n))
241 |
242 | return conv1d_fft(x, K, dim = -3, weight_dim = -2)
243 |
244 | x = apply_learned_ema_with_damping(x, self.alphas, self.dampen_factors)
245 |
246 | if self.bidirectional:
247 | x_reversed = apply_learned_ema_with_damping(x_reversed, self.reverse_alphas, self.reverse_dampen_factors)
248 | x_reversed = torch.flip(x_reversed, dims = (1,))
249 | x = torch.cat((x, x_reversed), dim = -2)
250 |
251 | # maybe norm heads
252 |
253 | x = self.norm_heads(x)
254 |
255 | # combine heads and out
256 |
257 | return einsum('... h d, h d -> ... d', x, self.reduction)
258 |
259 | # Mega Layer
260 | # Single headed Attention + Multi-headed EMA, then GRU-esque gating
261 |
262 | class MegaLayer(nn.Module):
263 | def __init__(
264 | self,
265 | *,
266 | dim = 128,
267 | ema_heads = 16,
268 | attn_dim_qk = 64,
269 | attn_dim_value = 256,
270 | laplacian_attn_fn = False,
271 | causal = True,
272 | norm_mhesa_heads = False
273 | ):
274 | super().__init__()
275 |
276 | self.single_headed_attn = SingleHeadedAttention(
277 | dim = dim,
278 | dim_qk = attn_dim_qk,
279 | dim_value = attn_dim_value,
280 | causal = causal,
281 | laplacian_attn_fn = laplacian_attn_fn
282 | )
283 |
284 | self.multi_headed_ema = MultiHeadedEMA(
285 | dim = dim,
286 | heads = ema_heads,
287 | bidirectional = not causal,
288 | norm_mhesa_heads = norm_mhesa_heads
289 | )
290 |
291 | self.to_reset_gate = nn.Sequential(
292 | nn.Linear(dim, attn_dim_value),
293 | nn.SiLU()
294 | )
295 |
296 | self.to_update_gate = nn.Sequential(
297 | nn.Linear(dim, dim),
298 | nn.Sigmoid()
299 | )
300 |
301 | # equation 14, for calculating H
302 |
303 | self.Wh = nn.Parameter(torch.randn(dim, dim))
304 | self.Uh = nn.Parameter(torch.randn(attn_dim_value, dim))
305 | self.bh = nn.Parameter(torch.randn(dim))
306 |
307 | def forward(self, x, residual = None):
308 | residual = default(residual, x)
309 |
310 | ema_output = self.multi_headed_ema(x)
311 | attn_output = self.single_headed_attn(ema_output, x)
312 |
313 | reset_gate = self.to_reset_gate(ema_output)
314 | update_gate = self.to_update_gate(ema_output)
315 |
316 | gated_attn_output = attn_output * reset_gate
317 |
318 | # equation 14
319 |
320 | H = F.silu(ema_output @ self.Wh + gated_attn_output @ self.Uh + self.bh)
321 |
322 | # update gate
323 |
324 | return update_gate * H + (1 - update_gate) * residual
325 |
326 | # Mega
327 |
328 | def FeedForward(dim, ff_mult):
329 | dim_hidden = int(dim * ff_mult)
330 | return nn.Sequential(
331 | nn.Linear(dim, dim_hidden),
332 | nn.GELU(),
333 | nn.Linear(dim_hidden, dim)
334 | )
335 |
336 | class Mega(nn.Module):
337 | def __init__(
338 | self,
339 | *,
340 | dim,
341 | num_tokens,
342 | depth,
343 | ff_mult = 2,
344 | pre_norm = False,
345 | **kwargs
346 | ):
347 | super().__init__()
348 | self.token_emb = nn.Embedding(num_tokens, dim)
349 | self.pre_norm = pre_norm
350 |
351 | self.layers = nn.ModuleList([])
352 | for _ in range(depth):
353 | self.layers.append(nn.ModuleList([
354 | MegaLayer(dim = dim, **kwargs),
355 | nn.LayerNorm(dim),
356 | FeedForward(dim = dim, ff_mult = ff_mult),
357 | nn.LayerNorm(dim)
358 | ]))
359 |
360 | self.to_logits = nn.Sequential(
361 | nn.LayerNorm(dim) if pre_norm else nn.Identity(),
362 | nn.Linear(dim, num_tokens)
363 | )
364 |
365 | def forward(self, x):
366 | pre_norm = self.pre_norm
367 | post_norm = not self.pre_norm
368 |
369 | x = self.token_emb(x)
370 |
371 | for mega_layer, mega_norm, ff, ff_norm in self.layers:
372 | mega_maybe_prenorm = mega_norm if pre_norm else identity
373 | ff_maybe_prenorm = ff_norm if pre_norm else identity
374 |
375 | mega_maybe_postnorm = mega_norm if post_norm else identity
376 | ff_maybe_postnorm = ff_norm if post_norm else identity
377 |
378 | x = mega_layer(mega_maybe_prenorm(x), x)
379 |
380 | x = mega_maybe_postnorm(x)
381 |
382 | x = ff(ff_maybe_prenorm(x)) + x
383 |
384 | x = ff_maybe_postnorm(x)
385 |
386 | return self.to_logits(x)
387 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'Mega-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.1.0',
7 | license='MIT',
8 | description = 'Mega - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/Mega-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'attention mechanism',
17 | 'exponential moving average',
18 | 'long range arena'
19 | ],
20 | install_requires=[
21 | 'einops>=0.4',
22 | 'scipy',
23 | 'torch>=1.6',
24 | ],
25 | classifiers=[
26 | 'Development Status :: 4 - Beta',
27 | 'Intended Audience :: Developers',
28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
29 | 'License :: OSI Approved :: MIT License',
30 | 'Programming Language :: Python :: 3.6',
31 | ],
32 | )
33 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from mega_pytorch.mega_pytorch import Mega
2 | from mega_pytorch.autoregressive_wrapper import AutoregressiveWrapper
3 |
4 | import argparse
5 | import random
6 | import tqdm
7 | import gzip
8 | import numpy as np
9 |
10 | import torch
11 | import torch.optim as optim
12 | from torch.nn import functional as F
13 | from torch.utils.data import DataLoader, Dataset
14 |
15 | # constants
16 |
17 | NUM_BATCHES = int(1e5)
18 | BATCH_SIZE = 4
19 | GRADIENT_ACCUMULATE_EVERY = 4
20 | LEARNING_RATE = 2e-4
21 | VALIDATE_EVERY = 100
22 | GENERATE_EVERY = 500
23 | GENERATE_LENGTH = 512
24 | SEQ_LEN = 512
25 |
26 | # helpers
27 |
28 | def cycle(loader):
29 | while True:
30 | for data in loader:
31 | yield data
32 |
33 | def decode_token(token):
34 | return str(chr(max(32, token)))
35 |
36 | def decode_tokens(tokens):
37 | return ''.join(list(map(decode_token, tokens)))
38 |
39 | # instantiate GPT-like decoder model
40 |
41 | model = Mega(
42 | num_tokens = 256,
43 | dim = 512,
44 | depth = 8
45 | )
46 |
47 | model = AutoregressiveWrapper(model)
48 |
49 | model.cuda()
50 |
51 | # prepare enwik8 data
52 |
53 | with gzip.open('./data/enwik8.gz') as file:
54 | x = np.array(np.frombuffer(file.read(int(95e6)), dtype = np.uint8))
55 | train_x, valid_x = np.split(x, [int(90e6)])
56 | data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
57 |
58 | class TextSamplerDataset(Dataset):
59 | def __init__(self, data, seq_len):
60 | super().__init__()
61 | self.data = data
62 | self.seq_len = seq_len
63 |
64 | def __getitem__(self, index):
65 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
66 | full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
67 | return full_seq.cuda()
68 |
69 | def __len__(self):
70 | return self.data.size(0) // self.seq_len
71 |
72 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
73 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
74 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
75 | val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
76 |
77 | # optimizer
78 |
79 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
80 |
81 | # training
82 |
83 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
84 | model.train()
85 |
86 | for __ in range(GRADIENT_ACCUMULATE_EVERY):
87 | loss = model(next(train_loader))
88 | loss.backward()
89 |
90 | print(f'training loss: {loss.item()}')
91 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
92 | optim.step()
93 | optim.zero_grad()
94 |
95 | if i % VALIDATE_EVERY == 0:
96 | model.eval()
97 | with torch.no_grad():
98 | loss = model(next(val_loader))
99 | print(f'validation loss: {loss.item()}')
100 |
101 | if i % GENERATE_EVERY == 0:
102 | model.eval()
103 | inp = random.choice(val_dataset)[:-1]
104 | prime = decode_tokens(inp)
105 | print(f"\n\n {prime} \n\n {'-' * 80} \n")
106 |
107 | sample = model.generate(inp[None, ...], GENERATE_LENGTH)
108 | output_str = decode_tokens(sample[0])
109 | print(output_str + "\n\n")
110 |
--------------------------------------------------------------------------------