├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── data ├── README.md └── enwik8.gz ├── mega.png ├── mega_pytorch ├── __init__.py ├── autoregressive_wrapper.py └── mega_pytorch.py ├── setup.py └── train.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Mega - Moving Average Equipped Gated Attention - Pytorch 4 | 5 | Implementation of the Mega layer, the Single-head Attention with Multi-headed EMA layer that exists in the architecture that currently holds SOTA on Long Range Arena, beating S4 on Pathfinder-X and all the other tasks save for audio. 6 | 7 | ## Install 8 | 9 | ```bash 10 | $ pip install mega-pytorch 11 | ``` 12 | 13 | ## Usage 14 | 15 | The Mega Layer with combination of attention and learned EMA 16 | 17 | ```python 18 | import torch 19 | from mega_pytorch import MegaLayer 20 | 21 | layer = MegaLayer( 22 | dim = 128, # model dimensions 23 | ema_heads = 16, # number of EMA heads 24 | attn_dim_qk = 64, # dimension of queries / keys in attention 25 | attn_dim_value = 256, # dimension of values in attention 26 | laplacian_attn_fn = False, # whether to use softmax (false) or laplacian attention activation fn (true) 27 | ) 28 | 29 | x = torch.randn(1, 1024, 128) # (batch, seq, dim) 30 | 31 | out = layer(x) # (1, 1024, 128) 32 | ``` 33 | 34 | Full Mega (with layernorm for now) 35 | 36 | ```python 37 | import torch 38 | from mega_pytorch import Mega 39 | 40 | mega = Mega( 41 | num_tokens = 256, # number of tokens 42 | dim = 128, # model dimensions 43 | depth = 6, # depth 44 | ema_heads = 16, # number of EMA heads 45 | attn_dim_qk = 64, # dimension of queries / keys in attention 46 | attn_dim_value = 256, # dimensino of values in attention 47 | laplacian_attn_fn = True, # whether to use softmax (false) or laplacian attention activation fn (true) 48 | ) 49 | 50 | x = torch.randint(0, 256, (1, 1024)) 51 | 52 | logits = mega(x) # (1, 1024, 256) 53 | ``` 54 | 55 | ## Todo 56 | 57 | - [ ] add dynamic positional bias for best length extrapolation arch 58 | 59 | ## Citations 60 | 61 | ```bibtex 62 | @inproceedings{Ma2022MegaMA, 63 | title = {Mega: Moving Average Equipped Gated Attention}, 64 | author = {Xuezhe Ma and Chunting Zhou and Xiang Kong and Junxian He and Liangke Gui and Graham Neubig and Jonathan May and Luke Zettlemoyer}, 65 | year = {2022} 66 | } 67 | ``` 68 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data source 2 | 3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/ -------------------------------------------------------------------------------- /data/enwik8.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/Mega-pytorch/dc765fd5313bf02419a473dfc819b2ab33e046e8/data/enwik8.gz -------------------------------------------------------------------------------- /mega.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/Mega-pytorch/dc765fd5313bf02419a473dfc819b2ab33e046e8/mega.png -------------------------------------------------------------------------------- /mega_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from mega_pytorch.mega_pytorch import MegaLayer, Mega, MultiHeadedEMA 2 | -------------------------------------------------------------------------------- /mega_pytorch/autoregressive_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange 6 | 7 | # helper function 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | def eval_decorator(fn): 13 | def inner(model, *args, **kwargs): 14 | was_training = model.training 15 | model.eval() 16 | out = fn(model, *args, **kwargs) 17 | model.train(was_training) 18 | return out 19 | return inner 20 | 21 | # top k filtering 22 | 23 | def top_k(logits, thres = 0.9): 24 | k = int((1 - thres) * logits.shape[-1]) 25 | val, ind = torch.topk(logits, k) 26 | probs = torch.full_like(logits, float('-inf')) 27 | probs.scatter_(1, ind, val) 28 | return probs 29 | 30 | class AutoregressiveWrapper(nn.Module): 31 | def __init__(self, net, pad_value = 0): 32 | super().__init__() 33 | self.pad_value = pad_value 34 | self.net = net 35 | 36 | @torch.no_grad() 37 | @eval_decorator 38 | def generate(self, start_tokens, seq_len, temperature = 1., filter_thres = 0.9, **kwargs): 39 | b, t, device = *start_tokens.shape, start_tokens.device 40 | 41 | out = start_tokens 42 | 43 | for _ in range(seq_len): 44 | logits = self.net(out, **kwargs)[:, -1, :] 45 | 46 | filtered_logits = top_k(logits, thres = filter_thres) 47 | probs = F.softmax(filtered_logits / temperature, dim=-1) 48 | 49 | sample = torch.multinomial(probs, 1) 50 | 51 | out = torch.cat((out, sample), dim=-1) 52 | 53 | out = out[:, t:] 54 | return out 55 | 56 | def forward(self, x, **kwargs): 57 | x_inp, x_labels = x[:, :-1], x[:, 1:] 58 | logits = self.net(x_inp, **kwargs) 59 | return F.cross_entropy(rearrange(logits, 'b c n -> b n c'), x_labels) 60 | -------------------------------------------------------------------------------- /mega_pytorch/mega_pytorch.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, einsum 7 | from torch.fft import rfft, irfft 8 | 9 | from einops import rearrange 10 | from einops.layers.torch import Rearrange 11 | 12 | from scipy.fftpack import next_fast_len 13 | 14 | # functions 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | def identity(t, *args, **kwargs): 20 | return t 21 | 22 | def default(val, d): 23 | return val if exists(val) else d 24 | 25 | def append_dims(x, num_dims): 26 | if num_dims <= 0: 27 | return x 28 | return x.view(*x.shape, *((1,) * num_dims)) 29 | 30 | def conv1d_fft(x, weights, dim = -2, weight_dim = -1): 31 | # O(N log(N)) 1d convolution using some fourier trick 32 | 33 | assert weight_dim >= dim 34 | 35 | N = x.shape[dim] 36 | M = weights.shape[weight_dim] 37 | 38 | fast_len = next_fast_len(N + M - 1) 39 | 40 | f_x = rfft(x, n = fast_len, dim = dim) 41 | f_weight = rfft(weights, n = fast_len, dim = weight_dim) 42 | 43 | f_v_weight = f_x * append_dims(f_weight.conj(), weight_dim - dim) 44 | out = irfft(f_v_weight, fast_len, dim = dim) 45 | out = out.roll(-1, dims = (dim,)) 46 | 47 | indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device) 48 | out = out.index_select(dim, indices) 49 | return out 50 | 51 | # positional bias for single-headed attention 52 | 53 | class T5RelativePositionBias(nn.Module): 54 | def __init__( 55 | self, 56 | scale, 57 | causal = False, 58 | num_buckets = 32, 59 | max_distance = 128 60 | ): 61 | super().__init__() 62 | self.scale = scale 63 | self.causal = causal 64 | self.num_buckets = num_buckets 65 | self.max_distance = max_distance 66 | self.relative_attention_bias = nn.Embedding(num_buckets, 1) 67 | 68 | @staticmethod 69 | def _relative_position_bucket( 70 | relative_position, 71 | causal = True, 72 | num_buckets = 32, 73 | max_distance = 128 74 | ): 75 | ret = 0 76 | n = -relative_position 77 | if not causal: 78 | num_buckets //= 2 79 | ret += (n < 0).long() * num_buckets 80 | n = torch.abs(n) 81 | else: 82 | n = torch.max(n, torch.zeros_like(n)) 83 | 84 | max_exact = num_buckets // 2 85 | is_small = n < max_exact 86 | 87 | val_if_large = max_exact + ( 88 | torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) 89 | ).long() 90 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) 91 | 92 | ret += torch.where(is_small, n, val_if_large) 93 | return ret 94 | 95 | def forward(self, x): 96 | i, j, device = *x.shape[-2:], x.device 97 | q_pos = torch.arange(i, dtype = torch.long, device = device) 98 | k_pos = torch.arange(j, dtype = torch.long, device = device) 99 | rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') 100 | rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance) 101 | values = self.relative_attention_bias(rp_bucket) 102 | bias = rearrange(values, 'i j 1 -> i j') 103 | return bias * self.scale 104 | 105 | # classes 106 | 107 | class LaplacianAttnFn(nn.Module): 108 | def forward(self, x): 109 | mu = math.sqrt(0.5) 110 | std = math.sqrt((4 * math.pi) ** -1) 111 | return (1 + torch.special.erf((x - mu) / (std * math.sqrt(2)))) * 0.5 112 | 113 | class OffsetScale(nn.Module): 114 | def __init__(self, dim, heads = 1): 115 | super().__init__() 116 | self.gamma = nn.Parameter(torch.ones(heads, dim)) 117 | self.beta = nn.Parameter(torch.zeros(heads, dim)) 118 | nn.init.normal_(self.gamma, std = 0.02) 119 | 120 | def forward(self, x): 121 | out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta 122 | return out.unbind(dim = -2) 123 | 124 | class SingleHeadedAttention(nn.Module): 125 | def __init__( 126 | self, 127 | *, 128 | dim, 129 | dim_qk, 130 | dim_value, 131 | causal = False, 132 | laplacian_attn_fn = False 133 | ): 134 | super().__init__() 135 | self.causal = causal 136 | self.laplacian_attn_fn = laplacian_attn_fn 137 | 138 | self.attn_fn = partial(F.softmax, dim = -1) if not laplacian_attn_fn else LaplacianAttnFn() 139 | 140 | self.rel_pos_bias = T5RelativePositionBias(causal = causal, scale = dim_qk ** 0.5) 141 | 142 | self.to_qk = nn.Sequential( 143 | nn.Linear(dim, dim_qk), 144 | nn.SiLU() 145 | ) 146 | 147 | self.offsetscale = OffsetScale(dim_qk, heads = 2) 148 | 149 | self.to_v = nn.Sequential( 150 | nn.Linear(dim, dim_value), 151 | nn.SiLU() 152 | ) 153 | 154 | def forward(self, x, v_input = None): 155 | seq_len, dim, device, dtype = *x.shape[-2:], x.device, x.dtype 156 | 157 | v_input = default(v_input, x) 158 | 159 | qk, v = self.to_qk(x), self.to_v(v_input) 160 | q, k = self.offsetscale(qk) 161 | 162 | scale = (seq_len ** -1) if self.laplacian_attn_fn else (dim ** -0.5) 163 | 164 | sim = einsum('b i d, b j d -> b i j', q, k) * scale 165 | 166 | sim = sim + self.rel_pos_bias(sim) 167 | 168 | if self.causal: 169 | causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).triu(1) 170 | 171 | if self.causal and not self.laplacian_attn_fn: 172 | # is softmax attention and using large negative value pre-softmax 173 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 174 | 175 | attn = self.attn_fn(sim) 176 | 177 | if self.causal and self.laplacian_attn_fn: 178 | # if using laplacian attention function, zero out upper triangular with 0s 179 | attn = attn.masked_fill(causal_mask, 0.) 180 | 181 | return einsum('b i j, b j d -> b i d', attn, v) 182 | 183 | class MultiHeadedEMA(nn.Module): 184 | def __init__( 185 | self, 186 | *, 187 | dim, 188 | heads, 189 | bidirectional = False, 190 | norm_mhesa_heads = False 191 | ): 192 | super().__init__() 193 | self.bidirectional = bidirectional 194 | 195 | self.expansion = nn.Parameter(torch.randn(heads * (2 if bidirectional else 1), dim)) 196 | self.reduction = nn.Parameter(torch.randn(heads * (2 if bidirectional else 1), dim)) 197 | 198 | # learned alpha and dampening factors 199 | 200 | self.alphas = nn.Parameter(torch.randn(heads)) 201 | self.dampen_factors = nn.Parameter(torch.randn(heads)) 202 | 203 | if bidirectional: 204 | self.reverse_alphas = nn.Parameter(torch.randn(heads)) 205 | self.reverse_dampen_factors = nn.Parameter(torch.randn(heads)) 206 | 207 | self.heads = heads 208 | 209 | self.norm_heads = nn.Identity() 210 | 211 | if norm_mhesa_heads: 212 | # https://arxiv.org/abs/2210.06423 - retnet used sub-ln with some success as groupnorm 213 | 214 | self.norm_heads = nn.Sequential( 215 | Rearrange('b n h d -> b (h d) n'), 216 | nn.GroupNorm(heads, dim * heads), 217 | Rearrange('b (h d) n -> b n h d', h = heads) 218 | ) 219 | 220 | def forward(self, x): 221 | device, seq_len = x.device, x.shape[1] 222 | 223 | # project in and split heads 224 | 225 | x = einsum('... d, h d -> ... h d', x, self.expansion) 226 | 227 | if self.bidirectional: 228 | x, x_reversed = x.chunk(2, dim = -2) 229 | x_reversed = torch.flip(x_reversed, dims = (1,)) 230 | 231 | # weights derived from alphas (learned exponential smoothing decay rate) 232 | 233 | def apply_learned_ema_with_damping(x, alphas, dampen_factors): 234 | alphas = alphas.sigmoid() 235 | dampen_factors = dampen_factors.sigmoid() 236 | 237 | reversed_powers = torch.arange(seq_len - 1, -1, -1, device = device) 238 | K = alphas * (((1 - alphas) * dampen_factors) ** rearrange(reversed_powers, '... l -> ... l 1')) 239 | 240 | # conv1d fft O(nlog(n)) 241 | 242 | return conv1d_fft(x, K, dim = -3, weight_dim = -2) 243 | 244 | x = apply_learned_ema_with_damping(x, self.alphas, self.dampen_factors) 245 | 246 | if self.bidirectional: 247 | x_reversed = apply_learned_ema_with_damping(x_reversed, self.reverse_alphas, self.reverse_dampen_factors) 248 | x_reversed = torch.flip(x_reversed, dims = (1,)) 249 | x = torch.cat((x, x_reversed), dim = -2) 250 | 251 | # maybe norm heads 252 | 253 | x = self.norm_heads(x) 254 | 255 | # combine heads and out 256 | 257 | return einsum('... h d, h d -> ... d', x, self.reduction) 258 | 259 | # Mega Layer 260 | # Single headed Attention + Multi-headed EMA, then GRU-esque gating 261 | 262 | class MegaLayer(nn.Module): 263 | def __init__( 264 | self, 265 | *, 266 | dim = 128, 267 | ema_heads = 16, 268 | attn_dim_qk = 64, 269 | attn_dim_value = 256, 270 | laplacian_attn_fn = False, 271 | causal = True, 272 | norm_mhesa_heads = False 273 | ): 274 | super().__init__() 275 | 276 | self.single_headed_attn = SingleHeadedAttention( 277 | dim = dim, 278 | dim_qk = attn_dim_qk, 279 | dim_value = attn_dim_value, 280 | causal = causal, 281 | laplacian_attn_fn = laplacian_attn_fn 282 | ) 283 | 284 | self.multi_headed_ema = MultiHeadedEMA( 285 | dim = dim, 286 | heads = ema_heads, 287 | bidirectional = not causal, 288 | norm_mhesa_heads = norm_mhesa_heads 289 | ) 290 | 291 | self.to_reset_gate = nn.Sequential( 292 | nn.Linear(dim, attn_dim_value), 293 | nn.SiLU() 294 | ) 295 | 296 | self.to_update_gate = nn.Sequential( 297 | nn.Linear(dim, dim), 298 | nn.Sigmoid() 299 | ) 300 | 301 | # equation 14, for calculating H 302 | 303 | self.Wh = nn.Parameter(torch.randn(dim, dim)) 304 | self.Uh = nn.Parameter(torch.randn(attn_dim_value, dim)) 305 | self.bh = nn.Parameter(torch.randn(dim)) 306 | 307 | def forward(self, x, residual = None): 308 | residual = default(residual, x) 309 | 310 | ema_output = self.multi_headed_ema(x) 311 | attn_output = self.single_headed_attn(ema_output, x) 312 | 313 | reset_gate = self.to_reset_gate(ema_output) 314 | update_gate = self.to_update_gate(ema_output) 315 | 316 | gated_attn_output = attn_output * reset_gate 317 | 318 | # equation 14 319 | 320 | H = F.silu(ema_output @ self.Wh + gated_attn_output @ self.Uh + self.bh) 321 | 322 | # update gate 323 | 324 | return update_gate * H + (1 - update_gate) * residual 325 | 326 | # Mega 327 | 328 | def FeedForward(dim, ff_mult): 329 | dim_hidden = int(dim * ff_mult) 330 | return nn.Sequential( 331 | nn.Linear(dim, dim_hidden), 332 | nn.GELU(), 333 | nn.Linear(dim_hidden, dim) 334 | ) 335 | 336 | class Mega(nn.Module): 337 | def __init__( 338 | self, 339 | *, 340 | dim, 341 | num_tokens, 342 | depth, 343 | ff_mult = 2, 344 | pre_norm = False, 345 | **kwargs 346 | ): 347 | super().__init__() 348 | self.token_emb = nn.Embedding(num_tokens, dim) 349 | self.pre_norm = pre_norm 350 | 351 | self.layers = nn.ModuleList([]) 352 | for _ in range(depth): 353 | self.layers.append(nn.ModuleList([ 354 | MegaLayer(dim = dim, **kwargs), 355 | nn.LayerNorm(dim), 356 | FeedForward(dim = dim, ff_mult = ff_mult), 357 | nn.LayerNorm(dim) 358 | ])) 359 | 360 | self.to_logits = nn.Sequential( 361 | nn.LayerNorm(dim) if pre_norm else nn.Identity(), 362 | nn.Linear(dim, num_tokens) 363 | ) 364 | 365 | def forward(self, x): 366 | pre_norm = self.pre_norm 367 | post_norm = not self.pre_norm 368 | 369 | x = self.token_emb(x) 370 | 371 | for mega_layer, mega_norm, ff, ff_norm in self.layers: 372 | mega_maybe_prenorm = mega_norm if pre_norm else identity 373 | ff_maybe_prenorm = ff_norm if pre_norm else identity 374 | 375 | mega_maybe_postnorm = mega_norm if post_norm else identity 376 | ff_maybe_postnorm = ff_norm if post_norm else identity 377 | 378 | x = mega_layer(mega_maybe_prenorm(x), x) 379 | 380 | x = mega_maybe_postnorm(x) 381 | 382 | x = ff(ff_maybe_prenorm(x)) + x 383 | 384 | x = ff_maybe_postnorm(x) 385 | 386 | return self.to_logits(x) 387 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'Mega-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.1.0', 7 | license='MIT', 8 | description = 'Mega - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/Mega-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'attention mechanism', 17 | 'exponential moving average', 18 | 'long range arena' 19 | ], 20 | install_requires=[ 21 | 'einops>=0.4', 22 | 'scipy', 23 | 'torch>=1.6', 24 | ], 25 | classifiers=[ 26 | 'Development Status :: 4 - Beta', 27 | 'Intended Audience :: Developers', 28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 29 | 'License :: OSI Approved :: MIT License', 30 | 'Programming Language :: Python :: 3.6', 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from mega_pytorch.mega_pytorch import Mega 2 | from mega_pytorch.autoregressive_wrapper import AutoregressiveWrapper 3 | 4 | import argparse 5 | import random 6 | import tqdm 7 | import gzip 8 | import numpy as np 9 | 10 | import torch 11 | import torch.optim as optim 12 | from torch.nn import functional as F 13 | from torch.utils.data import DataLoader, Dataset 14 | 15 | # constants 16 | 17 | NUM_BATCHES = int(1e5) 18 | BATCH_SIZE = 4 19 | GRADIENT_ACCUMULATE_EVERY = 4 20 | LEARNING_RATE = 2e-4 21 | VALIDATE_EVERY = 100 22 | GENERATE_EVERY = 500 23 | GENERATE_LENGTH = 512 24 | SEQ_LEN = 512 25 | 26 | # helpers 27 | 28 | def cycle(loader): 29 | while True: 30 | for data in loader: 31 | yield data 32 | 33 | def decode_token(token): 34 | return str(chr(max(32, token))) 35 | 36 | def decode_tokens(tokens): 37 | return ''.join(list(map(decode_token, tokens))) 38 | 39 | # instantiate GPT-like decoder model 40 | 41 | model = Mega( 42 | num_tokens = 256, 43 | dim = 512, 44 | depth = 8 45 | ) 46 | 47 | model = AutoregressiveWrapper(model) 48 | 49 | model.cuda() 50 | 51 | # prepare enwik8 data 52 | 53 | with gzip.open('./data/enwik8.gz') as file: 54 | x = np.array(np.frombuffer(file.read(int(95e6)), dtype = np.uint8)) 55 | train_x, valid_x = np.split(x, [int(90e6)]) 56 | data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x) 57 | 58 | class TextSamplerDataset(Dataset): 59 | def __init__(self, data, seq_len): 60 | super().__init__() 61 | self.data = data 62 | self.seq_len = seq_len 63 | 64 | def __getitem__(self, index): 65 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) 66 | full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long() 67 | return full_seq.cuda() 68 | 69 | def __len__(self): 70 | return self.data.size(0) // self.seq_len 71 | 72 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN) 73 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN) 74 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE)) 75 | val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE)) 76 | 77 | # optimizer 78 | 79 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 80 | 81 | # training 82 | 83 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'): 84 | model.train() 85 | 86 | for __ in range(GRADIENT_ACCUMULATE_EVERY): 87 | loss = model(next(train_loader)) 88 | loss.backward() 89 | 90 | print(f'training loss: {loss.item()}') 91 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 92 | optim.step() 93 | optim.zero_grad() 94 | 95 | if i % VALIDATE_EVERY == 0: 96 | model.eval() 97 | with torch.no_grad(): 98 | loss = model(next(val_loader)) 99 | print(f'validation loss: {loss.item()}') 100 | 101 | if i % GENERATE_EVERY == 0: 102 | model.eval() 103 | inp = random.choice(val_dataset)[:-1] 104 | prime = decode_tokens(inp) 105 | print(f"\n\n {prime} \n\n {'-' * 80} \n") 106 | 107 | sample = model.generate(inp[None, ...], GENERATE_LENGTH) 108 | output_str = decode_tokens(sample[0]) 109 | print(output_str + "\n\n") 110 | --------------------------------------------------------------------------------