├── .github
└── workflows
│ ├── python-package.yml
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── denoise.py
├── diagram.png
├── se3_transformer_pytorch
├── __init__.py
├── basis.py
├── data
│ ├── J_dense.npy
│ └── J_dense.pt
├── irr_repr.py
├── reversible.py
├── rotary.py
├── se3_transformer_pytorch.py
├── spherical_harmonics.py
└── utils.py
├── setup.cfg
├── setup.py
└── tests
├── test_basis.py
├── test_equivariance.py
├── test_irrep_repr.py
└── test_spherical_harmonics.py
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python package
5 |
6 | on:
7 | push:
8 | branches: [ main ]
9 | pull_request:
10 | branches: [ main ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 | strategy:
17 | matrix:
18 | python-version: [3.8, 3.9]
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python ${{ matrix.python-version }}
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | python setup.py install
30 | python -m pip install pytest
31 | python -m pip install torch==1.10.0
32 | - name: Test with pytest
33 | run: |
34 | python setup.py pytest
35 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # custom dont-upload files
7 | custom_tests/*
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include se3_transformer_pytorch/data/J_dense.pt
2 | include se3_transformer_pytorch/data/J_dense.npy
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## SE3 Transformer - Pytorch
4 |
5 | Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 results and other drug discovery applications.
6 |
7 | [](https://colab.research.google.com/drive/1ICW0DpXfUuVYsnNkt1DHwUyyTduHHvE3?usp=sharing) Example of equivariance
8 |
9 | If you had been using any version of SE3 Transformers prior to version 0.6.0, please update. A huge bug has been uncovered by @MattMcPartlon, if you were not using the adjacency sparse neighbors settings and relying on nearest neighbors functionality
10 |
11 | Update: It is recommended that you use Equiformer instead
12 |
13 | ## Install
14 |
15 | ```bash
16 | $ pip install se3-transformer-pytorch
17 | ```
18 |
19 | ## Usage
20 |
21 | ```python
22 | import torch
23 | from se3_transformer_pytorch import SE3Transformer
24 |
25 | model = SE3Transformer(
26 | dim = 512,
27 | heads = 8,
28 | depth = 6,
29 | dim_head = 64,
30 | num_degrees = 4,
31 | valid_radius = 10
32 | )
33 |
34 | feats = torch.randn(1, 1024, 512)
35 | coors = torch.randn(1, 1024, 3)
36 | mask = torch.ones(1, 1024).bool()
37 |
38 | out = model(feats, coors, mask) # (1, 1024, 512)
39 | ```
40 |
41 | Potential example usage in Alphafold2, as outlined here
42 |
43 | ```python
44 | import torch
45 | from se3_transformer_pytorch import SE3Transformer
46 |
47 | model = SE3Transformer(
48 | dim = 64,
49 | depth = 2,
50 | input_degrees = 1,
51 | num_degrees = 2,
52 | output_degrees = 2,
53 | reduce_dim_out = True,
54 | differentiable_coors = True
55 | )
56 |
57 | atom_feats = torch.randn(2, 32, 64)
58 | coors = torch.randn(2, 32, 3)
59 | mask = torch.ones(2, 32).bool()
60 |
61 | refined_coors = coors + model(atom_feats, coors, mask, return_type = 1) # (2, 32, 3)
62 | ```
63 |
64 | You can also let the base transformer class take care of embedding the type 0 features being passed in. Assuming they are atoms
65 |
66 | ```python
67 | import torch
68 | from se3_transformer_pytorch import SE3Transformer
69 |
70 | model = SE3Transformer(
71 | num_tokens = 28, # 28 unique atoms
72 | dim = 64,
73 | depth = 2,
74 | input_degrees = 1,
75 | num_degrees = 2,
76 | output_degrees = 2,
77 | reduce_dim_out = True
78 | )
79 |
80 | atoms = torch.randint(0, 28, (2, 32))
81 | coors = torch.randn(2, 32, 3)
82 | mask = torch.ones(2, 32).bool()
83 |
84 | refined_coors = coors + model(atoms, coors, mask, return_type = 1) # (2, 32, 3)
85 | ```
86 |
87 | If you think the net could further benefit from positional encoding, you can featurize your positions in space and pass it in as follows.
88 |
89 | ```python
90 | import torch
91 | from se3_transformer_pytorch import SE3Transformer
92 |
93 | model = SE3Transformer(
94 | dim = 64,
95 | depth = 2,
96 | input_degrees = 2,
97 | num_degrees = 2,
98 | output_degrees = 2,
99 | reduce_dim_out = True # reduce out the final dimension
100 | )
101 |
102 | atom_feats = torch.randn(2, 32, 64, 1) # b x n x d x type0
103 | coors_feats = torch.randn(2, 32, 64, 3) # b x n x d x type1
104 |
105 | # atom features are type 0, predicted coordinates are type 1
106 | features = {'0': atom_feats, '1': coors_feats}
107 | coors = torch.randn(2, 32, 3)
108 | mask = torch.ones(2, 32).bool()
109 |
110 | refined_coors = coors + model(features, coors, mask, return_type = 1) # (2, 32, 3) - equivariant to input type 1 features and coordinates
111 | ```
112 |
113 | ## Edges
114 |
115 | To offer edge information to SE3 Transformers (say bond types between atoms), you just have to pass in two more keyword arguments on initialization.
116 |
117 | ```python
118 | import torch
119 | from se3_transformer_pytorch import SE3Transformer
120 |
121 | model = SE3Transformer(
122 | num_tokens = 28,
123 | dim = 64,
124 | num_edge_tokens = 4, # number of edge type, say 4 bond types
125 | edge_dim = 16, # dimension of edge embedding
126 | depth = 2,
127 | input_degrees = 1,
128 | num_degrees = 3,
129 | output_degrees = 1,
130 | reduce_dim_out = True
131 | )
132 |
133 | atoms = torch.randint(0, 28, (2, 32))
134 | bonds = torch.randint(0, 4, (2, 32, 32))
135 | coors = torch.randn(2, 32, 3)
136 | mask = torch.ones(2, 32).bool()
137 |
138 | pred = model(atoms, coors, mask, edges = bonds, return_type = 0) # (2, 32, 1)
139 | ```
140 |
141 | If you would like to pass in continuous values for your edges, you can choose to not set the `num_edge_tokens`, encode your discrete bond types, and then concat it to the fourier features of these continuous values
142 |
143 | ```python
144 | import torch
145 | from se3_transformer_pytorch import SE3Transformer
146 | from se3_transformer_pytorch.utils import fourier_encode
147 |
148 | model = SE3Transformer(
149 | dim = 64,
150 | depth = 1,
151 | attend_self = True,
152 | num_degrees = 2,
153 | output_degrees = 2,
154 | edge_dim = 34 # edge dimension must match the final dimension of the edges being passed in
155 | )
156 |
157 | feats = torch.randn(1, 32, 64)
158 | coors = torch.randn(1, 32, 3)
159 | mask = torch.ones(1, 32).bool()
160 |
161 | pairwise_continuous_values = torch.randint(0, 4, (1, 32, 32, 2)) # say there are 2
162 |
163 | edges = fourier_encode(
164 | pairwise_continuous_values,
165 | num_encodings = 8,
166 | include_self = True
167 | ) # (1, 32, 32, 34) - {2 * (2 * 8 + 1)}
168 |
169 | out = model(feats, coors, mask, edges = edges, return_type = 1)
170 | ```
171 |
172 | ## Sparse Neighbors
173 |
174 | If you know the connectivity of your points (say you are working with molecules), you can pass in an adjacency matrix, in the form of a boolean mask (where `True` indicates connectivity).
175 |
176 | ```python
177 | import torch
178 | from se3_transformer_pytorch import SE3Transformer
179 |
180 | model = SE3Transformer(
181 | dim = 32,
182 | heads = 8,
183 | depth = 1,
184 | dim_head = 64,
185 | num_degrees = 2,
186 | valid_radius = 10,
187 | attend_sparse_neighbors = True, # this must be set to true, in which case it will assert that you pass in the adjacency matrix
188 | num_neighbors = 0, # if you set this to 0, it will only consider the connected neighbors as defined by the adjacency matrix. but if you set a value greater than 0, it will continue to fetch the closest points up to this many, excluding the ones already specified by the adjacency matrix
189 | max_sparse_neighbors = 8 # you can cap the number of neighbors, sampled from within your sparse set of neighbors as defined by the adjacency matrix, if specified
190 | )
191 |
192 | feats = torch.randn(1, 128, 32)
193 | coors = torch.randn(1, 128, 3)
194 | mask = torch.ones(1, 128).bool()
195 |
196 | # placeholder adjacency matrix
197 | # naively assuming the sequence is one long chain (128, 128)
198 |
199 | i = torch.arange(128)
200 | adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))
201 |
202 | out = model(feats, coors, mask, adj_mat = adj_mat) # (1, 128, 512)
203 | ```
204 |
205 | You can also have the network automatically derive for you the Nth-degree neighbors with one extra keyword `num_adj_degrees`. If you would like the system to differentiate between the degree of the neighbors as edge information, further pass in a non-zero `adj_dim`.
206 |
207 | ```python
208 | import torch
209 | from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer
210 |
211 | model = SE3Transformer(
212 | dim = 64,
213 | depth = 1,
214 | attend_self = True,
215 | num_degrees = 2,
216 | output_degrees = 2,
217 | num_neighbors = 0,
218 | attend_sparse_neighbors = True,
219 | num_adj_degrees = 2, # automatically derive 2nd degree neighbors
220 | adj_dim = 4 # embed 1st and 2nd degree neighbors (as well as null neighbors) with edge embeddings of this dimension
221 | )
222 |
223 | feats = torch.randn(1, 32, 64)
224 | coors = torch.randn(1, 32, 3)
225 | mask = torch.ones(1, 32).bool()
226 |
227 | # placeholder adjacency matrix
228 | # naively assuming the sequence is one long chain (128, 128)
229 |
230 | i = torch.arange(128)
231 | adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))
232 |
233 | out = model(feats, coors, mask, adj_mat = adj_mat, return_type = 1)
234 | ```
235 |
236 | To have fine control over the dimensionality of each type, you can use the `hidden_fiber_dict` and `out_fiber_dict` keywords to pass in a dictionary with the degree to dimension values as the key / values.
237 |
238 | ```python
239 | import torch
240 | from se3_transformer_pytorch import SE3Transformer
241 |
242 | model = SE3Transformer(
243 | num_tokens = 28,
244 | dim = 64,
245 | num_edge_tokens = 4,
246 | edge_dim = 16,
247 | depth = 2,
248 | input_degrees = 1,
249 | num_degrees = 3,
250 | output_degrees = 1,
251 | hidden_fiber_dict = {0: 16, 1: 8, 2: 4},
252 | out_fiber_dict = {0: 16, 1: 1},
253 | reduce_dim_out = False
254 | )
255 |
256 | atoms = torch.randint(0, 28, (2, 32))
257 | bonds = torch.randint(0, 4, (2, 32, 32))
258 | coors = torch.randn(2, 32, 3)
259 | mask = torch.ones(2, 32).bool()
260 |
261 | pred = model(atoms, coors, mask, edges = bonds)
262 |
263 | pred['0'] # (2, 32, 16)
264 | pred['1'] # (2, 32, 1, 3)
265 | ```
266 |
267 | ## Neighbors
268 |
269 | You can further control which nodes can be considered by passing in a neighbor mask. All `False` values will be masked out of consideration.
270 |
271 | ```python
272 | import torch
273 | from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer
274 |
275 | model = SE3Transformer(
276 | dim = 16,
277 | dim_head = 16,
278 | attend_self = True,
279 | num_degrees = 4,
280 | output_degrees = 2,
281 | num_edge_tokens = 4,
282 | num_neighbors = 8, # make sure you set this value as the maximum number of neighbors set by your neighbor_mask, or it will throw a warning
283 | edge_dim = 2,
284 | depth = 3
285 | )
286 |
287 | feats = torch.randn(1, 32, 16)
288 | coors = torch.randn(1, 32, 3)
289 | mask = torch.ones(1, 32).bool()
290 | bonds = torch.randint(0, 4, (1, 32, 32))
291 |
292 | neighbor_mask = torch.ones(1, 32, 32).bool() # set the nodes you wish to be masked out as False
293 |
294 | out = model(
295 | feats,
296 | coors,
297 | mask,
298 | edges = bonds,
299 | neighbor_mask = neighbor_mask,
300 | return_type = 1
301 | )
302 | ```
303 |
304 | ## Global Nodes
305 |
306 | This feature allows you to pass in vectors that can be viewed as global nodes that are seen by all other nodes. The idea would be to pool your graph into a few feature vectors, which will be projected to key / values across all the attention layers in the network. All nodes will have full access to global node information, regardless of nearest neighbors or adjacency calculation.
307 |
308 | ```python
309 | import torch
310 | from torch import nn
311 | from se3_transformer_pytorch import SE3Transformer
312 |
313 | model = SE3Transformer(
314 | dim = 64,
315 | depth = 1,
316 | num_degrees = 2,
317 | num_neighbors = 4,
318 | valid_radius = 10,
319 | global_feats_dim = 32 # this must be set to the dimension of the global features, in this example, 32
320 | )
321 |
322 | feats = torch.randn(1, 32, 64)
323 | coors = torch.randn(1, 32, 3)
324 | mask = torch.ones(1, 32).bool()
325 |
326 | # naively derive global features
327 | # by pooling features and projecting
328 | global_feats = nn.Linear(64, 32)(feats.mean(dim = 1, keepdim = True)) # (1, 1, 32)
329 |
330 | out = model(feats, coors, mask, return_type = 0, global_feats = global_feats)
331 | ```
332 |
333 | Todo:
334 |
335 | - [ ] allow global nodes to attend to all other nodes, to give the network a global conduit for information. (Similar to BigBird, ETC, Longformer etc)
336 |
337 | ## Autoregressive
338 |
339 | You can use SE3 Transformers autoregressively with just one extra flag
340 |
341 | ```python
342 | import torch
343 | from se3_transformer_pytorch import SE3Transformer
344 |
345 | model = SE3Transformer(
346 | dim = 512,
347 | heads = 8,
348 | depth = 6,
349 | dim_head = 64,
350 | num_degrees = 4,
351 | valid_radius = 10,
352 | causal = True # set this to True
353 | )
354 |
355 | feats = torch.randn(1, 1024, 512)
356 | coors = torch.randn(1, 1024, 3)
357 | mask = torch.ones(1, 1024).bool()
358 |
359 | out = model(feats, coors, mask) # (1, 1024, 512)
360 | ```
361 |
362 | ## Experimental Features
363 |
364 | ### Non-pairwise convolved keys
365 |
366 | I've discovered that using linearly projected keys (rather than the pairwise convolution) seems to do ok in a toy denoising task. This leads to 25% memory savings. You can try this feature by setting `linear_proj_keys = True`
367 |
368 | ```python
369 | import torch
370 | from se3_transformer_pytorch import SE3Transformer
371 |
372 | model = SE3Transformer(
373 | dim = 64,
374 | depth = 1,
375 | num_degrees = 4,
376 | num_neighbors = 8,
377 | valid_radius = 10,
378 | splits = 4,
379 | linear_proj_keys = True # set this to True
380 | ).cuda()
381 |
382 | feats = torch.randn(1, 32, 64).cuda()
383 | coors = torch.randn(1, 32, 3).cuda()
384 | mask = torch.ones(1, 32).bool().cuda()
385 |
386 | out = model(feats, coors, mask, return_type = 0)
387 | ```
388 |
389 | ### Shared key / values across all heads
390 |
391 | There is a relatively unknown technique for transformers where one can share one key / value head across all the heads of the queries. In my experience in NLP, this usually leads to worse performance, but if you are really in need to tradeoff memory for more depth or higher number of degrees, this may be a good option.
392 |
393 | ```python
394 | import torch
395 | from se3_transformer_pytorch import SE3Transformer
396 |
397 | model = SE3Transformer(
398 | dim = 64,
399 | depth = 8,
400 | num_degrees = 4,
401 | num_neighbors = 8,
402 | valid_radius = 10,
403 | splits = 4,
404 | one_headed_key_values = True # one head of key / values shared across all heads of the queries
405 | ).cuda()
406 |
407 | feats = torch.randn(1, 32, 64).cuda()
408 | coors = torch.randn(1, 32, 3).cuda()
409 | mask = torch.ones(1, 32).bool().cuda()
410 |
411 | out = model(feats, coors, mask, return_type = 0)
412 | ```
413 |
414 | ### Tied key / values
415 |
416 | You can also tie the key / values (have them be the same), for half memory savings
417 |
418 | ```python
419 | import torch
420 | from se3_transformer_pytorch import SE3Transformer
421 |
422 | model = SE3Transformer(
423 | dim = 64,
424 | depth = 8,
425 | num_degrees = 4,
426 | num_neighbors = 8,
427 | valid_radius = 10,
428 | splits = 4,
429 | tie_key_values = True # set this to True
430 | ).cuda()
431 |
432 | feats = torch.randn(1, 32, 64).cuda()
433 | coors = torch.randn(1, 32, 3).cuda()
434 | mask = torch.ones(1, 32).bool().cuda()
435 |
436 | out = model(feats, coors, mask, return_type = 0)
437 | ```
438 |
439 | ### Using EGNN
440 |
441 | This is an experimental version of EGNN that works for higher types, and greater dimensionality than just 1 (for the coordinates). The class name is still `SE3Transformer` since it reuses some preexisting logic, so just ignore that for now until I clean it up later.
442 |
443 | ```python
444 | import torch
445 | from se3_transformer_pytorch import SE3Transformer
446 |
447 | model = SE3Transformer(
448 | dim = 32,
449 | num_neighbors = 8,
450 | num_edge_tokens = 4,
451 | edge_dim = 4,
452 | num_degrees = 4, # number of higher order types - will use basis on a TCN to project to these dimensions
453 | use_egnn = True, # set this to true to use EGNN instead of equivariant attention layers
454 | egnn_hidden_dim = 64, # egnn hidden dimension
455 | depth = 4, # depth of EGNN
456 | reduce_dim_out = True # will project the dimension of the higher types to 1
457 | ).cuda()
458 |
459 | feats = torch.randn(2, 32, 32).cuda()
460 | coors = torch.randn(2, 32, 3).cuda()
461 | bonds = torch.randint(0, 4, (2, 32, 32)).cuda()
462 | mask = torch.ones(2, 32).bool().cuda()
463 |
464 | refinement = model(feats, coors, mask, edges = bonds, return_type = 1) # (2, 32, 3)
465 |
466 | coors = coors + refinement # update coors with refinement
467 | ```
468 |
469 | If you would like to specify individual dimensions for each of the higher types, just pass in `hidden_fiber_dict` where the dictionary is in the format {\:\} instead of `num_degrees`
470 |
471 | ```python
472 | import torch
473 | from se3_transformer_pytorch import SE3Transformer
474 |
475 | model = SE3Transformer(
476 | dim = 32,
477 | num_neighbors = 8,
478 | hidden_fiber_dict = {0: 32, 1: 16, 2: 8, 3: 4},
479 | use_egnn = True,
480 | depth = 4,
481 | egnn_hidden_dim = 64,
482 | egnn_weights_clamp_value = 2,
483 | reduce_dim_out = True
484 | ).cuda()
485 |
486 | feats = torch.randn(2, 32, 32).cuda()
487 | coors = torch.randn(2, 32, 3).cuda()
488 | mask = torch.ones(2, 32).bool().cuda()
489 |
490 | refinement = model(feats, coors, mask, return_type = 1) # (2, 32, 3)
491 |
492 | coors = coors + refinement # update coors with refinement
493 | ```
494 |
495 | ## Scaling (wip)
496 |
497 | This section will list ongoing efforts to make SE3 Transformer scale a little better.
498 |
499 | Firstly, I have added reversible networks. This allows me to add a little more depth before hitting the usual memory roadblocks. Equivariance preservation is demonstrated in the tests.
500 |
501 | ```python
502 | import torch
503 | from se3_transformer_pytorch import SE3Transformer
504 |
505 | model = SE3Transformer(
506 | num_tokens = 20,
507 | dim = 32,
508 | dim_head = 32,
509 | heads = 4,
510 | depth = 12, # 12 layers
511 | input_degrees = 1,
512 | num_degrees = 3,
513 | output_degrees = 1,
514 | reduce_dim_out = True,
515 | reversible = True # set reversible to True
516 | ).cuda()
517 |
518 | atoms = torch.randint(0, 4, (2, 32)).cuda()
519 | coors = torch.randn(2, 32, 3).cuda()
520 | mask = torch.ones(2, 32).bool().cuda()
521 |
522 | pred = model(atoms, coors, mask = mask, return_type = 0)
523 |
524 | loss = pred.sum()
525 | loss.backward()
526 | ```
527 |
528 | ## Examples
529 |
530 | First install `sidechainnet`
531 |
532 | ```bash
533 | $ pip install sidechainnet
534 | ```
535 |
536 | Then run the protein backbone denoising task
537 |
538 | ```bash
539 | $ python denoise.py
540 | ```
541 |
542 | ## Caching
543 |
544 | By default, the basis vectors are cached. However, if there is ever the need to clear the cache, you simply have to set the environmental flag `CLEAR_CACHE` to some value on initiating the script
545 |
546 | ```bash
547 | $ CLEAR_CACHE=1 python train.py
548 | ```
549 |
550 | Or you can try deleting the cache directory, which should exist at
551 |
552 | ```bash
553 | $ rm -rf ~/.cache.equivariant_attention
554 | ```
555 |
556 | You can also designate your own directory where you want the caches to be stored, in the case that the default directory may have permission issues
557 |
558 | ```bash
559 | CACHE_PATH=./path/to/my/cache python train.py
560 | ```
561 |
562 | ## Testing
563 |
564 | ```bash
565 | $ python setup.py pytest
566 | ```
567 |
568 | ## Credit
569 |
570 | This library is largely a port of Fabian's official repository, but without the DGL library.
571 |
572 | ## Citations
573 |
574 | ```bibtex
575 | @misc{fuchs2020se3transformers,
576 | title = {SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks},
577 | author = {Fabian B. Fuchs and Daniel E. Worrall and Volker Fischer and Max Welling},
578 | year = {2020},
579 | eprint = {2006.10503},
580 | archivePrefix = {arXiv},
581 | primaryClass = {cs.LG}
582 | }
583 | ```
584 |
585 | ```bibtex
586 | @misc{satorras2021en,
587 | title = {E(n) Equivariant Graph Neural Networks},
588 | author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
589 | year = {2021},
590 | eprint = {2102.09844},
591 | archivePrefix = {arXiv},
592 | primaryClass = {cs.LG}
593 | }
594 | ```
595 |
596 | ```bibtex
597 | @misc{gomez2017reversible,
598 | title = {The Reversible Residual Network: Backpropagation Without Storing Activations},
599 | author = {Aidan N. Gomez and Mengye Ren and Raquel Urtasun and Roger B. Grosse},
600 | year = {2017},
601 | eprint = {1707.04585},
602 | archivePrefix = {arXiv},
603 | primaryClass = {cs.CV}
604 | }
605 | ```
606 |
607 | ```bibtex
608 | @misc{shazeer2019fast,
609 | title = {Fast Transformer Decoding: One Write-Head is All You Need},
610 | author = {Noam Shazeer},
611 | year = {2019},
612 | eprint = {1911.02150},
613 | archivePrefix = {arXiv},
614 | primaryClass = {cs.NE}
615 | }
616 | ```
617 |
--------------------------------------------------------------------------------
/denoise.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.optim import Adam
4 |
5 | from einops import rearrange, repeat
6 |
7 | import sidechainnet as scn
8 | from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer
9 |
10 | torch.set_default_dtype(torch.float64)
11 |
12 | BATCH_SIZE = 1
13 | GRADIENT_ACCUMULATE_EVERY = 16
14 |
15 | def cycle(loader, len_thres = 500):
16 | while True:
17 | for data in loader:
18 | if data.seqs.shape[1] > len_thres:
19 | continue
20 | yield data
21 |
22 | transformer = SE3Transformer(
23 | num_tokens = 24,
24 | dim = 8,
25 | dim_head = 8,
26 | heads = 2,
27 | depth = 2,
28 | attend_self = True,
29 | input_degrees = 1,
30 | output_degrees = 2,
31 | reduce_dim_out = True,
32 | differentiable_coors = True,
33 | num_neighbors = 0,
34 | attend_sparse_neighbors = True,
35 | num_adj_degrees = 2,
36 | adj_dim = 4,
37 | num_degrees=2,
38 | )
39 |
40 | data = scn.load(
41 | casp_version = 12,
42 | thinning = 30,
43 | with_pytorch = 'dataloaders',
44 | batch_size = BATCH_SIZE,
45 | dynamic_batching = False
46 | )
47 | # Add gaussian noise to the coords
48 | # Testing the refinement algorithm
49 |
50 | dl = cycle(data['train'])
51 | optim = Adam(transformer.parameters(), lr=1e-4)
52 | transformer = transformer.cuda()
53 |
54 | for _ in range(10000):
55 | for _ in range(GRADIENT_ACCUMULATE_EVERY):
56 | batch = next(dl)
57 | seqs, coords, masks = batch.seqs, batch.crds, batch.msks
58 |
59 | seqs = seqs.cuda().argmax(dim = -1)
60 | coords = coords.cuda().type(torch.float64)
61 | masks = masks.cuda().bool()
62 |
63 | l = seqs.shape[1]
64 | coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)
65 |
66 | # Keeping only the backbone coordinates
67 | coords = coords[:, :, 0:3, :]
68 | coords = rearrange(coords, 'b l s c -> b (l s) c')
69 |
70 | seq = repeat(seqs, 'b n -> b (n c)', c = 3)
71 | masks = repeat(masks, 'b n -> b (n c)', c = 3)
72 |
73 | noised_coords = coords + torch.randn_like(coords).cuda()
74 |
75 | i = torch.arange(seq.shape[-1], device = seqs.device)
76 | adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
77 |
78 | out = transformer(
79 | seq,
80 | noised_coords,
81 | mask = masks,
82 | adj_mat = adj_mat,
83 | return_type = 1
84 | )
85 |
86 | denoised_coords = noised_coords + out
87 |
88 | loss = F.mse_loss(denoised_coords[masks], coords[masks])
89 | (loss / GRADIENT_ACCUMULATE_EVERY).backward()
90 |
91 | print('loss:', loss.item())
92 | optim.step()
93 | optim.zero_grad()
94 |
--------------------------------------------------------------------------------
/diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/se3-transformer-pytorch/e1669ee69345c271b7aa0a3aeeda452d32e26736/diagram.png
--------------------------------------------------------------------------------
/se3_transformer_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer
2 |
--------------------------------------------------------------------------------
/se3_transformer_pytorch/basis.py:
--------------------------------------------------------------------------------
1 | import os
2 | from math import pi
3 | import torch
4 | from torch import einsum
5 | from einops import rearrange
6 | from itertools import product
7 | from contextlib import contextmanager
8 |
9 | from se3_transformer_pytorch.irr_repr import irr_repr, spherical_harmonics
10 | from se3_transformer_pytorch.utils import torch_default_dtype, cache_dir, exists, default, to_order
11 | from se3_transformer_pytorch.spherical_harmonics import clear_spherical_harmonics_cache
12 |
13 | # constants
14 |
15 | CACHE_PATH = default(os.getenv('CACHE_PATH'), os.path.expanduser('~/.cache.equivariant_attention'))
16 | CACHE_PATH = CACHE_PATH if not exists(os.environ.get('CLEAR_CACHE')) else None
17 |
18 | # todo (figure ot why this was hard coded in official repo)
19 |
20 | RANDOM_ANGLES = [
21 | [4.41301023, 5.56684102, 4.59384642],
22 | [4.93325116, 6.12697327, 4.14574096],
23 | [0.53878964, 4.09050444, 5.36539036],
24 | [2.16017393, 3.48835314, 5.55174441],
25 | [2.52385107, 0.2908958, 3.90040975]
26 | ]
27 |
28 | # helpers
29 |
30 | @contextmanager
31 | def null_context():
32 | yield
33 |
34 | # functions
35 |
36 | def get_matrix_kernel(A, eps = 1e-10):
37 | '''
38 | Compute an orthonormal basis of the kernel (x_1, x_2, ...)
39 | A x_i = 0
40 | scalar_product(x_i, x_j) = delta_ij
41 |
42 | :param A: matrix
43 | :return: matrix where each row is a basis vector of the kernel of A
44 | '''
45 | _u, s, v = torch.svd(A)
46 | kernel = v.t()[s < eps]
47 | return kernel
48 |
49 |
50 | def get_matrices_kernel(As, eps = 1e-10):
51 | '''
52 | Computes the common kernel of all the As matrices
53 | '''
54 | matrix = torch.cat(As, dim=0)
55 | return get_matrix_kernel(matrix, eps)
56 |
57 | def get_spherical_from_cartesian(cartesian, divide_radius_by = 1.0):
58 | """
59 | # ON ANGLE CONVENTION
60 | #
61 | # sh has following convention for angles:
62 | # :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)).
63 | # :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi.
64 | #
65 | # the 3D steerable CNN code therefore (probably) has the following convention for alpha and beta:
66 | # beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1)).
67 | # alpha = phi
68 | #
69 | """
70 | # initialise return array
71 | spherical = torch.zeros_like(cartesian)
72 |
73 | # indices for return array
74 | ind_radius, ind_alpha, ind_beta = 0, 1, 2
75 |
76 | cartesian_x, cartesian_y, cartesian_z = 2, 0, 1
77 |
78 | # get projected radius in xy plane
79 | r_xy = cartesian[..., cartesian_x] ** 2 + cartesian[..., cartesian_y] ** 2
80 |
81 | # get second angle
82 | # version 'elevation angle defined from Z-axis down'
83 | spherical[..., ind_beta] = torch.atan2(torch.sqrt(r_xy), cartesian[..., cartesian_z])
84 |
85 | # get angle in x-y plane
86 | spherical[...,ind_alpha] = torch.atan2(cartesian[...,cartesian_y], cartesian[...,cartesian_x])
87 |
88 | # get overall radius
89 | radius = torch.sqrt(r_xy + cartesian[...,cartesian_z]**2)
90 |
91 | if divide_radius_by != 1.0:
92 | radius /= divide_radius_by
93 |
94 | spherical[..., ind_radius] = radius
95 | return spherical
96 |
97 | def kron(a, b):
98 | """
99 | A part of the pylabyk library: numpytorch.py at https://github.com/yulkang/pylabyk
100 |
101 | Kronecker product of matrices a and b with leading batch dimensions.
102 | Batch dimensions are broadcast. The number of them mush
103 | :type a: torch.Tensor
104 | :type b: torch.Tensor
105 | :rtype: torch.Tensor
106 | """
107 | res = einsum('... i j, ... k l -> ... i k j l', a, b)
108 | return rearrange(res, '... i j k l -> ... (i j) (k l)')
109 |
110 | def get_R_tensor(order_out, order_in, a, b, c):
111 | return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c))
112 |
113 | def sylvester_submatrix(order_out, order_in, J, a, b, c):
114 | ''' generate Kronecker product matrix for solving the Sylvester equation in subspace J '''
115 | R_tensor = get_R_tensor(order_out, order_in, a, b, c) # [m_out * m_in, m_out * m_in]
116 | R_irrep_J = irr_repr(J, a, b, c) # [m, m]
117 |
118 | R_tensor_identity = torch.eye(R_tensor.shape[0])
119 | R_irrep_J_identity = torch.eye(R_irrep_J.shape[0])
120 |
121 | return kron(R_tensor, R_irrep_J_identity) - kron(R_tensor_identity, R_irrep_J.t()) # [(m_out * m_in) * m, (m_out * m_in) * m]
122 |
123 | @cache_dir(CACHE_PATH)
124 | @torch_default_dtype(torch.float64)
125 | @torch.no_grad()
126 | def basis_transformation_Q_J(J, order_in, order_out, random_angles = RANDOM_ANGLES):
127 | """
128 | :param J: order of the spherical harmonics
129 | :param order_in: order of the input representation
130 | :param order_out: order of the output representation
131 | :return: one part of the Q^-1 matrix of the article
132 | """
133 | sylvester_submatrices = [sylvester_submatrix(order_out, order_in, J, a, b, c) for a, b, c in random_angles]
134 | null_space = get_matrices_kernel(sylvester_submatrices)
135 | assert null_space.size(0) == 1, null_space.size() # unique subspace solution
136 | Q_J = null_space[0] # [(m_out * m_in) * m]
137 | Q_J = Q_J.view(to_order(order_out) * to_order(order_in), to_order(J)) # [m_out * m_in, m]
138 | return Q_J.float() # [m_out * m_in, m]
139 |
140 | def precompute_sh(r_ij, max_J):
141 | """
142 | pre-comput spherical harmonics up to order max_J
143 |
144 | :param r_ij: relative positions
145 | :param max_J: maximum order used in entire network
146 | :return: dict where each entry has shape [B,N,K,2J+1]
147 | """
148 | i_alpha, i_beta = 1, 2
149 | Y_Js = {J: spherical_harmonics(J, r_ij[...,i_alpha], r_ij[...,i_beta]) for J in range(max_J + 1)}
150 | clear_spherical_harmonics_cache()
151 | return Y_Js
152 |
153 | def get_basis(r_ij, max_degree, differentiable = False):
154 | """Return equivariant weight basis (basis)
155 |
156 | Call this function *once* at the start of each forward pass of the model.
157 | It computes the equivariant weight basis, W_J^lk(x), and internodal
158 | distances, needed to compute varphi_J^lk(x), of eqn 8 of
159 | https://arxiv.org/pdf/2006.10503.pdf. The return values of this function
160 | can be shared as input across all SE(3)-Transformer layers in a model.
161 |
162 | Args:
163 | r_ij: relative positional vectors
164 | max_degree: non-negative int for degree of highest feature-type
165 | differentiable: whether r_ij should receive gradients from basis
166 | Returns:
167 | dict of equivariant bases, keys are in form ''
168 | """
169 |
170 | # Relative positional encodings (vector)
171 | context = null_context if not differentiable else torch.no_grad
172 |
173 | device, dtype = r_ij.device, r_ij.dtype
174 |
175 | with context():
176 | r_ij = get_spherical_from_cartesian(r_ij)
177 |
178 | # Spherical harmonic basis
179 | Y = precompute_sh(r_ij, 2 * max_degree)
180 |
181 | # Equivariant basis (dict['d_in>'])
182 |
183 | basis = {}
184 | for d_in, d_out in product(range(max_degree+1), range(max_degree+1)):
185 | K_Js = []
186 | for J in range(abs(d_in - d_out), d_in + d_out + 1):
187 | # Get spherical harmonic projection matrices
188 | Q_J = basis_transformation_Q_J(J, d_in, d_out)
189 | Q_J = Q_J.type(dtype).to(device)
190 |
191 | # Create kernel from spherical harmonics
192 | K_J = torch.matmul(Y[J], Q_J.T)
193 | K_Js.append(K_J)
194 |
195 | # Reshape so can take linear combinations with a dot product
196 | K_Js = torch.stack(K_Js, dim = -1)
197 | size = (*r_ij.shape[:-1], 1, to_order(d_out), 1, to_order(d_in), to_order(min(d_in,d_out)))
198 | basis[f'{d_in},{d_out}'] = K_Js.view(*size)
199 |
200 | # extra detach for safe measure
201 | if not differentiable:
202 | for k, v in basis.items():
203 | basis[k] = v.detach()
204 |
205 | return basis
206 |
--------------------------------------------------------------------------------
/se3_transformer_pytorch/data/J_dense.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/se3-transformer-pytorch/e1669ee69345c271b7aa0a3aeeda452d32e26736/se3_transformer_pytorch/data/J_dense.npy
--------------------------------------------------------------------------------
/se3_transformer_pytorch/data/J_dense.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/se3-transformer-pytorch/e1669ee69345c271b7aa0a3aeeda452d32e26736/se3_transformer_pytorch/data/J_dense.pt
--------------------------------------------------------------------------------
/se3_transformer_pytorch/irr_repr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from torch import sin, cos, atan2, acos
5 | from math import pi
6 | from pathlib import Path
7 | from functools import wraps
8 |
9 | from se3_transformer_pytorch.utils import exists, default, cast_torch_tensor, to_order
10 | from se3_transformer_pytorch.spherical_harmonics import get_spherical_harmonics, clear_spherical_harmonics_cache
11 |
12 | DATA_PATH = path = Path(os.path.dirname(__file__)) / 'data'
13 |
14 | try:
15 | path = DATA_PATH / 'J_dense.pt'
16 | Jd = torch.load(str(path))
17 | except:
18 | path = DATA_PATH / 'J_dense.npy'
19 | Jd_np = np.load(str(path), allow_pickle = True)
20 | Jd = list(map(torch.from_numpy, Jd_np))
21 |
22 | def wigner_d_matrix(degree, alpha, beta, gamma, dtype = None, device = None):
23 | """Create wigner D matrices for batch of ZYZ Euler anglers for degree l."""
24 | J = Jd[degree].type(dtype).to(device)
25 | order = to_order(degree)
26 | x_a = z_rot_mat(alpha, degree)
27 | x_b = z_rot_mat(beta, degree)
28 | x_c = z_rot_mat(gamma, degree)
29 | res = x_a @ J @ x_b @ J @ x_c
30 | return res.view(order, order)
31 |
32 | def z_rot_mat(angle, l):
33 | device, dtype = angle.device, angle.dtype
34 | order = to_order(l)
35 | m = angle.new_zeros((order, order))
36 | inds = torch.arange(0, order, 1, dtype=torch.long, device=device)
37 | reversed_inds = torch.arange(2 * l, -1, -1, dtype=torch.long, device=device)
38 | frequencies = torch.arange(l, -l - 1, -1, dtype=dtype, device=device)[None]
39 |
40 | m[inds, reversed_inds] = sin(frequencies * angle[None])
41 | m[inds, inds] = cos(frequencies * angle[None])
42 | return m
43 |
44 | def irr_repr(order, alpha, beta, gamma, dtype = None):
45 | """
46 | irreducible representation of SO3
47 | - compatible with compose and spherical_harmonics
48 | """
49 | cast_ = cast_torch_tensor(lambda t: t)
50 | dtype = default(dtype, torch.get_default_dtype())
51 | alpha, beta, gamma = map(cast_, (alpha, beta, gamma))
52 | return wigner_d_matrix(order, alpha, beta, gamma, dtype = dtype)
53 |
54 | @cast_torch_tensor
55 | def rot_z(gamma):
56 | '''
57 | Rotation around Z axis
58 | '''
59 | return torch.tensor([
60 | [cos(gamma), -sin(gamma), 0],
61 | [sin(gamma), cos(gamma), 0],
62 | [0, 0, 1]
63 | ], dtype=gamma.dtype)
64 |
65 | @cast_torch_tensor
66 | def rot_y(beta):
67 | '''
68 | Rotation around Y axis
69 | '''
70 | return torch.tensor([
71 | [cos(beta), 0, sin(beta)],
72 | [0, 1, 0],
73 | [-sin(beta), 0, cos(beta)]
74 | ], dtype=beta.dtype)
75 |
76 | @cast_torch_tensor
77 | def x_to_alpha_beta(x):
78 | '''
79 | Convert point (x, y, z) on the sphere into (alpha, beta)
80 | '''
81 | x = x / torch.norm(x)
82 | beta = acos(x[2])
83 | alpha = atan2(x[1], x[0])
84 | return (alpha, beta)
85 |
86 | def rot(alpha, beta, gamma):
87 | '''
88 | ZYZ Euler angles rotation
89 | '''
90 | return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
91 |
92 | def compose(a1, b1, c1, a2, b2, c2):
93 | """
94 | (a, b, c) = (a1, b1, c1) composed with (a2, b2, c2)
95 | """
96 | comp = rot(a1, b1, c1) @ rot(a2, b2, c2)
97 | xyz = comp @ torch.tensor([0, 0, 1.])
98 | a, b = x_to_alpha_beta(xyz)
99 | rotz = rot(0, -b, -a) @ comp
100 | c = atan2(rotz[1, 0], rotz[0, 0])
101 | return a, b, c
102 |
103 | def spherical_harmonics(order, alpha, beta, dtype = None):
104 | return get_spherical_harmonics(order, theta = (pi - beta), phi = alpha)
105 |
--------------------------------------------------------------------------------
/se3_transformer_pytorch/reversible.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd.function import Function
4 | from torch.utils.checkpoint import get_device_states, set_device_states
5 |
6 | # helpers
7 |
8 | def map_values(fn, x):
9 | out = {}
10 | for (k, v) in x.items():
11 | out[k] = fn(v)
12 | return out
13 |
14 | def dict_chunk(x, chunks, dim):
15 | out1 = {}
16 | out2 = {}
17 | for (k, v) in x.items():
18 | c1, c2 = v.chunk(chunks, dim = dim)
19 | out1[k] = c1
20 | out2[k] = c2
21 | return out1, out2
22 |
23 | def dict_sum(x, y):
24 | out = {}
25 | for k in x.keys():
26 | out[k] = x[k] + y[k]
27 | return out
28 |
29 | def dict_subtract(x, y):
30 | out = {}
31 | for k in x.keys():
32 | out[k] = x[k] - y[k]
33 | return out
34 |
35 | def dict_cat(x, y, dim):
36 | out = {}
37 | for k, v1 in x.items():
38 | v2 = y[k]
39 | out[k] = torch.cat((v1, v2), dim = dim)
40 | return out
41 |
42 | def dict_set_(x, key, value):
43 | for k, v in x.items():
44 | setattr(v, key, value)
45 |
46 | def dict_backwards_(outputs, grad_tensors):
47 | for k, v in outputs.items():
48 | torch.autograd.backward(v, grad_tensors[k], retain_graph = True)
49 |
50 | def dict_del_(x):
51 | for k, v in x.items():
52 | del v
53 | del x
54 |
55 | def values(d):
56 | return [v for _, v in d.items()]
57 |
58 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
59 | class Deterministic(nn.Module):
60 | def __init__(self, net):
61 | super().__init__()
62 | self.net = net
63 | self.cpu_state = None
64 | self.cuda_in_fwd = None
65 | self.gpu_devices = None
66 | self.gpu_states = None
67 |
68 | def record_rng(self, *args):
69 | self.cpu_state = torch.get_rng_state()
70 | if torch.cuda._initialized:
71 | self.cuda_in_fwd = True
72 | self.gpu_devices, self.gpu_states = get_device_states(*args)
73 |
74 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
75 | if record_rng:
76 | self.record_rng(*args)
77 |
78 | if not set_rng:
79 | return self.net(*args, **kwargs)
80 |
81 | rng_devices = []
82 | if self.cuda_in_fwd:
83 | rng_devices = self.gpu_devices
84 |
85 | with torch.random.fork_rng(devices=rng_devices, enabled=True):
86 | torch.set_rng_state(self.cpu_state)
87 | if self.cuda_in_fwd:
88 | set_device_states(self.gpu_devices, self.gpu_states)
89 | return self.net(*args, **kwargs)
90 |
91 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
92 | # once multi-GPU is confirmed working, refactor and send PR back to source
93 | class ReversibleBlock(nn.Module):
94 | def __init__(self, f, g):
95 | super().__init__()
96 | self.f = Deterministic(f)
97 | self.g = Deterministic(g)
98 |
99 | def forward(self, x, **kwargs):
100 | training = self.training
101 | x1, x2 = dict_chunk(x, 2, dim = -1)
102 | y1, y2 = None, None
103 |
104 | with torch.no_grad():
105 | y1 = dict_sum(x1, self.f(x2, record_rng = training, **kwargs))
106 | y2 = dict_sum(x2, self.g(y1, record_rng = training))
107 |
108 | return dict_cat(y1, y2, dim = -1)
109 |
110 | def backward_pass(self, y, dy, **kwargs):
111 | y1, y2 = dict_chunk(y, 2, dim = -1)
112 | dict_del_(y)
113 |
114 | dy1, dy2 = dict_chunk(dy, 2, dim = -1)
115 | dict_del_(dy)
116 |
117 | with torch.enable_grad():
118 | dict_set_(y1, 'requires_grad', True)
119 | gy1 = self.g(y1, set_rng = True)
120 | dict_backwards_(gy1, dy2)
121 |
122 | with torch.no_grad():
123 | x2 = dict_subtract(y2, gy1)
124 | dict_del_(y2)
125 | dict_del_(gy1)
126 |
127 | dx1 = dict_sum(dy1, map_values(lambda t: t.grad, y1))
128 | dict_del_(dy1)
129 | dict_set_(y1, 'grad', None)
130 |
131 | with torch.enable_grad():
132 | dict_set_(x2, 'requires_grad', True)
133 | fx2 = self.f(x2, set_rng = True, **kwargs)
134 | dict_backwards_(fx2, dx1)
135 |
136 | with torch.no_grad():
137 | x1 = dict_subtract(y1, fx2)
138 | dict_del_(y1)
139 | dict_del_(fx2)
140 |
141 | dx2 = dict_sum(dy2, map_values(lambda t: t.grad, x2))
142 | dict_del_(dy2)
143 | dict_set_(x2, 'grad', None)
144 |
145 | x2 = map_values(lambda t: t.detach(), x2)
146 |
147 | x = dict_cat(x1, x2, dim = -1)
148 | dx = dict_cat(dx1, dx2, dim = -1)
149 |
150 | return x, dx
151 |
152 | class _ReversibleFunction(Function):
153 | @staticmethod
154 | def forward(ctx, x, blocks, kwargs):
155 | input_keys = kwargs.pop('input_keys')
156 | split_dims = kwargs.pop('split_dims')
157 | input_values = x.split(split_dims, dim = -1)
158 | x = dict(zip(input_keys, input_values))
159 |
160 | ctx.kwargs = kwargs
161 | ctx.split_dims = split_dims
162 | ctx.input_keys = input_keys
163 |
164 | for block in blocks:
165 | x = block(x, **kwargs)
166 |
167 | ctx.y = map_values(lambda t: t.detach(), x)
168 | ctx.blocks = blocks
169 |
170 | x = torch.cat(values(x), dim = -1)
171 | return x
172 |
173 | @staticmethod
174 | def backward(ctx, dy):
175 | y = ctx.y
176 | kwargs = ctx.kwargs
177 | input_keys = ctx.input_keys
178 | split_dims = ctx.split_dims
179 |
180 | dy = dy.split(split_dims, dim = -1)
181 | dy = dict(zip(input_keys, dy))
182 |
183 | for block in ctx.blocks[::-1]:
184 | y, dy = block.backward_pass(y, dy, **kwargs)
185 |
186 | dy = torch.cat(values(dy), dim = -1)
187 | return dy, None, None
188 |
189 | class SequentialSequence(nn.Module):
190 | def __init__(self, blocks):
191 | super().__init__()
192 | self.blocks = blocks
193 |
194 | def forward(self, x, **kwargs):
195 | for (attn, ff) in self.blocks:
196 | x = attn(x, **kwargs)
197 | x = ff(x)
198 | return x
199 |
200 | class ReversibleSequence(nn.Module):
201 | def __init__(self, blocks):
202 | super().__init__()
203 | self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks])
204 |
205 | def forward(self, x, **kwargs):
206 | blocks = self.blocks
207 |
208 | x = map_values(lambda t: torch.cat((t, t), dim = -1), x)
209 |
210 | input_keys = x.keys()
211 | split_dims = tuple(map(lambda t: t.shape[-1], x.values()))
212 | block_kwargs = {'input_keys': input_keys, 'split_dims': split_dims, **kwargs}
213 |
214 | x = torch.cat(values(x), dim = -1)
215 |
216 | x = _ReversibleFunction.apply(x, blocks, block_kwargs)
217 |
218 | x = dict(zip(input_keys, x.split(split_dims, dim = -1)))
219 | x = map_values(lambda t: torch.stack(t.chunk(2, dim = -1)).mean(dim = 0), x)
220 | return x
221 |
--------------------------------------------------------------------------------
/se3_transformer_pytorch/rotary.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from einops import rearrange, repeat
4 |
5 | class SinusoidalEmbeddings(nn.Module):
6 | def __init__(self, dim):
7 | super().__init__()
8 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
9 | self.register_buffer('inv_freq', inv_freq)
10 |
11 | def forward(self, t):
12 | freqs = t[..., None].float() * self.inv_freq[None, :]
13 | return repeat(freqs, '... d -> ... (d r)', r = 2)
14 |
15 | def rotate_half(x):
16 | x = rearrange(x, '... (d j) m -> ... d j m', j = 2)
17 | x1, x2 = x.unbind(dim = -2)
18 | return torch.cat((-x2, x1), dim = -2)
19 |
20 | def apply_rotary_pos_emb(t, freqs):
21 | rot_dim = freqs.shape[-2]
22 | t, t_pass = t[..., :rot_dim, :], t[..., rot_dim:, :]
23 | t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
24 | return torch.cat((t, t_pass), dim = -2)
25 |
--------------------------------------------------------------------------------
/se3_transformer_pytorch/se3_transformer_pytorch.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 | from itertools import product
3 | from collections import namedtuple
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn, einsum
8 |
9 | from se3_transformer_pytorch.basis import get_basis
10 | from se3_transformer_pytorch.utils import exists, default, uniq, map_values, batched_index_select, masked_mean, to_order, fourier_encode, cast_tuple, safe_cat, fast_split, rand_uniform, broadcat
11 | from se3_transformer_pytorch.reversible import ReversibleSequence, SequentialSequence
12 | from se3_transformer_pytorch.rotary import SinusoidalEmbeddings, apply_rotary_pos_emb
13 |
14 | from einops import rearrange, repeat
15 |
16 | # fiber helpers
17 |
18 | FiberEl = namedtuple('FiberEl', ['degrees', 'dim'])
19 |
20 | class Fiber(nn.Module):
21 | def __init__(
22 | self,
23 | structure
24 | ):
25 | super().__init__()
26 | if isinstance(structure, dict):
27 | structure = [FiberEl(degree, dim) for degree, dim in structure.items()]
28 | self.structure = structure
29 |
30 | @property
31 | def dims(self):
32 | return uniq(map(lambda t: t[1], self.structure))
33 |
34 | @property
35 | def degrees(self):
36 | return map(lambda t: t[0], self.structure)
37 |
38 | @staticmethod
39 | def create(num_degrees, dim):
40 | dim_tuple = dim if isinstance(dim, tuple) else ((dim,) * num_degrees)
41 | return Fiber([FiberEl(degree, dim) for degree, dim in zip(range(num_degrees), dim_tuple)])
42 |
43 | def __getitem__(self, degree):
44 | return dict(self.structure)[degree]
45 |
46 | def __iter__(self):
47 | return iter(self.structure)
48 |
49 | def __mul__(self, fiber):
50 | return product(self.structure, fiber.structure)
51 |
52 | def __and__(self, fiber):
53 | out = []
54 | degrees_out = fiber.degrees
55 | for degree, dim in self:
56 | if degree in fiber.degrees:
57 | dim_out = fiber[degree]
58 | out.append((degree, dim, dim_out))
59 | return out
60 |
61 | def get_tensor_device_and_dtype(features):
62 | first_tensor = next(iter(features.items()))[1]
63 | return first_tensor.device, first_tensor.dtype
64 |
65 | # classes
66 |
67 | class ResidualSE3(nn.Module):
68 | """ only support instance where both Fibers are identical """
69 | def forward(self, x, res):
70 | out = {}
71 | for degree, tensor in x.items():
72 | degree = str(degree)
73 | out[degree] = tensor
74 | if degree in res:
75 | out[degree] = out[degree] + res[degree]
76 | return out
77 |
78 | class LinearSE3(nn.Module):
79 | def __init__(
80 | self,
81 | fiber_in,
82 | fiber_out
83 | ):
84 | super().__init__()
85 | self.weights = nn.ParameterDict()
86 |
87 | for (degree, dim_in, dim_out) in (fiber_in & fiber_out):
88 | key = str(degree)
89 | self.weights[key] = nn.Parameter(torch.randn(dim_in, dim_out) / sqrt(dim_in))
90 |
91 | def forward(self, x):
92 | out = {}
93 | for degree, weight in self.weights.items():
94 | out[degree] = einsum('b n d m, d e -> b n e m', x[degree], weight)
95 | return out
96 |
97 | class NormSE3(nn.Module):
98 | """Norm-based SE(3)-equivariant nonlinearity.
99 |
100 | Nonlinearities are important in SE(3) equivariant GCNs. They are also quite
101 | expensive to compute, so it is convenient for them to share resources with
102 | other layers, such as normalization. The general workflow is as follows:
103 |
104 | > for feature type in features:
105 | > norm, phase <- feature
106 | > output = fnc(norm) * phase
107 |
108 | where fnc: {R+}^m -> R^m is a learnable map from m norms to m scalars.
109 | """
110 | def __init__(
111 | self,
112 | fiber,
113 | nonlin = nn.GELU(),
114 | gated_scale = False,
115 | eps = 1e-12,
116 | ):
117 | super().__init__()
118 | self.fiber = fiber
119 | self.nonlin = nonlin
120 | self.eps = eps
121 |
122 | # Norm mappings: 1 per feature type
123 | self.transform = nn.ModuleDict()
124 | for degree, chan in fiber:
125 | self.transform[str(degree)] = nn.ParameterDict({
126 | 'scale': nn.Parameter(torch.ones(1, 1, chan)) if not gated_scale else None,
127 | 'w_gate': nn.Parameter(rand_uniform((chan, chan), -1e-3, 1e-3)) if gated_scale else None
128 | })
129 |
130 | def forward(self, features):
131 | output = {}
132 | for degree, t in features.items():
133 | # Compute the norms and normalized features
134 | norm = t.norm(dim = -1, keepdim = True).clamp(min = self.eps)
135 | phase = t / norm
136 |
137 | # Transform on norms
138 | parameters = self.transform[degree]
139 | gate_weights, scale = parameters['w_gate'], parameters['scale']
140 |
141 | transformed = rearrange(norm, '... () -> ...')
142 |
143 | if not exists(scale):
144 | scale = einsum('b n d, d e -> b n e', transformed, gate_weights)
145 |
146 | transformed = self.nonlin(transformed * scale)
147 | transformed = rearrange(transformed, '... -> ... ()')
148 |
149 | # Nonlinearity on norm
150 | output[degree] = (transformed * phase).view(*t.shape)
151 |
152 | return output
153 |
154 | class ConvSE3(nn.Module):
155 | """A tensor field network layer
156 |
157 | ConvSE3 stands for a Convolution SE(3)-equivariant layer. It is the
158 | equivalent of a linear layer in an MLP, a conv layer in a CNN, or a graph
159 | conv layer in a GCN.
160 |
161 | At each node, the activations are split into different "feature types",
162 | indexed by the SE(3) representation type: non-negative integers 0, 1, 2, ..
163 | """
164 | def __init__(
165 | self,
166 | fiber_in,
167 | fiber_out,
168 | self_interaction = True,
169 | pool = True,
170 | edge_dim = 0,
171 | fourier_encode_dist = False,
172 | num_fourier_features = 4,
173 | splits = 4
174 | ):
175 | super().__init__()
176 | self.fiber_in = fiber_in
177 | self.fiber_out = fiber_out
178 | self.edge_dim = edge_dim
179 | self.self_interaction = self_interaction
180 |
181 | self.num_fourier_features = num_fourier_features
182 | self.fourier_encode_dist = fourier_encode_dist
183 |
184 | # radial function will assume a dimension of at minimum 1, for the relative distance - extra fourier features must be added to the edge dimension
185 | edge_dim += (0 if not fourier_encode_dist else (num_fourier_features * 2))
186 |
187 | # Neighbor -> center weights
188 | self.kernel_unary = nn.ModuleDict()
189 |
190 | self.splits = splits # for splitting the computation of kernel and basis, to reduce peak memory usage
191 |
192 | for (di, mi), (do, mo) in (self.fiber_in * self.fiber_out):
193 | self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim = edge_dim, splits = splits)
194 |
195 | self.pool = pool
196 |
197 | # Center -> center weights
198 | if self_interaction:
199 | assert self.pool, 'must pool edges if followed with self interaction'
200 | self.self_interact = LinearSE3(fiber_in, fiber_out)
201 | self.self_interact_sum = ResidualSE3()
202 |
203 | def forward(
204 | self,
205 | inp,
206 | edge_info,
207 | rel_dist = None,
208 | basis = None
209 | ):
210 | splits = self.splits
211 | neighbor_indices, neighbor_masks, edges = edge_info
212 | rel_dist = rearrange(rel_dist, 'b m n -> b m n ()')
213 |
214 | kernels = {}
215 | outputs = {}
216 |
217 | if self.fourier_encode_dist:
218 | rel_dist = fourier_encode(rel_dist[..., None], num_encodings = self.num_fourier_features)
219 |
220 | # split basis
221 |
222 | basis_keys = basis.keys()
223 | split_basis_values = list(zip(*list(map(lambda t: fast_split(t, splits, dim = 1), basis.values()))))
224 | split_basis = list(map(lambda v: dict(zip(basis_keys, v)), split_basis_values))
225 |
226 | # go through every permutation of input degree type to output degree type
227 |
228 | for degree_out in self.fiber_out.degrees:
229 | output = 0
230 | degree_out_key = str(degree_out)
231 |
232 | for degree_in, m_in in self.fiber_in:
233 | etype = f'({degree_in},{degree_out})'
234 |
235 | x = inp[str(degree_in)]
236 |
237 | x = batched_index_select(x, neighbor_indices, dim = 1)
238 | x = x.view(*x.shape[:3], to_order(degree_in) * m_in, 1)
239 |
240 | kernel_fn = self.kernel_unary[etype]
241 | edge_features = torch.cat((rel_dist, edges), dim = -1) if exists(edges) else rel_dist
242 |
243 | output_chunk = None
244 | split_x = fast_split(x, splits, dim = 1)
245 | split_edge_features = fast_split(edge_features, splits, dim = 1)
246 |
247 | # process input, edges, and basis in chunks along the sequence dimension
248 |
249 | for x_chunk, edge_features, basis in zip(split_x, split_edge_features, split_basis):
250 | kernel = kernel_fn(edge_features, basis = basis)
251 | chunk = einsum('... o i, ... i c -> ... o c', kernel, x_chunk)
252 | output_chunk = safe_cat(output_chunk, chunk, dim = 1)
253 |
254 | output = output + output_chunk
255 |
256 | if self.pool:
257 | output = masked_mean(output, neighbor_masks, dim = 2) if exists(neighbor_masks) else output.mean(dim = 2)
258 |
259 | leading_shape = x.shape[:2] if self.pool else x.shape[:3]
260 | output = output.view(*leading_shape, -1, to_order(degree_out))
261 |
262 | outputs[degree_out_key] = output
263 |
264 | if self.self_interaction:
265 | self_interact_out = self.self_interact(inp)
266 | outputs = self.self_interact_sum(outputs, self_interact_out)
267 |
268 | return outputs
269 |
270 | class RadialFunc(nn.Module):
271 | """NN parameterized radial profile function."""
272 | def __init__(
273 | self,
274 | num_freq,
275 | in_dim,
276 | out_dim,
277 | edge_dim = None,
278 | mid_dim = 128
279 | ):
280 | super().__init__()
281 | self.num_freq = num_freq
282 | self.in_dim = in_dim
283 | self.mid_dim = mid_dim
284 | self.out_dim = out_dim
285 | self.edge_dim = default(edge_dim, 0)
286 |
287 | self.net = nn.Sequential(
288 | nn.Linear(self.edge_dim + 1, mid_dim),
289 | nn.LayerNorm(mid_dim),
290 | nn.GELU(),
291 | nn.Linear(mid_dim, mid_dim),
292 | nn.LayerNorm(mid_dim),
293 | nn.GELU(),
294 | nn.Linear(mid_dim, num_freq * in_dim * out_dim)
295 | )
296 |
297 | def forward(self, x):
298 | y = self.net(x)
299 | return rearrange(y, '... (o i f) -> ... o () i () f', i = self.in_dim, o = self.out_dim)
300 |
301 | class PairwiseConv(nn.Module):
302 | """SE(3)-equivariant convolution between two single-type features"""
303 | def __init__(
304 | self,
305 | degree_in,
306 | nc_in,
307 | degree_out,
308 | nc_out,
309 | edge_dim = 0,
310 | splits = 4
311 | ):
312 | super().__init__()
313 | self.degree_in = degree_in
314 | self.degree_out = degree_out
315 | self.nc_in = nc_in
316 | self.nc_out = nc_out
317 |
318 | self.num_freq = to_order(min(degree_in, degree_out))
319 | self.d_out = to_order(degree_out)
320 | self.edge_dim = edge_dim
321 |
322 | self.rp = RadialFunc(self.num_freq, nc_in, nc_out, edge_dim)
323 |
324 | self.splits = splits
325 |
326 | def forward(self, feat, basis):
327 | splits = self.splits
328 | R = self.rp(feat)
329 | B = basis[f'{self.degree_in},{self.degree_out}']
330 |
331 | out_shape = (*R.shape[:3], self.d_out * self.nc_out, -1)
332 |
333 | # torch.sum(R * B, dim = -1) is too memory intensive
334 | # needs to be chunked to reduce peak memory usage
335 |
336 | out = 0
337 | for i in range(R.shape[-1]):
338 | out += R[..., i] * B[..., i]
339 |
340 | out = rearrange(out, 'b n h s ... -> (b n h s) ...')
341 |
342 | # reshape and out
343 | return out.view(*out_shape)
344 |
345 | # feed forwards
346 |
347 | class FeedForwardSE3(nn.Module):
348 | def __init__(
349 | self,
350 | fiber,
351 | mult = 4
352 | ):
353 | super().__init__()
354 | self.fiber = fiber
355 | fiber_hidden = Fiber(list(map(lambda t: (t[0], t[1] * mult), fiber)))
356 |
357 | self.project_in = LinearSE3(fiber, fiber_hidden)
358 | self.nonlin = NormSE3(fiber_hidden)
359 | self.project_out = LinearSE3(fiber_hidden, fiber)
360 |
361 | def forward(self, features):
362 | outputs = self.project_in(features)
363 | outputs = self.nonlin(outputs)
364 | outputs = self.project_out(outputs)
365 | return outputs
366 |
367 | class FeedForwardBlockSE3(nn.Module):
368 | def __init__(
369 | self,
370 | fiber,
371 | norm_gated_scale = False
372 | ):
373 | super().__init__()
374 | self.fiber = fiber
375 | self.prenorm = NormSE3(fiber, gated_scale = norm_gated_scale)
376 | self.feedforward = FeedForwardSE3(fiber)
377 | self.residual = ResidualSE3()
378 |
379 | def forward(self, features):
380 | res = features
381 | out = self.prenorm(features)
382 | out = self.feedforward(out)
383 | return self.residual(out, res)
384 |
385 | # attention
386 |
387 | class AttentionSE3(nn.Module):
388 | def __init__(
389 | self,
390 | fiber,
391 | dim_head = 64,
392 | heads = 8,
393 | attend_self = False,
394 | edge_dim = None,
395 | fourier_encode_dist = False,
396 | rel_dist_num_fourier_features = 4,
397 | use_null_kv = False,
398 | splits = 4,
399 | global_feats_dim = None,
400 | linear_proj_keys = False,
401 | tie_key_values = False
402 | ):
403 | super().__init__()
404 | hidden_dim = dim_head * heads
405 | hidden_fiber = Fiber(list(map(lambda t: (t[0], hidden_dim), fiber)))
406 | project_out = not (heads == 1 and len(fiber.dims) == 1 and dim_head == fiber.dims[0])
407 |
408 | self.scale = dim_head ** -0.5
409 | self.heads = heads
410 |
411 | self.linear_proj_keys = linear_proj_keys # whether to linearly project features for keys, rather than convolve with basis
412 |
413 | self.to_q = LinearSE3(fiber, hidden_fiber)
414 | self.to_v = ConvSE3(fiber, hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
415 |
416 | assert not (linear_proj_keys and tie_key_values), 'you cannot do linear projection of keys and have shared key / values turned on at the same time'
417 |
418 | if linear_proj_keys:
419 | self.to_k = LinearSE3(fiber, hidden_fiber)
420 | elif not tie_key_values:
421 | self.to_k = ConvSE3(fiber, hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
422 | else:
423 | self.to_k = None
424 |
425 | self.to_out = LinearSE3(hidden_fiber, fiber) if project_out else nn.Identity()
426 |
427 | self.use_null_kv = use_null_kv
428 | if use_null_kv:
429 | self.null_keys = nn.ParameterDict()
430 | self.null_values = nn.ParameterDict()
431 |
432 | for degree in fiber.degrees:
433 | m = to_order(degree)
434 | degree_key = str(degree)
435 | self.null_keys[degree_key] = nn.Parameter(torch.zeros(heads, dim_head, m))
436 | self.null_values[degree_key] = nn.Parameter(torch.zeros(heads, dim_head, m))
437 |
438 | self.attend_self = attend_self
439 | if attend_self:
440 | self.to_self_k = LinearSE3(fiber, hidden_fiber)
441 | self.to_self_v = LinearSE3(fiber, hidden_fiber)
442 |
443 | self.accept_global_feats = exists(global_feats_dim)
444 | if self.accept_global_feats:
445 | global_input_fiber = Fiber.create(1, global_feats_dim)
446 | global_output_fiber = Fiber.create(1, hidden_fiber[0])
447 | self.to_global_k = LinearSE3(global_input_fiber, global_output_fiber)
448 | self.to_global_v = LinearSE3(global_input_fiber, global_output_fiber)
449 |
450 | def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None):
451 | h, attend_self = self.heads, self.attend_self
452 | device, dtype = get_tensor_device_and_dtype(features)
453 | neighbor_indices, neighbor_mask, edges = edge_info
454 |
455 | if exists(neighbor_mask):
456 | neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j')
457 |
458 | queries = self.to_q(features)
459 | values = self.to_v(features, edge_info, rel_dist, basis)
460 |
461 | if self.linear_proj_keys:
462 | keys = self.to_k(features)
463 | keys = map_values(lambda val: batched_index_select(val, neighbor_indices, dim = 1), keys)
464 | elif not exists(self.to_k):
465 | keys = values
466 | else:
467 | keys = self.to_k(features, edge_info, rel_dist, basis)
468 |
469 | if attend_self:
470 | self_keys, self_values = self.to_self_k(features), self.to_self_v(features)
471 |
472 | if exists(global_feats):
473 | global_keys, global_values = self.to_global_k(global_feats), self.to_global_v(global_feats)
474 |
475 | outputs = {}
476 | for degree in features.keys():
477 | q, k, v = map(lambda t: t[degree], (queries, keys, values))
478 |
479 | q = rearrange(q, 'b i (h d) m -> b h i d m', h = h)
480 | k, v = map(lambda t: rearrange(t, 'b i j (h d) m -> b h i j d m', h = h), (k, v))
481 |
482 | if attend_self:
483 | self_k, self_v = map(lambda t: t[degree], (self_keys, self_values))
484 | self_k, self_v = map(lambda t: rearrange(t, 'b n (h d) m -> b h n () d m', h = h), (self_k, self_v))
485 | k = torch.cat((self_k, k), dim = 3)
486 | v = torch.cat((self_v, v), dim = 3)
487 |
488 | if exists(pos_emb) and degree == '0':
489 | query_pos_emb, key_pos_emb = pos_emb
490 | query_pos_emb = rearrange(query_pos_emb, 'b i d -> b () i d ()')
491 | key_pos_emb = rearrange(key_pos_emb, 'b i j d -> b () i j d ()')
492 | q = apply_rotary_pos_emb(q, query_pos_emb)
493 | k = apply_rotary_pos_emb(k, key_pos_emb)
494 | v = apply_rotary_pos_emb(v, key_pos_emb)
495 |
496 | if self.use_null_kv:
497 | null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values))
498 | null_k, null_v = map(lambda t: repeat(t, 'h d m -> b h i () d m', b = q.shape[0], i = q.shape[2]), (null_k, null_v))
499 | k = torch.cat((null_k, k), dim = 3)
500 | v = torch.cat((null_v, v), dim = 3)
501 |
502 | if exists(global_feats) and degree == '0':
503 | global_k, global_v = map(lambda t: t[degree], (global_keys, global_values))
504 | global_k, global_v = map(lambda t: repeat(t, 'b j (h d) m -> b h i j d m', h = h, i = k.shape[2]), (global_k, global_v))
505 | k = torch.cat((global_k, k), dim = 3)
506 | v = torch.cat((global_v, v), dim = 3)
507 |
508 | sim = einsum('b h i d m, b h i j d m -> b h i j', q, k) * self.scale
509 |
510 | if exists(neighbor_mask):
511 | num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1]
512 | mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True)
513 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
514 |
515 | attn = sim.softmax(dim = -1)
516 | out = einsum('b h i j, b h i j d m -> b h i d m', attn, v)
517 | outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m')
518 |
519 | return self.to_out(outputs)
520 |
521 | # AttentionSE3, but with one key / value projection shared across all query heads
522 | class OneHeadedKVAttentionSE3(nn.Module):
523 | def __init__(
524 | self,
525 | fiber,
526 | dim_head = 64,
527 | heads = 8,
528 | attend_self = False,
529 | edge_dim = None,
530 | fourier_encode_dist = False,
531 | rel_dist_num_fourier_features = 4,
532 | use_null_kv = False,
533 | splits = 4,
534 | global_feats_dim = None,
535 | linear_proj_keys = False,
536 | tie_key_values = False
537 | ):
538 | super().__init__()
539 | hidden_dim = dim_head * heads
540 | hidden_fiber = Fiber(list(map(lambda t: (t[0], hidden_dim), fiber)))
541 | kv_hidden_fiber = Fiber(list(map(lambda t: (t[0], dim_head), fiber)))
542 | project_out = not (heads == 1 and len(fiber.dims) == 1 and dim_head == fiber.dims[0])
543 |
544 | self.scale = dim_head ** -0.5
545 | self.heads = heads
546 |
547 | self.linear_proj_keys = linear_proj_keys # whether to linearly project features for keys, rather than convolve with basis
548 |
549 | self.to_q = LinearSE3(fiber, hidden_fiber)
550 | self.to_v = ConvSE3(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
551 |
552 | assert not (linear_proj_keys and tie_key_values), 'you cannot do linear projection of keys and have shared key / values turned on at the same time'
553 |
554 | if linear_proj_keys:
555 | self.to_k = LinearSE3(fiber, kv_hidden_fiber)
556 | elif not tie_key_values:
557 | self.to_k = ConvSE3(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
558 | else:
559 | self.to_k = None
560 |
561 | self.to_out = LinearSE3(hidden_fiber, fiber) if project_out else nn.Identity()
562 |
563 | self.use_null_kv = use_null_kv
564 | if use_null_kv:
565 | self.null_keys = nn.ParameterDict()
566 | self.null_values = nn.ParameterDict()
567 |
568 | for degree in fiber.degrees:
569 | m = to_order(degree)
570 | degree_key = str(degree)
571 | self.null_keys[degree_key] = nn.Parameter(torch.zeros(dim_head, m))
572 | self.null_values[degree_key] = nn.Parameter(torch.zeros(dim_head, m))
573 |
574 | self.attend_self = attend_self
575 | if attend_self:
576 | self.to_self_k = LinearSE3(fiber, kv_hidden_fiber)
577 | self.to_self_v = LinearSE3(fiber, kv_hidden_fiber)
578 |
579 | self.accept_global_feats = exists(global_feats_dim)
580 | if self.accept_global_feats:
581 | global_input_fiber = Fiber.create(1, global_feats_dim)
582 | global_output_fiber = Fiber.create(1, kv_hidden_fiber[0])
583 | self.to_global_k = LinearSE3(global_input_fiber, global_output_fiber)
584 | self.to_global_v = LinearSE3(global_input_fiber, global_output_fiber)
585 |
586 | def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None):
587 | h, attend_self = self.heads, self.attend_self
588 | device, dtype = get_tensor_device_and_dtype(features)
589 | neighbor_indices, neighbor_mask, edges = edge_info
590 |
591 | if exists(neighbor_mask):
592 | neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j')
593 |
594 | queries = self.to_q(features)
595 | values = self.to_v(features, edge_info, rel_dist, basis)
596 |
597 | if self.linear_proj_keys:
598 | keys = self.to_k(features)
599 | keys = map_values(lambda val: batched_index_select(val, neighbor_indices, dim = 1), keys)
600 | elif not exists(self.to_k):
601 | keys = values
602 | else:
603 | keys = self.to_k(features, edge_info, rel_dist, basis)
604 |
605 | if attend_self:
606 | self_keys, self_values = self.to_self_k(features), self.to_self_v(features)
607 |
608 | if exists(global_feats):
609 | global_keys, global_values = self.to_global_k(global_feats), self.to_global_v(global_feats)
610 |
611 | outputs = {}
612 | for degree in features.keys():
613 | q, k, v = map(lambda t: t[degree], (queries, keys, values))
614 |
615 | q = rearrange(q, 'b i (h d) m -> b h i d m', h = h)
616 |
617 | if attend_self:
618 | self_k, self_v = map(lambda t: t[degree], (self_keys, self_values))
619 | self_k, self_v = map(lambda t: rearrange(t, 'b n d m -> b n () d m'), (self_k, self_v))
620 | k = torch.cat((self_k, k), dim = 2)
621 | v = torch.cat((self_v, v), dim = 2)
622 |
623 | if exists(pos_emb) and degree == '0':
624 | query_pos_emb, key_pos_emb = pos_emb
625 | query_pos_emb = rearrange(query_pos_emb, 'b i d -> b () i d ()')
626 | key_pos_emb = rearrange(key_pos_emb, 'b i j d -> b i j d ()')
627 | q = apply_rotary_pos_emb(q, query_pos_emb)
628 | k = apply_rotary_pos_emb(k, key_pos_emb)
629 | v = apply_rotary_pos_emb(v, key_pos_emb)
630 |
631 | if self.use_null_kv:
632 | null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values))
633 | null_k, null_v = map(lambda t: repeat(t, 'd m -> b i () d m', b = q.shape[0], i = q.shape[2]), (null_k, null_v))
634 | k = torch.cat((null_k, k), dim = 2)
635 | v = torch.cat((null_v, v), dim = 2)
636 |
637 | if exists(global_feats) and degree == '0':
638 | global_k, global_v = map(lambda t: t[degree], (global_keys, global_values))
639 | global_k, global_v = map(lambda t: repeat(t, 'b j d m -> b i j d m', i = k.shape[1]), (global_k, global_v))
640 | k = torch.cat((global_k, k), dim = 2)
641 | v = torch.cat((global_v, v), dim = 2)
642 |
643 | sim = einsum('b h i d m, b i j d m -> b h i j', q, k) * self.scale
644 |
645 | if exists(neighbor_mask):
646 | num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1]
647 | mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True)
648 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
649 |
650 | attn = sim.softmax(dim = -1)
651 | out = einsum('b h i j, b i j d m -> b h i d m', attn, v)
652 | outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m')
653 |
654 | return self.to_out(outputs)
655 |
656 | class AttentionBlockSE3(nn.Module):
657 | def __init__(
658 | self,
659 | fiber,
660 | dim_head = 24,
661 | heads = 8,
662 | attend_self = False,
663 | edge_dim = None,
664 | use_null_kv = False,
665 | fourier_encode_dist = False,
666 | rel_dist_num_fourier_features = 4,
667 | splits = 4,
668 | global_feats_dim = False,
669 | linear_proj_keys = False,
670 | tie_key_values = False,
671 | attention_klass = AttentionSE3,
672 | norm_gated_scale = False
673 | ):
674 | super().__init__()
675 | self.attn = attention_klass(fiber, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, use_null_kv = use_null_kv, rel_dist_num_fourier_features = rel_dist_num_fourier_features, fourier_encode_dist =fourier_encode_dist, splits = splits, global_feats_dim = global_feats_dim, linear_proj_keys = linear_proj_keys, tie_key_values = tie_key_values)
676 | self.prenorm = NormSE3(fiber, gated_scale = norm_gated_scale)
677 | self.residual = ResidualSE3()
678 |
679 | def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None):
680 | res = features
681 | outputs = self.prenorm(features)
682 | outputs = self.attn(outputs, edge_info, rel_dist, basis, global_feats, pos_emb, mask)
683 | return self.residual(outputs, res)
684 |
685 | # egnn
686 |
687 | class Swish_(nn.Module):
688 | def forward(self, x):
689 | return x * x.sigmoid()
690 |
691 | SiLU = nn.SiLU if hasattr(nn, 'SiLU') else Swish_
692 |
693 | class HtypesNorm(nn.Module):
694 | def __init__(self, dim, eps = 1e-8, scale_init = 1e-2, bias_init = 1e-2):
695 | super().__init__()
696 | self.eps = eps
697 | scale = torch.empty(1, 1, 1, dim, 1).fill_(scale_init)
698 | bias = torch.empty(1, 1, 1, dim, 1).fill_(bias_init)
699 | self.scale = nn.Parameter(scale)
700 | self.bias = nn.Parameter(bias)
701 |
702 | def forward(self, coors):
703 | norm = coors.norm(dim = -1, keepdim = True)
704 | normed_coors = coors / norm.clamp(min = self.eps)
705 | return normed_coors * (norm * self.scale + self.bias)
706 |
707 | class EGNN(nn.Module):
708 | def __init__(
709 | self,
710 | fiber,
711 | hidden_dim = 32,
712 | edge_dim = 0,
713 | init_eps = 1e-3,
714 | coor_weights_clamp_value = None
715 | ):
716 | super().__init__()
717 | self.fiber = fiber
718 | node_dim = fiber[0]
719 |
720 | htypes = list(filter(lambda t: t.degrees != 0, fiber))
721 | num_htypes = len(htypes)
722 | htype_dims = sum([fiberel.dim for fiberel in htypes])
723 |
724 | edge_input_dim = node_dim * 2 + htype_dims + edge_dim + 1
725 |
726 | self.node_norm = nn.LayerNorm(node_dim)
727 |
728 | self.edge_mlp = nn.Sequential(
729 | nn.Linear(edge_input_dim, edge_input_dim * 2),
730 | SiLU(),
731 | nn.Linear(edge_input_dim * 2, hidden_dim),
732 | SiLU()
733 | )
734 |
735 | self.htype_norms = nn.ModuleDict({})
736 | self.htype_gating = nn.ModuleDict({})
737 |
738 | for degree, dim in fiber:
739 | if degree == 0:
740 | continue
741 | self.htype_norms[str(degree)] = HtypesNorm(dim)
742 | self.htype_gating[str(degree)] = nn.Linear(node_dim, dim)
743 |
744 | self.htypes_mlp = nn.Sequential(
745 | nn.Linear(hidden_dim, hidden_dim * 4),
746 | SiLU(),
747 | nn.Linear(hidden_dim * 4, htype_dims)
748 | )
749 |
750 | self.node_mlp = nn.Sequential(
751 | nn.Linear(node_dim + hidden_dim, node_dim * 2),
752 | SiLU(),
753 | nn.Linear(node_dim * 2, node_dim)
754 | )
755 |
756 | self.coor_weights_clamp_value = coor_weights_clamp_value
757 | self.init_eps = init_eps
758 | self.apply(self.init_)
759 |
760 | def init_(self, module):
761 | if type(module) in {nn.Linear}:
762 | nn.init.normal_(module.weight, std = self.init_eps)
763 |
764 | def forward(
765 | self,
766 | features,
767 | edge_info,
768 | rel_dist,
769 | mask = None,
770 | **kwargs
771 | ):
772 | neighbor_indices, neighbor_masks, edges = edge_info
773 |
774 | mask = neighbor_masks
775 |
776 | # type 0 features
777 |
778 | nodes = features['0']
779 | nodes = rearrange(nodes, '... () -> ...')
780 |
781 | # higher types (htype)
782 |
783 | htypes = list(filter(lambda t: t[0] != '0', features.items()))
784 | htype_degrees = list(map(lambda t: t[0], htypes))
785 | htype_dims = list(map(lambda t: t[1].shape[-2], htypes))
786 |
787 | # prepare higher types
788 |
789 | rel_htypes = []
790 | rel_htypes_dists = []
791 |
792 | for degree, htype in htypes:
793 | rel_htype = rearrange(htype, 'b i d m -> b i () d m') - rearrange(htype, 'b j d m -> b () j d m')
794 | rel_htype_dist = rel_htype.norm(dim = -1)
795 |
796 | rel_htypes.append(rel_htype)
797 | rel_htypes_dists.append(rel_htype_dist)
798 |
799 | # prepare edges for edge MLP
800 |
801 | nodes_i = rearrange(nodes, 'b i d -> b i () d')
802 | nodes_j = batched_index_select(nodes, neighbor_indices, dim = 1)
803 | neighbor_higher_type_dists = map(lambda t: batched_index_select(t, neighbor_indices, dim = 2), rel_htypes_dists)
804 | coor_rel_dist = rearrange(rel_dist, 'b i j -> b i j ()')
805 |
806 | edge_mlp_inputs = broadcat((nodes_i, nodes_j, *neighbor_higher_type_dists, coor_rel_dist), dim = -1)
807 |
808 | if exists(edges):
809 | edge_mlp_inputs = torch.cat((edge_mlp_inputs, edges), dim = -1)
810 |
811 | # get intermediate representation
812 |
813 | m_ij = self.edge_mlp(edge_mlp_inputs)
814 |
815 | # to coordinates
816 |
817 | htype_weights = self.htypes_mlp(m_ij)
818 |
819 | if exists(self.coor_weights_clamp_value):
820 | clamp_value = self.coor_weights_clamp_value
821 | htype_weights.clamp_(min = -clamp_value, max = clamp_value)
822 |
823 | split_htype_weights = htype_weights.split(htype_dims, dim = -1)
824 |
825 | htype_updates = []
826 |
827 | if exists(mask):
828 | htype_mask = rearrange(mask, 'b i j -> b i j ()')
829 | htype_weights = htype_weights.masked_fill(~htype_mask, 0.)
830 |
831 | for degree, rel_htype, htype_weight in zip(htype_degrees, rel_htypes, split_htype_weights):
832 | normed_rel_htype = self.htype_norms[str(degree)](rel_htype)
833 | normed_rel_htype = batched_index_select(normed_rel_htype, neighbor_indices, dim = 2)
834 |
835 | htype_update = einsum('b i j d m, b i j d -> b i d m', normed_rel_htype, htype_weight)
836 | htype_updates.append(htype_update)
837 |
838 | # to nodes
839 |
840 | if exists(mask):
841 | m_ij_mask = rearrange(mask, '... -> ... ()')
842 | m_ij = m_ij.masked_fill(~m_ij_mask, 0.)
843 |
844 | m_i = m_ij.sum(dim = -2)
845 |
846 | normed_nodes = self.node_norm(nodes)
847 | node_mlp_input = torch.cat((normed_nodes, m_i), dim = -1)
848 | node_out = self.node_mlp(node_mlp_input) + nodes
849 |
850 | # update nodes
851 |
852 | features['0'] = rearrange(node_out, '... -> ... ()')
853 |
854 | # update higher types
855 |
856 | update_htype_dicts = dict(zip(htype_degrees, htype_updates))
857 |
858 | for degree, update_htype in update_htype_dicts.items():
859 | features[degree] = features[degree] + update_htype
860 |
861 | for degree in htype_degrees:
862 | gating = self.htype_gating[str(degree)](node_out).sigmoid()
863 | features[degree] = features[degree] * rearrange(gating, '... -> ... ()')
864 |
865 | return features
866 |
867 | class EGnnNetwork(nn.Module):
868 | def __init__(
869 | self,
870 | *,
871 | fiber,
872 | depth,
873 | edge_dim = 0,
874 | hidden_dim = 32,
875 | coor_weights_clamp_value = None,
876 | feedforward = False
877 | ):
878 | super().__init__()
879 | self.fiber = fiber
880 | self.layers = nn.ModuleList([])
881 | for _ in range(depth):
882 | self.layers.append(nn.ModuleList([
883 | EGNN(fiber = fiber, edge_dim = edge_dim, hidden_dim = hidden_dim, coor_weights_clamp_value = coor_weights_clamp_value),
884 | FeedForwardBlockSE3(fiber) if feedforward else None
885 | ]))
886 |
887 | def forward(
888 | self,
889 | features,
890 | edge_info,
891 | rel_dist,
892 | basis,
893 | global_feats = None,
894 | pos_emb = None,
895 | mask = None,
896 | **kwargs
897 | ):
898 | neighbor_indices, neighbor_masks, edges = edge_info
899 | device = neighbor_indices.device
900 |
901 | # modify neighbors to include self (since se3 transformer depends on removing attention to token self, but this does not apply for EGNN)
902 |
903 | self_indices = torch.arange(neighbor_indices.shape[1], device = device)
904 | self_indices = rearrange(self_indices, 'i -> () i ()')
905 | neighbor_indices = broadcat((self_indices, neighbor_indices), dim = -1)
906 |
907 | neighbor_masks = F.pad(neighbor_masks, (1, 0), value = True)
908 | rel_dist = F.pad(rel_dist, (1, 0), value = 0.)
909 |
910 | if exists(edges):
911 | edges = F.pad(edges, (0, 0, 1, 0), value = 0.) # make edge of token to itself 0 for now
912 |
913 | edge_info = (neighbor_indices, neighbor_masks, edges)
914 |
915 | # go through layers
916 |
917 | for egnn, ff in self.layers:
918 | features = egnn(
919 | features,
920 | edge_info = edge_info,
921 | rel_dist = rel_dist,
922 | basis = basis,
923 | global_feats = global_feats,
924 | pos_emb = pos_emb,
925 | mask = mask,
926 | **kwargs
927 | )
928 |
929 | if exists(ff):
930 | features = ff(features)
931 |
932 | return features
933 |
934 | # main class
935 |
936 | class SE3Transformer(nn.Module):
937 | def __init__(
938 | self,
939 | *,
940 | dim,
941 | heads = 8,
942 | dim_head = 24,
943 | depth = 2,
944 | input_degrees = 1,
945 | num_degrees = None,
946 | output_degrees = 1,
947 | valid_radius = 1e5,
948 | reduce_dim_out = False,
949 | num_tokens = None,
950 | num_positions = None,
951 | num_edge_tokens = None,
952 | edge_dim = None,
953 | reversible = False,
954 | attend_self = True,
955 | use_null_kv = False,
956 | differentiable_coors = False,
957 | fourier_encode_dist = False,
958 | rel_dist_num_fourier_features = 4,
959 | num_neighbors = float('inf'),
960 | attend_sparse_neighbors = False,
961 | num_adj_degrees = None,
962 | adj_dim = 0,
963 | max_sparse_neighbors = float('inf'),
964 | dim_in = None,
965 | dim_out = None,
966 | norm_out = False,
967 | num_conv_layers = 0,
968 | causal = False,
969 | splits = 4,
970 | global_feats_dim = None,
971 | linear_proj_keys = False,
972 | one_headed_key_values = False,
973 | tie_key_values = False,
974 | rotary_position = False,
975 | rotary_rel_dist = False,
976 | norm_gated_scale = False,
977 | use_egnn = False,
978 | egnn_hidden_dim = 32,
979 | egnn_weights_clamp_value = None,
980 | egnn_feedforward = False,
981 | hidden_fiber_dict = None,
982 | out_fiber_dict = None
983 | ):
984 | super().__init__()
985 | dim_in = default(dim_in, dim)
986 | self.dim_in = cast_tuple(dim_in, input_degrees)
987 | self.dim = dim
988 |
989 | # token embedding
990 |
991 | self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
992 |
993 | # positional embedding
994 |
995 | self.num_positions = num_positions
996 | self.pos_emb = nn.Embedding(num_positions, dim) if exists(num_positions) else None
997 |
998 | self.rotary_rel_dist = rotary_rel_dist
999 | self.rotary_position = rotary_position
1000 |
1001 | self.rotary_pos_emb = None
1002 | if rotary_position or rotary_rel_dist:
1003 | num_rotaries = int(rotary_position) + int(rotary_rel_dist)
1004 | self.rotary_pos_emb = SinusoidalEmbeddings(dim_head // num_rotaries)
1005 |
1006 | # edges
1007 |
1008 | assert not (exists(num_edge_tokens) and not exists(edge_dim)), 'edge dimension (edge_dim) must be supplied if SE3 transformer is to have edge tokens'
1009 |
1010 | self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None
1011 | self.has_edges = exists(edge_dim) and edge_dim > 0
1012 |
1013 | self.input_degrees = input_degrees
1014 |
1015 | assert not (exists(num_adj_degrees) and num_adj_degrees < 1), 'make sure adjacent degrees is greater than 1'
1016 |
1017 | self.num_degrees = num_degrees if exists(num_degrees) else (max(hidden_fiber_dict.keys()) + 1)
1018 |
1019 | output_degrees = output_degrees if not use_egnn else None
1020 | self.output_degrees = output_degrees
1021 |
1022 | # whether to differentiate through basis, needed for alphafold2
1023 |
1024 | self.differentiable_coors = differentiable_coors
1025 |
1026 | # neighbors hyperparameters
1027 |
1028 | self.valid_radius = valid_radius
1029 | self.num_neighbors = num_neighbors
1030 |
1031 | # sparse neighbors, derived from adjacency matrix or edges being passed in
1032 |
1033 | self.attend_sparse_neighbors = attend_sparse_neighbors
1034 | self.max_sparse_neighbors = max_sparse_neighbors
1035 |
1036 | # adjacent neighbor derivation and embed
1037 |
1038 | self.num_adj_degrees = num_adj_degrees
1039 | self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None
1040 |
1041 | edge_dim = (edge_dim if self.has_edges else 0) + (adj_dim if exists(self.adj_emb) else 0)
1042 |
1043 | # define fibers and dimensionality
1044 |
1045 | dim_in = default(dim_in, dim)
1046 | dim_out = default(dim_out, dim)
1047 |
1048 | assert exists(num_degrees) or exists(hidden_fiber_dict), 'either num_degrees or hidden_fiber_dict must be specified'
1049 |
1050 | fiber_in = Fiber.create(input_degrees, dim_in)
1051 |
1052 | if exists(hidden_fiber_dict):
1053 | fiber_hidden = Fiber(hidden_fiber_dict)
1054 | elif exists(num_degrees):
1055 | fiber_hidden = Fiber.create(num_degrees, dim)
1056 |
1057 | if exists(out_fiber_dict):
1058 | fiber_out = Fiber(out_fiber_dict)
1059 | self.output_degrees = max(out_fiber_dict.keys()) + 1
1060 | elif exists(output_degrees):
1061 | fiber_out = Fiber.create(output_degrees, dim_out)
1062 | else:
1063 | fiber_out = None
1064 |
1065 | conv_kwargs = dict(edge_dim = edge_dim, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
1066 |
1067 | # causal
1068 |
1069 | assert not (causal and not attend_self), 'attending to self must be turned on if in autoregressive mode (for the first token)'
1070 | self.causal = causal
1071 |
1072 | # main network
1073 |
1074 | self.conv_in = ConvSE3(fiber_in, fiber_hidden, **conv_kwargs)
1075 |
1076 | # pre-convs
1077 |
1078 | self.convs = nn.ModuleList([])
1079 | for _ in range(num_conv_layers):
1080 | self.convs.append(nn.ModuleList([
1081 | ConvSE3(fiber_hidden, fiber_hidden, **conv_kwargs),
1082 | NormSE3(fiber_hidden, gated_scale = norm_gated_scale)
1083 | ]))
1084 |
1085 | # global features
1086 |
1087 | self.accept_global_feats = exists(global_feats_dim)
1088 | assert not (reversible and self.accept_global_feats), 'reversibility and global features are not compatible'
1089 |
1090 | # trunk
1091 |
1092 | self.attend_self = attend_self
1093 |
1094 | default_attention_klass = OneHeadedKVAttentionSE3 if one_headed_key_values else AttentionSE3
1095 |
1096 | if use_egnn:
1097 | self.net = EGnnNetwork(fiber = fiber_hidden, depth = depth, edge_dim = edge_dim, hidden_dim = egnn_hidden_dim, coor_weights_clamp_value = egnn_weights_clamp_value, feedforward = egnn_feedforward)
1098 | else:
1099 | layers = nn.ModuleList([])
1100 | for ind in range(depth):
1101 | attention_klass = default_attention_klass
1102 |
1103 | layers.append(nn.ModuleList([
1104 | AttentionBlockSE3(fiber_hidden, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, fourier_encode_dist = fourier_encode_dist, rel_dist_num_fourier_features = rel_dist_num_fourier_features, use_null_kv = use_null_kv, splits = splits, global_feats_dim = global_feats_dim, linear_proj_keys = linear_proj_keys, attention_klass = attention_klass, tie_key_values = tie_key_values, norm_gated_scale = norm_gated_scale),
1105 | FeedForwardBlockSE3(fiber_hidden, norm_gated_scale = norm_gated_scale)
1106 | ]))
1107 |
1108 | execution_class = ReversibleSequence if reversible else SequentialSequence
1109 | self.net = execution_class(layers)
1110 |
1111 | # out
1112 |
1113 | self.conv_out = ConvSE3(fiber_hidden, fiber_out, **conv_kwargs) if exists(fiber_out) else None
1114 |
1115 | self.norm = NormSE3(fiber_out, gated_scale = norm_gated_scale, nonlin = nn.Identity()) if (norm_out or reversible) and exists(fiber_out) else nn.Identity()
1116 |
1117 | final_fiber = default(fiber_out, fiber_hidden)
1118 |
1119 | self.linear_out = LinearSE3(
1120 | final_fiber,
1121 | Fiber(list(map(lambda t: FiberEl(degrees = t[0], dim = 1), final_fiber)))
1122 | ) if reduce_dim_out else None
1123 |
1124 | def forward(
1125 | self,
1126 | feats,
1127 | coors,
1128 | mask = None,
1129 | adj_mat = None,
1130 | edges = None,
1131 | return_type = None,
1132 | return_pooled = False,
1133 | neighbor_mask = None,
1134 | global_feats = None
1135 | ):
1136 | assert not (self.accept_global_feats ^ exists(global_feats)), 'you cannot pass in global features unless you init the class correctly'
1137 |
1138 | _mask = mask
1139 |
1140 | if self.output_degrees == 1:
1141 | return_type = 0
1142 |
1143 | if exists(self.token_emb):
1144 | feats = self.token_emb(feats)
1145 |
1146 | if exists(self.pos_emb):
1147 | assert feats.shape[1] <= self.num_positions, 'feature sequence length must be less than the number of positions given at init'
1148 | pos_emb = self.pos_emb(torch.arange(feats.shape[1], device = feats.device))
1149 | feats += rearrange(pos_emb, 'n d -> () n d')
1150 |
1151 | assert not (self.attend_sparse_neighbors and not exists(adj_mat)), 'adjacency matrix (adjacency_mat) or edges (edges) must be passed in'
1152 | assert not (self.has_edges and not exists(edges)), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types'
1153 |
1154 | if torch.is_tensor(feats):
1155 | feats = {'0': feats[..., None]}
1156 |
1157 | if torch.is_tensor(global_feats):
1158 | global_feats = {'0': global_feats[..., None]}
1159 |
1160 | b, n, d, *_, device = *feats['0'].shape, feats['0'].device
1161 |
1162 | assert d == self.dim_in[0], f'feature dimension {d} must be equal to dimension given at init {self.dim_in[0]}'
1163 | assert set(map(int, feats.keys())) == set(range(self.input_degrees)), f'input must have {self.input_degrees} degree'
1164 |
1165 | num_degrees, neighbors, max_sparse_neighbors, valid_radius = self.num_degrees, self.num_neighbors, self.max_sparse_neighbors, self.valid_radius
1166 |
1167 | assert self.attend_sparse_neighbors or neighbors > 0, 'you must either attend to sparsely bonded neighbors, or set number of locally attended neighbors to be greater than 0'
1168 |
1169 | # se3 transformer by default cannot have a node attend to itself
1170 |
1171 | exclude_self_mask = rearrange(~torch.eye(n, dtype = torch.bool, device = device), 'i j -> () i j')
1172 | remove_self = lambda t: t.masked_select(exclude_self_mask).reshape(b, n, n - 1)
1173 | get_max_value = lambda t: torch.finfo(t.dtype).max
1174 |
1175 | # create N-degrees adjacent matrix from 1st degree connections
1176 |
1177 | if exists(self.num_adj_degrees):
1178 | if len(adj_mat.shape) == 2:
1179 | adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b)
1180 |
1181 | adj_indices = adj_mat.clone().long()
1182 |
1183 | for ind in range(self.num_adj_degrees - 1):
1184 | degree = ind + 2
1185 |
1186 | next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
1187 | next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool()
1188 | adj_indices = adj_indices.masked_fill(next_degree_mask, degree)
1189 | adj_mat = next_degree_adj_mat.clone()
1190 |
1191 | adj_indices = adj_indices.masked_select(exclude_self_mask).reshape(b, n, n - 1)
1192 |
1193 | # calculate sparsely connected neighbors
1194 |
1195 | sparse_neighbor_mask = None
1196 | num_sparse_neighbors = 0
1197 |
1198 | if self.attend_sparse_neighbors:
1199 | assert exists(adj_mat), 'adjacency matrix must be passed in (keyword argument adj_mat)'
1200 |
1201 | if exists(adj_mat):
1202 | if len(adj_mat.shape) == 2:
1203 | adj_mat = repeat(adj_mat, 'i j -> b i j', b = b)
1204 |
1205 | adj_mat = remove_self(adj_mat)
1206 |
1207 | adj_mat_values = adj_mat.float()
1208 | adj_mat_max_neighbors = adj_mat_values.sum(dim = -1).max().item()
1209 |
1210 | if max_sparse_neighbors < adj_mat_max_neighbors:
1211 | noise = torch.empty_like(adj_mat_values).uniform_(-0.01, 0.01)
1212 | adj_mat_values += noise
1213 |
1214 | num_sparse_neighbors = int(min(max_sparse_neighbors, adj_mat_max_neighbors))
1215 | values, indices = adj_mat_values.topk(num_sparse_neighbors, dim = -1)
1216 | sparse_neighbor_mask = torch.zeros_like(adj_mat_values).scatter_(-1, indices, values)
1217 | sparse_neighbor_mask = sparse_neighbor_mask > 0.5
1218 |
1219 | # exclude edge of token to itself
1220 |
1221 | indices = repeat(torch.arange(n, device = device), 'j -> b i j', b = b, i = n)
1222 | rel_pos = rearrange(coors, 'b n d -> b n () d') - rearrange(coors, 'b n d -> b () n d')
1223 |
1224 | indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1)
1225 | rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape(b, n, n - 1, 3)
1226 |
1227 | if exists(mask):
1228 | mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j')
1229 | mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1)
1230 |
1231 | if exists(edges):
1232 | if exists(self.edge_emb):
1233 | edges = self.edge_emb(edges)
1234 |
1235 | edges = edges.masked_select(exclude_self_mask[..., None]).reshape(b, n, n - 1, -1)
1236 |
1237 | if exists(self.adj_emb):
1238 | adj_emb = self.adj_emb(adj_indices)
1239 | edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb
1240 |
1241 | rel_dist = rel_pos.norm(dim = -1)
1242 |
1243 | # rel_dist gets modified using adjacency or neighbor mask
1244 |
1245 | modified_rel_dist = rel_dist.clone()
1246 | max_value = get_max_value(modified_rel_dist) # for masking out nodes from being considered as neighbors
1247 |
1248 | # neighbors
1249 |
1250 | if exists(neighbor_mask):
1251 | neighbor_mask = remove_self(neighbor_mask)
1252 |
1253 | max_neighbors = neighbor_mask.sum(dim = -1).max().item()
1254 | if max_neighbors > neighbors:
1255 | print(f'neighbor_mask shows maximum number of neighbors as {max_neighbors} but specified number of neighbors is {neighbors}')
1256 |
1257 | modified_rel_dist = modified_rel_dist.masked_fill(~neighbor_mask, max_value)
1258 |
1259 | # use sparse neighbor mask to assign priority of bonded
1260 |
1261 | if exists(sparse_neighbor_mask):
1262 | modified_rel_dist = modified_rel_dist.masked_fill(sparse_neighbor_mask, 0.)
1263 |
1264 | # mask out future nodes to high distance if causal turned on
1265 |
1266 | if self.causal:
1267 | causal_mask = torch.ones(n, n - 1, device = device).triu().bool()
1268 | modified_rel_dist = modified_rel_dist.masked_fill(causal_mask[None, ...], max_value)
1269 |
1270 | # if number of local neighbors by distance is set to 0, then only fetch the sparse neighbors defined by adjacency matrix
1271 |
1272 | if neighbors == 0:
1273 | valid_radius = 0
1274 |
1275 | # get neighbors and neighbor mask, excluding self
1276 |
1277 | neighbors = int(min(neighbors, n - 1))
1278 | total_neighbors = int(neighbors + num_sparse_neighbors)
1279 | assert total_neighbors > 0, 'you must be fetching at least 1 neighbor'
1280 |
1281 | total_neighbors = int(min(total_neighbors, n - 1)) # make sure total neighbors does not exceed the length of the sequence itself
1282 |
1283 | dist_values, nearest_indices = modified_rel_dist.topk(total_neighbors, dim = -1, largest = False)
1284 | neighbor_mask = dist_values <= valid_radius
1285 |
1286 | neighbor_rel_dist = batched_index_select(rel_dist, nearest_indices, dim = 2)
1287 | neighbor_rel_pos = batched_index_select(rel_pos, nearest_indices, dim = 2)
1288 | neighbor_indices = batched_index_select(indices, nearest_indices, dim = 2)
1289 |
1290 | if exists(mask):
1291 | neighbor_mask = neighbor_mask & batched_index_select(mask, nearest_indices, dim = 2)
1292 |
1293 | if exists(edges):
1294 | edges = batched_index_select(edges, nearest_indices, dim = 2)
1295 |
1296 | # calculate rotary pos emb
1297 |
1298 | rotary_pos_emb = None
1299 | rotary_query_pos_emb = None
1300 | rotary_key_pos_emb = None
1301 |
1302 | if self.rotary_position:
1303 | seq = torch.arange(n, device = device)
1304 | seq_pos_emb = self.rotary_pos_emb(seq)
1305 | self_indices = torch.arange(neighbor_indices.shape[1], device = device)
1306 | self_indices = repeat(self_indices, 'i -> b i ()', b = b)
1307 | neighbor_indices_with_self = torch.cat((self_indices, neighbor_indices), dim = 2)
1308 | pos_emb = batched_index_select(seq_pos_emb, neighbor_indices_with_self, dim = 0)
1309 |
1310 | rotary_key_pos_emb = pos_emb
1311 | rotary_query_pos_emb = repeat(seq_pos_emb, 'n d -> b n d', b = b)
1312 |
1313 | if self.rotary_rel_dist:
1314 | neighbor_rel_dist_with_self = F.pad(neighbor_rel_dist, (1, 0), value = 0) * 1e2
1315 | rel_dist_pos_emb = self.rotary_pos_emb(neighbor_rel_dist_with_self)
1316 | rotary_key_pos_emb = safe_cat(rotary_key_pos_emb, rel_dist_pos_emb, dim = -1)
1317 |
1318 | query_dist = torch.zeros(n, device = device)
1319 | query_pos_emb = self.rotary_pos_emb(query_dist)
1320 | query_pos_emb = repeat(query_pos_emb, 'n d -> b n d', b = b)
1321 |
1322 | rotary_query_pos_emb = safe_cat(rotary_query_pos_emb, query_pos_emb, dim = -1)
1323 |
1324 | if exists(rotary_query_pos_emb) and exists(rotary_key_pos_emb):
1325 | rotary_pos_emb = (rotary_query_pos_emb, rotary_key_pos_emb)
1326 |
1327 | # calculate basis
1328 |
1329 | basis = get_basis(neighbor_rel_pos, num_degrees - 1, differentiable = self.differentiable_coors)
1330 |
1331 | # main logic
1332 |
1333 | edge_info = (neighbor_indices, neighbor_mask, edges)
1334 | x = feats
1335 |
1336 | # project in
1337 |
1338 | x = self.conv_in(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis)
1339 |
1340 | # preconvolution layers
1341 |
1342 | for conv, nonlin in self.convs:
1343 | x = nonlin(x)
1344 | x = conv(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis)
1345 |
1346 | # transformer layers
1347 |
1348 | x = self.net(x, edge_info = edge_info, rel_dist = neighbor_rel_dist, basis = basis, global_feats = global_feats, pos_emb = rotary_pos_emb, mask = _mask)
1349 |
1350 | # project out
1351 |
1352 | if exists(self.conv_out):
1353 | x = self.conv_out(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis)
1354 |
1355 | # norm
1356 |
1357 | x = self.norm(x)
1358 |
1359 | # reduce dim if specified
1360 |
1361 | if exists(self.linear_out):
1362 | x = self.linear_out(x)
1363 | x = map_values(lambda t: t.squeeze(dim = 2), x)
1364 |
1365 | if return_pooled:
1366 | mask_fn = (lambda t: masked_mean(t, _mask, dim = 1)) if exists(_mask) else (lambda t: t.mean(dim = 1))
1367 | x = map_values(mask_fn, x)
1368 |
1369 | if '0' in x:
1370 | x['0'] = x['0'].squeeze(dim = -1)
1371 |
1372 | if exists(return_type):
1373 | return x[str(return_type)]
1374 |
1375 | return x
1376 |
--------------------------------------------------------------------------------
/se3_transformer_pytorch/spherical_harmonics.py:
--------------------------------------------------------------------------------
1 | from math import pi, sqrt
2 | from functools import reduce
3 | from operator import mul
4 | import torch
5 |
6 | from functools import lru_cache
7 | from se3_transformer_pytorch.utils import cache
8 |
9 | # constants
10 |
11 | CACHE = {}
12 |
13 | def clear_spherical_harmonics_cache():
14 | CACHE.clear()
15 |
16 | def lpmv_cache_key_fn(l, m, x):
17 | return (l, m)
18 |
19 | # spherical harmonics
20 |
21 | @lru_cache(maxsize = 1000)
22 | def semifactorial(x):
23 | return reduce(mul, range(x, 1, -2), 1.)
24 |
25 | @lru_cache(maxsize = 1000)
26 | def pochhammer(x, k):
27 | return reduce(mul, range(x + 1, x + k), float(x))
28 |
29 | def negative_lpmv(l, m, y):
30 | if m < 0:
31 | y *= ((-1) ** m / pochhammer(l + m + 1, -2 * m))
32 | return y
33 |
34 | @cache(cache = CACHE, key_fn = lpmv_cache_key_fn)
35 | def lpmv(l, m, x):
36 | """Associated Legendre function including Condon-Shortley phase.
37 |
38 | Args:
39 | m: int order
40 | l: int degree
41 | x: float argument tensor
42 | Returns:
43 | tensor of x-shape
44 | """
45 | # Check memoized versions
46 | m_abs = abs(m)
47 |
48 | if m_abs > l:
49 | return None
50 |
51 | if l == 0:
52 | return torch.ones_like(x)
53 |
54 | # Check if on boundary else recurse solution down to boundary
55 | if m_abs == l:
56 | # Compute P_m^m
57 | y = (-1)**m_abs * semifactorial(2*m_abs-1)
58 | y *= torch.pow(1-x*x, m_abs/2)
59 | return negative_lpmv(l, m, y)
60 |
61 | # Recursively precompute lower degree harmonics
62 | lpmv(l-1, m, x)
63 |
64 | # Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m
65 | # Inplace speedup
66 | y = ((2*l-1) / (l-m_abs)) * x * lpmv(l-1, m_abs, x)
67 |
68 | if l - m_abs > 1:
69 | y -= ((l+m_abs-1)/(l-m_abs)) * CACHE[(l-2, m_abs)]
70 |
71 | if m < 0:
72 | y = self.negative_lpmv(l, m, y)
73 | return y
74 |
75 | def get_spherical_harmonics_element(l, m, theta, phi):
76 | """Tesseral spherical harmonic with Condon-Shortley phase.
77 |
78 | The Tesseral spherical harmonics are also known as the real spherical
79 | harmonics.
80 |
81 | Args:
82 | l: int for degree
83 | m: int for order, where -l <= m < l
84 | theta: collatitude or polar angle
85 | phi: longitude or azimuth
86 | Returns:
87 | tensor of shape theta
88 | """
89 | m_abs = abs(m)
90 | assert m_abs <= l, "absolute value of order m must be <= degree l"
91 |
92 | N = sqrt((2*l + 1) / (4 * pi))
93 | leg = lpmv(l, m_abs, torch.cos(theta))
94 |
95 | if m == 0:
96 | return N * leg
97 |
98 | if m > 0:
99 | Y = torch.cos(m * phi)
100 | else:
101 | Y = torch.sin(m_abs * phi)
102 |
103 | Y *= leg
104 | N *= sqrt(2. / pochhammer(l - m_abs + 1, 2 * m_abs))
105 | Y *= N
106 | return Y
107 |
108 | def get_spherical_harmonics(l, theta, phi):
109 | """ Tesseral harmonic with Condon-Shortley phase.
110 |
111 | The Tesseral spherical harmonics are also known as the real spherical
112 | harmonics.
113 |
114 | Args:
115 | l: int for degree
116 | theta: collatitude or polar angle
117 | phi: longitude or azimuth
118 | Returns:
119 | tensor of shape [*theta.shape, 2*l+1]
120 | """
121 | return torch.stack([ get_spherical_harmonics_element(l, m, theta, phi) \
122 | for m in range(-l, l+1) ],
123 | dim = -1)
124 |
--------------------------------------------------------------------------------
/se3_transformer_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import pickle
5 | import gzip
6 | import torch
7 | import contextlib
8 | from functools import wraps, lru_cache
9 | from filelock import FileLock
10 |
11 | from einops import rearrange
12 |
13 | # helper functions
14 |
15 | def exists(val):
16 | return val is not None
17 |
18 | def default(val, d):
19 | return val if exists(val) else d
20 |
21 | def uniq(arr):
22 | return list({el: True for el in arr}.keys())
23 |
24 | def to_order(degree):
25 | return 2 * degree + 1
26 |
27 | def map_values(fn, d):
28 | return {k: fn(v) for k, v in d.items()}
29 |
30 | def safe_cat(arr, el, dim):
31 | if not exists(arr):
32 | return el
33 | return torch.cat((arr, el), dim = dim)
34 |
35 | def cast_tuple(val, depth):
36 | return val if isinstance(val, tuple) else (val,) * depth
37 |
38 | def broadcat(tensors, dim = -1):
39 | num_tensors = len(tensors)
40 | shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
41 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
42 | shape_len = list(shape_lens)[0]
43 |
44 | dim = (dim + shape_len) if dim < 0 else dim
45 | dims = list(zip(*map(lambda t: list(t.shape), tensors)))
46 |
47 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
48 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
49 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
50 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
51 | expanded_dims.insert(dim, (dim, dims[dim]))
52 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
53 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
54 | return torch.cat(tensors, dim = dim)
55 |
56 | def batched_index_select(values, indices, dim = 1):
57 | value_dims = values.shape[(dim + 1):]
58 | values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
59 | indices = indices[(..., *((None,) * len(value_dims)))]
60 | indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
61 | value_expand_len = len(indices_shape) - (dim + 1)
62 | values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
63 |
64 | value_expand_shape = [-1] * len(values.shape)
65 | expand_slice = slice(dim, (dim + value_expand_len))
66 | value_expand_shape[expand_slice] = indices.shape[expand_slice]
67 | values = values.expand(*value_expand_shape)
68 |
69 | dim += value_expand_len
70 | return values.gather(dim, indices)
71 |
72 | def masked_mean(tensor, mask, dim = -1):
73 | diff_len = len(tensor.shape) - len(mask.shape)
74 | mask = mask[(..., *((None,) * diff_len))]
75 | tensor.masked_fill_(~mask, 0.)
76 |
77 | total_el = mask.sum(dim = dim)
78 | mean = tensor.sum(dim = dim) / total_el.clamp(min = 1.)
79 | mean.masked_fill_(total_el == 0, 0.)
80 | return mean
81 |
82 | def rand_uniform(size, min_val, max_val):
83 | return torch.empty(size).uniform_(min_val, max_val)
84 |
85 | def fast_split(arr, splits, dim=0):
86 | axis_len = arr.shape[dim]
87 | splits = min(axis_len, max(splits, 1))
88 | chunk_size = axis_len // splits
89 | remainder = axis_len - chunk_size * splits
90 | s = 0
91 | for i in range(splits):
92 | adjust, remainder = 1 if remainder > 0 else 0, remainder - 1
93 | yield torch.narrow(arr, dim, s, chunk_size + adjust)
94 | s += chunk_size + adjust
95 |
96 | def fourier_encode(x, num_encodings = 4, include_self = True, flatten = True):
97 | x = x.unsqueeze(-1)
98 | device, dtype, orig_x = x.device, x.dtype, x
99 | scales = 2 ** torch.arange(num_encodings, device = device, dtype = dtype)
100 | x = x / scales
101 | x = torch.cat([x.sin(), x.cos()], dim=-1)
102 | x = torch.cat((x, orig_x), dim = -1) if include_self else x
103 | x = rearrange(x, 'b m n ... -> b m n (...)') if flatten else x
104 | return x
105 |
106 | # default dtype context manager
107 |
108 | @contextlib.contextmanager
109 | def torch_default_dtype(dtype):
110 | prev_dtype = torch.get_default_dtype()
111 | torch.set_default_dtype(dtype)
112 | yield
113 | torch.set_default_dtype(prev_dtype)
114 |
115 | def cast_torch_tensor(fn):
116 | @wraps(fn)
117 | def inner(t):
118 | if not torch.is_tensor(t):
119 | t = torch.tensor(t, dtype = torch.get_default_dtype())
120 | return fn(t)
121 | return inner
122 |
123 | # benchmark tool
124 |
125 | def benchmark(fn):
126 | def inner(*args, **kwargs):
127 | start = time.time()
128 | res = fn(*args, **kwargs)
129 | diff = time.time() - start
130 | return diff, res
131 | return inner
132 |
133 | # caching functions
134 |
135 | def cache(cache, key_fn):
136 | def cache_inner(fn):
137 | @wraps(fn)
138 | def inner(*args, **kwargs):
139 | key_name = key_fn(*args, **kwargs)
140 | if key_name in cache:
141 | return cache[key_name]
142 | res = fn(*args, **kwargs)
143 | cache[key_name] = res
144 | return res
145 |
146 | return inner
147 | return cache_inner
148 |
149 | # cache in directory
150 |
151 | def cache_dir(dirname, maxsize=128):
152 | '''
153 | Cache a function with a directory
154 |
155 | :param dirname: the directory path
156 | :param maxsize: maximum size of the RAM cache (there is no limit for the directory cache)
157 | '''
158 | def decorator(func):
159 |
160 | @lru_cache(maxsize=maxsize)
161 | @wraps(func)
162 | def wrapper(*args, **kwargs):
163 | if not exists(dirname):
164 | return func(*args, **kwargs)
165 |
166 | os.makedirs(dirname, exist_ok = True)
167 |
168 | indexfile = os.path.join(dirname, "index.pkl")
169 | lock = FileLock(os.path.join(dirname, "mutex"))
170 |
171 | with lock:
172 | index = {}
173 | if os.path.exists(indexfile):
174 | with open(indexfile, "rb") as file:
175 | index = pickle.load(file)
176 |
177 | key = (args, frozenset(kwargs), func.__defaults__)
178 |
179 | if key in index:
180 | filename = index[key]
181 | else:
182 | index[key] = filename = f"{len(index)}.pkl.gz"
183 | with open(indexfile, "wb") as file:
184 | pickle.dump(index, file)
185 |
186 | filepath = os.path.join(dirname, filename)
187 |
188 | if os.path.exists(filepath):
189 | with lock:
190 | with gzip.open(filepath, "rb") as file:
191 | result = pickle.load(file)
192 | return result
193 |
194 | print(f"compute {filename}... ", end="", flush = True)
195 | result = func(*args, **kwargs)
196 | print(f"save {filename}... ", end="", flush = True)
197 |
198 | with lock:
199 | with gzip.open(filepath, "wb") as file:
200 | pickle.dump(result, file)
201 |
202 | print("done")
203 |
204 | return result
205 | return wrapper
206 | return decorator
207 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [aliases]
2 | test=pytest
3 |
4 | [tool:pytest]
5 | addopts = --verbose
6 | python_files = tests/*.py
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'se3-transformer-pytorch',
5 | packages = find_packages(),
6 | include_package_data = True,
7 | version = '0.9.0',
8 | license='MIT',
9 | description = 'SE3 Transformer - Pytorch',
10 | author = 'Phil Wang',
11 | author_email = 'lucidrains@gmail.com',
12 | url = 'https://github.com/lucidrains/se3-transformer-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'attention mechanism',
16 | 'transformers',
17 | 'equivariance',
18 | 'SE3'
19 | ],
20 | install_requires=[
21 | 'einops>=0.3',
22 | 'filelock',
23 | 'numpy',
24 | 'torch>=1.6'
25 | ],
26 | setup_requires=[
27 | 'pytest-runner',
28 | ],
29 | tests_require=[
30 | 'pytest',
31 | 'lie_learn',
32 | 'numpy',
33 | ],
34 | classifiers=[
35 | 'Development Status :: 4 - Beta',
36 | 'Intended Audience :: Developers',
37 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
38 | 'License :: OSI Approved :: MIT License',
39 | 'Programming Language :: Python :: 3.6',
40 | ],
41 | )
42 |
--------------------------------------------------------------------------------
/tests/test_basis.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from se3_transformer_pytorch.basis import get_basis, get_R_tensor, basis_transformation_Q_J
3 | from se3_transformer_pytorch.irr_repr import irr_repr
4 |
5 | def test_basis():
6 | max_degree = 3
7 | x = torch.randn(2, 1024, 3)
8 | basis = get_basis(x, max_degree)
9 | assert len(basis.keys()) == (max_degree + 1) ** 2, 'correct number of basis kernels'
10 |
11 | def test_basis_transformation_Q_J():
12 | rand_angles = torch.rand(4, 3)
13 | J, order_out, order_in = 1, 1, 1
14 | Q_J = basis_transformation_Q_J(J, order_in, order_out).float()
15 | assert all(torch.allclose(get_R_tensor(order_out, order_in, a, b, c) @ Q_J, Q_J @ irr_repr(J, a, b, c)) for a, b, c in rand_angles)
16 |
--------------------------------------------------------------------------------
/tests/test_equivariance.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer
3 | from se3_transformer_pytorch.irr_repr import rot
4 | from se3_transformer_pytorch.utils import torch_default_dtype, fourier_encode
5 |
6 | def test_transformer():
7 | model = SE3Transformer(
8 | dim = 64,
9 | depth = 1,
10 | num_degrees = 2,
11 | num_neighbors = 4,
12 | valid_radius = 10
13 | )
14 |
15 | feats = torch.randn(1, 32, 64)
16 | coors = torch.randn(1, 32, 3)
17 | mask = torch.ones(1, 32).bool()
18 |
19 | out = model(feats, coors, mask, return_type = 0)
20 | assert out.shape == (1, 32, 64), 'output must be of the right shape'
21 |
22 | def test_causal_se3_transformer():
23 | model = SE3Transformer(
24 | dim = 64,
25 | depth = 1,
26 | num_degrees = 2,
27 | num_neighbors = 4,
28 | valid_radius = 10,
29 | causal = True
30 | )
31 |
32 | feats = torch.randn(1, 32, 64)
33 | coors = torch.randn(1, 32, 3)
34 | mask = torch.ones(1, 32).bool()
35 |
36 | out = model(feats, coors, mask, return_type = 0)
37 | assert out.shape == (1, 32, 64), 'output must be of the right shape'
38 |
39 | def test_se3_transformer_with_global_nodes():
40 | model = SE3Transformer(
41 | dim = 64,
42 | depth = 1,
43 | num_degrees = 2,
44 | num_neighbors = 4,
45 | valid_radius = 10,
46 | global_feats_dim = 16
47 | )
48 |
49 | feats = torch.randn(1, 32, 64)
50 | coors = torch.randn(1, 32, 3)
51 | mask = torch.ones(1, 32).bool()
52 |
53 | global_feats = torch.randn(1, 2, 16)
54 |
55 | out = model(feats, coors, mask, return_type = 0, global_feats = global_feats)
56 | assert out.shape == (1, 32, 64), 'output must be of the right shape'
57 |
58 | def test_one_headed_key_values_se3_transformer_with_global_nodes():
59 | model = SE3Transformer(
60 | dim = 64,
61 | depth = 1,
62 | num_degrees = 2,
63 | num_neighbors = 4,
64 | valid_radius = 10,
65 | global_feats_dim = 16,
66 | one_headed_key_values = True
67 | )
68 |
69 | feats = torch.randn(1, 32, 64)
70 | coors = torch.randn(1, 32, 3)
71 | mask = torch.ones(1, 32).bool()
72 |
73 | global_feats = torch.randn(1, 2, 16)
74 |
75 | out = model(feats, coors, mask, return_type = 0, global_feats = global_feats)
76 | assert out.shape == (1, 32, 64), 'output must be of the right shape'
77 |
78 | def test_transformer_with_edges():
79 | model = SE3Transformer(
80 | dim = 64,
81 | depth = 1,
82 | num_degrees = 2,
83 | num_neighbors = 4,
84 | edge_dim = 4,
85 | num_edge_tokens = 4
86 | )
87 |
88 | feats = torch.randn(1, 32, 64)
89 | edges = torch.randint(0, 4, (1, 32))
90 | coors = torch.randn(1, 32, 3)
91 | mask = torch.ones(1, 32).bool()
92 |
93 | out = model(feats, coors, mask, edges = edges, return_type = 0)
94 | assert out.shape == (1, 32, 64), 'output must be of the right shape'
95 |
96 | def test_transformer_with_continuous_edges():
97 | model = SE3Transformer(
98 | dim = 64,
99 | depth = 1,
100 | attend_self = True,
101 | num_degrees = 2,
102 | output_degrees = 2,
103 | edge_dim = 34
104 | )
105 |
106 | feats = torch.randn(1, 32, 64)
107 | coors = torch.randn(1, 32, 3)
108 | mask = torch.ones(1, 32).bool()
109 |
110 | pairwise_continuous_values = torch.randint(0, 4, (1, 32, 32, 2))
111 |
112 | edges = fourier_encode(
113 | pairwise_continuous_values,
114 | num_encodings = 8,
115 | include_self = True
116 | )
117 |
118 | out = model(feats, coors, mask, edges = edges, return_type = 1)
119 | assert True
120 |
121 | def test_different_input_dimensions_for_types():
122 | model = SE3Transformer(
123 | dim_in = (4, 2),
124 | dim = 4,
125 | depth = 1,
126 | input_degrees = 2,
127 | num_degrees = 2,
128 | output_degrees = 2,
129 | reduce_dim_out = True
130 | )
131 |
132 | atom_feats = torch.randn(2, 32, 4, 1)
133 | coors_feats = torch.randn(2, 32, 2, 3)
134 |
135 | features = {'0': atom_feats, '1': coors_feats}
136 | coors = torch.randn(2, 32, 3)
137 | mask = torch.ones(2, 32).bool()
138 |
139 | refined_coors = coors + model(features, coors, mask, return_type = 1)
140 | assert True
141 |
142 | def test_equivariance():
143 | model = SE3Transformer(
144 | dim = 64,
145 | depth = 1,
146 | attend_self = True,
147 | num_neighbors = 4,
148 | num_degrees = 2,
149 | output_degrees = 2,
150 | fourier_encode_dist = True
151 | )
152 |
153 | feats = torch.randn(1, 32, 64)
154 | coors = torch.randn(1, 32, 3)
155 | mask = torch.ones(1, 32).bool()
156 |
157 | R = rot(15, 0, 45)
158 | out1 = model(feats, coors @ R, mask, return_type = 1)
159 | out2 = model(feats, coors, mask, return_type = 1) @ R
160 |
161 | diff = (out1 - out2).max()
162 | assert diff < 1e-4, 'is not equivariant'
163 |
164 | def test_equivariance_with_egnn_backbone():
165 | model = SE3Transformer(
166 | dim = 64,
167 | depth = 1,
168 | attend_self = True,
169 | num_neighbors = 4,
170 | num_degrees = 2,
171 | output_degrees = 2,
172 | fourier_encode_dist = True,
173 | use_egnn = True
174 | )
175 |
176 | feats = torch.randn(1, 32, 64)
177 | coors = torch.randn(1, 32, 3)
178 | mask = torch.ones(1, 32).bool()
179 |
180 | R = rot(15, 0, 45)
181 | out1 = model(feats, coors @ R, mask, return_type = 1)
182 | out2 = model(feats, coors, mask, return_type = 1) @ R
183 |
184 | diff = (out1 - out2).max()
185 | assert diff < 1e-4, 'is not equivariant'
186 |
187 | def test_rotary():
188 | model = SE3Transformer(
189 | dim = 64,
190 | depth = 1,
191 | attend_self = True,
192 | num_neighbors = 4,
193 | num_degrees = 2,
194 | output_degrees = 2,
195 | fourier_encode_dist = True,
196 | rotary_position = True,
197 | rotary_rel_dist = True
198 | )
199 |
200 | feats = torch.randn(1, 32, 64)
201 | coors = torch.randn(1, 32, 3)
202 | mask = torch.ones(1, 32).bool()
203 |
204 | R = rot(15, 0, 45)
205 | out1 = model(feats, coors @ R, mask, return_type = 1)
206 | out2 = model(feats, coors, mask, return_type = 1) @ R
207 |
208 | diff = (out1 - out2).max()
209 | assert diff < 1e-4, 'is not equivariant'
210 |
211 | def test_equivariance_linear_proj_keys():
212 | model = SE3Transformer(
213 | dim = 64,
214 | depth = 1,
215 | attend_self = True,
216 | num_neighbors = 4,
217 | num_degrees = 2,
218 | output_degrees = 2,
219 | fourier_encode_dist = True,
220 | linear_proj_keys = True
221 | )
222 |
223 | feats = torch.randn(1, 32, 64)
224 | coors = torch.randn(1, 32, 3)
225 | mask = torch.ones(1, 32).bool()
226 |
227 | R = rot(15, 0, 45)
228 | out1 = model(feats, coors @ R, mask, return_type = 1)
229 | out2 = model(feats, coors, mask, return_type = 1) @ R
230 |
231 | diff = (out1 - out2).max()
232 | assert diff < 1e-4, 'is not equivariant'
233 |
234 | @torch_default_dtype(torch.float64)
235 | def test_equivariance_only_sparse_neighbors():
236 | model = SE3Transformer(
237 | dim = 64,
238 | depth = 1,
239 | attend_self = True,
240 | num_degrees = 2,
241 | output_degrees = 2,
242 | num_neighbors = 0,
243 | attend_sparse_neighbors = True,
244 | num_adj_degrees = 2,
245 | adj_dim = 4
246 | )
247 |
248 | feats = torch.randn(1, 32, 64)
249 | coors = torch.randn(1, 32, 3)
250 | mask = torch.ones(1, 32).bool()
251 |
252 | seq = torch.arange(32)
253 | adj_mat = (seq[:, None] >= (seq[None, :] - 1)) & (seq[:, None] <= (seq[None, :] + 1))
254 |
255 | R = rot(15, 0, 45)
256 | out1 = model(feats, coors @ R, mask, adj_mat = adj_mat, return_type = 1)
257 | out2 = model(feats, coors, mask, adj_mat = adj_mat, return_type = 1) @ R
258 |
259 | diff = (out1 - out2).max()
260 | assert diff < 1e-4, 'is not equivariant'
261 |
262 | def test_equivariance_with_reversible_network():
263 | model = SE3Transformer(
264 | dim = 64,
265 | depth = 1,
266 | attend_self = True,
267 | num_neighbors = 4,
268 | num_degrees = 2,
269 | output_degrees = 2,
270 | reversible = True
271 | )
272 |
273 | feats = torch.randn(1, 32, 64)
274 | coors = torch.randn(1, 32, 3)
275 | mask = torch.ones(1, 32).bool()
276 |
277 | R = rot(15, 0, 45)
278 | out1 = model(feats, coors @ R, mask, return_type = 1)
279 | out2 = model(feats, coors, mask, return_type = 1) @ R
280 |
281 | diff = (out1 - out2).max()
282 | assert diff < 1e-4, 'is not equivariant'
283 |
284 | def test_equivariance_with_type_one_input():
285 | model = SE3Transformer(
286 | dim = 64,
287 | depth = 1,
288 | attend_self = True,
289 | num_neighbors = 4,
290 | num_degrees = 2,
291 | input_degrees = 2,
292 | output_degrees = 2
293 | )
294 |
295 | atom_features = torch.randn(1, 32, 64, 1)
296 | pred_coors = torch.randn(1, 32, 64, 3)
297 |
298 | coors = torch.randn(1, 32, 3)
299 | mask = torch.ones(1, 32).bool()
300 |
301 | R = rot(15, 0, 45)
302 | out1 = model({'0': atom_features, '1': pred_coors @ R}, coors @ R, mask, return_type = 1)
303 | out2 = model({'0': atom_features, '1': pred_coors}, coors, mask, return_type = 1) @ R
304 |
305 | diff = (out1 - out2).max()
306 | assert diff < 1e-4, 'is not equivariant'
307 |
--------------------------------------------------------------------------------
/tests/test_irrep_repr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from se3_transformer_pytorch.spherical_harmonics import clear_spherical_harmonics_cache
3 | from se3_transformer_pytorch.irr_repr import spherical_harmonics, irr_repr, compose
4 | from se3_transformer_pytorch.utils import torch_default_dtype
5 |
6 | @torch_default_dtype(torch.float64)
7 | def test_irr_repr():
8 | """
9 | This test tests that
10 | - irr_repr
11 | - compose
12 | - spherical_harmonics
13 | are compatible
14 |
15 | Y(Z(alpha) Y(beta) Z(gamma) x) = D(alpha, beta, gamma) Y(x)
16 | with x = Z(a) Y(b) eta
17 | """
18 | for order in range(7):
19 | a, b = torch.rand(2)
20 | alpha, beta, gamma = torch.rand(3)
21 |
22 | ra, rb, _ = compose(alpha, beta, gamma, a, b, 0)
23 | Yrx = spherical_harmonics(order, ra, rb)
24 | clear_spherical_harmonics_cache()
25 |
26 | Y = spherical_harmonics(order, a, b)
27 | clear_spherical_harmonics_cache()
28 |
29 | DrY = irr_repr(order, alpha, beta, gamma) @ Y
30 |
31 | d, r = (Yrx - DrY).abs().max(), Y.abs().max()
32 | print(d.item(), r.item())
33 | assert d < 1e-10 * r, d / r
34 |
--------------------------------------------------------------------------------
/tests/test_spherical_harmonics.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import numpy as np
4 |
5 | from lie_learn.representations.SO3.spherical_harmonics import sh
6 |
7 | from se3_transformer_pytorch.spherical_harmonics import get_spherical_harmonics_element
8 | from se3_transformer_pytorch.utils import benchmark
9 |
10 | def test_spherical_harmonics():
11 | dtype = torch.float64
12 |
13 | theta = 0.1 * torch.randn(32, 1024, 10, dtype=dtype)
14 | phi = 0.1 * torch.randn(32, 1024, 10, dtype=dtype)
15 |
16 | s0 = s1 = 0
17 | max_error = -1.
18 |
19 | for l in range(8):
20 | for m in range(-l, l + 1):
21 | start = time.time()
22 |
23 | diff, y = benchmark(get_spherical_harmonics_element)(l, m, theta, phi)
24 | y = y.type(torch.float32)
25 | s0 += diff
26 |
27 | diff, z = benchmark(sh)(l, m, theta, phi)
28 | s1 += diff
29 |
30 | error = np.mean(np.abs((y.cpu().numpy() - z) / z))
31 | max_error = max(max_error, error)
32 | print(f"l: {l}, m: {m} ", error)
33 |
34 | time_diff_ratio = s0 / s1
35 |
36 | assert max_error < 1e-4, 'maximum error must be less than 1e-3'
37 | assert time_diff_ratio < 1., 'spherical harmonics must be faster than the one offered by lie_learn'
38 |
39 | print(f"Max error: {max_error}")
40 | print(f"Time diff: {time_diff_ratio}")
41 |
--------------------------------------------------------------------------------