├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── graph-transformer.png
├── graph_transformer_pytorch
├── __init__.py
└── graph_transformer_pytorch.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Graph Transformer - Pytorch
4 |
5 | Implementation of Graph Transformer in Pytorch, for potential use in replicating Alphafold2. This was recently used by both Costa et al and Bakers lab for transforming MSA and pair-wise embedding into 3d coordinates.
6 |
7 | ## Install
8 |
9 | ```bash
10 | $ pip install graph-transformer-pytorch
11 | ```
12 |
13 | ## Usage
14 |
15 | ```python
16 | import torch
17 | from graph_transformer_pytorch import GraphTransformer
18 |
19 | model = GraphTransformer(
20 | dim = 256,
21 | depth = 6,
22 | edge_dim = 512, # optional - if left out, edge dimensions is assumed to be the same as the node dimensions above
23 | with_feedforwards = True, # whether to add a feedforward after each attention layer, suggested by literature to be needed
24 | gated_residual = True, # to use the gated residual to prevent over-smoothing
25 | rel_pos_emb = True # set to True if the nodes are ordered, default to False
26 | )
27 |
28 | nodes = torch.randn(1, 128, 256)
29 | edges = torch.randn(1, 128, 128, 512)
30 | mask = torch.ones(1, 128).bool()
31 |
32 | nodes, edges = model(nodes, edges, mask = mask)
33 |
34 | nodes.shape # (1, 128, 256) - project to R^3 for coordinates
35 | ```
36 |
37 | If you want it to handle an adjacency matrix
38 |
39 | ```python
40 | import torch
41 | from graph_transformer_pytorch import GraphTransformer
42 |
43 | model = GraphTransformer(
44 | dim = 256,
45 | depth = 6,
46 | edge_dim = 512,
47 | with_feedforwards = True,
48 | gated_residual = True,
49 | rel_pos_emb = True,
50 | accept_adjacency_matrix = True # set this to True
51 | )
52 |
53 | nodes = torch.randn(2, 128, 256)
54 | adj_mat = torch.randint(0, 2, (2, 128, 128))
55 | mask = torch.ones(2, 128).bool()
56 |
57 | nodes, edges = model(nodes, adj_mat = adj_mat, mask = mask)
58 |
59 | nodes.shape # (1, 128, 256) - project to R^3 for coordinates
60 | ```
61 |
62 | ## Citations
63 |
64 | ```bibtex
65 | @article {Costa2021.06.02.446809,
66 | author = {Costa, Allan and Ponnapati, Manvitha and Jacobson, Joseph M. and Chatterjee, Pranam},
67 | title = {Distillation of MSA Embeddings to Folded Protein Structures with Graph Transformers},
68 | year = {2021},
69 | doi = {10.1101/2021.06.02.446809},
70 | publisher = {Cold Spring Harbor Laboratory},
71 | URL = {https://www.biorxiv.org/content/early/2021/06/02/2021.06.02.446809},
72 | eprint = {https://www.biorxiv.org/content/early/2021/06/02/2021.06.02.446809.full.pdf},
73 | journal = {bioRxiv}
74 | }
75 | ```
76 |
77 | ```bibtex
78 | @article {Baek2021.06.14.448402,
79 | author = {Baek, Minkyung and DiMaio, Frank and Anishchenko, Ivan and Dauparas, Justas and Ovchinnikov, Sergey and Lee, Gyu Rie and Wang, Jue and Cong, Qian and Kinch, Lisa N. and Schaeffer, R. Dustin and Mill{\'a}n, Claudia and Park, Hahnbeom and Adams, Carson and Glassman, Caleb R. and DeGiovanni, Andy and Pereira, Jose H. and Rodrigues, Andria V. and van Dijk, Alberdina A. and Ebrecht, Ana C. and Opperman, Diederik J. and Sagmeister, Theo and Buhlheller, Christoph and Pavkov-Keller, Tea and Rathinaswamy, Manoj K and Dalwadi, Udit and Yip, Calvin K and Burke, John E and Garcia, K. Christopher and Grishin, Nick V. and Adams, Paul D. and Read, Randy J. and Baker, David},
80 | title = {Accurate prediction of protein structures and interactions using a 3-track network},
81 | year = {2021},
82 | doi = {10.1101/2021.06.14.448402},
83 | publisher = {Cold Spring Harbor Laboratory},
84 | URL = {https://www.biorxiv.org/content/early/2021/06/15/2021.06.14.448402},
85 | eprint = {https://www.biorxiv.org/content/early/2021/06/15/2021.06.14.448402.full.pdf},
86 | journal = {bioRxiv}
87 | }
88 | ```
89 |
90 | ```bibtex
91 | @misc{shi2021masked,
92 | title = {Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification},
93 | author = {Yunsheng Shi and Zhengjie Huang and Shikun Feng and Hui Zhong and Wenjin Wang and Yu Sun},
94 | year = {2021},
95 | eprint = {2009.03509},
96 | archivePrefix = {arXiv},
97 | primaryClass = {cs.LG}
98 | }
99 | ```
100 |
--------------------------------------------------------------------------------
/graph-transformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/graph-transformer-pytorch/8613d2141c84f1cbaa4e2802e3091638c706c9ca/graph-transformer.png
--------------------------------------------------------------------------------
/graph_transformer_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from graph_transformer_pytorch.graph_transformer_pytorch import GraphTransformer
2 |
--------------------------------------------------------------------------------
/graph_transformer_pytorch/graph_transformer_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from einops import rearrange, repeat
4 |
5 | from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
6 |
7 | # helpers
8 |
9 | def exists(val):
10 | return val is not None
11 |
12 | def default(val, d):
13 | return val if exists(val) else d
14 |
15 | List = nn.ModuleList
16 |
17 | # normalizations
18 |
19 | class PreNorm(nn.Module):
20 | def __init__(
21 | self,
22 | dim,
23 | fn
24 | ):
25 | super().__init__()
26 | self.fn = fn
27 | self.norm = nn.LayerNorm(dim)
28 |
29 | def forward(self, x, *args, **kwargs):
30 | x = self.norm(x)
31 | return self.fn(x, *args,**kwargs)
32 |
33 | # gated residual
34 |
35 | class Residual(nn.Module):
36 | def forward(self, x, res):
37 | return x + res
38 |
39 | class GatedResidual(nn.Module):
40 | def __init__(self, dim):
41 | super().__init__()
42 | self.proj = nn.Sequential(
43 | nn.Linear(dim * 3, 1, bias = False),
44 | nn.Sigmoid()
45 | )
46 |
47 | def forward(self, x, res):
48 | gate_input = torch.cat((x, res, x - res), dim = -1)
49 | gate = self.proj(gate_input)
50 | return x * gate + res * (1 - gate)
51 |
52 | # attention
53 |
54 | class Attention(nn.Module):
55 | def __init__(
56 | self,
57 | dim,
58 | pos_emb = None,
59 | dim_head = 64,
60 | heads = 8,
61 | edge_dim = None
62 | ):
63 | super().__init__()
64 | edge_dim = default(edge_dim, dim)
65 |
66 | inner_dim = dim_head * heads
67 | self.heads = heads
68 | self.scale = dim_head ** -0.5
69 |
70 | self.pos_emb = pos_emb
71 |
72 | self.to_q = nn.Linear(dim, inner_dim)
73 | self.to_kv = nn.Linear(dim, inner_dim * 2)
74 | self.edges_to_kv = nn.Linear(edge_dim, inner_dim)
75 |
76 | self.to_out = nn.Linear(inner_dim, dim)
77 |
78 | def forward(self, nodes, edges, mask = None):
79 | h = self.heads
80 |
81 | q = self.to_q(nodes)
82 | k, v = self.to_kv(nodes).chunk(2, dim = -1)
83 |
84 | e_kv = self.edges_to_kv(edges)
85 |
86 | q, k, v, e_kv = map(lambda t: rearrange(t, 'b ... (h d) -> (b h) ... d', h = h), (q, k, v, e_kv))
87 |
88 | if exists(self.pos_emb):
89 | freqs = self.pos_emb(torch.arange(nodes.shape[1], device = nodes.device))
90 | freqs = rearrange(freqs, 'n d -> () n d')
91 | q = apply_rotary_emb(freqs, q)
92 | k = apply_rotary_emb(freqs, k)
93 |
94 | ek, ev = e_kv, e_kv
95 |
96 | k, v = map(lambda t: rearrange(t, 'b j d -> b () j d '), (k, v))
97 | k = k + ek
98 | v = v + ev
99 |
100 | sim = einsum('b i d, b i j d -> b i j', q, k) * self.scale
101 |
102 | if exists(mask):
103 | mask = rearrange(mask, 'b i -> b i ()') & rearrange(mask, 'b j -> b () j')
104 | mask = repeat(mask, 'b i j -> (b h) i j', h = h)
105 | max_neg_value = -torch.finfo(sim.dtype).max
106 | sim.masked_fill_(~mask, max_neg_value)
107 |
108 | attn = sim.softmax(dim = -1)
109 | out = einsum('b i j, b i j d -> b i d', attn, v)
110 | out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
111 | return self.to_out(out)
112 |
113 | # optional feedforward
114 |
115 | def FeedForward(dim, ff_mult = 4):
116 | return nn.Sequential(
117 | nn.Linear(dim, dim * ff_mult),
118 | nn.GELU(),
119 | nn.Linear(dim * ff_mult, dim)
120 | )
121 |
122 | # classes
123 |
124 | class GraphTransformer(nn.Module):
125 | def __init__(
126 | self,
127 | dim,
128 | depth,
129 | dim_head = 64,
130 | edge_dim = None,
131 | heads = 8,
132 | gated_residual = True,
133 | with_feedforwards = False,
134 | norm_edges = False,
135 | rel_pos_emb = False,
136 | accept_adjacency_matrix = False
137 | ):
138 | super().__init__()
139 | self.layers = List([])
140 | edge_dim = default(edge_dim, dim)
141 | self.norm_edges = nn.LayerNorm(edge_dim) if norm_edges else nn.Identity()
142 |
143 | self.adj_emb = nn.Embedding(2, edge_dim) if accept_adjacency_matrix else None
144 |
145 | pos_emb = RotaryEmbedding(dim_head) if rel_pos_emb else None
146 |
147 | for _ in range(depth):
148 | self.layers.append(List([
149 | List([
150 | PreNorm(dim, Attention(dim, pos_emb = pos_emb, edge_dim = edge_dim, dim_head = dim_head, heads = heads)),
151 | GatedResidual(dim)
152 | ]),
153 | List([
154 | PreNorm(dim, FeedForward(dim)),
155 | GatedResidual(dim)
156 | ]) if with_feedforwards else None
157 | ]))
158 |
159 | def forward(
160 | self,
161 | nodes,
162 | edges = None,
163 | adj_mat = None,
164 | mask = None
165 | ):
166 | batch, seq, _ = nodes.shape
167 |
168 | if exists(edges):
169 | edges = self.norm_edges(edges)
170 |
171 | if exists(adj_mat):
172 | assert adj_mat.shape == (batch, seq, seq)
173 | assert exists(self.adj_emb), 'accept_adjacency_matrix must be set to True'
174 | adj_mat = self.adj_emb(adj_mat.long())
175 |
176 | all_edges = default(edges, 0) + default(adj_mat, 0)
177 |
178 | for attn_block, ff_block in self.layers:
179 | attn, attn_residual = attn_block
180 | nodes = attn_residual(attn(nodes, all_edges, mask = mask), nodes)
181 |
182 | if exists(ff_block):
183 | ff, ff_residual = ff_block
184 | nodes = ff_residual(ff(nodes), nodes)
185 |
186 | return nodes, edges
187 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'graph-transformer-pytorch',
5 | packages = find_packages(),
6 | version = '0.1.1',
7 | license='MIT',
8 | description = 'Graph Transformer - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/graph-transformer-pytorch',
12 | long_description_content_type = 'text/markdown',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'transformers',
17 | 'graphs'
18 | ],
19 | install_requires=[
20 | 'einops>=0.3',
21 | 'rotary-embedding-torch',
22 | 'torch>=1.6'
23 | ],
24 | classifiers=[
25 | 'Development Status :: 4 - Beta',
26 | 'Intended Audience :: Developers',
27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
28 | 'License :: OSI Approved :: MIT License',
29 | 'Programming Language :: Python :: 3.6',
30 | ],
31 | )
32 |
--------------------------------------------------------------------------------