├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── data
├── README.md
└── enwik8.gz
├── metaformer.png
├── metaformer_gpt
├── __init__.py
├── autoregressive_wrapper.py
└── metaformer_gpt.py
├── setup.py
└── train.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Metaformer - GPT (wip)
4 |
5 | Implementation of Metaformer, but in an autoregressive manner. In particular, they propose simply using mean centering as a way to do token mixing in a parameter-less fashion, alternating with feedforwards.
6 |
7 | ## Install
8 |
9 | ```bash
10 | $ pip install metaformer-gpt
11 | ```
12 |
13 | ## Usage
14 |
15 | ```python
16 | import torch
17 | from metaformer_gpt import MetaformerGPT
18 |
19 | gpt = MetaformerGPT(
20 | num_tokens = 256,
21 | dim = 512,
22 | depth = 8
23 | )
24 |
25 | ids = torch.randint(0, 256, (1, 1024))
26 | logits = gpt(ids) # (1, 1024, 256)
27 | ```
28 |
29 | ## Citations
30 |
31 | ```bibtex
32 | @article{Yu2021MetaFormerIA,
33 | title = {MetaFormer is Actually What You Need for Vision},
34 | author = {Weihao Yu and Mi Luo and Pan Zhou and Chenyang Si and Yichen Zhou and Xinchao Wang and Jiashi Feng and Shuicheng Yan},
35 | journal = {ArXiv},
36 | year = {2021},
37 | volume = {abs/2111.11418}
38 | }
39 | ```
40 |
41 | ```bibtex
42 | @misc{woo2022etsformer,
43 | title = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting},
44 | author = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
45 | year = {2022},
46 | eprint = {2202.01381},
47 | archivePrefix = {arXiv},
48 | primaryClass = {cs.LG}
49 | }
50 | ```
51 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Data source
2 |
3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
--------------------------------------------------------------------------------
/data/enwik8.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/metaformer-gpt/9853266c8dc8daf985b0f346c6fbf2b63991c7f3/data/enwik8.gz
--------------------------------------------------------------------------------
/metaformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/metaformer-gpt/9853266c8dc8daf985b0f346c6fbf2b63991c7f3/metaformer.png
--------------------------------------------------------------------------------
/metaformer_gpt/__init__.py:
--------------------------------------------------------------------------------
1 | from metaformer_gpt.metaformer_gpt import MetaformerGPT, MultiheadExponentialTimeDecay
2 |
--------------------------------------------------------------------------------
/metaformer_gpt/autoregressive_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from einops import rearrange
4 | from torch import nn
5 |
6 | # helper function
7 |
8 |
9 | def exists(val):
10 | return val is not None
11 |
12 |
13 | def eval_decorator(fn):
14 | def inner(model, *args, **kwargs):
15 | was_training = model.training
16 | model.eval()
17 | out = fn(model, *args, **kwargs)
18 | model.train(was_training)
19 | return out
20 |
21 | return inner
22 |
23 |
24 | # top k filtering
25 |
26 |
27 | def top_k(logits, thres=0.9):
28 | k = int((1 - thres) * logits.shape[-1])
29 | val, ind = torch.topk(logits, k)
30 | probs = torch.full_like(logits, float("-inf"))
31 | probs.scatter_(1, ind, val)
32 | return probs
33 |
34 |
35 | class AutoregressiveWrapper(nn.Module):
36 | def __init__(self, net, max_seq_len=2048, pad_value=0):
37 | super().__init__()
38 | self.max_seq_len = max_seq_len
39 | self.pad_value = pad_value
40 | self.net = net
41 |
42 | @torch.no_grad()
43 | @eval_decorator
44 | def generate(
45 | self,
46 | start_tokens,
47 | seq_len,
48 | eos_token=None,
49 | temperature=1.0,
50 | filter_thres=0.9,
51 | **kwargs
52 | ):
53 | b, t, device = *start_tokens.shape, start_tokens.device
54 |
55 | out = start_tokens
56 |
57 | for _ in range(seq_len):
58 | logits = self.net(out, **kwargs)[:, -1, :]
59 |
60 | filtered_logits = top_k(logits, thres=filter_thres)
61 | probs = F.softmax(filtered_logits / temperature, dim=-1)
62 |
63 | sample = torch.multinomial(probs, 1)
64 |
65 | out = torch.cat((out, sample), dim=-1)
66 |
67 | if exists(eos_token):
68 | is_eos_token = out == eos_token
69 |
70 | if is_eos_token.any(dim=-1).all():
71 | # mask out everything after the eos tokens
72 | shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
73 | mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
74 | out = out.masked_fill(mask, self.pad_value)
75 | break
76 |
77 | out = out[:, t:]
78 | return out
79 |
80 | def forward(self, x, **kwargs):
81 | x_inp, x_labels = x[:, :-1], x[:, 1:]
82 | logits = self.net(x_inp, **kwargs)
83 | return F.cross_entropy(rearrange(logits, "b c n -> b n c"), x_labels)
84 |
--------------------------------------------------------------------------------
/metaformer_gpt/metaformer_gpt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from einops import rearrange, repeat
4 |
5 | from scipy.fftpack import next_fast_len
6 |
7 | # helper functions
8 |
9 | def cummean(x, *, dim):
10 | numer = x.cumsum(dim = dim)
11 | denom = torch.arange(x.shape[1], device = x.device) + 1
12 | return numer / rearrange(denom, '... -> ... 1')
13 |
14 | def conv1d_fft(x, weights, dim = -2, weight_dim = -1):
15 | # O(N log(N)) 1d convolution using some fourier trick
16 |
17 | N = x.shape[dim]
18 | M = weights.shape[weight_dim]
19 |
20 | fast_len = next_fast_len(N + M - 1)
21 |
22 | f_x = torch.fft.rfft(x, n = fast_len, dim = dim)
23 | f_weight = torch.fft.rfft(weights, n = fast_len, dim = weight_dim)
24 |
25 | f_v_weight = f_x * rearrange(f_weight.conj(), '... -> ... 1')
26 | out = torch.fft.irfft(f_v_weight, fast_len, dim = dim)
27 | out = out.roll(-1, dims = (dim,))
28 |
29 | indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device)
30 | out = out.index_select(dim, indices)
31 | return out
32 |
33 | # classes
34 |
35 | class MeanCenteringPool(nn.Module):
36 | def __init__(
37 | self,
38 | dim
39 | ):
40 | super().__init__()
41 | self.norm = nn.LayerNorm(dim)
42 | self.proj = nn.Linear(dim, dim, bias = False)
43 |
44 | def forward(self, x):
45 | x = self.norm(x)
46 | x = cummean(x, dim = 1) - x
47 | return self.proj(x)
48 |
49 | class MultiheadExponentialTimeDecay(nn.Module):
50 | def __init__(
51 | self,
52 | dim,
53 | *,
54 | heads = 8,
55 | dim_head = 64
56 | ):
57 | super().__init__()
58 | self.heads = heads
59 | inner_dim = heads * dim_head
60 |
61 | self.norm = nn.LayerNorm(dim)
62 | self.alpha = nn.Parameter(torch.randn(heads))
63 |
64 | self.project_in = nn.Linear(dim, inner_dim, bias = False)
65 | self.project_out = nn.Linear(inner_dim, dim, bias = False)
66 |
67 | def forward(self, x):
68 | b, n, d, h, device = *x.shape, self.heads, x.device
69 |
70 | x = self.norm(x)
71 |
72 | # linear project in
73 |
74 | x = self.project_in(x)
75 |
76 | # split out heads
77 |
78 | x = rearrange(x, 'b n (h d) -> b h n d', h = h)
79 |
80 | # prepare exponential alpha
81 |
82 | alpha = self.alpha.sigmoid()
83 | alpha = rearrange(alpha, 'h -> h 1')
84 |
85 | # arange == powers
86 |
87 | arange = torch.arange(n, device = device)
88 | weights = alpha * (1 - alpha) ** torch.flip(arange, dims = (0,))
89 | output = conv1d_fft(x, weights)
90 |
91 | # merge heads
92 |
93 | output = rearrange(output, 'b h n d -> b n (h d)')
94 | return self.project_out(output)
95 |
96 | def FeedForward(dim, mult = 4):
97 | hidden_dim = int(dim * mult)
98 | return nn.Sequential(
99 | nn.LayerNorm(dim),
100 | nn.Linear(dim, hidden_dim, bias = False),
101 | nn.GELU(),
102 | nn.Linear(hidden_dim, dim, bias = False)
103 | )
104 |
105 | class MetaformerGPT(nn.Module):
106 | def __init__(
107 | self,
108 | *,
109 | num_tokens,
110 | dim,
111 | depth,
112 | heads = 16,
113 | dim_head = 32,
114 | max_seq_len = 2048,
115 | ff_mult = 4
116 | ):
117 | super().__init__()
118 | self.token_emb = nn.Embedding(num_tokens, dim)
119 | self.pos_emb = nn.Embedding(max_seq_len, dim)
120 |
121 | self.layers = nn.ModuleList([])
122 | for _ in range(depth):
123 | self.layers.append(nn.ModuleList([
124 | MultiheadExponentialTimeDecay(dim, heads = heads, dim_head = dim_head),
125 | MeanCenteringPool(dim),
126 | FeedForward(dim, mult = ff_mult)
127 | ]))
128 |
129 | self.to_logits = nn.Sequential(
130 | nn.LayerNorm(dim),
131 | nn.Linear(dim, num_tokens, bias = False)
132 | )
133 |
134 | def forward(self, x):
135 | n, device = x.shape[1], x.device
136 |
137 | x = self.token_emb(x)
138 | x = x + self.pos_emb(torch.arange(n, device = device))
139 |
140 | for mh_esa, pool, ff in self.layers:
141 | x = mh_esa(x) + x
142 | x = pool(x) + x
143 | x = ff(x) + x
144 |
145 | return self.to_logits(x)
146 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'metaformer-gpt',
5 | packages = find_packages(exclude=[]),
6 | version = '0.0.5',
7 | license='MIT',
8 | description = 'Metaformer - GPT',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/metaformer-gpt',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'transformers',
17 | 'attention-less'
18 | ],
19 | install_requires=[
20 | 'einops>=0.4',
21 | 'scipy',
22 | 'torch>=1.6',
23 | ],
24 | classifiers=[
25 | 'Development Status :: 4 - Beta',
26 | 'Intended Audience :: Developers',
27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
28 | 'License :: OSI Approved :: MIT License',
29 | 'Programming Language :: Python :: 3.6',
30 | ],
31 | )
32 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import random
3 |
4 | import numpy as np
5 | import torch
6 | import torch.optim as optim
7 | import tqdm
8 | from torch.nn import functional as F
9 | from torch.utils.data import DataLoader, Dataset
10 |
11 | from metaformer_gpt import MetaformerGPT
12 | from metaformer_gpt.autoregressive_wrapper import AutoregressiveWrapper
13 |
14 | # constants
15 |
16 | NUM_BATCHES = int(1e5)
17 | BATCH_SIZE = 4
18 | GRADIENT_ACCUMULATE_EVERY = 4
19 | LEARNING_RATE = 2e-4
20 | VALIDATE_EVERY = 100
21 | GENERATE_EVERY = 500
22 | GENERATE_LENGTH = 512
23 | SEQ_LEN = 1024
24 |
25 | # helpers
26 |
27 |
28 | def cycle(loader):
29 | while True:
30 | for data in loader:
31 | yield data
32 |
33 |
34 | def decode_token(token):
35 | return str(chr(max(32, token)))
36 |
37 |
38 | def decode_tokens(tokens):
39 | return "".join(list(map(decode_token, tokens)))
40 |
41 |
42 | # instantiate GPT-like decoder model
43 |
44 | model = MetaformerGPT(
45 | num_tokens = 256,
46 | dim = 512,
47 | depth = 8,
48 | heads = 16,
49 | dim_head = 32
50 | )
51 |
52 | model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
53 | model.cuda()
54 |
55 | # prepare enwik8 data
56 |
57 | with gzip.open("./data/enwik8.gz") as file:
58 | X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
59 | trX, vaX = np.split(X, [int(90e6)])
60 | data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
61 |
62 |
63 | class TextSamplerDataset(Dataset):
64 | def __init__(self, data, seq_len):
65 | super().__init__()
66 | self.data = data
67 | self.seq_len = seq_len
68 |
69 | def __getitem__(self, index):
70 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
71 | full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
72 | return full_seq.cuda()
73 |
74 | def __len__(self):
75 | return self.data.size(0) // self.seq_len
76 |
77 |
78 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
79 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
80 | train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
81 | val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
82 |
83 | # optimizer
84 |
85 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
86 |
87 | # training
88 |
89 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
90 | model.train()
91 |
92 | for __ in range(GRADIENT_ACCUMULATE_EVERY):
93 | loss = model(next(train_loader))
94 | loss.backward()
95 |
96 | print(f"training loss: {loss.item()}")
97 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
98 | optim.step()
99 | optim.zero_grad()
100 |
101 | if i % VALIDATE_EVERY == 0:
102 | model.eval()
103 | with torch.no_grad():
104 | loss = model(next(val_loader))
105 | print(f"validation loss: {loss.item()}")
106 |
107 | if i % GENERATE_EVERY == 0:
108 | model.eval()
109 | inp = random.choice(val_dataset)[:-1]
110 | prime = decode_tokens(inp)
111 | print(f"%s \n\n %s", (prime, "*" * 100))
112 |
113 | sample = model.generate(inp[None, ...], GENERATE_LENGTH)
114 | output_str = decode_tokens(sample[0])
115 | print(output_str)
116 |
--------------------------------------------------------------------------------