├── .github └── workflows │ ├── python-package.yml │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── denoise.py ├── diagram.png ├── se3_transformer_pytorch ├── __init__.py ├── basis.py ├── data │ ├── J_dense.npy │ └── J_dense.pt ├── irr_repr.py ├── reversible.py ├── rotary.py ├── se3_transformer_pytorch.py ├── spherical_harmonics.py └── utils.py ├── setup.cfg ├── setup.py └── tests ├── test_basis.py ├── test_equivariance.py ├── test_irrep_repr.py └── test_spherical_harmonics.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.8, 3.9] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python setup.py install 30 | python -m pip install pytest 31 | python -m pip install torch==1.10.0 32 | - name: Test with pytest 33 | run: | 34 | python setup.py pytest 35 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # custom dont-upload files 7 | custom_tests/* 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include se3_transformer_pytorch/data/J_dense.pt 2 | include se3_transformer_pytorch/data/J_dense.npy -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## SE3 Transformer - Pytorch 4 | 5 | Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 results and other drug discovery applications. 6 | 7 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1ICW0DpXfUuVYsnNkt1DHwUyyTduHHvE3?usp=sharing) Example of equivariance 8 | 9 | If you had been using any version of SE3 Transformers prior to version 0.6.0, please update. A huge bug has been uncovered by @MattMcPartlon, if you were not using the adjacency sparse neighbors settings and relying on nearest neighbors functionality 10 | 11 | Update: It is recommended that you use Equiformer instead 12 | 13 | ## Install 14 | 15 | ```bash 16 | $ pip install se3-transformer-pytorch 17 | ``` 18 | 19 | ## Usage 20 | 21 | ```python 22 | import torch 23 | from se3_transformer_pytorch import SE3Transformer 24 | 25 | model = SE3Transformer( 26 | dim = 512, 27 | heads = 8, 28 | depth = 6, 29 | dim_head = 64, 30 | num_degrees = 4, 31 | valid_radius = 10 32 | ) 33 | 34 | feats = torch.randn(1, 1024, 512) 35 | coors = torch.randn(1, 1024, 3) 36 | mask = torch.ones(1, 1024).bool() 37 | 38 | out = model(feats, coors, mask) # (1, 1024, 512) 39 | ``` 40 | 41 | Potential example usage in Alphafold2, as outlined here 42 | 43 | ```python 44 | import torch 45 | from se3_transformer_pytorch import SE3Transformer 46 | 47 | model = SE3Transformer( 48 | dim = 64, 49 | depth = 2, 50 | input_degrees = 1, 51 | num_degrees = 2, 52 | output_degrees = 2, 53 | reduce_dim_out = True, 54 | differentiable_coors = True 55 | ) 56 | 57 | atom_feats = torch.randn(2, 32, 64) 58 | coors = torch.randn(2, 32, 3) 59 | mask = torch.ones(2, 32).bool() 60 | 61 | refined_coors = coors + model(atom_feats, coors, mask, return_type = 1) # (2, 32, 3) 62 | ``` 63 | 64 | You can also let the base transformer class take care of embedding the type 0 features being passed in. Assuming they are atoms 65 | 66 | ```python 67 | import torch 68 | from se3_transformer_pytorch import SE3Transformer 69 | 70 | model = SE3Transformer( 71 | num_tokens = 28, # 28 unique atoms 72 | dim = 64, 73 | depth = 2, 74 | input_degrees = 1, 75 | num_degrees = 2, 76 | output_degrees = 2, 77 | reduce_dim_out = True 78 | ) 79 | 80 | atoms = torch.randint(0, 28, (2, 32)) 81 | coors = torch.randn(2, 32, 3) 82 | mask = torch.ones(2, 32).bool() 83 | 84 | refined_coors = coors + model(atoms, coors, mask, return_type = 1) # (2, 32, 3) 85 | ``` 86 | 87 | If you think the net could further benefit from positional encoding, you can featurize your positions in space and pass it in as follows. 88 | 89 | ```python 90 | import torch 91 | from se3_transformer_pytorch import SE3Transformer 92 | 93 | model = SE3Transformer( 94 | dim = 64, 95 | depth = 2, 96 | input_degrees = 2, 97 | num_degrees = 2, 98 | output_degrees = 2, 99 | reduce_dim_out = True # reduce out the final dimension 100 | ) 101 | 102 | atom_feats = torch.randn(2, 32, 64, 1) # b x n x d x type0 103 | coors_feats = torch.randn(2, 32, 64, 3) # b x n x d x type1 104 | 105 | # atom features are type 0, predicted coordinates are type 1 106 | features = {'0': atom_feats, '1': coors_feats} 107 | coors = torch.randn(2, 32, 3) 108 | mask = torch.ones(2, 32).bool() 109 | 110 | refined_coors = coors + model(features, coors, mask, return_type = 1) # (2, 32, 3) - equivariant to input type 1 features and coordinates 111 | ``` 112 | 113 | ## Edges 114 | 115 | To offer edge information to SE3 Transformers (say bond types between atoms), you just have to pass in two more keyword arguments on initialization. 116 | 117 | ```python 118 | import torch 119 | from se3_transformer_pytorch import SE3Transformer 120 | 121 | model = SE3Transformer( 122 | num_tokens = 28, 123 | dim = 64, 124 | num_edge_tokens = 4, # number of edge type, say 4 bond types 125 | edge_dim = 16, # dimension of edge embedding 126 | depth = 2, 127 | input_degrees = 1, 128 | num_degrees = 3, 129 | output_degrees = 1, 130 | reduce_dim_out = True 131 | ) 132 | 133 | atoms = torch.randint(0, 28, (2, 32)) 134 | bonds = torch.randint(0, 4, (2, 32, 32)) 135 | coors = torch.randn(2, 32, 3) 136 | mask = torch.ones(2, 32).bool() 137 | 138 | pred = model(atoms, coors, mask, edges = bonds, return_type = 0) # (2, 32, 1) 139 | ``` 140 | 141 | If you would like to pass in continuous values for your edges, you can choose to not set the `num_edge_tokens`, encode your discrete bond types, and then concat it to the fourier features of these continuous values 142 | 143 | ```python 144 | import torch 145 | from se3_transformer_pytorch import SE3Transformer 146 | from se3_transformer_pytorch.utils import fourier_encode 147 | 148 | model = SE3Transformer( 149 | dim = 64, 150 | depth = 1, 151 | attend_self = True, 152 | num_degrees = 2, 153 | output_degrees = 2, 154 | edge_dim = 34 # edge dimension must match the final dimension of the edges being passed in 155 | ) 156 | 157 | feats = torch.randn(1, 32, 64) 158 | coors = torch.randn(1, 32, 3) 159 | mask = torch.ones(1, 32).bool() 160 | 161 | pairwise_continuous_values = torch.randint(0, 4, (1, 32, 32, 2)) # say there are 2 162 | 163 | edges = fourier_encode( 164 | pairwise_continuous_values, 165 | num_encodings = 8, 166 | include_self = True 167 | ) # (1, 32, 32, 34) - {2 * (2 * 8 + 1)} 168 | 169 | out = model(feats, coors, mask, edges = edges, return_type = 1) 170 | ``` 171 | 172 | ## Sparse Neighbors 173 | 174 | If you know the connectivity of your points (say you are working with molecules), you can pass in an adjacency matrix, in the form of a boolean mask (where `True` indicates connectivity). 175 | 176 | ```python 177 | import torch 178 | from se3_transformer_pytorch import SE3Transformer 179 | 180 | model = SE3Transformer( 181 | dim = 32, 182 | heads = 8, 183 | depth = 1, 184 | dim_head = 64, 185 | num_degrees = 2, 186 | valid_radius = 10, 187 | attend_sparse_neighbors = True, # this must be set to true, in which case it will assert that you pass in the adjacency matrix 188 | num_neighbors = 0, # if you set this to 0, it will only consider the connected neighbors as defined by the adjacency matrix. but if you set a value greater than 0, it will continue to fetch the closest points up to this many, excluding the ones already specified by the adjacency matrix 189 | max_sparse_neighbors = 8 # you can cap the number of neighbors, sampled from within your sparse set of neighbors as defined by the adjacency matrix, if specified 190 | ) 191 | 192 | feats = torch.randn(1, 128, 32) 193 | coors = torch.randn(1, 128, 3) 194 | mask = torch.ones(1, 128).bool() 195 | 196 | # placeholder adjacency matrix 197 | # naively assuming the sequence is one long chain (128, 128) 198 | 199 | i = torch.arange(128) 200 | adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1)) 201 | 202 | out = model(feats, coors, mask, adj_mat = adj_mat) # (1, 128, 512) 203 | ``` 204 | 205 | You can also have the network automatically derive for you the Nth-degree neighbors with one extra keyword `num_adj_degrees`. If you would like the system to differentiate between the degree of the neighbors as edge information, further pass in a non-zero `adj_dim`. 206 | 207 | ```python 208 | import torch 209 | from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer 210 | 211 | model = SE3Transformer( 212 | dim = 64, 213 | depth = 1, 214 | attend_self = True, 215 | num_degrees = 2, 216 | output_degrees = 2, 217 | num_neighbors = 0, 218 | attend_sparse_neighbors = True, 219 | num_adj_degrees = 2, # automatically derive 2nd degree neighbors 220 | adj_dim = 4 # embed 1st and 2nd degree neighbors (as well as null neighbors) with edge embeddings of this dimension 221 | ) 222 | 223 | feats = torch.randn(1, 32, 64) 224 | coors = torch.randn(1, 32, 3) 225 | mask = torch.ones(1, 32).bool() 226 | 227 | # placeholder adjacency matrix 228 | # naively assuming the sequence is one long chain (128, 128) 229 | 230 | i = torch.arange(128) 231 | adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1)) 232 | 233 | out = model(feats, coors, mask, adj_mat = adj_mat, return_type = 1) 234 | ``` 235 | 236 | To have fine control over the dimensionality of each type, you can use the `hidden_fiber_dict` and `out_fiber_dict` keywords to pass in a dictionary with the degree to dimension values as the key / values. 237 | 238 | ```python 239 | import torch 240 | from se3_transformer_pytorch import SE3Transformer 241 | 242 | model = SE3Transformer( 243 | num_tokens = 28, 244 | dim = 64, 245 | num_edge_tokens = 4, 246 | edge_dim = 16, 247 | depth = 2, 248 | input_degrees = 1, 249 | num_degrees = 3, 250 | output_degrees = 1, 251 | hidden_fiber_dict = {0: 16, 1: 8, 2: 4}, 252 | out_fiber_dict = {0: 16, 1: 1}, 253 | reduce_dim_out = False 254 | ) 255 | 256 | atoms = torch.randint(0, 28, (2, 32)) 257 | bonds = torch.randint(0, 4, (2, 32, 32)) 258 | coors = torch.randn(2, 32, 3) 259 | mask = torch.ones(2, 32).bool() 260 | 261 | pred = model(atoms, coors, mask, edges = bonds) 262 | 263 | pred['0'] # (2, 32, 16) 264 | pred['1'] # (2, 32, 1, 3) 265 | ``` 266 | 267 | ## Neighbors 268 | 269 | You can further control which nodes can be considered by passing in a neighbor mask. All `False` values will be masked out of consideration. 270 | 271 | ```python 272 | import torch 273 | from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer 274 | 275 | model = SE3Transformer( 276 | dim = 16, 277 | dim_head = 16, 278 | attend_self = True, 279 | num_degrees = 4, 280 | output_degrees = 2, 281 | num_edge_tokens = 4, 282 | num_neighbors = 8, # make sure you set this value as the maximum number of neighbors set by your neighbor_mask, or it will throw a warning 283 | edge_dim = 2, 284 | depth = 3 285 | ) 286 | 287 | feats = torch.randn(1, 32, 16) 288 | coors = torch.randn(1, 32, 3) 289 | mask = torch.ones(1, 32).bool() 290 | bonds = torch.randint(0, 4, (1, 32, 32)) 291 | 292 | neighbor_mask = torch.ones(1, 32, 32).bool() # set the nodes you wish to be masked out as False 293 | 294 | out = model( 295 | feats, 296 | coors, 297 | mask, 298 | edges = bonds, 299 | neighbor_mask = neighbor_mask, 300 | return_type = 1 301 | ) 302 | ``` 303 | 304 | ## Global Nodes 305 | 306 | This feature allows you to pass in vectors that can be viewed as global nodes that are seen by all other nodes. The idea would be to pool your graph into a few feature vectors, which will be projected to key / values across all the attention layers in the network. All nodes will have full access to global node information, regardless of nearest neighbors or adjacency calculation. 307 | 308 | ```python 309 | import torch 310 | from torch import nn 311 | from se3_transformer_pytorch import SE3Transformer 312 | 313 | model = SE3Transformer( 314 | dim = 64, 315 | depth = 1, 316 | num_degrees = 2, 317 | num_neighbors = 4, 318 | valid_radius = 10, 319 | global_feats_dim = 32 # this must be set to the dimension of the global features, in this example, 32 320 | ) 321 | 322 | feats = torch.randn(1, 32, 64) 323 | coors = torch.randn(1, 32, 3) 324 | mask = torch.ones(1, 32).bool() 325 | 326 | # naively derive global features 327 | # by pooling features and projecting 328 | global_feats = nn.Linear(64, 32)(feats.mean(dim = 1, keepdim = True)) # (1, 1, 32) 329 | 330 | out = model(feats, coors, mask, return_type = 0, global_feats = global_feats) 331 | ``` 332 | 333 | Todo: 334 | 335 | - [ ] allow global nodes to attend to all other nodes, to give the network a global conduit for information. (Similar to BigBird, ETC, Longformer etc) 336 | 337 | ## Autoregressive 338 | 339 | You can use SE3 Transformers autoregressively with just one extra flag 340 | 341 | ```python 342 | import torch 343 | from se3_transformer_pytorch import SE3Transformer 344 | 345 | model = SE3Transformer( 346 | dim = 512, 347 | heads = 8, 348 | depth = 6, 349 | dim_head = 64, 350 | num_degrees = 4, 351 | valid_radius = 10, 352 | causal = True # set this to True 353 | ) 354 | 355 | feats = torch.randn(1, 1024, 512) 356 | coors = torch.randn(1, 1024, 3) 357 | mask = torch.ones(1, 1024).bool() 358 | 359 | out = model(feats, coors, mask) # (1, 1024, 512) 360 | ``` 361 | 362 | ## Experimental Features 363 | 364 | ### Non-pairwise convolved keys 365 | 366 | I've discovered that using linearly projected keys (rather than the pairwise convolution) seems to do ok in a toy denoising task. This leads to 25% memory savings. You can try this feature by setting `linear_proj_keys = True` 367 | 368 | ```python 369 | import torch 370 | from se3_transformer_pytorch import SE3Transformer 371 | 372 | model = SE3Transformer( 373 | dim = 64, 374 | depth = 1, 375 | num_degrees = 4, 376 | num_neighbors = 8, 377 | valid_radius = 10, 378 | splits = 4, 379 | linear_proj_keys = True # set this to True 380 | ).cuda() 381 | 382 | feats = torch.randn(1, 32, 64).cuda() 383 | coors = torch.randn(1, 32, 3).cuda() 384 | mask = torch.ones(1, 32).bool().cuda() 385 | 386 | out = model(feats, coors, mask, return_type = 0) 387 | ``` 388 | 389 | ### Shared key / values across all heads 390 | 391 | There is a relatively unknown technique for transformers where one can share one key / value head across all the heads of the queries. In my experience in NLP, this usually leads to worse performance, but if you are really in need to tradeoff memory for more depth or higher number of degrees, this may be a good option. 392 | 393 | ```python 394 | import torch 395 | from se3_transformer_pytorch import SE3Transformer 396 | 397 | model = SE3Transformer( 398 | dim = 64, 399 | depth = 8, 400 | num_degrees = 4, 401 | num_neighbors = 8, 402 | valid_radius = 10, 403 | splits = 4, 404 | one_headed_key_values = True # one head of key / values shared across all heads of the queries 405 | ).cuda() 406 | 407 | feats = torch.randn(1, 32, 64).cuda() 408 | coors = torch.randn(1, 32, 3).cuda() 409 | mask = torch.ones(1, 32).bool().cuda() 410 | 411 | out = model(feats, coors, mask, return_type = 0) 412 | ``` 413 | 414 | ### Tied key / values 415 | 416 | You can also tie the key / values (have them be the same), for half memory savings 417 | 418 | ```python 419 | import torch 420 | from se3_transformer_pytorch import SE3Transformer 421 | 422 | model = SE3Transformer( 423 | dim = 64, 424 | depth = 8, 425 | num_degrees = 4, 426 | num_neighbors = 8, 427 | valid_radius = 10, 428 | splits = 4, 429 | tie_key_values = True # set this to True 430 | ).cuda() 431 | 432 | feats = torch.randn(1, 32, 64).cuda() 433 | coors = torch.randn(1, 32, 3).cuda() 434 | mask = torch.ones(1, 32).bool().cuda() 435 | 436 | out = model(feats, coors, mask, return_type = 0) 437 | ``` 438 | 439 | ### Using EGNN 440 | 441 | This is an experimental version of EGNN that works for higher types, and greater dimensionality than just 1 (for the coordinates). The class name is still `SE3Transformer` since it reuses some preexisting logic, so just ignore that for now until I clean it up later. 442 | 443 | ```python 444 | import torch 445 | from se3_transformer_pytorch import SE3Transformer 446 | 447 | model = SE3Transformer( 448 | dim = 32, 449 | num_neighbors = 8, 450 | num_edge_tokens = 4, 451 | edge_dim = 4, 452 | num_degrees = 4, # number of higher order types - will use basis on a TCN to project to these dimensions 453 | use_egnn = True, # set this to true to use EGNN instead of equivariant attention layers 454 | egnn_hidden_dim = 64, # egnn hidden dimension 455 | depth = 4, # depth of EGNN 456 | reduce_dim_out = True # will project the dimension of the higher types to 1 457 | ).cuda() 458 | 459 | feats = torch.randn(2, 32, 32).cuda() 460 | coors = torch.randn(2, 32, 3).cuda() 461 | bonds = torch.randint(0, 4, (2, 32, 32)).cuda() 462 | mask = torch.ones(2, 32).bool().cuda() 463 | 464 | refinement = model(feats, coors, mask, edges = bonds, return_type = 1) # (2, 32, 3) 465 | 466 | coors = coors + refinement # update coors with refinement 467 | ``` 468 | 469 | If you would like to specify individual dimensions for each of the higher types, just pass in `hidden_fiber_dict` where the dictionary is in the format {\:\} instead of `num_degrees` 470 | 471 | ```python 472 | import torch 473 | from se3_transformer_pytorch import SE3Transformer 474 | 475 | model = SE3Transformer( 476 | dim = 32, 477 | num_neighbors = 8, 478 | hidden_fiber_dict = {0: 32, 1: 16, 2: 8, 3: 4}, 479 | use_egnn = True, 480 | depth = 4, 481 | egnn_hidden_dim = 64, 482 | egnn_weights_clamp_value = 2, 483 | reduce_dim_out = True 484 | ).cuda() 485 | 486 | feats = torch.randn(2, 32, 32).cuda() 487 | coors = torch.randn(2, 32, 3).cuda() 488 | mask = torch.ones(2, 32).bool().cuda() 489 | 490 | refinement = model(feats, coors, mask, return_type = 1) # (2, 32, 3) 491 | 492 | coors = coors + refinement # update coors with refinement 493 | ``` 494 | 495 | ## Scaling (wip) 496 | 497 | This section will list ongoing efforts to make SE3 Transformer scale a little better. 498 | 499 | Firstly, I have added reversible networks. This allows me to add a little more depth before hitting the usual memory roadblocks. Equivariance preservation is demonstrated in the tests. 500 | 501 | ```python 502 | import torch 503 | from se3_transformer_pytorch import SE3Transformer 504 | 505 | model = SE3Transformer( 506 | num_tokens = 20, 507 | dim = 32, 508 | dim_head = 32, 509 | heads = 4, 510 | depth = 12, # 12 layers 511 | input_degrees = 1, 512 | num_degrees = 3, 513 | output_degrees = 1, 514 | reduce_dim_out = True, 515 | reversible = True # set reversible to True 516 | ).cuda() 517 | 518 | atoms = torch.randint(0, 4, (2, 32)).cuda() 519 | coors = torch.randn(2, 32, 3).cuda() 520 | mask = torch.ones(2, 32).bool().cuda() 521 | 522 | pred = model(atoms, coors, mask = mask, return_type = 0) 523 | 524 | loss = pred.sum() 525 | loss.backward() 526 | ``` 527 | 528 | ## Examples 529 | 530 | First install `sidechainnet` 531 | 532 | ```bash 533 | $ pip install sidechainnet 534 | ``` 535 | 536 | Then run the protein backbone denoising task 537 | 538 | ```bash 539 | $ python denoise.py 540 | ``` 541 | 542 | ## Caching 543 | 544 | By default, the basis vectors are cached. However, if there is ever the need to clear the cache, you simply have to set the environmental flag `CLEAR_CACHE` to some value on initiating the script 545 | 546 | ```bash 547 | $ CLEAR_CACHE=1 python train.py 548 | ``` 549 | 550 | Or you can try deleting the cache directory, which should exist at 551 | 552 | ```bash 553 | $ rm -rf ~/.cache.equivariant_attention 554 | ``` 555 | 556 | You can also designate your own directory where you want the caches to be stored, in the case that the default directory may have permission issues 557 | 558 | ```bash 559 | CACHE_PATH=./path/to/my/cache python train.py 560 | ``` 561 | 562 | ## Testing 563 | 564 | ```bash 565 | $ python setup.py pytest 566 | ``` 567 | 568 | ## Credit 569 | 570 | This library is largely a port of Fabian's official repository, but without the DGL library. 571 | 572 | ## Citations 573 | 574 | ```bibtex 575 | @misc{fuchs2020se3transformers, 576 | title = {SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks}, 577 | author = {Fabian B. Fuchs and Daniel E. Worrall and Volker Fischer and Max Welling}, 578 | year = {2020}, 579 | eprint = {2006.10503}, 580 | archivePrefix = {arXiv}, 581 | primaryClass = {cs.LG} 582 | } 583 | ``` 584 | 585 | ```bibtex 586 | @misc{satorras2021en, 587 | title = {E(n) Equivariant Graph Neural Networks}, 588 | author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling}, 589 | year = {2021}, 590 | eprint = {2102.09844}, 591 | archivePrefix = {arXiv}, 592 | primaryClass = {cs.LG} 593 | } 594 | ``` 595 | 596 | ```bibtex 597 | @misc{gomez2017reversible, 598 | title = {The Reversible Residual Network: Backpropagation Without Storing Activations}, 599 | author = {Aidan N. Gomez and Mengye Ren and Raquel Urtasun and Roger B. Grosse}, 600 | year = {2017}, 601 | eprint = {1707.04585}, 602 | archivePrefix = {arXiv}, 603 | primaryClass = {cs.CV} 604 | } 605 | ``` 606 | 607 | ```bibtex 608 | @misc{shazeer2019fast, 609 | title = {Fast Transformer Decoding: One Write-Head is All You Need}, 610 | author = {Noam Shazeer}, 611 | year = {2019}, 612 | eprint = {1911.02150}, 613 | archivePrefix = {arXiv}, 614 | primaryClass = {cs.NE} 615 | } 616 | ``` 617 | -------------------------------------------------------------------------------- /denoise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.optim import Adam 4 | 5 | from einops import rearrange, repeat 6 | 7 | import sidechainnet as scn 8 | from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer 9 | 10 | torch.set_default_dtype(torch.float64) 11 | 12 | BATCH_SIZE = 1 13 | GRADIENT_ACCUMULATE_EVERY = 16 14 | 15 | def cycle(loader, len_thres = 500): 16 | while True: 17 | for data in loader: 18 | if data.seqs.shape[1] > len_thres: 19 | continue 20 | yield data 21 | 22 | transformer = SE3Transformer( 23 | num_tokens = 24, 24 | dim = 8, 25 | dim_head = 8, 26 | heads = 2, 27 | depth = 2, 28 | attend_self = True, 29 | input_degrees = 1, 30 | output_degrees = 2, 31 | reduce_dim_out = True, 32 | differentiable_coors = True, 33 | num_neighbors = 0, 34 | attend_sparse_neighbors = True, 35 | num_adj_degrees = 2, 36 | adj_dim = 4, 37 | num_degrees=2, 38 | ) 39 | 40 | data = scn.load( 41 | casp_version = 12, 42 | thinning = 30, 43 | with_pytorch = 'dataloaders', 44 | batch_size = BATCH_SIZE, 45 | dynamic_batching = False 46 | ) 47 | # Add gaussian noise to the coords 48 | # Testing the refinement algorithm 49 | 50 | dl = cycle(data['train']) 51 | optim = Adam(transformer.parameters(), lr=1e-4) 52 | transformer = transformer.cuda() 53 | 54 | for _ in range(10000): 55 | for _ in range(GRADIENT_ACCUMULATE_EVERY): 56 | batch = next(dl) 57 | seqs, coords, masks = batch.seqs, batch.crds, batch.msks 58 | 59 | seqs = seqs.cuda().argmax(dim = -1) 60 | coords = coords.cuda().type(torch.float64) 61 | masks = masks.cuda().bool() 62 | 63 | l = seqs.shape[1] 64 | coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14) 65 | 66 | # Keeping only the backbone coordinates 67 | coords = coords[:, :, 0:3, :] 68 | coords = rearrange(coords, 'b l s c -> b (l s) c') 69 | 70 | seq = repeat(seqs, 'b n -> b (n c)', c = 3) 71 | masks = repeat(masks, 'b n -> b (n c)', c = 3) 72 | 73 | noised_coords = coords + torch.randn_like(coords).cuda() 74 | 75 | i = torch.arange(seq.shape[-1], device = seqs.device) 76 | adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1)) 77 | 78 | out = transformer( 79 | seq, 80 | noised_coords, 81 | mask = masks, 82 | adj_mat = adj_mat, 83 | return_type = 1 84 | ) 85 | 86 | denoised_coords = noised_coords + out 87 | 88 | loss = F.mse_loss(denoised_coords[masks], coords[masks]) 89 | (loss / GRADIENT_ACCUMULATE_EVERY).backward() 90 | 91 | print('loss:', loss.item()) 92 | optim.step() 93 | optim.zero_grad() 94 | -------------------------------------------------------------------------------- /diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/se3-transformer-pytorch/e1669ee69345c271b7aa0a3aeeda452d32e26736/diagram.png -------------------------------------------------------------------------------- /se3_transformer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer 2 | -------------------------------------------------------------------------------- /se3_transformer_pytorch/basis.py: -------------------------------------------------------------------------------- 1 | import os 2 | from math import pi 3 | import torch 4 | from torch import einsum 5 | from einops import rearrange 6 | from itertools import product 7 | from contextlib import contextmanager 8 | 9 | from se3_transformer_pytorch.irr_repr import irr_repr, spherical_harmonics 10 | from se3_transformer_pytorch.utils import torch_default_dtype, cache_dir, exists, default, to_order 11 | from se3_transformer_pytorch.spherical_harmonics import clear_spherical_harmonics_cache 12 | 13 | # constants 14 | 15 | CACHE_PATH = default(os.getenv('CACHE_PATH'), os.path.expanduser('~/.cache.equivariant_attention')) 16 | CACHE_PATH = CACHE_PATH if not exists(os.environ.get('CLEAR_CACHE')) else None 17 | 18 | # todo (figure ot why this was hard coded in official repo) 19 | 20 | RANDOM_ANGLES = [ 21 | [4.41301023, 5.56684102, 4.59384642], 22 | [4.93325116, 6.12697327, 4.14574096], 23 | [0.53878964, 4.09050444, 5.36539036], 24 | [2.16017393, 3.48835314, 5.55174441], 25 | [2.52385107, 0.2908958, 3.90040975] 26 | ] 27 | 28 | # helpers 29 | 30 | @contextmanager 31 | def null_context(): 32 | yield 33 | 34 | # functions 35 | 36 | def get_matrix_kernel(A, eps = 1e-10): 37 | ''' 38 | Compute an orthonormal basis of the kernel (x_1, x_2, ...) 39 | A x_i = 0 40 | scalar_product(x_i, x_j) = delta_ij 41 | 42 | :param A: matrix 43 | :return: matrix where each row is a basis vector of the kernel of A 44 | ''' 45 | _u, s, v = torch.svd(A) 46 | kernel = v.t()[s < eps] 47 | return kernel 48 | 49 | 50 | def get_matrices_kernel(As, eps = 1e-10): 51 | ''' 52 | Computes the common kernel of all the As matrices 53 | ''' 54 | matrix = torch.cat(As, dim=0) 55 | return get_matrix_kernel(matrix, eps) 56 | 57 | def get_spherical_from_cartesian(cartesian, divide_radius_by = 1.0): 58 | """ 59 | # ON ANGLE CONVENTION 60 | # 61 | # sh has following convention for angles: 62 | # :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)). 63 | # :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi. 64 | # 65 | # the 3D steerable CNN code therefore (probably) has the following convention for alpha and beta: 66 | # beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1)). 67 | # alpha = phi 68 | # 69 | """ 70 | # initialise return array 71 | spherical = torch.zeros_like(cartesian) 72 | 73 | # indices for return array 74 | ind_radius, ind_alpha, ind_beta = 0, 1, 2 75 | 76 | cartesian_x, cartesian_y, cartesian_z = 2, 0, 1 77 | 78 | # get projected radius in xy plane 79 | r_xy = cartesian[..., cartesian_x] ** 2 + cartesian[..., cartesian_y] ** 2 80 | 81 | # get second angle 82 | # version 'elevation angle defined from Z-axis down' 83 | spherical[..., ind_beta] = torch.atan2(torch.sqrt(r_xy), cartesian[..., cartesian_z]) 84 | 85 | # get angle in x-y plane 86 | spherical[...,ind_alpha] = torch.atan2(cartesian[...,cartesian_y], cartesian[...,cartesian_x]) 87 | 88 | # get overall radius 89 | radius = torch.sqrt(r_xy + cartesian[...,cartesian_z]**2) 90 | 91 | if divide_radius_by != 1.0: 92 | radius /= divide_radius_by 93 | 94 | spherical[..., ind_radius] = radius 95 | return spherical 96 | 97 | def kron(a, b): 98 | """ 99 | A part of the pylabyk library: numpytorch.py at https://github.com/yulkang/pylabyk 100 | 101 | Kronecker product of matrices a and b with leading batch dimensions. 102 | Batch dimensions are broadcast. The number of them mush 103 | :type a: torch.Tensor 104 | :type b: torch.Tensor 105 | :rtype: torch.Tensor 106 | """ 107 | res = einsum('... i j, ... k l -> ... i k j l', a, b) 108 | return rearrange(res, '... i j k l -> ... (i j) (k l)') 109 | 110 | def get_R_tensor(order_out, order_in, a, b, c): 111 | return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c)) 112 | 113 | def sylvester_submatrix(order_out, order_in, J, a, b, c): 114 | ''' generate Kronecker product matrix for solving the Sylvester equation in subspace J ''' 115 | R_tensor = get_R_tensor(order_out, order_in, a, b, c) # [m_out * m_in, m_out * m_in] 116 | R_irrep_J = irr_repr(J, a, b, c) # [m, m] 117 | 118 | R_tensor_identity = torch.eye(R_tensor.shape[0]) 119 | R_irrep_J_identity = torch.eye(R_irrep_J.shape[0]) 120 | 121 | return kron(R_tensor, R_irrep_J_identity) - kron(R_tensor_identity, R_irrep_J.t()) # [(m_out * m_in) * m, (m_out * m_in) * m] 122 | 123 | @cache_dir(CACHE_PATH) 124 | @torch_default_dtype(torch.float64) 125 | @torch.no_grad() 126 | def basis_transformation_Q_J(J, order_in, order_out, random_angles = RANDOM_ANGLES): 127 | """ 128 | :param J: order of the spherical harmonics 129 | :param order_in: order of the input representation 130 | :param order_out: order of the output representation 131 | :return: one part of the Q^-1 matrix of the article 132 | """ 133 | sylvester_submatrices = [sylvester_submatrix(order_out, order_in, J, a, b, c) for a, b, c in random_angles] 134 | null_space = get_matrices_kernel(sylvester_submatrices) 135 | assert null_space.size(0) == 1, null_space.size() # unique subspace solution 136 | Q_J = null_space[0] # [(m_out * m_in) * m] 137 | Q_J = Q_J.view(to_order(order_out) * to_order(order_in), to_order(J)) # [m_out * m_in, m] 138 | return Q_J.float() # [m_out * m_in, m] 139 | 140 | def precompute_sh(r_ij, max_J): 141 | """ 142 | pre-comput spherical harmonics up to order max_J 143 | 144 | :param r_ij: relative positions 145 | :param max_J: maximum order used in entire network 146 | :return: dict where each entry has shape [B,N,K,2J+1] 147 | """ 148 | i_alpha, i_beta = 1, 2 149 | Y_Js = {J: spherical_harmonics(J, r_ij[...,i_alpha], r_ij[...,i_beta]) for J in range(max_J + 1)} 150 | clear_spherical_harmonics_cache() 151 | return Y_Js 152 | 153 | def get_basis(r_ij, max_degree, differentiable = False): 154 | """Return equivariant weight basis (basis) 155 | 156 | Call this function *once* at the start of each forward pass of the model. 157 | It computes the equivariant weight basis, W_J^lk(x), and internodal 158 | distances, needed to compute varphi_J^lk(x), of eqn 8 of 159 | https://arxiv.org/pdf/2006.10503.pdf. The return values of this function 160 | can be shared as input across all SE(3)-Transformer layers in a model. 161 | 162 | Args: 163 | r_ij: relative positional vectors 164 | max_degree: non-negative int for degree of highest feature-type 165 | differentiable: whether r_ij should receive gradients from basis 166 | Returns: 167 | dict of equivariant bases, keys are in form '' 168 | """ 169 | 170 | # Relative positional encodings (vector) 171 | context = null_context if not differentiable else torch.no_grad 172 | 173 | device, dtype = r_ij.device, r_ij.dtype 174 | 175 | with context(): 176 | r_ij = get_spherical_from_cartesian(r_ij) 177 | 178 | # Spherical harmonic basis 179 | Y = precompute_sh(r_ij, 2 * max_degree) 180 | 181 | # Equivariant basis (dict['d_in>']) 182 | 183 | basis = {} 184 | for d_in, d_out in product(range(max_degree+1), range(max_degree+1)): 185 | K_Js = [] 186 | for J in range(abs(d_in - d_out), d_in + d_out + 1): 187 | # Get spherical harmonic projection matrices 188 | Q_J = basis_transformation_Q_J(J, d_in, d_out) 189 | Q_J = Q_J.type(dtype).to(device) 190 | 191 | # Create kernel from spherical harmonics 192 | K_J = torch.matmul(Y[J], Q_J.T) 193 | K_Js.append(K_J) 194 | 195 | # Reshape so can take linear combinations with a dot product 196 | K_Js = torch.stack(K_Js, dim = -1) 197 | size = (*r_ij.shape[:-1], 1, to_order(d_out), 1, to_order(d_in), to_order(min(d_in,d_out))) 198 | basis[f'{d_in},{d_out}'] = K_Js.view(*size) 199 | 200 | # extra detach for safe measure 201 | if not differentiable: 202 | for k, v in basis.items(): 203 | basis[k] = v.detach() 204 | 205 | return basis 206 | -------------------------------------------------------------------------------- /se3_transformer_pytorch/data/J_dense.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/se3-transformer-pytorch/e1669ee69345c271b7aa0a3aeeda452d32e26736/se3_transformer_pytorch/data/J_dense.npy -------------------------------------------------------------------------------- /se3_transformer_pytorch/data/J_dense.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/se3-transformer-pytorch/e1669ee69345c271b7aa0a3aeeda452d32e26736/se3_transformer_pytorch/data/J_dense.pt -------------------------------------------------------------------------------- /se3_transformer_pytorch/irr_repr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch import sin, cos, atan2, acos 5 | from math import pi 6 | from pathlib import Path 7 | from functools import wraps 8 | 9 | from se3_transformer_pytorch.utils import exists, default, cast_torch_tensor, to_order 10 | from se3_transformer_pytorch.spherical_harmonics import get_spherical_harmonics, clear_spherical_harmonics_cache 11 | 12 | DATA_PATH = path = Path(os.path.dirname(__file__)) / 'data' 13 | 14 | try: 15 | path = DATA_PATH / 'J_dense.pt' 16 | Jd = torch.load(str(path)) 17 | except: 18 | path = DATA_PATH / 'J_dense.npy' 19 | Jd_np = np.load(str(path), allow_pickle = True) 20 | Jd = list(map(torch.from_numpy, Jd_np)) 21 | 22 | def wigner_d_matrix(degree, alpha, beta, gamma, dtype = None, device = None): 23 | """Create wigner D matrices for batch of ZYZ Euler anglers for degree l.""" 24 | J = Jd[degree].type(dtype).to(device) 25 | order = to_order(degree) 26 | x_a = z_rot_mat(alpha, degree) 27 | x_b = z_rot_mat(beta, degree) 28 | x_c = z_rot_mat(gamma, degree) 29 | res = x_a @ J @ x_b @ J @ x_c 30 | return res.view(order, order) 31 | 32 | def z_rot_mat(angle, l): 33 | device, dtype = angle.device, angle.dtype 34 | order = to_order(l) 35 | m = angle.new_zeros((order, order)) 36 | inds = torch.arange(0, order, 1, dtype=torch.long, device=device) 37 | reversed_inds = torch.arange(2 * l, -1, -1, dtype=torch.long, device=device) 38 | frequencies = torch.arange(l, -l - 1, -1, dtype=dtype, device=device)[None] 39 | 40 | m[inds, reversed_inds] = sin(frequencies * angle[None]) 41 | m[inds, inds] = cos(frequencies * angle[None]) 42 | return m 43 | 44 | def irr_repr(order, alpha, beta, gamma, dtype = None): 45 | """ 46 | irreducible representation of SO3 47 | - compatible with compose and spherical_harmonics 48 | """ 49 | cast_ = cast_torch_tensor(lambda t: t) 50 | dtype = default(dtype, torch.get_default_dtype()) 51 | alpha, beta, gamma = map(cast_, (alpha, beta, gamma)) 52 | return wigner_d_matrix(order, alpha, beta, gamma, dtype = dtype) 53 | 54 | @cast_torch_tensor 55 | def rot_z(gamma): 56 | ''' 57 | Rotation around Z axis 58 | ''' 59 | return torch.tensor([ 60 | [cos(gamma), -sin(gamma), 0], 61 | [sin(gamma), cos(gamma), 0], 62 | [0, 0, 1] 63 | ], dtype=gamma.dtype) 64 | 65 | @cast_torch_tensor 66 | def rot_y(beta): 67 | ''' 68 | Rotation around Y axis 69 | ''' 70 | return torch.tensor([ 71 | [cos(beta), 0, sin(beta)], 72 | [0, 1, 0], 73 | [-sin(beta), 0, cos(beta)] 74 | ], dtype=beta.dtype) 75 | 76 | @cast_torch_tensor 77 | def x_to_alpha_beta(x): 78 | ''' 79 | Convert point (x, y, z) on the sphere into (alpha, beta) 80 | ''' 81 | x = x / torch.norm(x) 82 | beta = acos(x[2]) 83 | alpha = atan2(x[1], x[0]) 84 | return (alpha, beta) 85 | 86 | def rot(alpha, beta, gamma): 87 | ''' 88 | ZYZ Euler angles rotation 89 | ''' 90 | return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma) 91 | 92 | def compose(a1, b1, c1, a2, b2, c2): 93 | """ 94 | (a, b, c) = (a1, b1, c1) composed with (a2, b2, c2) 95 | """ 96 | comp = rot(a1, b1, c1) @ rot(a2, b2, c2) 97 | xyz = comp @ torch.tensor([0, 0, 1.]) 98 | a, b = x_to_alpha_beta(xyz) 99 | rotz = rot(0, -b, -a) @ comp 100 | c = atan2(rotz[1, 0], rotz[0, 0]) 101 | return a, b, c 102 | 103 | def spherical_harmonics(order, alpha, beta, dtype = None): 104 | return get_spherical_harmonics(order, theta = (pi - beta), phi = alpha) 105 | -------------------------------------------------------------------------------- /se3_transformer_pytorch/reversible.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd.function import Function 4 | from torch.utils.checkpoint import get_device_states, set_device_states 5 | 6 | # helpers 7 | 8 | def map_values(fn, x): 9 | out = {} 10 | for (k, v) in x.items(): 11 | out[k] = fn(v) 12 | return out 13 | 14 | def dict_chunk(x, chunks, dim): 15 | out1 = {} 16 | out2 = {} 17 | for (k, v) in x.items(): 18 | c1, c2 = v.chunk(chunks, dim = dim) 19 | out1[k] = c1 20 | out2[k] = c2 21 | return out1, out2 22 | 23 | def dict_sum(x, y): 24 | out = {} 25 | for k in x.keys(): 26 | out[k] = x[k] + y[k] 27 | return out 28 | 29 | def dict_subtract(x, y): 30 | out = {} 31 | for k in x.keys(): 32 | out[k] = x[k] - y[k] 33 | return out 34 | 35 | def dict_cat(x, y, dim): 36 | out = {} 37 | for k, v1 in x.items(): 38 | v2 = y[k] 39 | out[k] = torch.cat((v1, v2), dim = dim) 40 | return out 41 | 42 | def dict_set_(x, key, value): 43 | for k, v in x.items(): 44 | setattr(v, key, value) 45 | 46 | def dict_backwards_(outputs, grad_tensors): 47 | for k, v in outputs.items(): 48 | torch.autograd.backward(v, grad_tensors[k], retain_graph = True) 49 | 50 | def dict_del_(x): 51 | for k, v in x.items(): 52 | del v 53 | del x 54 | 55 | def values(d): 56 | return [v for _, v in d.items()] 57 | 58 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 59 | class Deterministic(nn.Module): 60 | def __init__(self, net): 61 | super().__init__() 62 | self.net = net 63 | self.cpu_state = None 64 | self.cuda_in_fwd = None 65 | self.gpu_devices = None 66 | self.gpu_states = None 67 | 68 | def record_rng(self, *args): 69 | self.cpu_state = torch.get_rng_state() 70 | if torch.cuda._initialized: 71 | self.cuda_in_fwd = True 72 | self.gpu_devices, self.gpu_states = get_device_states(*args) 73 | 74 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs): 75 | if record_rng: 76 | self.record_rng(*args) 77 | 78 | if not set_rng: 79 | return self.net(*args, **kwargs) 80 | 81 | rng_devices = [] 82 | if self.cuda_in_fwd: 83 | rng_devices = self.gpu_devices 84 | 85 | with torch.random.fork_rng(devices=rng_devices, enabled=True): 86 | torch.set_rng_state(self.cpu_state) 87 | if self.cuda_in_fwd: 88 | set_device_states(self.gpu_devices, self.gpu_states) 89 | return self.net(*args, **kwargs) 90 | 91 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 92 | # once multi-GPU is confirmed working, refactor and send PR back to source 93 | class ReversibleBlock(nn.Module): 94 | def __init__(self, f, g): 95 | super().__init__() 96 | self.f = Deterministic(f) 97 | self.g = Deterministic(g) 98 | 99 | def forward(self, x, **kwargs): 100 | training = self.training 101 | x1, x2 = dict_chunk(x, 2, dim = -1) 102 | y1, y2 = None, None 103 | 104 | with torch.no_grad(): 105 | y1 = dict_sum(x1, self.f(x2, record_rng = training, **kwargs)) 106 | y2 = dict_sum(x2, self.g(y1, record_rng = training)) 107 | 108 | return dict_cat(y1, y2, dim = -1) 109 | 110 | def backward_pass(self, y, dy, **kwargs): 111 | y1, y2 = dict_chunk(y, 2, dim = -1) 112 | dict_del_(y) 113 | 114 | dy1, dy2 = dict_chunk(dy, 2, dim = -1) 115 | dict_del_(dy) 116 | 117 | with torch.enable_grad(): 118 | dict_set_(y1, 'requires_grad', True) 119 | gy1 = self.g(y1, set_rng = True) 120 | dict_backwards_(gy1, dy2) 121 | 122 | with torch.no_grad(): 123 | x2 = dict_subtract(y2, gy1) 124 | dict_del_(y2) 125 | dict_del_(gy1) 126 | 127 | dx1 = dict_sum(dy1, map_values(lambda t: t.grad, y1)) 128 | dict_del_(dy1) 129 | dict_set_(y1, 'grad', None) 130 | 131 | with torch.enable_grad(): 132 | dict_set_(x2, 'requires_grad', True) 133 | fx2 = self.f(x2, set_rng = True, **kwargs) 134 | dict_backwards_(fx2, dx1) 135 | 136 | with torch.no_grad(): 137 | x1 = dict_subtract(y1, fx2) 138 | dict_del_(y1) 139 | dict_del_(fx2) 140 | 141 | dx2 = dict_sum(dy2, map_values(lambda t: t.grad, x2)) 142 | dict_del_(dy2) 143 | dict_set_(x2, 'grad', None) 144 | 145 | x2 = map_values(lambda t: t.detach(), x2) 146 | 147 | x = dict_cat(x1, x2, dim = -1) 148 | dx = dict_cat(dx1, dx2, dim = -1) 149 | 150 | return x, dx 151 | 152 | class _ReversibleFunction(Function): 153 | @staticmethod 154 | def forward(ctx, x, blocks, kwargs): 155 | input_keys = kwargs.pop('input_keys') 156 | split_dims = kwargs.pop('split_dims') 157 | input_values = x.split(split_dims, dim = -1) 158 | x = dict(zip(input_keys, input_values)) 159 | 160 | ctx.kwargs = kwargs 161 | ctx.split_dims = split_dims 162 | ctx.input_keys = input_keys 163 | 164 | for block in blocks: 165 | x = block(x, **kwargs) 166 | 167 | ctx.y = map_values(lambda t: t.detach(), x) 168 | ctx.blocks = blocks 169 | 170 | x = torch.cat(values(x), dim = -1) 171 | return x 172 | 173 | @staticmethod 174 | def backward(ctx, dy): 175 | y = ctx.y 176 | kwargs = ctx.kwargs 177 | input_keys = ctx.input_keys 178 | split_dims = ctx.split_dims 179 | 180 | dy = dy.split(split_dims, dim = -1) 181 | dy = dict(zip(input_keys, dy)) 182 | 183 | for block in ctx.blocks[::-1]: 184 | y, dy = block.backward_pass(y, dy, **kwargs) 185 | 186 | dy = torch.cat(values(dy), dim = -1) 187 | return dy, None, None 188 | 189 | class SequentialSequence(nn.Module): 190 | def __init__(self, blocks): 191 | super().__init__() 192 | self.blocks = blocks 193 | 194 | def forward(self, x, **kwargs): 195 | for (attn, ff) in self.blocks: 196 | x = attn(x, **kwargs) 197 | x = ff(x) 198 | return x 199 | 200 | class ReversibleSequence(nn.Module): 201 | def __init__(self, blocks): 202 | super().__init__() 203 | self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks]) 204 | 205 | def forward(self, x, **kwargs): 206 | blocks = self.blocks 207 | 208 | x = map_values(lambda t: torch.cat((t, t), dim = -1), x) 209 | 210 | input_keys = x.keys() 211 | split_dims = tuple(map(lambda t: t.shape[-1], x.values())) 212 | block_kwargs = {'input_keys': input_keys, 'split_dims': split_dims, **kwargs} 213 | 214 | x = torch.cat(values(x), dim = -1) 215 | 216 | x = _ReversibleFunction.apply(x, blocks, block_kwargs) 217 | 218 | x = dict(zip(input_keys, x.split(split_dims, dim = -1))) 219 | x = map_values(lambda t: torch.stack(t.chunk(2, dim = -1)).mean(dim = 0), x) 220 | return x 221 | -------------------------------------------------------------------------------- /se3_transformer_pytorch/rotary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange, repeat 4 | 5 | class SinusoidalEmbeddings(nn.Module): 6 | def __init__(self, dim): 7 | super().__init__() 8 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 9 | self.register_buffer('inv_freq', inv_freq) 10 | 11 | def forward(self, t): 12 | freqs = t[..., None].float() * self.inv_freq[None, :] 13 | return repeat(freqs, '... d -> ... (d r)', r = 2) 14 | 15 | def rotate_half(x): 16 | x = rearrange(x, '... (d j) m -> ... d j m', j = 2) 17 | x1, x2 = x.unbind(dim = -2) 18 | return torch.cat((-x2, x1), dim = -2) 19 | 20 | def apply_rotary_pos_emb(t, freqs): 21 | rot_dim = freqs.shape[-2] 22 | t, t_pass = t[..., :rot_dim, :], t[..., rot_dim:, :] 23 | t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) 24 | return torch.cat((t, t_pass), dim = -2) 25 | -------------------------------------------------------------------------------- /se3_transformer_pytorch/se3_transformer_pytorch.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | from itertools import product 3 | from collections import namedtuple 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn, einsum 8 | 9 | from se3_transformer_pytorch.basis import get_basis 10 | from se3_transformer_pytorch.utils import exists, default, uniq, map_values, batched_index_select, masked_mean, to_order, fourier_encode, cast_tuple, safe_cat, fast_split, rand_uniform, broadcat 11 | from se3_transformer_pytorch.reversible import ReversibleSequence, SequentialSequence 12 | from se3_transformer_pytorch.rotary import SinusoidalEmbeddings, apply_rotary_pos_emb 13 | 14 | from einops import rearrange, repeat 15 | 16 | # fiber helpers 17 | 18 | FiberEl = namedtuple('FiberEl', ['degrees', 'dim']) 19 | 20 | class Fiber(nn.Module): 21 | def __init__( 22 | self, 23 | structure 24 | ): 25 | super().__init__() 26 | if isinstance(structure, dict): 27 | structure = [FiberEl(degree, dim) for degree, dim in structure.items()] 28 | self.structure = structure 29 | 30 | @property 31 | def dims(self): 32 | return uniq(map(lambda t: t[1], self.structure)) 33 | 34 | @property 35 | def degrees(self): 36 | return map(lambda t: t[0], self.structure) 37 | 38 | @staticmethod 39 | def create(num_degrees, dim): 40 | dim_tuple = dim if isinstance(dim, tuple) else ((dim,) * num_degrees) 41 | return Fiber([FiberEl(degree, dim) for degree, dim in zip(range(num_degrees), dim_tuple)]) 42 | 43 | def __getitem__(self, degree): 44 | return dict(self.structure)[degree] 45 | 46 | def __iter__(self): 47 | return iter(self.structure) 48 | 49 | def __mul__(self, fiber): 50 | return product(self.structure, fiber.structure) 51 | 52 | def __and__(self, fiber): 53 | out = [] 54 | degrees_out = fiber.degrees 55 | for degree, dim in self: 56 | if degree in fiber.degrees: 57 | dim_out = fiber[degree] 58 | out.append((degree, dim, dim_out)) 59 | return out 60 | 61 | def get_tensor_device_and_dtype(features): 62 | first_tensor = next(iter(features.items()))[1] 63 | return first_tensor.device, first_tensor.dtype 64 | 65 | # classes 66 | 67 | class ResidualSE3(nn.Module): 68 | """ only support instance where both Fibers are identical """ 69 | def forward(self, x, res): 70 | out = {} 71 | for degree, tensor in x.items(): 72 | degree = str(degree) 73 | out[degree] = tensor 74 | if degree in res: 75 | out[degree] = out[degree] + res[degree] 76 | return out 77 | 78 | class LinearSE3(nn.Module): 79 | def __init__( 80 | self, 81 | fiber_in, 82 | fiber_out 83 | ): 84 | super().__init__() 85 | self.weights = nn.ParameterDict() 86 | 87 | for (degree, dim_in, dim_out) in (fiber_in & fiber_out): 88 | key = str(degree) 89 | self.weights[key] = nn.Parameter(torch.randn(dim_in, dim_out) / sqrt(dim_in)) 90 | 91 | def forward(self, x): 92 | out = {} 93 | for degree, weight in self.weights.items(): 94 | out[degree] = einsum('b n d m, d e -> b n e m', x[degree], weight) 95 | return out 96 | 97 | class NormSE3(nn.Module): 98 | """Norm-based SE(3)-equivariant nonlinearity. 99 | 100 | Nonlinearities are important in SE(3) equivariant GCNs. They are also quite 101 | expensive to compute, so it is convenient for them to share resources with 102 | other layers, such as normalization. The general workflow is as follows: 103 | 104 | > for feature type in features: 105 | > norm, phase <- feature 106 | > output = fnc(norm) * phase 107 | 108 | where fnc: {R+}^m -> R^m is a learnable map from m norms to m scalars. 109 | """ 110 | def __init__( 111 | self, 112 | fiber, 113 | nonlin = nn.GELU(), 114 | gated_scale = False, 115 | eps = 1e-12, 116 | ): 117 | super().__init__() 118 | self.fiber = fiber 119 | self.nonlin = nonlin 120 | self.eps = eps 121 | 122 | # Norm mappings: 1 per feature type 123 | self.transform = nn.ModuleDict() 124 | for degree, chan in fiber: 125 | self.transform[str(degree)] = nn.ParameterDict({ 126 | 'scale': nn.Parameter(torch.ones(1, 1, chan)) if not gated_scale else None, 127 | 'w_gate': nn.Parameter(rand_uniform((chan, chan), -1e-3, 1e-3)) if gated_scale else None 128 | }) 129 | 130 | def forward(self, features): 131 | output = {} 132 | for degree, t in features.items(): 133 | # Compute the norms and normalized features 134 | norm = t.norm(dim = -1, keepdim = True).clamp(min = self.eps) 135 | phase = t / norm 136 | 137 | # Transform on norms 138 | parameters = self.transform[degree] 139 | gate_weights, scale = parameters['w_gate'], parameters['scale'] 140 | 141 | transformed = rearrange(norm, '... () -> ...') 142 | 143 | if not exists(scale): 144 | scale = einsum('b n d, d e -> b n e', transformed, gate_weights) 145 | 146 | transformed = self.nonlin(transformed * scale) 147 | transformed = rearrange(transformed, '... -> ... ()') 148 | 149 | # Nonlinearity on norm 150 | output[degree] = (transformed * phase).view(*t.shape) 151 | 152 | return output 153 | 154 | class ConvSE3(nn.Module): 155 | """A tensor field network layer 156 | 157 | ConvSE3 stands for a Convolution SE(3)-equivariant layer. It is the 158 | equivalent of a linear layer in an MLP, a conv layer in a CNN, or a graph 159 | conv layer in a GCN. 160 | 161 | At each node, the activations are split into different "feature types", 162 | indexed by the SE(3) representation type: non-negative integers 0, 1, 2, .. 163 | """ 164 | def __init__( 165 | self, 166 | fiber_in, 167 | fiber_out, 168 | self_interaction = True, 169 | pool = True, 170 | edge_dim = 0, 171 | fourier_encode_dist = False, 172 | num_fourier_features = 4, 173 | splits = 4 174 | ): 175 | super().__init__() 176 | self.fiber_in = fiber_in 177 | self.fiber_out = fiber_out 178 | self.edge_dim = edge_dim 179 | self.self_interaction = self_interaction 180 | 181 | self.num_fourier_features = num_fourier_features 182 | self.fourier_encode_dist = fourier_encode_dist 183 | 184 | # radial function will assume a dimension of at minimum 1, for the relative distance - extra fourier features must be added to the edge dimension 185 | edge_dim += (0 if not fourier_encode_dist else (num_fourier_features * 2)) 186 | 187 | # Neighbor -> center weights 188 | self.kernel_unary = nn.ModuleDict() 189 | 190 | self.splits = splits # for splitting the computation of kernel and basis, to reduce peak memory usage 191 | 192 | for (di, mi), (do, mo) in (self.fiber_in * self.fiber_out): 193 | self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim = edge_dim, splits = splits) 194 | 195 | self.pool = pool 196 | 197 | # Center -> center weights 198 | if self_interaction: 199 | assert self.pool, 'must pool edges if followed with self interaction' 200 | self.self_interact = LinearSE3(fiber_in, fiber_out) 201 | self.self_interact_sum = ResidualSE3() 202 | 203 | def forward( 204 | self, 205 | inp, 206 | edge_info, 207 | rel_dist = None, 208 | basis = None 209 | ): 210 | splits = self.splits 211 | neighbor_indices, neighbor_masks, edges = edge_info 212 | rel_dist = rearrange(rel_dist, 'b m n -> b m n ()') 213 | 214 | kernels = {} 215 | outputs = {} 216 | 217 | if self.fourier_encode_dist: 218 | rel_dist = fourier_encode(rel_dist[..., None], num_encodings = self.num_fourier_features) 219 | 220 | # split basis 221 | 222 | basis_keys = basis.keys() 223 | split_basis_values = list(zip(*list(map(lambda t: fast_split(t, splits, dim = 1), basis.values())))) 224 | split_basis = list(map(lambda v: dict(zip(basis_keys, v)), split_basis_values)) 225 | 226 | # go through every permutation of input degree type to output degree type 227 | 228 | for degree_out in self.fiber_out.degrees: 229 | output = 0 230 | degree_out_key = str(degree_out) 231 | 232 | for degree_in, m_in in self.fiber_in: 233 | etype = f'({degree_in},{degree_out})' 234 | 235 | x = inp[str(degree_in)] 236 | 237 | x = batched_index_select(x, neighbor_indices, dim = 1) 238 | x = x.view(*x.shape[:3], to_order(degree_in) * m_in, 1) 239 | 240 | kernel_fn = self.kernel_unary[etype] 241 | edge_features = torch.cat((rel_dist, edges), dim = -1) if exists(edges) else rel_dist 242 | 243 | output_chunk = None 244 | split_x = fast_split(x, splits, dim = 1) 245 | split_edge_features = fast_split(edge_features, splits, dim = 1) 246 | 247 | # process input, edges, and basis in chunks along the sequence dimension 248 | 249 | for x_chunk, edge_features, basis in zip(split_x, split_edge_features, split_basis): 250 | kernel = kernel_fn(edge_features, basis = basis) 251 | chunk = einsum('... o i, ... i c -> ... o c', kernel, x_chunk) 252 | output_chunk = safe_cat(output_chunk, chunk, dim = 1) 253 | 254 | output = output + output_chunk 255 | 256 | if self.pool: 257 | output = masked_mean(output, neighbor_masks, dim = 2) if exists(neighbor_masks) else output.mean(dim = 2) 258 | 259 | leading_shape = x.shape[:2] if self.pool else x.shape[:3] 260 | output = output.view(*leading_shape, -1, to_order(degree_out)) 261 | 262 | outputs[degree_out_key] = output 263 | 264 | if self.self_interaction: 265 | self_interact_out = self.self_interact(inp) 266 | outputs = self.self_interact_sum(outputs, self_interact_out) 267 | 268 | return outputs 269 | 270 | class RadialFunc(nn.Module): 271 | """NN parameterized radial profile function.""" 272 | def __init__( 273 | self, 274 | num_freq, 275 | in_dim, 276 | out_dim, 277 | edge_dim = None, 278 | mid_dim = 128 279 | ): 280 | super().__init__() 281 | self.num_freq = num_freq 282 | self.in_dim = in_dim 283 | self.mid_dim = mid_dim 284 | self.out_dim = out_dim 285 | self.edge_dim = default(edge_dim, 0) 286 | 287 | self.net = nn.Sequential( 288 | nn.Linear(self.edge_dim + 1, mid_dim), 289 | nn.LayerNorm(mid_dim), 290 | nn.GELU(), 291 | nn.Linear(mid_dim, mid_dim), 292 | nn.LayerNorm(mid_dim), 293 | nn.GELU(), 294 | nn.Linear(mid_dim, num_freq * in_dim * out_dim) 295 | ) 296 | 297 | def forward(self, x): 298 | y = self.net(x) 299 | return rearrange(y, '... (o i f) -> ... o () i () f', i = self.in_dim, o = self.out_dim) 300 | 301 | class PairwiseConv(nn.Module): 302 | """SE(3)-equivariant convolution between two single-type features""" 303 | def __init__( 304 | self, 305 | degree_in, 306 | nc_in, 307 | degree_out, 308 | nc_out, 309 | edge_dim = 0, 310 | splits = 4 311 | ): 312 | super().__init__() 313 | self.degree_in = degree_in 314 | self.degree_out = degree_out 315 | self.nc_in = nc_in 316 | self.nc_out = nc_out 317 | 318 | self.num_freq = to_order(min(degree_in, degree_out)) 319 | self.d_out = to_order(degree_out) 320 | self.edge_dim = edge_dim 321 | 322 | self.rp = RadialFunc(self.num_freq, nc_in, nc_out, edge_dim) 323 | 324 | self.splits = splits 325 | 326 | def forward(self, feat, basis): 327 | splits = self.splits 328 | R = self.rp(feat) 329 | B = basis[f'{self.degree_in},{self.degree_out}'] 330 | 331 | out_shape = (*R.shape[:3], self.d_out * self.nc_out, -1) 332 | 333 | # torch.sum(R * B, dim = -1) is too memory intensive 334 | # needs to be chunked to reduce peak memory usage 335 | 336 | out = 0 337 | for i in range(R.shape[-1]): 338 | out += R[..., i] * B[..., i] 339 | 340 | out = rearrange(out, 'b n h s ... -> (b n h s) ...') 341 | 342 | # reshape and out 343 | return out.view(*out_shape) 344 | 345 | # feed forwards 346 | 347 | class FeedForwardSE3(nn.Module): 348 | def __init__( 349 | self, 350 | fiber, 351 | mult = 4 352 | ): 353 | super().__init__() 354 | self.fiber = fiber 355 | fiber_hidden = Fiber(list(map(lambda t: (t[0], t[1] * mult), fiber))) 356 | 357 | self.project_in = LinearSE3(fiber, fiber_hidden) 358 | self.nonlin = NormSE3(fiber_hidden) 359 | self.project_out = LinearSE3(fiber_hidden, fiber) 360 | 361 | def forward(self, features): 362 | outputs = self.project_in(features) 363 | outputs = self.nonlin(outputs) 364 | outputs = self.project_out(outputs) 365 | return outputs 366 | 367 | class FeedForwardBlockSE3(nn.Module): 368 | def __init__( 369 | self, 370 | fiber, 371 | norm_gated_scale = False 372 | ): 373 | super().__init__() 374 | self.fiber = fiber 375 | self.prenorm = NormSE3(fiber, gated_scale = norm_gated_scale) 376 | self.feedforward = FeedForwardSE3(fiber) 377 | self.residual = ResidualSE3() 378 | 379 | def forward(self, features): 380 | res = features 381 | out = self.prenorm(features) 382 | out = self.feedforward(out) 383 | return self.residual(out, res) 384 | 385 | # attention 386 | 387 | class AttentionSE3(nn.Module): 388 | def __init__( 389 | self, 390 | fiber, 391 | dim_head = 64, 392 | heads = 8, 393 | attend_self = False, 394 | edge_dim = None, 395 | fourier_encode_dist = False, 396 | rel_dist_num_fourier_features = 4, 397 | use_null_kv = False, 398 | splits = 4, 399 | global_feats_dim = None, 400 | linear_proj_keys = False, 401 | tie_key_values = False 402 | ): 403 | super().__init__() 404 | hidden_dim = dim_head * heads 405 | hidden_fiber = Fiber(list(map(lambda t: (t[0], hidden_dim), fiber))) 406 | project_out = not (heads == 1 and len(fiber.dims) == 1 and dim_head == fiber.dims[0]) 407 | 408 | self.scale = dim_head ** -0.5 409 | self.heads = heads 410 | 411 | self.linear_proj_keys = linear_proj_keys # whether to linearly project features for keys, rather than convolve with basis 412 | 413 | self.to_q = LinearSE3(fiber, hidden_fiber) 414 | self.to_v = ConvSE3(fiber, hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits) 415 | 416 | assert not (linear_proj_keys and tie_key_values), 'you cannot do linear projection of keys and have shared key / values turned on at the same time' 417 | 418 | if linear_proj_keys: 419 | self.to_k = LinearSE3(fiber, hidden_fiber) 420 | elif not tie_key_values: 421 | self.to_k = ConvSE3(fiber, hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits) 422 | else: 423 | self.to_k = None 424 | 425 | self.to_out = LinearSE3(hidden_fiber, fiber) if project_out else nn.Identity() 426 | 427 | self.use_null_kv = use_null_kv 428 | if use_null_kv: 429 | self.null_keys = nn.ParameterDict() 430 | self.null_values = nn.ParameterDict() 431 | 432 | for degree in fiber.degrees: 433 | m = to_order(degree) 434 | degree_key = str(degree) 435 | self.null_keys[degree_key] = nn.Parameter(torch.zeros(heads, dim_head, m)) 436 | self.null_values[degree_key] = nn.Parameter(torch.zeros(heads, dim_head, m)) 437 | 438 | self.attend_self = attend_self 439 | if attend_self: 440 | self.to_self_k = LinearSE3(fiber, hidden_fiber) 441 | self.to_self_v = LinearSE3(fiber, hidden_fiber) 442 | 443 | self.accept_global_feats = exists(global_feats_dim) 444 | if self.accept_global_feats: 445 | global_input_fiber = Fiber.create(1, global_feats_dim) 446 | global_output_fiber = Fiber.create(1, hidden_fiber[0]) 447 | self.to_global_k = LinearSE3(global_input_fiber, global_output_fiber) 448 | self.to_global_v = LinearSE3(global_input_fiber, global_output_fiber) 449 | 450 | def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None): 451 | h, attend_self = self.heads, self.attend_self 452 | device, dtype = get_tensor_device_and_dtype(features) 453 | neighbor_indices, neighbor_mask, edges = edge_info 454 | 455 | if exists(neighbor_mask): 456 | neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j') 457 | 458 | queries = self.to_q(features) 459 | values = self.to_v(features, edge_info, rel_dist, basis) 460 | 461 | if self.linear_proj_keys: 462 | keys = self.to_k(features) 463 | keys = map_values(lambda val: batched_index_select(val, neighbor_indices, dim = 1), keys) 464 | elif not exists(self.to_k): 465 | keys = values 466 | else: 467 | keys = self.to_k(features, edge_info, rel_dist, basis) 468 | 469 | if attend_self: 470 | self_keys, self_values = self.to_self_k(features), self.to_self_v(features) 471 | 472 | if exists(global_feats): 473 | global_keys, global_values = self.to_global_k(global_feats), self.to_global_v(global_feats) 474 | 475 | outputs = {} 476 | for degree in features.keys(): 477 | q, k, v = map(lambda t: t[degree], (queries, keys, values)) 478 | 479 | q = rearrange(q, 'b i (h d) m -> b h i d m', h = h) 480 | k, v = map(lambda t: rearrange(t, 'b i j (h d) m -> b h i j d m', h = h), (k, v)) 481 | 482 | if attend_self: 483 | self_k, self_v = map(lambda t: t[degree], (self_keys, self_values)) 484 | self_k, self_v = map(lambda t: rearrange(t, 'b n (h d) m -> b h n () d m', h = h), (self_k, self_v)) 485 | k = torch.cat((self_k, k), dim = 3) 486 | v = torch.cat((self_v, v), dim = 3) 487 | 488 | if exists(pos_emb) and degree == '0': 489 | query_pos_emb, key_pos_emb = pos_emb 490 | query_pos_emb = rearrange(query_pos_emb, 'b i d -> b () i d ()') 491 | key_pos_emb = rearrange(key_pos_emb, 'b i j d -> b () i j d ()') 492 | q = apply_rotary_pos_emb(q, query_pos_emb) 493 | k = apply_rotary_pos_emb(k, key_pos_emb) 494 | v = apply_rotary_pos_emb(v, key_pos_emb) 495 | 496 | if self.use_null_kv: 497 | null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values)) 498 | null_k, null_v = map(lambda t: repeat(t, 'h d m -> b h i () d m', b = q.shape[0], i = q.shape[2]), (null_k, null_v)) 499 | k = torch.cat((null_k, k), dim = 3) 500 | v = torch.cat((null_v, v), dim = 3) 501 | 502 | if exists(global_feats) and degree == '0': 503 | global_k, global_v = map(lambda t: t[degree], (global_keys, global_values)) 504 | global_k, global_v = map(lambda t: repeat(t, 'b j (h d) m -> b h i j d m', h = h, i = k.shape[2]), (global_k, global_v)) 505 | k = torch.cat((global_k, k), dim = 3) 506 | v = torch.cat((global_v, v), dim = 3) 507 | 508 | sim = einsum('b h i d m, b h i j d m -> b h i j', q, k) * self.scale 509 | 510 | if exists(neighbor_mask): 511 | num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1] 512 | mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True) 513 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) 514 | 515 | attn = sim.softmax(dim = -1) 516 | out = einsum('b h i j, b h i j d m -> b h i d m', attn, v) 517 | outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m') 518 | 519 | return self.to_out(outputs) 520 | 521 | # AttentionSE3, but with one key / value projection shared across all query heads 522 | class OneHeadedKVAttentionSE3(nn.Module): 523 | def __init__( 524 | self, 525 | fiber, 526 | dim_head = 64, 527 | heads = 8, 528 | attend_self = False, 529 | edge_dim = None, 530 | fourier_encode_dist = False, 531 | rel_dist_num_fourier_features = 4, 532 | use_null_kv = False, 533 | splits = 4, 534 | global_feats_dim = None, 535 | linear_proj_keys = False, 536 | tie_key_values = False 537 | ): 538 | super().__init__() 539 | hidden_dim = dim_head * heads 540 | hidden_fiber = Fiber(list(map(lambda t: (t[0], hidden_dim), fiber))) 541 | kv_hidden_fiber = Fiber(list(map(lambda t: (t[0], dim_head), fiber))) 542 | project_out = not (heads == 1 and len(fiber.dims) == 1 and dim_head == fiber.dims[0]) 543 | 544 | self.scale = dim_head ** -0.5 545 | self.heads = heads 546 | 547 | self.linear_proj_keys = linear_proj_keys # whether to linearly project features for keys, rather than convolve with basis 548 | 549 | self.to_q = LinearSE3(fiber, hidden_fiber) 550 | self.to_v = ConvSE3(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits) 551 | 552 | assert not (linear_proj_keys and tie_key_values), 'you cannot do linear projection of keys and have shared key / values turned on at the same time' 553 | 554 | if linear_proj_keys: 555 | self.to_k = LinearSE3(fiber, kv_hidden_fiber) 556 | elif not tie_key_values: 557 | self.to_k = ConvSE3(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits) 558 | else: 559 | self.to_k = None 560 | 561 | self.to_out = LinearSE3(hidden_fiber, fiber) if project_out else nn.Identity() 562 | 563 | self.use_null_kv = use_null_kv 564 | if use_null_kv: 565 | self.null_keys = nn.ParameterDict() 566 | self.null_values = nn.ParameterDict() 567 | 568 | for degree in fiber.degrees: 569 | m = to_order(degree) 570 | degree_key = str(degree) 571 | self.null_keys[degree_key] = nn.Parameter(torch.zeros(dim_head, m)) 572 | self.null_values[degree_key] = nn.Parameter(torch.zeros(dim_head, m)) 573 | 574 | self.attend_self = attend_self 575 | if attend_self: 576 | self.to_self_k = LinearSE3(fiber, kv_hidden_fiber) 577 | self.to_self_v = LinearSE3(fiber, kv_hidden_fiber) 578 | 579 | self.accept_global_feats = exists(global_feats_dim) 580 | if self.accept_global_feats: 581 | global_input_fiber = Fiber.create(1, global_feats_dim) 582 | global_output_fiber = Fiber.create(1, kv_hidden_fiber[0]) 583 | self.to_global_k = LinearSE3(global_input_fiber, global_output_fiber) 584 | self.to_global_v = LinearSE3(global_input_fiber, global_output_fiber) 585 | 586 | def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None): 587 | h, attend_self = self.heads, self.attend_self 588 | device, dtype = get_tensor_device_and_dtype(features) 589 | neighbor_indices, neighbor_mask, edges = edge_info 590 | 591 | if exists(neighbor_mask): 592 | neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j') 593 | 594 | queries = self.to_q(features) 595 | values = self.to_v(features, edge_info, rel_dist, basis) 596 | 597 | if self.linear_proj_keys: 598 | keys = self.to_k(features) 599 | keys = map_values(lambda val: batched_index_select(val, neighbor_indices, dim = 1), keys) 600 | elif not exists(self.to_k): 601 | keys = values 602 | else: 603 | keys = self.to_k(features, edge_info, rel_dist, basis) 604 | 605 | if attend_self: 606 | self_keys, self_values = self.to_self_k(features), self.to_self_v(features) 607 | 608 | if exists(global_feats): 609 | global_keys, global_values = self.to_global_k(global_feats), self.to_global_v(global_feats) 610 | 611 | outputs = {} 612 | for degree in features.keys(): 613 | q, k, v = map(lambda t: t[degree], (queries, keys, values)) 614 | 615 | q = rearrange(q, 'b i (h d) m -> b h i d m', h = h) 616 | 617 | if attend_self: 618 | self_k, self_v = map(lambda t: t[degree], (self_keys, self_values)) 619 | self_k, self_v = map(lambda t: rearrange(t, 'b n d m -> b n () d m'), (self_k, self_v)) 620 | k = torch.cat((self_k, k), dim = 2) 621 | v = torch.cat((self_v, v), dim = 2) 622 | 623 | if exists(pos_emb) and degree == '0': 624 | query_pos_emb, key_pos_emb = pos_emb 625 | query_pos_emb = rearrange(query_pos_emb, 'b i d -> b () i d ()') 626 | key_pos_emb = rearrange(key_pos_emb, 'b i j d -> b i j d ()') 627 | q = apply_rotary_pos_emb(q, query_pos_emb) 628 | k = apply_rotary_pos_emb(k, key_pos_emb) 629 | v = apply_rotary_pos_emb(v, key_pos_emb) 630 | 631 | if self.use_null_kv: 632 | null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values)) 633 | null_k, null_v = map(lambda t: repeat(t, 'd m -> b i () d m', b = q.shape[0], i = q.shape[2]), (null_k, null_v)) 634 | k = torch.cat((null_k, k), dim = 2) 635 | v = torch.cat((null_v, v), dim = 2) 636 | 637 | if exists(global_feats) and degree == '0': 638 | global_k, global_v = map(lambda t: t[degree], (global_keys, global_values)) 639 | global_k, global_v = map(lambda t: repeat(t, 'b j d m -> b i j d m', i = k.shape[1]), (global_k, global_v)) 640 | k = torch.cat((global_k, k), dim = 2) 641 | v = torch.cat((global_v, v), dim = 2) 642 | 643 | sim = einsum('b h i d m, b i j d m -> b h i j', q, k) * self.scale 644 | 645 | if exists(neighbor_mask): 646 | num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1] 647 | mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True) 648 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) 649 | 650 | attn = sim.softmax(dim = -1) 651 | out = einsum('b h i j, b i j d m -> b h i d m', attn, v) 652 | outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m') 653 | 654 | return self.to_out(outputs) 655 | 656 | class AttentionBlockSE3(nn.Module): 657 | def __init__( 658 | self, 659 | fiber, 660 | dim_head = 24, 661 | heads = 8, 662 | attend_self = False, 663 | edge_dim = None, 664 | use_null_kv = False, 665 | fourier_encode_dist = False, 666 | rel_dist_num_fourier_features = 4, 667 | splits = 4, 668 | global_feats_dim = False, 669 | linear_proj_keys = False, 670 | tie_key_values = False, 671 | attention_klass = AttentionSE3, 672 | norm_gated_scale = False 673 | ): 674 | super().__init__() 675 | self.attn = attention_klass(fiber, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, use_null_kv = use_null_kv, rel_dist_num_fourier_features = rel_dist_num_fourier_features, fourier_encode_dist =fourier_encode_dist, splits = splits, global_feats_dim = global_feats_dim, linear_proj_keys = linear_proj_keys, tie_key_values = tie_key_values) 676 | self.prenorm = NormSE3(fiber, gated_scale = norm_gated_scale) 677 | self.residual = ResidualSE3() 678 | 679 | def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None): 680 | res = features 681 | outputs = self.prenorm(features) 682 | outputs = self.attn(outputs, edge_info, rel_dist, basis, global_feats, pos_emb, mask) 683 | return self.residual(outputs, res) 684 | 685 | # egnn 686 | 687 | class Swish_(nn.Module): 688 | def forward(self, x): 689 | return x * x.sigmoid() 690 | 691 | SiLU = nn.SiLU if hasattr(nn, 'SiLU') else Swish_ 692 | 693 | class HtypesNorm(nn.Module): 694 | def __init__(self, dim, eps = 1e-8, scale_init = 1e-2, bias_init = 1e-2): 695 | super().__init__() 696 | self.eps = eps 697 | scale = torch.empty(1, 1, 1, dim, 1).fill_(scale_init) 698 | bias = torch.empty(1, 1, 1, dim, 1).fill_(bias_init) 699 | self.scale = nn.Parameter(scale) 700 | self.bias = nn.Parameter(bias) 701 | 702 | def forward(self, coors): 703 | norm = coors.norm(dim = -1, keepdim = True) 704 | normed_coors = coors / norm.clamp(min = self.eps) 705 | return normed_coors * (norm * self.scale + self.bias) 706 | 707 | class EGNN(nn.Module): 708 | def __init__( 709 | self, 710 | fiber, 711 | hidden_dim = 32, 712 | edge_dim = 0, 713 | init_eps = 1e-3, 714 | coor_weights_clamp_value = None 715 | ): 716 | super().__init__() 717 | self.fiber = fiber 718 | node_dim = fiber[0] 719 | 720 | htypes = list(filter(lambda t: t.degrees != 0, fiber)) 721 | num_htypes = len(htypes) 722 | htype_dims = sum([fiberel.dim for fiberel in htypes]) 723 | 724 | edge_input_dim = node_dim * 2 + htype_dims + edge_dim + 1 725 | 726 | self.node_norm = nn.LayerNorm(node_dim) 727 | 728 | self.edge_mlp = nn.Sequential( 729 | nn.Linear(edge_input_dim, edge_input_dim * 2), 730 | SiLU(), 731 | nn.Linear(edge_input_dim * 2, hidden_dim), 732 | SiLU() 733 | ) 734 | 735 | self.htype_norms = nn.ModuleDict({}) 736 | self.htype_gating = nn.ModuleDict({}) 737 | 738 | for degree, dim in fiber: 739 | if degree == 0: 740 | continue 741 | self.htype_norms[str(degree)] = HtypesNorm(dim) 742 | self.htype_gating[str(degree)] = nn.Linear(node_dim, dim) 743 | 744 | self.htypes_mlp = nn.Sequential( 745 | nn.Linear(hidden_dim, hidden_dim * 4), 746 | SiLU(), 747 | nn.Linear(hidden_dim * 4, htype_dims) 748 | ) 749 | 750 | self.node_mlp = nn.Sequential( 751 | nn.Linear(node_dim + hidden_dim, node_dim * 2), 752 | SiLU(), 753 | nn.Linear(node_dim * 2, node_dim) 754 | ) 755 | 756 | self.coor_weights_clamp_value = coor_weights_clamp_value 757 | self.init_eps = init_eps 758 | self.apply(self.init_) 759 | 760 | def init_(self, module): 761 | if type(module) in {nn.Linear}: 762 | nn.init.normal_(module.weight, std = self.init_eps) 763 | 764 | def forward( 765 | self, 766 | features, 767 | edge_info, 768 | rel_dist, 769 | mask = None, 770 | **kwargs 771 | ): 772 | neighbor_indices, neighbor_masks, edges = edge_info 773 | 774 | mask = neighbor_masks 775 | 776 | # type 0 features 777 | 778 | nodes = features['0'] 779 | nodes = rearrange(nodes, '... () -> ...') 780 | 781 | # higher types (htype) 782 | 783 | htypes = list(filter(lambda t: t[0] != '0', features.items())) 784 | htype_degrees = list(map(lambda t: t[0], htypes)) 785 | htype_dims = list(map(lambda t: t[1].shape[-2], htypes)) 786 | 787 | # prepare higher types 788 | 789 | rel_htypes = [] 790 | rel_htypes_dists = [] 791 | 792 | for degree, htype in htypes: 793 | rel_htype = rearrange(htype, 'b i d m -> b i () d m') - rearrange(htype, 'b j d m -> b () j d m') 794 | rel_htype_dist = rel_htype.norm(dim = -1) 795 | 796 | rel_htypes.append(rel_htype) 797 | rel_htypes_dists.append(rel_htype_dist) 798 | 799 | # prepare edges for edge MLP 800 | 801 | nodes_i = rearrange(nodes, 'b i d -> b i () d') 802 | nodes_j = batched_index_select(nodes, neighbor_indices, dim = 1) 803 | neighbor_higher_type_dists = map(lambda t: batched_index_select(t, neighbor_indices, dim = 2), rel_htypes_dists) 804 | coor_rel_dist = rearrange(rel_dist, 'b i j -> b i j ()') 805 | 806 | edge_mlp_inputs = broadcat((nodes_i, nodes_j, *neighbor_higher_type_dists, coor_rel_dist), dim = -1) 807 | 808 | if exists(edges): 809 | edge_mlp_inputs = torch.cat((edge_mlp_inputs, edges), dim = -1) 810 | 811 | # get intermediate representation 812 | 813 | m_ij = self.edge_mlp(edge_mlp_inputs) 814 | 815 | # to coordinates 816 | 817 | htype_weights = self.htypes_mlp(m_ij) 818 | 819 | if exists(self.coor_weights_clamp_value): 820 | clamp_value = self.coor_weights_clamp_value 821 | htype_weights.clamp_(min = -clamp_value, max = clamp_value) 822 | 823 | split_htype_weights = htype_weights.split(htype_dims, dim = -1) 824 | 825 | htype_updates = [] 826 | 827 | if exists(mask): 828 | htype_mask = rearrange(mask, 'b i j -> b i j ()') 829 | htype_weights = htype_weights.masked_fill(~htype_mask, 0.) 830 | 831 | for degree, rel_htype, htype_weight in zip(htype_degrees, rel_htypes, split_htype_weights): 832 | normed_rel_htype = self.htype_norms[str(degree)](rel_htype) 833 | normed_rel_htype = batched_index_select(normed_rel_htype, neighbor_indices, dim = 2) 834 | 835 | htype_update = einsum('b i j d m, b i j d -> b i d m', normed_rel_htype, htype_weight) 836 | htype_updates.append(htype_update) 837 | 838 | # to nodes 839 | 840 | if exists(mask): 841 | m_ij_mask = rearrange(mask, '... -> ... ()') 842 | m_ij = m_ij.masked_fill(~m_ij_mask, 0.) 843 | 844 | m_i = m_ij.sum(dim = -2) 845 | 846 | normed_nodes = self.node_norm(nodes) 847 | node_mlp_input = torch.cat((normed_nodes, m_i), dim = -1) 848 | node_out = self.node_mlp(node_mlp_input) + nodes 849 | 850 | # update nodes 851 | 852 | features['0'] = rearrange(node_out, '... -> ... ()') 853 | 854 | # update higher types 855 | 856 | update_htype_dicts = dict(zip(htype_degrees, htype_updates)) 857 | 858 | for degree, update_htype in update_htype_dicts.items(): 859 | features[degree] = features[degree] + update_htype 860 | 861 | for degree in htype_degrees: 862 | gating = self.htype_gating[str(degree)](node_out).sigmoid() 863 | features[degree] = features[degree] * rearrange(gating, '... -> ... ()') 864 | 865 | return features 866 | 867 | class EGnnNetwork(nn.Module): 868 | def __init__( 869 | self, 870 | *, 871 | fiber, 872 | depth, 873 | edge_dim = 0, 874 | hidden_dim = 32, 875 | coor_weights_clamp_value = None, 876 | feedforward = False 877 | ): 878 | super().__init__() 879 | self.fiber = fiber 880 | self.layers = nn.ModuleList([]) 881 | for _ in range(depth): 882 | self.layers.append(nn.ModuleList([ 883 | EGNN(fiber = fiber, edge_dim = edge_dim, hidden_dim = hidden_dim, coor_weights_clamp_value = coor_weights_clamp_value), 884 | FeedForwardBlockSE3(fiber) if feedforward else None 885 | ])) 886 | 887 | def forward( 888 | self, 889 | features, 890 | edge_info, 891 | rel_dist, 892 | basis, 893 | global_feats = None, 894 | pos_emb = None, 895 | mask = None, 896 | **kwargs 897 | ): 898 | neighbor_indices, neighbor_masks, edges = edge_info 899 | device = neighbor_indices.device 900 | 901 | # modify neighbors to include self (since se3 transformer depends on removing attention to token self, but this does not apply for EGNN) 902 | 903 | self_indices = torch.arange(neighbor_indices.shape[1], device = device) 904 | self_indices = rearrange(self_indices, 'i -> () i ()') 905 | neighbor_indices = broadcat((self_indices, neighbor_indices), dim = -1) 906 | 907 | neighbor_masks = F.pad(neighbor_masks, (1, 0), value = True) 908 | rel_dist = F.pad(rel_dist, (1, 0), value = 0.) 909 | 910 | if exists(edges): 911 | edges = F.pad(edges, (0, 0, 1, 0), value = 0.) # make edge of token to itself 0 for now 912 | 913 | edge_info = (neighbor_indices, neighbor_masks, edges) 914 | 915 | # go through layers 916 | 917 | for egnn, ff in self.layers: 918 | features = egnn( 919 | features, 920 | edge_info = edge_info, 921 | rel_dist = rel_dist, 922 | basis = basis, 923 | global_feats = global_feats, 924 | pos_emb = pos_emb, 925 | mask = mask, 926 | **kwargs 927 | ) 928 | 929 | if exists(ff): 930 | features = ff(features) 931 | 932 | return features 933 | 934 | # main class 935 | 936 | class SE3Transformer(nn.Module): 937 | def __init__( 938 | self, 939 | *, 940 | dim, 941 | heads = 8, 942 | dim_head = 24, 943 | depth = 2, 944 | input_degrees = 1, 945 | num_degrees = None, 946 | output_degrees = 1, 947 | valid_radius = 1e5, 948 | reduce_dim_out = False, 949 | num_tokens = None, 950 | num_positions = None, 951 | num_edge_tokens = None, 952 | edge_dim = None, 953 | reversible = False, 954 | attend_self = True, 955 | use_null_kv = False, 956 | differentiable_coors = False, 957 | fourier_encode_dist = False, 958 | rel_dist_num_fourier_features = 4, 959 | num_neighbors = float('inf'), 960 | attend_sparse_neighbors = False, 961 | num_adj_degrees = None, 962 | adj_dim = 0, 963 | max_sparse_neighbors = float('inf'), 964 | dim_in = None, 965 | dim_out = None, 966 | norm_out = False, 967 | num_conv_layers = 0, 968 | causal = False, 969 | splits = 4, 970 | global_feats_dim = None, 971 | linear_proj_keys = False, 972 | one_headed_key_values = False, 973 | tie_key_values = False, 974 | rotary_position = False, 975 | rotary_rel_dist = False, 976 | norm_gated_scale = False, 977 | use_egnn = False, 978 | egnn_hidden_dim = 32, 979 | egnn_weights_clamp_value = None, 980 | egnn_feedforward = False, 981 | hidden_fiber_dict = None, 982 | out_fiber_dict = None 983 | ): 984 | super().__init__() 985 | dim_in = default(dim_in, dim) 986 | self.dim_in = cast_tuple(dim_in, input_degrees) 987 | self.dim = dim 988 | 989 | # token embedding 990 | 991 | self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None 992 | 993 | # positional embedding 994 | 995 | self.num_positions = num_positions 996 | self.pos_emb = nn.Embedding(num_positions, dim) if exists(num_positions) else None 997 | 998 | self.rotary_rel_dist = rotary_rel_dist 999 | self.rotary_position = rotary_position 1000 | 1001 | self.rotary_pos_emb = None 1002 | if rotary_position or rotary_rel_dist: 1003 | num_rotaries = int(rotary_position) + int(rotary_rel_dist) 1004 | self.rotary_pos_emb = SinusoidalEmbeddings(dim_head // num_rotaries) 1005 | 1006 | # edges 1007 | 1008 | assert not (exists(num_edge_tokens) and not exists(edge_dim)), 'edge dimension (edge_dim) must be supplied if SE3 transformer is to have edge tokens' 1009 | 1010 | self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None 1011 | self.has_edges = exists(edge_dim) and edge_dim > 0 1012 | 1013 | self.input_degrees = input_degrees 1014 | 1015 | assert not (exists(num_adj_degrees) and num_adj_degrees < 1), 'make sure adjacent degrees is greater than 1' 1016 | 1017 | self.num_degrees = num_degrees if exists(num_degrees) else (max(hidden_fiber_dict.keys()) + 1) 1018 | 1019 | output_degrees = output_degrees if not use_egnn else None 1020 | self.output_degrees = output_degrees 1021 | 1022 | # whether to differentiate through basis, needed for alphafold2 1023 | 1024 | self.differentiable_coors = differentiable_coors 1025 | 1026 | # neighbors hyperparameters 1027 | 1028 | self.valid_radius = valid_radius 1029 | self.num_neighbors = num_neighbors 1030 | 1031 | # sparse neighbors, derived from adjacency matrix or edges being passed in 1032 | 1033 | self.attend_sparse_neighbors = attend_sparse_neighbors 1034 | self.max_sparse_neighbors = max_sparse_neighbors 1035 | 1036 | # adjacent neighbor derivation and embed 1037 | 1038 | self.num_adj_degrees = num_adj_degrees 1039 | self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None 1040 | 1041 | edge_dim = (edge_dim if self.has_edges else 0) + (adj_dim if exists(self.adj_emb) else 0) 1042 | 1043 | # define fibers and dimensionality 1044 | 1045 | dim_in = default(dim_in, dim) 1046 | dim_out = default(dim_out, dim) 1047 | 1048 | assert exists(num_degrees) or exists(hidden_fiber_dict), 'either num_degrees or hidden_fiber_dict must be specified' 1049 | 1050 | fiber_in = Fiber.create(input_degrees, dim_in) 1051 | 1052 | if exists(hidden_fiber_dict): 1053 | fiber_hidden = Fiber(hidden_fiber_dict) 1054 | elif exists(num_degrees): 1055 | fiber_hidden = Fiber.create(num_degrees, dim) 1056 | 1057 | if exists(out_fiber_dict): 1058 | fiber_out = Fiber(out_fiber_dict) 1059 | self.output_degrees = max(out_fiber_dict.keys()) + 1 1060 | elif exists(output_degrees): 1061 | fiber_out = Fiber.create(output_degrees, dim_out) 1062 | else: 1063 | fiber_out = None 1064 | 1065 | conv_kwargs = dict(edge_dim = edge_dim, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits) 1066 | 1067 | # causal 1068 | 1069 | assert not (causal and not attend_self), 'attending to self must be turned on if in autoregressive mode (for the first token)' 1070 | self.causal = causal 1071 | 1072 | # main network 1073 | 1074 | self.conv_in = ConvSE3(fiber_in, fiber_hidden, **conv_kwargs) 1075 | 1076 | # pre-convs 1077 | 1078 | self.convs = nn.ModuleList([]) 1079 | for _ in range(num_conv_layers): 1080 | self.convs.append(nn.ModuleList([ 1081 | ConvSE3(fiber_hidden, fiber_hidden, **conv_kwargs), 1082 | NormSE3(fiber_hidden, gated_scale = norm_gated_scale) 1083 | ])) 1084 | 1085 | # global features 1086 | 1087 | self.accept_global_feats = exists(global_feats_dim) 1088 | assert not (reversible and self.accept_global_feats), 'reversibility and global features are not compatible' 1089 | 1090 | # trunk 1091 | 1092 | self.attend_self = attend_self 1093 | 1094 | default_attention_klass = OneHeadedKVAttentionSE3 if one_headed_key_values else AttentionSE3 1095 | 1096 | if use_egnn: 1097 | self.net = EGnnNetwork(fiber = fiber_hidden, depth = depth, edge_dim = edge_dim, hidden_dim = egnn_hidden_dim, coor_weights_clamp_value = egnn_weights_clamp_value, feedforward = egnn_feedforward) 1098 | else: 1099 | layers = nn.ModuleList([]) 1100 | for ind in range(depth): 1101 | attention_klass = default_attention_klass 1102 | 1103 | layers.append(nn.ModuleList([ 1104 | AttentionBlockSE3(fiber_hidden, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, fourier_encode_dist = fourier_encode_dist, rel_dist_num_fourier_features = rel_dist_num_fourier_features, use_null_kv = use_null_kv, splits = splits, global_feats_dim = global_feats_dim, linear_proj_keys = linear_proj_keys, attention_klass = attention_klass, tie_key_values = tie_key_values, norm_gated_scale = norm_gated_scale), 1105 | FeedForwardBlockSE3(fiber_hidden, norm_gated_scale = norm_gated_scale) 1106 | ])) 1107 | 1108 | execution_class = ReversibleSequence if reversible else SequentialSequence 1109 | self.net = execution_class(layers) 1110 | 1111 | # out 1112 | 1113 | self.conv_out = ConvSE3(fiber_hidden, fiber_out, **conv_kwargs) if exists(fiber_out) else None 1114 | 1115 | self.norm = NormSE3(fiber_out, gated_scale = norm_gated_scale, nonlin = nn.Identity()) if (norm_out or reversible) and exists(fiber_out) else nn.Identity() 1116 | 1117 | final_fiber = default(fiber_out, fiber_hidden) 1118 | 1119 | self.linear_out = LinearSE3( 1120 | final_fiber, 1121 | Fiber(list(map(lambda t: FiberEl(degrees = t[0], dim = 1), final_fiber))) 1122 | ) if reduce_dim_out else None 1123 | 1124 | def forward( 1125 | self, 1126 | feats, 1127 | coors, 1128 | mask = None, 1129 | adj_mat = None, 1130 | edges = None, 1131 | return_type = None, 1132 | return_pooled = False, 1133 | neighbor_mask = None, 1134 | global_feats = None 1135 | ): 1136 | assert not (self.accept_global_feats ^ exists(global_feats)), 'you cannot pass in global features unless you init the class correctly' 1137 | 1138 | _mask = mask 1139 | 1140 | if self.output_degrees == 1: 1141 | return_type = 0 1142 | 1143 | if exists(self.token_emb): 1144 | feats = self.token_emb(feats) 1145 | 1146 | if exists(self.pos_emb): 1147 | assert feats.shape[1] <= self.num_positions, 'feature sequence length must be less than the number of positions given at init' 1148 | pos_emb = self.pos_emb(torch.arange(feats.shape[1], device = feats.device)) 1149 | feats += rearrange(pos_emb, 'n d -> () n d') 1150 | 1151 | assert not (self.attend_sparse_neighbors and not exists(adj_mat)), 'adjacency matrix (adjacency_mat) or edges (edges) must be passed in' 1152 | assert not (self.has_edges and not exists(edges)), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types' 1153 | 1154 | if torch.is_tensor(feats): 1155 | feats = {'0': feats[..., None]} 1156 | 1157 | if torch.is_tensor(global_feats): 1158 | global_feats = {'0': global_feats[..., None]} 1159 | 1160 | b, n, d, *_, device = *feats['0'].shape, feats['0'].device 1161 | 1162 | assert d == self.dim_in[0], f'feature dimension {d} must be equal to dimension given at init {self.dim_in[0]}' 1163 | assert set(map(int, feats.keys())) == set(range(self.input_degrees)), f'input must have {self.input_degrees} degree' 1164 | 1165 | num_degrees, neighbors, max_sparse_neighbors, valid_radius = self.num_degrees, self.num_neighbors, self.max_sparse_neighbors, self.valid_radius 1166 | 1167 | assert self.attend_sparse_neighbors or neighbors > 0, 'you must either attend to sparsely bonded neighbors, or set number of locally attended neighbors to be greater than 0' 1168 | 1169 | # se3 transformer by default cannot have a node attend to itself 1170 | 1171 | exclude_self_mask = rearrange(~torch.eye(n, dtype = torch.bool, device = device), 'i j -> () i j') 1172 | remove_self = lambda t: t.masked_select(exclude_self_mask).reshape(b, n, n - 1) 1173 | get_max_value = lambda t: torch.finfo(t.dtype).max 1174 | 1175 | # create N-degrees adjacent matrix from 1st degree connections 1176 | 1177 | if exists(self.num_adj_degrees): 1178 | if len(adj_mat.shape) == 2: 1179 | adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b) 1180 | 1181 | adj_indices = adj_mat.clone().long() 1182 | 1183 | for ind in range(self.num_adj_degrees - 1): 1184 | degree = ind + 2 1185 | 1186 | next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0 1187 | next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool() 1188 | adj_indices = adj_indices.masked_fill(next_degree_mask, degree) 1189 | adj_mat = next_degree_adj_mat.clone() 1190 | 1191 | adj_indices = adj_indices.masked_select(exclude_self_mask).reshape(b, n, n - 1) 1192 | 1193 | # calculate sparsely connected neighbors 1194 | 1195 | sparse_neighbor_mask = None 1196 | num_sparse_neighbors = 0 1197 | 1198 | if self.attend_sparse_neighbors: 1199 | assert exists(adj_mat), 'adjacency matrix must be passed in (keyword argument adj_mat)' 1200 | 1201 | if exists(adj_mat): 1202 | if len(adj_mat.shape) == 2: 1203 | adj_mat = repeat(adj_mat, 'i j -> b i j', b = b) 1204 | 1205 | adj_mat = remove_self(adj_mat) 1206 | 1207 | adj_mat_values = adj_mat.float() 1208 | adj_mat_max_neighbors = adj_mat_values.sum(dim = -1).max().item() 1209 | 1210 | if max_sparse_neighbors < adj_mat_max_neighbors: 1211 | noise = torch.empty_like(adj_mat_values).uniform_(-0.01, 0.01) 1212 | adj_mat_values += noise 1213 | 1214 | num_sparse_neighbors = int(min(max_sparse_neighbors, adj_mat_max_neighbors)) 1215 | values, indices = adj_mat_values.topk(num_sparse_neighbors, dim = -1) 1216 | sparse_neighbor_mask = torch.zeros_like(adj_mat_values).scatter_(-1, indices, values) 1217 | sparse_neighbor_mask = sparse_neighbor_mask > 0.5 1218 | 1219 | # exclude edge of token to itself 1220 | 1221 | indices = repeat(torch.arange(n, device = device), 'j -> b i j', b = b, i = n) 1222 | rel_pos = rearrange(coors, 'b n d -> b n () d') - rearrange(coors, 'b n d -> b () n d') 1223 | 1224 | indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1) 1225 | rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape(b, n, n - 1, 3) 1226 | 1227 | if exists(mask): 1228 | mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j') 1229 | mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1) 1230 | 1231 | if exists(edges): 1232 | if exists(self.edge_emb): 1233 | edges = self.edge_emb(edges) 1234 | 1235 | edges = edges.masked_select(exclude_self_mask[..., None]).reshape(b, n, n - 1, -1) 1236 | 1237 | if exists(self.adj_emb): 1238 | adj_emb = self.adj_emb(adj_indices) 1239 | edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb 1240 | 1241 | rel_dist = rel_pos.norm(dim = -1) 1242 | 1243 | # rel_dist gets modified using adjacency or neighbor mask 1244 | 1245 | modified_rel_dist = rel_dist.clone() 1246 | max_value = get_max_value(modified_rel_dist) # for masking out nodes from being considered as neighbors 1247 | 1248 | # neighbors 1249 | 1250 | if exists(neighbor_mask): 1251 | neighbor_mask = remove_self(neighbor_mask) 1252 | 1253 | max_neighbors = neighbor_mask.sum(dim = -1).max().item() 1254 | if max_neighbors > neighbors: 1255 | print(f'neighbor_mask shows maximum number of neighbors as {max_neighbors} but specified number of neighbors is {neighbors}') 1256 | 1257 | modified_rel_dist = modified_rel_dist.masked_fill(~neighbor_mask, max_value) 1258 | 1259 | # use sparse neighbor mask to assign priority of bonded 1260 | 1261 | if exists(sparse_neighbor_mask): 1262 | modified_rel_dist = modified_rel_dist.masked_fill(sparse_neighbor_mask, 0.) 1263 | 1264 | # mask out future nodes to high distance if causal turned on 1265 | 1266 | if self.causal: 1267 | causal_mask = torch.ones(n, n - 1, device = device).triu().bool() 1268 | modified_rel_dist = modified_rel_dist.masked_fill(causal_mask[None, ...], max_value) 1269 | 1270 | # if number of local neighbors by distance is set to 0, then only fetch the sparse neighbors defined by adjacency matrix 1271 | 1272 | if neighbors == 0: 1273 | valid_radius = 0 1274 | 1275 | # get neighbors and neighbor mask, excluding self 1276 | 1277 | neighbors = int(min(neighbors, n - 1)) 1278 | total_neighbors = int(neighbors + num_sparse_neighbors) 1279 | assert total_neighbors > 0, 'you must be fetching at least 1 neighbor' 1280 | 1281 | total_neighbors = int(min(total_neighbors, n - 1)) # make sure total neighbors does not exceed the length of the sequence itself 1282 | 1283 | dist_values, nearest_indices = modified_rel_dist.topk(total_neighbors, dim = -1, largest = False) 1284 | neighbor_mask = dist_values <= valid_radius 1285 | 1286 | neighbor_rel_dist = batched_index_select(rel_dist, nearest_indices, dim = 2) 1287 | neighbor_rel_pos = batched_index_select(rel_pos, nearest_indices, dim = 2) 1288 | neighbor_indices = batched_index_select(indices, nearest_indices, dim = 2) 1289 | 1290 | if exists(mask): 1291 | neighbor_mask = neighbor_mask & batched_index_select(mask, nearest_indices, dim = 2) 1292 | 1293 | if exists(edges): 1294 | edges = batched_index_select(edges, nearest_indices, dim = 2) 1295 | 1296 | # calculate rotary pos emb 1297 | 1298 | rotary_pos_emb = None 1299 | rotary_query_pos_emb = None 1300 | rotary_key_pos_emb = None 1301 | 1302 | if self.rotary_position: 1303 | seq = torch.arange(n, device = device) 1304 | seq_pos_emb = self.rotary_pos_emb(seq) 1305 | self_indices = torch.arange(neighbor_indices.shape[1], device = device) 1306 | self_indices = repeat(self_indices, 'i -> b i ()', b = b) 1307 | neighbor_indices_with_self = torch.cat((self_indices, neighbor_indices), dim = 2) 1308 | pos_emb = batched_index_select(seq_pos_emb, neighbor_indices_with_self, dim = 0) 1309 | 1310 | rotary_key_pos_emb = pos_emb 1311 | rotary_query_pos_emb = repeat(seq_pos_emb, 'n d -> b n d', b = b) 1312 | 1313 | if self.rotary_rel_dist: 1314 | neighbor_rel_dist_with_self = F.pad(neighbor_rel_dist, (1, 0), value = 0) * 1e2 1315 | rel_dist_pos_emb = self.rotary_pos_emb(neighbor_rel_dist_with_self) 1316 | rotary_key_pos_emb = safe_cat(rotary_key_pos_emb, rel_dist_pos_emb, dim = -1) 1317 | 1318 | query_dist = torch.zeros(n, device = device) 1319 | query_pos_emb = self.rotary_pos_emb(query_dist) 1320 | query_pos_emb = repeat(query_pos_emb, 'n d -> b n d', b = b) 1321 | 1322 | rotary_query_pos_emb = safe_cat(rotary_query_pos_emb, query_pos_emb, dim = -1) 1323 | 1324 | if exists(rotary_query_pos_emb) and exists(rotary_key_pos_emb): 1325 | rotary_pos_emb = (rotary_query_pos_emb, rotary_key_pos_emb) 1326 | 1327 | # calculate basis 1328 | 1329 | basis = get_basis(neighbor_rel_pos, num_degrees - 1, differentiable = self.differentiable_coors) 1330 | 1331 | # main logic 1332 | 1333 | edge_info = (neighbor_indices, neighbor_mask, edges) 1334 | x = feats 1335 | 1336 | # project in 1337 | 1338 | x = self.conv_in(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis) 1339 | 1340 | # preconvolution layers 1341 | 1342 | for conv, nonlin in self.convs: 1343 | x = nonlin(x) 1344 | x = conv(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis) 1345 | 1346 | # transformer layers 1347 | 1348 | x = self.net(x, edge_info = edge_info, rel_dist = neighbor_rel_dist, basis = basis, global_feats = global_feats, pos_emb = rotary_pos_emb, mask = _mask) 1349 | 1350 | # project out 1351 | 1352 | if exists(self.conv_out): 1353 | x = self.conv_out(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis) 1354 | 1355 | # norm 1356 | 1357 | x = self.norm(x) 1358 | 1359 | # reduce dim if specified 1360 | 1361 | if exists(self.linear_out): 1362 | x = self.linear_out(x) 1363 | x = map_values(lambda t: t.squeeze(dim = 2), x) 1364 | 1365 | if return_pooled: 1366 | mask_fn = (lambda t: masked_mean(t, _mask, dim = 1)) if exists(_mask) else (lambda t: t.mean(dim = 1)) 1367 | x = map_values(mask_fn, x) 1368 | 1369 | if '0' in x: 1370 | x['0'] = x['0'].squeeze(dim = -1) 1371 | 1372 | if exists(return_type): 1373 | return x[str(return_type)] 1374 | 1375 | return x 1376 | -------------------------------------------------------------------------------- /se3_transformer_pytorch/spherical_harmonics.py: -------------------------------------------------------------------------------- 1 | from math import pi, sqrt 2 | from functools import reduce 3 | from operator import mul 4 | import torch 5 | 6 | from functools import lru_cache 7 | from se3_transformer_pytorch.utils import cache 8 | 9 | # constants 10 | 11 | CACHE = {} 12 | 13 | def clear_spherical_harmonics_cache(): 14 | CACHE.clear() 15 | 16 | def lpmv_cache_key_fn(l, m, x): 17 | return (l, m) 18 | 19 | # spherical harmonics 20 | 21 | @lru_cache(maxsize = 1000) 22 | def semifactorial(x): 23 | return reduce(mul, range(x, 1, -2), 1.) 24 | 25 | @lru_cache(maxsize = 1000) 26 | def pochhammer(x, k): 27 | return reduce(mul, range(x + 1, x + k), float(x)) 28 | 29 | def negative_lpmv(l, m, y): 30 | if m < 0: 31 | y *= ((-1) ** m / pochhammer(l + m + 1, -2 * m)) 32 | return y 33 | 34 | @cache(cache = CACHE, key_fn = lpmv_cache_key_fn) 35 | def lpmv(l, m, x): 36 | """Associated Legendre function including Condon-Shortley phase. 37 | 38 | Args: 39 | m: int order 40 | l: int degree 41 | x: float argument tensor 42 | Returns: 43 | tensor of x-shape 44 | """ 45 | # Check memoized versions 46 | m_abs = abs(m) 47 | 48 | if m_abs > l: 49 | return None 50 | 51 | if l == 0: 52 | return torch.ones_like(x) 53 | 54 | # Check if on boundary else recurse solution down to boundary 55 | if m_abs == l: 56 | # Compute P_m^m 57 | y = (-1)**m_abs * semifactorial(2*m_abs-1) 58 | y *= torch.pow(1-x*x, m_abs/2) 59 | return negative_lpmv(l, m, y) 60 | 61 | # Recursively precompute lower degree harmonics 62 | lpmv(l-1, m, x) 63 | 64 | # Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m 65 | # Inplace speedup 66 | y = ((2*l-1) / (l-m_abs)) * x * lpmv(l-1, m_abs, x) 67 | 68 | if l - m_abs > 1: 69 | y -= ((l+m_abs-1)/(l-m_abs)) * CACHE[(l-2, m_abs)] 70 | 71 | if m < 0: 72 | y = self.negative_lpmv(l, m, y) 73 | return y 74 | 75 | def get_spherical_harmonics_element(l, m, theta, phi): 76 | """Tesseral spherical harmonic with Condon-Shortley phase. 77 | 78 | The Tesseral spherical harmonics are also known as the real spherical 79 | harmonics. 80 | 81 | Args: 82 | l: int for degree 83 | m: int for order, where -l <= m < l 84 | theta: collatitude or polar angle 85 | phi: longitude or azimuth 86 | Returns: 87 | tensor of shape theta 88 | """ 89 | m_abs = abs(m) 90 | assert m_abs <= l, "absolute value of order m must be <= degree l" 91 | 92 | N = sqrt((2*l + 1) / (4 * pi)) 93 | leg = lpmv(l, m_abs, torch.cos(theta)) 94 | 95 | if m == 0: 96 | return N * leg 97 | 98 | if m > 0: 99 | Y = torch.cos(m * phi) 100 | else: 101 | Y = torch.sin(m_abs * phi) 102 | 103 | Y *= leg 104 | N *= sqrt(2. / pochhammer(l - m_abs + 1, 2 * m_abs)) 105 | Y *= N 106 | return Y 107 | 108 | def get_spherical_harmonics(l, theta, phi): 109 | """ Tesseral harmonic with Condon-Shortley phase. 110 | 111 | The Tesseral spherical harmonics are also known as the real spherical 112 | harmonics. 113 | 114 | Args: 115 | l: int for degree 116 | theta: collatitude or polar angle 117 | phi: longitude or azimuth 118 | Returns: 119 | tensor of shape [*theta.shape, 2*l+1] 120 | """ 121 | return torch.stack([ get_spherical_harmonics_element(l, m, theta, phi) \ 122 | for m in range(-l, l+1) ], 123 | dim = -1) 124 | -------------------------------------------------------------------------------- /se3_transformer_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import pickle 5 | import gzip 6 | import torch 7 | import contextlib 8 | from functools import wraps, lru_cache 9 | from filelock import FileLock 10 | 11 | from einops import rearrange 12 | 13 | # helper functions 14 | 15 | def exists(val): 16 | return val is not None 17 | 18 | def default(val, d): 19 | return val if exists(val) else d 20 | 21 | def uniq(arr): 22 | return list({el: True for el in arr}.keys()) 23 | 24 | def to_order(degree): 25 | return 2 * degree + 1 26 | 27 | def map_values(fn, d): 28 | return {k: fn(v) for k, v in d.items()} 29 | 30 | def safe_cat(arr, el, dim): 31 | if not exists(arr): 32 | return el 33 | return torch.cat((arr, el), dim = dim) 34 | 35 | def cast_tuple(val, depth): 36 | return val if isinstance(val, tuple) else (val,) * depth 37 | 38 | def broadcat(tensors, dim = -1): 39 | num_tensors = len(tensors) 40 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 41 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' 42 | shape_len = list(shape_lens)[0] 43 | 44 | dim = (dim + shape_len) if dim < 0 else dim 45 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 46 | 47 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 48 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' 49 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 50 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 51 | expanded_dims.insert(dim, (dim, dims[dim])) 52 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 53 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 54 | return torch.cat(tensors, dim = dim) 55 | 56 | def batched_index_select(values, indices, dim = 1): 57 | value_dims = values.shape[(dim + 1):] 58 | values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices)) 59 | indices = indices[(..., *((None,) * len(value_dims)))] 60 | indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims) 61 | value_expand_len = len(indices_shape) - (dim + 1) 62 | values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)] 63 | 64 | value_expand_shape = [-1] * len(values.shape) 65 | expand_slice = slice(dim, (dim + value_expand_len)) 66 | value_expand_shape[expand_slice] = indices.shape[expand_slice] 67 | values = values.expand(*value_expand_shape) 68 | 69 | dim += value_expand_len 70 | return values.gather(dim, indices) 71 | 72 | def masked_mean(tensor, mask, dim = -1): 73 | diff_len = len(tensor.shape) - len(mask.shape) 74 | mask = mask[(..., *((None,) * diff_len))] 75 | tensor.masked_fill_(~mask, 0.) 76 | 77 | total_el = mask.sum(dim = dim) 78 | mean = tensor.sum(dim = dim) / total_el.clamp(min = 1.) 79 | mean.masked_fill_(total_el == 0, 0.) 80 | return mean 81 | 82 | def rand_uniform(size, min_val, max_val): 83 | return torch.empty(size).uniform_(min_val, max_val) 84 | 85 | def fast_split(arr, splits, dim=0): 86 | axis_len = arr.shape[dim] 87 | splits = min(axis_len, max(splits, 1)) 88 | chunk_size = axis_len // splits 89 | remainder = axis_len - chunk_size * splits 90 | s = 0 91 | for i in range(splits): 92 | adjust, remainder = 1 if remainder > 0 else 0, remainder - 1 93 | yield torch.narrow(arr, dim, s, chunk_size + adjust) 94 | s += chunk_size + adjust 95 | 96 | def fourier_encode(x, num_encodings = 4, include_self = True, flatten = True): 97 | x = x.unsqueeze(-1) 98 | device, dtype, orig_x = x.device, x.dtype, x 99 | scales = 2 ** torch.arange(num_encodings, device = device, dtype = dtype) 100 | x = x / scales 101 | x = torch.cat([x.sin(), x.cos()], dim=-1) 102 | x = torch.cat((x, orig_x), dim = -1) if include_self else x 103 | x = rearrange(x, 'b m n ... -> b m n (...)') if flatten else x 104 | return x 105 | 106 | # default dtype context manager 107 | 108 | @contextlib.contextmanager 109 | def torch_default_dtype(dtype): 110 | prev_dtype = torch.get_default_dtype() 111 | torch.set_default_dtype(dtype) 112 | yield 113 | torch.set_default_dtype(prev_dtype) 114 | 115 | def cast_torch_tensor(fn): 116 | @wraps(fn) 117 | def inner(t): 118 | if not torch.is_tensor(t): 119 | t = torch.tensor(t, dtype = torch.get_default_dtype()) 120 | return fn(t) 121 | return inner 122 | 123 | # benchmark tool 124 | 125 | def benchmark(fn): 126 | def inner(*args, **kwargs): 127 | start = time.time() 128 | res = fn(*args, **kwargs) 129 | diff = time.time() - start 130 | return diff, res 131 | return inner 132 | 133 | # caching functions 134 | 135 | def cache(cache, key_fn): 136 | def cache_inner(fn): 137 | @wraps(fn) 138 | def inner(*args, **kwargs): 139 | key_name = key_fn(*args, **kwargs) 140 | if key_name in cache: 141 | return cache[key_name] 142 | res = fn(*args, **kwargs) 143 | cache[key_name] = res 144 | return res 145 | 146 | return inner 147 | return cache_inner 148 | 149 | # cache in directory 150 | 151 | def cache_dir(dirname, maxsize=128): 152 | ''' 153 | Cache a function with a directory 154 | 155 | :param dirname: the directory path 156 | :param maxsize: maximum size of the RAM cache (there is no limit for the directory cache) 157 | ''' 158 | def decorator(func): 159 | 160 | @lru_cache(maxsize=maxsize) 161 | @wraps(func) 162 | def wrapper(*args, **kwargs): 163 | if not exists(dirname): 164 | return func(*args, **kwargs) 165 | 166 | os.makedirs(dirname, exist_ok = True) 167 | 168 | indexfile = os.path.join(dirname, "index.pkl") 169 | lock = FileLock(os.path.join(dirname, "mutex")) 170 | 171 | with lock: 172 | index = {} 173 | if os.path.exists(indexfile): 174 | with open(indexfile, "rb") as file: 175 | index = pickle.load(file) 176 | 177 | key = (args, frozenset(kwargs), func.__defaults__) 178 | 179 | if key in index: 180 | filename = index[key] 181 | else: 182 | index[key] = filename = f"{len(index)}.pkl.gz" 183 | with open(indexfile, "wb") as file: 184 | pickle.dump(index, file) 185 | 186 | filepath = os.path.join(dirname, filename) 187 | 188 | if os.path.exists(filepath): 189 | with lock: 190 | with gzip.open(filepath, "rb") as file: 191 | result = pickle.load(file) 192 | return result 193 | 194 | print(f"compute {filename}... ", end="", flush = True) 195 | result = func(*args, **kwargs) 196 | print(f"save {filename}... ", end="", flush = True) 197 | 198 | with lock: 199 | with gzip.open(filepath, "wb") as file: 200 | pickle.dump(result, file) 201 | 202 | print("done") 203 | 204 | return result 205 | return wrapper 206 | return decorator 207 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | addopts = --verbose 6 | python_files = tests/*.py 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'se3-transformer-pytorch', 5 | packages = find_packages(), 6 | include_package_data = True, 7 | version = '0.9.0', 8 | license='MIT', 9 | description = 'SE3 Transformer - Pytorch', 10 | author = 'Phil Wang', 11 | author_email = 'lucidrains@gmail.com', 12 | url = 'https://github.com/lucidrains/se3-transformer-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'attention mechanism', 16 | 'transformers', 17 | 'equivariance', 18 | 'SE3' 19 | ], 20 | install_requires=[ 21 | 'einops>=0.3', 22 | 'filelock', 23 | 'numpy', 24 | 'torch>=1.6' 25 | ], 26 | setup_requires=[ 27 | 'pytest-runner', 28 | ], 29 | tests_require=[ 30 | 'pytest', 31 | 'lie_learn', 32 | 'numpy', 33 | ], 34 | classifiers=[ 35 | 'Development Status :: 4 - Beta', 36 | 'Intended Audience :: Developers', 37 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 38 | 'License :: OSI Approved :: MIT License', 39 | 'Programming Language :: Python :: 3.6', 40 | ], 41 | ) 42 | -------------------------------------------------------------------------------- /tests/test_basis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from se3_transformer_pytorch.basis import get_basis, get_R_tensor, basis_transformation_Q_J 3 | from se3_transformer_pytorch.irr_repr import irr_repr 4 | 5 | def test_basis(): 6 | max_degree = 3 7 | x = torch.randn(2, 1024, 3) 8 | basis = get_basis(x, max_degree) 9 | assert len(basis.keys()) == (max_degree + 1) ** 2, 'correct number of basis kernels' 10 | 11 | def test_basis_transformation_Q_J(): 12 | rand_angles = torch.rand(4, 3) 13 | J, order_out, order_in = 1, 1, 1 14 | Q_J = basis_transformation_Q_J(J, order_in, order_out).float() 15 | assert all(torch.allclose(get_R_tensor(order_out, order_in, a, b, c) @ Q_J, Q_J @ irr_repr(J, a, b, c)) for a, b, c in rand_angles) 16 | -------------------------------------------------------------------------------- /tests/test_equivariance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer 3 | from se3_transformer_pytorch.irr_repr import rot 4 | from se3_transformer_pytorch.utils import torch_default_dtype, fourier_encode 5 | 6 | def test_transformer(): 7 | model = SE3Transformer( 8 | dim = 64, 9 | depth = 1, 10 | num_degrees = 2, 11 | num_neighbors = 4, 12 | valid_radius = 10 13 | ) 14 | 15 | feats = torch.randn(1, 32, 64) 16 | coors = torch.randn(1, 32, 3) 17 | mask = torch.ones(1, 32).bool() 18 | 19 | out = model(feats, coors, mask, return_type = 0) 20 | assert out.shape == (1, 32, 64), 'output must be of the right shape' 21 | 22 | def test_causal_se3_transformer(): 23 | model = SE3Transformer( 24 | dim = 64, 25 | depth = 1, 26 | num_degrees = 2, 27 | num_neighbors = 4, 28 | valid_radius = 10, 29 | causal = True 30 | ) 31 | 32 | feats = torch.randn(1, 32, 64) 33 | coors = torch.randn(1, 32, 3) 34 | mask = torch.ones(1, 32).bool() 35 | 36 | out = model(feats, coors, mask, return_type = 0) 37 | assert out.shape == (1, 32, 64), 'output must be of the right shape' 38 | 39 | def test_se3_transformer_with_global_nodes(): 40 | model = SE3Transformer( 41 | dim = 64, 42 | depth = 1, 43 | num_degrees = 2, 44 | num_neighbors = 4, 45 | valid_radius = 10, 46 | global_feats_dim = 16 47 | ) 48 | 49 | feats = torch.randn(1, 32, 64) 50 | coors = torch.randn(1, 32, 3) 51 | mask = torch.ones(1, 32).bool() 52 | 53 | global_feats = torch.randn(1, 2, 16) 54 | 55 | out = model(feats, coors, mask, return_type = 0, global_feats = global_feats) 56 | assert out.shape == (1, 32, 64), 'output must be of the right shape' 57 | 58 | def test_one_headed_key_values_se3_transformer_with_global_nodes(): 59 | model = SE3Transformer( 60 | dim = 64, 61 | depth = 1, 62 | num_degrees = 2, 63 | num_neighbors = 4, 64 | valid_radius = 10, 65 | global_feats_dim = 16, 66 | one_headed_key_values = True 67 | ) 68 | 69 | feats = torch.randn(1, 32, 64) 70 | coors = torch.randn(1, 32, 3) 71 | mask = torch.ones(1, 32).bool() 72 | 73 | global_feats = torch.randn(1, 2, 16) 74 | 75 | out = model(feats, coors, mask, return_type = 0, global_feats = global_feats) 76 | assert out.shape == (1, 32, 64), 'output must be of the right shape' 77 | 78 | def test_transformer_with_edges(): 79 | model = SE3Transformer( 80 | dim = 64, 81 | depth = 1, 82 | num_degrees = 2, 83 | num_neighbors = 4, 84 | edge_dim = 4, 85 | num_edge_tokens = 4 86 | ) 87 | 88 | feats = torch.randn(1, 32, 64) 89 | edges = torch.randint(0, 4, (1, 32)) 90 | coors = torch.randn(1, 32, 3) 91 | mask = torch.ones(1, 32).bool() 92 | 93 | out = model(feats, coors, mask, edges = edges, return_type = 0) 94 | assert out.shape == (1, 32, 64), 'output must be of the right shape' 95 | 96 | def test_transformer_with_continuous_edges(): 97 | model = SE3Transformer( 98 | dim = 64, 99 | depth = 1, 100 | attend_self = True, 101 | num_degrees = 2, 102 | output_degrees = 2, 103 | edge_dim = 34 104 | ) 105 | 106 | feats = torch.randn(1, 32, 64) 107 | coors = torch.randn(1, 32, 3) 108 | mask = torch.ones(1, 32).bool() 109 | 110 | pairwise_continuous_values = torch.randint(0, 4, (1, 32, 32, 2)) 111 | 112 | edges = fourier_encode( 113 | pairwise_continuous_values, 114 | num_encodings = 8, 115 | include_self = True 116 | ) 117 | 118 | out = model(feats, coors, mask, edges = edges, return_type = 1) 119 | assert True 120 | 121 | def test_different_input_dimensions_for_types(): 122 | model = SE3Transformer( 123 | dim_in = (4, 2), 124 | dim = 4, 125 | depth = 1, 126 | input_degrees = 2, 127 | num_degrees = 2, 128 | output_degrees = 2, 129 | reduce_dim_out = True 130 | ) 131 | 132 | atom_feats = torch.randn(2, 32, 4, 1) 133 | coors_feats = torch.randn(2, 32, 2, 3) 134 | 135 | features = {'0': atom_feats, '1': coors_feats} 136 | coors = torch.randn(2, 32, 3) 137 | mask = torch.ones(2, 32).bool() 138 | 139 | refined_coors = coors + model(features, coors, mask, return_type = 1) 140 | assert True 141 | 142 | def test_equivariance(): 143 | model = SE3Transformer( 144 | dim = 64, 145 | depth = 1, 146 | attend_self = True, 147 | num_neighbors = 4, 148 | num_degrees = 2, 149 | output_degrees = 2, 150 | fourier_encode_dist = True 151 | ) 152 | 153 | feats = torch.randn(1, 32, 64) 154 | coors = torch.randn(1, 32, 3) 155 | mask = torch.ones(1, 32).bool() 156 | 157 | R = rot(15, 0, 45) 158 | out1 = model(feats, coors @ R, mask, return_type = 1) 159 | out2 = model(feats, coors, mask, return_type = 1) @ R 160 | 161 | diff = (out1 - out2).max() 162 | assert diff < 1e-4, 'is not equivariant' 163 | 164 | def test_equivariance_with_egnn_backbone(): 165 | model = SE3Transformer( 166 | dim = 64, 167 | depth = 1, 168 | attend_self = True, 169 | num_neighbors = 4, 170 | num_degrees = 2, 171 | output_degrees = 2, 172 | fourier_encode_dist = True, 173 | use_egnn = True 174 | ) 175 | 176 | feats = torch.randn(1, 32, 64) 177 | coors = torch.randn(1, 32, 3) 178 | mask = torch.ones(1, 32).bool() 179 | 180 | R = rot(15, 0, 45) 181 | out1 = model(feats, coors @ R, mask, return_type = 1) 182 | out2 = model(feats, coors, mask, return_type = 1) @ R 183 | 184 | diff = (out1 - out2).max() 185 | assert diff < 1e-4, 'is not equivariant' 186 | 187 | def test_rotary(): 188 | model = SE3Transformer( 189 | dim = 64, 190 | depth = 1, 191 | attend_self = True, 192 | num_neighbors = 4, 193 | num_degrees = 2, 194 | output_degrees = 2, 195 | fourier_encode_dist = True, 196 | rotary_position = True, 197 | rotary_rel_dist = True 198 | ) 199 | 200 | feats = torch.randn(1, 32, 64) 201 | coors = torch.randn(1, 32, 3) 202 | mask = torch.ones(1, 32).bool() 203 | 204 | R = rot(15, 0, 45) 205 | out1 = model(feats, coors @ R, mask, return_type = 1) 206 | out2 = model(feats, coors, mask, return_type = 1) @ R 207 | 208 | diff = (out1 - out2).max() 209 | assert diff < 1e-4, 'is not equivariant' 210 | 211 | def test_equivariance_linear_proj_keys(): 212 | model = SE3Transformer( 213 | dim = 64, 214 | depth = 1, 215 | attend_self = True, 216 | num_neighbors = 4, 217 | num_degrees = 2, 218 | output_degrees = 2, 219 | fourier_encode_dist = True, 220 | linear_proj_keys = True 221 | ) 222 | 223 | feats = torch.randn(1, 32, 64) 224 | coors = torch.randn(1, 32, 3) 225 | mask = torch.ones(1, 32).bool() 226 | 227 | R = rot(15, 0, 45) 228 | out1 = model(feats, coors @ R, mask, return_type = 1) 229 | out2 = model(feats, coors, mask, return_type = 1) @ R 230 | 231 | diff = (out1 - out2).max() 232 | assert diff < 1e-4, 'is not equivariant' 233 | 234 | @torch_default_dtype(torch.float64) 235 | def test_equivariance_only_sparse_neighbors(): 236 | model = SE3Transformer( 237 | dim = 64, 238 | depth = 1, 239 | attend_self = True, 240 | num_degrees = 2, 241 | output_degrees = 2, 242 | num_neighbors = 0, 243 | attend_sparse_neighbors = True, 244 | num_adj_degrees = 2, 245 | adj_dim = 4 246 | ) 247 | 248 | feats = torch.randn(1, 32, 64) 249 | coors = torch.randn(1, 32, 3) 250 | mask = torch.ones(1, 32).bool() 251 | 252 | seq = torch.arange(32) 253 | adj_mat = (seq[:, None] >= (seq[None, :] - 1)) & (seq[:, None] <= (seq[None, :] + 1)) 254 | 255 | R = rot(15, 0, 45) 256 | out1 = model(feats, coors @ R, mask, adj_mat = adj_mat, return_type = 1) 257 | out2 = model(feats, coors, mask, adj_mat = adj_mat, return_type = 1) @ R 258 | 259 | diff = (out1 - out2).max() 260 | assert diff < 1e-4, 'is not equivariant' 261 | 262 | def test_equivariance_with_reversible_network(): 263 | model = SE3Transformer( 264 | dim = 64, 265 | depth = 1, 266 | attend_self = True, 267 | num_neighbors = 4, 268 | num_degrees = 2, 269 | output_degrees = 2, 270 | reversible = True 271 | ) 272 | 273 | feats = torch.randn(1, 32, 64) 274 | coors = torch.randn(1, 32, 3) 275 | mask = torch.ones(1, 32).bool() 276 | 277 | R = rot(15, 0, 45) 278 | out1 = model(feats, coors @ R, mask, return_type = 1) 279 | out2 = model(feats, coors, mask, return_type = 1) @ R 280 | 281 | diff = (out1 - out2).max() 282 | assert diff < 1e-4, 'is not equivariant' 283 | 284 | def test_equivariance_with_type_one_input(): 285 | model = SE3Transformer( 286 | dim = 64, 287 | depth = 1, 288 | attend_self = True, 289 | num_neighbors = 4, 290 | num_degrees = 2, 291 | input_degrees = 2, 292 | output_degrees = 2 293 | ) 294 | 295 | atom_features = torch.randn(1, 32, 64, 1) 296 | pred_coors = torch.randn(1, 32, 64, 3) 297 | 298 | coors = torch.randn(1, 32, 3) 299 | mask = torch.ones(1, 32).bool() 300 | 301 | R = rot(15, 0, 45) 302 | out1 = model({'0': atom_features, '1': pred_coors @ R}, coors @ R, mask, return_type = 1) 303 | out2 = model({'0': atom_features, '1': pred_coors}, coors, mask, return_type = 1) @ R 304 | 305 | diff = (out1 - out2).max() 306 | assert diff < 1e-4, 'is not equivariant' 307 | -------------------------------------------------------------------------------- /tests/test_irrep_repr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from se3_transformer_pytorch.spherical_harmonics import clear_spherical_harmonics_cache 3 | from se3_transformer_pytorch.irr_repr import spherical_harmonics, irr_repr, compose 4 | from se3_transformer_pytorch.utils import torch_default_dtype 5 | 6 | @torch_default_dtype(torch.float64) 7 | def test_irr_repr(): 8 | """ 9 | This test tests that 10 | - irr_repr 11 | - compose 12 | - spherical_harmonics 13 | are compatible 14 | 15 | Y(Z(alpha) Y(beta) Z(gamma) x) = D(alpha, beta, gamma) Y(x) 16 | with x = Z(a) Y(b) eta 17 | """ 18 | for order in range(7): 19 | a, b = torch.rand(2) 20 | alpha, beta, gamma = torch.rand(3) 21 | 22 | ra, rb, _ = compose(alpha, beta, gamma, a, b, 0) 23 | Yrx = spherical_harmonics(order, ra, rb) 24 | clear_spherical_harmonics_cache() 25 | 26 | Y = spherical_harmonics(order, a, b) 27 | clear_spherical_harmonics_cache() 28 | 29 | DrY = irr_repr(order, alpha, beta, gamma) @ Y 30 | 31 | d, r = (Yrx - DrY).abs().max(), Y.abs().max() 32 | print(d.item(), r.item()) 33 | assert d < 1e-10 * r, d / r 34 | -------------------------------------------------------------------------------- /tests/test_spherical_harmonics.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import numpy as np 4 | 5 | from lie_learn.representations.SO3.spherical_harmonics import sh 6 | 7 | from se3_transformer_pytorch.spherical_harmonics import get_spherical_harmonics_element 8 | from se3_transformer_pytorch.utils import benchmark 9 | 10 | def test_spherical_harmonics(): 11 | dtype = torch.float64 12 | 13 | theta = 0.1 * torch.randn(32, 1024, 10, dtype=dtype) 14 | phi = 0.1 * torch.randn(32, 1024, 10, dtype=dtype) 15 | 16 | s0 = s1 = 0 17 | max_error = -1. 18 | 19 | for l in range(8): 20 | for m in range(-l, l + 1): 21 | start = time.time() 22 | 23 | diff, y = benchmark(get_spherical_harmonics_element)(l, m, theta, phi) 24 | y = y.type(torch.float32) 25 | s0 += diff 26 | 27 | diff, z = benchmark(sh)(l, m, theta, phi) 28 | s1 += diff 29 | 30 | error = np.mean(np.abs((y.cpu().numpy() - z) / z)) 31 | max_error = max(max_error, error) 32 | print(f"l: {l}, m: {m} ", error) 33 | 34 | time_diff_ratio = s0 / s1 35 | 36 | assert max_error < 1e-4, 'maximum error must be less than 1e-3' 37 | assert time_diff_ratio < 1., 'spherical harmonics must be faster than the one offered by lie_learn' 38 | 39 | print(f"Max error: {max_error}") 40 | print(f"Time diff: {time_diff_ratio}") 41 | --------------------------------------------------------------------------------