├── .github └── workflows │ ├── python-publish.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── frame-averaging.png ├── frame_averaging_pytorch ├── __init__.py └── frame_averaging.py ├── pyproject.toml └── tests └── test_frame_average.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Pytest 2 | on: [push, pull_request] 3 | 4 | jobs: 5 | build: 6 | 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - uses: actions/checkout@v4 11 | - name: Set up Python 3.10 12 | uses: actions/setup-python@v5 13 | with: 14 | python-version: "3.10" 15 | - name: Install dependencies 16 | run: | 17 | python -m pip install --upgrade pip 18 | python -m pip install -e .[test] 19 | - name: Test with pytest 20 | run: | 21 | python -m pytest tests/ 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Frame Averaging - Pytorch 4 | 5 | Pytorch implementation of a simple way to enable (Stochastic) Frame Averaging for any network. This technique was recently adopted by Prescient Design in AbDiffuser 6 | 7 | ## Install 8 | 9 | ```bash 10 | $ pip install frame-averaging-pytorch 11 | ``` 12 | 13 | ## Usage 14 | 15 | ```python 16 | import torch 17 | from frame_averaging_pytorch import FrameAverage 18 | 19 | # contrived neural network 20 | 21 | net = torch.nn.Linear(3, 3) 22 | 23 | # wrap the network with FrameAverage 24 | 25 | net = FrameAverage( 26 | net, 27 | dim = 3, # defaults to 3 for spatial, but can be any value 28 | stochastic = True # whether to use stochastic variant from FAENet (one frame sampled at random) 29 | ) 30 | 31 | # pass your input to the network as usual 32 | 33 | points = torch.randn(4, 1024, 3) 34 | mask = torch.ones(4, 1024).bool() 35 | 36 | out = net(points, frame_average_mask = mask) 37 | 38 | out.shape # (4, 1024, 3) 39 | 40 | # frame averaging is automatically taken care of, as though the network were unwrapped 41 | ``` 42 | 43 | or you can also carry it out manually 44 | 45 | ```python 46 | import torch 47 | from frame_averaging_pytorch import FrameAverage 48 | 49 | # contrived neural network 50 | 51 | net = torch.nn.Linear(3, 3) 52 | 53 | # frame average module without passing in network 54 | 55 | fa = FrameAverage() 56 | 57 | # pass the 3d points and mask to FrameAverage forward 58 | 59 | points = torch.randn(4, 1024, 3) 60 | mask = torch.ones(4, 1024).bool() 61 | 62 | framed_inputs, frame_average_fn = fa(points, frame_average_mask = mask) 63 | 64 | # network forward 65 | 66 | net_out = net(framed_inputs) 67 | 68 | # frame average 69 | 70 | frame_averaged = frame_average_fn(net_out) 71 | 72 | frame_averaged.shape # (4, 1024, 3) 73 | ``` 74 | 75 | ## Citations 76 | 77 | ```bibtex 78 | @article{Puny2021FrameAF, 79 | title = {Frame Averaging for Invariant and Equivariant Network Design}, 80 | author = {Omri Puny and Matan Atzmon and Heli Ben-Hamu and Edward James Smith and Ishan Misra and Aditya Grover and Yaron Lipman}, 81 | journal = {ArXiv}, 82 | year = {2021}, 83 | volume = {abs/2110.03336}, 84 | url = {https://api.semanticscholar.org/CorpusID:238419638} 85 | } 86 | ``` 87 | 88 | ```bibtex 89 | @article{Duval2023FAENetFA, 90 | title = {FAENet: Frame Averaging Equivariant GNN for Materials Modeling}, 91 | author = {Alexandre Duval and Victor Schmidt and Alex Hernandez Garcia and Santiago Miret and Fragkiskos D. Malliaros and Yoshua Bengio and David Rolnick}, 92 | journal = {ArXiv}, 93 | year = {2023}, 94 | volume = {abs/2305.05577}, 95 | url = {https://api.semanticscholar.org/CorpusID:258564608} 96 | } 97 | ``` 98 | -------------------------------------------------------------------------------- /frame-averaging.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/frame-averaging-pytorch/746032ba57b44038ac3fbfaa1632af86e69b1050/frame-averaging.png -------------------------------------------------------------------------------- /frame_averaging_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from frame_averaging_pytorch.frame_averaging import ( 2 | FrameAverage 3 | ) 4 | 5 | __all__ = [ 6 | FrameAverage 7 | ] 8 | -------------------------------------------------------------------------------- /frame_averaging_pytorch/frame_averaging.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from random import randrange 4 | 5 | import torch 6 | from torch.nn import Module 7 | from torch.utils._pytree import tree_map 8 | 9 | from einops import rearrange, repeat, reduce, einsum 10 | 11 | # helper functions 12 | 13 | def exists(v): 14 | return v is not None 15 | 16 | def default(v, d): 17 | return v if exists(v) else d 18 | 19 | # main class 20 | 21 | class FrameAverage(Module): 22 | def __init__( 23 | self, 24 | net: Module | None = None, 25 | dim = 3, 26 | stochastic = False, 27 | invariant_output = False, 28 | return_stochastic_as_augmented_pos = False # will simply return points as augmented points of same shape on forward 29 | ): 30 | super().__init__() 31 | self.net = net 32 | 33 | assert dim > 1 34 | 35 | self.dim = dim 36 | self.num_frames = 2 ** dim 37 | 38 | # frames are all permutations of the positive (+1) and negative (-1) eigenvectors for each dimension, iiuc 39 | # so there will be 2 ^ dim frames 40 | 41 | directions = torch.tensor([-1, 1]) 42 | 43 | colon = slice(None) 44 | accum = [] 45 | 46 | for ind in range(dim): 47 | dim_slice = [None] * dim 48 | dim_slice[ind] = colon 49 | 50 | accum.append(directions[dim_slice]) 51 | 52 | accum = torch.broadcast_tensors(*accum) 53 | operations = torch.stack(accum, dim = -1) 54 | operations = rearrange(operations, '... d -> (...) d') 55 | 56 | assert operations.shape == (self.num_frames, dim) 57 | 58 | self.register_buffer('operations', operations) 59 | 60 | # whether to use stochastic frame averaging 61 | # proposed in https://arxiv.org/abs/2305.05577 62 | # one frame is selected at random 63 | 64 | self.stochastic = stochastic 65 | self.return_stochastic_as_augmented_pos = return_stochastic_as_augmented_pos 66 | 67 | # invariant output setting 68 | 69 | self.invariant_output = invariant_output 70 | 71 | def forward( 72 | self, 73 | points, 74 | *args, 75 | frame_average_mask = None, 76 | return_framed_inputs_and_averaging_function = False, 77 | **kwargs, 78 | ): 79 | """ 80 | b - batch 81 | n - sequence 82 | d - dimension (input or source) 83 | e - dimension (target) 84 | f - frames 85 | """ 86 | 87 | assert points.shape[-1] == self.dim, f'expected points of dimension {self.dim}, but received {points.shape[-1]}' 88 | 89 | # account for variable lengthed points 90 | 91 | if exists(frame_average_mask): 92 | frame_average_mask = rearrange(frame_average_mask, '... -> ... 1') 93 | points = points * frame_average_mask 94 | 95 | # shape must end with (batch, seq, dim) 96 | 97 | batch, seq_dim, input_dim = points.shape 98 | 99 | # frame averaging logic 100 | 101 | if exists(frame_average_mask): 102 | num = reduce(points, 'b n d -> b 1 d', 'sum') 103 | den = reduce(frame_average_mask.float(), 'b n 1 -> b 1 1', 'sum') 104 | centroid = num / den.clamp(min = 1) 105 | else: 106 | centroid = reduce(points, 'b n d -> b 1 d', 'mean') 107 | 108 | centered_points = points - centroid 109 | 110 | if exists(frame_average_mask): 111 | centered_points = centered_points * frame_average_mask 112 | 113 | covariance = einsum(centered_points, centered_points, 'b n d, b n e -> b d e') 114 | 115 | _, eigenvectors = torch.linalg.eigh(covariance) 116 | 117 | # if stochastic, just select one random operation 118 | 119 | num_frames = self.num_frames 120 | operations = self.operations 121 | 122 | if self.stochastic: 123 | rand_frame_index = randrange(self.num_frames) 124 | 125 | operations = operations[rand_frame_index:(rand_frame_index + 1)] 126 | num_frames = 1 127 | 128 | # frames 129 | 130 | frames = rearrange(eigenvectors, 'b d e -> b 1 d e') * rearrange(operations, 'f e -> f 1 e') 131 | 132 | # inverse frame op 133 | 134 | inputs = einsum(frames, centered_points, 'b f d e, b n d -> b f n e') 135 | 136 | # define the frame averaging function 137 | 138 | def frame_average(out): 139 | if not self.invariant_output: 140 | # apply frames 141 | 142 | out = einsum(frames, out, 'b f d e, b f ... e -> b f ... d') 143 | 144 | if not self.stochastic: 145 | # averaging across frames, thus "frame averaging" 146 | 147 | out = reduce(out, 'b f ... -> b ...', 'mean') 148 | else: 149 | out = rearrange(out, 'b 1 ... -> b ...') 150 | 151 | return out 152 | 153 | # if one wants to handle the framed inputs externally 154 | 155 | if return_framed_inputs_and_averaging_function or not exists(self.net): 156 | 157 | if self.stochastic and self.return_stochastic_as_augmented_pos: 158 | return rearrange(inputs, 'b 1 ... -> b ...') 159 | 160 | return inputs, frame_average 161 | 162 | # merge frames into batch 163 | 164 | inputs = rearrange(inputs, 'b f ... -> (b f) ...') 165 | 166 | # if batch is expanded by number of frames, any tensor being passed in for args and kwargs needed to be expanded as well 167 | # automatically take care of this 168 | 169 | if not self.stochastic: 170 | args, kwargs = tree_map( 171 | lambda el: ( 172 | repeat(el, 'b ... -> (b f) ...', f = num_frames) 173 | if torch.is_tensor(el) 174 | else el 175 | ) 176 | , (args, kwargs)) 177 | 178 | # main network forward 179 | 180 | out = self.net(inputs, *args, **kwargs) 181 | 182 | # use tree map to handle multiple outputs 183 | 184 | out = tree_map(lambda t: rearrange(t, '(b f) ... -> b f ...', f = num_frames) if torch.is_tensor(t) else t, out) 185 | out = tree_map(lambda t: frame_average(t) if torch.is_tensor(t) else t, out) 186 | 187 | return out 188 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "frame-averaging-pytorch" 3 | version = "0.1.2" 4 | description = "Frame Averaging" 5 | authors = [ 6 | { name = "Phil Wang", email = "lucidrains@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">= 3.9" 10 | license = { file = "LICENSE" } 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'deep learning', 14 | 'geometric learning', 15 | ] 16 | 17 | classifiers=[ 18 | 'Development Status :: 4 - Beta', 19 | 'Intended Audience :: Developers', 20 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 21 | 'License :: OSI Approved :: MIT License', 22 | 'Programming Language :: Python :: 3.9', 23 | ] 24 | 25 | dependencies = [ 26 | "torch>=2.0", 27 | "einops>=0.8.0", 28 | ] 29 | 30 | [project.urls] 31 | Homepage = "https://pypi.org/project/frame-averaging-pytorch/" 32 | Repository = "https://github.com/lucidrains/frame-averaging-pytorch" 33 | 34 | [project.optional-dependencies] 35 | examples = [] 36 | test = [ 37 | "pytest" 38 | ] 39 | 40 | [tool.pytest.ini_options] 41 | pythonpath = [ 42 | "." 43 | ] 44 | 45 | [build-system] 46 | requires = ["hatchling"] 47 | build-backend = "hatchling.build" 48 | 49 | [tool.rye] 50 | managed = true 51 | dev-dependencies = [] 52 | 53 | [tool.hatch.metadata] 54 | allow-direct-references = true 55 | 56 | [tool.hatch.build.targets.wheel] 57 | packages = ["frame_averaging_pytorch"] 58 | -------------------------------------------------------------------------------- /tests/test_frame_average.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import Module 6 | from frame_averaging_pytorch import FrameAverage 7 | 8 | @pytest.mark.parametrize('stochastic', (True, False)) 9 | @pytest.mark.parametrize('dim', (2, 3, 4)) 10 | @pytest.mark.parametrize('has_mask', (True, False)) 11 | def test_frame_average( 12 | stochastic: bool, 13 | dim: int, 14 | has_mask: bool 15 | ): 16 | 17 | net = torch.nn.Linear(dim, dim) 18 | 19 | net = FrameAverage( 20 | net, 21 | dim = dim, 22 | stochastic = stochastic 23 | ) 24 | 25 | points = torch.randn(4, 1024, dim) 26 | 27 | mask = None 28 | if has_mask: 29 | mask = torch.ones(4, 1024).bool() 30 | 31 | out = net(points, frame_average_mask = mask) 32 | assert out.shape == points.shape 33 | 34 | def test_frame_average_manual(): 35 | 36 | net = torch.nn.Linear(3, 3) 37 | 38 | fa = FrameAverage() 39 | points = torch.randn(4, 1024, 3) 40 | 41 | framed_inputs, frame_average_fn = fa(points) 42 | 43 | net_out = net(framed_inputs) 44 | 45 | frame_averaged = frame_average_fn(net_out) 46 | 47 | assert frame_averaged.shape == points.shape 48 | 49 | def test_frame_average_multiple_inputs_and_outputs(): 50 | 51 | class Network(Module): 52 | def __init__(self): 53 | super().__init__() 54 | self.net = nn.Linear(3, 3) 55 | self.to_out1 = nn.Linear(3, 3) 56 | self.to_out2 = nn.Linear(3, 3) 57 | 58 | def forward(self, x, mask): 59 | x = x.masked_fill(~mask[..., None], 0.) 60 | hidden = self.net(x) 61 | return 0., self.to_out1(hidden), self.to_out2(hidden) 62 | 63 | net = Network() 64 | net = FrameAverage(net) 65 | 66 | points = torch.randn(4, 1024, 3) 67 | mask = torch.ones(4, 1024).bool() 68 | 69 | _, out1, out2 = net(points, mask, frame_average_mask = mask) 70 | 71 | assert out1.shape == out2.shape == points.shape 72 | --------------------------------------------------------------------------------