├── .github
└── workflows
│ ├── python-publish.yml
│ └── test.yml
├── .gitignore
├── LICENSE
├── README.md
├── frame-averaging.png
├── frame_averaging_pytorch
├── __init__.py
└── frame_averaging.py
├── pyproject.toml
└── tests
└── test_frame_average.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Pytest
2 | on: [push, pull_request]
3 |
4 | jobs:
5 | build:
6 |
7 | runs-on: ubuntu-latest
8 |
9 | steps:
10 | - uses: actions/checkout@v4
11 | - name: Set up Python 3.10
12 | uses: actions/setup-python@v5
13 | with:
14 | python-version: "3.10"
15 | - name: Install dependencies
16 | run: |
17 | python -m pip install --upgrade pip
18 | python -m pip install -e .[test]
19 | - name: Test with pytest
20 | run: |
21 | python -m pytest tests/
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Frame Averaging - Pytorch
4 |
5 | Pytorch implementation of a simple way to enable (Stochastic) Frame Averaging for any network. This technique was recently adopted by Prescient Design in AbDiffuser
6 |
7 | ## Install
8 |
9 | ```bash
10 | $ pip install frame-averaging-pytorch
11 | ```
12 |
13 | ## Usage
14 |
15 | ```python
16 | import torch
17 | from frame_averaging_pytorch import FrameAverage
18 |
19 | # contrived neural network
20 |
21 | net = torch.nn.Linear(3, 3)
22 |
23 | # wrap the network with FrameAverage
24 |
25 | net = FrameAverage(
26 | net,
27 | dim = 3, # defaults to 3 for spatial, but can be any value
28 | stochastic = True # whether to use stochastic variant from FAENet (one frame sampled at random)
29 | )
30 |
31 | # pass your input to the network as usual
32 |
33 | points = torch.randn(4, 1024, 3)
34 | mask = torch.ones(4, 1024).bool()
35 |
36 | out = net(points, frame_average_mask = mask)
37 |
38 | out.shape # (4, 1024, 3)
39 |
40 | # frame averaging is automatically taken care of, as though the network were unwrapped
41 | ```
42 |
43 | or you can also carry it out manually
44 |
45 | ```python
46 | import torch
47 | from frame_averaging_pytorch import FrameAverage
48 |
49 | # contrived neural network
50 |
51 | net = torch.nn.Linear(3, 3)
52 |
53 | # frame average module without passing in network
54 |
55 | fa = FrameAverage()
56 |
57 | # pass the 3d points and mask to FrameAverage forward
58 |
59 | points = torch.randn(4, 1024, 3)
60 | mask = torch.ones(4, 1024).bool()
61 |
62 | framed_inputs, frame_average_fn = fa(points, frame_average_mask = mask)
63 |
64 | # network forward
65 |
66 | net_out = net(framed_inputs)
67 |
68 | # frame average
69 |
70 | frame_averaged = frame_average_fn(net_out)
71 |
72 | frame_averaged.shape # (4, 1024, 3)
73 | ```
74 |
75 | ## Citations
76 |
77 | ```bibtex
78 | @article{Puny2021FrameAF,
79 | title = {Frame Averaging for Invariant and Equivariant Network Design},
80 | author = {Omri Puny and Matan Atzmon and Heli Ben-Hamu and Edward James Smith and Ishan Misra and Aditya Grover and Yaron Lipman},
81 | journal = {ArXiv},
82 | year = {2021},
83 | volume = {abs/2110.03336},
84 | url = {https://api.semanticscholar.org/CorpusID:238419638}
85 | }
86 | ```
87 |
88 | ```bibtex
89 | @article{Duval2023FAENetFA,
90 | title = {FAENet: Frame Averaging Equivariant GNN for Materials Modeling},
91 | author = {Alexandre Duval and Victor Schmidt and Alex Hernandez Garcia and Santiago Miret and Fragkiskos D. Malliaros and Yoshua Bengio and David Rolnick},
92 | journal = {ArXiv},
93 | year = {2023},
94 | volume = {abs/2305.05577},
95 | url = {https://api.semanticscholar.org/CorpusID:258564608}
96 | }
97 | ```
98 |
--------------------------------------------------------------------------------
/frame-averaging.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/frame-averaging-pytorch/746032ba57b44038ac3fbfaa1632af86e69b1050/frame-averaging.png
--------------------------------------------------------------------------------
/frame_averaging_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from frame_averaging_pytorch.frame_averaging import (
2 | FrameAverage
3 | )
4 |
5 | __all__ = [
6 | FrameAverage
7 | ]
8 |
--------------------------------------------------------------------------------
/frame_averaging_pytorch/frame_averaging.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from random import randrange
4 |
5 | import torch
6 | from torch.nn import Module
7 | from torch.utils._pytree import tree_map
8 |
9 | from einops import rearrange, repeat, reduce, einsum
10 |
11 | # helper functions
12 |
13 | def exists(v):
14 | return v is not None
15 |
16 | def default(v, d):
17 | return v if exists(v) else d
18 |
19 | # main class
20 |
21 | class FrameAverage(Module):
22 | def __init__(
23 | self,
24 | net: Module | None = None,
25 | dim = 3,
26 | stochastic = False,
27 | invariant_output = False,
28 | return_stochastic_as_augmented_pos = False # will simply return points as augmented points of same shape on forward
29 | ):
30 | super().__init__()
31 | self.net = net
32 |
33 | assert dim > 1
34 |
35 | self.dim = dim
36 | self.num_frames = 2 ** dim
37 |
38 | # frames are all permutations of the positive (+1) and negative (-1) eigenvectors for each dimension, iiuc
39 | # so there will be 2 ^ dim frames
40 |
41 | directions = torch.tensor([-1, 1])
42 |
43 | colon = slice(None)
44 | accum = []
45 |
46 | for ind in range(dim):
47 | dim_slice = [None] * dim
48 | dim_slice[ind] = colon
49 |
50 | accum.append(directions[dim_slice])
51 |
52 | accum = torch.broadcast_tensors(*accum)
53 | operations = torch.stack(accum, dim = -1)
54 | operations = rearrange(operations, '... d -> (...) d')
55 |
56 | assert operations.shape == (self.num_frames, dim)
57 |
58 | self.register_buffer('operations', operations)
59 |
60 | # whether to use stochastic frame averaging
61 | # proposed in https://arxiv.org/abs/2305.05577
62 | # one frame is selected at random
63 |
64 | self.stochastic = stochastic
65 | self.return_stochastic_as_augmented_pos = return_stochastic_as_augmented_pos
66 |
67 | # invariant output setting
68 |
69 | self.invariant_output = invariant_output
70 |
71 | def forward(
72 | self,
73 | points,
74 | *args,
75 | frame_average_mask = None,
76 | return_framed_inputs_and_averaging_function = False,
77 | **kwargs,
78 | ):
79 | """
80 | b - batch
81 | n - sequence
82 | d - dimension (input or source)
83 | e - dimension (target)
84 | f - frames
85 | """
86 |
87 | assert points.shape[-1] == self.dim, f'expected points of dimension {self.dim}, but received {points.shape[-1]}'
88 |
89 | # account for variable lengthed points
90 |
91 | if exists(frame_average_mask):
92 | frame_average_mask = rearrange(frame_average_mask, '... -> ... 1')
93 | points = points * frame_average_mask
94 |
95 | # shape must end with (batch, seq, dim)
96 |
97 | batch, seq_dim, input_dim = points.shape
98 |
99 | # frame averaging logic
100 |
101 | if exists(frame_average_mask):
102 | num = reduce(points, 'b n d -> b 1 d', 'sum')
103 | den = reduce(frame_average_mask.float(), 'b n 1 -> b 1 1', 'sum')
104 | centroid = num / den.clamp(min = 1)
105 | else:
106 | centroid = reduce(points, 'b n d -> b 1 d', 'mean')
107 |
108 | centered_points = points - centroid
109 |
110 | if exists(frame_average_mask):
111 | centered_points = centered_points * frame_average_mask
112 |
113 | covariance = einsum(centered_points, centered_points, 'b n d, b n e -> b d e')
114 |
115 | _, eigenvectors = torch.linalg.eigh(covariance)
116 |
117 | # if stochastic, just select one random operation
118 |
119 | num_frames = self.num_frames
120 | operations = self.operations
121 |
122 | if self.stochastic:
123 | rand_frame_index = randrange(self.num_frames)
124 |
125 | operations = operations[rand_frame_index:(rand_frame_index + 1)]
126 | num_frames = 1
127 |
128 | # frames
129 |
130 | frames = rearrange(eigenvectors, 'b d e -> b 1 d e') * rearrange(operations, 'f e -> f 1 e')
131 |
132 | # inverse frame op
133 |
134 | inputs = einsum(frames, centered_points, 'b f d e, b n d -> b f n e')
135 |
136 | # define the frame averaging function
137 |
138 | def frame_average(out):
139 | if not self.invariant_output:
140 | # apply frames
141 |
142 | out = einsum(frames, out, 'b f d e, b f ... e -> b f ... d')
143 |
144 | if not self.stochastic:
145 | # averaging across frames, thus "frame averaging"
146 |
147 | out = reduce(out, 'b f ... -> b ...', 'mean')
148 | else:
149 | out = rearrange(out, 'b 1 ... -> b ...')
150 |
151 | return out
152 |
153 | # if one wants to handle the framed inputs externally
154 |
155 | if return_framed_inputs_and_averaging_function or not exists(self.net):
156 |
157 | if self.stochastic and self.return_stochastic_as_augmented_pos:
158 | return rearrange(inputs, 'b 1 ... -> b ...')
159 |
160 | return inputs, frame_average
161 |
162 | # merge frames into batch
163 |
164 | inputs = rearrange(inputs, 'b f ... -> (b f) ...')
165 |
166 | # if batch is expanded by number of frames, any tensor being passed in for args and kwargs needed to be expanded as well
167 | # automatically take care of this
168 |
169 | if not self.stochastic:
170 | args, kwargs = tree_map(
171 | lambda el: (
172 | repeat(el, 'b ... -> (b f) ...', f = num_frames)
173 | if torch.is_tensor(el)
174 | else el
175 | )
176 | , (args, kwargs))
177 |
178 | # main network forward
179 |
180 | out = self.net(inputs, *args, **kwargs)
181 |
182 | # use tree map to handle multiple outputs
183 |
184 | out = tree_map(lambda t: rearrange(t, '(b f) ... -> b f ...', f = num_frames) if torch.is_tensor(t) else t, out)
185 | out = tree_map(lambda t: frame_average(t) if torch.is_tensor(t) else t, out)
186 |
187 | return out
188 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "frame-averaging-pytorch"
3 | version = "0.1.2"
4 | description = "Frame Averaging"
5 | authors = [
6 | { name = "Phil Wang", email = "lucidrains@gmail.com" }
7 | ]
8 | readme = "README.md"
9 | requires-python = ">= 3.9"
10 | license = { file = "LICENSE" }
11 | keywords = [
12 | 'artificial intelligence',
13 | 'deep learning',
14 | 'geometric learning',
15 | ]
16 |
17 | classifiers=[
18 | 'Development Status :: 4 - Beta',
19 | 'Intended Audience :: Developers',
20 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
21 | 'License :: OSI Approved :: MIT License',
22 | 'Programming Language :: Python :: 3.9',
23 | ]
24 |
25 | dependencies = [
26 | "torch>=2.0",
27 | "einops>=0.8.0",
28 | ]
29 |
30 | [project.urls]
31 | Homepage = "https://pypi.org/project/frame-averaging-pytorch/"
32 | Repository = "https://github.com/lucidrains/frame-averaging-pytorch"
33 |
34 | [project.optional-dependencies]
35 | examples = []
36 | test = [
37 | "pytest"
38 | ]
39 |
40 | [tool.pytest.ini_options]
41 | pythonpath = [
42 | "."
43 | ]
44 |
45 | [build-system]
46 | requires = ["hatchling"]
47 | build-backend = "hatchling.build"
48 |
49 | [tool.rye]
50 | managed = true
51 | dev-dependencies = []
52 |
53 | [tool.hatch.metadata]
54 | allow-direct-references = true
55 |
56 | [tool.hatch.build.targets.wheel]
57 | packages = ["frame_averaging_pytorch"]
58 |
--------------------------------------------------------------------------------
/tests/test_frame_average.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import Module
6 | from frame_averaging_pytorch import FrameAverage
7 |
8 | @pytest.mark.parametrize('stochastic', (True, False))
9 | @pytest.mark.parametrize('dim', (2, 3, 4))
10 | @pytest.mark.parametrize('has_mask', (True, False))
11 | def test_frame_average(
12 | stochastic: bool,
13 | dim: int,
14 | has_mask: bool
15 | ):
16 |
17 | net = torch.nn.Linear(dim, dim)
18 |
19 | net = FrameAverage(
20 | net,
21 | dim = dim,
22 | stochastic = stochastic
23 | )
24 |
25 | points = torch.randn(4, 1024, dim)
26 |
27 | mask = None
28 | if has_mask:
29 | mask = torch.ones(4, 1024).bool()
30 |
31 | out = net(points, frame_average_mask = mask)
32 | assert out.shape == points.shape
33 |
34 | def test_frame_average_manual():
35 |
36 | net = torch.nn.Linear(3, 3)
37 |
38 | fa = FrameAverage()
39 | points = torch.randn(4, 1024, 3)
40 |
41 | framed_inputs, frame_average_fn = fa(points)
42 |
43 | net_out = net(framed_inputs)
44 |
45 | frame_averaged = frame_average_fn(net_out)
46 |
47 | assert frame_averaged.shape == points.shape
48 |
49 | def test_frame_average_multiple_inputs_and_outputs():
50 |
51 | class Network(Module):
52 | def __init__(self):
53 | super().__init__()
54 | self.net = nn.Linear(3, 3)
55 | self.to_out1 = nn.Linear(3, 3)
56 | self.to_out2 = nn.Linear(3, 3)
57 |
58 | def forward(self, x, mask):
59 | x = x.masked_fill(~mask[..., None], 0.)
60 | hidden = self.net(x)
61 | return 0., self.to_out1(hidden), self.to_out2(hidden)
62 |
63 | net = Network()
64 | net = FrameAverage(net)
65 |
66 | points = torch.randn(4, 1024, 3)
67 | mask = torch.ones(4, 1024).bool()
68 |
69 | _, out1, out2 = net(points, mask, frame_average_mask = mask)
70 |
71 | assert out1.shape == out2.shape == points.shape
72 |
--------------------------------------------------------------------------------