├── diagram.png ├── diagram-2.png ├── lie_transformer_pytorch ├── __init__.py ├── reversible.py ├── se3.py └── lie_transformer_pytorch.py ├── setup.cfg ├── tests.py ├── .github └── workflows │ ├── python-publish.yml │ └── python-package.yml ├── setup.py ├── LICENSE ├── .gitignore └── README.md /diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/lie-transformer-pytorch/HEAD/diagram.png -------------------------------------------------------------------------------- /diagram-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/lie-transformer-pytorch/HEAD/diagram-2.png -------------------------------------------------------------------------------- /lie_transformer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from lie_transformer_pytorch.lie_transformer_pytorch import LieTransformer 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | addopts = --verbose 6 | python_files = tests.py 7 | -------------------------------------------------------------------------------- /tests.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lie_transformer_pytorch import LieTransformer 3 | 4 | def test_transformer(): 5 | model = LieTransformer( 6 | dim = 512, 7 | depth = 1 8 | ) 9 | 10 | feats = torch.randn(1, 64, 512) 11 | coors = torch.randn(1, 64, 3) 12 | mask = torch.ones(1, 64).bool() 13 | 14 | out = model(feats, coors, mask = mask) 15 | assert out.shape == (1, 256, 512), 'transformer runs' 16 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.7, 3.8, 3.9] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 30 | - name: Test with pytest 31 | run: | 32 | python setup.py test 33 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'lie-transformer-pytorch', 5 | packages = find_packages(), 6 | version = '0.0.17', 7 | license='MIT', 8 | description = 'Lie Transformer - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/lie-transformer-pytorch', 12 | keywords = [ 13 | 'artificial intelligence', 14 | 'attention mechanism', 15 | 'transformers', 16 | 'equivariance', 17 | 'lifting', 18 | 'lie groups' 19 | ], 20 | install_requires=[ 21 | 'torch>=1.6', 22 | 'einops>=0.3' 23 | ], 24 | setup_requires=[ 25 | 'pytest-runner', 26 | ], 27 | tests_require=[ 28 | 'pytest' 29 | ], 30 | classifiers=[ 31 | 'Development Status :: 4 - Beta', 32 | 'Intended Audience :: Developers', 33 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 34 | 'License :: OSI Approved :: MIT License', 35 | 'Programming Language :: Python :: 3.6', 36 | ], 37 | ) 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | ## Lie Transformer - Pytorch 6 | 7 | Implementation of Lie Transformer, Equivariant Self-Attention, in Pytorch. Only the SE3 version will be present in this repository, as it may be needed for Alphafold2 replication. 8 | 9 | ## Install 10 | 11 | ```bash 12 | $ pip install lie-transformer-pytorch 13 | ``` 14 | 15 | ## Usage 16 | 17 | ```python 18 | import torch 19 | from lie_transformer_pytorch import LieTransformer 20 | 21 | model = LieTransformer( 22 | dim = 512, 23 | depth = 2, 24 | heads = 8, 25 | dim_head = 64, 26 | liftsamples = 4 27 | ) 28 | 29 | coors = torch.randn(1, 64, 3) 30 | features = torch.randn(1, 64, 512) 31 | mask = torch.ones(1, 64).bool() 32 | 33 | out = model(features, coors, mask = mask) # (1, 256, 512) <- 256 = (seq len * liftsamples) 34 | ``` 35 | 36 | Allowing Lie Transformer take care of embedding the features, just specify the number of unique tokens (node types). 37 | 38 | ```python 39 | import torch 40 | from lie_transformer_pytorch import LieTransformer 41 | 42 | model = LieTransformer( 43 | num_tokens = 28, # say 28 different types of atoms 44 | dim = 512, 45 | depth = 2, 46 | heads = 8, 47 | dim_head = 64, 48 | liftsamples = 4 49 | ) 50 | 51 | atoms = torch.randint(0, 28, (1, 64)) 52 | coors = torch.randn(1, 64, 3) 53 | mask = torch.ones(1, 64).bool() 54 | 55 | out = model(atoms, coors, mask = mask) # (1, 256, 512) <- 256 = (seq len * liftsamples) 56 | ``` 57 | 58 | Although it was not in the paper, I decided to allow for passing in edge information as well (bond types). The edge information will be embedded by the dimension specified, concatted with the location, and passed through the MLP before summed with the attention matrix. 59 | 60 | Simply set two more keyword arguments on initialization of the transformer, and then pass in the specific bond types as shape `b x seq x seq`. 61 | 62 | ```python 63 | import torch 64 | from lie_transformer_pytorch import LieTransformer 65 | 66 | model = LieTransformer( 67 | num_tokens = 28, # say 28 different types of atoms 68 | num_edge_types = 4, # number of different edge types 69 | edge_dim = 16, # dimension of edges 70 | dim = 512, 71 | depth = 2, 72 | heads = 8, 73 | dim_head = 64, 74 | liftsamples = 4 75 | ) 76 | 77 | atoms = torch.randint(0, 28, (1, 64)) 78 | bonds = torch.randint(0, 4, (1, 64, 64)) 79 | coors = torch.randn(1, 64, 3) 80 | mask = torch.ones(1, 64).bool() 81 | 82 | out = model(atoms, coors, edges = bonds, mask = mask) # (1, 256, 512) <- 256 = (seq len * liftsamples) 83 | ``` 84 | 85 | ## Credit 86 | 87 | This repository is largely adapted from LieConv, cited below 88 | 89 | ## Citations 90 | 91 | ```bibtex 92 | @misc{hutchinson2020lietransformer, 93 | title = {LieTransformer: Equivariant self-attention for Lie Groups}, 94 | author = {Michael Hutchinson and Charline Le Lan and Sheheryar Zaidi and Emilien Dupont and Yee Whye Teh and Hyunjik Kim}, 95 | year = {2020}, 96 | eprint = {2012.10885}, 97 | archivePrefix = {arXiv}, 98 | primaryClass = {cs.LG} 99 | } 100 | ``` 101 | 102 | ```bibtex 103 | @misc{finzi2020generalizing, 104 | title = {Generalizing Convolutional Neural Networks for Equivariance to Lie Groups on Arbitrary Continuous Data}, 105 | author = {Marc Finzi and Samuel Stanton and Pavel Izmailov and Andrew Gordon Wilson}, 106 | year = {2020}, 107 | eprint = {2002.12880}, 108 | archivePrefix = {arXiv}, 109 | primaryClass = {stat.ML} 110 | } 111 | ``` 112 | -------------------------------------------------------------------------------- /lie_transformer_pytorch/reversible.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd.function import Function 4 | from torch.utils.checkpoint import get_device_states, set_device_states 5 | 6 | # helpers 7 | 8 | def sum_tuple(x, y, dim = 1): 9 | x = list(x) 10 | x[dim] += y[dim] 11 | return tuple(x) 12 | 13 | def subtract_tuple(x, y, dim = 1): 14 | x = list(x) 15 | x[dim] -= y[dim] 16 | return tuple(x) 17 | 18 | def set_tuple(x, dim, value): 19 | x = list(x).copy() 20 | x[dim] = value 21 | return tuple(x) 22 | 23 | def map_tuple(fn, x, dim = 1): 24 | x = list(x) 25 | x[dim] = fn(x[dim]) 26 | return tuple(x) 27 | 28 | def chunk_tuple(fn, x, dim = 1): 29 | x = list(x) 30 | value = x[dim] 31 | chunks = fn(value) 32 | return tuple(map(lambda t: set_tuple(x, 1, t), chunks)) 33 | 34 | def cat_tuple(x, y, dim = 1, cat_dim = -1): 35 | x = list(x) 36 | y = list(y) 37 | x[dim] = torch.cat((x[dim], y[dim]), dim = cat_dim) 38 | return tuple(x) 39 | 40 | def del_tuple(x): 41 | for el in x: 42 | if el is not None: 43 | del el 44 | 45 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 46 | class Deterministic(nn.Module): 47 | def __init__(self, net): 48 | super().__init__() 49 | self.net = net 50 | self.cpu_state = None 51 | self.cuda_in_fwd = None 52 | self.gpu_devices = None 53 | self.gpu_states = None 54 | 55 | def record_rng(self, *args): 56 | self.cpu_state = torch.get_rng_state() 57 | if torch.cuda._initialized: 58 | self.cuda_in_fwd = True 59 | self.gpu_devices, self.gpu_states = get_device_states(*args) 60 | 61 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs): 62 | if record_rng: 63 | self.record_rng(*args) 64 | 65 | if not set_rng: 66 | return self.net(*args, **kwargs) 67 | 68 | rng_devices = [] 69 | if self.cuda_in_fwd: 70 | rng_devices = self.gpu_devices 71 | 72 | with torch.random.fork_rng(devices=rng_devices, enabled=True): 73 | torch.set_rng_state(self.cpu_state) 74 | if self.cuda_in_fwd: 75 | set_device_states(self.gpu_devices, self.gpu_states) 76 | return self.net(*args, **kwargs) 77 | 78 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 79 | # once multi-GPU is confirmed working, refactor and send PR back to source 80 | class ReversibleBlock(nn.Module): 81 | def __init__(self, f, g): 82 | super().__init__() 83 | self.f = Deterministic(f) 84 | self.g = Deterministic(g) 85 | 86 | def forward(self, x, f_args = {}, g_args = {}): 87 | training = self.training 88 | x1, x2 = chunk_tuple(lambda t: torch.chunk(t, 2, dim=2), x) 89 | y1, y2 = None, None 90 | 91 | with torch.no_grad(): 92 | y1 = sum_tuple(self.f(x2, record_rng = training, **f_args), x1) 93 | y2 = sum_tuple(self.g(y1, record_rng = training, **g_args), x2) 94 | 95 | return cat_tuple(y1, y2, cat_dim = 2) 96 | 97 | def backward_pass(self, y, dy, f_args = {}, g_args = {}): 98 | y1, y2 = chunk_tuple(lambda t: torch.chunk(t, 2, dim=2), y) 99 | del_tuple(y) 100 | 101 | dy1, dy2 = torch.chunk(dy, 2, dim=2) 102 | del dy 103 | 104 | with torch.enable_grad(): 105 | y1[1].requires_grad = True 106 | gy1 = self.g(y1, set_rng=True, **g_args) 107 | torch.autograd.backward(gy1[1], dy2) 108 | 109 | with torch.no_grad(): 110 | x2 = subtract_tuple(y2, gy1) 111 | del_tuple(y2) 112 | del gy1 113 | 114 | dx1 = dy1 + y1[1].grad 115 | del dy1 116 | y1[1].grad = None 117 | 118 | with torch.enable_grad(): 119 | x2[1].requires_grad = True 120 | fx2 = self.f(x2, set_rng = True, **f_args) 121 | torch.autograd.backward(fx2[1], dx1) 122 | 123 | with torch.no_grad(): 124 | x1 = subtract_tuple(y1, fx2) 125 | del fx2 126 | del_tuple(y1) 127 | 128 | dx2 = dy2 + x2[1].grad 129 | del dy2 130 | x2[1].grad = None 131 | 132 | x2 = map_tuple(lambda t: t.detach(), x2) 133 | x = cat_tuple(x1, x2, cat_dim = -1) 134 | dx = torch.cat((dx1, dx2), dim=2) 135 | 136 | return x, dx 137 | 138 | class _ReversibleFunction(Function): 139 | @staticmethod 140 | def forward(ctx, x, blocks, kwargs): 141 | ctx.kwargs = kwargs 142 | x = (kwargs.pop('coords'), x, kwargs.pop('mask'), kwargs.pop('edges')) 143 | 144 | for block in blocks: 145 | x = block(x, **kwargs) 146 | 147 | ctx.y = map_tuple(lambda t: t.detach(), x, dim = 1) 148 | ctx.blocks = blocks 149 | return x[1] 150 | 151 | @staticmethod 152 | def backward(ctx, dy): 153 | y = ctx.y 154 | kwargs = ctx.kwargs 155 | 156 | for block in ctx.blocks[::-1]: 157 | y, dy = block.backward_pass(y, dy, **kwargs) 158 | return dy, None, None 159 | 160 | class SequentialSequence(nn.Module): 161 | def __init__(self, blocks): 162 | super().__init__() 163 | self.blocks = blocks 164 | 165 | def forward(self, x): 166 | for (f, g) in self.blocks: 167 | x = sum_tuple(f(x), x, dim = 1) 168 | x = sum_tuple(g(x), x, dim = 1) 169 | return x 170 | 171 | class ReversibleSequence(nn.Module): 172 | def __init__(self, blocks): 173 | super().__init__() 174 | self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks]) 175 | 176 | def forward(self, x, **kwargs): 177 | x = map_tuple(lambda t: torch.cat((t, t), dim = -1), x) 178 | 179 | blocks = self.blocks 180 | 181 | coords, values, mask, edges = x 182 | kwargs = {'coords': coords, 'mask': mask, 'edges': edges, **kwargs} 183 | x = _ReversibleFunction.apply(values, blocks, kwargs) 184 | 185 | x = (coords, x, mask, edges) 186 | return map_tuple(lambda t: sum(t.chunk(2, dim = -1)) * 0.5, x) 187 | -------------------------------------------------------------------------------- /lie_transformer_pytorch/se3.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | import torch 3 | from functools import wraps 4 | from torch import acos, atan2, cos, sin 5 | from einops import rearrange, repeat 6 | 7 | # constants 8 | 9 | THRES = 7e-2 10 | 11 | # helper functions 12 | 13 | def exists(val): 14 | return val is not None 15 | 16 | def to(t): 17 | return {'device': t.device, 'dtype': t.dtype} 18 | 19 | def taylor(thres): 20 | def outer(fn): 21 | @wraps(fn) 22 | def inner(x): 23 | usetaylor = x.abs() < THRES 24 | taylor_expanded, full = fn(x, x * x) 25 | return torch.where(usetaylor, taylor_expanded, full) 26 | return inner 27 | return outer 28 | 29 | # Helper functions for analytic exponential maps. Uses taylor expansions near x=0 30 | # See http://ethaneade.com/lie_groups.pdf for derivations. 31 | 32 | @taylor(THRES) 33 | def sinc(x, x2): 34 | """ sin(x)/x """ 35 | texpand = 1-x2/6*(1-x2/20*(1-x2/42)) 36 | full = sin(x) / x 37 | return texpand, full 38 | 39 | @taylor(THRES) 40 | def sincc(x, x2): 41 | """ (1-sinc(x))/x^2""" 42 | texpand = 1/6*(1-x2/20*(1-x2/42*(1-x2/72))) 43 | full = (x-sin(x)) / x**3 44 | return texpand, full 45 | 46 | @taylor(THRES) 47 | def cosc(x, x2): 48 | """ (1-cos(x))/x^2""" 49 | texpand = 1/2*(1-x2/12*(1-x2/30*(1-x2/56))) 50 | full = (1-cos(x)) / x2 51 | return texpand, full 52 | 53 | @taylor(THRES) 54 | def coscc(x, x2): 55 | #assert not torch.any(torch.isinf(x2)), f"infs in x2 log" 56 | texpand = 1/12*(1+x2/60*(1+x2/42*(1+x2/40))) 57 | costerm = (2*(1-cos(x))).clamp(min=1e-6) 58 | full = (1-x*sin(x)/costerm) / x2 #Nans can come up here when cos = 1 59 | return texpand, full 60 | 61 | @taylor(THRES) 62 | def sinc_inv(x, _): 63 | texpand = 1+(1/6)*x**2 +(7/360)*x**4 64 | full = x / sin(x) 65 | assert not torch.any(torch.isinf(texpand)|torch.isnan(texpand)),'sincinv texpand inf'+torch.any(torch.isinf(texpand)) 66 | return texpand, full 67 | 68 | ## Lie Groups acting on R3 69 | 70 | # Hodge star on R3 71 | def cross_matrix(k): 72 | """Application of hodge star on R3, mapping Λ^1 R3 -> Λ^2 R3""" 73 | K = torch.zeros(*k.shape[:-1], 3, 3, **to(k)) 74 | K[...,0,1] = -k[...,2] 75 | K[...,0,2] = k[...,1] 76 | K[...,1,0] = k[...,2] 77 | K[...,1,2] = -k[...,0] 78 | K[...,2,0] = -k[...,1] 79 | K[...,2,1] = k[...,0] 80 | return K 81 | 82 | def uncross_matrix(K): 83 | """Application of hodge star on R3, mapping Λ^2 R3 -> Λ^1 R3""" 84 | k = torch.zeros(*K.shape[:-1], **to(K)) 85 | k[...,0] = (K[...,2,1] - K[...,1,2])/2 86 | k[...,1] = (K[...,0,2] - K[...,2,0])/2 87 | k[...,2] = (K[...,1,0] - K[...,0,1])/2 88 | return k 89 | 90 | class SO3: 91 | lie_dim = 3 92 | rep_dim = 3 93 | q_dim = 1 94 | 95 | def __init__(self, alpha = .2): 96 | super().__init__() 97 | self.alpha = alpha 98 | 99 | def exp(self,w): 100 | """ Computes (matrix) exponential Lie algebra elements (in a given basis). 101 | ie out = exp(\sum_i a_i A_i) where A_i are the exponential generators of G. 102 | Input: [a (*,lie_dim)] where * is arbitrarily shaped 103 | Output: [exp(a) (*,rep_dim,rep_dim)] returns the matrix for each.""" 104 | 105 | """ Rodriguez's formula, assuming shape (*,3) 106 | where components 1,2,3 are the generators for xrot,yrot,zrot""" 107 | theta = w.norm(dim=-1)[..., None, None] 108 | K = cross_matrix(w) 109 | I = torch.eye(3, **to(K)) 110 | Rs = I + K * sinc(theta) + (K @ K) * cosc(theta) 111 | return Rs 112 | 113 | def log(self,R): 114 | """ Computes components in terms of generators rx,ry,rz. Shape (*,3,3)""" 115 | 116 | """ Computes (matrix) logarithm for collection of matrices and converts to Lie algebra basis. 117 | Input [u (*,rep_dim,rep_dim)] 118 | Output [coeffs of log(u) in basis (*,d)] """ 119 | trR = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2] 120 | costheta = ((trR-1) / 2).clamp(max=1, min=-1).unsqueeze(-1) 121 | theta = acos(costheta) 122 | logR = uncross_matrix(R) * sinc_inv(theta) 123 | return logR 124 | 125 | def inv(self,g): 126 | """ We can compute the inverse of elements g (*,rep_dim,rep_dim) as exp(-log(g))""" 127 | return self.exp(-self.log(g)) 128 | 129 | def elems2pairs(self,a): 130 | """ computes log(e^-b e^a) for all a b pairs along n dimension of input. 131 | inputs: [a (bs,n,d)] outputs: [pairs_ab (bs,n,n,d)] """ 132 | vinv = self.exp(-a.unsqueeze(-3)) 133 | u = self.exp(a.unsqueeze(-2)) 134 | return self.log(vinv@u) # ((bs,1,n,d) -> (bs,1,n,r,r))@((bs,n,1,d) -> (bs,n,1,r,r)) 135 | 136 | def lift(self, x, nsamples, **kwargs): 137 | """assumes p has shape (*,n,2), vals has shape (*,n,c), mask has shape (*,n) 138 | returns (a,v) with shapes [(*,n*nsamples,lie_dim),(*,n*nsamples,c)""" 139 | p, v, m, e = x 140 | expanded_a = self.lifted_elems(p,nsamples,**kwargs) # (bs,n*ns,d), (bs,n*ns,qd) 141 | nsamples = expanded_a.shape[-2]//m.shape[-1] 142 | # expand v and mask like q 143 | expanded_v = repeat(v, 'b n c -> b (n m) c', m = nsamples) # (bs,n,c) -> (bs,n,1,c) -> (bs,n,ns,c) -> (bs,n*ns,c) 144 | expanded_mask = repeat(m, 'b n -> b (n m)', m = nsamples) # (bs,n) -> (bs,n,ns) -> (bs,n*ns) 145 | expanded_e = repeat(e, 'b n1 n2 c -> b (n1 m1) (n2 m2) c', m1 = nsamples, m2 = nsamples) if exists(e) else None 146 | 147 | # convert from elems to pairs 148 | paired_a = self.elems2pairs(expanded_a) #(bs,n*ns,d) -> (bs,n*ns,n*ns,d) 149 | embedded_locations = paired_a 150 | return (embedded_locations,expanded_v,expanded_mask, expanded_e) 151 | 152 | class SE3(SO3): 153 | lie_dim = 6 154 | rep_dim = 4 155 | q_dim = 0 156 | 157 | def __init__(self, alpha=.2, per_point=True): 158 | super().__init__() 159 | self.alpha = alpha 160 | self.per_point = per_point 161 | 162 | def exp(self,w): 163 | dd_kwargs = to(w) 164 | theta = w[...,:3].norm(dim=-1)[...,None,None] 165 | K = cross_matrix(w[...,:3]) 166 | R = super().exp(w[...,:3]) 167 | I = torch.eye(3, **dd_kwargs) 168 | V = I + cosc(theta)*K + sincc(theta)*(K@K) 169 | U = torch.zeros(*w.shape[:-1],4,4, **dd_kwargs) 170 | U[...,:3,:3] = R 171 | U[...,:3,3] = (V@w[...,3:].unsqueeze(-1)).squeeze(-1) 172 | U[...,3,3] = 1 173 | return U 174 | 175 | def log(self,U): 176 | w = super().log(U[..., :3, :3]) 177 | I = torch.eye(3, **to(w)) 178 | K = cross_matrix(w[..., :3]) 179 | theta = w.norm(dim=-1)[..., None, None]#%(2*pi) 180 | #theta[theta>pi] -= 2*pi 181 | cosccc = coscc(theta) 182 | Vinv = I - K/2 + cosccc*(K@K) 183 | u = (Vinv @ U[..., :3, 3].unsqueeze(-1)).squeeze(-1) 184 | #assert not torch.any(torch.isnan(u)), f"nans in u log {torch.isnan(u).sum()}, {torch.where(torch.isnan(u))}" 185 | return torch.cat([w, u], dim=-1) 186 | 187 | def lifted_elems(self,pt,nsamples): 188 | """ pt (bs,n,D) mask (bs,n), per_point specifies whether to 189 | use a different group element per atom in the molecule""" 190 | #return farthest_lift(self,pt,mask,nsamples,alpha) 191 | # same lifts for each point right now 192 | bs,n = pt.shape[:2] 193 | dd_kwargs = to(pt) 194 | 195 | q = torch.randn(bs, (n if self.per_point else 1), nsamples, 4, **dd_kwargs) 196 | q /= q.norm(dim=-1).unsqueeze(-1) 197 | 198 | theta_2 = atan2(q[..., 1:].norm(dim=-1),q[..., 0])[..., None] 199 | so3_elem = 2 * sinc_inv(theta_2) * q[...,1:] # (sin(x/2)u -> xu) for x angle and u direction 200 | se3_elem = torch.cat([so3_elem, torch.zeros_like(so3_elem)], dim=-1) 201 | R = self.exp(se3_elem) 202 | 203 | T = torch.zeros(bs, n, nsamples, 4, 4, **dd_kwargs) # (bs,n,nsamples,4,4) 204 | T[..., :, :] = torch.eye(4, **dd_kwargs) 205 | T[..., :3, 3] = pt[..., None, :] # (bs,n,1,3) 206 | 207 | a = self.log(T @ R) # bs, n, nsamples, 6 208 | return a.reshape(bs, n * nsamples, 6) 209 | 210 | def distance(self,abq_pairs): 211 | dist_rot = abq_pairs[...,:3].norm(dim=-1) 212 | dist_trans = abq_pairs[...,3:].norm(dim=-1) 213 | return dist_rot * self.alpha + (1-self.alpha) * dist_trans 214 | -------------------------------------------------------------------------------- /lie_transformer_pytorch/lie_transformer_pytorch.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, einsum 7 | 8 | from lie_transformer_pytorch.se3 import SE3 9 | from einops import rearrange, repeat 10 | 11 | from lie_transformer_pytorch.reversible import SequentialSequence, ReversibleSequence 12 | 13 | # helpers 14 | 15 | def exists(val): 16 | return val is not None 17 | 18 | def cast_tuple(val, depth): 19 | return val if isinstance(val, tuple) else ((val,) * depth) 20 | 21 | def default(val, d): 22 | return val if exists(val) else d 23 | 24 | def batched_index_select(values, indices, dim = 1): 25 | value_dims = values.shape[(dim + 1):] 26 | values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices)) 27 | indices = indices[(..., *((None,) * len(value_dims)))] 28 | indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims) 29 | value_expand_len = len(indices_shape) - (dim + 1) 30 | values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)] 31 | 32 | value_expand_shape = [-1] * len(values.shape) 33 | expand_slice = slice(dim, (dim + value_expand_len)) 34 | value_expand_shape[expand_slice] = indices.shape[expand_slice] 35 | values = values.expand(*value_expand_shape) 36 | 37 | dim += value_expand_len 38 | return values.gather(dim, indices) 39 | 40 | # helper classes 41 | 42 | class Pass(nn.Module): 43 | def __init__(self, fn, dim = 1): 44 | super().__init__() 45 | self.fn = fn 46 | self.dim = dim 47 | 48 | def forward(self,x): 49 | dim = self.dim 50 | xs = list(x) 51 | xs[dim] = self.fn(xs[dim]) 52 | return xs 53 | 54 | class Lambda(nn.Module): 55 | def __init__(self, fn): 56 | super().__init__() 57 | self.fn = fn 58 | 59 | def forward(self, x): 60 | return self.fn(x) 61 | 62 | 63 | class GlobalPool(nn.Module): 64 | """computes values reduced over all spatial locations (& group elements) in the mask""" 65 | def __init__(self, mean = False): 66 | super().__init__() 67 | self.mean = mean 68 | 69 | def forward(self, x): 70 | coords, vals, mask = x 71 | 72 | if not exists(mask): 73 | return val.mean(dim = 1) 74 | 75 | masked_vals = vals.masked_fill_(~mask[..., None], 0.) 76 | summed = masked_vals.sum(dim = 1) 77 | 78 | if not self.mean: 79 | return summed 80 | 81 | count = mask.sum(-1).unsqueeze(-1) 82 | return summed / count 83 | 84 | # subsampling code 85 | 86 | def FPSindices(dists, frac, mask): 87 | """ inputs: pairwise distances DISTS (bs,n,n), downsample_frac (float), valid atom mask (bs,n) 88 | outputs: chosen_indices (bs,m) """ 89 | m = int(round(frac * dists.shape[1])) 90 | bs, n, device = *dists.shape[:2], dists.device 91 | dd_kwargs = {'device': device, 'dtype': torch.long} 92 | B = torch.arange(bs, **dd_kwargs) 93 | 94 | chosen_indices = torch.zeros(bs, m, **dd_kwargs) 95 | distances = torch.ones(bs, n, device=device) * 1e8 96 | a = torch.randint(0, n, (bs,), **dd_kwargs) # choose random start 97 | idx = a % mask.sum(-1) + torch.cat([torch.zeros(1, **dd_kwargs), torch.cumsum(mask.sum(-1), dim=0)[:-1]], dim=0) 98 | farthest = torch.where(mask)[1][idx] 99 | 100 | for i in range(m): 101 | chosen_indices[:, i] = farthest # add point that is farthest to chosen 102 | dist = dists[B, farthest].masked_fill(~mask, -100) # (bs,n) compute distance from new point to all others 103 | closer = dist < distances # if dist from new point is smaller than chosen points so far 104 | distances[closer] = dist[closer] # update the chosen set's distance to all other points 105 | farthest = torch.max(distances, -1)[1] # select the point that is farthest from the set 106 | 107 | return chosen_indices 108 | 109 | 110 | class FPSsubsample(nn.Module): 111 | def __init__(self, ds_frac, cache = False, group = None): 112 | super().__init__() 113 | self.ds_frac = ds_frac 114 | self.cache = cache 115 | self.cached_indices = None 116 | self.group = default(group, SE3()) 117 | 118 | def get_query_indices(self, abq_pairs, mask): 119 | if self.cache and exists(self.cached_indices): 120 | return self.cached_indices 121 | 122 | dist = self.group.distance if self.group else lambda ab: ab.norm(dim=-1) 123 | value = FPSindices(dist(abq_pairs), self.ds_frac, mask).detach() 124 | 125 | if self.cache: 126 | self.cached_indices = value 127 | 128 | return value 129 | 130 | def forward(self, inp, withquery=False): 131 | abq_pairs, vals, mask, edges = inp 132 | device = vals.device 133 | 134 | if self.ds_frac != 1: 135 | query_idx = self.get_query_indices(abq_pairs, mask) 136 | 137 | B = torch.arange(query_idx.shape[0], device = device).long()[:,None] 138 | subsampled_abq_pairs = abq_pairs[B, query_idx][B, :, query_idx] 139 | subsampled_values = batched_index_select(vals, query_idx, dim = 1) 140 | subsampled_mask = batched_index_select(mask, query_idx, dim = 1) 141 | subsampled_edges = edges[B, query_idx][B, :, query_idx] if exists(edges) else None 142 | else: 143 | subsampled_abq_pairs = abq_pairs 144 | subsampled_values = vals 145 | subsampled_mask = mask 146 | subsampled_edges = edges 147 | query_idx = None 148 | 149 | ret = ( 150 | subsampled_abq_pairs, 151 | subsampled_values, 152 | subsampled_mask, 153 | subsampled_edges 154 | ) 155 | 156 | if withquery: 157 | ret = (*ret, query_idx) 158 | 159 | return ret 160 | 161 | # lie attention 162 | 163 | class LieSelfAttention(nn.Module): 164 | def __init__( 165 | self, 166 | dim, 167 | edge_dim = None, 168 | group = None, 169 | mc_samples = 32, 170 | ds_frac = 1, 171 | fill = 1 / 3, 172 | dim_head = 64, 173 | heads = 8, 174 | cache = False 175 | ): 176 | super().__init__() 177 | self.dim = dim 178 | 179 | self.mc_samples = mc_samples # number of samples to use to estimate convolution 180 | self.group = default(group, SE3()) # Equivariance group for LieConv 181 | self.register_buffer('r',torch.tensor(2.)) # Internal variable for local_neighborhood radius, set by fill 182 | self.fill_frac = min(fill, 1.) # Average Fraction of the input which enters into local_neighborhood, determines r 183 | 184 | self.subsample = FPSsubsample(ds_frac, cache = cache, group = self.group) 185 | self.coeff = .5 # Internal coefficient used for updating r 186 | 187 | self.fill_frac_ema = fill # Keeps track of average fill frac, used for logging only 188 | 189 | # attention related parameters 190 | 191 | inner_dim = dim_head * heads 192 | self.heads = heads 193 | 194 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 195 | self.to_k = nn.Linear(dim, inner_dim, bias = False) 196 | self.to_v = nn.Linear(dim, inner_dim, bias = False) 197 | self.to_out = nn.Linear(inner_dim, dim) 198 | 199 | edge_dim = default(edge_dim, 0) 200 | edge_dim_in = self.group.lie_dim + edge_dim 201 | 202 | self.loc_attn_mlp = nn.Sequential( 203 | nn.Linear(edge_dim_in, edge_dim_in * 4), 204 | nn.ReLU(), 205 | nn.Linear(edge_dim_in * 4, 1), 206 | ) 207 | 208 | def extract_neighborhood(self, inp, query_indices): 209 | """ inputs: [pairs_abq (bs,n,n,d), inp_vals (bs,n,c), mask (bs,n), query_indices (bs,m)] 210 | outputs: [neighbor_abq (bs,m,mc_samples,d), neighbor_vals (bs,m,mc_samples,c)]""" 211 | 212 | # Subsample pairs_ab, inp_vals, mask to the query_indices 213 | pairs_abq, inp_vals, mask, edges = inp 214 | device = inp_vals.device 215 | 216 | if exists(query_indices): 217 | abq_at_query = batched_index_select(pairs_abq, query_indices, dim = 1) 218 | mask_at_query = batched_index_select(mask, query_indices, dim = 1) 219 | edges_at_query = batched_index_select(edges, query_indices, dim = 1) if exists(edges) else None 220 | else: 221 | abq_at_query = pairs_abq 222 | mask_at_query = mask 223 | edges_at_query = edges 224 | 225 | mask_at_query = mask_at_query[..., None] 226 | 227 | vals_at_query = inp_vals 228 | dists = self.group.distance(abq_at_query) #(bs,m,n,d) -> (bs,m,n) 229 | mask_value = torch.finfo(dists.dtype).max 230 | dists = dists.masked_fill(mask[:,None,:], mask_value) 231 | 232 | k = min(self.mc_samples, inp_vals.shape[1]) 233 | 234 | # NBHD: Sampled Distance Ball 235 | bs, m, n = dists.shape 236 | within_ball = (dists < self.r) & mask[:,None,:] & mask_at_query # (bs,m,n) 237 | noise = torch.zeros((bs, m, n), device = device).uniform_(0, 1) 238 | valid_within_ball, nbhd_idx = torch.topk(within_ball + noise, k, dim=-1, sorted=False) 239 | valid_within_ball = (valid_within_ball > 1) 240 | 241 | # Retrieve abq_pairs, values, and mask at the nbhd locations 242 | 243 | nbhd_abq = batched_index_select(abq_at_query, nbhd_idx, dim = 2) 244 | nbhd_vals = batched_index_select(vals_at_query, nbhd_idx, dim = 1) 245 | nbhd_mask = batched_index_select(mask, nbhd_idx, dim = 1) 246 | nbhd_edges = batched_index_select(edges_at_query, nbhd_idx, dim = 2) if exists(edges) else None 247 | 248 | if self.training: # update ball radius to match fraction fill_frac inside 249 | navg = (within_ball.float()).sum(-1).sum() / mask_at_query.sum() 250 | avg_fill = (navg / mask.sum(-1).float().mean()).cpu().item() 251 | self.r += self.coeff * (self.fill_frac - avg_fill) 252 | self.fill_frac_ema += .1 * (avg_fill-self.fill_frac_ema) 253 | 254 | nbhd_mask &= valid_within_ball.bool() 255 | 256 | return nbhd_abq, nbhd_vals, nbhd_mask, nbhd_edges, nbhd_idx 257 | 258 | def forward(self, inp): 259 | """inputs: [pairs_abq (bs,n,n,d)], [inp_vals (bs,n,ci)]), [query_indices (bs,m)] 260 | outputs [subsampled_abq (bs,m,m,d)], [convolved_vals (bs,m,co)]""" 261 | sub_abq, sub_vals, sub_mask, sub_edges, query_indices = self.subsample(inp, withquery = True) 262 | nbhd_abq, nbhd_vals, nbhd_mask, nbhd_edges, nbhd_indices = self.extract_neighborhood(inp, query_indices) 263 | 264 | h, b, n, d, device = self.heads, *sub_vals.shape, sub_vals.device 265 | 266 | q, k, v = self.to_q(sub_vals), self.to_k(nbhd_vals), self.to_v(nbhd_vals) 267 | 268 | q = rearrange(q, 'b n (h d) -> b h n d', h = h) 269 | k, v = map(lambda t: rearrange(t, 'b n m (h d) -> b h n m d', h = h), (k, v)) 270 | 271 | sim = einsum('b h i d, b h i j d -> b h i j', q, k) * (q.shape[-1] ** -0.5) 272 | 273 | edges = nbhd_abq 274 | if exists(nbhd_edges): 275 | edges = torch.cat((nbhd_abq, nbhd_edges), dim = -1) 276 | 277 | loc_attn = self.loc_attn_mlp(edges) 278 | loc_attn = rearrange(loc_attn, 'b i j () -> b () i j') 279 | sim = sim + loc_attn 280 | 281 | mask_value = -torch.finfo(sim.dtype).max 282 | 283 | sim.masked_fill_(~rearrange(nbhd_mask, 'b n m -> b () n m'), mask_value) 284 | 285 | attn = sim.softmax(dim = -1) 286 | out = einsum('b h i j, b h i j d -> b h i d', attn, v) 287 | out = rearrange(out, 'b h n d -> b n (h d)', h = h) 288 | combined = self.to_out(out) 289 | 290 | return sub_abq, combined, sub_mask, sub_edges 291 | 292 | class LieSelfAttentionWrapper(nn.Module): 293 | def __init__(self, dim, attn): 294 | super().__init__() 295 | self.dim = dim 296 | self.attn = attn 297 | 298 | self.net = nn.Sequential( 299 | Pass(nn.LayerNorm(dim)), 300 | self.attn 301 | ) 302 | 303 | def forward(self, inp): 304 | sub_coords, sub_values, mask, edges = self.attn.subsample(inp) 305 | new_coords, new_values, mask, edges = self.net(inp) 306 | new_values[..., :self.dim] += sub_values 307 | return new_coords, new_values, mask, edges 308 | 309 | class FeedForward(nn.Module): 310 | def __init__(self, dim, mult = 4): 311 | super().__init__() 312 | self.dim = dim 313 | 314 | self.net = nn.Sequential( 315 | Pass(nn.LayerNorm(dim)), 316 | Pass(nn.Linear(dim, mult * dim)), 317 | Pass(nn.GELU()), 318 | Pass(nn.Linear(mult * dim, dim)), 319 | ) 320 | 321 | def forward(self,inp): 322 | sub_coords, sub_values, mask, edges = inp 323 | new_coords, new_values, mask, edges = self.net(inp) 324 | new_values = new_values + sub_values 325 | return new_coords, new_values, mask, edges 326 | 327 | # transformer class 328 | 329 | class LieTransformer(nn.Module): 330 | """ 331 | [Fill] specifies the fraction of the input which is included in local neighborhood. 332 | (can be array to specify a different value for each layer) 333 | [nbhd] number of samples to use for Monte Carlo estimation (p) 334 | [dim] number of input channels: 1 for MNIST, 3 for RGB images, other for non images 335 | [ds_frac] total downsampling to perform throughout the layers of the net. In (0,1) 336 | [num_layers] number of BottleNeck Block layers in the network 337 | [k] channel width for the network. Can be int (same for all) or array to specify individually. 338 | [liftsamples] number of samples to use in lifting. 1 for all groups with trivial stabilizer. Otherwise 2+ 339 | [Group] Chosen group to be equivariant to. 340 | """ 341 | def __init__( 342 | self, 343 | dim, 344 | num_tokens = None, 345 | num_edge_types = None, 346 | edge_dim = None, 347 | heads = 8, 348 | dim_head = 64, 349 | depth = 2, 350 | ds_frac = 1., 351 | dim_out = None, 352 | k = 1536, 353 | nbhd = 128, 354 | mean = True, 355 | per_point = True, 356 | liftsamples = 4, 357 | fill = 1 / 4, 358 | cache = False, 359 | reversible = False, 360 | **kwargs 361 | ): 362 | super().__init__() 363 | assert not (ds_frac < 1 and reversible), 'must not downsample if network is reversible' 364 | 365 | dim_out = default(dim_out, dim) 366 | self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None 367 | self.edge_emb = nn.Embedding(num_edge_types, edge_dim) if exists(num_edge_types) else None 368 | 369 | group = SE3() 370 | self.group = group 371 | self.liftsamples = liftsamples 372 | 373 | layers_fill = cast_tuple(fill, depth) 374 | layers = nn.ModuleList([]) 375 | 376 | for _, layer_fill in zip(range(depth), layers_fill): 377 | layers.append(nn.ModuleList([ 378 | LieSelfAttentionWrapper(dim, LieSelfAttention(dim, heads = heads, dim_head = dim_head, edge_dim = edge_dim, mc_samples = nbhd, ds_frac = ds_frac, group = group, fill = fill, cache = cache,**kwargs)), 379 | FeedForward(dim) 380 | ])) 381 | 382 | execute_class = ReversibleSequence if reversible else SequentialSequence 383 | self.net = execute_class(layers) 384 | 385 | self.to_logits = nn.Sequential( 386 | Pass(nn.LayerNorm(dim)), 387 | Pass(nn.Linear(dim, dim_out)) 388 | ) 389 | 390 | self.pool = GlobalPool(mean = mean) 391 | 392 | def forward(self, feats, coors, edges = None, mask = None, return_pooled = False): 393 | b, n, *_ = feats.shape 394 | 395 | if exists(self.token_emb): 396 | feats = self.token_emb(feats) 397 | 398 | if exists(self.edge_emb): 399 | assert exists(edges), 'edges must be passed in on forward' 400 | assert edges.shape[1] == edges.shape[2] and edges.shape[1] == n, f'edges must be of the shape ({b}, {n}, {n})' 401 | edges = self.edge_emb(edges) 402 | 403 | inps = (coors, feats, mask, edges) 404 | 405 | lifted_x = self.group.lift(inps, self.liftsamples) 406 | out = self.net(lifted_x) 407 | 408 | out = self.to_logits(out) 409 | 410 | if not return_pooled: 411 | features = out[1] 412 | return features 413 | 414 | return self.pool(out) 415 | --------------------------------------------------------------------------------