├── .deepsource.toml
├── .github
└── workflows
│ └── python-package.yml
├── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── docs
├── CHANGELOG.md
├── Makefile
├── _static
│ ├── emlp_frontfig_bitmap.png
│ ├── emlp_logo4x.png
│ └── style.css
├── _templates
│ └── layout.html
├── conf.py
├── documentation.md
├── index.rst
├── inner_workings.md
├── manual_make.sh
├── notebooks
│ ├── 1quickstart.ipynb
│ ├── 2building_a_model.ipynb
│ ├── 3new_groups.ipynb
│ ├── 4new_representations.ipynb
│ ├── 5mixed_tensors.ipynb
│ ├── 6multilinear_maps.ipynb
│ ├── _colab_preamble.ipynb
│ ├── colabs
│ │ ├── 1quickstart.ipynb
│ │ ├── 2building_a_model.ipynb
│ │ ├── 3new_groups.ipynb
│ │ ├── 4new_representations.ipynb
│ │ ├── 5mixed_tensors.ipynb
│ │ ├── 6multilinear_maps.ipynb
│ │ ├── 7pytorch_support.ipynb
│ │ ├── all.ipynb
│ │ ├── flax_support.ipynb
│ │ ├── haiku_support.ipynb
│ │ ├── imgs
│ │ │ ├── EMLP_fig.png
│ │ │ └── imgs
│ │ │ │ └── EMLP_fig.png
│ │ └── pytorch_support.ipynb
│ ├── flax_support.ipynb
│ ├── haiku_support.ipynb
│ ├── imgs
│ │ └── EMLP_fig.png
│ ├── merge_nbs.sh
│ └── pytorch_support.ipynb
├── package
│ ├── emlp.groups.rst
│ ├── emlp.nn.rst
│ └── emlp.reps.rst
├── requirements.txt
└── testing.md
├── emlp
├── __init__.py
├── datasets.py
├── groups.py
├── nn
│ ├── __init__.py
│ ├── flax.py
│ ├── haiku.py
│ ├── objax.py
│ └── pytorch.py
├── reps
│ ├── __init__.py
│ ├── linear_operator_base.py
│ ├── linear_operators.py
│ ├── product_sum_reps.py
│ └── representation.py
└── utils.py
├── experiments
├── data_efficiency.py
├── datasets
│ └── batchnorm.py
├── depreciated
│ ├── train_cube_simple.py
│ ├── train_rubiks.py
│ └── train_tagging.py
├── hnn.py
├── hnn_expt.py
├── neuralode.py
├── neuralode_expt.py
├── notebooks
│ ├── additional_appendix_figs.ipynb
│ ├── make_tables.ipynb
│ ├── paper_figures.ipynb
│ └── synthetic_results_all.csv
├── train_regression.py
└── trainer
│ ├── classifier.py
│ ├── hamiltonian_dynamics.py
│ ├── model_trainer.py
│ ├── trainer.py
│ └── utils.py
├── setup.py
└── tests
├── equivariance_tests.py
├── model_tests.py
└── product_groups_tests.py
/.deepsource.toml:
--------------------------------------------------------------------------------
1 | version = 1
2 |
3 | test_patterns = ["tests/*tests.py"]
4 |
5 | exclude_patterns = [
6 | "*.ipynb",
7 | "experiments/*.py",
8 | "docs/*"
9 | ]
10 |
11 | [[analyzers]]
12 | name = "python"
13 | enabled = true
14 |
15 | [analyzers.meta]
16 | runtime_version = "3.x.x"
17 |
--------------------------------------------------------------------------------
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches:
6 | - '**' # matches every branch
7 | pull_request:
8 | branches:
9 | - '**' # matches every branch
10 |
11 | workflow_dispatch:
12 |
13 | jobs:
14 | build:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v2
18 | - uses: actions/setup-python@v1
19 | with:
20 | python-version: 3.7
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | python -m pip install pytest
25 | python -m pip install pytest-cov
26 | pip install git+https://github.com/deepmind/dm-haiku
27 | python -m pip install flax
28 | pip install -e .[EXPTS]
29 | - name: Test coverage.
30 | run: |
31 | pytest --cov emlp --cov-report xml:cov.xml tests/*.py
32 | - name: Upload coverage to Codecov
33 | uses: codecov/codecov-action@v1
34 | with:
35 | files: ./cov.xml
36 | name: codecov-umbrella
37 | path_to_write_report: ./coverage/codecov_report.txt
38 | verbose: true
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Random other stuff
132 | *.pyc
133 | *.pdf
134 | *.state
135 | *.tfevents*
136 | *.df
137 | *.s
138 |
139 | # additional excludes specific to the repo
140 | *.t
141 | *.dat
142 | *.h5
143 | *.pkl
144 | *.json
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | build:
9 | image: latest
10 |
11 |
12 | # Build documentation in the docs/ directory with Sphinx
13 | sphinx:
14 | builder: html
15 | configuration: docs/conf.py
16 |
17 | # Optionally build your docs in additional formats such as PDF
18 | formats: []
19 | # - pdf
20 |
21 | # Optionally set the version of Python and requirements required to build your docs
22 | python:
23 | version: 3.8
24 | install:
25 | - requirements: docs/requirements.txt
26 | - method: pip
27 | path: .
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 mfinzi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 | # A Practical Method for Constructing Equivariant Multilayer Perceptrons for Arbitrary Matrix Groups
6 | [](https://emlp.readthedocs.io/en/latest/) | [](https://arxiv.org/abs/2104.09459) | [](https://colab.research.google.com/github/mfinzi/equivariant-MLP/blob/master/docs/notebooks/colabs/all.ipynb) |
7 | [](https://codecov.io/github/mfinzi/equivariant-MLP)
8 | | [](https://pypi.org/project/emlp/)
9 |
10 |
11 |
12 |
13 | *EMLP* is a jax library for the automated construction of equivariant layers in deep learning based on the ICML2021 paper [A Practical Method for Constructing Equivariant Multilayer Perceptrons for Arbitrary Matrix Groups](https://arxiv.org/abs/2104.09459). You can read the documentation [here](https://emlp.readthedocs.io/en/latest/).
14 |
15 |
16 | ## What EMLP is great at doing
17 |
18 | - Computing equivariant linear layers between finite dimensional
19 | representations. You specify the symmetry group (discrete, continuous,
20 | non compact, complex) and the representations (tensors, irreducibles, induced representations, etc), and we will compute the basis of equivariant
21 | maps mapping from one to the other.
22 |
23 | - Automatic construction of full equivariant models for small data. E.g.
24 | if your inputs and outputs (and intended features) are a small collection of elements like scalars, vectors, tensors, irreps with a total dimension less than 1000, then you will likely be able to use EMLP as a turnkey solution for making the model or atleast function as a strong baseline.
25 |
26 | - As a tool for building larger models, but where EMLP is just one component in a larger system. For example, using EMLP as the convolution kernel in an equivariant PointConv network.
27 |
28 | ## What EMLP is not great at doing
29 |
30 | - An efficient implementation of CNNs, Deep Sets, typical translation + rotation equivariant GCNNs, graph neural networks.
31 |
32 | - Handling large data like images, voxel grids, medium-large graphs, point clouds.
33 |
34 | Given the current approach, EMLP can only ever be as fast as an MLP. So if flattening the inputs into a single vector would be too large to train with an MLP, then it will also be too large to train with EMLP.
35 |
36 | --------------------------------------------------------------------------------
37 |
38 | # Showcasing some examples of computing equivariant bases
39 |
40 | We provide a type system for representations. With the operators ρᵤ⊗ρᵥ, ρᵤ⊕ρᵥ, ρ* implemented as `*`,`+` and `.T` build up different representations. The basic building blocks for representations are the base vector representation `V` and tensor representations `T(p,q) = V**p*V.T**q`.
41 |
42 | For any given matrix group and representation formed in our type system, you can get the equivariant basis with [`rep.equivariant_basis()`](https://emlp.readthedocs.io/en/latest/package/emlp.reps.html#emlp.reps.equivariant_basis) or a matrix which projects to that subspace with [`rep.equivariant_projector()`](https://emlp.readthedocs.io/en/latest/package/emlp.reps.html#emlp.reps.equivariant_projector).
43 |
44 | For example to find all O(1,3) (Lorentz) equivariant linear maps from from a 4-Vector Xᶜ to a rank (2,1) tensor Mᵇᵈₐ, you can run
45 |
46 | ```python
47 | from emlp.reps import V,T
48 | from emlp.groups import *
49 |
50 | G = O13()
51 | Q = (T(1,0)>>T(2,1))(G).equivariant_basis()
52 | ```
53 |
54 | or how about equivariant maps from one Rubik's cube to another?
55 | ```python
56 | G = RubiksCube()
57 |
58 | Q = (V(G)>>V(G)).equivariant_basis()
59 | ```
60 |
61 | Using `+` and `*` you can put together composite representations (where multiple representations are concatenated together). For example lets find all equivariant linear maps from 5 node features and 2 edge features to 3 global invariants and 1 edge feature of a graph of size n=5:
62 | ```python
63 | G=S(5)
64 |
65 | repin = 10*T(1)+5*T(2)
66 | repout = 3*T(0)+T(2)
67 | Q = (repin(G)>>repout(G)).equivariant_basis()
68 | ```
69 |
70 | From the examples above, there are many different ways of writing a representation like `10*T(1)+5*T(2)` which are all equivalent.
71 | `10*T(1)+5*T(2)` = `10*V+5*V**2` = `5*V*(2+V)`
72 |
79 |
80 | You can even mix and match representations from different groups. For example with the cyclic group ℤ₃, the permutation group 𝕊₄, and the orthogonal group O(3)
81 |
82 | ```python
83 | rep = 2*V(Z(3))*V(S(4))+V(O(3))**2
84 | Q = (rep>>rep).equivariant_basis()
85 | ```
86 |
87 | Outside of these tensor representations, our type system works with any finite dimensional linear representation and you can even build your own bespoke representations following the instructions [here](https://emlp.readthedocs.io/en/latest/notebooks/4new_representations.html).
88 |
89 | You can visualize these equivariant bases with [`vis(repin,repout)`](https://emlp.readthedocs.io/en/latest/package/emlp.reps.html#emlp.reps.vis), such as with the three examples above
90 |
91 |
92 |
94 |
95 |
96 | Checkout our [documentation](https://emlp.readthedocs.io/en/latest/) to see how to use our system and some worked examples.
97 |
98 | # Simple example of using EMLP as a full equivariant model
99 |
100 | Suppose we want to construct a Lorentz equivariant model for particle physics data that takes in the input and output 4-momentum of two particles
101 | in a collision, as well as a some metadata about these particles like their charge, and we want to classify the output
102 | as belonging to 3 distinct classes of collisions. Since the outputs are simple logits, they should be unchanged by
103 | Lorentz transformation, and similarly with the charges.
104 |
105 | ```python
106 | import emlp
107 | from emlp.reps import T
108 | from emlp.groups import Lorentz
109 | import numpy as np
110 |
111 | repin = 4*T(1)+2*T(0) # 4 four vectors and 2 scalars for the charges
112 | repout = 3*T(0) # 3 output logits for the 3 classes of collisions
113 | group = Lorentz()
114 | model = emlp.nn.EMLP(repin,repout,group=group,num_layers=3,ch=384)
115 |
116 | x = np.random.randn(32,repin(group).size()) # Create a minibatch of data
117 | y = model(x) # Outputs the 3 class logits
118 | ```
119 |
120 | Here we have used the default Objax EMLP, but you can also use our [PyTorch](https://emlp.readthedocs.io/en/latest/notebooks/pytorch_support.html), [Haiku](https://emlp.readthedocs.io/en/latest/notebooks/haiku_support.html), or [Flax](https://emlp.readthedocs.io/en/latest/notebooks/flax_support.html) versions of the models. To see more examples, or how to use your own representations or symmetry groups, check out the documentation.
121 |
122 | # Installation instructions
123 |
124 | To install as a package, run
125 | ```bash
126 | pip install emlp
127 | ```
128 |
129 | To run the scripts you will instead need to clone the repo and install it locally which you can do with
130 |
131 | ```bash
132 | git clone https://github.com/mfinzi/equivariant-MLP.git
133 | cd equivariant-MLP
134 | pip install -e .[EXPTS]
135 | ```
136 |
137 | # Experimental Results from Paper
138 |
139 | Assuming you have installed the repo locally, you can run the experiments we described in the paper.
140 |
141 | To train the regression models on one of the `Inertia`, `O5Synthetic`, or `ParticleInteraction` datasets found in [`emlp.datasets.py`](https://github.com/mfinzi/equivariant-MLP/blob/master/emlp/datasets.py) you can run the script [`experiments/train_regression.py`](https://github.com/mfinzi/equivariant-MLP/blob/master/experiments/train_regression.py) with command line arguments specifying the dataset, network, and symmetry group. For example to train [`EMLP`](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html#emlp.nn.EMLP) with [`SO(3)`](https://emlp.readthedocs.io/en/latest/package/emlp.groups.html#emlp.groups.SO) equivariance on the `Inertia` dataset, you can run
142 |
143 | ```
144 | python experiments/train_regression.py --dataset Inertia --network EMLP --group "SO(3)"
145 | ```
146 |
147 | or to train the MLP baseline you can run
148 |
149 | ```
150 | python experiments/train_regression.py --dataset Inertia --network MLP
151 | ```
152 | Other command line arguments such as `--aug=True` for data augmentation or `--ch=512` for number of hidden units and others are available, and you can browse the options and their defaults with `python experiments/train_regression.py -h`. If no group is specified, EMLP will automatically choose the one matched to the dataset, but you can also go crazy with any of the other groups implemented in [`groups.py`](https://github.com/mfinzi/equivariant-MLP/blob/master/emlp/groups.py) provided the dimensions match the data (e.g. for the 3D inertia dataset you could do `--group=` [`"Z(3)"`](https://emlp.readthedocs.io/en/latest/package/emlp.groups.html#emlp.groups.Z) or [`"DkeR3(3)"`](https://emlp.readthedocs.io/en/latest/package/emlp.groups.html#emlp.groups.DkeR3) but not [`"Sp(2)"`](https://emlp.readthedocs.io/en/latest/package/emlp.groups.html#emlp.groups.Sp) or [`"SU(5)"`](https://emlp.readthedocs.io/en/latest/package/emlp.groups.html#emlp.groups.SU)).
153 |
154 | For the dynamical systems modeling experiments you can use the scripts
155 | [`experiments/neuralode.py`](https://github.com/mfinzi/equivariant-MLP/blob/master/experiments/neuralode.py) to train (equivariant) Neural ODEs and [`experiments/hnn.py`](https://github.com/mfinzi/equivariant-MLP/blob/master/experiments/hnn.py) to train (equivariant) Hamiltonian Neural Networks.
156 |
157 |
158 | For the dynamical system task, the Neural ODE and HNN models have special names. [`EMLPode`](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html#emlp.nn.EMLPode) and [`MLPode`](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html#emlp.nn.MLPode) for the Neural ODEs in `neuralode.py` and [`EMLPH`](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html#emlp.nn.EMLPH) and [`MLPH`](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html#emlp.nn.MLPH) for the HNNs in `hnn.py`. For example,
159 |
160 | ```
161 | python experiments/neuralode.py --network EMLPode --group="O2eR3()"
162 | ```
163 | or
164 |
165 | ```
166 | python experiments/hnn.py --network EMLPH --group="DkeR3(6)"
167 | ```
168 |
169 | These models are trained to fit a double spring dynamical system. 30s rollouts of the dataset, along with rollout error on these trajectories, and conservation of angular momentum are shown below.
170 |
171 |
172 |
173 |
177 |
178 | If you find our work helpful, please cite it with
179 | ```bibtex
180 | @article{finzi2021emlp,
181 | title={A Practical Method for Constructing Equivariant Multilayer Perceptrons for Arbitrary Matrix Groups},
182 | author={Finzi, Marc and Welling, Max and Wilson, Andrew Gordon},
183 | journal={Arxiv},
184 | year={2021}
185 | }
186 | ```
187 |
189 |
--------------------------------------------------------------------------------
/docs/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Change Log
2 |
3 |
7 |
8 | ## EMLP 1.0.0
9 | * New Features
10 | * Flax support (see `using EMLP with Flax`)
11 | * Auto generated `size()`, `__eq__`, `__hash__`, and `.T` methods for new representations
12 | * You can now use ints in place of `Scalars` for direct sum, e.g. add `3+V`
13 | * Codebase improvements
14 | * Streamlined product_sum_reps direct sum and product rules, now with plumb dispatch
15 | * More general `Dual(Rep)` implementation that now works with any kind of Rep, not just `V`
16 | * CI setup and with more tests
17 |
18 | ## EMLP 0.9.0
19 | * Cross Platform Support:
20 | * You can now use EMLP in PyTorch, check out `Using EMLP with PyTorch`
21 | * You can also use EMLP with Haiku in jax, check out `Using EMLP with Haiku`
22 |
23 | * Bug Fixes
24 | * Fixed broken constraints with Trivial group
25 |
26 | ## EMLP 0.8.0 (Unreleased)
27 |
28 | * New features:
29 | * Fallback autograd jvp implementation of drho to make implementing new reps easier.
30 | * Mixed group representations (now working and tested)
31 | * Experimental support of complex groups and representations
32 | * Bug Fixes:
33 | * Element ordering of mixed groups is now correctly maintained in the solution
34 | * Fixed edge case of {func}`lazy_direct_matmat` when concatenating matrices of size 0
35 | affecting {func}`emlp.reps.Rep.equivariant_basis` but not
36 | {func}`emlp.reps.Rep.equivariant_projector`
37 | * API Changes:
38 | * `emlp.solver.representation` -> `emlp.reps`
39 | * `emlp.solver.groups` -> `emlp.groups`
40 | * `emlp.models.mlp` -> `emlp.nn`
41 | * `rep.symmetric_basis()` -> `rep.equivariant_basis()`
42 | * `rep.symmetric_projector()` -> `rep.equivariant_projector()`
43 | * Tests and experiments separated from package and api
44 |
45 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/emlp_frontfig_bitmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mfinzi/equivariant-MLP/b80815bd4c1ca52b37f3dacc0874e70fdd7b413f/docs/_static/emlp_frontfig_bitmap.png
--------------------------------------------------------------------------------
/docs/_static/emlp_logo4x.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mfinzi/equivariant-MLP/b80815bd4c1ca52b37f3dacc0874e70fdd7b413f/docs/_static/emlp_logo4x.png
--------------------------------------------------------------------------------
/docs/_static/style.css:
--------------------------------------------------------------------------------
1 | @import url("theme.css");
2 |
3 | .wy-side-nav-search {
4 | background-color: #fff;
5 | }
--------------------------------------------------------------------------------
/docs/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %}
2 | {% set css_files = css_files + ["_static/style.css"] %}
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | # import os
14 | # import sys
15 | # sys.path.insert(0, os.path.abspath('.'))
16 |
17 | import os
18 | import sys
19 | import re
20 | sys.path.insert(0, os.path.abspath('..'))
21 |
22 | RE_VERSION = re.compile(r'^__version__ \= \'(\d+\.\d+\.\d+(?:\w+\d+)?)\'$', re.MULTILINE)
23 | PROJECTDIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
24 | sys.path.insert(0, PROJECTDIR)
25 | def get_release():
26 | with open(os.path.join(PROJECTDIR, 'emlp', '__init__.py')) as f:
27 | version = re.search(RE_VERSION, f.read())
28 | assert version is not None, "can't parse __version__ from __init__.py"
29 | return version.group(1)
30 |
31 | # -- Project information -----------------------------------------------------
32 |
33 | project = 'EMLP'
34 | copyright = '2021, Marc Finzi'
35 | author = 'Marc Finzi'
36 |
37 | # import emlp
38 | release = get_release()
39 |
40 |
41 | sys.path.append(os.path.abspath('sphinxext'))
42 | extensions = [
43 | 'nbsphinx',
44 | 'sphinx.ext.mathjax',
45 | 'recommonmark',
46 | 'sphinx.ext.autodoc',
47 | 'sphinx.ext.autosummary',
48 | 'sphinx.ext.intersphinx',
49 | 'sphinx.ext.napoleon',
50 | 'sphinx.ext.viewcode',
51 | 'matplotlib.sphinxext.plot_directive',
52 | 'sphinx_autodoc_typehints',
53 | 'sphinx.ext.autosectionlabel',
54 | ]
55 | autosummary_generate = True
56 | autodoc_default_options = {'autosummary': True}
57 | autodoc_member_order = 'bysource'
58 |
59 | intersphinx_mapping = {
60 | 'python': ('https://docs.python.org/3/', None),
61 | 'numpy': ('https://numpy.org/doc/stable/', None),
62 | 'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None),
63 | 'objax': ('https://objax.readthedocs.io/en/latest/', None),
64 | }
65 |
66 | suppress_warnings = [
67 | 'ref.citation', # Many duplicated citations in numpy/scipy docstrings.
68 | 'ref.footnote', # Many unreferenced footnotes in numpy/scipy docstrings
69 | ]
70 |
71 | # Add any paths that contain templates here, relative to this directory.
72 | templates_path = ['_templates']
73 |
74 | # The suffix(es) of source filenames.
75 | # Note: important to list ipynb before md here: we have both md and ipynb
76 | # copies of each notebook, and myst will choose which to convert based on
77 | # the order in the source_suffix list. Notebooks which are not executed have
78 | # outputs stored in ipynb but not in md, so we must convert the ipynb.
79 | source_suffix = ['.rst', '.ipynb', '.md']
80 |
81 | # The master toctree document.
82 | main_doc = 'index'
83 |
84 | # The language for content autogenerated by Sphinx. Refer to documentation
85 | # for a list of supported languages.
86 | #
87 | # This is also used if you do content translation via gettext catalogs.
88 | # Usually you set "language" from the command line for these cases.
89 | language = None
90 |
91 | # List of patterns, relative to source directory, that match files and
92 | # directories to ignore when looking for source files.
93 | # This pattern also affects html_static_path and html_extra_path.
94 | exclude_patterns = [
95 | # Sometimes sphinx reads its own outputs as inputs!
96 | 'build/html',
97 | 'build/jupyter_execute',
98 | 'notebooks/README.md',
99 | 'notebooks/colabs/*.ipynb'
100 | 'README.md',
101 | # Ignore markdown source for notebooks; myst-nb builds from the ipynb
102 | 'notebooks/*.md'
103 | ]
104 |
105 | # The name of the Pygments (syntax highlighting) style to use.
106 | pygments_style = None
107 |
108 |
109 | autosummary_generate = True
110 | napolean_use_rtype = False
111 |
112 | # mathjax_config = {
113 | # 'TeX': {'equationNumbers': {'autoNumber': 'AMS', 'useLabelIds': True}},
114 | # }
115 |
116 | # Additional files needed for generating LaTeX/PDF output:
117 | # latex_additional_files = ['references.bib']
118 |
119 | # -- Options for HTML output -------------------------------------------------
120 |
121 | # The theme to use for HTML and HTML Help pages. See the documentation for
122 | # a list of builtin themes.
123 | #
124 | html_theme = 'sphinx_rtd_theme'
125 |
126 | # html_theme_options = {
127 | # 'logo_only': False,
128 | # 'display_version': True,
129 | # 'prev_next_buttons_location': 'bottom',
130 | # 'style_external_links': False,
131 | # }
132 |
133 | # Theme options are theme-specific and customize the look and feel of a theme
134 | # further. For a list of options available for each theme, see the
135 | # documentation.
136 |
137 | # The name of an image file (relative to this directory) to place at the top
138 | # of the sidebar.
139 |
140 | # Add any paths that contain custom static files (such as style sheets) here,
141 | # relative to this directory. They are copied after the builtin static files,
142 | # so a file named "default.css" will overwrite the builtin "default.css".
143 | html_static_path = ['_static']
144 | html_logo = '_static/emlp_logo4x.png'
145 | # Custom sidebar templates, must be a dictionary that maps document names
146 | # to template names.
147 | #
148 | # The default sidebars (for documents that don't match any pattern) are
149 | # defined by theme itself. Builtin themes are using these templates by
150 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
151 | # 'searchbox.html']``.
152 | #
153 | # html_sidebars = {}
154 |
155 | # # commented out notebook execution
156 | # # -- Options for myst ----------------------------------------------
157 | # jupyter_execute_notebooks = "off"# changed from "force"
158 | # #jupyter_execute_notebooks = "cache"
159 | # execution_allow_errors = False
160 | # execution_fail_on_error = True # Requires https://github.com/executablebooks/MyST-NB/pull/296
161 |
162 | # # Notebook cell execution timeout; defaults to 30.
163 | # execution_timeout = 100
164 |
165 | # -- Options for HTMLHelp output ---------------------------------------------
166 |
167 | # Output file base name for HTML help builder.
168 | htmlhelp_basename = 'EMLPdoc'
169 |
170 |
171 | # -- Options for LaTeX output ------------------------------------------------
172 |
173 | latex_elements = {
174 | # The paper size ('letterpaper' or 'a4paper').
175 | #
176 | # 'papersize': 'letterpaper',
177 |
178 | # The font size ('10pt', '11pt' or '12pt').
179 | #
180 | # 'pointsize': '10pt',
181 |
182 | # Additional stuff for the LaTeX preamble.
183 | #
184 | # 'preamble': '',
185 |
186 | # Latex figure (float) alignment
187 | #
188 | # 'figure_align': 'htbp',
189 | }
190 |
191 | # # Grouping the document tree into LaTeX files. List of tuples
192 | # # (source start file, target name, title,
193 | # # author, documentclass [howto, manual, or own class]).
194 | # latex_documents = [
195 | # (main_doc, 'JAX.tex', 'JAX Documentation',
196 | # 'The JAX authors', 'manual'),
197 | # ]
198 |
199 |
200 | # # -- Options for manual page output ------------------------------------------
201 |
202 | # # One entry per manual page. List of tuples
203 | # # (source start file, name, description, authors, manual section).
204 | # man_pages = [
205 | # (main_doc, 'jax', 'JAX Documentation',
206 | # [author], 1)
207 | # ]
208 |
209 |
210 | # # -- Options for Texinfo output ----------------------------------------------
211 |
212 | # # Grouping the document tree into Texinfo files. List of tuples
213 | # # (source start file, target name, title, author,
214 | # # dir menu entry, description, category)
215 | # texinfo_documents = [
216 | # (main_doc, 'JAX', 'JAX Documentation',
217 | # author, 'JAX', 'One line description of project.',
218 | # 'Miscellaneous'),
219 | # ]
220 |
221 |
222 | # -- Options for Epub output -------------------------------------------------
223 |
224 | # Bibliographic Dublin Core info.
225 | epub_title = project
226 |
227 | # The unique identifier of the text. This can be a ISBN number
228 | # or the project homepage.
229 | #
230 | # epub_identifier = ''
231 |
232 | # A unique identification for the text.
233 | #
234 | # epub_uid = ''
235 |
236 | # A list of files that should not be packed into the epub file.
237 | epub_exclude_files = ['search.html']
238 |
239 |
240 | # -- Extension configuration -------------------------------------------------
241 |
242 | # Tell sphinx-autodoc-typehints to generate stub parameter annotations including
243 | # types, even if the parameters aren't explicitly documented.
244 | always_document_param_types = True
245 |
246 | # -- Options for nbsphinx -----------------------------------------------------
247 |
248 | # Execute notebooks before conversion: 'always', 'never', 'auto' (default)
249 | # We never execute notebooks to avoid problems if nbsphinx won't find all dependencies.
250 | nbsphinx_execute = 'never'
251 |
252 | # If True, the build process is continued even if an exception occurs:
253 | nbsphinx_allow_errors = True
254 |
255 | # Controls when a cell will time out (defaults to 30; use -1 for no timeout):
256 | nbsphinx_timeout = 180
257 |
258 | # Default Pygments lexer for syntax highlighting in code cells:
259 | nbsphinx_codecell_lexer = 'ipython3'
260 |
261 | nbsphinx_prolog = r"""
262 | {% set docname = 'docs/notebooks/colabs/' + env.doc2path(env.docname, base=None).split('/')[-1] %}
263 |
264 | .. only:: html
265 |
266 | .. role:: raw-html(raw)
267 | :format: html
268 |
269 | .. nbinfo::
270 | Interactive online version:
271 | :raw-html:`
`
272 |
273 | """
274 |
275 | # nbsphinx_prolog = """
276 | # ----
277 |
278 | # Generated by nbsphinx_ from a Jupyter_ notebook.
279 |
280 | # .. _nbsphinx: http://nbsphinx.readthedocs.io/
281 | # .. _Jupyter: https://jupyter.org/
282 | # """
283 |
284 | # nbsphinx_prolog = """
285 | # ----
286 | # {% set docname = 'notebooks/docs/' + env.doc2path(env.docname, base=None) %}
287 | # .. only:: html
288 | # .. role:: raw-html(raw)
289 | # :format: html
290 | # .. nbinfo::
291 | # Interactive online version:
292 | # :raw-html:`
`
293 | # """
294 |
--------------------------------------------------------------------------------
/docs/documentation.md:
--------------------------------------------------------------------------------
1 | # Update documentation
2 |
3 | To rebuild the documentation, you need to install the requirement packages:
4 | ```
5 | pip install -r docs/requirements.txt
6 | ```
7 | And then run:
8 | ```
9 | sphinx-build -b html docs docs/build/html
10 | ```
11 | This can take a long time because it executes many of the notebooks in the documentation source;
12 | if you'd prefer to build the docs without exeuting the notebooks, you can run:
13 | ```
14 | sphinx-build -b html -D jupyter_execute_notebooks=off docs docs/build/html
15 | ```
16 | You can then see the generated documentation in `docs/build/html/index.html`.
17 |
18 |
19 | ### Editing ipynb
20 |
21 | To edit notebooks in the Colab interface,
22 | open and `Upload` from your local repo.
23 | Update it as needed, `Run all cells` then `Download ipynb` (for editing and running in Colab you will need to add
24 | `!pip install git+https://github.com/mfinzi/equivariant-MLP.git`).
25 | You may want to test that it executes properly, using `sphinx-build` as explained above.
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. EMLP documentation master file, created by
2 | sphinx-quickstart on Mon Feb 22 18:41:05 2021.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 |
7 | EMLP reference documentation
8 | ============================
9 |
10 | A type system for the automated construction of equivariant layers.
11 | EMLP is designed to make constructing equivariant layers with different matrix groups
12 | and representations an easy task, and one that does not require knowledge of analytic solutions.
13 |
14 |
15 | .. toctree::
16 | :maxdepth: 1
17 | :caption: Getting Started
18 |
19 | notebooks/1quickstart.ipynb
20 | notebooks/2building_a_model.ipynb
21 | notebooks/3new_groups.ipynb
22 |
23 | .. toctree::
24 | :maxdepth: 1
25 | :caption: Advanced Features
26 |
27 | notebooks/4new_representations.ipynb
28 | notebooks/5mixed_tensors.ipynb
29 | notebooks/6multilinear_maps.ipynb
30 |
31 | .. toctree::
32 | :maxdepth: 1
33 | :caption: Cross Platform Support
34 |
35 | notebooks/pytorch_support.ipynb
36 | notebooks/haiku_support.ipynb
37 | notebooks/flax_support.ipynb
38 |
39 | .. toctree::
40 | :maxdepth: 1
41 | :caption: Examples
42 |
43 | .. toctree::
44 | :glob:
45 | :maxdepth: 1
46 | :caption: API Reference
47 |
48 | package/emlp.groups
49 | package/emlp.reps
50 | package/emlp.nn
51 |
52 | .. toctree::
53 | :maxdepth: 1
54 | :caption: Notes
55 |
56 | testing.md
57 | inner_workings.md
58 | CHANGELOG.md
59 |
60 | .. toctree::
61 | :maxdepth: 2
62 | :caption: Developer Documentation
63 |
64 | documentation.md
65 |
66 | Indices and tables
67 | ==================
68 |
69 | * :ref:`genindex`
70 | * :ref:`modindex`
71 |
--------------------------------------------------------------------------------
/docs/inner_workings.md:
--------------------------------------------------------------------------------
1 | # Inner Workings
2 |
3 | It's a mystery, even to me
--------------------------------------------------------------------------------
/docs/manual_make.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # jupytext --sync ./notebooks/*
3 | ./notebooks/merge_nbs.sh
4 | sphinx-build -b html . ./build/html
--------------------------------------------------------------------------------
/docs/notebooks/5mixed_tensors.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Combining Representations from Different Groups (experimental)\n",
8 | "## Direct Product groups"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "metadata": {},
14 | "source": [
15 | "It is possible to combine representations from different groups. These kinds of representations are relevant when there are multiple structures in the data. For example, a point cloud is a set of vectors which transforms both under permutations and rotations of the vectors. We can formalize this as $V_{S_n}\\otimes V_{SO(3)}$ or in tensor notation $T_1^{S_n}\\otimes T_1^{SO(3)}$. While this object could be expressed as the representation of the product group $V_{S_n \\times SO(3)}$, other objects like $T_k^{S_n}\\otimes T_j^{SO(3)}$ can not be so easily.\n",
16 | "\n",
17 | "Nevertheless, we can calculate the symmetric bases for these objects. For example, maps from vector edge features $T_1^{SO(3)}\\otimes T_2^{S_n}$ to matrix node features $T_2^{SO(3)}\\otimes T_1^{S_n}$ can be computed:"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 1,
23 | "metadata": {},
24 | "outputs": [
25 | {
26 | "name": "stdout",
27 | "output_type": "stream",
28 | "text": [
29 | "V_SO(3)⊗V²_S(4) --> V²_SO(3)⊗V_S(4)\n"
30 | ]
31 | }
32 | ],
33 | "source": [
34 | "from emlp.groups import *\n",
35 | "from emlp.reps import T,vis,V,Scalar\n",
36 | "\n",
37 | "repin,repout = T(1)(SO(3))*T(2)(S(4)),T(2)(SO(3))*T(1)(S(4))\n",
38 | "print(repin,\"-->\", repout)"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 7,
44 | "metadata": {},
45 | "outputs": [
46 | {
47 | "data": {
48 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAATAAAADnCAYAAACZtwrQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAH/UlEQVR4nO3dMYhdRRsG4JPEIoKtIBjcbCxFLJU0AUurkCZiF7AIIqwhhUgKEQuxCMs2VoJWkjRilVKwESyDWGl2jdhZplkCm7UQxGLmmLnMuee+Z5+nnOueuZvL//4D7879Th0fHw8AiU7P/QYAViXAgFgCDIglwIBYAgyI9czYi9t7t1WUwKwOdm6eqr3mBAbEEmBALAEGxBJgQCwBBsQSYEAsAQbEEmBALAEGxBJgQCwBBsQSYEAsAQbEEmBALAEGxBJgQCwBBsQSYEAsAQbEEmBALAEGxBJgQCwBBsQanQtZ1WtaZG3aW+351elwncy17yp7+wzWsu+z5x8V11+69bjLtvtvP199bevjH4vrf350sWmPo7Pl9TOH5fVzn7Xt2/r8mtq+w079Z5zAgFgCDIglwIBYAgyIJcCAWOMt5NSN0NTN2zraw6n5DGb14pVfiuv7n5Qbudbm7cKdv6qv/fHtK8X1rVv1nyl5eLncdNbaw1a137lXOznGCQyIJcCAWAIMiCXAgFgCDIi12l3IXqa+h9erMZvzjuTUfAYr6dW81RrCYRiGc1fKdwNbG9Ct79pay9b2s/Y7tP4brcIJDIglwIBYAgyIJcCAWAIMiDVvCzm1Xo3Z1N+Kuq495pDyGXRqM9dxL7B1j7Gms6S1/Ry7z9mi1n6OcQIDYgkwIJYAA2IJMCCWAANinTo+rtc423u3yy/ONZMw5fkd93juYfn/Yz59/+suz9/54Z3i+sFbXxbXX/7+WtPznxyeKa6fPntUXH/w5ldN+7Y+v6a6793rTc+hv4Odm9X/pTmBAbEEGBBLgAGxBBgQS4ABsfrOhWxt3pb6/I573P/wi+L69r13i+ut7dvepW+K6699/l5xfbex/bzx09Xieq09bFX7fXu1k2w2JzAglgADYgkwIJYAA2IJMCCWAANirfaV0kv984eel78n3qPXnw/U/szhQac/39h9/W5xvab1zzdq77/134dMTmBALAEGxBJgQCwBBsQSYECsvoNt09vDTWwnG019ubn1+bWWsKa1/axdRm9Vaz+H80+6PJ9pOIEBsQQYEEuAAbEEGBBLgAGxsgbbdmrqqubad2zvmnW8J4bnfp9nsPAwGC7873Ne+NVgW2B5BBgQS4ABsQQYEEuAAbH6DrZtNfXdwyU0dXP9DnM2shtkrsHCw2C48FPt3e1JAGsmwIBYAgyIJcCAWAIMiNX3G1lbTX0XsldreRIbOe3nqKnncg6D2ZxPwwkMiCXAgFgCDIglwIBYAgyINW8LObVerWXPuZA1G9ayzWbT2s9G67gXeNJmc/68W/8ZJzAglgADYgkwIJYAA2IJMCCWuZCbsO8qe/sMlrEv/+tg56a5kMDyCDAglgADYgkwIJYAA2KZC7npfAZQ5QQGxBJgQCwBBsQSYEAsAQbEMhfyaZ6/5HtyPgOCOYEBsQQYEEuAAbEEGBBLgAGxzIX8r02cC9lzjzmkfAbazEhOYEAsAQbEEmBALAEGxBJgQKzRFvLZ84+K6y/detxl8/23ny+ub338Y3H9z48uNj3/6Gx5/cxhef3cZ2371p4/tkdNbe/fdt9oe1Avc7VyrfvO9T7X0X6etBmZK/ybOoEBsQQYEEuAAbEEGBBLgAGxRlvIF6/8Ulzf/6TcyrU2bxfu/FVc/+PbV4rrW7fK/33Nw8vllnOsPWwx9vu2NqBsqLnmco7tYTbnv5zAgFgCDIglwIBYAgyIJcCAWCt9I2utSWtt3mot4bkr5XuBre3n1ndtrWVr+1l7/2PvqVcDykKs4y7kgmdzOoEBsQQYEEuAAbEEGBBLgAGxus6F7NVO9nr+WEtY0tp+1u5yrqLWgA4Pu23BSbXg2ZxOYEAsAQbEEmBALAEGxBJgQKxTx8f1CmF773b5xalbh/Tnd9zDbM7xfXs13Bs3l3MYNm8+40z7Hnxws7qzExgQS4ABsQQYEEuAAbEEGBBr/C7k1Heilvr8jnuYzTlu6vu3s34L6YLnOfbiBAbEEmBALAEGxBJgQCwBBsRa7RtZl9oe9rw7OfEeZnP+o/b+FzGXc64GNKj9dAIDYgkwIJYAA2IJMCCWAANidZ0LGd8ebmI72chsztVU53L+3uXxWYLaTycwIJYAA2IJMCCWAANiCTAglgADYo0Otn31xm7xxU/f/7rL5js/vFNcP3jry+L6y99fa3r+k8MzxfXTZ4+K6w/e/Kpp39rzx/aoqe5993rTc7rZsOGmi92X/3WwY7AtsEACDIglwIBYAgyIJcCAWKOXue9/+EVxffveu8X11uZt79I3xfXXPn+vuL7b2H7e+OlqcX2sPWwx9vu2NqBVQRdrYd2cwIBYAgyIJcCAWAIMiCXAgFgrfaV0rUlrbd5qLeGDTu3n7ut3i+s1re1n7f2PvadeDejktJ8EcAIDYgkwIJYAA2IJMCCWAANidR1s26ud7PX8sZawpLX9rN3lXEWtAR22nnTbI4L2kwZOYEAsAQbEEmBALAEGxBJgQKzRuZDbe7fLL9Z/pHH3yvpJnA1oLiEUmQsJLJIAA2IJMCCWAANiCTAg1vhdyKmbsdbnt7afGjxYNCcwIJYAA2IJMCCWAANiCTAgVtdvZG029V3IXq2le4qwkZzAgFgCDIglwIBYAgyIJcCAWPO2kFPr1Vr2ups5RqMJzZzAgFgCDIglwIBYAgyIJcCAWKNzIQE2mRMYEEuAAbEEGBBLgAGxBBgQS4ABsf4G+uFQ5DwSeDQAAAAASUVORK5CYII=\n",
49 | "text/plain": [
50 | ""
51 | ]
52 | },
53 | "metadata": {
54 | "needs_background": "light"
55 | },
56 | "output_type": "display_data"
57 | }
58 | ],
59 | "source": [
60 | "vis(repin,repout,cluster=False)"
61 | ]
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "metadata": {},
66 | "source": [
67 | "Or perhaps you would like equivariant maps from two sequence of sets and a matrix (under 3D rotations) to itself. You can go wild."
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 2,
73 | "metadata": {},
74 | "outputs": [
75 | {
76 | "name": "stdout",
77 | "output_type": "stream",
78 | "text": [
79 | "Rep: V²+2V_S(4)⊗V_Z(3)\n",
80 | "Linear maps: V⁴+4V_S(4)⊗V_Z(3)⊗V²_SO(3)+4V²_S(4)⊗V²_Z(3)\n"
81 | ]
82 | }
83 | ],
84 | "source": [
85 | "rep = 2*T(1)(Z(3))*T(1)(S(4))+T(2)(SO(3))\n",
86 | "print(f\"Rep: {rep}\")\n",
87 | "print(f\"Linear maps: {rep>>rep}\")"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": 7,
93 | "metadata": {
94 | "scrolled": true
95 | },
96 | "outputs": [
97 | {
98 | "data": {
99 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAI+UlEQVR4nO3dX8jfVR0H8LMhudrQLP/xYG4p1Voxc2GR0KVYEYWVrfBKKCo0NU0iqLCsJFdGGUQs6EqSjCKoXOpFNy2a2DJLZur8Uww2ZyPZTENaF7vo5nvOw/fsu+/ee57X6/J7OJ4H5psD55zf57Pi8OHDBciz8nj/AcAw4YRQwgmhhBNCCSeEOqk1+KHtn6we5e58+jWT/RGbNzxQHbv/2bWTrVNKKU/sGP67X/ed3ZOuc+Pv7q2Ofe6mT0y61qGFFYPfz7r/xUnXedn+Q9WxXZ86ZdK1NmzZN/j9pd1PTrrOSeetq45NuVZrnbsf2zL4D2jnhFDCCaGEE0IJJ4QSTgglnBCqeZXSui658Ny/j55T07ouuejVT3XNq6ldmTx67Xmj57S0rku+cdMPRs9pqV2Z7L3o5NFzWlrXJeu//9zoObTZOSGUcEIo4YRQwgmhhBNCNU9rW2qnsrVT3Naclp6T3ClPcUuZ7yS3dorbmtPSOpGd6yS3dorbmsMRdk4IJZwQSjghlHBCKOGEUMIJoVa0Kr6ff+fXJi0HX7tm+deLL59ymeZj+Z3vWph0rdo1y2kPT7pM85rl5o9fOelatWuWQ+temnSd1jXLyoMvDH5XQwg47oQTQgknhBJOCCWcEKr58H3FU/VT1J4H34cqy536039X5/Q8lt+99/Tq2OZtw9XleyvLv7YMn0Cv/PGa6pyeB99X3vOx6tiPtv5w8Htv2ZPVe2qH9PX/XXoey//n9NXVsVWV09rlxM4JoYQTQgknhBJOCCWcEEo4IVTzKqWnrk7PFUtPZfnF5tXUrkymriw/Z3X05JpEUzfwXU7snBBKOCGUcEIo4YRQwgmhhBNCdbdjmKsJ7VJs4DtX64KeBr6LzauZq4HvcmLnhFDCCaGEE0IJJ4QSTgjVrPj+7oWrJ634XjvJPby2XkOoR+ux/FzV5e/4/TsmXad1kvvPC06bdK3aSe5cleVLKWXtXXsGv6v4Dhx3wgmhhBNCCSeEEk4IJZwQqvnw/cJtw8fZpfQ9+K61LnhiR/2h+pRtH0qpt37oeSxfSr31w+on52td8KoHDwx+761JVGv9UGv7UErfY/l62wdKsXNCLOGEUMIJoYQTQgknhGqe1vaU7ug5xe2pLL/YvJraqezUleXnrI6eXPakt4Evdk6IJZwQSjghlHBCKOGEUMIJoborvs/VhHYpNvCdqzp6TwPfxebVzNXAdzmxc0Io4YRQwgmhhBNCCSeEEk4I1X2VUtPzS5adZaFrrZ4GvqWMb/3Qc83ybFk3ep2eX7Ic8dLotXquWXraPvQ28N1yyftGr7XU2DkhlHBCKOGEUMIJoYQTQjWb52687tvVwbN2HJzsj3hs8yuqY2e8Yf9k65RSyqYz/jH4/YHvXjjpOr/6+jerY++//vpJ1zrl138Z/L7re+snXeeWi39WHdvyrY9MutbzZw/2ky3nfmX7pOs8/aWLq2NTrtVa55GbPqN5LpxIhBNCCSeEEk4IJZwQSjghVPPhe+u6ZO/b1oyeU9O6LnnmkeHmtIvNq6ldmbz1mp2j57S0rkt+cdtto+e01K5M1l+9a/ScltZ1yY033Dl6Dm12TgglnBBKOCGUcEIo4YRQ3WVKaqeytVPc1pyWnpPcKU9xS5nvJLd2itua09I6kZ3rJLd2ituawxF2TgglnBBKOCGUcEIo4YRQwgmhJq/43vNYflV5vmut2pVJ67F8faSu55rlwVsvGL1Oz2P5Ukq56L5rRq/Vc82y74qNo9fpeSxfSilfvuOjo9daauycEEo4IZRwQijhhFDCCaGEE0I1r1LOv/1v1bGeX2OcfGC4u8ML286szun5JcupjauZWuuH3rYPf3zmnMHv1331J9U5Pb/GuPTmz1bH7v/icOuH3ppEz73nzYPfz7zjz9U5Pb9k+fz2D1THVo3+ry09dk4IJZwQSjghlHBCKOGEUM3T2p4H3z2nuD2P5RebV9PzWL7nJHfO6ujJNYmmbuC7nNg5IZRwQijhhFDCCaGEE0IJJ4TqriE0VxPapdjAd67WBb01iaa8Zpm67cNyYueEUMIJoYQTQgknhBJOCCWcEGrydgw9v2R5/NOv71qrp7t2T+uHnmuWV45epb91we1f+PDotXquWeZq+1BKKU/eML6dxVJj54RQwgmhhBNCCSeEEk4I1Tyt3b9puEJ7KX0PvmvV0TfNVFm+lHp1+Z7H8qXUq8uvfOjx6pypq6PfUqku31uTqFZdvlZZvpS+x/K1yvIcYeeEUMIJoYQTQgknhBJOCCWcEKp5ldLz4LvniqXnsfxi82p6Hsv3XLPM2boguSZRbwNf7JwQSzghlHBCKOGEUMIJobrLlMzVhHYpNvCdqzp6b9mTKU9yp64sv5zYOSGUcEIo4YRQwgmhhBNCCSeEmrzie89j+frlS1tPA9+e6vJ91ywvjl6ntzr6vis2jl6r55plrsrypZTyzq03jl5rqbFzQijhhFDCCaGEE0IJJ4QSTgi14vDheuuCh54+pz5I01VXjW80y//9duvWwe+XLrxl0nXe+9cD1bFfvum0Wda59o33rRj6bueEUMIJoYQTQgknhBJOCNU8rb1k5eVOa4ly8PK3V8fW3PWHwe97fr6hOmfhsoeP+m86Wvf+9y6ntXAiEU4IJZwQSjghlHBCKOGEUJPXEIJjqXZdUkr9mmXhsvqcHnNdzdg5IZRwQijhhFDCCaGEE0I5reWE0j4pHT6V7Xks3zLXY3k7J4QSTgglnBBKOCGUcEIo4YRQKr4fIyq+H50P3nrP4Pcpq7CXUspv9vypOjZldXkV32EJEU4IJZwQSjghlHBCKOGEUNoxQEXt1yytX7LUfjXT+iWLdgxwghFOCCWcEEo4IZRwQig1hFgyek5KW2qnsq2aRFNWl7dzQijhhFDCCaGEE0IJJ4QSTgjlKoUlY+o2CfWrmfENfHvaPtg5IZRwQijhhFDCCaGEE0Kp+H6MqPh+dOaq+N6qxD7lWq3K8ivPflSZEjiRCCeEEk4IJZwQSjghlHBCqOZVCnD82DkhlHBCKOGEUMIJoYQTQgknhPof2soBg+pZzXoAAAAASUVORK5CYII=\n",
100 | "text/plain": [
101 | ""
102 | ]
103 | },
104 | "metadata": {
105 | "needs_background": "light"
106 | },
107 | "output_type": "display_data"
108 | }
109 | ],
110 | "source": [
111 | "vis(rep,rep)"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "metadata": {},
117 | "source": [
118 | "The kronecker product of the individual solutions for different groups can be seen in these basis matrices, in the top left hand corner we have a circulant matrix of deep set solutions (identity + background)"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": 3,
124 | "metadata": {},
125 | "outputs": [
126 | {
127 | "data": {
128 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAEpklEQVR4nO3cPWuedRiH4TttoNAvoHPFRYJkKmSU6l4EwZZu7dSl2MmX2ZcpkiUfIQiCe1F06BBIEVIpLmKHDsWv4NA+foHk6XLVng8cx/jccJHl5A8Zflur1WoBei686T8AOJs4IUqcECVOiBInRG2v+/jlHx+P/Cv3+P7ViTNJz+68GLlz5ebpyJ2aB89PR+7sHNwduVP057efbZ31u5cTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LWLiFMLRjs7Z+M3CkuKkwtGDw92h25U1tUmFoweHLvcOTOsmzOqoKXE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROi1i4hTLGo8GoWFdabXC+YWlV43YsKXk6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiNparVbnfvzgw+/O/7jBphYVlmVZHu1eHLtVMrWocOnx5ZE7RVOLChfe/mvrzN9HrgPjxAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUWtnSl7+8+7ITMnOwd2JM0nXbzwcuXN8/+rInZpnd16M3Lly83TkTtHPL380UwKbRJwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IWrtEsJ7X3w/soTw5N7hxJnkosJbj/4dubO3fzJyp7aosP3r7yN3nh7tjtxZlt6qgiUE2DDihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR/8sSwpTiosLUEsKU2qLC1BLCpKlVhalFBUsIsGHECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROiNmoJYcrUosKyLMu1W7fHbpVMLSo82r04cqdoalHh70+/soQAm0ScECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFq7RLCOz98PbKEcOXm6cSZpAfPT0fu7BzcHblTc/3Gw5E7x/evjtwp+u2Xzy0hwCYRJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6I2l73cWrB4OnR7sid4qLC1ILBk3uHI3dqiwpTCwZ7+ycjd5Zlc1YVvJwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiFo7UzLF3MmrmTtZb3JaZGry5HXPnXg5IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFqa7VanfvxowufnP9xg00tKizLslx6fHnsVsnUosK1W7dH7hRNLSp88/5PW2f97uWEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihKi1SwjAm+PlhChxQpQ4IUqcECVOiBInRP0H3NO3PQoCCbMAAAAASUVORK5CYII=\n",
129 | "text/plain": [
130 | ""
131 | ]
132 | },
133 | "metadata": {
134 | "needs_background": "light"
135 | },
136 | "output_type": "display_data"
137 | }
138 | ],
139 | "source": [
140 | "vis(V(Z(3))*V(S(4)),V(Z(3))*V(S(4)))"
141 | ]
142 | },
143 | {
144 | "cell_type": "markdown",
145 | "metadata": {},
146 | "source": [
147 | "And in the bottom left we have maps $V_{Z_3}\\otimes V_{S_4} \\rightarrow V_{SO(2)}^{\\otimes 2}$ of which the solutions are the product of $V_{Z_3}\\otimes V_{S_4} \\rightarrow \\mathbb{R}$ (the vector $\\mathbf{1}$) and $\\mathbb{R} \\rightarrow V_{SO(2)}^{\\otimes 2}$ (the flattened $I_{3\\times 3}$)"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": 5,
153 | "metadata": {},
154 | "outputs": [
155 | {
156 | "data": {
157 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAATAAAADnCAYAAACZtwrQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAADZUlEQVR4nO3dsY3CQBRF0bVFEYicnC4omlJogCo82wDjYIXMXumc9CU/uhrJgZcxxg9A0frtAwD+SsCALAEDsgQMyBIwIOu0N26vq0+UwFet5+cy3Y48BOCTBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBr97+Q98vtoDMA3nts880LDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7KWMcZ03F7X+QhwgPX8XKbbkYcAfJKAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZJ32xvvldtAZAO89tvnmBQZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZyxhjOm6v63wEOMB6fi7T7chDAD5JwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7J2/wsJ8J95gQFZAgZkCRiQJWBAloABWQIGZP0C0FAeWUQoKl0AAAAASUVORK5CYII=\n",
158 | "text/plain": [
159 | ""
160 | ]
161 | },
162 | "metadata": {
163 | "needs_background": "light"
164 | },
165 | "output_type": "display_data"
166 | }
167 | ],
168 | "source": [
169 | "vis(V(Z(3))*V(S(4)),T(2)(SO(3)),False)"
170 | ]
171 | },
172 | {
173 | "cell_type": "markdown",
174 | "metadata": {},
175 | "source": [
176 | "## Wreath Products (coming soon)"
177 | ]
178 | },
179 | {
180 | "cell_type": "markdown",
181 | "metadata": {},
182 | "source": [
183 | "These are all examples of tensor products of representations, which are in turn a kind of representation on the direct product group $G=G_1\\times G_2$. There are other ways of combining groups however, and for many hierarchical structures there is a larger group of symmetries $G_1\\wr G_2$ that contains $G_1 \\times G_2$ but also other elements. These so called wreath products in the very nice paper [Equivariant Maps for Hierarchical Structures](https://arxiv.org/pdf/2006.03627.pdf). Support for the wreath product of representations is not yet implemented but if this is something that would be useful to you, send me an email."
184 | ]
185 | }
186 | ],
187 | "metadata": {
188 | "kernelspec": {
189 | "display_name": "Python 3",
190 | "language": "python",
191 | "name": "python3"
192 | },
193 | "language_info": {
194 | "codemirror_mode": {
195 | "name": "ipython",
196 | "version": 3
197 | },
198 | "file_extension": ".py",
199 | "mimetype": "text/x-python",
200 | "name": "python",
201 | "nbconvert_exporter": "python",
202 | "pygments_lexer": "ipython3",
203 | "version": "3.8.5"
204 | }
205 | },
206 | "nbformat": 4,
207 | "nbformat_minor": 2
208 | }
209 |
--------------------------------------------------------------------------------
/docs/notebooks/6multilinear_maps.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Multilinear Maps"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "Our codebase extends trivially to multilinear maps, since these maps are in fact just linear maps in disguise.\n",
15 | "\n",
16 | "If we have a sequence of representations $R_1$, $R_2$, $R_3$ for example, we can write the (bi)linear maps $R_1\\rightarrow R_2\\rightarrow R_3$. This way of thinking about maps of multiple variables borrowed from programming languages and curried functions is very powerful."
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | "We can think of such an object $R_1\\rightarrow R_2\\rightarrow R_3$ either as $R_1 \\rightarrow (R_2\\rightarrow R_3)$: a linear map from $R_1$ to linear maps from $R_2$ to $R_3$ or as\n",
24 | "$(R_1\\times R_2) \\rightarrow R_3$: a bilinear map from $R_1$ and $R_2$ to $R_3$. Since linear maps from one representation to another are just another representation in our type system, you can use this way of thinking to find the equivariant solutions to arbitrary multilinear maps."
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {},
30 | "source": [
31 | "For example, we can get the bilinear $SO(4)$ equivariant maps $(R_1\\times R_2) \\rightarrow R_3$ with the code below."
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 1,
37 | "metadata": {},
38 | "outputs": [
39 | {
40 | "name": "stdout",
41 | "output_type": "stream",
42 | "text": [
43 | "(2940, 27)\n"
44 | ]
45 | }
46 | ],
47 | "source": [
48 | "from emlp.groups import SO,rel_err\n",
49 | "from emlp.reps import V\n",
50 | "\n",
51 | "G = SO(4)\n",
52 | "W = V(G)\n",
53 | "R1 = 3*W+W**2 # some example representations\n",
54 | "R2 = W.T+W**0\n",
55 | "R3 = W**0 +W**2 +W\n",
56 | "\n",
57 | "Q = (R1>>(R2>>R3)).equivariant_basis()\n",
58 | "print(Q.shape)"
59 | ]
60 | },
61 | {
62 | "cell_type": "markdown",
63 | "metadata": {},
64 | "source": [
65 | "And we can verify that these multilinear solutions are indeed equivariant"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 2,
71 | "metadata": {},
72 | "outputs": [
73 | {
74 | "data": {
75 | "text/plain": [
76 | "DeviceArray(1.2422272e-07, dtype=float32)"
77 | ]
78 | },
79 | "execution_count": 2,
80 | "metadata": {},
81 | "output_type": "execute_result"
82 | }
83 | ],
84 | "source": [
85 | "import numpy as np\n",
86 | "\n",
87 | "example_map = (Q@np.random.randn(Q.shape[-1]))\n",
88 | "example_map = example_map.reshape(R3.size(),R2.size(),R1.size())\n",
89 | "\n",
90 | "x1 = np.random.randn(R1.size())\n",
91 | "x2 = np.random.randn(R2.size())\n",
92 | "g = G.sample()\n",
93 | "\n",
94 | "out1 = np.einsum(\"ijk,j,k\",example_map,R2.rho(g)@x2,R1.rho(g)@x1)\n",
95 | "out2 = R3.rho(g)@np.einsum(\"ijk,j,k\",example_map,x2,x1)\n",
96 | "rel_err(out1,out2)"
97 | ]
98 | },
99 | {
100 | "cell_type": "markdown",
101 | "metadata": {},
102 | "source": [
103 | "Note that the output mapping is of shape $(\\mathrm{dim}(R_3),\\mathrm{dim}(R_2),\\mathrm{dim}(R_1))$\n",
104 | "with the inputs to the right as you would expect with a matrix. \n",
105 | "\n",
106 | "Note the parenthesis in the expression `(R1>>(R2>>R3))` since the python `>>` associates to the right.\n",
107 | "The notation $R_1\\rightarrow R_2 \\rightarrow R_3$ or `(R1>>(R2>>R3))` can be a bit confusing since the inputs are on the right. It can be easier in this concept to instead reverse the arrows and express the same object as $R_3\\leftarrow R_2\\leftarrow R_1$ or `R3<>R2` wherever you like, and it is usually more intuitive."
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": 3,
115 | "metadata": {},
116 | "outputs": [
117 | {
118 | "data": {
119 | "text/plain": [
120 | "True"
121 | ]
122 | },
123 | "execution_count": 3,
124 | "metadata": {},
125 | "output_type": "execute_result"
126 | }
127 | ],
128 | "source": [
129 | "R3<>(R2>>R3))"
130 | ]
131 | }
132 | ],
133 | "metadata": {
134 | "kernelspec": {
135 | "display_name": "Python 3",
136 | "language": "python",
137 | "name": "python3"
138 | },
139 | "language_info": {
140 | "codemirror_mode": {
141 | "name": "ipython",
142 | "version": 3
143 | },
144 | "file_extension": ".py",
145 | "mimetype": "text/x-python",
146 | "name": "python",
147 | "nbconvert_exporter": "python",
148 | "pygments_lexer": "ipython3",
149 | "version": "3.8.5"
150 | }
151 | },
152 | "nbformat": 4,
153 | "nbformat_minor": 4
154 | }
155 |
--------------------------------------------------------------------------------
/docs/notebooks/_colab_preamble.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "First install the repo and requirements."
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "%pip --quiet install git+https://github.com/mfinzi/equivariant-MLP.git"
17 | ]
18 | }
19 | ],
20 | "metadata": {
21 | "kernelspec": {
22 | "display_name": "Python 3",
23 | "language": "python",
24 | "name": "python3"
25 | },
26 | "language_info": {
27 | "codemirror_mode": {
28 | "name": "ipython",
29 | "version": 3
30 | },
31 | "file_extension": ".py",
32 | "mimetype": "text/x-python",
33 | "name": "python",
34 | "nbconvert_exporter": "python",
35 | "pygments_lexer": "ipython3",
36 | "version": "3.8.5"
37 | },
38 | "accelerator": "GPU"
39 | },
40 | "nbformat": 4,
41 | "nbformat_minor": 4
42 | }
43 |
--------------------------------------------------------------------------------
/docs/notebooks/colabs/5mixed_tensors.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "First install the repo and requirements."
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "%pip --quiet install git+https://github.com/mfinzi/equivariant-MLP.git"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | "# Combining Representations from Different Groups (experimental)\n",
24 | "## Direct Product groups"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {},
30 | "source": [
31 | "It is possible to combine representations from different groups. These kinds of representations are relevant when there are multiple structures in the data. For example, a point cloud is a set of vectors which transforms both under permutations and rotations of the vectors. We can formalize this as $V_{S_n}\\otimes V_{SO(3)}$ or in tensor notation $T_1^{S_n}\\otimes T_1^{SO(3)}$. While this object could be expressed as the representation of the product group $V_{S_n \\times SO(3)}$, other objects like $T_k^{S_n}\\otimes T_j^{SO(3)}$ can not be so easily.\n",
32 | "\n",
33 | "Nevertheless, we can calculate the symmetric bases for these objects. For example, maps from vector edge features $T_1^{SO(3)}\\otimes T_2^{S_n}$ to matrix node features $T_2^{SO(3)}\\otimes T_1^{S_n}$ can be computed:"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 1,
39 | "metadata": {},
40 | "outputs": [
41 | {
42 | "name": "stdout",
43 | "output_type": "stream",
44 | "text": [
45 | "V_SO(3)⊗V²_S(4) --> V²_SO(3)⊗V_S(4)\n"
46 | ]
47 | }
48 | ],
49 | "source": [
50 | "from emlp.groups import *\n",
51 | "from emlp.reps import T,vis,V,Scalar\n",
52 | "\n",
53 | "repin,repout = T(1)(SO(3))*T(2)(S(4)),T(2)(SO(3))*T(1)(S(4))\n",
54 | "print(repin,\"-->\", repout)"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 7,
60 | "metadata": {},
61 | "outputs": [
62 | {
63 | "data": {
64 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAATAAAADnCAYAAACZtwrQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAH/UlEQVR4nO3dMYhdRRsG4JPEIoKtIBjcbCxFLJU0AUurkCZiF7AIIqwhhUgKEQuxCMs2VoJWkjRilVKwESyDWGl2jdhZplkCm7UQxGLmmLnMuee+Z5+nnOueuZvL//4D7879Th0fHw8AiU7P/QYAViXAgFgCDIglwIBYAgyI9czYi9t7t1WUwKwOdm6eqr3mBAbEEmBALAEGxBJgQCwBBsQSYEAsAQbEEmBALAEGxBJgQCwBBsQSYEAsAQbEEmBALAEGxBJgQCwBBsQSYEAsAQbEEmBALAEGxBJgQCwBBsQanQtZ1WtaZG3aW+351elwncy17yp7+wzWsu+z5x8V11+69bjLtvtvP199bevjH4vrf350sWmPo7Pl9TOH5fVzn7Xt2/r8mtq+w079Z5zAgFgCDIglwIBYAgyIJcCAWOMt5NSN0NTN2zraw6n5DGb14pVfiuv7n5Qbudbm7cKdv6qv/fHtK8X1rVv1nyl5eLncdNbaw1a137lXOznGCQyIJcCAWAIMiCXAgFgCDIi12l3IXqa+h9erMZvzjuTUfAYr6dW81RrCYRiGc1fKdwNbG9Ct79pay9b2s/Y7tP4brcIJDIglwIBYAgyIJcCAWAIMiDVvCzm1Xo3Z1N+Kuq495pDyGXRqM9dxL7B1j7Gms6S1/Ry7z9mi1n6OcQIDYgkwIJYAA2IJMCCWAANinTo+rtc423u3yy/ONZMw5fkd93juYfn/Yz59/+suz9/54Z3i+sFbXxbXX/7+WtPznxyeKa6fPntUXH/w5ldN+7Y+v6a6793rTc+hv4Odm9X/pTmBAbEEGBBLgAGxBBgQS4ABsfrOhWxt3pb6/I573P/wi+L69r13i+ut7dvepW+K6699/l5xfbex/bzx09Xieq09bFX7fXu1k2w2JzAglgADYgkwIJYAA2IJMCCWAANirfaV0kv984eel78n3qPXnw/U/szhQac/39h9/W5xvab1zzdq77/134dMTmBALAEGxBJgQCwBBsQSYECsvoNt09vDTWwnG019ubn1+bWWsKa1/axdRm9Vaz+H80+6PJ9pOIEBsQQYEEuAAbEEGBBLgAGxsgbbdmrqqubad2zvmnW8J4bnfp9nsPAwGC7873Ne+NVgW2B5BBgQS4ABsQQYEEuAAbH6DrZtNfXdwyU0dXP9DnM2shtkrsHCw2C48FPt3e1JAGsmwIBYAgyIJcCAWAIMiNX3G1lbTX0XsldreRIbOe3nqKnncg6D2ZxPwwkMiCXAgFgCDIglwIBYAgyINW8LObVerWXPuZA1G9ayzWbT2s9G67gXeNJmc/68W/8ZJzAglgADYgkwIJYAA2IJMCCWuZCbsO8qe/sMlrEv/+tg56a5kMDyCDAglgADYgkwIJYAA2KZC7npfAZQ5QQGxBJgQCwBBsQSYEAsAQbEMhfyaZ6/5HtyPgOCOYEBsQQYEEuAAbEEGBBLgAGxzIX8r02cC9lzjzmkfAbazEhOYEAsAQbEEmBALAEGxBJgQKzRFvLZ84+K6y/detxl8/23ny+ub338Y3H9z48uNj3/6Gx5/cxhef3cZ2371p4/tkdNbe/fdt9oe1Avc7VyrfvO9T7X0X6etBmZK/ybOoEBsQQYEEuAAbEEGBBLgAGxRlvIF6/8Ulzf/6TcyrU2bxfu/FVc/+PbV4rrW7fK/33Nw8vllnOsPWwx9vu2NqBsqLnmco7tYTbnv5zAgFgCDIglwIBYAgyIJcCAWCt9I2utSWtt3mot4bkr5XuBre3n1ndtrWVr+1l7/2PvqVcDykKs4y7kgmdzOoEBsQQYEEuAAbEEGBBLgAGxus6F7NVO9nr+WEtY0tp+1u5yrqLWgA4Pu23BSbXg2ZxOYEAsAQbEEmBALAEGxBJgQKxTx8f1CmF773b5xalbh/Tnd9zDbM7xfXs13Bs3l3MYNm8+40z7Hnxws7qzExgQS4ABsQQYEEuAAbEEGBBr/C7k1Heilvr8jnuYzTlu6vu3s34L6YLnOfbiBAbEEmBALAEGxBJgQCwBBsRa7RtZl9oe9rw7OfEeZnP+o/b+FzGXc64GNKj9dAIDYgkwIJYAA2IJMCCWAANidZ0LGd8ebmI72chsztVU53L+3uXxWYLaTycwIJYAA2IJMCCWAANiCTAglgADYo0Otn31xm7xxU/f/7rL5js/vFNcP3jry+L6y99fa3r+k8MzxfXTZ4+K6w/e/Kpp39rzx/aoqe5993rTc7rZsOGmi92X/3WwY7AtsEACDIglwIBYAgyIJcCAWKOXue9/+EVxffveu8X11uZt79I3xfXXPn+vuL7b2H7e+OlqcX2sPWwx9vu2NqBVQRdrYd2cwIBYAgyIJcCAWAIMiCXAgFgrfaV0rUlrbd5qLeGDTu3n7ut3i+s1re1n7f2PvadeDejktJ8EcAIDYgkwIJYAA2IJMCCWAANidR1s26ud7PX8sZawpLX9rN3lXEWtAR22nnTbI4L2kwZOYEAsAQbEEmBALAEGxBJgQKzRuZDbe7fLL9Z/pHH3yvpJnA1oLiEUmQsJLJIAA2IJMCCWAANiCTAg1vhdyKmbsdbnt7afGjxYNCcwIJYAA2IJMCCWAANiCTAgVtdvZG029V3IXq2le4qwkZzAgFgCDIglwIBYAgyIJcCAWPO2kFPr1Vr2ups5RqMJzZzAgFgCDIglwIBYAgyIJcCAWKNzIQE2mRMYEEuAAbEEGBBLgAGxBBgQS4ABsf4G+uFQ5DwSeDQAAAAASUVORK5CYII=\n",
65 | "text/plain": [
66 | ""
67 | ]
68 | },
69 | "metadata": {
70 | "needs_background": "light"
71 | },
72 | "output_type": "display_data"
73 | }
74 | ],
75 | "source": [
76 | "vis(repin,repout,cluster=False)"
77 | ]
78 | },
79 | {
80 | "cell_type": "markdown",
81 | "metadata": {},
82 | "source": [
83 | "Or perhaps you would like equivariant maps from two sequence of sets and a matrix (under 3D rotations) to itself. You can go wild."
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 2,
89 | "metadata": {},
90 | "outputs": [
91 | {
92 | "name": "stdout",
93 | "output_type": "stream",
94 | "text": [
95 | "Rep: V²+2V_S(4)⊗V_Z(3)\n",
96 | "Linear maps: V⁴+4V_S(4)⊗V_Z(3)⊗V²_SO(3)+4V²_S(4)⊗V²_Z(3)\n"
97 | ]
98 | }
99 | ],
100 | "source": [
101 | "rep = 2*T(1)(Z(3))*T(1)(S(4))+T(2)(SO(3))\n",
102 | "print(f\"Rep: {rep}\")\n",
103 | "print(f\"Linear maps: {rep>>rep}\")"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 7,
109 | "metadata": {
110 | "scrolled": true
111 | },
112 | "outputs": [
113 | {
114 | "data": {
115 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAI+UlEQVR4nO3dX8jfVR0H8LMhudrQLP/xYG4p1Voxc2GR0KVYEYWVrfBKKCo0NU0iqLCsJFdGGUQs6EqSjCKoXOpFNy2a2DJLZur8Uww2ZyPZTENaF7vo5nvOw/fsu+/ee57X6/J7OJ4H5psD55zf57Pi8OHDBciz8nj/AcAw4YRQwgmhhBNCCSeEOqk1+KHtn6we5e58+jWT/RGbNzxQHbv/2bWTrVNKKU/sGP67X/ed3ZOuc+Pv7q2Ofe6mT0y61qGFFYPfz7r/xUnXedn+Q9WxXZ86ZdK1NmzZN/j9pd1PTrrOSeetq45NuVZrnbsf2zL4D2jnhFDCCaGEE0IJJ4QSTgglnBCqeZXSui658Ny/j55T07ouuejVT3XNq6ldmTx67Xmj57S0rku+cdMPRs9pqV2Z7L3o5NFzWlrXJeu//9zoObTZOSGUcEIo4YRQwgmhhBNCNU9rW2qnsrVT3Naclp6T3ClPcUuZ7yS3dorbmtPSOpGd6yS3dorbmsMRdk4IJZwQSjghlHBCKOGEUMIJoVa0Kr6ff+fXJi0HX7tm+deLL59ymeZj+Z3vWph0rdo1y2kPT7pM85rl5o9fOelatWuWQ+temnSd1jXLyoMvDH5XQwg47oQTQgknhBJOCCWcEKr58H3FU/VT1J4H34cqy536039X5/Q8lt+99/Tq2OZtw9XleyvLv7YMn0Cv/PGa6pyeB99X3vOx6tiPtv5w8Htv2ZPVe2qH9PX/XXoey//n9NXVsVWV09rlxM4JoYQTQgknhBJOCCWcEEo4IVTzKqWnrk7PFUtPZfnF5tXUrkymriw/Z3X05JpEUzfwXU7snBBKOCGUcEIo4YRQwgmhhBNCdbdjmKsJ7VJs4DtX64KeBr6LzauZq4HvcmLnhFDCCaGEE0IJJ4QSTgjVrPj+7oWrJ634XjvJPby2XkOoR+ux/FzV5e/4/TsmXad1kvvPC06bdK3aSe5cleVLKWXtXXsGv6v4Dhx3wgmhhBNCCSeEEk4IJZwQqvnw/cJtw8fZpfQ9+K61LnhiR/2h+pRtH0qpt37oeSxfSr31w+on52td8KoHDwx+761JVGv9UGv7UErfY/l62wdKsXNCLOGEUMIJoYQTQgknhGqe1vaU7ug5xe2pLL/YvJraqezUleXnrI6eXPakt4Evdk6IJZwQSjghlHBCKOGEUMIJoborvs/VhHYpNvCdqzp6TwPfxebVzNXAdzmxc0Io4YRQwgmhhBNCCSeEEk4I1X2VUtPzS5adZaFrrZ4GvqWMb/3Qc83ybFk3ep2eX7Ic8dLotXquWXraPvQ28N1yyftGr7XU2DkhlHBCKOGEUMIJoYQTQjWb52687tvVwbN2HJzsj3hs8yuqY2e8Yf9k65RSyqYz/jH4/YHvXjjpOr/6+jerY++//vpJ1zrl138Z/L7re+snXeeWi39WHdvyrY9MutbzZw/2ky3nfmX7pOs8/aWLq2NTrtVa55GbPqN5LpxIhBNCCSeEEk4IJZwQSjghVPPhe+u6ZO/b1oyeU9O6LnnmkeHmtIvNq6ldmbz1mp2j57S0rkt+cdtto+e01K5M1l+9a/ScltZ1yY033Dl6Dm12TgglnBBKOCGUcEIo4YRQ3WVKaqeytVPc1pyWnpPcKU9xS5nvJLd2itua09I6kZ3rJLd2ituawxF2TgglnBBKOCGUcEIo4YRQwgmhJq/43vNYflV5vmut2pVJ67F8faSu55rlwVsvGL1Oz2P5Ukq56L5rRq/Vc82y74qNo9fpeSxfSilfvuOjo9daauycEEo4IZRwQijhhFDCCaGEE0I1r1LOv/1v1bGeX2OcfGC4u8ML286szun5JcupjauZWuuH3rYPf3zmnMHv1331J9U5Pb/GuPTmz1bH7v/icOuH3ppEz73nzYPfz7zjz9U5Pb9k+fz2D1THVo3+ry09dk4IJZwQSjghlHBCKOGEUM3T2p4H3z2nuD2P5RebV9PzWL7nJHfO6ujJNYmmbuC7nNg5IZRwQijhhFDCCaGEE0IJJ4TqriE0VxPapdjAd67WBb01iaa8Zpm67cNyYueEUMIJoYQTQgknhBJOCCWcEGrydgw9v2R5/NOv71qrp7t2T+uHnmuWV45epb91we1f+PDotXquWeZq+1BKKU/eML6dxVJj54RQwgmhhBNCCSeEEk4I1Tyt3b9puEJ7KX0PvmvV0TfNVFm+lHp1+Z7H8qXUq8uvfOjx6pypq6PfUqku31uTqFZdvlZZvpS+x/K1yvIcYeeEUMIJoYQTQgknhBJOCCWcEKp5ldLz4LvniqXnsfxi82p6Hsv3XLPM2boguSZRbwNf7JwQSzghlHBCKOGEUMIJobrLlMzVhHYpNvCdqzp6b9mTKU9yp64sv5zYOSGUcEIo4YRQwgmhhBNCCSeEmrzie89j+frlS1tPA9+e6vJ91ywvjl6ntzr6vis2jl6r55plrsrypZTyzq03jl5rqbFzQijhhFDCCaGEE0IJJ4QSTgi14vDheuuCh54+pz5I01VXjW80y//9duvWwe+XLrxl0nXe+9cD1bFfvum0Wda59o33rRj6bueEUMIJoYQTQgknhBJOCNU8rb1k5eVOa4ly8PK3V8fW3PWHwe97fr6hOmfhsoeP+m86Wvf+9y6ntXAiEU4IJZwQSjghlHBCKOGEUJPXEIJjqXZdUkr9mmXhsvqcHnNdzdg5IZRwQijhhFDCCaGEE0I5reWE0j4pHT6V7Xks3zLXY3k7J4QSTgglnBBKOCGUcEIo4YRQKr4fIyq+H50P3nrP4Pcpq7CXUspv9vypOjZldXkV32EJEU4IJZwQSjghlHBCKOGEUNoxQEXt1yytX7LUfjXT+iWLdgxwghFOCCWcEEo4IZRwQig1hFgyek5KW2qnsq2aRFNWl7dzQijhhFDCCaGEE0IJJ4QSTgjlKoUlY+o2CfWrmfENfHvaPtg5IZRwQijhhFDCCaGEE0Kp+H6MqPh+dOaq+N6qxD7lWq3K8ivPflSZEjiRCCeEEk4IJZwQSjghlHBCqOZVCnD82DkhlHBCKOGEUMIJoYQTQgknhPof2soBg+pZzXoAAAAASUVORK5CYII=\n",
116 | "text/plain": [
117 | ""
118 | ]
119 | },
120 | "metadata": {
121 | "needs_background": "light"
122 | },
123 | "output_type": "display_data"
124 | }
125 | ],
126 | "source": [
127 | "vis(rep,rep)"
128 | ]
129 | },
130 | {
131 | "cell_type": "markdown",
132 | "metadata": {},
133 | "source": [
134 | "The kronecker product of the individual solutions for different groups can be seen in these basis matrices, in the top left hand corner we have a circulant matrix of deep set solutions (identity + background)"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 3,
140 | "metadata": {},
141 | "outputs": [
142 | {
143 | "data": {
144 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAEpklEQVR4nO3cPWuedRiH4TttoNAvoHPFRYJkKmSU6l4EwZZu7dSl2MmX2ZcpkiUfIQiCe1F06BBIEVIpLmKHDsWv4NA+foHk6XLVng8cx/jccJHl5A8Zflur1WoBei686T8AOJs4IUqcECVOiBInRG2v+/jlHx+P/Cv3+P7ViTNJz+68GLlz5ebpyJ2aB89PR+7sHNwduVP057efbZ31u5cTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LWLiFMLRjs7Z+M3CkuKkwtGDw92h25U1tUmFoweHLvcOTOsmzOqoKXE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROi1i4hTLGo8GoWFdabXC+YWlV43YsKXk6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiNparVbnfvzgw+/O/7jBphYVlmVZHu1eHLtVMrWocOnx5ZE7RVOLChfe/mvrzN9HrgPjxAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUWtnSl7+8+7ITMnOwd2JM0nXbzwcuXN8/+rInZpnd16M3Lly83TkTtHPL380UwKbRJwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IWrtEsJ7X3w/soTw5N7hxJnkosJbj/4dubO3fzJyp7aosP3r7yN3nh7tjtxZlt6qgiUE2DDihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR/8sSwpTiosLUEsKU2qLC1BLCpKlVhalFBUsIsGHECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROiNmoJYcrUosKyLMu1W7fHbpVMLSo82r04cqdoalHh70+/soQAm0ScECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFq7RLCOz98PbKEcOXm6cSZpAfPT0fu7BzcHblTc/3Gw5E7x/evjtwp+u2Xzy0hwCYRJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6I2l73cWrB4OnR7sid4qLC1ILBk3uHI3dqiwpTCwZ7+ycjd5Zlc1YVvJwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiFo7UzLF3MmrmTtZb3JaZGry5HXPnXg5IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFqa7VanfvxowufnP9xg00tKizLslx6fHnsVsnUosK1W7dH7hRNLSp88/5PW2f97uWEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihKi1SwjAm+PlhChxQpQ4IUqcECVOiBInRP0H3NO3PQoCCbMAAAAASUVORK5CYII=\n",
145 | "text/plain": [
146 | ""
147 | ]
148 | },
149 | "metadata": {
150 | "needs_background": "light"
151 | },
152 | "output_type": "display_data"
153 | }
154 | ],
155 | "source": [
156 | "vis(V(Z(3))*V(S(4)),V(Z(3))*V(S(4)))"
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "metadata": {},
162 | "source": [
163 | "And in the bottom left we have maps $V_{Z_3}\\otimes V_{S_4} \\rightarrow V_{SO(2)}^{\\otimes 2}$ of which the solutions are the product of $V_{Z_3}\\otimes V_{S_4} \\rightarrow \\mathbb{R}$ (the vector $\\mathbf{1}$) and $\\mathbb{R} \\rightarrow V_{SO(2)}^{\\otimes 2}$ (the flattened $I_{3\\times 3}$)"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 5,
169 | "metadata": {},
170 | "outputs": [
171 | {
172 | "data": {
173 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAATAAAADnCAYAAACZtwrQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAADZUlEQVR4nO3dsY3CQBRF0bVFEYicnC4omlJogCo82wDjYIXMXumc9CU/uhrJgZcxxg9A0frtAwD+SsCALAEDsgQMyBIwIOu0N26vq0+UwFet5+cy3Y48BOCTBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBr97+Q98vtoDMA3nts880LDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7KWMcZ03F7X+QhwgPX8XKbbkYcAfJKAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZJ32xvvldtAZAO89tvnmBQZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZyxhjOm6v63wEOMB6fi7T7chDAD5JwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7J2/wsJ8J95gQFZAgZkCRiQJWBAloABWQIGZP0C0FAeWUQoKl0AAAAASUVORK5CYII=\n",
174 | "text/plain": [
175 | ""
176 | ]
177 | },
178 | "metadata": {
179 | "needs_background": "light"
180 | },
181 | "output_type": "display_data"
182 | }
183 | ],
184 | "source": [
185 | "vis(V(Z(3))*V(S(4)),T(2)(SO(3)),False)"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "metadata": {},
191 | "source": [
192 | "## Wreath Products (coming soon)"
193 | ]
194 | },
195 | {
196 | "cell_type": "markdown",
197 | "metadata": {},
198 | "source": [
199 | "These are all examples of tensor products of representations, which are in turn a kind of representation on the direct product group $G=G_1\\times G_2$. There are other ways of combining groups however, and for many hierarchical structures there is a larger group of symmetries $G_1\\wr G_2$ that contains $G_1 \\times G_2$ but also other elements. These so called wreath products in the very nice paper [Equivariant Maps for Hierarchical Structures](https://arxiv.org/pdf/2006.03627.pdf). Support for the wreath product of representations is not yet implemented but if this is something that would be useful to you, send me an email."
200 | ]
201 | }
202 | ],
203 | "metadata": {
204 | "accelerator": "GPU",
205 | "kernelspec": {
206 | "display_name": "Python 3",
207 | "language": "python",
208 | "name": "python3"
209 | },
210 | "language_info": {
211 | "codemirror_mode": {
212 | "name": "ipython",
213 | "version": 3
214 | },
215 | "file_extension": ".py",
216 | "mimetype": "text/x-python",
217 | "name": "python",
218 | "nbconvert_exporter": "python",
219 | "pygments_lexer": "ipython3",
220 | "version": "3.8.5"
221 | }
222 | },
223 | "nbformat": 4,
224 | "nbformat_minor": 4
225 | }
226 |
--------------------------------------------------------------------------------
/docs/notebooks/colabs/6multilinear_maps.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "First install the repo and requirements."
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "%pip --quiet install git+https://github.com/mfinzi/equivariant-MLP.git"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | "# Multilinear Maps"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {},
29 | "source": [
30 | "Our codebase extends trivially to multilinear maps, since these maps are in fact just linear maps in disguise.\n",
31 | "\n",
32 | "If we have a sequence of representations $R_1$, $R_2$, $R_3$ for example, we can write the (bi)linear maps $R_1\\rightarrow R_2\\rightarrow R_3$. This way of thinking about maps of multiple variables borrowed from programming languages and curried functions is very powerful."
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "metadata": {},
38 | "source": [
39 | "We can think of such an object $R_1\\rightarrow R_2\\rightarrow R_3$ either as $R_1 \\rightarrow (R_2\\rightarrow R_3)$: a linear map from $R_1$ to linear maps from $R_2$ to $R_3$ or as\n",
40 | "$(R_1\\times R_2) \\rightarrow R_3$: a bilinear map from $R_1$ and $R_2$ to $R_3$. Since linear maps from one representation to another are just another representation in our type system, you can use this way of thinking to find the equivariant solutions to arbitrary multilinear maps."
41 | ]
42 | },
43 | {
44 | "cell_type": "markdown",
45 | "metadata": {},
46 | "source": [
47 | "For example, we can get the bilinear $SO(4)$ equivariant maps $(R_1\\times R_2) \\rightarrow R_3$ with the code below."
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 1,
53 | "metadata": {},
54 | "outputs": [
55 | {
56 | "name": "stdout",
57 | "output_type": "stream",
58 | "text": [
59 | "(2940, 27)\n"
60 | ]
61 | }
62 | ],
63 | "source": [
64 | "from emlp.groups import SO,rel_err\n",
65 | "from emlp.reps import V\n",
66 | "\n",
67 | "G = SO(4)\n",
68 | "W = V(G)\n",
69 | "R1 = 3*W+W**2 # some example representations\n",
70 | "R2 = W.T+W**0\n",
71 | "R3 = W**0 +W**2 +W\n",
72 | "\n",
73 | "Q = (R1>>(R2>>R3)).equivariant_basis()\n",
74 | "print(Q.shape)"
75 | ]
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "metadata": {},
80 | "source": [
81 | "And we can verify that these multilinear solutions are indeed equivariant"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 2,
87 | "metadata": {},
88 | "outputs": [
89 | {
90 | "data": {
91 | "text/plain": [
92 | "DeviceArray(1.2422272e-07, dtype=float32)"
93 | ]
94 | },
95 | "execution_count": 2,
96 | "metadata": {},
97 | "output_type": "execute_result"
98 | }
99 | ],
100 | "source": [
101 | "import numpy as np\n",
102 | "\n",
103 | "example_map = (Q@np.random.randn(Q.shape[-1]))\n",
104 | "example_map = example_map.reshape(R3.size(),R2.size(),R1.size())\n",
105 | "\n",
106 | "x1 = np.random.randn(R1.size())\n",
107 | "x2 = np.random.randn(R2.size())\n",
108 | "g = G.sample()\n",
109 | "\n",
110 | "out1 = np.einsum(\"ijk,j,k\",example_map,R2.rho(g)@x2,R1.rho(g)@x1)\n",
111 | "out2 = R3.rho(g)@np.einsum(\"ijk,j,k\",example_map,x2,x1)\n",
112 | "rel_err(out1,out2)"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {},
118 | "source": [
119 | "Note that the output mapping is of shape $(\\mathrm{dim}(R_3),\\mathrm{dim}(R_2),\\mathrm{dim}(R_1))$\n",
120 | "with the inputs to the right as you would expect with a matrix. \n",
121 | "\n",
122 | "Note the parenthesis in the expression `(R1>>(R2>>R3))` since the python `>>` associates to the right.\n",
123 | "The notation $R_1\\rightarrow R_2 \\rightarrow R_3$ or `(R1>>(R2>>R3))` can be a bit confusing since the inputs are on the right. It can be easier in this concept to instead reverse the arrows and express the same object as $R_3\\leftarrow R_2\\leftarrow R_1$ or `R3<>R2` wherever you like, and it is usually more intuitive."
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": 3,
131 | "metadata": {},
132 | "outputs": [
133 | {
134 | "data": {
135 | "text/plain": [
136 | "True"
137 | ]
138 | },
139 | "execution_count": 3,
140 | "metadata": {},
141 | "output_type": "execute_result"
142 | }
143 | ],
144 | "source": [
145 | "R3<>(R2>>R3))"
146 | ]
147 | }
148 | ],
149 | "metadata": {
150 | "accelerator": "GPU",
151 | "kernelspec": {
152 | "display_name": "Python 3",
153 | "language": "python",
154 | "name": "python3"
155 | },
156 | "language_info": {
157 | "codemirror_mode": {
158 | "name": "ipython",
159 | "version": 3
160 | },
161 | "file_extension": ".py",
162 | "mimetype": "text/x-python",
163 | "name": "python",
164 | "nbconvert_exporter": "python",
165 | "pygments_lexer": "ipython3",
166 | "version": "3.8.5"
167 | }
168 | },
169 | "nbformat": 4,
170 | "nbformat_minor": 4
171 | }
172 |
--------------------------------------------------------------------------------
/docs/notebooks/colabs/7pytorch_support.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "First install the repo and requirements."
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "%pip --quiet install git+https://github.com/mfinzi/equivariant-MLP.git"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | "# Limited Pytorch Support"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {},
29 | "source": [
30 | "We strongly recommend that users of our libary write native Jax code. However, we understand that due to existing code and/or constraints from the employer, it is sometimes unavoidable to use other frameworks like PyTorch. \n",
31 | "\n",
32 | "To service these requirements, we have added a way that PyTorch users can make use of the equivariant bases $Q\\in \\mathbb{R}^{n\\times r}$ and projection matrices $P = QQ^\\top$ that are computed by our solver. Since these objects are implicitly defined through `LinearOperators`, it is not as straightforward as simply calling `torch.from_numpy(Q)`. However, there is a way to use these operators within PyTorch code while preserving any gradients of the operation. We provide the function `emlp.reps.pytorch_support.torchify_fn` to do this."
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 1,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "import torch\n",
42 | "import jax\n",
43 | "import jax.numpy as jnp\n",
44 | "from emlp.reps import V\n",
45 | "from emlp.groups import S\n",
46 | "\n",
47 | "W =V(S(4))\n",
48 | "rep = 3*W+W**2"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 2,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "Q = (rep>>rep).equivariant_basis()\n",
58 | "P = (rep>>rep).equivariant_projector()"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 3,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "applyQ = lambda v: Q@v\n",
68 | "applyP = lambda v: P@v"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "metadata": {},
74 | "source": [
75 | "The key is to wrap the desired operations as a function, and then we can apply `torchify_fn`. Now instead of taking jax objects as inputs and outputing jax objects, these functions take in PyTorch objects and output PyTorch objects."
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 4,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "from emlp.reps.pytorch_support import torchify_fn\n",
85 | "applyQ_torch = torchify_fn(applyQ)\n",
86 | "applyP_torch = torchify_fn(applyP)"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 5,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "x_torch = torch.arange(Q.shape[-1]).float().cuda()\n",
96 | "x_torch.requires_grad=True\n",
97 | "x_jax = jnp.asarray(x_torch.cpu().data.numpy()) "
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 6,
103 | "metadata": {},
104 | "outputs": [
105 | {
106 | "name": "stdout",
107 | "output_type": "stream",
108 | "text": [
109 | "jax output: [0.48484263 0.07053992 0.07053989 0.07053995 1.6988853 ]\n",
110 | "torch output: tensor([0.4848, 0.0705, 0.0705, 0.0705, 1.6989], device='cuda:0',\n",
111 | " grad_fn=)\n"
112 | ]
113 | }
114 | ],
115 | "source": [
116 | "Qx1 = applyQ(x_jax)\n",
117 | "Qx2 = applyQ_torch(x_torch)\n",
118 | "print(\"jax output: \",Qx1[:5])\n",
119 | "print(\"torch output: \",Qx2[:5])"
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "metadata": {},
125 | "source": [
126 | "The outputs match, and note that the torch outputs will be on whichever is the default jax device. Similarly, the gradients of the two objects also match:"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 7,
132 | "metadata": {},
133 | "outputs": [
134 | {
135 | "data": {
136 | "text/plain": [
137 | "tensor([-2.8704, 2.7858, -2.8704, 2.7858, -2.8704], device='cuda:0')"
138 | ]
139 | },
140 | "execution_count": 7,
141 | "metadata": {},
142 | "output_type": "execute_result"
143 | }
144 | ],
145 | "source": [
146 | "torch.autograd.grad(Qx2.sum(),x_torch)[0][:5]"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": 8,
152 | "metadata": {},
153 | "outputs": [
154 | {
155 | "data": {
156 | "text/plain": [
157 | "DeviceArray([-2.8703732, 2.7858496, -2.8703732, 2.7858496, -2.8703732], dtype=float32)"
158 | ]
159 | },
160 | "execution_count": 8,
161 | "metadata": {},
162 | "output_type": "execute_result"
163 | }
164 | ],
165 | "source": [
166 | "jax.grad(lambda x: (Q@x).sum())(x_jax)[:5]"
167 | ]
168 | },
169 | {
170 | "cell_type": "markdown",
171 | "metadata": {},
172 | "source": [
173 | "So you can safely use these torchified functions within your model, and still compute the gradients correctly."
174 | ]
175 | }
176 | ],
177 | "metadata": {
178 | "accelerator": "GPU",
179 | "kernelspec": {
180 | "display_name": "Python 3",
181 | "language": "python",
182 | "name": "python3"
183 | },
184 | "language_info": {
185 | "codemirror_mode": {
186 | "name": "ipython",
187 | "version": 3
188 | },
189 | "file_extension": ".py",
190 | "mimetype": "text/x-python",
191 | "name": "python",
192 | "nbconvert_exporter": "python",
193 | "pygments_lexer": "ipython3",
194 | "version": "3.8.5"
195 | }
196 | },
197 | "nbformat": 4,
198 | "nbformat_minor": 4
199 | }
200 |
--------------------------------------------------------------------------------
/docs/notebooks/colabs/flax_support.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "First install the repo and requirements."
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "%pip --quiet install git+https://github.com/mfinzi/equivariant-MLP.git"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | "# Using EMLP with Flax"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {},
29 | "source": [
30 | "To use EMLP with [Flax](https://github.com/google/flax) is pretty similar to Objax or Haiku. Just make sure to import from the flax implementation `emlp.nn.flax`"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 1,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "from jax import random\n",
40 | "import numpy as np\n",
41 | "import emlp.nn.flax as nn # import from the flax implementation\n",
42 | "from emlp.reps import T,V\n",
43 | "from emlp.groups import SO\n",
44 | "\n",
45 | "repin= 4*V # Setup some example data representations\n",
46 | "repout = V\n",
47 | "G = SO(3)\n",
48 | "\n",
49 | "x = np.random.randn(5,repin(G).size()) # generate some random data"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": 2,
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "model = nn.EMLP(repin,repout,G)\n",
59 | "\n",
60 | "key = random.PRNGKey(0)\n",
61 | "params = model.init(random.PRNGKey(42), x)\n",
62 | "\n",
63 | "y = model.apply(params, x) # Forward pass with inputs x and parameters"
64 | ]
65 | },
66 | {
67 | "cell_type": "markdown",
68 | "metadata": {},
69 | "source": [
70 | "And indeed, the parameters of the model are registered as expected."
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 3,
76 | "metadata": {},
77 | "outputs": [
78 | {
79 | "data": {
80 | "text/plain": [
81 | "['modules_0', 'modules_1', 'modules_2', 'modules_3']"
82 | ]
83 | },
84 | "execution_count": 3,
85 | "metadata": {},
86 | "output_type": "execute_result"
87 | }
88 | ],
89 | "source": [
90 | "list(params['params'].keys())"
91 | ]
92 | }
93 | ],
94 | "metadata": {
95 | "accelerator": "GPU",
96 | "interpreter": {
97 | "hash": "ec74566b76234e57f2cd5bb0818dcd91369c1a3af290381c3b6efeb6aea6cdd5"
98 | },
99 | "kernelspec": {
100 | "display_name": "Python 3",
101 | "language": "python",
102 | "name": "python3"
103 | },
104 | "language_info": {
105 | "codemirror_mode": {
106 | "name": "ipython",
107 | "version": 3
108 | },
109 | "file_extension": ".py",
110 | "mimetype": "text/x-python",
111 | "name": "python",
112 | "nbconvert_exporter": "python",
113 | "pygments_lexer": "ipython3",
114 | "version": "3.8.5"
115 | }
116 | },
117 | "nbformat": 4,
118 | "nbformat_minor": 4
119 | }
120 |
--------------------------------------------------------------------------------
/docs/notebooks/colabs/haiku_support.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "First install the repo and requirements."
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "%pip --quiet install git+https://github.com/mfinzi/equivariant-MLP.git"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | "# Using EMLP with Haiku"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {},
29 | "source": [
30 | "There are many neural network frameworks for jax, and they are often incompatible. Since most of the functionality of this package is written in pure jax, it can be used with flax, trax, linen, haiku, objax, or whatever your favorite jax NN framework.\n",
31 | "\n",
32 | "\n",
33 | "However, the equivariant neural network layers provided in the [Layers and Models](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html) are made for objax.\n",
34 | "If we try to use them with the popular [Haiku framework](https://dm-haiku.readthedocs.io/en/latest/), things will not work as expected.\n",
35 | "\n",
36 | "## Dont Do This:"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 1,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "import haiku as hk\n",
46 | "from jax import random\n",
47 | "import numpy as np\n",
48 | "import emlp.nn as nn\n",
49 | "from emlp.reps import T,V\n",
50 | "from emlp.groups import SO\n",
51 | "\n",
52 | "repin= 4*V # Setup some example data representations\n",
53 | "repout = V\n",
54 | "G = SO(3)\n",
55 | "\n",
56 | "x = np.random.randn(10,repin(G).size()) # generate some random data"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 2,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "model = nn.EMLP(repin,repout,G)\n",
66 | "net = hk.without_apply_rng(hk.transform(model))\n",
67 | "\n",
68 | "key = random.PRNGKey(0)\n",
69 | "params = net.init(random.PRNGKey(42), x)\n",
70 | "\n",
71 | "y = net.apply(params, x)"
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {},
77 | "source": [
78 | "Although the code executes, we see that Haiku does not recognize the model parameters and treats the network as if it is a stateless jax function."
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 3,
84 | "metadata": {},
85 | "outputs": [
86 | {
87 | "data": {
88 | "text/plain": [
89 | "FlatMapping({})"
90 | ]
91 | },
92 | "execution_count": 3,
93 | "metadata": {},
94 | "output_type": "execute_result"
95 | }
96 | ],
97 | "source": [
98 | "params"
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "metadata": {},
104 | "source": [
105 | "It's not hard to build EMLP layers in Haiku, and for each of the nn layers in [Layers and Models](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html) we have implemented a Haiku version with the same arguments. These layers are accessible via `emlp.nn.haiku` rather than `emlp.nn`. To use EMLP models and equivariant layers with Haiku, instead of the above you should import from `emlp.nn.haiku`."
106 | ]
107 | },
108 | {
109 | "cell_type": "markdown",
110 | "metadata": {},
111 | "source": [
112 | "## Instead, Do This:"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": 4,
118 | "metadata": {
119 | "scrolled": true
120 | },
121 | "outputs": [],
122 | "source": [
123 | "import emlp.nn.haiku as ehk\n",
124 | "\n",
125 | "model = ehk.EMLP(repin,repout,SO(3))\n",
126 | "net = hk.without_apply_rng(hk.transform(model))\n",
127 | "\n",
128 | "key = random.PRNGKey(0)\n",
129 | "params = net.init(random.PRNGKey(42), x)\n",
130 | "y = net.apply(params, x)"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": 5,
136 | "metadata": {},
137 | "outputs": [
138 | {
139 | "data": {
140 | "text/plain": [
141 | "KeysOnlyKeysView(['sequential/hk_linear', 'sequential/hk_bi_linear', 'sequential/hk_linear_1', 'sequential/hk_bi_linear_1', 'sequential/hk_linear_2', 'sequential/hk_bi_linear_2', 'sequential/hk_linear_3'])"
142 | ]
143 | },
144 | "execution_count": 5,
145 | "metadata": {},
146 | "output_type": "execute_result"
147 | }
148 | ],
149 | "source": [
150 | "params.keys()"
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "metadata": {},
156 | "source": [
157 | "With this Haiku EMLP, paramaters are registered as expected.\n",
158 | "\n",
159 | "If your favorite deep learning framework is not one of objax, haiku, or pytorch, don't panic. It's possible to use EMLP with other jax frameworks without much trouble, similar to the objax and haiku implementations. If you need help with this, start a pull request and we can send over some pointers."
160 | ]
161 | }
162 | ],
163 | "metadata": {
164 | "accelerator": "GPU",
165 | "kernelspec": {
166 | "display_name": "Python 3",
167 | "language": "python",
168 | "name": "python3"
169 | },
170 | "language_info": {
171 | "codemirror_mode": {
172 | "name": "ipython",
173 | "version": 3
174 | },
175 | "file_extension": ".py",
176 | "mimetype": "text/x-python",
177 | "name": "python",
178 | "nbconvert_exporter": "python",
179 | "pygments_lexer": "ipython3",
180 | "version": "3.8.5"
181 | }
182 | },
183 | "nbformat": 4,
184 | "nbformat_minor": 4
185 | }
186 |
--------------------------------------------------------------------------------
/docs/notebooks/colabs/imgs/EMLP_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mfinzi/equivariant-MLP/b80815bd4c1ca52b37f3dacc0874e70fdd7b413f/docs/notebooks/colabs/imgs/EMLP_fig.png
--------------------------------------------------------------------------------
/docs/notebooks/colabs/imgs/imgs/EMLP_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mfinzi/equivariant-MLP/b80815bd4c1ca52b37f3dacc0874e70fdd7b413f/docs/notebooks/colabs/imgs/imgs/EMLP_fig.png
--------------------------------------------------------------------------------
/docs/notebooks/flax_support.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Using EMLP with Flax"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "To use EMLP with [Flax](https://github.com/google/flax) is pretty similar to Objax or Haiku. Just make sure to import from the flax implementation `emlp.nn.flax`"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "from jax import random\n",
24 | "import numpy as np\n",
25 | "import emlp.nn.flax as nn # import from the flax implementation\n",
26 | "from emlp.reps import T,V\n",
27 | "from emlp.groups import SO\n",
28 | "\n",
29 | "repin= 4*V # Setup some example data representations\n",
30 | "repout = V\n",
31 | "G = SO(3)\n",
32 | "\n",
33 | "x = np.random.randn(5,repin(G).size()) # generate some random data"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 2,
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "model = nn.EMLP(repin,repout,G)\n",
43 | "\n",
44 | "key = random.PRNGKey(0)\n",
45 | "params = model.init(random.PRNGKey(42), x)\n",
46 | "\n",
47 | "y = model.apply(params, x) # Forward pass with inputs x and parameters"
48 | ]
49 | },
50 | {
51 | "cell_type": "markdown",
52 | "metadata": {},
53 | "source": [
54 | "And indeed, the parameters of the model are registered as expected."
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 3,
60 | "metadata": {},
61 | "outputs": [
62 | {
63 | "output_type": "execute_result",
64 | "data": {
65 | "text/plain": [
66 | "['modules_0', 'modules_1', 'modules_2', 'modules_3']"
67 | ]
68 | },
69 | "metadata": {},
70 | "execution_count": 3
71 | }
72 | ],
73 | "source": [
74 | "list(params['params'].keys())"
75 | ]
76 | }
77 | ],
78 | "metadata": {
79 | "kernelspec": {
80 | "name": "python3",
81 | "display_name": "Python 3.8.5 64-bit ('freshenv': conda)"
82 | },
83 | "language_info": {
84 | "codemirror_mode": {
85 | "name": "ipython",
86 | "version": 3
87 | },
88 | "file_extension": ".py",
89 | "mimetype": "text/x-python",
90 | "name": "python",
91 | "nbconvert_exporter": "python",
92 | "pygments_lexer": "ipython3",
93 | "version": "3.8.5"
94 | },
95 | "interpreter": {
96 | "hash": "ec74566b76234e57f2cd5bb0818dcd91369c1a3af290381c3b6efeb6aea6cdd5"
97 | }
98 | },
99 | "nbformat": 4,
100 | "nbformat_minor": 4
101 | }
--------------------------------------------------------------------------------
/docs/notebooks/haiku_support.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Using EMLP with Haiku"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "There are many neural network frameworks for jax, and they are often incompatible. Since most of the functionality of this package is written in pure jax, it can be used with flax, trax, linen, haiku, objax, or whatever your favorite jax NN framework.\n",
15 | "\n",
16 | "\n",
17 | "However, the equivariant neural network layers provided in the [Layers and Models](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html) are made for objax.\n",
18 | "If we try to use them with the popular [Haiku framework](https://dm-haiku.readthedocs.io/en/latest/), things will not work as expected.\n",
19 | "\n",
20 | "## Dont Do This:"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 1,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "import haiku as hk\n",
30 | "from jax import random\n",
31 | "import numpy as np\n",
32 | "import emlp.nn as nn\n",
33 | "from emlp.reps import T,V\n",
34 | "from emlp.groups import SO\n",
35 | "\n",
36 | "repin= 4*V # Setup some example data representations\n",
37 | "repout = V\n",
38 | "G = SO(3)\n",
39 | "\n",
40 | "x = np.random.randn(10,repin(G).size()) # generate some random data"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 2,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "model = nn.EMLP(repin,repout,G)\n",
50 | "net = hk.without_apply_rng(hk.transform(model))\n",
51 | "\n",
52 | "key = random.PRNGKey(0)\n",
53 | "params = net.init(random.PRNGKey(42), x)\n",
54 | "\n",
55 | "y = net.apply(params, x)"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "metadata": {},
61 | "source": [
62 | "Although the code executes, we see that Haiku does not recognize the model parameters and treats the network as if it is a stateless jax function."
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 3,
68 | "metadata": {},
69 | "outputs": [
70 | {
71 | "data": {
72 | "text/plain": [
73 | "FlatMapping({})"
74 | ]
75 | },
76 | "execution_count": 3,
77 | "metadata": {},
78 | "output_type": "execute_result"
79 | }
80 | ],
81 | "source": [
82 | "params"
83 | ]
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "metadata": {},
88 | "source": [
89 | "It's not hard to build EMLP layers in Haiku, and for each of the nn layers in [Layers and Models](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html) we have implemented a Haiku version with the same arguments. These layers are accessible via `emlp.nn.haiku` rather than `emlp.nn`. To use EMLP models and equivariant layers with Haiku, instead of the above you should import from `emlp.nn.haiku`."
90 | ]
91 | },
92 | {
93 | "cell_type": "markdown",
94 | "metadata": {},
95 | "source": [
96 | "## Instead, Do This:"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 4,
102 | "metadata": {
103 | "scrolled": true
104 | },
105 | "outputs": [],
106 | "source": [
107 | "import emlp.nn.haiku as ehk\n",
108 | "\n",
109 | "model = ehk.EMLP(repin,repout,SO(3))\n",
110 | "net = hk.without_apply_rng(hk.transform(model))\n",
111 | "\n",
112 | "key = random.PRNGKey(0)\n",
113 | "params = net.init(random.PRNGKey(42), x)\n",
114 | "y = net.apply(params, x)"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 5,
120 | "metadata": {},
121 | "outputs": [
122 | {
123 | "data": {
124 | "text/plain": [
125 | "KeysOnlyKeysView(['sequential/hk_linear', 'sequential/hk_bi_linear', 'sequential/hk_linear_1', 'sequential/hk_bi_linear_1', 'sequential/hk_linear_2', 'sequential/hk_bi_linear_2', 'sequential/hk_linear_3'])"
126 | ]
127 | },
128 | "execution_count": 5,
129 | "metadata": {},
130 | "output_type": "execute_result"
131 | }
132 | ],
133 | "source": [
134 | "params.keys()"
135 | ]
136 | },
137 | {
138 | "cell_type": "markdown",
139 | "metadata": {},
140 | "source": [
141 | "With this Haiku EMLP, paramaters are registered as expected.\n",
142 | "\n",
143 | "If your favorite deep learning framework is not one of objax, haiku, or pytorch, don't panic. It's possible to use EMLP with other jax frameworks without much trouble, similar to the objax and haiku implementations. If you need help with this, start a pull request and we can send over some pointers."
144 | ]
145 | }
146 | ],
147 | "metadata": {
148 | "kernelspec": {
149 | "display_name": "Python 3",
150 | "language": "python",
151 | "name": "python3"
152 | },
153 | "language_info": {
154 | "codemirror_mode": {
155 | "name": "ipython",
156 | "version": 3
157 | },
158 | "file_extension": ".py",
159 | "mimetype": "text/x-python",
160 | "name": "python",
161 | "nbconvert_exporter": "python",
162 | "pygments_lexer": "ipython3",
163 | "version": "3.8.5"
164 | }
165 | },
166 | "nbformat": 4,
167 | "nbformat_minor": 4
168 | }
169 |
--------------------------------------------------------------------------------
/docs/notebooks/imgs/EMLP_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mfinzi/equivariant-MLP/b80815bd4c1ca52b37f3dacc0874e70fdd7b413f/docs/notebooks/imgs/EMLP_fig.png
--------------------------------------------------------------------------------
/docs/notebooks/merge_nbs.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #set -x
3 | dir_path=$(dirname $(realpath $0))
4 | PREAMBLE=$dir_path/_colab_preamble.ipynb
5 | NOTEBOOKS=$dir_path/[0-9]*.ipynb
6 | NOTEBOOKSALL=$dir_path/[^_]*.ipynb
7 | mkdir -p $dir_path/colabs
8 | cp -r $dir_path/imgs $dir_path/colabs/imgs
9 | for nb in $NOTEBOOKSALL
10 | do
11 | nbmerge $PREAMBLE $nb -o $dir_path/colabs/$(basename $nb)
12 | done
13 | nbmerge $PREAMBLE $NOTEBOOKS -o $dir_path/colabs/all.ipynb
14 | # nbmerge -i notebooks/[^_]*.ipynb -o notebooks/_colab_all.ipynb
--------------------------------------------------------------------------------
/docs/package/emlp.groups.rst:
--------------------------------------------------------------------------------
1 | Groups
2 | ======
3 |
4 | .. automodule:: emlp.groups
5 | :members:
--------------------------------------------------------------------------------
/docs/package/emlp.nn.rst:
--------------------------------------------------------------------------------
1 | Layers and Models
2 | ====
3 |
4 | .. automodule:: emlp.nn
5 | :members:
6 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/package/emlp.reps.rst:
--------------------------------------------------------------------------------
1 | Representations
2 | ==============
3 |
4 | .. automodule:: emlp.reps
5 | :members:
6 | :exclude-members: Rep
7 |
8 | .. autoclass:: Rep
9 | :members: size,rho,drho,equivariant_basis,equivariant_projector, rho_dense,drho_dense
10 |
11 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | # sphinx <4 required by myst-nb v0.12.0 (Feb 2021)
2 | # sphinx >=3 required by sphinx-autodoc-typehints v1.11.1 (Oct 2020)
3 | Jinja2 < 3.1
4 | sphinx >=3, <4
5 | sphinx_rtd_theme
6 | sphinx-autodoc-typehints
7 | myst-nb
8 | nbmerge
9 | nbsphinx == 0.8.6
10 | recommonmark
11 | # Packages used for CI tests.
12 | pytest
13 | pytest-xdist
14 | #jupytext
15 | # Packages used for notebook execution
16 | matplotlib
17 | #. # install emlp itself from current directory.
18 |
--------------------------------------------------------------------------------
/docs/testing.md:
--------------------------------------------------------------------------------
1 | # Running the Tests
2 |
3 | We have a number of tests checking the equivariance of representations constructed in
4 | different ways (`.T`, `*`, `+`) for the groups that have been implemented (`Z(n)`,`S(n)`,`D(k)`,`SO(n)`, `O(n)`,`Sp(n)`,`SO13()`,`O13()`,`SU(n)`).
5 | We use pytest and some of the tests are automatically generated. Because there is a large amount of tests and it can take quite some time to run them all (about 15 minutes),
6 | you can run a subset using pytests built in features to filter by the matches on the name of the testcase using the `-k` argument.
7 |
8 | For example to run `test_prod` with all the groups you can run
9 | ```pytest tests/equivariance_tests.py -k "test_prod"```
10 |
11 | To run the test case for a specific group could use the filter `-k "test_prod and SO3"` and to run all tests with that group
12 | you could run
13 | ```pytest tests/equivariance_tests.py -k "SO3"```
14 |
15 | Due to pytest parsing limitations, all parenthesis in the test names are stripped.
16 | To list all available tests, (or those that match certain `-k` arguments) use `--co` (for collect only).
17 |
18 | The usual pytest command line arguments apply (like `-v` for verbose).
19 |
20 |
21 | Similarly, you can find tests for "mixed" representations containing sub-representations from different groups in `emlp/tests/product_groups_tests.py`.
--------------------------------------------------------------------------------
/emlp/__init__.py:
--------------------------------------------------------------------------------
1 | # import importlib
2 | # import pkgutil
3 | # __all__ = []
4 | # for loader, module_name, is_pkg in pkgutil.walk_packages(__path__):
5 | # module = importlib.import_module('.'+module_name,package=__name__)
6 | # try:
7 | # globals().update({k: getattr(module, k) for k in module.__all__})
8 | # __all__ += module.__all__
9 | # except AttributeError: continue
10 | # # concatenate the __all__ from each of the submodules (expose to user)
11 |
12 | __version__ = '1.0.3'
13 | from .nn import *
14 | from .groups import *
15 | from .reps import *
16 |
--------------------------------------------------------------------------------
/emlp/datasets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import jax.numpy as jnp
3 | from emlp.reps import Scalar,Vector,T
4 | from emlp.utils import export,Named
5 | from emlp.groups import SO,O,Trivial,Lorentz,RubiksCube,Cube
6 | from functools import partial
7 | import itertools
8 | from jax import vmap,jit
9 | from objax import Module
10 |
11 | @export
12 | class Inertia(object):
13 | def __init__(self,N=1024,k=5):
14 | super().__init__()
15 | self.dim = (1+3)*k
16 | self.X = np.random.randn(N,self.dim)
17 | self.X[:,:k] = np.log(1+np.exp(self.X[:,:k])) # Masses
18 | mi = self.X[:,:k]
19 | ri = self.X[:,k:].reshape(-1,k,3)
20 | I = np.eye(3)
21 | r2 = (ri**2).sum(-1)[...,None,None]
22 | inertia = (mi[:,:,None,None]*(r2*I - ri[...,None]*ri[...,None,:])).sum(1)
23 | self.Y = inertia.reshape(-1,9)
24 | self.rep_in = k*Scalar+k*Vector
25 | self.rep_out = T(2)
26 | self.symmetry = O(3)
27 | # One has to be careful computing offset and scale in a way so that standardizing
28 | # does not violate equivariance
29 | Xmean = self.X.mean(0)
30 | Xmean[k:] = 0
31 | Xstd = np.zeros_like(Xmean)
32 | Xstd[:k] = np.abs(self.X[:,:k]).mean(0)#.std(0)
33 | #Xstd[k:] = (np.sqrt((self.X[:,k:].reshape(N,k,3)**2).mean((0,2))[:,None]) + np.zeros((k,3))).reshape(k*3)
34 | Xstd[k:] = (np.abs(self.X[:,k:].reshape(N,k,3)).mean((0,2))[:,None] + np.zeros((k,3))).reshape(k*3)
35 | Ymean = 0*self.Y.mean(0)
36 | #Ystd = np.sqrt(((self.Y-Ymean)**2).mean((0,1)))+ np.zeros_like(Ymean)
37 | Ystd = np.abs(self.Y-Ymean).mean((0,1)) + np.zeros_like(Ymean)
38 | self.stats =0,1,0,1#Xmean,Xstd,Ymean,Ystd
39 |
40 | def __getitem__(self,i):
41 | return (self.X[i],self.Y[i])
42 | def __len__(self):
43 | return self.X.shape[0]
44 | def default_aug(self,model):
45 | return GroupAugmentation(model,self.rep_in,self.rep_out,self.symmetry)
46 |
47 | @export
48 | class O5Synthetic(object):
49 | def __init__(self,N=1024):
50 | super().__init__()
51 | d=5
52 | self.dim = 2*d
53 | self.X = np.random.randn(N,self.dim)
54 | ri = self.X.reshape(-1,2,5)
55 | r1,r2 = ri.transpose(1,0,2)
56 | self.Y = np.sin(np.sqrt((r1**2).sum(-1)))-.5*np.sqrt((r2**2).sum(-1))**3 + (r1*r2).sum(-1)/(np.sqrt((r1**2).sum(-1))*np.sqrt((r2**2).sum(-1)))
57 | self.rep_in = 2*Vector
58 | self.rep_out = Scalar
59 | self.symmetry = O(d)
60 | self.Y = self.Y[...,None]
61 | # One has to be careful computing mean and std in a way so that standardizing
62 | # does not violate equivariance
63 | Xmean = self.X.mean(0) # can add and subtract arbitrary tensors
64 | Xscale = (np.sqrt((self.X.reshape(N,2,d)**2).mean((0,2)))[:,None]+0*ri[0]).reshape(self.dim)
65 | self.stats = 0,Xscale,self.Y.mean(axis=0),self.Y.std(axis=0)
66 |
67 | def __getitem__(self,i):
68 | return (self.X[i],self.Y[i])
69 | def __len__(self):
70 | return self.X.shape[0]
71 | def default_aug(self,model):
72 | return GroupAugmentation(model,self.rep_in,self.rep_out,self.symmetry)
73 |
74 | @export
75 | class ParticleInteraction(object):
76 | """ Electron muon e^4 interaction"""
77 | def __init__(self,N=1024):
78 | super().__init__()
79 | self.dim = 4*4
80 | self.rep_in = 4*Vector
81 | self.rep_out = Scalar
82 | self.X = np.random.randn(N,self.dim)/4
83 | P = self.X.reshape(N,4,4)
84 | p1,p2,p3,p4 = P.transpose(1,0,2)
85 | 𝜂 = np.diag(np.array([1.,-1.,-1.,-1.]))
86 | dot = lambda v1,v2: ((v1@𝜂)*v2).sum(-1)
87 | Le = (p1[:,:,None]*p3[:,None,:] - (dot(p1,p3)-dot(p1,p1))[:,None,None]*𝜂)
88 | L𝜇 = ((p2@𝜂)[:,:,None]*(p4@𝜂)[:,None,:] - (dot(p2,p4)-dot(p2,p2))[:,None,None]*𝜂)
89 | M = 4*(Le*L𝜇).sum(-1).sum(-1)
90 | self.Y = M
91 | self.symmetry = Lorentz()
92 | self.Y = self.Y[...,None]
93 | # One has to be careful computing mean and std in a way so that standardizing
94 | # does not violate equivariance
95 | self.Xscale = np.sqrt((np.abs((self.X.reshape(N,4,4)@𝜂)*self.X.reshape(N,4,4)).mean(-1)).mean(0))
96 | self.Xscale = (self.Xscale[:,None]+np.zeros((4,4))).reshape(-1)
97 | self.stats = 0,self.Xscale,self.Y.mean(axis=0),self.Y.std(axis=0)#self.X.mean(axis=0),self.X.std(axis=0),
98 | def __getitem__(self,i):
99 | return (self.X[i],self.Y[i])
100 | def __len__(self):
101 | return self.X.shape[0]
102 | def default_aug(self,model):
103 | return GroupAugmentation(model,self.rep_in,self.rep_out,self.symmetry)
104 |
105 | class GroupAugmentation(Module):
106 | def __init__(self,network,rep_in,rep_out,group):
107 | super().__init__()
108 | self.rep_in=rep_in
109 | self.rep_out=rep_out
110 | self.G=group
111 | self.rho_in = jit(vmap(self.rep_in.rho))
112 | self.rho_out = jit(vmap(self.rep_out.rho))
113 | self.model = network
114 | def __call__(self,x,training=True):
115 | if training:
116 | gs = self.G.samples(x.shape[0])
117 | rhout_inv = jnp.linalg.inv(self.rho_out(gs))
118 | return (rhout_inv@self.model((self.rho_in(gs)@x[...,None])[...,0],training)[...,None])[...,0]
119 | else:
120 | return self.model(x,False)
121 |
122 | @export
123 | class InvertedCube(object):
124 | def __init__(self,train=True):
125 | pass #TODO: finish implementing this simple dataset
126 | solved_state = np.eye(6)
127 | parity_perm = np.array([5,3,4,1,2,0])
128 | parity_state = solved_state[:,parity_perm]
129 |
130 | labels = np.array([1,0]).astype(int)
131 | self.X = np.zeros((2,6*6))
132 | self.X[0] = solved_state.reshape(-1)
133 | self.X[1] = parity_state.reshape(-1)
134 | self.Y = labels
135 | self.symmetry = Cube()
136 | self.rep_in = 6*Vector
137 | self.rep_out = 2*Scalar
138 | self.stats = (0,1)
139 | if train==False: # Scramble the cubes for test time
140 | gs = self.symmetry.samples(100)
141 | self.X = np.repeat(self.X,50,axis=0).reshape(100,6,6)@gs
142 | self.Y = np.repeat(self.Y,50,axis=0)
143 | p = np.random.permutation(100)
144 | self.X = self.X[p].reshape((100,-1))
145 | self.Y = self.Y[p].reshape((100,))
146 | self.X = np.array(self.X)
147 | self.Y = np.array(self.Y)
148 |
149 | def __getitem__(self,i):
150 | return (self.X[i],self.Y[i])
151 | def __len__(self):
152 | return self.X.shape[0]
153 |
154 |
155 | #### Ways of constructing invalid Rubik's cubes
156 | # (see https://ruwix.com/the-rubiks-cube/unsolvable-rubiks-cube-invalid-scramble/)
157 |
158 | def UBedge_flip(state):
159 | """ # invalid edge flip on Rubiks state (6,48)"""
160 | UB_edge = np.array([1,3*8+1]) # Top middle of U, top middle of B
161 | edge_flip = state.copy()
162 | edge_flip[:,UB_edge] = edge_flip[:,np.roll(UB_edge,1)]
163 | return edge_flip
164 |
165 | def ULBcorner_rot(state,i=1):
166 | """ Invalid rotated corner with ULB corner rotation """
167 | ULB_corner_ids = np.array([0,4*8,3*8+2]) # top left of U, top left of L, top right of B
168 | rotated_corner_state = state.copy()
169 | rotated_corner_state[:,ULB_corner_ids] = rotated_corner_state[:,np.roll(ULB_corner_ids,i)]
170 | return rotated_corner_state
171 |
172 | def LBface_swap(state):
173 | """ Invalid piece swap between L center top and B center top """
174 | L_B_center_top_faces = np.array([4*8+1,3*8+1])
175 | piece_swap = state.copy()
176 | piece_swap[:,L_B_center_top_faces] = piece_swap[:,np.roll(L_B_center_top_faces,1)]
177 | return piece_swap
178 |
179 |
180 | @export
181 | class BrokenRubiksCube(object):
182 | """ Binary classification problem of predicting whether a Rubik's cube configuration
183 | is solvable or 'broken' and not able to be solved by transformations from the group
184 | e.g. by removing and twisting a corner before replacing.
185 | Dataset is generated by taking several hand identified simple instances of solvable
186 | and unsolvable cubes, and then scrambling them.
187 |
188 | Features are represented as 6xT(1) tensors of the Rubiks Group (one hot for each color)"""
189 | def __init__(self,train=True):
190 | super().__init__()
191 | # start with a valid configuration
192 |
193 | solved_state = np.zeros((6,48))
194 | for i in range(6):
195 | solved_state[i,8*i:8*(i+1)] = 1
196 |
197 | Id = lambda x: x
198 | transforms = [Id,itertools.product([Id,ULBcorner_rot,partial(ULBcorner_rot,i=2)],
199 | [Id,UBedge_flip],
200 | [Id,LBface_swap])]
201 | #equivalence_classes = np.vstack([t3(t2(t1(solved_state))) for t1,t2,t3 in transforms])
202 | labels = np.zeros((22,))
203 | labels[:11]=1 # First configuration is solvable
204 | labels[11:]=0 # all others are not
205 | self.X = np.zeros((22,6*48)) # duplicate solvable example 11 times for convencience (balance)
206 | self.X[:11] = solved_state.reshape(-1)#equivalence_classes.reshape(12,-1)[:1]
207 | parity_perm = np.array([5,3,4,1,2,0])
208 | self.X[11:] = solved_state.reshape(6,6,8)[:,parity_perm].reshape(-1)#equivalence_classes.reshape(12,-1)[1:]
209 | self.Y = labels
210 | self.symmetry = RubiksCube()
211 | self.rep_in = 6*Vector
212 | self.rep_out = 2*Scalar
213 | self.stats = (0,1)
214 | if train==False: # Scramble the cubes for test time
215 | gs = self.symmetry.samples(440)
216 | self.X = np.repeat(self.X,20,axis=0).reshape(440,6,48)@gs
217 | self.Y = np.repeat(self.Y,20,axis=0)
218 | p = np.random.permutation(440)
219 | self.X = self.X[p].reshape((440,-1))
220 | self.Y = self.Y[p].reshape((440,))
221 | self.X = np.array(self.X)
222 | self.Y = np.array(self.Y)
223 |
224 | def __getitem__(self,i):
225 | return (self.X[i],self.Y[i])
226 | def __len__(self):
227 | return self.X.shape[0]
228 |
229 | # @export
230 | # class BrokenRubiksCube2x2(object):
231 | # """ Binary classification problem of predicting whether a Rubik's cube configuration
232 | # is solvable or 'broken' and not able to be solved by transformations from the group
233 | # e.g. by removing and twisting a corner before replacing.
234 | # Dataset is generated by taking several hand identified simple instances of solvable
235 | # and unsolvable cubes, and then scrambling them.
236 |
237 | # Features are represented as 6xT(1) tensors of the Rubiks Group (one hot for each color)"""
238 | # def __init__(self,train=True):
239 | # super().__init__()
240 | # # start with a valid configuration
241 |
242 | # solved_state = np.zeros((6,24))
243 | # for i in range(6):
244 | # solved_state[i,4*i:4*(i+1)] = 1
245 |
246 | # ULB_corner_ids = np.array([0,4*4,3*4+2]) # top left of U, top left of L, top right of B
247 | # rotated_corner_state = solved_state.copy()
248 | # rotated_corner_state[:,ULB_corner_ids] = rotated_corner_state[:,np.roll(ULB_corner_ids,1)]
249 |
250 | # labels = np.zeros((2,))
251 | # labels[:1]=1 # First configuration is solvable
252 | # labels[1:]=0 # all others are not
253 | # self.X = np.zeros((2,6*24)) # duplicate solvable example 11 times for convencience (balance)
254 | # self.X[0] = solved_state.reshape(-1)
255 | # parity_perm = np.array([5,3,4,1,2,0])
256 | # self.X[1] = rotated_corner_state.reshape(6,6,4)[:,parity_perm].reshape(-1)
257 | # self.Y = labels
258 | # self.X = np.repeat(self.X,10,axis=0)
259 | # self.Y = np.repeat(self.Y,10,axis=0)
260 | # self.symmetry = RubiksCube2x2()
261 | # self.rep_in = 6*Vector
262 | # self.rep_out = 2*Scalar
263 | # self.stats = (0,1)
264 | # if train==False: # Scramble the cubes for test time
265 | # N = 200
266 | # gs = self.symmetry.samples(N)
267 | # self.X = np.repeat(self.X,10,axis=0).reshape(N,6,24)@gs
268 | # self.Y = np.repeat(self.Y,10,axis=0)
269 | # p = np.random.permutation(N)
270 | # self.X = self.X[p].reshape((N,-1))
271 | # self.Y = self.Y[p].reshape((N,))
272 | # self.X = np.array(self.X)
273 | # self.Y = np.array(self.Y)
274 |
275 | # def __getitem__(self,i):
276 | # return (self.X[i],self.Y[i])
277 | # def __len__(self):
278 | # return self.X.shape[0]
--------------------------------------------------------------------------------
/emlp/nn/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import pkgutil
3 | __all__ = [] # expose objax implementation as base nn
4 | module = importlib.import_module('.'+'objax',package=__name__)
5 | globals().update({k: getattr(module, k) for k in module.__all__})
6 | __all__ += module.__all__
7 |
--------------------------------------------------------------------------------
/emlp/nn/flax.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 | from emlp.reps import T,Rep,Scalar
5 | from emlp.reps import bilinear_weights
6 | #from emlp.reps import LinearOperator # why does this not work?
7 | from emlp.reps.linear_operator_base import LinearOperator
8 | from emlp.reps.product_sum_reps import SumRep
9 | from emlp.groups import Group
10 | from emlp.utils import Named,export
11 | from flax import linen as nn
12 | import logging
13 | from emlp.nn import gated,gate_indices,uniform_rep
14 | from typing import Union,Iterable,Optional
15 | # def Sequential(*args):
16 | # """ Wrapped to mimic pytorch syntax"""
17 | # return nn.Sequential(args)
18 |
19 |
20 |
21 |
22 | @export
23 | def Linear(repin,repout):
24 | """ Basic equivariant Linear layer from repin to repout."""
25 | cin =repin.size()
26 | cout = repout.size()
27 | rep_W = repin>>repout
28 | Pw = rep_W.equivariant_projector()
29 | Pb = repout.equivariant_projector()
30 | logging.info(f"Linear W components:{rep_W.size()} rep:{rep_W}")
31 | return _Linear(Pw,Pb,cout)
32 |
33 | class _Linear(nn.Module):
34 | Pw:LinearOperator
35 | Pb:LinearOperator
36 | cout:int
37 | @nn.compact
38 | def __call__(self,x):
39 | w = self.param('w',nn.initializers.lecun_normal(),(self.cout,x.shape[-1]))
40 | b = self.param('b',nn.initializers.zeros,(self.cout,))
41 | W = (self.Pw@w.reshape(-1)).reshape(*w.shape)
42 | B = self.Pb@b
43 | return x@W.T+B
44 |
45 | @export
46 | def BiLinear(repin,repout):
47 | """ Cheap bilinear layer (adds parameters for each part of the input which can be
48 | interpreted as a linear map from a part of the input to the output representation)."""
49 | Wdim, weight_proj = bilinear_weights(repout,repin)
50 | #self.w = TrainVar(objax.random.normal((Wdim,)))#xavier_normal((Wdim,))) #TODO: revert to xavier
51 | logging.info(f"BiW components: dim:{Wdim}")
52 | return _BiLinear(Wdim,weight_proj)
53 |
54 | class _BiLinear(nn.Module):
55 | Wdim:int
56 | weight_proj:callable
57 |
58 | @nn.compact
59 | def __call__(self, x):
60 | w = self.param('w',nn.initializers.normal(),(self.Wdim,)) #TODO: change to standard normal
61 | W = self.weight_proj(w,x)
62 | out= .1*(W@x[...,None])[...,0]
63 | return out
64 |
65 |
66 | @export
67 | class GatedNonlinearity(nn.Module): #TODO: add support for mixed tensors and non sumreps
68 | """ Gated nonlinearity. Requires input to have the additional gate scalars
69 | for every non regular and non scalar rep. Applies swish to regular and
70 | scalar reps. (Right now assumes rep is a SumRep)"""
71 | rep:Rep
72 | def __call__(self,values):
73 | gate_scalars = values[..., gate_indices(self.rep)]
74 | activations = jax.nn.sigmoid(gate_scalars) * values[..., :self.rep.size()]
75 | return activations
76 |
77 | @export
78 | def EMLPBlock(rep_in,rep_out):
79 | """ Basic building block of EMLP consisting of G-Linear, biLinear,
80 | and gated nonlinearity. """
81 | linear = Linear(rep_in,gated(rep_out))
82 | bilinear = BiLinear(gated(rep_out),gated(rep_out))
83 | nonlinearity = GatedNonlinearity(rep_out)
84 | return _EMLPBlock(linear,bilinear,nonlinearity)
85 |
86 | class _EMLPBlock(nn.Module):
87 | linear:nn.Module
88 | bilinear:nn.Module
89 | nonlinearity:nn.Module
90 |
91 | def __call__(self,x):
92 | lin = self.linear(x)
93 | preact =self.bilinear(lin)+lin
94 | return self.nonlinearity(preact)
95 |
96 | @export
97 | def EMLP(rep_in,rep_out,group,ch=384,num_layers=3):
98 | """ Equivariant MultiLayer Perceptron.
99 | If the input ch argument is an int, uses the hands off uniform_rep heuristic.
100 | If the ch argument is a representation, uses this representation for the hidden layers.
101 | Individual layer representations can be set explicitly by using a list of ints or a list of
102 | representations, rather than use the same for each hidden layer.
103 |
104 | Args:
105 | rep_in (Rep): input representation
106 | rep_out (Rep): output representation
107 | group (Group): symmetry group
108 | ch (int or list[int] or Rep or list[Rep]): number of channels in the hidden layers
109 | num_layers (int): number of hidden layers
110 |
111 | Returns:
112 | Module: the EMLP objax module."""
113 | logging.info("Initing EMLP (flax)")
114 | rep_in = rep_in(group)
115 | rep_out = rep_out(group)
116 | if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)]
117 | elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)]
118 | else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch]
119 | reps = [rep_in]+middle_layers
120 | logging.info(f"Reps: {reps}")
121 | return Sequential(*[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])],Linear(reps[-1],rep_out))
122 |
123 | def swish(x):
124 | return jax.nn.sigmoid(x)*x
125 |
126 | class _Sequential(nn.Module):
127 | modules:Iterable[callable]
128 | def __call__(self,x):
129 | for module in self.modules:
130 | x = module(x)
131 | return x
132 |
133 | def Sequential(*layers):
134 | return _Sequential(layers)
135 |
136 | def MLPBlock(cout):
137 | return Sequential(nn.Dense(cout),swish) # ,nn.BatchNorm0D(cout,momentum=.9),swish)#,
138 |
139 | @export
140 | class MLP(nn.Module,metaclass=Named):
141 | """ Standard baseline MLP. Representations and group are used for shapes only. """
142 | rep_in:Rep
143 | rep_out:Rep
144 | group:Group
145 | ch: Optional[InterruptedError]=384
146 | num_layers:Optional[int]=3
147 | def setup(self):
148 | logging.info("Initing MLP (flax)")
149 | cout = self.rep_out(self.group).size()
150 | self.modules = [MLPBlock(self.ch) for _ in range(self.num_layers)]+[nn.Dense(cout)]
151 | def __call__(self,x):
152 | for module in self.modules:
153 | x = module(x)
154 | return x
155 |
--------------------------------------------------------------------------------
/emlp/nn/haiku.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 | from emlp.reps import Rep
5 | from emlp.reps import bilinear_weights
6 | from emlp.utils import export
7 | import logging
8 | import haiku as hk
9 | from emlp.nn import gated,gate_indices,uniform_rep
10 |
11 | def Sequential(*args):
12 | """ Wrapped to mimic pytorch syntax"""
13 | return lambda x: hk.Sequential(args)(x)
14 |
15 | @export
16 | def Linear(repin,repout):
17 | rep_W = repout << repin
18 | logging.info(f"Linear W components:{rep_W.size()} rep:{rep_W}")
19 | rep_bias = repout
20 | Pw = rep_W.equivariant_projector()
21 | Pb = rep_bias.equivariant_projector()
22 | return lambda x: hkLinear(Pw,Pb,(repout.size(),repin.size()))(x)
23 |
24 | class hkLinear(hk.Module):
25 | """ Basic equivariant Linear layer from repin to repout."""
26 | def __init__(self, Pw,Pb,shape,name=None):
27 | super().__init__(name=name)
28 | self.Pw = Pw
29 | self.Pb = Pb
30 | self.shape=shape
31 |
32 | def __call__(self, x): # (cin) -> (cout)
33 | i,j = self.shape
34 | w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(i))
35 | w = hk.get_parameter("w", shape=self.shape, dtype=x.dtype, init=w_init)
36 | b = hk.get_parameter("b", shape=[i], dtype=x.dtype, init=w_init)
37 | W = (self.Pw@w.reshape(-1)).reshape(*self.shape)
38 | b = self.Pb@b
39 | return x@W.T+b
40 |
41 | @export
42 | def BiLinear(repin,repout):
43 | """ Cheap bilinear layer (adds parameters for each part of the input which can be
44 | interpreted as a linear map from a part of the input to the output representation)."""
45 | Wdim, weight_proj = bilinear_weights(repout,repin)
46 | return lambda x: hkBiLinear(weight_proj,Wdim)(x)
47 |
48 |
49 | class hkBiLinear(hk.Module):
50 | def __init__(self, weight_proj,Wdim,name=None):
51 | super().__init__(name=name)
52 | self.weight_proj=weight_proj
53 | self.Wdim=Wdim
54 |
55 | def __call__(self, x):
56 | # compatible with non sumreps? need to check
57 | w_init = hk.initializers.TruncatedNormal(1.)
58 | w = hk.get_parameter("w", shape=[self.Wdim], dtype=x.dtype, init=w_init)
59 | W = self.weight_proj(w,x)
60 | return .1*(W@x[...,None])[...,0]
61 |
62 | @export
63 | class GatedNonlinearity(object): # TODO: add support for mixed tensors and non sumreps
64 | """ Gated nonlinearity. Requires input to have the additional gate scalars
65 | for every non regular and non scalar rep. Applies swish to regular and
66 | scalar reps. (Right now assumes rep is a SumRep)"""
67 | def __init__(self,rep,name=None):
68 | super().__init__()
69 | self.rep=rep
70 |
71 | def __call__(self,values):
72 | gate_scalars = values[..., gate_indices(self.rep)]
73 | activations = jax.nn.sigmoid(gate_scalars) * values[..., :self.rep.size()]
74 | return activations
75 |
76 |
77 | @export
78 | def EMLPBlock(repin,repout):
79 | """ Basic building block of EMLP consisting of G-Linear, biLinear,
80 | and gated nonlinearity. """
81 | linear = Linear(repin,gated(repout))
82 | bilinear = BiLinear(gated(repout),gated(repout))
83 | nonlinearity = GatedNonlinearity(repout)
84 | def block(x):
85 | lin = linear(x)
86 | preact =bilinear(lin)+lin
87 | return nonlinearity(preact)
88 | return block
89 |
90 |
91 | @export
92 | def EMLP(rep_in,rep_out,group,ch=384,num_layers=3):
93 | """ Equivariant MultiLayer Perceptron.
94 | If the input ch argument is an int, uses the hands off uniform_rep heuristic.
95 | If the ch argument is a representation, uses this representation for the hidden layers.
96 | Individual layer representations can be set explicitly by using a list of ints or a list of
97 | representations, rather than use the same for each hidden layer.
98 |
99 | Args:
100 | rep_in (Rep): input representation
101 | rep_out (Rep): output representation
102 | group (Group): symmetry group
103 | ch (int or list[int] or Rep or list[Rep]): number of channels in the hidden layers
104 | num_layers (int): number of hidden layers
105 |
106 | Returns:
107 | Module: the EMLP objax module."""
108 | logging.info("Initing EMLP (Haiku)")
109 | rep_in =rep_in(group)
110 | rep_out = rep_out(group)
111 | # Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps
112 | if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)]
113 | elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)]
114 | else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch]
115 | # assert all((not rep.G is None) for rep in middle_layers[0].reps)
116 | reps = [rep_in]+middle_layers
117 | # logging.info(f"Reps: {reps}")
118 | network = Sequential(
119 | *[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])],
120 | Linear(reps[-1],rep_out)
121 | )
122 | return network
123 |
124 | @export
125 | def MLP(rep_in,rep_out,group,ch=384,num_layers=3):
126 | """ Standard baseline MLP. Representations and group are used for shapes only. """
127 | cout = rep_out(group).size()
128 | mlp = lambda x: Sequential(
129 | *[Sequential(hk.Linear(ch),jax.nn.swish) for _ in range(num_layers)],
130 | hk.Linear(cout)
131 | )(x)
132 | return mlp
133 |
--------------------------------------------------------------------------------
/emlp/nn/objax.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import objax.nn as nn
4 | import objax.functional as F
5 | import numpy as np
6 | from emlp.reps import T,Rep,Scalar
7 | from emlp.reps import bilinear_weights
8 | from emlp.reps.product_sum_reps import SumRep
9 | import collections
10 | from emlp.utils import Named,export
11 | import scipy as sp
12 | import scipy.special
13 | import random
14 | import logging
15 | from objax.variable import TrainVar, StateVar
16 | from objax.nn.init import kaiming_normal, xavier_normal
17 | from objax.module import Module
18 | import objax
19 | from objax.nn.init import orthogonal
20 | from scipy.special import binom
21 | from jax import jit,vmap
22 | from functools import lru_cache as cache
23 |
24 | def Sequential(*args):
25 | """ Wrapped to mimic pytorch syntax"""
26 | return nn.Sequential(args)
27 |
28 | @export
29 | class Linear(nn.Linear):
30 | """ Basic equivariant Linear layer from repin to repout."""
31 | def __init__(self, repin, repout):
32 | nin,nout = repin.size(),repout.size()
33 | super().__init__(nin,nout)
34 | self.b = TrainVar(objax.random.uniform((nout,))/jnp.sqrt(nout))
35 | self.w = TrainVar(orthogonal((nout, nin)))
36 | self.rep_W = rep_W = repout*repin.T
37 |
38 | rep_bias = repout
39 | self.Pw = rep_W.equivariant_projector()
40 | self.Pb = rep_bias.equivariant_projector()
41 | logging.info(f"Linear W components:{rep_W.size()} rep:{rep_W}")
42 | def __call__(self, x): # (cin) -> (cout)
43 | logging.debug(f"linear in shape: {x.shape}")
44 | W = (self.Pw@self.w.value.reshape(-1)).reshape(*self.w.value.shape)
45 | b = self.Pb@self.b.value
46 | out = x@W.T+b
47 | logging.debug(f"linear out shape:{out.shape}")
48 | return out
49 |
50 | @export
51 | class BiLinear(Module):
52 | """ Cheap bilinear layer (adds parameters for each part of the input which can be
53 | interpreted as a linear map from a part of the input to the output representation)."""
54 | def __init__(self, repin, repout):
55 | super().__init__()
56 | Wdim, weight_proj = bilinear_weights(repout,repin)
57 | self.weight_proj = jit(weight_proj)
58 | self.w = TrainVar(objax.random.normal((Wdim,)))
59 | logging.info(f"BiW components: dim:{Wdim}")
60 |
61 | def __call__(self, x,training=True):
62 | # compatible with non sumreps? need to check
63 | W = self.weight_proj(self.w.value,x)
64 | out= .1*(W@x[...,None])[...,0]
65 | return out
66 |
67 | @export
68 | def gated(ch_rep:Rep) -> Rep:
69 | """ Returns the rep with an additional scalar 'gate' for each of the nonscalars and non regular
70 | reps in the input. To be used as the output for linear (and or bilinear) layers directly
71 | before a :func:`GatedNonlinearity` to produce its scalar gates. """
72 | if isinstance(ch_rep,SumRep):
73 | return ch_rep+sum([Scalar(rep.G) for rep in ch_rep if rep!=Scalar and not rep.is_permutation])
74 | else:
75 | return ch_rep+Scalar(ch_rep.G) if not ch_rep.is_permutation else ch_rep
76 |
77 | @export
78 | class GatedNonlinearity(Module):
79 | """ Gated nonlinearity. Requires input to have the additional gate scalars
80 | for every non regular and non scalar rep. Applies swish to regular and
81 | scalar reps. """
82 | def __init__(self,rep):
83 | super().__init__()
84 | self.rep=rep
85 | def __call__(self,values):
86 | gate_scalars = values[..., gate_indices(self.rep)]
87 | activations = jax.nn.sigmoid(gate_scalars) * values[..., :self.rep.size()]
88 | return activations
89 |
90 | @export
91 | class EMLPBlock(Module):
92 | """ Basic building block of EMLP consisting of G-Linear, biLinear,
93 | and gated nonlinearity. """
94 | def __init__(self,rep_in,rep_out):
95 | super().__init__()
96 | self.linear = Linear(rep_in,gated(rep_out))
97 | self.bilinear = BiLinear(gated(rep_out),gated(rep_out))
98 | self.nonlinearity = GatedNonlinearity(rep_out)
99 | def __call__(self,x):
100 | lin = self.linear(x)
101 | preact =self.bilinear(lin)+lin
102 | return self.nonlinearity(preact)
103 |
104 | def uniform_rep_general(ch,*rep_types):
105 | """ adds all combinations of (powers of) rep_types up to
106 | a total size of ch channels. """
107 | raise NotImplementedError
108 |
109 |
110 |
111 |
112 | @export
113 | def uniform_rep(ch,group):
114 | """ A heuristic method for allocating a given number of channels (ch)
115 | into tensor types. Attempts to distribute the channels evenly across
116 | the different tensor types. Useful for hands off layer construction.
117 |
118 | Args:
119 | ch (int): total number of channels
120 | group (Group): symmetry group
121 |
122 | Returns:
123 | SumRep: The direct sum representation with dim(V)=ch
124 | """
125 | d = group.d
126 | Ns = np.zeros((lambertW(ch,d)+1,),int) # number of tensors of each rank
127 | while ch>0:
128 | max_rank = lambertW(ch,d) # compute the max rank tensor that can fit up to
129 | Ns[:max_rank+1] += np.array([d**(max_rank-r) for r in range(max_rank+1)],dtype=int)
130 | ch -= (max_rank+1)*d**max_rank # compute leftover channels
131 | sum_rep = sum([binomial_allocation(nr,r,group) for r,nr in enumerate(Ns)])
132 | sum_rep,perm = sum_rep.canonicalize()
133 | return sum_rep
134 |
135 | def lambertW(ch,d):
136 | """ Returns solution to x*d^x = ch rounded down."""
137 | max_rank=0
138 | while (max_rank+1)*d**max_rank <= ch:
139 | max_rank += 1
140 | max_rank -= 1
141 | return max_rank
142 |
143 | def binomial_allocation(N,rank,G):
144 | """ Allocates N of tensors of total rank r=(p+q) into
145 | T(k,r-k) for k=0,1,...,r to match the binomial distribution.
146 | For orthogonal representations there is no
147 | distinction between p and q, so this op is equivalent to N*T(rank)."""
148 | if N==0: return 0
149 | n_binoms = N//(2**rank)
150 | n_leftover = N%(2**rank)
151 | even_split = sum([n_binoms*int(binom(rank,k))*T(k,rank-k,G) for k in range(rank+1)])
152 | ps = np.random.binomial(rank,.5,n_leftover)
153 | ragged = sum([T(int(p),rank-int(p),G) for p in ps])
154 | out = even_split+ragged
155 | return out
156 |
157 | def uniform_allocation(N,rank):
158 | """ Uniformly allocates N of tensors of total rank r=(p+q) into
159 | T(k,r-k) for k=0,1,...,r. For orthogonal representations there is no
160 | distinction between p and q, so this op is equivalent to N*T(rank)."""
161 | if N==0: return 0
162 | even_split = sum((N//(rank+1))*T(k,rank-k) for k in range(rank+1))
163 | ragged = sum(random.sample([T(k,rank-k) for k in range(rank+1)],N%(rank+1)))
164 | return even_split+ragged
165 |
166 | @export
167 | class EMLP(Module,metaclass=Named):
168 | """ Equivariant MultiLayer Perceptron.
169 | If the input ch argument is an int, uses the hands off uniform_rep heuristic.
170 | If the ch argument is a representation, uses this representation for the hidden layers.
171 | Individual layer representations can be set explicitly by using a list of ints or a list of
172 | representations, rather than use the same for each hidden layer.
173 |
174 | Args:
175 | rep_in (Rep): input representation
176 | rep_out (Rep): output representation
177 | group (Group): symmetry group
178 | ch (int or list[int] or Rep or list[Rep]): number of channels in the hidden layers
179 | num_layers (int): number of hidden layers
180 |
181 | Returns:
182 | Module: the EMLP objax module."""
183 | def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):#@
184 | super().__init__()
185 | logging.info("Initing EMLP (objax)")
186 | self.rep_in =rep_in(group)
187 | self.rep_out = rep_out(group)
188 |
189 | self.G=group
190 | # Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps
191 | if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)]#[uniform_rep(ch,group) for _ in range(num_layers)]
192 | elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)]
193 | else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch]
194 | #assert all((not rep.G is None) for rep in middle_layers[0].reps)
195 | reps = [self.rep_in]+middle_layers
196 | logging.info(f"Reps: {reps}")
197 | self.network = Sequential(
198 | *[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])],
199 | Linear(reps[-1],self.rep_out)
200 | )
201 | def __call__(self,x,training=True):
202 | return self.network(x)
203 |
204 | def swish(x):
205 | return jax.nn.sigmoid(x)*x
206 |
207 | def MLPBlock(cin,cout):
208 | return Sequential(nn.Linear(cin,cout),swish)#,nn.BatchNorm0D(cout,momentum=.9),swish)#,
209 |
210 | @export
211 | class MLP(Module,metaclass=Named):
212 | """ Standard baseline MLP. Representations and group are used for shapes only. """
213 | def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):
214 | super().__init__()
215 | self.rep_in =rep_in(group)
216 | self.rep_out = rep_out(group)
217 | self.G = group
218 | chs = [self.rep_in.size()] + num_layers*[ch]
219 | cout = self.rep_out.size()
220 | logging.info("Initing MLP")
221 | self.net = Sequential(
222 | *[MLPBlock(cin,cout) for cin,cout in zip(chs,chs[1:])],
223 | nn.Linear(chs[-1],cout)
224 | )
225 | def __call__(self,x,training=True):
226 | y = self.net(x)
227 | return y
228 |
229 | @export
230 | class Standardize(Module):
231 | """ A convenience module to wrap a given module, normalize its input
232 | by some dataset x mean and std stats, and unnormalize its output by
233 | the dataset y mean and std stats.
234 |
235 | Args:
236 | model (Module): model to wrap
237 | ds_stats ((μx,σx,μy,σy) or (μx,σx)): tuple of the normalization stats
238 |
239 | Returns:
240 | Module: Wrapped model with input normalization (and output unnormalization)"""
241 | def __init__(self,model,ds_stats):
242 | super().__init__()
243 | self.model = model
244 | self.ds_stats=ds_stats
245 | def __call__(self,x,training):
246 | if len(self.ds_stats)==2:
247 | muin,sin = self.ds_stats
248 | return self.model((x-muin)/sin,training=training)
249 | else:
250 | muin,sin,muout,sout = self.ds_stats
251 | y = sout*self.model((x-muin)/sin,training=training)+muout
252 | return y
253 |
254 |
255 |
256 | # Networks for hamiltonian dynamics (need to sum for batched Hamiltonian grads)
257 | @export
258 | class MLPode(Module,metaclass=Named):
259 | def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):
260 | super().__init__()
261 | self.rep_in =rep_in(group)
262 | self.rep_out = rep_out(group)
263 | self.G = group
264 | chs = [self.rep_in.size()] + num_layers*[ch]
265 | cout = self.rep_out.size()
266 | logging.info("Initing MLP")
267 | self.net = Sequential(
268 | *[Sequential(nn.Linear(cin,cout),swish) for cin,cout in zip(chs,chs[1:])],
269 | nn.Linear(chs[-1],cout)
270 | )
271 | def __call__(self,z,t):
272 | return self.net(z)
273 |
274 | @export
275 | class EMLPode(EMLP):
276 | """ Neural ODE Equivariant MLP. Same args as EMLP."""
277 | #__doc__ += EMLP.__doc__.split('.')[1]
278 | def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):#@
279 | #super().__init__()
280 | logging.info("Initing EMLP")
281 | self.rep_in =rep_in(group)
282 | self.rep_out = rep_out(group)
283 | self.G=group
284 | # Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps
285 | if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)]#[uniform_rep(ch,group) for _ in range(num_layers)]
286 | elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)]
287 | else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch]
288 | #print(middle_layers[0].reps[0].G)
289 | #print(self.rep_in.G)
290 | reps = [self.rep_in]+middle_layers
291 | logging.info(f"Reps: {reps}")
292 | self.network = Sequential(
293 | *[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])],
294 | Linear(reps[-1],self.rep_out)
295 | )
296 | def __call__(self,z,t):
297 | return self.network(z)
298 |
299 | # Networks for hamiltonian dynamics (need to sum for batched Hamiltonian grads)
300 | @export
301 | class MLPH(Module,metaclass=Named):
302 | def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):
303 | super().__init__()
304 | self.rep_in =rep_in(group)
305 | self.rep_out = rep_out(group)
306 | self.G = group
307 | chs = [self.rep_in.size()] + num_layers*[ch]
308 | cout = self.rep_out.size()
309 | logging.info("Initing MLP")
310 | self.net = Sequential(
311 | *[Sequential(nn.Linear(cin,cout),swish) for cin,cout in zip(chs,chs[1:])],
312 | nn.Linear(chs[-1],cout)
313 | )
314 | def H(self,x):#,training=True):
315 | y = self.net(x).sum()
316 | return y
317 | def __call__(self,x):
318 | return self.H(x)
319 |
320 | @export
321 | class EMLPH(EMLP):
322 | """ Equivariant EMLP modeling a Hamiltonian for HNN. Same args as EMLP"""
323 | #__doc__ += EMLP.__doc__.split('.')[1]
324 | def H(self,x):#,training=True):
325 | y = self.network(x)
326 | return y.sum()
327 | def __call__(self,x):
328 | return self.H(x)
329 |
330 | @export
331 | @cache(maxsize=None)
332 | def gate_indices(ch_rep:Rep) -> jnp.ndarray:
333 | """ Indices for scalars, and also additional scalar gates
334 | added by gated(sumrep)"""
335 | channels = ch_rep.size()
336 | perm = ch_rep.perm
337 | indices = np.arange(channels)
338 |
339 | if not isinstance(ch_rep,SumRep): # If just a single rep, only one scalar at end
340 | return indices if ch_rep.is_permutation else np.ones(ch_rep.size())*ch_rep.size()
341 |
342 | num_nonscalars = 0
343 | i=0
344 | for rep in ch_rep:
345 | if rep!=Scalar and not rep.is_permutation:
346 | indices[perm[i:i+rep.size()]] = channels+num_nonscalars
347 | num_nonscalars+=1
348 | i+=rep.size()
349 | return indices
--------------------------------------------------------------------------------
/emlp/nn/pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | import jax
4 | from jax import jit
5 | from jax.tree_util import tree_flatten, tree_unflatten
6 | import types
7 | import copy
8 | import jax.numpy as jnp
9 | import numpy as np
10 | import types
11 | from functools import partial
12 | from emlp.reps import T,Rep,Scalar
13 | from emlp.reps import bilinear_weights
14 | from emlp.utils import Named,export
15 | import logging
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | from emlp.nn import gated,gate_indices,uniform_rep
20 |
21 | def torch2jax(arr):
22 | if isinstance(arr,torch.Tensor):
23 | return jnp.asarray(arr.cpu().data.numpy())
24 | else:
25 | return arr
26 |
27 | def jax2torch(arr):
28 | if isinstance(arr,(jnp.ndarray,np.ndarray)):
29 | if jax.devices()[0].platform=='gpu':
30 | device = torch.device('cuda')
31 | else: device = torch.device('cpu')
32 | return torch.from_numpy(np.array(arr)).to(device)
33 | else:
34 | return arr
35 |
36 | def to_jax(pytree):
37 | flat_values, tree_type = tree_flatten(pytree)
38 | transformed_flat = [torch2jax(v) for v in flat_values]
39 | return tree_unflatten(tree_type, transformed_flat)
40 |
41 |
42 | def to_pytorch(pytree):
43 | flat_values, tree_type = tree_flatten(pytree)
44 | transformed_flat = [jax2torch(v) for v in flat_values]
45 | return tree_unflatten(tree_type, transformed_flat)
46 |
47 | @export
48 | def torchify_fn(function):
49 | """ A method to enable interopability between jax and pytorch autograd.
50 | Calling torchify on a given function that has pytrees of jax.ndarray
51 | objects as inputs and outputs will return a function that first converts
52 | the inputs to jax, runs through the jax function, and converts the output
53 | back to torch but preserving the gradients of the operation to be called
54 | with pytorch autograd. """
55 | vjp = jit(lambda *args: jax.vjp(function,*args))
56 | class torched_fn(Function):
57 | @staticmethod
58 | def forward(ctx,*args):
59 | if any(ctx.needs_input_grad):
60 | y,ctx.vjp_fn = vjp(*to_jax(args))#jax.vjp(function,*to_jax(args))
61 | return to_pytorch(y)
62 | return to_pytorch(function(*to_jax(args)))
63 | @staticmethod
64 | def backward(ctx,*grad_outputs):
65 | return to_pytorch(ctx.vjp_fn(*to_jax(grad_outputs)))
66 | return torched_fn.apply #TORCHED #Roasted
67 |
68 |
69 | @export
70 | class Linear(nn.Linear):
71 | """ Basic equivariant Linear layer from repin to repout."""
72 | def __init__(self, repin, repout):
73 | nin,nout = repin.size(),repout.size()
74 | super().__init__(nin,nout)
75 | rep_W = repout*repin.T
76 | rep_bias = repout
77 | Pw = rep_W.equivariant_projector()
78 | Pb = rep_bias.equivariant_projector()
79 | self.proj_b = torchify_fn(jit(lambda b: Pb@b))
80 | self.proj_w = torchify_fn(jit(lambda w:(Pw@w.reshape(-1)).reshape(nout,nin)))
81 | logging.info(f"Linear W components:{rep_W.size()} rep:{rep_W}")
82 |
83 | def forward(self, x): # (cin) -> (cout)
84 | return F.linear(x,self.proj_w(self.weight),self.proj_b(self.bias))
85 |
86 | @export
87 | class BiLinear(nn.Module):
88 | """ Cheap bilinear layer (adds parameters for each part of the input which can be
89 | interpreted as a linear map from a part of the input to the output representation)."""
90 | def __init__(self, repin, repout):
91 | super().__init__()
92 | Wdim, weight_proj = bilinear_weights(repout,repin)
93 | self.weight_proj = torchify_fn(jit(weight_proj))
94 | self.bi_params = nn.Parameter(torch.randn(Wdim))
95 | logging.info(f"BiW components: dim:{Wdim}")
96 |
97 | def forward(self, x,training=True):
98 | # compatible with non sumreps? need to check
99 | W = self.weight_proj(self.bi_params,x)
100 | out= .1*(W@x[...,None])[...,0]
101 | return out
102 |
103 | @export
104 | class GatedNonlinearity(nn.Module): #TODO: add support for mixed tensors and non sumreps
105 | """ Gated nonlinearity. Requires input to have the additional gate scalars
106 | for every non regular and non scalar rep. Applies swish to regular and
107 | scalar reps. (Right now assumes rep is a SumRep)"""
108 | def __init__(self,rep):
109 | super().__init__()
110 | self.rep=rep
111 | def forward(self,values):
112 | gate_scalars = values[..., gate_indices(self.rep)]
113 | activations = gate_scalars.sigmoid() * values[..., :self.rep.size()]
114 | return activations
115 |
116 | @export
117 | class EMLPBlock(nn.Module):
118 | """ Basic building block of EMLP consisting of G-Linear, biLinear,
119 | and gated nonlinearity. """
120 | def __init__(self,rep_in,rep_out):
121 | super().__init__()
122 | self.linear = Linear(rep_in,gated(rep_out))
123 | self.bilinear = BiLinear(gated(rep_out),gated(rep_out))
124 | self.nonlinearity = GatedNonlinearity(rep_out)
125 |
126 | def forward(self,x):
127 | lin = self.linear(x)
128 | preact =self.bilinear(lin)+lin
129 | return self.nonlinearity(preact)
130 |
131 | @export
132 | class EMLP(nn.Module):
133 | """ Equivariant MultiLayer Perceptron.
134 | If the input ch argument is an int, uses the hands off uniform_rep heuristic.
135 | If the ch argument is a representation, uses this representation for the hidden layers.
136 | Individual layer representations can be set explicitly by using a list of ints or a list of
137 | representations, rather than use the same for each hidden layer.
138 |
139 | Args:
140 | rep_in (Rep): input representation
141 | rep_out (Rep): output representation
142 | group (Group): symmetry group
143 | ch (int or list[int] or Rep or list[Rep]): number of channels in the hidden layers
144 | num_layers (int): number of hidden layers
145 |
146 | Returns:
147 | Module: the EMLP objax module."""
148 | def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):
149 | super().__init__()
150 | logging.info("Initing EMLP (PyTorch)")
151 | self.rep_in =rep_in(group)
152 | self.rep_out = rep_out(group)
153 |
154 | self.G=group
155 | # Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps
156 | if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)]#[uniform_rep(ch,group) for _ in range(num_layers)]
157 | elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)]
158 | else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch]
159 | #assert all((not rep.G is None) for rep in middle_layers[0].reps)
160 | reps = [self.rep_in]+middle_layers
161 | #logging.info(f"Reps: {reps}")
162 | self.network = nn.Sequential(
163 | *[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])],
164 | Linear(reps[-1],self.rep_out)
165 | )
166 | def forward(self,x):
167 | return self.network(x)
168 |
169 | class Swish(nn.Module):
170 | def forward(self,x):
171 | return x.sigmoid()*x
172 |
173 | def MLPBlock(cin,cout):
174 | return nn.Sequential(nn.Linear(cin,cout),Swish())#,nn.BatchNorm0D(cout,momentum=.9),swish)#,
175 |
176 | @export
177 | class MLP(nn.Module):
178 | """ Standard baseline MLP. Representations and group are used for shapes only. """
179 | def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):
180 | super().__init__()
181 | self.rep_in =rep_in(group)
182 | self.rep_out = rep_out(group)
183 | self.G = group
184 | chs = [self.rep_in.size()] + num_layers*[ch]
185 | cout = self.rep_out.size()
186 | logging.info("Initing MLP")
187 | self.net = nn.Sequential(
188 | *[MLPBlock(cin,cout) for cin,cout in zip(chs,chs[1:])],
189 | nn.Linear(chs[-1],cout)
190 | )
191 |
192 | def forward(self,x):
193 | y = self.net(x)
194 | return y
195 |
196 | @export
197 | class Standardize(nn.Module):
198 | """ A convenience module to wrap a given module, normalize its input
199 | by some dataset x mean and std stats, and unnormalize its output by
200 | the dataset y mean and std stats.
201 |
202 | Args:
203 | model (Module): model to wrap
204 | ds_stats ((μx,σx,μy,σy) or (μx,σx)): tuple of the normalization stats
205 |
206 | Returns:
207 | Module: Wrapped model with input normalization (and output unnormalization)"""
208 | def __init__(self,model,ds_stats):
209 | super().__init__()
210 | self.model = model
211 | self.ds_stats=ds_stats
212 |
213 | def forward(self,x,training):
214 | if len(self.ds_stats)==2:
215 | muin,sin = self.ds_stats
216 | return self.model((x-muin)/sin,training=training)
217 | else:
218 | muin,sin,muout,sout = self.ds_stats
219 | y = sout*self.model((x-muin)/sin,training=training)+muout
220 | return y
221 |
--------------------------------------------------------------------------------
/emlp/reps/__init__.py:
--------------------------------------------------------------------------------
1 | # from .product_sum_reps import SumRep,DeferredSumRep,ProductRep,DeferredProductRep,DirectProduct
2 | # __all__=["SumRep","DeferredSumRep","ProductRep","DeferredProductRep","DirectProduct"]
3 | import importlib
4 | import pkgutil
5 | __all__ = []
6 | for loader, module_name, is_pkg in pkgutil.walk_packages(__path__):
7 | module = importlib.import_module('.'+module_name,package=__name__)
8 | try:
9 | globals().update({k: getattr(module, k) for k in module.__all__})
10 | __all__ += module.__all__
11 | except AttributeError: continue
12 |
13 | # concatenate __all__ from each of the modules
--------------------------------------------------------------------------------
/emlp/reps/linear_operators.py:
--------------------------------------------------------------------------------
1 | from .linear_operator_base import LinearOperator,Lazy
2 | import jax.numpy as jnp
3 | import numpy as np
4 | from jax import jit
5 | import jax
6 | from functools import reduce
7 |
8 | product = lambda c: reduce(lambda a,b:a*b,c)
9 |
10 | def lazify(x):
11 | if isinstance(x,LinearOperator): return x
12 | elif isinstance(x,(jnp.ndarray,np.ndarray)): return Lazy(x)
13 | else: raise NotImplementedError
14 |
15 | def densify(x):
16 | if isinstance(x,LinearOperator): return x.to_dense()
17 | elif isinstance(x,(jnp.ndarray,np.ndarray)): return x
18 | else: raise NotImplementedError
19 |
20 | class I(LinearOperator):
21 | def __init__(self,d):
22 | shape = (d,d)
23 | super().__init__(None, shape)
24 | def _matmat(self,V): #(c,k)
25 | return V
26 | def _matvec(self,V):
27 | return V
28 | def _adjoint(self):
29 | return self
30 | def invT(self):
31 | return self
32 |
33 | class LazyKron(LinearOperator):
34 |
35 | def __init__(self,Ms):
36 | self.Ms = Ms
37 | shape = product([Mi.shape[0] for Mi in Ms]), product([Mi.shape[1] for Mi in Ms])
38 | #self.dtype=Ms[0].dtype
39 | super().__init__(None,shape)
40 |
41 | def _matvec(self,v):
42 | return self._matmat(v).reshape(-1)
43 | def _matmat(self,v):
44 | ev = v.reshape(*[Mi.shape[-1] for Mi in self.Ms],-1)
45 | for i,M in enumerate(self.Ms):
46 | ev_front = jnp.moveaxis(ev,i,0)
47 | Mev_front = (M@ev_front.reshape(M.shape[-1],-1)).reshape(M.shape[0],*ev_front.shape[1:])
48 | ev = jnp.moveaxis(Mev_front,0,i)
49 | return ev.reshape(self.shape[0],ev.shape[-1])
50 | def _adjoint(self):
51 | return LazyKron([Mi.T for Mi in self.Ms])
52 | def invT(self):
53 | return LazyKron([M.invT() for M in self.Ms])
54 | def to_dense(self):
55 | Ms = [M.to_dense() if isinstance(M,LinearOperator) else M for M in self.Ms]
56 | return reduce(jnp.kron,Ms)
57 | def __new__(cls,Ms):
58 | if len(Ms)==1: return Ms[0]
59 | return super().__new__(cls)
60 |
61 | #@jit
62 | def kronsum(A,B):
63 | return jnp.kron(A,jnp.eye(B.shape[-1])) + jnp.kron(jnp.eye(A.shape[-1]),B)
64 |
65 |
66 | class LazyKronsum(LinearOperator):
67 |
68 | def __init__(self,Ms):
69 | self.Ms = Ms
70 | shape = product([Mi.shape[0] for Mi in Ms]), product([Mi.shape[1] for Mi in Ms])
71 | #self.dtype=Ms[0].dtype
72 | dtype=jnp.dtype('float32')
73 | super().__init__(dtype,shape)
74 |
75 | def _matvec(self,v):
76 | return self._matmat(v).reshape(-1)
77 |
78 | def _matmat(self,v):
79 | ev = v.reshape(*[Mi.shape[-1] for Mi in self.Ms],-1)
80 | out = 0*ev
81 | for i,M in enumerate(self.Ms):
82 | ev_front = jnp.moveaxis(ev,i,0)
83 | Mev_front = (M@ev_front.reshape(M.shape[-1],-1)).reshape(M.shape[0],*ev_front.shape[1:])
84 | out += jnp.moveaxis(Mev_front,0,i)
85 | return out.reshape(self.shape[0],ev.shape[-1])
86 |
87 | def _adjoint(self):
88 | return LazyKronsum([Mi.T for Mi in self.Ms])
89 | def to_dense(self):
90 | Ms = [M.to_dense() if isinstance(M,LinearOperator) else M for M in self.Ms]
91 | return reduce(kronsum,Ms)
92 | def __new__(cls,Ms):
93 | if len(Ms)==1: return Ms[0]
94 | return super().__new__(cls)
95 | ## could also be implemented as follows, but the fusing the sum into a single linearOperator is faster
96 | # def lazy_kronsum(Ms):
97 | # n = len(Ms)
98 | # lprod = np.cumprod([1]+[mi.shape[-1] for mi in Ms])
99 | # rprod = np.cumprod([1]+[mi.shape[-1] for mi in reversed(Ms)])[::-1]
100 | # return reduce(lambda a,b: a+b,[lazy_kron([I(lprod[i]),Mi,I(rprod[i+1])]) for i,Mi in enumerate(Ms)])
101 |
102 |
103 | class LazyJVP(LinearOperator):
104 | def __init__(self,operator_fn,X,TX):
105 | self.shape = operator_fn(X).shape
106 | self.vjp = lambda v: jax.jvp(lambda x: operator_fn(x)@v,[X],[TX])[1]
107 | self.vjp_T = lambda v: jax.jvp(lambda x: operator_fn(x).T@v,[X],[TX])[1]
108 | self.dtype=jnp.dtype('float32')
109 | def _matmat(self,v):
110 | return self.vjp(v)
111 | def _matvec(self,v):
112 | return self.vjp(v)
113 | def _rmatmat(self,v):
114 | return self.vjp_T(v)
115 |
116 |
117 | class ConcatLazy(LinearOperator):
118 | """ Produces a linear operator equivalent to concatenating
119 | a collection of matrices Ms along axis=0 """
120 | def __init__(self,Ms):
121 | self.Ms = Ms
122 | assert all(M.shape[0]==Ms[0].shape[0] for M in Ms),\
123 | f"Trying to concatenate matrices of different sizes {[M.shape for M in Ms]}"
124 | shape = (sum(M.shape[0] for M in Ms),Ms[0].shape[1])
125 | super().__init__(None,shape)
126 |
127 | def _matmat(self,V):
128 | return jnp.concatenate([M@V for M in self.Ms],axis=0)
129 | def _rmatmat(self,V):
130 | Vs = jnp.split(V,len(self.Ms))
131 | return sum([self.Ms[i].T@Vs[i] for i in range(len(self.Ms))])
132 | def to_dense(self):
133 | dense_Ms = [M.to_dense() if isinstance(M,LinearOperator) else M for M in self.Ms]
134 | return jnp.concatenate(dense_Ms,axis=0)
135 |
136 | class LazyDirectSum(LinearOperator):
137 | def __init__(self,Ms,multiplicities=None):
138 | self.Ms = [jax.device_put(M.astype(np.float32)) if isinstance(M,(np.ndarray)) else M for M in Ms]
139 | self.multiplicities = [1 for M in Ms] if multiplicities is None else multiplicities
140 | shape = (sum(Mi.shape[0]*c for Mi,c in zip(Ms,multiplicities)),
141 | sum(Mi.shape[0]*c for Mi,c in zip(Ms,multiplicities)))
142 | super().__init__(None,shape)
143 | #self.dtype=Ms[0].dtype
144 | #self.dtype=jnp.dtype('float32')
145 |
146 | def _matvec(self,v):
147 | return lazy_direct_matmat(v,self.Ms,self.multiplicities)
148 |
149 | def _matmat(self,v): # (n,k)
150 | return lazy_direct_matmat(v,self.Ms,self.multiplicities)
151 | def _adjoint(self):
152 | return LazyDirectSum([Mi.T for Mi in self.Ms])
153 | def invT(self):
154 | return LazyDirectSum([M.invT() for M in self.Ms])
155 | def to_dense(self):
156 | Ms_all = [M for M,c in zip(self.Ms,self.multiplicities) for _ in range(c)]
157 | Ms_all = [Mi.to_dense() if isinstance(Mi,LinearOperator) else Mi for Mi in Ms_all]
158 | return jax.scipy.linalg.block_diag(*Ms_all)
159 | # def __new__(cls,Ms,multiplicities=None):
160 | # if len(Ms)==1 and multiplicities is None: return Ms[0]
161 | # return super().__new__(cls)
162 |
163 | def lazy_direct_matmat(v,Ms,mults):
164 | n = v.shape[0]
165 | k = v.shape[1] if len(v.shape)>1 else 1
166 | i=0
167 | y = []
168 | for M, multiplicity in zip(Ms,mults):
169 | i_end = i+multiplicity*M.shape[-1]
170 | elems = M@v[i:i_end].T.reshape(k*multiplicity,M.shape[-1]).T
171 | y.append(elems.T.reshape(k,multiplicity*M.shape[0]).T)
172 | i = i_end
173 | y = jnp.concatenate(y,axis=0) #concatenate over rep axis
174 | return y
175 |
176 |
177 | class LazyPerm(LinearOperator):
178 | def __init__(self,perm):
179 | self.perm=perm
180 | shape = (len(perm),len(perm))
181 | super().__init__(None,shape)
182 |
183 | def _matmat(self,V):
184 | return V[self.perm]
185 | def _matvec(self,V):
186 | return V[self.perm]
187 | def _adjoint(self):
188 | return LazyPerm(np.argsort(self.perm))
189 | def invT(self):
190 | return self
191 |
192 | class LazyShift(LinearOperator):
193 | def __init__(self,n,k=1):
194 | self.k=k
195 | shape = (n,n)
196 | super().__init__(None,shape)
197 |
198 | def _matmat(self,V): #(c,k) #Still needs to be tested??
199 | return jnp.roll(V,self.k,axis=0)
200 | def _matvec(self,V):
201 | return jnp.roll(V,self.k,axis=0)
202 | def _adjoint(self):
203 | return LazyShift(self.shape[0],-self.k)
204 | def invT(self):
205 | return self
206 |
207 | class SwapMatrix(LinearOperator):
208 | def __init__(self,swaprows,n):
209 | self.swaprows=swaprows
210 | shape = (n,n)
211 | super().__init__(None,shape)
212 | def _matmat(self,V): #(c,k)
213 | V = jax.ops.index_update(V, jax.ops.index[self.swaprows], V[self.swaprows[::-1]])
214 | return V
215 | def _matvec(self,V):
216 | return self._matmat(V)
217 | def _adjoint(self):
218 | return self
219 | def invT(self):
220 | return self
221 |
222 | class Rot90(LinearOperator):
223 | def __init__(self,n,k):
224 | shape = (n*n,n*n)
225 | self.n=n
226 | self.k = k
227 | super().__init__(None,shape)
228 | def _matmat(self,V): #(c,k)
229 | return jnp.rot90(V.reshape((self.n,self.n,-1)),self.k).reshape(V.shape)
230 | def _matvec(self,V):
231 | return jnp.rot90(V.reshape((self.n,self.n,-1)),self.k).reshape(V.shape)
232 | def invT(self):
233 | return self
--------------------------------------------------------------------------------
/emlp/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | class Named(type):
4 | def __str__(self):
5 | return self.__name__
6 | def __repr__(self):
7 | return self.__name__
8 |
9 | def export(fn):
10 | mod = sys.modules[fn.__module__]
11 | if hasattr(mod, '__all__'):
12 | mod.__all__.append(fn.__name__)
13 | else:
14 | mod.__all__ = [fn.__name__]
15 | return fn
--------------------------------------------------------------------------------
/experiments/data_efficiency.py:
--------------------------------------------------------------------------------
1 | from emlp.nn import MLP,EMLP#,LinearBNSwish
2 | from emlp.datasets import O5Synthetic, ParticleInteraction, Inertia
3 | from emlp.reps import T,Scalar
4 | from emlp.groups import SO, O, Trivial, O13, SO13, SO13p
5 | from oil.tuning.study import train_trial, Study
6 |
7 | from oil.tuning.args import argupdated_config
8 | from emlp.experiments.train_regression import makeTrainer
9 | import copy
10 | import emlp.datasets
11 |
12 | if __name__=="__main__":
13 | Trial = train_trial(makeTrainer)
14 | config_spec = copy.deepcopy(makeTrainer.__kwdefaults__)
15 | name = 'data_efficiency_nobn'
16 | config_spec['ndata'] = 30000+5000+1000
17 | config_spec['log_level'] = 'warning'
18 | # Run MLP baseline on datasets
19 | config_spec.update({
20 | 'dataset':ParticleInteraction,#[O5Synthetic,Inertia,ParticleInteraction],
21 | 'network':MLP,'aug':[False,True],
22 | 'num_epochs':(lambda cfg: min(int(30*30000/cfg['split']['train']),1000)),
23 | 'split':{'train':[30,100,300,1000,3000,10000,30000],'test':5000,'val':1000},
24 | })
25 | config_spec = argupdated_config(config_spec,namespace=datasets.regression)
26 | name = f"{name}_{config_spec['dataset']}"
27 | thestudy = Study(Trial,{},study_name=name,base_log_dir=config_spec['trainer_config'].get('log_dir',None))
28 | thestudy.run(num_trials=-3,new_config_spec=config_spec,ordered=True)
29 |
30 | # Now run the EMLP (with appropriate group) on the datasets
31 | config_spec['network']=EMLP
32 | config_spec['aug'] = False
33 | groups = {O5Synthetic:[SO(5),O(5)],Inertia:[SO(3),O(3)],ParticleInteraction:[SO13p(),SO13(),O13()]}
34 | config_spec['net_config']['group'] = groups[config_spec['dataset']]
35 | thestudy.run(num_trials=-3,new_config_spec=config_spec,ordered=True)
36 | print(thestudy.results_df())
37 |
--------------------------------------------------------------------------------
/experiments/datasets/batchnorm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import copy
4 | import numpy as np
5 | import objax.nn as nn
6 | import jax
7 | from jax import jit
8 | import jax.numpy as jnp
9 | from emlp.reps import Scalar
10 | from emlp.reps.product_sum_reps import SumRep
11 | import logging
12 | import objax.functional as F
13 | from functools import partial
14 | import objax
15 | from functools import lru_cache as cache
16 |
17 |
18 |
19 | @cache(maxsize=None)
20 | def gate_indices(sumrep): #TODO: add support for mixed_tensors
21 | """ Indices for scalars, and also additional scalar gates
22 | added by gated(sumrep)"""
23 | assert isinstance(sumrep,SumRep), f"unexpected type for gate indices {type(sumrep)}"
24 | channels = sumrep.size()
25 | perm = sumrep.perm
26 | indices = np.arange(channels)
27 | num_nonscalars = 0
28 | i=0
29 | for rep in sumrep:
30 | if rep!=Scalar and not rep.is_regular:
31 | indices[perm[i:i+rep.size()]] = channels+num_nonscalars
32 | num_nonscalars+=1
33 | i+=rep.size()
34 | return indices
35 |
36 | @cache(maxsize=None)
37 | def scalar_mask(sumrep):
38 | channels = sumrep.size()
39 | perm = sumrep.perm
40 | mask = np.ones(channels)>0
41 | i=0
42 | for rep in sumrep:
43 | if rep!=Scalar: mask[perm[i:i+rep.size()]] = False
44 | i+=rep.size()
45 | return mask
46 |
47 | @cache(maxsize=None)
48 | def regular_mask(sumrep):
49 | channels = sumrep.size()
50 | mask = np.ones(channels)<0
51 | i=0
52 | for rep in sumrep:
53 | if rep.is_regular: mask[perm[i:i+rep.size()]] = True
54 | i+=rep.size()
55 | return mask
56 |
57 | @export
58 | class TensorBN(nn.BatchNorm0D): #TODO: add suport for mixed tensors.
59 | """ Equivariant Batchnorm for tensor representations.
60 | Applies BN on Scalar channels and Mean only BN on others """
61 | def __init__(self,rep):
62 | super().__init__(rep.size(),momentum=0.9)
63 | self.rep=rep
64 | def __call__(self,x,training): #TODO: support elementwise for regular reps
65 | #return x #DISABLE BN, harms performance!! !!
66 | smask = jax.device_put(scalar_mask(self.rep))
67 | if training:
68 | m = x.mean(self.redux, keepdims=True)
69 | v = (x ** 2).mean(self.redux, keepdims=True) - m ** 2
70 | v = jnp.where(smask,v,ragged_gather_scatter((x ** 2).mean(self.redux),self.rep)) #in non scalar indices, divide by sum squared
71 | self.running_mean.value += (1 - self.momentum) * (m - self.running_mean.value)
72 | self.running_var.value += (1 - self.momentum) * (v - self.running_var.value)
73 | else:
74 | m, v = self.running_mean.value, self.running_var.value
75 | y = jnp.where(smask,self.gamma.value * (x - m) * F.rsqrt(v + self.eps) + self.beta.value,x*F.rsqrt(v+self.eps))#(x-m)*F.rsqrt(v + self.eps))
76 | return y # switch to or (x-m)
77 |
78 | class MaskBN(nn.BatchNorm0D):
79 | """ Equivariant Batchnorm for tensor representations.
80 | Applies BN on Scalar channels and Mean only BN on others """
81 | def __init__(self,ch):
82 | super().__init__(ch,momentum=0.9)
83 |
84 | def __call__(self,vals,mask,training=True):
85 | sum_dims = list(range(len(vals.shape[:-1])))
86 | x_or_zero = jnp.where(mask[...,None],vals,0*vals)
87 | if training:
88 | num_valid = mask.sum(sum_dims)
89 | m = x_or_zero.sum(sum_dims)/num_valid
90 | v = (x_or_zero ** 2).sum(sum_dims)/num_valid - m ** 2
91 | self.running_mean.value += (1 - self.momentum) * (m - self.running_mean.value)
92 | self.running_var.value += (1 - self.momentum) * (v - self.running_var.value)
93 | else:
94 | m, v = self.running_mean.value, self.running_var.value
95 | return ((x_or_zero-m)*self.gamma.value*F.rsqrt(v + self.eps) + self.beta.value,mask)
96 |
97 | class TensorMaskBN(nn.BatchNorm0D): #TODO find discrepancies with pytorch version
98 | """ Equivariant Batchnorm for tensor representations.
99 | Applies BN on Scalar channels and Mean only BN on others """
100 | def __init__(self,rep):
101 | super().__init__(rep.size(),momentum=0.9)
102 | self.rep=rep
103 | def __call__(self,x,mask,training):
104 | sum_dims = list(range(len(vals.shape[:-1])))
105 | x_or_zero = jnp.where(mask[...,None],vals,0*vals)
106 | smask = jax.device_put(scalar_mask(self.rep))
107 | if training:
108 | num_valid = mask.sum(sum_dims)
109 | m = x_or_zero.sum(sum_dims)/num_valid
110 | x2 = (x_or_zero ** 2).sum(sum_dims)/num_valid
111 | v = x2 - m ** 2
112 | v = jnp.where(smask,v,ragged_gather_scatter(x2,self.rep))
113 | self.running_mean.value += (1 - self.momentum) * (m - self.running_mean.value)
114 | self.running_var.value += (1 - self.momentum) * (v - self.running_var.value)
115 | else:
116 | m, v = self.running_mean.value, self.running_var.value
117 | y = jnp.where(smask,self.gamma.value * (x_or_zero - m) * F.rsqrt(v + self.eps) + \
118 | self.beta.value,x_or_zero*F.rsqrt(v+self.eps))
119 | return y,mask # switch to or (x-m)
120 |
121 | # @partial(jit,static_argnums=(1,))
122 | # def ragged_gather_scatter(x,x_rep):
123 | # y = []
124 | # i=0
125 | # for rep in x_rep.reps: # sum -> mean
126 | # y.append(x[i:i+rep.size()].mean(keepdims=True).repeat(rep.size(),axis=-1))
127 | # i+=rep.size()
128 | # return jnp.concatenate(y,-1)
129 |
130 | @partial(jit,static_argnums=(1,))
131 | def ragged_gather_scatter(x,x_rep):
132 | perm = x_rep.argsort()
133 | invperm = np.argsort(perm)
134 | x_sorted = x[perm]
135 | i=0
136 | y=[]
137 | for rep, multiplicity in x_rep.multiplicities().items():
138 | i_end = i+multiplicity*rep.size()
139 | y.append(x_sorted[i:i_end].reshape(multiplicity,rep.size()).mean(-1,keepdims=True).repeat(rep.size(),axis=-1).reshape(-1))
140 | i = i_end
141 | return jnp.concatenate(y)[invperm]
--------------------------------------------------------------------------------
/experiments/depreciated/train_cube_simple.py:
--------------------------------------------------------------------------------
1 |
2 | from emlp.models.datasets import InvertedCube
3 | from emlp.solver.groups import Cube
4 | from emlp.models.mlp import MLP,EMLP,Standardize
5 | from emlp.models.model_trainer import ClassifierPlus
6 | from emlp.solver.representation import T
7 | from slax.model_trainers import Classifier
8 | from torch.utils.data import DataLoader
9 | from oil.utils.utils import cosLr, islice, export,FixedNumpySeed,FixedPytorchSeed
10 | from slax.utils import LoaderTo
11 | from oil.tuning.study import train_trial
12 | from oil.datasetup.datasets import split_dataset
13 | from oil.tuning.args import argupdated_config
14 | import objax
15 | import logging
16 | import emlp.models
17 |
18 | # intermediate_rep1 = (100*T(0)+5*T(1))#+T(2))
19 | # intermediate_rep2 = 100*T(0)+10*T(1)
20 | #middle_reps = [intermediate_rep1,intermediate_rep2,intermediate_rep1]
21 | rep = 138*T(0)+23*T(1)+T(2)
22 | def makeTrainer(*,network=EMLP,num_epochs=500,seed=2020,aug=False,
23 | bs=50,lr=1e-3,device='cuda',
24 | net_config={'num_layers':3,'ch':rep,'group':Cube()},log_level='info',
25 | trainer_config={'log_dir':None,'log_args':{'minPeriod':.02,'timeFrac':50}},save=False):
26 | levels = {'critical': logging.CRITICAL,'error': logging.ERROR,
27 | 'warn': logging.WARNING,'warning': logging.WARNING,
28 | 'info': logging.INFO,'debug': logging.DEBUG}
29 | logging.getLogger().setLevel(levels[log_level])
30 | # Prep the datasets splits, model, and dataloaders
31 | with FixedNumpySeed(seed),FixedPytorchSeed(seed):
32 | datasets = {'train':InvertedCube(train=True),'test':InvertedCube(train=False)}
33 | model = Standardize(network(datasets['train'].rep_in,datasets['train'].rep_out,**net_config),datasets['train'].stats)
34 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=min(bs,len(v)),shuffle=(k=='train'),
35 | num_workers=0,pin_memory=False)) for k,v in datasets.items()}
36 | dataloaders['Train'] = dataloaders['train']
37 | opt_constr = objax.optimizer.Adam
38 | lr_sched = lambda e: lr*cosLr(num_epochs)(e)
39 | return ClassifierPlus(model,dataloaders,opt_constr,lr_sched,**trainer_config)
40 |
41 | if __name__ == "__main__":
42 | Trial = train_trial(makeTrainer)
43 | Trial(argupdated_config(makeTrainer.__kwdefaults__,namespace=(emlp.solver.groups,emlp.models.datasets,emlp.models.mlp)))
--------------------------------------------------------------------------------
/experiments/depreciated/train_rubiks.py:
--------------------------------------------------------------------------------
1 |
2 | from emlp.models.datasets import BrokenRubiksCube,BrokenRubiksCube2x2
3 | from emlp.solver.groups import RubiksCube,RubiksCube2x2
4 | from emlp.models.mlp import MLP,EMLP,Standardize
5 | from emlp.models.model_trainer import ClassifierPlus
6 | from emlp.solver.representation import T
7 | from slax.model_trainers import Classifier
8 | from torch.utils.data import DataLoader
9 | from oil.utils.utils import cosLr, islice, export,FixedNumpySeed,FixedPytorchSeed
10 | from slax.utils import LoaderTo
11 | from oil.tuning.study import train_trial
12 | from oil.datasetup.datasets import split_dataset
13 | from oil.tuning.args import argupdated_config
14 | import objax
15 | import logging
16 | import emlp.models
17 |
18 | # intermediate_rep1 = (100*T(0)+5*T(1))#+T(2))
19 | rep = 50*T(0)+5*T(1)+T(2)
20 | #middle_reps = [intermediate_rep1,intermediate_rep2,intermediate_rep1]
21 | def makeTrainer(*,network=EMLP,num_epochs=500,seed=2020,aug=False,
22 | bs=50,lr=1e-3,device='cuda',
23 | net_config={'num_layers':3,'ch':rep,'group':RubiksCube2x2()},log_level='info',
24 | trainer_config={'log_dir':None,'log_args':{'minPeriod':.02,'timeFrac':50}},save=False):
25 | levels = {'critical': logging.CRITICAL,'error': logging.ERROR,
26 | 'warn': logging.WARNING,'warning': logging.WARNING,
27 | 'info': logging.INFO,'debug': logging.DEBUG}
28 | logging.getLogger().setLevel(levels[log_level])
29 | # Prep the datasets splits, model, and dataloaders
30 | with FixedNumpySeed(seed),FixedPytorchSeed(seed):
31 | datasets = {'train':BrokenRubiksCube2x2(train=False),'test':BrokenRubiksCube2x2(train=False)}
32 | model = Standardize(network(datasets['train'].rep_in,datasets['train'].rep_out,**net_config),datasets['train'].stats)
33 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=min(bs,len(v)),shuffle=(k=='train'),
34 | num_workers=0,pin_memory=False)) for k,v in datasets.items()}
35 | dataloaders['Train'] = dataloaders['train']
36 | opt_constr = objax.optimizer.Adam
37 | lr_sched = lambda e: lr*cosLr(num_epochs)(e)#*min(1,e/(num_epochs/10))
38 | return ClassifierPlus(model,dataloaders,opt_constr,lr_sched,**trainer_config)
39 |
40 | if __name__ == "__main__":
41 | Trial = train_trial(makeTrainer)
42 | Trial(argupdated_config(makeTrainer.__kwdefaults__,namespace=(emlp.solver.groups,emlp.models.datasets,emlp.models.mlp)))
--------------------------------------------------------------------------------
/experiments/depreciated/train_tagging.py:
--------------------------------------------------------------------------------
1 | from emlp.models.mlp import MLP,EMLP#,LinearBNSwish
2 | from emlp.models.datasets import Fr,ParticleInteraction
3 | import jax.numpy as jnp
4 | import jax
5 | from emlp.solver.representation import T,Scalar,Matrix,Vector,Quad,repsize
6 | from emlp.solver.groups import SO,O,Trivial,Lorentz,O13,SO13,SO13p
7 | from emlp.models.mlp import EMLP,LieLinear,Standardize,EMLP2
8 | from emlp.models.model_trainer import RegressorPlus
9 | import itertools
10 | import numpy as np
11 | import torch
12 | from emlp.models.datasets import Inertia,Fr,ParticleInteraction
13 | from emlp.models.particle_dataset import TopTagging,collate_fn
14 | import objax
15 | import torch
16 | from torch.utils.data import DataLoader
17 | from oil.utils.utils import cosLr, islice, export,FixedNumpySeed,FixedPytorchSeed
18 | from slax.utils import LoaderTo
19 | from oil.tuning.study import train_trial
20 | from oil.datasetup.datasets import split_dataset
21 | from oil.tuning.args import argupdated_config
22 | from slax.model_trainers import Classifier,Trainer
23 | from functools import partial
24 | import torch.nn as nn
25 | import logging
26 | import emlp.models
27 | from emlp.models.pointconv_base import ResNet
28 | from emlp.models.model_trainer import ClassifierPlus
29 |
30 | def makeTrainer(*,network=ResNet,num_epochs=5,seed=2020,aug=False,
31 | bs=30,lr=1e-3,device='cuda',split={'train':-1,'val':10000},
32 | net_config={'k':512,'num_layers':4},log_level='info',
33 | trainer_config={'log_dir':None,'log_args':{'minPeriod':.02,'timeFrac':.2}},save=False):
34 | # Prep the datasets splits, model, and dataloaders
35 | datasets = {split:TopTagging(split=split) for split in ['train','val']}
36 | model = network(4,2,**net_config)
37 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=bs,shuffle=(k=='train'),
38 | num_workers=0,pin_memory=False,collate_fn=collate_fn,drop_last=True)) for k,v in datasets.items()}
39 | dataloaders['Train'] = islice(dataloaders['train'],0,None,10) #for logging subsample dataset by 5x
40 | #equivariance_test(model,dataloaders['train'],net_config['group'])
41 | opt_constr = objax.optimizer.Adam
42 | lr_sched = lambda e: lr*cosLr(num_epochs)(e)
43 | return ClassifierPlus(model,dataloaders,opt_constr,lr_sched,**trainer_config)
44 |
45 | if __name__ == "__main__":
46 | Trial = train_trial(makeTrainer)
47 | Trial(argupdated_config(makeTrainer.__kwdefaults__,namespace=(emlp.solver.groups,emlp.models.datasets,emlp.models.mlp)))
--------------------------------------------------------------------------------
/experiments/hnn.py:
--------------------------------------------------------------------------------
1 | from emlp.nn import MLP,EMLP,MLPH,EMLPH
2 | from emlp.groups import SO2eR3,O2eR3,DkeR3,Trivial
3 | from emlp.reps import Scalar
4 | from trainer.hamiltonian_dynamics import IntegratedDynamicsTrainer,DoubleSpringPendulum,hnn_trial
5 | from torch.utils.data import DataLoader
6 | from oil.utils.utils import cosLr,FixedNumpySeed,FixedPytorchSeed
7 | from trainer.utils import LoaderTo
8 | from oil.datasetup.datasets import split_dataset
9 | from oil.tuning.args import argupdated_config
10 | import torch.nn as nn
11 | import logging
12 | import emlp
13 | import emlp.reps
14 | import objax
15 |
16 | levels = {'critical': logging.CRITICAL,'error': logging.ERROR,
17 | 'warn': logging.WARNING,'warning': logging.WARNING,
18 | 'info': logging.INFO,'debug': logging.DEBUG}
19 |
20 | def makeTrainer(*,dataset=DoubleSpringPendulum,network=EMLPH,num_epochs=2000,ndata=5000,seed=2021,aug=False,
21 | bs=500,lr=3e-3,device='cuda',split={'train':500,'val':.1,'test':.1},
22 | net_config={'num_layers':3,'ch':128,'group':O2eR3()},log_level='info',
23 | trainer_config={'log_dir':None,'log_args':{'minPeriod':.02,'timeFrac':.75},},#'early_stop_metric':'val_MSE'},
24 | save=False,):
25 |
26 | logging.getLogger().setLevel(levels[log_level])
27 | # Prep the datasets splits, model, and dataloaders
28 | with FixedNumpySeed(seed),FixedPytorchSeed(seed):
29 | base_ds = dataset(n_systems=ndata,chunk_len=5)
30 | datasets = split_dataset(base_ds,splits=split)
31 | if net_config['group'] is None: net_config['group']=base_ds.symmetry
32 | model = network(base_ds.rep_in,Scalar,**net_config)
33 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=min(bs,len(v)),shuffle=(k=='train'),
34 | num_workers=0,pin_memory=False)) for k,v in datasets.items()}
35 | dataloaders['Train'] = dataloaders['train']
36 | #equivariance_test(model,dataloaders['train'],net_config['group'])
37 | opt_constr = objax.optimizer.Adam
38 | lr_sched = lambda e: lr#*cosLr(num_epochs)(e)#*min(1,e/(num_epochs/10))
39 | return IntegratedDynamicsTrainer(model,dataloaders,opt_constr,lr_sched,**trainer_config)
40 |
41 | if __name__ == "__main__":
42 | Trial = hnn_trial(makeTrainer)
43 | cfg,outcome = Trial(argupdated_config(makeTrainer.__kwdefaults__,namespace=(emlp.groups,emlp.nn)))
44 | print(outcome)
45 |
46 |
--------------------------------------------------------------------------------
/experiments/hnn_expt.py:
--------------------------------------------------------------------------------
1 | from oil.tuning.study import Study
2 | from emlp.nn import MLP,EMLP,MLPH,EMLPH
3 | from emlp.groups import SO2eR3,O2eR3,DkeR3,Trivial
4 |
5 | import copy
6 | from trainer.hamiltonian_dynamics import hnn_trial
7 | from hnn import makeTrainer
8 |
9 | if __name__=="__main__":
10 | Trial = hnn_trial(makeTrainer)
11 | config_spec = copy.deepcopy(makeTrainer.__kwdefaults__)
12 | name = "hnn_expt"#config_spec.pop('study_name')
13 | #name = f"{name}_{config_spec['dataset']}"
14 | thestudy = Study(Trial,{},study_name=name,base_log_dir=config_spec['trainer_config'].get('log_dir',None))
15 | config_spec['network'] = EMLPH
16 | config_spec['net_config']['group'] = [O2eR3(),SO2eR3(),DkeR3(6),DkeR3(2)]
17 | thestudy.run(num_trials=-5,new_config_spec=config_spec,ordered=True)
18 | config_spec['network'] = MLPH
19 | config_spec['net_config']['group'] = None
20 | thestudy.run(num_trials=-3,new_config_spec=config_spec,ordered=True)
21 | print(thestudy.results_df())
22 |
--------------------------------------------------------------------------------
/experiments/neuralode.py:
--------------------------------------------------------------------------------
1 | from emlp.nn import MLP,EMLP,MLPH,EMLPH,EMLPode,MLPode#,LinearBNSwish
2 | from emlp.groups import SO2eR3,O2eR3,DkeR3,Trivial
3 | from trainer.hamiltonian_dynamics import IntegratedODETrainer,DoubleSpringPendulum,ode_trial
4 | from torch.utils.data import DataLoader
5 | from oil.utils.utils import cosLr, islice, FixedNumpySeed,FixedPytorchSeed
6 | from trainer.utils import LoaderTo
7 | from oil.tuning.study import train_trial
8 | from oil.datasetup.datasets import split_dataset
9 | from oil.tuning.args import argupdated_config
10 | import logging
11 | import emlp.nn
12 | import emlp.groups
13 | import objax
14 |
15 | levels = {'critical': logging.CRITICAL,'error': logging.ERROR,
16 | 'warn': logging.WARNING,'warning': logging.WARNING,
17 | 'info': logging.INFO,'debug': logging.DEBUG}
18 |
19 | def makeTrainer(*,dataset=DoubleSpringPendulum,network=EMLPode,num_epochs=2000,ndata=5000,seed=2021,aug=False,
20 | bs=500,lr=3e-3,device='cuda',split={'train':500,'val':.1,'test':.1},
21 | net_config={'num_layers':3,'ch':128,'group':O2eR3()},log_level='warn',
22 | trainer_config={'log_dir':None,'log_args':{'minPeriod':.02,'timeFrac':.75},},#'early_stop_metric':'val_MSE'},
23 | save=False,):
24 |
25 | logging.getLogger().setLevel(levels[log_level])
26 | # Prep the datasets splits, model, and dataloaders
27 | with FixedNumpySeed(seed),FixedPytorchSeed(seed):
28 | base_ds = dataset(n_systems=ndata,chunk_len=5)
29 | datasets = split_dataset(base_ds,splits=split)
30 | if net_config['group'] is None: net_config['group']=base_ds.symmetry
31 | model = network(base_ds.rep_in,base_ds.rep_in,**net_config)
32 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=min(bs,len(v)),shuffle=(k=='train'),
33 | num_workers=0,pin_memory=False)) for k,v in datasets.items()}
34 | dataloaders['Train'] = dataloaders['train']
35 | #equivariance_test(model,dataloaders['train'],net_config['group'])
36 | opt_constr = objax.optimizer.Adam
37 | lr_sched = lambda e: lr#*cosLr(num_epochs)(e)#*min(1,e/(num_epochs/10))
38 | return IntegratedODETrainer(model,dataloaders,opt_constr,lr_sched,**trainer_config)
39 |
40 | if __name__ == "__main__":
41 | Trial = ode_trial(makeTrainer)
42 | cfg,outcome = Trial(argupdated_config(makeTrainer.__kwdefaults__,namespace=(emlp.groups,emlp.nn)))
43 | print(outcome)
44 |
--------------------------------------------------------------------------------
/experiments/neuralode_expt.py:
--------------------------------------------------------------------------------
1 | from emlp.nn import MLP,EMLP,MLPH,EMLPH,EMLPode,MLPode#,LinearBNSwish
2 | from emlp.groups import SO2eR3,O2eR3,DkeR3,Trivial
3 | from oil.tuning.study import Study
4 |
5 | import copy
6 | from trainer.hamiltonian_dynamics import ode_trial
7 | from neuralode import makeTrainer
8 |
9 | if __name__ == "__main__":
10 | Trial = ode_trial(makeTrainer)
11 | config_spec = copy.deepcopy(makeTrainer.__kwdefaults__)
12 | name = "ode_expt"#config_spec.pop('study_name')
13 |
14 | #name = f"{name}_{config_spec['dataset']}"
15 | thestudy = Study(Trial,{},study_name=name,base_log_dir=config_spec['trainer_config'].get('log_dir',None))
16 | config_spec['network'] = EMLPode
17 | config_spec['net_config']['group'] = [O2eR3(),SO2eR3(),DkeR3(6),DkeR3(2)]
18 | thestudy.run(num_trials=-5,new_config_spec=config_spec,ordered=True)
19 | config_spec['network'] = MLPode
20 | config_spec['net_config']['group'] = None
21 | thestudy.run(num_trials=-3,new_config_spec=config_spec,ordered=True)
22 | print(thestudy.results_df())
--------------------------------------------------------------------------------
/experiments/train_regression.py:
--------------------------------------------------------------------------------
1 | from emlp.nn import MLP, EMLP, Standardize
2 | from trainer.model_trainer import RegressorPlus
3 | from torch.utils.data import DataLoader
4 | from oil.utils.utils import cosLr, FixedNumpySeed, FixedPytorchSeed
5 | from trainer.utils import LoaderTo
6 | from oil.tuning.study import train_trial
7 | from oil.datasetup.datasets import split_dataset
8 | from oil.tuning.args import argupdated_config
9 | import logging
10 | import emlp.nn
11 | import emlp.reps
12 | import emlp.groups
13 | import objax
14 | import emlp.datasets
15 | from emlp.datasets import Inertia,O5Synthetic,ParticleInteraction
16 |
17 | log_levels = {'critical': logging.CRITICAL,'error': logging.ERROR,
18 | 'warn': logging.WARNING,'warning': logging.WARNING,
19 | 'info': logging.INFO,'debug': logging.DEBUG}
20 |
21 | def makeTrainer(*,dataset=Inertia,network=EMLP,num_epochs=300,ndata=1000+2000,seed=2021,aug=False,
22 | bs=500,lr=3e-3,device='cuda',split={'train':-1,'val':1000,'test':1000},
23 | net_config={'num_layers':3,'ch':384,'group':None},log_level='info',
24 | trainer_config={'log_dir':None,'log_args':{'minPeriod':.02,'timeFrac':.75},
25 | 'early_stop_metric':'val_MSE'},save=False,):
26 |
27 | logging.getLogger().setLevel(log_levels[log_level])
28 | # Prep the datasets splits, model, and dataloaders
29 | with FixedNumpySeed(seed),FixedPytorchSeed(seed):
30 | base_dataset = dataset(ndata)
31 | datasets = split_dataset(base_dataset,splits=split)
32 | if net_config['group'] is None: net_config['group']=base_dataset.symmetry
33 | model = network(base_dataset.rep_in,base_dataset.rep_out,**net_config)
34 | if aug: model = base_dataset.default_aug(model)
35 | model = Standardize(model,datasets['train'].stats)
36 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=min(bs,len(v)),shuffle=(k=='train'),
37 | num_workers=0,pin_memory=False)) for k,v in datasets.items()}
38 | dataloaders['Train'] = dataloaders['train']
39 | opt_constr = objax.optimizer.Adam
40 | lr_sched = lambda e: lr#*min(1,e/(num_epochs/10)) # Learning rate warmup
41 | return RegressorPlus(model,dataloaders,opt_constr,lr_sched,**trainer_config)
42 |
43 | if __name__ == "__main__":
44 | cfg = argupdated_config(makeTrainer.__kwdefaults__,
45 | namespace=(emlp.groups,emlp.datasets,emlp.nn))
46 | trainer = makeTrainer(**cfg)
47 | trainer.train(cfg['num_epochs'])
48 |
--------------------------------------------------------------------------------
/experiments/trainer/classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from oil.utils.utils import export
4 | from .trainer import Trainer
5 | import jax
6 | import jax.numpy as jnp
7 | import numpy as np
8 |
9 | def cross_entropy(logprobs, targets):
10 | ll = jnp.take_along_axis(logprobs, jnp.expand_dims(targets, axis=1), axis=1)
11 | ce = -jnp.mean(ll)
12 | return ce
13 |
14 | @export
15 | class Classifier(Trainer):
16 | """ Trainer subclass. Implements loss (crossentropy), batchAccuracy
17 | and getAccuracy (full dataset) """
18 |
19 | def loss(self,minibatch):
20 | """ Standard cross-entropy loss """ #TODO: support class weights
21 | x,y = minibatch
22 | logits = self.model(x,training=True)
23 | logp = jax.nn.log_softmax(logits)
24 | return cross_entropy(logp,y)
25 |
26 | def metrics(self,loader):
27 | acc = lambda mb: np.asarray(jax.device_get(jnp.mean(jnp.argmax(self.model.predict(mb[0]),axis=-1)==mb[1])))
28 | return {'Acc':self.evalAverageMetrics(loader,acc)}
29 |
30 | @export
31 | class Regressor(Trainer):
32 | """ Trainer subclass. Implements loss (crossentropy), batchAccuracy
33 | and getAccuracy (full dataset) """
34 |
35 | def loss(self,minibatch):
36 | """ Standard cross-entropy loss """
37 | x,y = minibatch
38 | mse = jnp.mean((self.model(x,training=True)-y)**2)
39 | return mse
40 |
41 | def metrics(self,loader):
42 | mse = lambda mb: np.asarray(jax.device_get(jnp.mean((self.model.predict(mb[0])-mb[1])**2)))
43 | return {'MSE':self.evalAverageMetrics(loader,mse)}
--------------------------------------------------------------------------------
/experiments/trainer/model_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from oil.utils.utils import export
4 | import jax
5 | from jax import vmap
6 | import jax.numpy as jnp
7 | import numpy as np
8 | import objax
9 | from .classifier import Regressor,Classifier
10 | from functools import partial
11 | from itertools import islice
12 |
13 | def rel_err(a,b):
14 | return jnp.sqrt(((a-b)**2).mean())/(jnp.sqrt((a**2).mean())+jnp.sqrt((b**2).mean()))#
15 |
16 | def scale_adjusted_rel_err(a,b,g):
17 | return jnp.sqrt(((a-b)**2).mean())/(jnp.sqrt((a**2).mean())+jnp.sqrt((b**2).mean())+jnp.abs(g-jnp.eye(g.shape[-1])).mean())
18 |
19 | def equivariance_err(model,mb,group=None):
20 | x,y = mb
21 | group = model.model.G if group is None else group
22 | gs = group.samples(x.shape[0])
23 | rho_gin = vmap(model.model.rep_in.rho_dense)(gs)
24 | rho_gout = vmap(model.model.rep_out.rho_dense)(gs)
25 | y1 = model.predict((rho_gin@x[...,None])[...,0])
26 | y2 = (rho_gout@model.predict(x)[...,None])[...,0]
27 | return np.asarray(scale_adjusted_rel_err(y1,y2,gs))
28 |
29 | @export
30 | class RegressorPlus(Regressor):
31 | """ Trainer subclass. Implements loss (crossentropy), batchAccuracy
32 | and getAccuracy (full dataset) """
33 | def __init__(self,model,*args,**kwargs):
34 | super().__init__(model,*args,**kwargs)
35 | fastloss = objax.Jit(self.loss,model.vars())
36 | self.gradvals = objax.Jit(objax.GradValues(fastloss,model.vars()),model.vars())
37 | self.model.predict = objax.Jit(objax.ForceArgs(model.__call__,training=False),model.vars())
38 | #self.model.predict = lambda x: self.model(x,training=False)
39 | def loss(self,minibatch):
40 | """ Standard cross-entropy loss """
41 | x,y = minibatch
42 | mse = jnp.mean((self.model(x,training=True)-y)**2)#jnp.mean(jnp.abs(self.model(x,training=True)-y))
43 | return mse
44 |
45 | def metrics(self,loader):
46 | mse = lambda mb: np.asarray(jax.device_get(jnp.mean((self.model.predict(mb[0])-mb[1])**2)))
47 | return {'MSE':self.evalAverageMetrics(loader,mse)}
48 | def logStuff(self, step, minibatch=None):
49 | metrics = {}
50 | metrics['test_equivar_err'] = self.evalAverageMetrics(islice(self.dataloaders['test'],0,None,5),
51 | partial(equivariance_err,self.model)) # subsample by 5x so it doesn't take too long
52 | self.logger.add_scalars('metrics', metrics, step)
53 | super().logStuff(step,minibatch)
54 |
55 | @export
56 | class ClassifierPlus(Classifier):
57 | """ Trainer subclass. Implements loss (crossentropy), batchAccuracy
58 | and getAccuracy (full dataset) """
59 | def __init__(self,model,*args,**kwargs):
60 | super().__init__(model,*args,**kwargs)
61 |
62 | fastloss = objax.Jit(self.loss,model.vars())
63 | self.gradvals = objax.Jit(objax.GradValues(fastloss,model.vars()),model.vars())
64 | self.model.predict = objax.Jit(objax.ForceArgs(model.__call__,training=False),model.vars())
65 | #self.model.predict = lambda x: self.model(x,training=False)
66 |
67 |
68 | def logStuff(self, step, minibatch=None):
69 | metrics = {}
70 | metrics['test_equivar_err'] = self.evalAverageMetrics(islice(self.dataloaders['test'],0,None,5),
71 | partial(equivariance_err,self.model)) # subsample by 5x so it doesn't take too long
72 | self.logger.add_scalars('metrics', metrics, step)
73 | super().logStuff(step,minibatch)
--------------------------------------------------------------------------------
/experiments/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | import dill
2 | from oil.logging.lazyLogger import LazyLogger
3 | from oil.utils.utils import Eval, Named
4 | from oil.utils.mytqdm import tqdm
5 | from oil.tuning.study import guess_metric_sign
6 | import copy, os, random
7 | import glob
8 | import numpy as np
9 | from natsort import natsorted
10 | import jax
11 | import logging
12 | from functools import partial
13 | import objax
14 |
15 | class Trainer(object,metaclass=Named):
16 | """ Base trainer
17 | """
18 | def __init__(self, model, dataloaders, optim = objax.optimizer.Adam,lr_sched =lambda e:1,
19 | log_dir=None, log_suffix='',log_args={},early_stop_metric=None):
20 | # Setup model, optimizer, and dataloaders
21 | self.model = model#
22 | #self.model= objax.Jit(objax.ForceArgs(model,training=True)) #TODO: figure out static nums
23 | #self.model.predict = objax.Jit(objax.ForceArgs(model.__call__,training=False),model.vars())
24 | #self.model.predict = objax.ForceArgs(model.__call__,training=False)
25 | #self._model = model
26 | #self.model = objax.ForceArgs(model,training=True)
27 | #self.model.predict = objax.ForceArgs(model.__call__,training=False)
28 | #self.model = objax.Jit(lambda x, training: model(x,training=training),model.vars(),static_argnums=(1,))
29 | #self.model = objax.Jit(model,static_argnums=(1,))
30 |
31 | self.optimizer = optim(model.vars())
32 | self.lr_sched= lr_sched
33 | self.dataloaders = dataloaders # A dictionary of dataloaders
34 | self.epoch = 0
35 |
36 | self.logger = LazyLogger(log_dir, log_suffix, **log_args)
37 | #self.logger.add_text('ModelSpec','model: {}'.format(model))
38 | self.hypers = {}
39 | self.ckpt = None# copy.deepcopy(self.state_dict()) #TODO fix model saving
40 | self.early_stop_metric = early_stop_metric
41 | #fastloss = objax.Jit(self.loss,model.vars())
42 |
43 | self.gradvals = objax.GradValues(self.loss,self.model.vars())
44 | #self.gradvals = objax.Jit(objax.GradValues(fastloss,model.vars()),model.vars())
45 | def metrics(self,loader):
46 | return {}
47 |
48 | def loss(self,minibatch):
49 | raise NotImplementedError
50 |
51 | def train_to(self, final_epoch=100):
52 | assert final_epoch>=self.epoch, "trying to train less than already trained"
53 | self.train(final_epoch-self.epoch)
54 |
55 | def train(self, num_epochs=100):
56 | """ The main training loop"""
57 | start_epoch = self.epoch
58 | steps_per_epoch = len(self.dataloaders['train']); step=0
59 | for self.epoch in tqdm(range(start_epoch, start_epoch + num_epochs),desc='train'):
60 | for i, minibatch in enumerate(self.dataloaders['train']):
61 | step = i + self.epoch*steps_per_epoch
62 | self.step(self.epoch+i/steps_per_epoch,minibatch)
63 | with self.logger as do_log:
64 | if do_log: self.logStuff(step, minibatch)
65 | self.epoch+=1
66 | self.logStuff(step)
67 |
68 | def step(self, epoch, minibatch):
69 | grad,loss = self.gradvals(minibatch)
70 | self.optimizer(self.lr_sched(epoch),grad)
71 | return loss
72 |
73 | def logStuff(self, step, minibatch=None):
74 | metrics = {}
75 | # if minibatch is not None and hasattr(self,'loss'):
76 | # try: metrics['Minibatch_Loss'] = self.loss(self.model_params,minibatch)
77 | # except (NotImplementedError, TypeError): pass
78 | for loader_name,dloader in self.dataloaders.items(): # Ignore metrics on train
79 | if loader_name=='train' or len(dloader)==0 or loader_name[0]=='_': continue
80 | for metric_name, metric_value in self.metrics(dloader).items():
81 | metrics[loader_name+'_'+metric_name] = metric_value
82 | self.logger.add_scalars('metrics', metrics, step)
83 | # for name,m in self.model.named_modules():
84 | # if hasattr(m, 'log_data'):
85 | # m.log_data(self.logger,step,name)
86 | self.logger.report()
87 | # update the best checkpoint
88 | if self.early_stop_metric is not None:
89 | maximize = guess_metric_sign(self.early_stop_metric)
90 | sign = 2*maximize-1
91 | best = (sign*self.logger.scalar_frame[self.early_stop_metric].values).max()
92 | current = sign*self.logger.scalar_frame[self.early_stop_metric].iloc[-1]
93 | if current >= best: self.ckpt = copy.deepcopy(self.state_dict())
94 | else: self.ckpt = copy.deepcopy(self.state_dict())
95 |
96 | def evalAverageMetrics(self,loader,metrics):
97 | num_total, loss_totals = 0, 0
98 | for minibatch in loader:
99 | try: mb_size = loader.batch_size
100 | except AttributeError: mb_size=1
101 | loss_totals += mb_size*metrics(minibatch)
102 | num_total += mb_size
103 | if num_total==0: raise KeyError("dataloader is empty")
104 | return loss_totals/num_total
105 |
106 | def state_dict(self):
107 | #TODO: handle saving and loading state
108 | state = {
109 | 'outcome':self.logger.scalar_frame[-1:],
110 | 'epoch':self.epoch,
111 | # 'model_state':self.model.state_dict(),
112 | # 'optim_state':self.optimizer.state_dict(),
113 | # 'logger_state':self.logger.state_dict(),
114 | }
115 | return state
116 |
117 | # def load_state_dict(self,state):
118 | # self.epoch = state['epoch']
119 | # self.model.load_state_dict(state['model_state'])
120 | # self.optimizer.load_state_dict(state['optim_state'])
121 | # self.logger.load_state_dict(state['logger_state'])
122 |
123 | # def load_checkpoint(self,path=None):
124 | # """ Loads the checkpoint from path, if None gets the highest epoch checkpoint"""
125 | # if not path:
126 | # chkpts = glob.glob(os.path.join(self.logger.log_dirr,'checkpoints/c*.state'))
127 | # path = natsorted(chkpts)[-1] # get most recent checkpoint
128 | # print(f"loading checkpoint {path}")
129 | # with open(path,'rb') as f:
130 | # self.load_state_dict(dill.load(f))
131 |
132 | # def save_checkpoint(self):
133 | # return self.logger.save_object(self.ckpt,suffix=f'checkpoints/c{self.epoch}.state')
134 |
135 |
--------------------------------------------------------------------------------
/experiments/trainer/utils.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import functools
3 | from oil.utils.utils import imap
4 | import numpy as np
5 | def minibatch_to(mb):
6 | try:
7 | if isinstance(mb,np.ndarray):
8 | return jax.device_put(mb)
9 | return jax.device_put(mb.numpy())
10 | except AttributeError:
11 | if isinstance(mb,dict):
12 | return type(mb)(((k,minibatch_to(v)) for k,v in mb.items()))
13 | else:
14 | return type(mb)(minibatch_to(elem) for elem in mb)
15 |
16 | def LoaderTo(loader):
17 | return imap(functools.partial(minibatch_to),loader)
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup,find_packages
2 | import sys, os, re
3 |
4 | README_FILE = 'README.md'
5 |
6 | def get_property(prop, project):
7 | result = re.search(r'{}\s*=\s*[\'"]([^\'"]*)[\'"]'.format(prop), open(project + '/__init__.py').read())
8 | return result.group(1)
9 |
10 | project_name = "emlp"
11 | setup(name=project_name,
12 | description="A Practical Method for Constructing Equivariant Multilayer Perceptrons for Arbitrary Matrix Groups",
13 | version= get_property('__version__',project_name),
14 | author='Marc Finzi',
15 | author_email='maf820@nyu.edu',
16 | license='MIT',
17 | python_requires='>=3.6',
18 | install_requires=['h5py','objax','pytest','plum-dispatch',
19 | 'optax','tqdm>=4.38','matplotlib','scikit-learn'],
20 | extras_require = {
21 | 'EXPTS':['olive-oil-ml']
22 | },
23 | packages=find_packages(),
24 | long_description=open('README.md', encoding='UTF-8').read(),
25 | long_description_content_type='text/markdown',
26 | url='https://github.com/mfinzi/equivariant-MLP',
27 | classifiers=[
28 | 'Development Status :: 4 - Beta',
29 | 'Intended Audience :: Developers',
30 | 'Intended Audience :: Science/Research',
31 | 'Programming Language :: Python :: 3',
32 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
33 | ],
34 | keywords=[
35 | 'equivariance','MLP','symmetry','group','AI','neural network',
36 | 'representation','group theory','deep learning','machine learning',
37 | 'rotation','Lorentz invariance',
38 | ],
39 |
40 | )
41 |
--------------------------------------------------------------------------------
/tests/equivariance_tests.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np#
3 | import copy
4 | from emlp.reps import *
5 | from emlp.groups import *
6 | from emlp.nn import uniform_rep
7 | import pytest#import unittest
8 | from jax import vmap
9 | import jax.numpy as jnp
10 | import logging
11 | import argparse
12 | import sys
13 | import copy
14 | import inspect
15 | #from functools import partialmethod,partial
16 |
17 | def rel_error(t1,t2):
18 | error = jnp.sqrt(jnp.mean(jnp.abs(t1-t2)**2))
19 | scale = jnp.sqrt(jnp.mean(jnp.abs(t1)**2)) + jnp.sqrt(jnp.mean(jnp.abs(t2)**2))
20 | return error/jnp.maximum(scale,1e-7)
21 |
22 | def scale_adjusted_rel_error(t1,t2,g):
23 | error = jnp.sqrt(jnp.mean(jnp.abs(t1-t2)**2))
24 | tscale = jnp.sqrt(jnp.mean(jnp.abs(t1)**2)) + jnp.sqrt(jnp.mean(jnp.abs(t2)**2))
25 | gscale = jnp.sqrt(jnp.mean(jnp.abs(g-jnp.eye(g.shape[-1]))**2))
26 | scale = jnp.maximum(tscale,gscale)
27 | return error/jnp.maximum(scale,1e-7)
28 |
29 | def equivariance_error(W,repin,repout,G):
30 | """ Computes the equivariance relative error rel_err(Wρ₁(g),ρ₂(g)W)
31 | of the matrix W (dim(repout),dim(repin))
32 | according to the input and output representations and group G. """
33 | N=5
34 | x = np.random.rand(N,repin.size())
35 | gs = G.samples(N)
36 | ring = vmap(repin.rho_dense)(gs)
37 | routg = vmap(repout.rho_dense)(gs)
38 | equiv_err = scale_adjusted_rel_error(W@ring,routg@W,gs)
39 | return equiv_err
40 |
41 | def strip_parens(string):
42 | return string.replace('(','').replace(')','')
43 |
44 | def parametrize(cases,ids=None):
45 | """ Expands test cases with pytest.mark.parametrize but with argnames
46 | assumed and ids given by the ids=[str(case) for case in cases] """
47 | def decorator(test_fn):
48 | argnames = ','.join(inspect.getfullargspec(test_fn).args)
49 | theids = [strip_parens(str(case)) for case in cases] if ids is None else ids
50 | return pytest.mark.parametrize(argnames,cases,ids=theids)(test_fn)
51 | return decorator
52 | # def expand_cases(cls,argseq):
53 | # def class_decorator(testcase):
54 | # for args in argseq:
55 | # setattr(cls, f"{testcase.__name__}_{args}", partialmethod(testcase,*tuplify(args)))
56 | # #setattr(cls,f"{testcase.__name__}_{args}",types.MethodType(partial(testcase,*tuplify(args)),cls))
57 | # return testcase
58 | # return class_decorator
59 |
60 | # def tuplify(x):
61 | # if not isinstance(x, tuple): return (x,)
62 | # return x
63 |
64 | test_groups = [SO(n) for n in [2,3,4]]+[O(n) for n in [2,3,4]]+\
65 | [SU(n) for n in [2,3,4]] + [U(n) for n in [2,3,4]] + \
66 | [SL(n) for n in [2,3,4]] + [GL(n) for n in [2,3,4]] + \
67 | [C(k) for k in [2,3,4,8]]+[D(k) for k in [2,3,4,8]]+\
68 | [S(n) for n in [2,4,6]]+[Z(n) for n in [2,4,6]]+\
69 | [SO11p(),SO13p(),SO13(),O13()] +[Sp(n) for n in [1,3]]+\
70 | [RubiksCube(),Cube(),ZksZnxZn(2,2),ZksZnxZn(4,4)]
71 | # class TestRepresentationSubspace(unittest.TestCase): pass
72 | # expand_test_cases = partial(expand_cases,TestRepresentationSubspace)
73 |
74 |
75 | #@pytest.mark.parametrize("G",test_groups,ids=test_group_names)
76 | @parametrize(test_groups)
77 | def test_sum(G):
78 | N=5
79 | rep = T(0,2)+3*(T(0,0)+T(1,0))+T(0,0)+T(1,1)+2*T(1,0)+T(0,2)+T(0,1)+3*T(0,2)+T(2,0)
80 | rep = rep(G)
81 | if G.num_constraints()*rep.size()>1e11 or rep.size()**2>10**7: return
82 | P = rep.equivariant_projector()
83 | v = np.random.rand(rep.size())
84 | v = P@v
85 | gs = G.samples(N)
86 | gv = (vmap(rep.rho_dense)(gs)*v).sum(-1)
87 | err = vmap(scale_adjusted_rel_error)(gv,v+jnp.zeros_like(gv),gs).mean()
88 | assert err<1e-4,f"Symmetric vector fails err {err:.3e} with G={G}"
89 |
90 | #@pytest.mark.parametrize("G",test_groups,ids=test_group_names)
91 | @parametrize([group for group in test_groups if group.d<5])
92 | def test_prod(G):
93 | N=5
94 | rep = T(0,1)*T(0,0)*T(1,0)**2*T(1,0)*T(0,0)**3*T(0,1)
95 | rep = rep(G)
96 | if G.num_constraints()*rep.size()>1e11 or rep.size()**2>10**7: return
97 | # P = rep.equivariant_projector()
98 | # v = np.random.rand(rep.size())
99 | # v = P@v
100 | Q = rep.equivariant_basis()
101 | v = Q@np.random.rand(Q.shape[-1])
102 | gs = G.samples(N)
103 | gv = (vmap(rep.rho_dense)(gs)*v).sum(-1)
104 |
105 | #print(f"g {gs[0]} and rho_dense {rep.rho_dense(gs[0])} {rep.rho_dense(gs[0]).shape}")
106 | err = vmap(scale_adjusted_rel_error)(gv,v+jnp.zeros_like(gv),gs).mean()
107 | assert err<1e-4,f"Symmetric vector fails err {err:.3e} with G={G}"
108 |
109 | #@pytest.mark.parametrize("G",test_groups,ids=test_group_names)
110 | @parametrize(test_groups)
111 | def test_high_rank_representations(G):
112 | N=5
113 | r = 10
114 | for p in range(r+1):
115 | for q in range(r-p+1):
116 | if G.num_constraints()*G.d**(3*(p+q))>1e11: continue
117 | if G.is_orthogonal and q>0: continue
118 | #try:
119 | #logging.info(f"{p},{q},{T(p,q)}")
120 | rep = T(p,q)(G)
121 | P = rep.equivariant_projector()
122 | v = np.random.rand(rep.size())
123 | v = P@v
124 | g = vmap(rep.rho_dense)(G.samples(N))
125 | gv = (g*v).sum(-1)
126 | #print(f"v{v.shape}, g{g.shape},gv{gv.shape},{G},T{p,q}")
127 | err = vmap(scale_adjusted_rel_error)(gv,v+jnp.zeros_like(gv),g).mean()
128 | if np.isnan(err): continue # deal with nans on cpu later
129 | assert err<1e-4,f"Symmetric vector fails err {err:.3e} with T{p,q} and G={G}"
130 | logging.info(f"Success with T{p,q} and G={G}")
131 | # except Exception as e:
132 | # print(f"Failed with G={G} and T({p,q})")
133 | # raise e
134 |
135 | @parametrize([
136 | (SO(3),T(1)+2*T(0),T(1)+T(2)+2*T(0)+T(1)),
137 | (SO(3),5*T(0)+5*T(1),3*T(0)+T(2)+2*T(1)),
138 | (SO(3),5*(T(0)+T(1)),2*(T(0)+T(1))+T(2)+T(1)),
139 | (SO(4), T(1)+2*T(2),(T(0)+T(3))*T(0)),
140 | (SO13p(),T(2)+4*T(1,0)+T(0,1),10*T(0)+3*T(1,0)+3*T(0,1)+T(0,2)+T(2,0)+T(1,1)),
141 | (Sp(2),(V+2*V**2)*(V.T+1*V).T + V.T,3*V**0 + V + V*V.T),
142 | (SU(3),T(2,0)+T(1,1)+T(0)+2*T(0,1),T(1,1)+V+V.T+T(0)+T(2,0)+T(0,2))
143 | ])
144 | def test_equivariant_matrix(G,repin,repout):
145 | N=5
146 | repin = repin(G)
147 | repout = repout(G)
148 | #repW = repout*repin.T
149 | repW = repin>>repout
150 | P = repW.equivariant_projector()
151 | W = np.random.rand(repout.size(),repin.size())
152 | W = (P@W.reshape(-1)).reshape(*W.shape)
153 |
154 | x = np.random.rand(N,repin.size())
155 | gs = G.samples(N)
156 | ring = vmap(repin.rho_dense)(gs)
157 | routg = vmap(repout.rho_dense)(gs)
158 | gx = (ring@x[...,None])[...,0]
159 | Wgx =gx@W.T
160 | #print(g.shape,(x@W.T).shape)
161 | gWx = (routg@(x@W.T)[...,None])[...,0]
162 | equiv_err = rel_error(Wgx,gWx)
163 | assert equiv_err<1e-4,f"Equivariant gWx=Wgx fails err {equiv_err:.3e} with G={G}"
164 |
165 | # print(f"R {repW.rho(gs[0])}")
166 | # print(f"R1 x R2 {jnp.kron(routg[0],jnp.linalg.inv(ring[0]).T)}")
167 | gvecW = (vmap(repW.rho_dense)(gs)*W.reshape(-1)).sum(-1)
168 | gWerr =vmap(scale_adjusted_rel_error)(gvecW,W.reshape(-1)+jnp.zeros_like(gvecW),gs).mean()
169 | assert gWerr<1e-4,f"Symmetric gvec(W)=vec(W) fails err {gWerr:.3e} with G={G}"
170 |
171 | @parametrize([(SO(3),5*T(0)+5*T(1),3*T(0)+T(2)+2*T(1)),
172 | (SO13p(),4*T(1,0),10*T(0)+3*T(1,0)+3*T(0,1)+T(0,2)+T(2,0)+T(1,1))])
173 | def test_bilinear_layer(G,repin,repout):
174 | N=5
175 | repin = repin(G)
176 | repout = repout(G)
177 | repW = repout*repin.T
178 | Wdim,P = bilinear_weights(repout,repin)
179 | x = np.random.rand(N,repin.size())
180 | gs = G.samples(N)
181 | ring = vmap(repin.rho_dense)(gs)
182 | routg = vmap(repout.rho_dense)(gs)
183 | gx = (ring@x[...,None])[...,0]
184 |
185 | W = np.random.rand(Wdim)
186 | W_x = P(W,x)
187 | Wxx = (W_x@x[...,None])[...,0]
188 | gWxx = (routg@Wxx[...,None])[...,0]
189 | Wgxgx =(P(W,gx)@gx[...,None])[...,0]
190 | equiv_err = rel_error(Wgxgx,gWxx)
191 | assert equiv_err<1e-4,f"Bilinear Equivariance fails err {equiv_err:.3e} with G={G}"
192 |
193 | @parametrize(test_groups)
194 | def test_large_representations(G):
195 | N=5
196 | ch = 256
197 | rep =repin=repout= uniform_rep(ch,G)
198 | repW = rep>>rep
199 | P = repW.equivariant_projector()
200 | W = np.random.rand(repout.size(),repin.size())
201 | W = (P@W.reshape(-1)).reshape(*W.shape)
202 |
203 | x = np.random.rand(N,repin.size())
204 | gs = G.samples(N)
205 | ring = vmap(repin.rho_dense)(gs)
206 | routg = vmap(repout.rho_dense)(gs)
207 | gx = (ring@x[...,None])[...,0]
208 | Wgx =gx@W.T
209 | #print(g.shape,(x@W.T).shape)
210 | gWx = (routg@(x@W.T)[...,None])[...,0]
211 | equiv_err = vmap(scale_adjusted_rel_error)(Wgx,gWx,gs).mean()
212 | assert equiv_err<1e-4,f"Large Rep Equivariant gWx=Wgx fails err {equiv_err:.3e} with G={G}"
213 | logging.info(f"Success with G={G}")
214 |
215 | # #print(dir(TestRepresentationSubspace))
216 | # if __name__ == '__main__':
217 | # parser = argparse.ArgumentParser()
218 | # parser.add_argument("--log", default="warning",help=("Logging Level Example --log debug', default='warning'"))
219 | # options,unknown_args = parser.parse_known_args()#["--log"])
220 | # levels = {'critical': logging.CRITICAL,'error': logging.ERROR,'warn': logging.WARNING,'warning': logging.WARNING,
221 | # 'info': logging.INFO,'debug': logging.DEBUG}
222 | # level = levels.get(options.log.lower())
223 | # logging.getLogger().setLevel(level)
224 | # unit_argv = [sys.argv[0]] + unknown_args
225 | # unittest.main(argv=unit_argv)
--------------------------------------------------------------------------------
/tests/model_tests.py:
--------------------------------------------------------------------------------
1 | import jax
2 | from jax import vmap
3 | import numpy as np
4 | import pytest
5 | from torch.utils.data import DataLoader
6 | from emlp.nn import uniform_rep,MLP,EMLP,Standardize
7 | import emlp
8 | from equivariance_tests import parametrize,rel_error,scale_adjusted_rel_error
9 | from oil.utils.utils import FixedNumpySeed, FixedPytorchSeed
10 | from emlp.datasets import Inertia,O5Synthetic,ParticleInteraction,InvertedCube
11 | from jax import random
12 |
13 |
14 | # def rel_err(a,b):
15 | # return jnp.sqrt(((a-b)**2).mean())/(jnp.sqrt((a**2).mean())+jnp.sqrt((b**2).mean()))#
16 |
17 | # def scale_adjusted_rel_err(a,b,g):
18 | # return jnp.sqrt(((a-b)**2).mean())/(jnp.sqrt((a**2).mean())+jnp.sqrt((b**2).mean())+jnp.abs(g-jnp.eye(g.shape[-1])).mean())
19 |
20 | def equivariance_err(model,mb,repin,repout,group):
21 | x,y = mb
22 | gs = group.samples(x.shape[0])
23 | rho_gin = vmap(repin(group).rho_dense)(gs)
24 | rho_gout = vmap(repout(group).rho_dense)(gs)
25 | y1 = model((rho_gin@x[...,None])[...,0])
26 | y2 = (rho_gout@model(x)[...,None])[...,0]
27 | return np.asarray(scale_adjusted_rel_error(y1,y2,gs))
28 |
29 | def get_dsmb(dsclass):
30 | seed=2021
31 | bs=50
32 | with FixedNumpySeed(seed),FixedPytorchSeed(seed):
33 | ds = dsclass(100)
34 | dataloader = DataLoader(ds,batch_size=min(bs,len(ds)),num_workers=0,pin_memory=False)
35 | mb = next(iter(dataloader))
36 | mb = jax.device_put(mb[0].numpy()),jax.device_put(mb[1].numpy())
37 | return ds,mb
38 |
39 | @parametrize([Inertia,O5Synthetic,ParticleInteraction,InvertedCube])
40 | def test_init_forward_and_equivariance(dsclass):
41 | network=emlp.nn.objax.EMLP
42 | ds,mb = get_dsmb(dsclass)
43 | model = network(ds.rep_in,ds.rep_out,group=ds.symmetry)
44 | assert equivariance_err(model,mb,ds.rep_in,ds.rep_out,ds.symmetry) < 1e-4, "Objax EMLP failed equivariance test"
45 |
46 | @parametrize([Inertia])
47 | def test_haiku_emlp(dsclass):
48 | import haiku as hk
49 | from emlp.nn.haiku import EMLP as hkEMLP
50 | network = hkEMLP
51 | ds,mb = get_dsmb(dsclass)
52 | net = network(ds.rep_in,ds.rep_out,group=ds.symmetry)
53 | net = hk.without_apply_rng(hk.transform(net))
54 | params = net.init(random.PRNGKey(42),mb[0])
55 | model = lambda x: net.apply(params,x)
56 | assert equivariance_err(model,mb,ds.rep_in,ds.rep_out,ds.symmetry) < 1e-4, "Haiku EMLP failed equivariance test"
57 |
58 | @parametrize([Inertia])
59 | def test_flax_emlp(dsclass):
60 | from emlp.nn.flax import EMLP as flaxEMLP
61 | network = flaxEMLP
62 | ds,mb = get_dsmb(dsclass)
63 | net = network(ds.rep_in,ds.rep_out,group=ds.symmetry)
64 | params = net.init(random.PRNGKey(42),mb[0])
65 | model = lambda x: net.apply(params,x)
66 | assert equivariance_err(model,mb,ds.rep_in,ds.rep_out,ds.symmetry) < 1e-4, "flax EMLP failed equivariance test"
67 |
68 | @parametrize([Inertia])
69 | def test_pytorch_emlp(dsclass):
70 | import torch
71 | from emlp.nn.pytorch import EMLP as ptEMLP
72 | network=ptEMLP
73 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
74 | ds,mb = get_dsmb(dsclass)
75 | net = network(ds.rep_in,ds.rep_out,group=ds.symmetry).to(device)
76 | model = lambda x: jax.device_put(net(torch.from_numpy(np.asarray(x)).to(device)).cpu().data.numpy())
77 | assert equivariance_err(model,mb,ds.rep_in,ds.rep_out,ds.symmetry) < 1e-4, "Pytorch EMLP failed equivariance test"
78 |
79 |
80 | from emlp.reps import vis, sparsify_basis, V,Rep
81 | from emlp.groups import S,SO
82 |
83 | def test_utilities():
84 | W = V(SO(3))
85 | vis(W,W)
86 | Q = (W**2>>W).equivariant_basis()
87 | SQ = sparsify_basis(Q)
88 | A = SQ@(1+np.arange(SQ.shape[-1]))
89 | nunique = len(np.unique(np.abs(A)))
90 | assert nunique in (SQ.shape[-1],SQ.shape[-1]+1), "Sparsify failes on SO(3) T3"
91 |
92 |
93 | def test_bespoke_representations():
94 | class ProductSubRep(Rep):
95 | def __init__(self,G,subgroup_id,size):
96 | """ Produces the representation of the subgroup of G = G1 x G2
97 | with the index subgroup_id in {0,1} specifying G1 or G2.
98 | Also requires specifying the size of the representation given by G1.d or G2.d """
99 | self.G = G
100 | self.index = subgroup_id
101 | self._size = size
102 | def __str__(self):
103 | return "V_"+str(self.G).split('x')[self.index]
104 | def size(self):
105 | return self._size
106 | def rho(self,M):
107 | # Given that M is a LazyKron object, we can just get the argument
108 | return M.Ms[self.index]
109 | def drho(self,A):
110 | return A.Ms[self.index]
111 | def __call__(self,G):
112 | # adding this will probably not be necessary in a future release,
113 | # necessary now because rep is __call__ed in nn.EMLP constructor
114 | assert self.G==G
115 | return self
116 | G1,G2 = SO(3),S(5)
117 | G = G1 * G2
118 |
119 | VSO3 = ProductSubRep(G,0,G1.d)
120 | VS5 = ProductSubRep(G,1,G2.d)
121 | Vin = VS5 + V(G)
122 | Vout = VSO3
123 | str(Vin>>Vout)
124 | model = emlp.nn.EMLP(Vin, Vout, group=G)
125 | input_point = np.random.randn(Vin.size())*10
126 | from emlp.reps.linear_operators import LazyKron
127 | lazy_G_sample = LazyKron([G1.sample(),G2.sample()])
128 |
129 | out1 = model(Vin.rho(lazy_G_sample)@input_point)
130 | out2 = Vout.rho(lazy_G_sample)@model(input_point)
131 | assert rel_error(out1,out2) < 1e-4, "EMLP equivariance fails on bespoke productsubrep"
132 |
--------------------------------------------------------------------------------
/tests/product_groups_tests.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np#
3 | import copy
4 | from emlp.reps import *
5 | from emlp.groups import *
6 | from emlp.nn import uniform_rep
7 | from equivariance_tests import parametrize,rel_error,scale_adjusted_rel_error
8 | import unittest
9 | from jax import vmap
10 | import jax.numpy as jnp
11 | import logging
12 |
13 |
14 | @parametrize([(SO(3),S(5)),(S(5),SO(3))])
15 | def test_symmetric_mixed_tensor(G1,G2):
16 | N=5
17 | rep = T(2)(G1)*T(1)(G2)
18 | P = rep.equivariant_projector()
19 | v = np.random.rand(rep.size())
20 | v = P@v
21 | samples = {G1:G1.samples(N),G2:G2.samples(N)}
22 | gv = (vmap(rep.rho_dense)(samples)*v).sum(-1)
23 | err = rel_error(gv,v+jnp.zeros_like(gv))
24 | assert err<3e-5,f"Symmetric vector fails err {err:.3e} with G={G1}x{G2}"
25 |
26 |
27 | @parametrize([(SO(3),S(5)),(S(5),SO(3))])
28 | def test_symmetric_mixed_tensor_sum(G1,G2):
29 | N=5
30 | rep = T(2)(G1)*T(1)(G2) + 2*T(0)(G1)*T(2)(G2)+T(1)(G1) +T(1)(G2)
31 | P = rep.equivariant_projector()
32 | v = np.random.rand(rep.size())
33 | v = P@v
34 | samples = {G1:G1.samples(N),G2:G2.samples(N)}
35 | gv = (vmap(rep.rho_dense)(samples)*v).sum(-1)
36 | err = rel_error(gv,v+jnp.zeros_like(gv))
37 | assert err<3e-5,f"Symmetric vector fails err {err:.3e} with G={G1}x{G2}"
38 |
39 |
40 | @parametrize([(SO(3),S(5)),(S(5),SO(3))])
41 | def test_symmetric_mixed_products(G1,G2):
42 | N=5
43 | rep1 = (T(0)+2*T(1)+T(2))(G1)
44 | rep2 = (T(0)+T(1))(G2)
45 | rep = rep2*rep1.T
46 | P = rep.equivariant_projector()
47 | v = np.random.rand(rep.size())
48 | v = P@v
49 | W = v.reshape((rep2.size(),rep1.size()))
50 | x = np.random.rand(N,rep1.size())
51 | g1s = G1.samples(N)
52 | g2s = G2.samples(N)
53 | ring = vmap(rep1.rho_dense)(g1s)
54 | routg = vmap(rep2.rho_dense)(g2s)
55 | gx = (ring@x[...,None])[...,0]
56 | Wgx =gx@W.T
57 | gWx = (routg@(x@W.T)[...,None])[...,0]
58 | equiv_err = rel_error(Wgx,gWx)
59 | assert equiv_err<1e-5,f"Equivariant gWx=Wgx fails err {equiv_err:.3e} with G={G1}x{G2}"
60 | samples = {G1:g1s,G2:g2s}
61 | gv = (vmap(rep.rho_dense)(samples)*v).sum(-1)
62 | err = rel_error(gv,v+jnp.zeros_like(gv))
63 | assert err<3e-5,f"Symmetric vector fails err {err:.3e} with G={G1}x{G2}"
64 |
65 | @parametrize([(SO(3),S(5)),(S(5),SO(3))])
66 | def test_equivariant_matrix(G1,G2):
67 | N=5
68 | repin = T(2)(G2) + 3*T(0)(G1) + T(1)(G2)+2*T(2)(G1)*T(1)(G2)
69 | repout = (T(1)(G1) + T(2)(G1)*T(0)(G2) + T(1)(G1)*T(1)(G2) + T(0)(G1)+T(2)(G1)*T(1)(G2))
70 | repW = repout*repin.T
71 | P = repW.equivariant_projector()
72 | W = np.random.rand(repout.size(),repin.size())
73 | W = (P@W.reshape(-1)).reshape(*W.shape)
74 |
75 | x = np.random.rand(N,repin.size())
76 | samples = {G1:G1.samples(N),G2:G2.samples(N)}
77 | ring = vmap(repin.rho_dense)(samples)
78 | routg = vmap(repout.rho_dense)(samples)
79 | gx = (ring@x[...,None])[...,0]
80 | Wgx =gx@W.T
81 | #print(g.shape,(x@W.T).shape)
82 | gWx = (routg@(x@W.T)[...,None])[...,0]
83 | equiv_err = rel_error(Wgx,gWx)
84 | assert equiv_err<3e-5,f"Equivariant gWx=Wgx fails err {equiv_err:.3e} with G={G1}x{G2}"
85 | # too much memory to run
86 | # gvecW = (vmap(repW.rho_dense)(samples)*W.reshape(-1)).sum(-1)
87 | # for i in range(N):
88 | # gWerr = rel_error(gvecW[i],W.reshape(-1))
89 | # assert gWerr<1e-5,f"Symmetric gvec(W)=vec(W) fails err {gWerr:.3e} with G={G1}x{G2}"
--------------------------------------------------------------------------------