├── tranception.png
├── tranception_pytorch
├── __init__.py
└── tranception_pytorch.py
├── setup.py
├── LICENSE
├── .github
└── workflows
│ └── python-publish.yml
├── README.md
└── .gitignore
/tranception.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/tranception-pytorch/HEAD/tranception.png
--------------------------------------------------------------------------------
/tranception_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from tranception_pytorch.tranception_pytorch import Tranception
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'tranception-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.0.8',
7 | license='MIT',
8 | description = 'Tranception - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/tranception-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'transformers',
17 | 'attention mechanism',
18 | 'protein fitness'
19 | ],
20 | install_requires=[
21 | 'einops>=0.4',
22 | 'einops-exts',
23 | 'torch>=1.6',
24 | ],
25 | classifiers=[
26 | 'Development Status :: 4 - Beta',
27 | 'Intended Audience :: Developers',
28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
29 | 'License :: OSI Approved :: MIT License',
30 | 'Programming Language :: Python :: 3.6',
31 | ],
32 | )
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Tranception - Pytorch (wip)
4 |
5 | Implementation of Tranception, an attention network, paired with retrieval, that is SOTA for protein fitness prediction. The Transformer architecture is inspired by Primer, and uses ALiBi relative positional encoding
6 |
7 | ## Install
8 |
9 | ```bash
10 | $ pip install tranception-pytorch
11 | ```
12 |
13 | ## Usage
14 |
15 | ```python
16 | import torch
17 | from tranception_pytorch import Tranception
18 |
19 | model = Tranception(
20 | dim = 512,
21 | depth = 6,
22 | heads = 8,
23 | dim_head = 64
24 | )
25 |
26 | amino_acids = torch.randint(0, 21, (1, 512))
27 |
28 | logits = model(amino_acids) # (1, 512, 21)
29 | ```
30 |
31 | ## Todo
32 |
33 | - [x] grouped heads with customizable depthwise convs (for variable k-mers), as well as grouped alibi pos bias
34 | - [ ] figure out attention to retrieved (looks like axial attention?)
35 | - [ ] play around with protein gym, and start betting on huggingface's accelerate
36 |
37 | ## Citations
38 |
39 | ```bibtex
40 | @article{Notin2022TranceptionPF,
41 | title = {Tranception: protein fitness prediction with autoregressive transformers and inference-time retrieval},
42 | author = {Pascal Notin and Mafalda Dias and Jonathan Frazer and Javier Marchena-Hurtado and Aidan N. Gomez and Debora S. Marks and Yarin Gal},
43 | journal = {ArXiv},
44 | year = {2022},
45 | volume = {abs/2205.13760}
46 | }
47 | ```
48 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/tranception_pytorch/tranception_pytorch.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn, einsum
5 |
6 | from einops import rearrange
7 | from einops_exts import rearrange_many
8 | from einops.layers.torch import Rearrange
9 |
10 | # helpers
11 |
12 | def exists(val):
13 | return val is not None
14 |
15 | def default(val, d):
16 | return val if exists(val) else d
17 |
18 | # relative positional bias
19 |
20 | class LearnedAlibiPosBias(nn.Module):
21 | def __init__(self, heads):
22 | super().__init__()
23 | self.heads = heads
24 | slopes = torch.Tensor(self._get_slopes(heads))
25 | slopes = rearrange(slopes, 'h -> h 1 1')
26 | self.slopes = nn.Parameter(slopes)
27 | self.register_buffer('bias', None, persistent = False)
28 |
29 | def get_bias(self, i, j, device):
30 | i_arange = torch.arange(i, device = device)
31 | j_arange = torch.arange(j, device = device)
32 | bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
33 | return bias
34 |
35 | @staticmethod
36 | def _get_slopes(heads):
37 | def get_slopes_power_of_2(n):
38 | start = (2**(-2**-(math.log2(n)-3)))
39 | ratio = start
40 | return [start*ratio**i for i in range(n)]
41 |
42 | if math.log2(heads).is_integer():
43 | return get_slopes_power_of_2(heads)
44 |
45 | closest_power_of_2 = 2 ** math.floor(math.log2(heads))
46 | return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
47 |
48 | def forward(self, qk_sim):
49 | h, i, j, device = *qk_sim.shape[-3:], qk_sim.device
50 |
51 | if exists(self.bias) and self.bias.shape[-1] >= j:
52 | return self.bias[..., :i, :j]
53 |
54 | bias = self.get_bias(i, j, device)
55 | bias = bias * self.slopes
56 |
57 | num_heads_unalibied = h - bias.shape[0]
58 | bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
59 | self.register_buffer('bias', bias, persistent = False)
60 |
61 | return bias
62 |
63 | # helper classes
64 |
65 | class ReluSquared(nn.Module):
66 | """ found with neural architecture search in Primer paper """
67 | def forward(self, x):
68 | return F.relu(x) ** 2
69 |
70 | def FeedForward(dim, mult = 4):
71 | hidden_dim = int(dim * mult)
72 | return nn.Sequential(
73 | nn.LayerNorm(dim),
74 | nn.Linear(dim, hidden_dim),
75 | ReluSquared(),
76 | nn.Linear(hidden_dim, dim)
77 | )
78 |
79 | class DepthwiseConv1d(nn.Module):
80 | def __init__(self, dim, kernel_size, causal = True):
81 | super().__init__()
82 | assert (kernel_size % 2) == 1
83 |
84 | self.padding = (kernel_size - 1, 0) if causal else (kernel_size // 2, kernel_size // 2)
85 | self.conv = nn.Conv1d(dim, dim, kernel_size = kernel_size, groups = dim)
86 |
87 | def forward(self, x):
88 | x = F.pad(x, self.padding)
89 | return self.conv(x)
90 |
91 | class Attention(nn.Module):
92 | def __init__(
93 | self,
94 | *,
95 | dim,
96 | heads = 8,
97 | dim_head = 64,
98 | causal = False,
99 | ds_conv_kernel_sizes = (0, 3, 5, 7) # heads were grouped into 4 groups and given a depthwise conv after the queries / keys / values projection
100 | ):
101 | super().__init__()
102 | self.groups = len(ds_conv_kernel_sizes)
103 | assert heads >= self.groups and (heads % self.groups) == 0, f'heads must be greater than {self.groups} and divisible by {self.groups}'
104 |
105 | self.scale = dim_head ** -0.5
106 | self.causal = causal
107 |
108 | self.heads = heads
109 | self.heads_per_group = heads // self.groups
110 |
111 | inner_dim = heads * dim_head
112 |
113 | self.norm = nn.LayerNorm(dim)
114 |
115 | self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
116 |
117 | # ds convs with different kernel sizes for 4 groups of heads
118 |
119 | self.qkv_ds_convs = nn.ModuleList([])
120 |
121 | for _ in range(3): # for queries, keys, values
122 | ds_convs = nn.ModuleList([])
123 |
124 | for kernel_size in ds_conv_kernel_sizes:
125 | if kernel_size == 0:
126 | ds_convs.append(nn.Identity())
127 | continue
128 |
129 | ds_convs.append(DepthwiseConv1d(dim_head * self.heads_per_group, kernel_size, causal = causal))
130 |
131 | self.qkv_ds_convs.append(ds_convs)
132 |
133 | # learned alibi positional bias for 4 groups of heads
134 |
135 | self.learned_alibi_pos_biases = nn.ModuleList([LearnedAlibiPosBias(heads = self.heads_per_group) for _ in range(self.groups)])
136 |
137 | # outward projection
138 |
139 | self.to_out = nn.Linear(inner_dim, dim, bias = False)
140 |
141 | def forward(self, x):
142 | device, heads_per_group = x.device, self.heads_per_group
143 |
144 | x = self.norm(x)
145 | x = rearrange(x, 'b n d -> b d n')
146 |
147 | q, k, v = self.to_qkv(x).chunk(3, dim = 1)
148 |
149 | q, k, v = rearrange_many((q, k, v), 'b (h d) n -> b h d n', h = self.heads)
150 |
151 | # apply causal depthwise conv to queries, keys, values (a la Primer) with different kernel sizes across 4 groups of heads
152 |
153 | def apply_causal_ds_conv_to_grouped_heads(args):
154 | projs, ds_convs = args
155 | batch = projs.shape[0]
156 |
157 | projs = rearrange_many(projs.split(heads_per_group, dim = 1), 'b h d n -> b (h d) n')
158 | conv_out = [fn(t) for fn, t in zip(ds_convs, projs)]
159 | conv_out = map(lambda t: rearrange(t, 'b (h d) n -> b h d n', h = heads_per_group), conv_out)
160 | conv_out = torch.cat(tuple(conv_out), dim = 1)
161 | return rearrange(conv_out, 'b h d n -> b h n d')
162 |
163 | q, k, v = map(apply_causal_ds_conv_to_grouped_heads, zip((q, k, v), self.qkv_ds_convs))
164 |
165 | # scale and similarity
166 |
167 | q = q * self.scale
168 | sim = einsum('b h i d, b h j d -> b h i j', q, k)
169 |
170 | # learned alibi pos bias across 4 groups of heads
171 | # so heads specialize to looking at different distances of kmers
172 |
173 | grouped_sims = sim.split(self.heads // self.groups, dim = 1)
174 | grouped_sims = [(alibi(sim_group) + sim_group) for alibi, sim_group in zip(self.learned_alibi_pos_biases, grouped_sims)]
175 |
176 | sim = torch.cat(grouped_sims, dim = 1)
177 |
178 | # causal mask
179 |
180 | if self.causal:
181 | i, j = sim.shape[-2:]
182 | causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
183 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
184 |
185 | # attention, but of course
186 |
187 | attn = sim.softmax(dim = -1)
188 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
189 |
190 | # merge heads
191 |
192 | out = rearrange(out, 'b h n d -> b n (h d)')
193 | return self.to_out(out)
194 |
195 | # classes
196 |
197 | class Tranception(nn.Module):
198 | def __init__(
199 | self,
200 | *,
201 | dim,
202 | depth,
203 | num_tokens = 21,
204 | heads = 8,
205 | dim_head = 64,
206 | ff_mult = 4,
207 | ds_conv_kernel_sizes = (0, 3, 5, 7),
208 | causal = True
209 | ):
210 | super().__init__()
211 | self.token_emb = nn.Embedding(num_tokens, dim)
212 |
213 | self.layers = nn.ModuleList([])
214 | for _ in range(depth):
215 | self.layers.append(nn.ModuleList([
216 | Attention(dim = dim, heads = heads, dim_head = dim_head, ds_conv_kernel_sizes = ds_conv_kernel_sizes, causal = causal),
217 | FeedForward(dim, mult = ff_mult)
218 | ]))
219 |
220 | self.to_logits = nn.Sequential(
221 | nn.LayerNorm(dim),
222 | nn.Linear(dim, num_tokens)
223 | )
224 |
225 | def forward(
226 | self,
227 | x,
228 | mask = None
229 | ):
230 | x = self.token_emb(x)
231 |
232 | for attn, ff in self.layers:
233 | x = attn(x) + x
234 | x = ff(x) + x
235 |
236 | return self.to_logits(x)
237 |
--------------------------------------------------------------------------------