├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── graph-transformer.png ├── graph_transformer_pytorch ├── __init__.py └── graph_transformer_pytorch.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Graph Transformer - Pytorch 4 | 5 | Implementation of Graph Transformer in Pytorch, for potential use in replicating Alphafold2. This was recently used by both Costa et al and Bakers lab for transforming MSA and pair-wise embedding into 3d coordinates. 6 | 7 | ## Install 8 | 9 | ```bash 10 | $ pip install graph-transformer-pytorch 11 | ``` 12 | 13 | ## Usage 14 | 15 | ```python 16 | import torch 17 | from graph_transformer_pytorch import GraphTransformer 18 | 19 | model = GraphTransformer( 20 | dim = 256, 21 | depth = 6, 22 | edge_dim = 512, # optional - if left out, edge dimensions is assumed to be the same as the node dimensions above 23 | with_feedforwards = True, # whether to add a feedforward after each attention layer, suggested by literature to be needed 24 | gated_residual = True, # to use the gated residual to prevent over-smoothing 25 | rel_pos_emb = True # set to True if the nodes are ordered, default to False 26 | ) 27 | 28 | nodes = torch.randn(1, 128, 256) 29 | edges = torch.randn(1, 128, 128, 512) 30 | mask = torch.ones(1, 128).bool() 31 | 32 | nodes, edges = model(nodes, edges, mask = mask) 33 | 34 | nodes.shape # (1, 128, 256) - project to R^3 for coordinates 35 | ``` 36 | 37 | If you want it to handle an adjacency matrix 38 | 39 | ```python 40 | import torch 41 | from graph_transformer_pytorch import GraphTransformer 42 | 43 | model = GraphTransformer( 44 | dim = 256, 45 | depth = 6, 46 | edge_dim = 512, 47 | with_feedforwards = True, 48 | gated_residual = True, 49 | rel_pos_emb = True, 50 | accept_adjacency_matrix = True # set this to True 51 | ) 52 | 53 | nodes = torch.randn(2, 128, 256) 54 | adj_mat = torch.randint(0, 2, (2, 128, 128)) 55 | mask = torch.ones(2, 128).bool() 56 | 57 | nodes, edges = model(nodes, adj_mat = adj_mat, mask = mask) 58 | 59 | nodes.shape # (1, 128, 256) - project to R^3 for coordinates 60 | ``` 61 | 62 | ## Citations 63 | 64 | ```bibtex 65 | @article {Costa2021.06.02.446809, 66 | author = {Costa, Allan and Ponnapati, Manvitha and Jacobson, Joseph M. and Chatterjee, Pranam}, 67 | title = {Distillation of MSA Embeddings to Folded Protein Structures with Graph Transformers}, 68 | year = {2021}, 69 | doi = {10.1101/2021.06.02.446809}, 70 | publisher = {Cold Spring Harbor Laboratory}, 71 | URL = {https://www.biorxiv.org/content/early/2021/06/02/2021.06.02.446809}, 72 | eprint = {https://www.biorxiv.org/content/early/2021/06/02/2021.06.02.446809.full.pdf}, 73 | journal = {bioRxiv} 74 | } 75 | ``` 76 | 77 | ```bibtex 78 | @article {Baek2021.06.14.448402, 79 | author = {Baek, Minkyung and DiMaio, Frank and Anishchenko, Ivan and Dauparas, Justas and Ovchinnikov, Sergey and Lee, Gyu Rie and Wang, Jue and Cong, Qian and Kinch, Lisa N. and Schaeffer, R. Dustin and Mill{\'a}n, Claudia and Park, Hahnbeom and Adams, Carson and Glassman, Caleb R. and DeGiovanni, Andy and Pereira, Jose H. and Rodrigues, Andria V. and van Dijk, Alberdina A. and Ebrecht, Ana C. and Opperman, Diederik J. and Sagmeister, Theo and Buhlheller, Christoph and Pavkov-Keller, Tea and Rathinaswamy, Manoj K and Dalwadi, Udit and Yip, Calvin K and Burke, John E and Garcia, K. Christopher and Grishin, Nick V. and Adams, Paul D. and Read, Randy J. and Baker, David}, 80 | title = {Accurate prediction of protein structures and interactions using a 3-track network}, 81 | year = {2021}, 82 | doi = {10.1101/2021.06.14.448402}, 83 | publisher = {Cold Spring Harbor Laboratory}, 84 | URL = {https://www.biorxiv.org/content/early/2021/06/15/2021.06.14.448402}, 85 | eprint = {https://www.biorxiv.org/content/early/2021/06/15/2021.06.14.448402.full.pdf}, 86 | journal = {bioRxiv} 87 | } 88 | ``` 89 | 90 | ```bibtex 91 | @misc{shi2021masked, 92 | title = {Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification}, 93 | author = {Yunsheng Shi and Zhengjie Huang and Shikun Feng and Hui Zhong and Wenjin Wang and Yu Sun}, 94 | year = {2021}, 95 | eprint = {2009.03509}, 96 | archivePrefix = {arXiv}, 97 | primaryClass = {cs.LG} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /graph-transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/graph-transformer-pytorch/8613d2141c84f1cbaa4e2802e3091638c706c9ca/graph-transformer.png -------------------------------------------------------------------------------- /graph_transformer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from graph_transformer_pytorch.graph_transformer_pytorch import GraphTransformer 2 | -------------------------------------------------------------------------------- /graph_transformer_pytorch/graph_transformer_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange, repeat 4 | 5 | from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb 6 | 7 | # helpers 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | def default(val, d): 13 | return val if exists(val) else d 14 | 15 | List = nn.ModuleList 16 | 17 | # normalizations 18 | 19 | class PreNorm(nn.Module): 20 | def __init__( 21 | self, 22 | dim, 23 | fn 24 | ): 25 | super().__init__() 26 | self.fn = fn 27 | self.norm = nn.LayerNorm(dim) 28 | 29 | def forward(self, x, *args, **kwargs): 30 | x = self.norm(x) 31 | return self.fn(x, *args,**kwargs) 32 | 33 | # gated residual 34 | 35 | class Residual(nn.Module): 36 | def forward(self, x, res): 37 | return x + res 38 | 39 | class GatedResidual(nn.Module): 40 | def __init__(self, dim): 41 | super().__init__() 42 | self.proj = nn.Sequential( 43 | nn.Linear(dim * 3, 1, bias = False), 44 | nn.Sigmoid() 45 | ) 46 | 47 | def forward(self, x, res): 48 | gate_input = torch.cat((x, res, x - res), dim = -1) 49 | gate = self.proj(gate_input) 50 | return x * gate + res * (1 - gate) 51 | 52 | # attention 53 | 54 | class Attention(nn.Module): 55 | def __init__( 56 | self, 57 | dim, 58 | pos_emb = None, 59 | dim_head = 64, 60 | heads = 8, 61 | edge_dim = None 62 | ): 63 | super().__init__() 64 | edge_dim = default(edge_dim, dim) 65 | 66 | inner_dim = dim_head * heads 67 | self.heads = heads 68 | self.scale = dim_head ** -0.5 69 | 70 | self.pos_emb = pos_emb 71 | 72 | self.to_q = nn.Linear(dim, inner_dim) 73 | self.to_kv = nn.Linear(dim, inner_dim * 2) 74 | self.edges_to_kv = nn.Linear(edge_dim, inner_dim) 75 | 76 | self.to_out = nn.Linear(inner_dim, dim) 77 | 78 | def forward(self, nodes, edges, mask = None): 79 | h = self.heads 80 | 81 | q = self.to_q(nodes) 82 | k, v = self.to_kv(nodes).chunk(2, dim = -1) 83 | 84 | e_kv = self.edges_to_kv(edges) 85 | 86 | q, k, v, e_kv = map(lambda t: rearrange(t, 'b ... (h d) -> (b h) ... d', h = h), (q, k, v, e_kv)) 87 | 88 | if exists(self.pos_emb): 89 | freqs = self.pos_emb(torch.arange(nodes.shape[1], device = nodes.device)) 90 | freqs = rearrange(freqs, 'n d -> () n d') 91 | q = apply_rotary_emb(freqs, q) 92 | k = apply_rotary_emb(freqs, k) 93 | 94 | ek, ev = e_kv, e_kv 95 | 96 | k, v = map(lambda t: rearrange(t, 'b j d -> b () j d '), (k, v)) 97 | k = k + ek 98 | v = v + ev 99 | 100 | sim = einsum('b i d, b i j d -> b i j', q, k) * self.scale 101 | 102 | if exists(mask): 103 | mask = rearrange(mask, 'b i -> b i ()') & rearrange(mask, 'b j -> b () j') 104 | mask = repeat(mask, 'b i j -> (b h) i j', h = h) 105 | max_neg_value = -torch.finfo(sim.dtype).max 106 | sim.masked_fill_(~mask, max_neg_value) 107 | 108 | attn = sim.softmax(dim = -1) 109 | out = einsum('b i j, b i j d -> b i d', attn, v) 110 | out = rearrange(out, '(b h) n d -> b n (h d)', h = h) 111 | return self.to_out(out) 112 | 113 | # optional feedforward 114 | 115 | def FeedForward(dim, ff_mult = 4): 116 | return nn.Sequential( 117 | nn.Linear(dim, dim * ff_mult), 118 | nn.GELU(), 119 | nn.Linear(dim * ff_mult, dim) 120 | ) 121 | 122 | # classes 123 | 124 | class GraphTransformer(nn.Module): 125 | def __init__( 126 | self, 127 | dim, 128 | depth, 129 | dim_head = 64, 130 | edge_dim = None, 131 | heads = 8, 132 | gated_residual = True, 133 | with_feedforwards = False, 134 | norm_edges = False, 135 | rel_pos_emb = False, 136 | accept_adjacency_matrix = False 137 | ): 138 | super().__init__() 139 | self.layers = List([]) 140 | edge_dim = default(edge_dim, dim) 141 | self.norm_edges = nn.LayerNorm(edge_dim) if norm_edges else nn.Identity() 142 | 143 | self.adj_emb = nn.Embedding(2, edge_dim) if accept_adjacency_matrix else None 144 | 145 | pos_emb = RotaryEmbedding(dim_head) if rel_pos_emb else None 146 | 147 | for _ in range(depth): 148 | self.layers.append(List([ 149 | List([ 150 | PreNorm(dim, Attention(dim, pos_emb = pos_emb, edge_dim = edge_dim, dim_head = dim_head, heads = heads)), 151 | GatedResidual(dim) 152 | ]), 153 | List([ 154 | PreNorm(dim, FeedForward(dim)), 155 | GatedResidual(dim) 156 | ]) if with_feedforwards else None 157 | ])) 158 | 159 | def forward( 160 | self, 161 | nodes, 162 | edges = None, 163 | adj_mat = None, 164 | mask = None 165 | ): 166 | batch, seq, _ = nodes.shape 167 | 168 | if exists(edges): 169 | edges = self.norm_edges(edges) 170 | 171 | if exists(adj_mat): 172 | assert adj_mat.shape == (batch, seq, seq) 173 | assert exists(self.adj_emb), 'accept_adjacency_matrix must be set to True' 174 | adj_mat = self.adj_emb(adj_mat.long()) 175 | 176 | all_edges = default(edges, 0) + default(adj_mat, 0) 177 | 178 | for attn_block, ff_block in self.layers: 179 | attn, attn_residual = attn_block 180 | nodes = attn_residual(attn(nodes, all_edges, mask = mask), nodes) 181 | 182 | if exists(ff_block): 183 | ff, ff_residual = ff_block 184 | nodes = ff_residual(ff(nodes), nodes) 185 | 186 | return nodes, edges 187 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'graph-transformer-pytorch', 5 | packages = find_packages(), 6 | version = '0.1.1', 7 | license='MIT', 8 | description = 'Graph Transformer - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/graph-transformer-pytorch', 12 | long_description_content_type = 'text/markdown', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'transformers', 17 | 'graphs' 18 | ], 19 | install_requires=[ 20 | 'einops>=0.3', 21 | 'rotary-embedding-torch', 22 | 'torch>=1.6' 23 | ], 24 | classifiers=[ 25 | 'Development Status :: 4 - Beta', 26 | 'Intended Audience :: Developers', 27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 28 | 'License :: OSI Approved :: MIT License', 29 | 'Programming Language :: Python :: 3.6', 30 | ], 31 | ) 32 | --------------------------------------------------------------------------------