├── mat.png
├── molecule_attention_transformer
├── __init__.py
└── molecule_attention_transformer.py
├── setup.py
├── .github
└── workflows
│ └── python-publish.yml
├── LICENSE
├── README.md
└── .gitignore
/mat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/molecule-attention-transformer/HEAD/mat.png
--------------------------------------------------------------------------------
/molecule_attention_transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from molecule_attention_transformer.molecule_attention_transformer import MAT
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'molecule-attention-transformer',
5 | packages = find_packages(),
6 | version = '0.0.4',
7 | license='MIT',
8 | description = 'Molecule Attention Transformer - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/molecule-attention-transformer',
12 | keywords = [
13 | 'artificial intelligence',
14 | 'attention mechanism',
15 | 'molecules'
16 | ],
17 | install_requires=[
18 | 'torch>=1.6',
19 | 'einops>=0.3'
20 | ],
21 | classifiers=[
22 | 'Development Status :: 4 - Beta',
23 | 'Intended Audience :: Developers',
24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
25 | 'License :: OSI Approved :: MIT License',
26 | 'Programming Language :: Python :: 3.6',
27 | ],
28 | )
29 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Molecule Attention Transformer - Pytorch (wip)
4 |
5 | Pytorch reimplementation of Molecule Attention Transformer, which uses a slightly modified transformer to tackle the graph-like structure of molecules. The repository is also meant to be educational, to understand the limitations of transformers for processing graphs (or perhaps lack thereof).
6 |
7 | Update: Reread the paper and results do look convincing. However, I do not like how it still takes hyperparameter sweeps of the relative contributions of the distance, adjacency, and self attention matrices to achieve good results. There must be a more hands-off way
8 |
9 | ## Install
10 |
11 | ```bash
12 | $ pip install molecule-attention-transformer
13 | ```
14 |
15 | ## Usage
16 |
17 | ```python
18 | import torch
19 | from molecule_attention_transformer import MAT
20 |
21 | model = MAT(
22 | dim_in = 26,
23 | model_dim = 512,
24 | dim_out = 1,
25 | depth = 6,
26 | Lg = 0.5, # lambda (g)raph - weight for adjacency matrix
27 | Ld = 0.5, # lambda (d)istance - weight for distance matrix
28 | La = 1, # lambda (a)ttention - weight for usual self-attention
29 | dist_kernel_fn = 'exp' # distance kernel fn - either 'exp' or 'softmax'
30 | )
31 |
32 | atoms = torch.randn(2, 100, 26)
33 | mask = torch.ones(2, 100).bool()
34 | adjacency_mat = torch.empty(2, 100, 100).random_(2).float()
35 | distance_mat = torch.randn(2, 100, 100)
36 |
37 | out = model(
38 | atoms,
39 | mask = mask,
40 | adjacency_mat = adjacency_mat,
41 | distance_mat = distance_mat
42 | ) # (2, 1)
43 | ```
44 |
45 | ## Citations
46 |
47 | ```bibtex
48 | @misc{maziarka2020molecule,
49 | title={Molecule Attention Transformer},
50 | author={Łukasz Maziarka and Tomasz Danel and Sławomir Mucha and Krzysztof Rataj and Jacek Tabor and Stanisław Jastrzębski},
51 | year={2020},
52 | eprint={2002.08264},
53 | archivePrefix={arXiv},
54 | primaryClass={cs.LG}
55 | }
56 | ```
57 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/molecule_attention_transformer/molecule_attention_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from functools import partial
4 | from torch import nn, einsum
5 | from einops import rearrange
6 |
7 | # constants
8 |
9 | DIST_KERNELS = {
10 | 'exp': {
11 | 'fn': lambda t: torch.exp(-t),
12 | 'mask_value_fn': lambda t: torch.finfo(t.dtype).max
13 | },
14 | 'softmax': {
15 | 'fn': lambda t: torch.softmax(t, dim = -1),
16 | 'mask_value_fn': lambda t: -torch.finfo(t.dtype).max
17 | }
18 | }
19 |
20 | # helpers
21 |
22 | def exists(val):
23 | return val is not None
24 |
25 | def default(val, d):
26 | return d if not exists(val) else val
27 |
28 | # helper classes
29 |
30 | class Residual(nn.Module):
31 | def __init__(self, fn):
32 | super().__init__()
33 | self.fn = fn
34 |
35 | def forward(self, x, **kwargs):
36 | return x + self.fn(x, **kwargs)
37 |
38 | class PreNorm(nn.Module):
39 | def __init__(self, dim, fn):
40 | super().__init__()
41 | self.norm = nn.LayerNorm(dim)
42 | self.fn = fn
43 |
44 | def forward(self, x, **kwargs):
45 | x = self.norm(x)
46 | return self.fn(x, **kwargs)
47 |
48 | class FeedForward(nn.Module):
49 | def __init__(self, dim, dim_out = None, mult = 4):
50 | super().__init__()
51 | dim_out = default(dim_out, dim)
52 | self.net = nn.Sequential(
53 | nn.Linear(dim, dim * mult),
54 | nn.GELU(),
55 | nn.Linear(dim * mult, dim_out)
56 | )
57 |
58 | def forward(self, x):
59 | return self.net(x)
60 |
61 | class Attention(nn.Module):
62 | def __init__(self, dim, heads = 8, dim_head = 64, Lg = 0.5, Ld = 0.5, La = 1, dist_kernel_fn = 'exp'):
63 | super().__init__()
64 | inner_dim = dim_head * heads
65 | self.heads= heads
66 | self.scale = dim_head ** -0.5
67 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
68 | self.to_out = nn.Linear(inner_dim, dim)
69 |
70 | # hyperparameters controlling the weighted linear combination from
71 | # self-attention (La)
72 | # adjacency graph (Lg)
73 | # pair-wise distance matrix (Ld)
74 |
75 | self.La = La
76 | self.Ld = Ld
77 | self.Lg = Lg
78 |
79 | self.dist_kernel_fn = dist_kernel_fn
80 |
81 | def forward(self, x, mask = None, adjacency_mat = None, distance_mat = None):
82 | h, La, Ld, Lg, dist_kernel_fn = self.heads, self.La, self.Ld, self.Lg, self.dist_kernel_fn
83 |
84 | qkv = self.to_qkv(x)
85 | q, k, v = rearrange(qkv, 'b n (h qkv d) -> b h n qkv d', h = h, qkv = 3).unbind(dim = -2)
86 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
87 |
88 | assert dist_kernel_fn in DIST_KERNELS, f'distance kernel function needs to be one of {DISTANCE_KERNELS.keys()}'
89 | dist_kernel_config = DIST_KERNELS[dist_kernel_fn]
90 |
91 | if exists(distance_mat):
92 | distance_mat = rearrange(distance_mat, 'b i j -> b () i j')
93 |
94 | if exists(adjacency_mat):
95 | adjacency_mat = rearrange(adjacency_mat, 'b i j -> b () i j')
96 |
97 | if exists(mask):
98 | mask_value = torch.finfo(dots.dtype).max
99 | mask = mask[:, None, :, None] * mask[:, None, None, :]
100 |
101 | # mask attention
102 | dots.masked_fill_(~mask, -mask_value)
103 |
104 | if exists(distance_mat):
105 | # mask distance to infinity
106 | # todo - make sure for softmax distance kernel, use -infinity
107 | dist_mask_value = dist_kernel_config['mask_value_fn'](dots)
108 | distance_mat.masked_fill_(~mask, dist_mask_value)
109 |
110 | if exists(adjacency_mat):
111 | adjacency_mat.masked_fill_(~mask, 0.)
112 |
113 | attn = dots.softmax(dim = -1)
114 |
115 | # sum contributions from adjacency and distance tensors
116 | attn = attn * La
117 |
118 | if exists(adjacency_mat):
119 | attn = attn + Lg * adjacency_mat
120 |
121 | if exists(distance_mat):
122 | distance_mat = dist_kernel_config['fn'](distance_mat)
123 | attn = attn + Ld * distance_mat
124 |
125 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
126 | out = rearrange(out, 'b h n d -> b n (h d)')
127 | return self.to_out(out)
128 |
129 | # main class
130 |
131 | class MAT(nn.Module):
132 | def __init__(
133 | self,
134 | *,
135 | dim_in,
136 | model_dim,
137 | dim_out,
138 | depth,
139 | heads = 8,
140 | Lg = 0.5,
141 | Ld = 0.5,
142 | La = 1,
143 | dist_kernel_fn = 'exp'
144 | ):
145 | super().__init__()
146 |
147 | self.embed_to_model = nn.Linear(dim_in, model_dim)
148 | self.layers = nn.ModuleList([])
149 |
150 | for _ in range(depth):
151 | layer = nn.ModuleList([
152 | Residual(PreNorm(model_dim, Attention(model_dim, heads = heads, Lg = Lg, Ld = Ld, La = La, dist_kernel_fn = dist_kernel_fn))),
153 | Residual(PreNorm(model_dim, FeedForward(model_dim)))
154 | ])
155 | self.layers.append(layer)
156 |
157 | self.norm_out = nn.LayerNorm(model_dim)
158 | self.ff_out = FeedForward(model_dim, dim_out)
159 |
160 | def forward(
161 | self,
162 | x,
163 | mask = None,
164 | adjacency_mat = None,
165 | distance_mat = None
166 | ):
167 | x = self.embed_to_model(x)
168 |
169 | for (attn, ff) in self.layers:
170 | x = attn(
171 | x,
172 | mask = mask,
173 | adjacency_mat = adjacency_mat,
174 | distance_mat = distance_mat
175 | )
176 | x = ff(x)
177 |
178 | x = self.norm_out(x)
179 | x = x.mean(dim = -2)
180 | x = self.ff_out(x)
181 | return x
182 |
--------------------------------------------------------------------------------