├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── data ├── README.md └── enwik8.gz ├── metaformer.png ├── metaformer_gpt ├── __init__.py ├── autoregressive_wrapper.py └── metaformer_gpt.py ├── setup.py └── train.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Metaformer - GPT (wip) 4 | 5 | Implementation of Metaformer, but in an autoregressive manner. In particular, they propose simply using mean centering as a way to do token mixing in a parameter-less fashion, alternating with feedforwards. 6 | 7 | ## Install 8 | 9 | ```bash 10 | $ pip install metaformer-gpt 11 | ``` 12 | 13 | ## Usage 14 | 15 | ```python 16 | import torch 17 | from metaformer_gpt import MetaformerGPT 18 | 19 | gpt = MetaformerGPT( 20 | num_tokens = 256, 21 | dim = 512, 22 | depth = 8 23 | ) 24 | 25 | ids = torch.randint(0, 256, (1, 1024)) 26 | logits = gpt(ids) # (1, 1024, 256) 27 | ``` 28 | 29 | ## Citations 30 | 31 | ```bibtex 32 | @article{Yu2021MetaFormerIA, 33 | title = {MetaFormer is Actually What You Need for Vision}, 34 | author = {Weihao Yu and Mi Luo and Pan Zhou and Chenyang Si and Yichen Zhou and Xinchao Wang and Jiashi Feng and Shuicheng Yan}, 35 | journal = {ArXiv}, 36 | year = {2021}, 37 | volume = {abs/2111.11418} 38 | } 39 | ``` 40 | 41 | ```bibtex 42 | @misc{woo2022etsformer, 43 | title = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting}, 44 | author = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi}, 45 | year = {2022}, 46 | eprint = {2202.01381}, 47 | archivePrefix = {arXiv}, 48 | primaryClass = {cs.LG} 49 | } 50 | ``` 51 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data source 2 | 3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/ -------------------------------------------------------------------------------- /data/enwik8.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/metaformer-gpt/9853266c8dc8daf985b0f346c6fbf2b63991c7f3/data/enwik8.gz -------------------------------------------------------------------------------- /metaformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/metaformer-gpt/9853266c8dc8daf985b0f346c6fbf2b63991c7f3/metaformer.png -------------------------------------------------------------------------------- /metaformer_gpt/__init__.py: -------------------------------------------------------------------------------- 1 | from metaformer_gpt.metaformer_gpt import MetaformerGPT, MultiheadExponentialTimeDecay 2 | -------------------------------------------------------------------------------- /metaformer_gpt/autoregressive_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | # helper function 7 | 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | 13 | def eval_decorator(fn): 14 | def inner(model, *args, **kwargs): 15 | was_training = model.training 16 | model.eval() 17 | out = fn(model, *args, **kwargs) 18 | model.train(was_training) 19 | return out 20 | 21 | return inner 22 | 23 | 24 | # top k filtering 25 | 26 | 27 | def top_k(logits, thres=0.9): 28 | k = int((1 - thres) * logits.shape[-1]) 29 | val, ind = torch.topk(logits, k) 30 | probs = torch.full_like(logits, float("-inf")) 31 | probs.scatter_(1, ind, val) 32 | return probs 33 | 34 | 35 | class AutoregressiveWrapper(nn.Module): 36 | def __init__(self, net, max_seq_len=2048, pad_value=0): 37 | super().__init__() 38 | self.max_seq_len = max_seq_len 39 | self.pad_value = pad_value 40 | self.net = net 41 | 42 | @torch.no_grad() 43 | @eval_decorator 44 | def generate( 45 | self, 46 | start_tokens, 47 | seq_len, 48 | eos_token=None, 49 | temperature=1.0, 50 | filter_thres=0.9, 51 | **kwargs 52 | ): 53 | b, t, device = *start_tokens.shape, start_tokens.device 54 | 55 | out = start_tokens 56 | 57 | for _ in range(seq_len): 58 | logits = self.net(out, **kwargs)[:, -1, :] 59 | 60 | filtered_logits = top_k(logits, thres=filter_thres) 61 | probs = F.softmax(filtered_logits / temperature, dim=-1) 62 | 63 | sample = torch.multinomial(probs, 1) 64 | 65 | out = torch.cat((out, sample), dim=-1) 66 | 67 | if exists(eos_token): 68 | is_eos_token = out == eos_token 69 | 70 | if is_eos_token.any(dim=-1).all(): 71 | # mask out everything after the eos tokens 72 | shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) 73 | mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1 74 | out = out.masked_fill(mask, self.pad_value) 75 | break 76 | 77 | out = out[:, t:] 78 | return out 79 | 80 | def forward(self, x, **kwargs): 81 | x_inp, x_labels = x[:, :-1], x[:, 1:] 82 | logits = self.net(x_inp, **kwargs) 83 | return F.cross_entropy(rearrange(logits, "b c n -> b n c"), x_labels) 84 | -------------------------------------------------------------------------------- /metaformer_gpt/metaformer_gpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange, repeat 4 | 5 | from scipy.fftpack import next_fast_len 6 | 7 | # helper functions 8 | 9 | def cummean(x, *, dim): 10 | numer = x.cumsum(dim = dim) 11 | denom = torch.arange(x.shape[1], device = x.device) + 1 12 | return numer / rearrange(denom, '... -> ... 1') 13 | 14 | def conv1d_fft(x, weights, dim = -2, weight_dim = -1): 15 | # O(N log(N)) 1d convolution using some fourier trick 16 | 17 | N = x.shape[dim] 18 | M = weights.shape[weight_dim] 19 | 20 | fast_len = next_fast_len(N + M - 1) 21 | 22 | f_x = torch.fft.rfft(x, n = fast_len, dim = dim) 23 | f_weight = torch.fft.rfft(weights, n = fast_len, dim = weight_dim) 24 | 25 | f_v_weight = f_x * rearrange(f_weight.conj(), '... -> ... 1') 26 | out = torch.fft.irfft(f_v_weight, fast_len, dim = dim) 27 | out = out.roll(-1, dims = (dim,)) 28 | 29 | indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device) 30 | out = out.index_select(dim, indices) 31 | return out 32 | 33 | # classes 34 | 35 | class MeanCenteringPool(nn.Module): 36 | def __init__( 37 | self, 38 | dim 39 | ): 40 | super().__init__() 41 | self.norm = nn.LayerNorm(dim) 42 | self.proj = nn.Linear(dim, dim, bias = False) 43 | 44 | def forward(self, x): 45 | x = self.norm(x) 46 | x = cummean(x, dim = 1) - x 47 | return self.proj(x) 48 | 49 | class MultiheadExponentialTimeDecay(nn.Module): 50 | def __init__( 51 | self, 52 | dim, 53 | *, 54 | heads = 8, 55 | dim_head = 64 56 | ): 57 | super().__init__() 58 | self.heads = heads 59 | inner_dim = heads * dim_head 60 | 61 | self.norm = nn.LayerNorm(dim) 62 | self.alpha = nn.Parameter(torch.randn(heads)) 63 | 64 | self.project_in = nn.Linear(dim, inner_dim, bias = False) 65 | self.project_out = nn.Linear(inner_dim, dim, bias = False) 66 | 67 | def forward(self, x): 68 | b, n, d, h, device = *x.shape, self.heads, x.device 69 | 70 | x = self.norm(x) 71 | 72 | # linear project in 73 | 74 | x = self.project_in(x) 75 | 76 | # split out heads 77 | 78 | x = rearrange(x, 'b n (h d) -> b h n d', h = h) 79 | 80 | # prepare exponential alpha 81 | 82 | alpha = self.alpha.sigmoid() 83 | alpha = rearrange(alpha, 'h -> h 1') 84 | 85 | # arange == powers 86 | 87 | arange = torch.arange(n, device = device) 88 | weights = alpha * (1 - alpha) ** torch.flip(arange, dims = (0,)) 89 | output = conv1d_fft(x, weights) 90 | 91 | # merge heads 92 | 93 | output = rearrange(output, 'b h n d -> b n (h d)') 94 | return self.project_out(output) 95 | 96 | def FeedForward(dim, mult = 4): 97 | hidden_dim = int(dim * mult) 98 | return nn.Sequential( 99 | nn.LayerNorm(dim), 100 | nn.Linear(dim, hidden_dim, bias = False), 101 | nn.GELU(), 102 | nn.Linear(hidden_dim, dim, bias = False) 103 | ) 104 | 105 | class MetaformerGPT(nn.Module): 106 | def __init__( 107 | self, 108 | *, 109 | num_tokens, 110 | dim, 111 | depth, 112 | heads = 16, 113 | dim_head = 32, 114 | max_seq_len = 2048, 115 | ff_mult = 4 116 | ): 117 | super().__init__() 118 | self.token_emb = nn.Embedding(num_tokens, dim) 119 | self.pos_emb = nn.Embedding(max_seq_len, dim) 120 | 121 | self.layers = nn.ModuleList([]) 122 | for _ in range(depth): 123 | self.layers.append(nn.ModuleList([ 124 | MultiheadExponentialTimeDecay(dim, heads = heads, dim_head = dim_head), 125 | MeanCenteringPool(dim), 126 | FeedForward(dim, mult = ff_mult) 127 | ])) 128 | 129 | self.to_logits = nn.Sequential( 130 | nn.LayerNorm(dim), 131 | nn.Linear(dim, num_tokens, bias = False) 132 | ) 133 | 134 | def forward(self, x): 135 | n, device = x.shape[1], x.device 136 | 137 | x = self.token_emb(x) 138 | x = x + self.pos_emb(torch.arange(n, device = device)) 139 | 140 | for mh_esa, pool, ff in self.layers: 141 | x = mh_esa(x) + x 142 | x = pool(x) + x 143 | x = ff(x) + x 144 | 145 | return self.to_logits(x) 146 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'metaformer-gpt', 5 | packages = find_packages(exclude=[]), 6 | version = '0.0.5', 7 | license='MIT', 8 | description = 'Metaformer - GPT', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/metaformer-gpt', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'transformers', 17 | 'attention-less' 18 | ], 19 | install_requires=[ 20 | 'einops>=0.4', 21 | 'scipy', 22 | 'torch>=1.6', 23 | ], 24 | classifiers=[ 25 | 'Development Status :: 4 - Beta', 26 | 'Intended Audience :: Developers', 27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 28 | 'License :: OSI Approved :: MIT License', 29 | 'Programming Language :: Python :: 3.6', 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | import tqdm 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from metaformer_gpt import MetaformerGPT 12 | from metaformer_gpt.autoregressive_wrapper import AutoregressiveWrapper 13 | 14 | # constants 15 | 16 | NUM_BATCHES = int(1e5) 17 | BATCH_SIZE = 4 18 | GRADIENT_ACCUMULATE_EVERY = 4 19 | LEARNING_RATE = 2e-4 20 | VALIDATE_EVERY = 100 21 | GENERATE_EVERY = 500 22 | GENERATE_LENGTH = 512 23 | SEQ_LEN = 1024 24 | 25 | # helpers 26 | 27 | 28 | def cycle(loader): 29 | while True: 30 | for data in loader: 31 | yield data 32 | 33 | 34 | def decode_token(token): 35 | return str(chr(max(32, token))) 36 | 37 | 38 | def decode_tokens(tokens): 39 | return "".join(list(map(decode_token, tokens))) 40 | 41 | 42 | # instantiate GPT-like decoder model 43 | 44 | model = MetaformerGPT( 45 | num_tokens = 256, 46 | dim = 512, 47 | depth = 8, 48 | heads = 16, 49 | dim_head = 32 50 | ) 51 | 52 | model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) 53 | model.cuda() 54 | 55 | # prepare enwik8 data 56 | 57 | with gzip.open("./data/enwik8.gz") as file: 58 | X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) 59 | trX, vaX = np.split(X, [int(90e6)]) 60 | data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX) 61 | 62 | 63 | class TextSamplerDataset(Dataset): 64 | def __init__(self, data, seq_len): 65 | super().__init__() 66 | self.data = data 67 | self.seq_len = seq_len 68 | 69 | def __getitem__(self, index): 70 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) 71 | full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long() 72 | return full_seq.cuda() 73 | 74 | def __len__(self): 75 | return self.data.size(0) // self.seq_len 76 | 77 | 78 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN) 79 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN) 80 | train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE)) 81 | val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE)) 82 | 83 | # optimizer 84 | 85 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 86 | 87 | # training 88 | 89 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): 90 | model.train() 91 | 92 | for __ in range(GRADIENT_ACCUMULATE_EVERY): 93 | loss = model(next(train_loader)) 94 | loss.backward() 95 | 96 | print(f"training loss: {loss.item()}") 97 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 98 | optim.step() 99 | optim.zero_grad() 100 | 101 | if i % VALIDATE_EVERY == 0: 102 | model.eval() 103 | with torch.no_grad(): 104 | loss = model(next(val_loader)) 105 | print(f"validation loss: {loss.item()}") 106 | 107 | if i % GENERATE_EVERY == 0: 108 | model.eval() 109 | inp = random.choice(val_dataset)[:-1] 110 | prime = decode_tokens(inp) 111 | print(f"%s \n\n %s", (prime, "*" * 100)) 112 | 113 | sample = model.generate(inp[None, ...], GENERATE_LENGTH) 114 | output_str = decode_tokens(sample[0]) 115 | print(output_str) 116 | --------------------------------------------------------------------------------