├── tranception.png ├── tranception_pytorch ├── __init__.py └── tranception_pytorch.py ├── setup.py ├── LICENSE ├── .github └── workflows │ └── python-publish.yml ├── README.md └── .gitignore /tranception.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/tranception-pytorch/HEAD/tranception.png -------------------------------------------------------------------------------- /tranception_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from tranception_pytorch.tranception_pytorch import Tranception 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'tranception-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.0.8', 7 | license='MIT', 8 | description = 'Tranception - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/tranception-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'transformers', 17 | 'attention mechanism', 18 | 'protein fitness' 19 | ], 20 | install_requires=[ 21 | 'einops>=0.4', 22 | 'einops-exts', 23 | 'torch>=1.6', 24 | ], 25 | classifiers=[ 26 | 'Development Status :: 4 - Beta', 27 | 'Intended Audience :: Developers', 28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 29 | 'License :: OSI Approved :: MIT License', 30 | 'Programming Language :: Python :: 3.6', 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Tranception - Pytorch (wip) 4 | 5 | Implementation of Tranception, an attention network, paired with retrieval, that is SOTA for protein fitness prediction. The Transformer architecture is inspired by Primer, and uses ALiBi relative positional encoding 6 | 7 | ## Install 8 | 9 | ```bash 10 | $ pip install tranception-pytorch 11 | ``` 12 | 13 | ## Usage 14 | 15 | ```python 16 | import torch 17 | from tranception_pytorch import Tranception 18 | 19 | model = Tranception( 20 | dim = 512, 21 | depth = 6, 22 | heads = 8, 23 | dim_head = 64 24 | ) 25 | 26 | amino_acids = torch.randint(0, 21, (1, 512)) 27 | 28 | logits = model(amino_acids) # (1, 512, 21) 29 | ``` 30 | 31 | ## Todo 32 | 33 | - [x] grouped heads with customizable depthwise convs (for variable k-mers), as well as grouped alibi pos bias 34 | - [ ] figure out attention to retrieved (looks like axial attention?) 35 | - [ ] play around with protein gym, and start betting on huggingface's accelerate 36 | 37 | ## Citations 38 | 39 | ```bibtex 40 | @article{Notin2022TranceptionPF, 41 | title = {Tranception: protein fitness prediction with autoregressive transformers and inference-time retrieval}, 42 | author = {Pascal Notin and Mafalda Dias and Jonathan Frazer and Javier Marchena-Hurtado and Aidan N. Gomez and Debora S. Marks and Yarin Gal}, 43 | journal = {ArXiv}, 44 | year = {2022}, 45 | volume = {abs/2205.13760} 46 | } 47 | ``` 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /tranception_pytorch/tranception_pytorch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn, einsum 5 | 6 | from einops import rearrange 7 | from einops_exts import rearrange_many 8 | from einops.layers.torch import Rearrange 9 | 10 | # helpers 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | def default(val, d): 16 | return val if exists(val) else d 17 | 18 | # relative positional bias 19 | 20 | class LearnedAlibiPosBias(nn.Module): 21 | def __init__(self, heads): 22 | super().__init__() 23 | self.heads = heads 24 | slopes = torch.Tensor(self._get_slopes(heads)) 25 | slopes = rearrange(slopes, 'h -> h 1 1') 26 | self.slopes = nn.Parameter(slopes) 27 | self.register_buffer('bias', None, persistent = False) 28 | 29 | def get_bias(self, i, j, device): 30 | i_arange = torch.arange(i, device = device) 31 | j_arange = torch.arange(j, device = device) 32 | bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1')) 33 | return bias 34 | 35 | @staticmethod 36 | def _get_slopes(heads): 37 | def get_slopes_power_of_2(n): 38 | start = (2**(-2**-(math.log2(n)-3))) 39 | ratio = start 40 | return [start*ratio**i for i in range(n)] 41 | 42 | if math.log2(heads).is_integer(): 43 | return get_slopes_power_of_2(heads) 44 | 45 | closest_power_of_2 = 2 ** math.floor(math.log2(heads)) 46 | return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2] 47 | 48 | def forward(self, qk_sim): 49 | h, i, j, device = *qk_sim.shape[-3:], qk_sim.device 50 | 51 | if exists(self.bias) and self.bias.shape[-1] >= j: 52 | return self.bias[..., :i, :j] 53 | 54 | bias = self.get_bias(i, j, device) 55 | bias = bias * self.slopes 56 | 57 | num_heads_unalibied = h - bias.shape[0] 58 | bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied)) 59 | self.register_buffer('bias', bias, persistent = False) 60 | 61 | return bias 62 | 63 | # helper classes 64 | 65 | class ReluSquared(nn.Module): 66 | """ found with neural architecture search in Primer paper """ 67 | def forward(self, x): 68 | return F.relu(x) ** 2 69 | 70 | def FeedForward(dim, mult = 4): 71 | hidden_dim = int(dim * mult) 72 | return nn.Sequential( 73 | nn.LayerNorm(dim), 74 | nn.Linear(dim, hidden_dim), 75 | ReluSquared(), 76 | nn.Linear(hidden_dim, dim) 77 | ) 78 | 79 | class DepthwiseConv1d(nn.Module): 80 | def __init__(self, dim, kernel_size, causal = True): 81 | super().__init__() 82 | assert (kernel_size % 2) == 1 83 | 84 | self.padding = (kernel_size - 1, 0) if causal else (kernel_size // 2, kernel_size // 2) 85 | self.conv = nn.Conv1d(dim, dim, kernel_size = kernel_size, groups = dim) 86 | 87 | def forward(self, x): 88 | x = F.pad(x, self.padding) 89 | return self.conv(x) 90 | 91 | class Attention(nn.Module): 92 | def __init__( 93 | self, 94 | *, 95 | dim, 96 | heads = 8, 97 | dim_head = 64, 98 | causal = False, 99 | ds_conv_kernel_sizes = (0, 3, 5, 7) # heads were grouped into 4 groups and given a depthwise conv after the queries / keys / values projection 100 | ): 101 | super().__init__() 102 | self.groups = len(ds_conv_kernel_sizes) 103 | assert heads >= self.groups and (heads % self.groups) == 0, f'heads must be greater than {self.groups} and divisible by {self.groups}' 104 | 105 | self.scale = dim_head ** -0.5 106 | self.causal = causal 107 | 108 | self.heads = heads 109 | self.heads_per_group = heads // self.groups 110 | 111 | inner_dim = heads * dim_head 112 | 113 | self.norm = nn.LayerNorm(dim) 114 | 115 | self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False) 116 | 117 | # ds convs with different kernel sizes for 4 groups of heads 118 | 119 | self.qkv_ds_convs = nn.ModuleList([]) 120 | 121 | for _ in range(3): # for queries, keys, values 122 | ds_convs = nn.ModuleList([]) 123 | 124 | for kernel_size in ds_conv_kernel_sizes: 125 | if kernel_size == 0: 126 | ds_convs.append(nn.Identity()) 127 | continue 128 | 129 | ds_convs.append(DepthwiseConv1d(dim_head * self.heads_per_group, kernel_size, causal = causal)) 130 | 131 | self.qkv_ds_convs.append(ds_convs) 132 | 133 | # learned alibi positional bias for 4 groups of heads 134 | 135 | self.learned_alibi_pos_biases = nn.ModuleList([LearnedAlibiPosBias(heads = self.heads_per_group) for _ in range(self.groups)]) 136 | 137 | # outward projection 138 | 139 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 140 | 141 | def forward(self, x): 142 | device, heads_per_group = x.device, self.heads_per_group 143 | 144 | x = self.norm(x) 145 | x = rearrange(x, 'b n d -> b d n') 146 | 147 | q, k, v = self.to_qkv(x).chunk(3, dim = 1) 148 | 149 | q, k, v = rearrange_many((q, k, v), 'b (h d) n -> b h d n', h = self.heads) 150 | 151 | # apply causal depthwise conv to queries, keys, values (a la Primer) with different kernel sizes across 4 groups of heads 152 | 153 | def apply_causal_ds_conv_to_grouped_heads(args): 154 | projs, ds_convs = args 155 | batch = projs.shape[0] 156 | 157 | projs = rearrange_many(projs.split(heads_per_group, dim = 1), 'b h d n -> b (h d) n') 158 | conv_out = [fn(t) for fn, t in zip(ds_convs, projs)] 159 | conv_out = map(lambda t: rearrange(t, 'b (h d) n -> b h d n', h = heads_per_group), conv_out) 160 | conv_out = torch.cat(tuple(conv_out), dim = 1) 161 | return rearrange(conv_out, 'b h d n -> b h n d') 162 | 163 | q, k, v = map(apply_causal_ds_conv_to_grouped_heads, zip((q, k, v), self.qkv_ds_convs)) 164 | 165 | # scale and similarity 166 | 167 | q = q * self.scale 168 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 169 | 170 | # learned alibi pos bias across 4 groups of heads 171 | # so heads specialize to looking at different distances of kmers 172 | 173 | grouped_sims = sim.split(self.heads // self.groups, dim = 1) 174 | grouped_sims = [(alibi(sim_group) + sim_group) for alibi, sim_group in zip(self.learned_alibi_pos_biases, grouped_sims)] 175 | 176 | sim = torch.cat(grouped_sims, dim = 1) 177 | 178 | # causal mask 179 | 180 | if self.causal: 181 | i, j = sim.shape[-2:] 182 | causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1) 183 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 184 | 185 | # attention, but of course 186 | 187 | attn = sim.softmax(dim = -1) 188 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 189 | 190 | # merge heads 191 | 192 | out = rearrange(out, 'b h n d -> b n (h d)') 193 | return self.to_out(out) 194 | 195 | # classes 196 | 197 | class Tranception(nn.Module): 198 | def __init__( 199 | self, 200 | *, 201 | dim, 202 | depth, 203 | num_tokens = 21, 204 | heads = 8, 205 | dim_head = 64, 206 | ff_mult = 4, 207 | ds_conv_kernel_sizes = (0, 3, 5, 7), 208 | causal = True 209 | ): 210 | super().__init__() 211 | self.token_emb = nn.Embedding(num_tokens, dim) 212 | 213 | self.layers = nn.ModuleList([]) 214 | for _ in range(depth): 215 | self.layers.append(nn.ModuleList([ 216 | Attention(dim = dim, heads = heads, dim_head = dim_head, ds_conv_kernel_sizes = ds_conv_kernel_sizes, causal = causal), 217 | FeedForward(dim, mult = ff_mult) 218 | ])) 219 | 220 | self.to_logits = nn.Sequential( 221 | nn.LayerNorm(dim), 222 | nn.Linear(dim, num_tokens) 223 | ) 224 | 225 | def forward( 226 | self, 227 | x, 228 | mask = None 229 | ): 230 | x = self.token_emb(x) 231 | 232 | for attn, ff in self.layers: 233 | x = attn(x) + x 234 | x = ff(x) + x 235 | 236 | return self.to_logits(x) 237 | --------------------------------------------------------------------------------