├── .deepsource.toml ├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── docs ├── CHANGELOG.md ├── Makefile ├── _static │ ├── emlp_frontfig_bitmap.png │ ├── emlp_logo4x.png │ └── style.css ├── _templates │ └── layout.html ├── conf.py ├── documentation.md ├── index.rst ├── inner_workings.md ├── manual_make.sh ├── notebooks │ ├── 1quickstart.ipynb │ ├── 2building_a_model.ipynb │ ├── 3new_groups.ipynb │ ├── 4new_representations.ipynb │ ├── 5mixed_tensors.ipynb │ ├── 6multilinear_maps.ipynb │ ├── _colab_preamble.ipynb │ ├── colabs │ │ ├── 1quickstart.ipynb │ │ ├── 2building_a_model.ipynb │ │ ├── 3new_groups.ipynb │ │ ├── 4new_representations.ipynb │ │ ├── 5mixed_tensors.ipynb │ │ ├── 6multilinear_maps.ipynb │ │ ├── 7pytorch_support.ipynb │ │ ├── all.ipynb │ │ ├── flax_support.ipynb │ │ ├── haiku_support.ipynb │ │ ├── imgs │ │ │ ├── EMLP_fig.png │ │ │ └── imgs │ │ │ │ └── EMLP_fig.png │ │ └── pytorch_support.ipynb │ ├── flax_support.ipynb │ ├── haiku_support.ipynb │ ├── imgs │ │ └── EMLP_fig.png │ ├── merge_nbs.sh │ └── pytorch_support.ipynb ├── package │ ├── emlp.groups.rst │ ├── emlp.nn.rst │ └── emlp.reps.rst ├── requirements.txt └── testing.md ├── emlp ├── __init__.py ├── datasets.py ├── groups.py ├── nn │ ├── __init__.py │ ├── flax.py │ ├── haiku.py │ ├── objax.py │ └── pytorch.py ├── reps │ ├── __init__.py │ ├── linear_operator_base.py │ ├── linear_operators.py │ ├── product_sum_reps.py │ └── representation.py └── utils.py ├── experiments ├── data_efficiency.py ├── datasets │ └── batchnorm.py ├── depreciated │ ├── train_cube_simple.py │ ├── train_rubiks.py │ └── train_tagging.py ├── hnn.py ├── hnn_expt.py ├── neuralode.py ├── neuralode_expt.py ├── notebooks │ ├── additional_appendix_figs.ipynb │ ├── make_tables.ipynb │ ├── paper_figures.ipynb │ └── synthetic_results_all.csv ├── train_regression.py └── trainer │ ├── classifier.py │ ├── hamiltonian_dynamics.py │ ├── model_trainer.py │ ├── trainer.py │ └── utils.py ├── setup.py └── tests ├── equivariance_tests.py ├── model_tests.py └── product_groups_tests.py /.deepsource.toml: -------------------------------------------------------------------------------- 1 | version = 1 2 | 3 | test_patterns = ["tests/*tests.py"] 4 | 5 | exclude_patterns = [ 6 | "*.ipynb", 7 | "experiments/*.py", 8 | "docs/*" 9 | ] 10 | 11 | [[analyzers]] 12 | name = "python" 13 | enabled = true 14 | 15 | [analyzers.meta] 16 | runtime_version = "3.x.x" 17 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - '**' # matches every branch 7 | pull_request: 8 | branches: 9 | - '**' # matches every branch 10 | 11 | workflow_dispatch: 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v2 18 | - uses: actions/setup-python@v1 19 | with: 20 | python-version: 3.7 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | python -m pip install pytest 25 | python -m pip install pytest-cov 26 | pip install git+https://github.com/deepmind/dm-haiku 27 | python -m pip install flax 28 | pip install -e .[EXPTS] 29 | - name: Test coverage. 30 | run: | 31 | pytest --cov emlp --cov-report xml:cov.xml tests/*.py 32 | - name: Upload coverage to Codecov 33 | uses: codecov/codecov-action@v1 34 | with: 35 | files: ./cov.xml 36 | name: codecov-umbrella 37 | path_to_write_report: ./coverage/codecov_report.txt 38 | verbose: true 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Random other stuff 132 | *.pyc 133 | *.pdf 134 | *.state 135 | *.tfevents* 136 | *.df 137 | *.s 138 | 139 | # additional excludes specific to the repo 140 | *.t 141 | *.dat 142 | *.h5 143 | *.pkl 144 | *.json -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | build: 9 | image: latest 10 | 11 | 12 | # Build documentation in the docs/ directory with Sphinx 13 | sphinx: 14 | builder: html 15 | configuration: docs/conf.py 16 | 17 | # Optionally build your docs in additional formats such as PDF 18 | formats: [] 19 | # - pdf 20 | 21 | # Optionally set the version of Python and requirements required to build your docs 22 | python: 23 | version: 3.8 24 | install: 25 | - requirements: docs/requirements.txt 26 | - method: pip 27 | path: . -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 mfinzi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | logo 3 |
4 | 5 | # A Practical Method for Constructing Equivariant Multilayer Perceptrons for Arbitrary Matrix Groups 6 | [![Documentation](https://readthedocs.org/projects/emlp/badge/)](https://emlp.readthedocs.io/en/latest/) | [![Paper](https://img.shields.io/badge/arXiv-2104.09459-red)](https://arxiv.org/abs/2104.09459) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mfinzi/equivariant-MLP/blob/master/docs/notebooks/colabs/all.ipynb) | 7 | [![codecov.io](https://codecov.io/github/mfinzi/equivariant-MLP/coverage.svg)](https://codecov.io/github/mfinzi/equivariant-MLP) 8 | | [![PyPI version](https://img.shields.io/pypi/v/emlp)](https://pypi.org/project/emlp/) 9 | 10 | 11 | 12 | 13 | *EMLP* is a jax library for the automated construction of equivariant layers in deep learning based on the ICML2021 paper [A Practical Method for Constructing Equivariant Multilayer Perceptrons for Arbitrary Matrix Groups](https://arxiv.org/abs/2104.09459). You can read the documentation [here](https://emlp.readthedocs.io/en/latest/). 14 | 15 | 16 | ## What EMLP is great at doing 17 | 18 | - Computing equivariant linear layers between finite dimensional 19 | representations. You specify the symmetry group (discrete, continuous, 20 | non compact, complex) and the representations (tensors, irreducibles, induced representations, etc), and we will compute the basis of equivariant 21 | maps mapping from one to the other. 22 | 23 | - Automatic construction of full equivariant models for small data. E.g. 24 | if your inputs and outputs (and intended features) are a small collection of elements like scalars, vectors, tensors, irreps with a total dimension less than 1000, then you will likely be able to use EMLP as a turnkey solution for making the model or atleast function as a strong baseline. 25 | 26 | - As a tool for building larger models, but where EMLP is just one component in a larger system. For example, using EMLP as the convolution kernel in an equivariant PointConv network. 27 | 28 | ## What EMLP is not great at doing 29 | 30 | - An efficient implementation of CNNs, Deep Sets, typical translation + rotation equivariant GCNNs, graph neural networks. 31 | 32 | - Handling large data like images, voxel grids, medium-large graphs, point clouds. 33 | 34 | Given the current approach, EMLP can only ever be as fast as an MLP. So if flattening the inputs into a single vector would be too large to train with an MLP, then it will also be too large to train with EMLP. 35 | 36 | -------------------------------------------------------------------------------- 37 | 38 | # Showcasing some examples of computing equivariant bases 39 | 40 | We provide a type system for representations. With the operators ρᵤ⊗ρᵥ, ρᵤ⊕ρᵥ, ρ* implemented as `*`,`+` and `.T` build up different representations. The basic building blocks for representations are the base vector representation `V` and tensor representations `T(p,q) = V**p*V.T**q`. 41 | 42 | For any given matrix group and representation formed in our type system, you can get the equivariant basis with [`rep.equivariant_basis()`](https://emlp.readthedocs.io/en/latest/package/emlp.reps.html#emlp.reps.equivariant_basis) or a matrix which projects to that subspace with [`rep.equivariant_projector()`](https://emlp.readthedocs.io/en/latest/package/emlp.reps.html#emlp.reps.equivariant_projector). 43 | 44 | For example to find all O(1,3) (Lorentz) equivariant linear maps from from a 4-Vector Xᶜ to a rank (2,1) tensor Mᵇᵈₐ, you can run 45 | 46 | ```python 47 | from emlp.reps import V,T 48 | from emlp.groups import * 49 | 50 | G = O13() 51 | Q = (T(1,0)>>T(2,1))(G).equivariant_basis() 52 | ``` 53 | 54 | or how about equivariant maps from one Rubik's cube to another? 55 | ```python 56 | G = RubiksCube() 57 | 58 | Q = (V(G)>>V(G)).equivariant_basis() 59 | ``` 60 | 61 | Using `+` and `*` you can put together composite representations (where multiple representations are concatenated together). For example lets find all equivariant linear maps from 5 node features and 2 edge features to 3 global invariants and 1 edge feature of a graph of size n=5: 62 | ```python 63 | G=S(5) 64 | 65 | repin = 10*T(1)+5*T(2) 66 | repout = 3*T(0)+T(2) 67 | Q = (repin(G)>>repout(G)).equivariant_basis() 68 | ``` 69 | 70 | From the examples above, there are many different ways of writing a representation like `10*T(1)+5*T(2)` which are all equivalent. 71 | `10*T(1)+5*T(2)` = `10*V+5*V**2` = `5*V*(2+V)` 72 | 79 | 80 | You can even mix and match representations from different groups. For example with the cyclic group ℤ₃, the permutation group 𝕊₄, and the orthogonal group O(3) 81 | 82 | ```python 83 | rep = 2*V(Z(3))*V(S(4))+V(O(3))**2 84 | Q = (rep>>rep).equivariant_basis() 85 | ``` 86 | 87 | Outside of these tensor representations, our type system works with any finite dimensional linear representation and you can even build your own bespoke representations following the instructions [here](https://emlp.readthedocs.io/en/latest/notebooks/4new_representations.html). 88 | 89 | You can visualize these equivariant bases with [`vis(repin,repout)`](https://emlp.readthedocs.io/en/latest/package/emlp.reps.html#emlp.reps.vis), such as with the three examples above 90 | 91 | 92 | 94 | 95 | 96 | Checkout our [documentation](https://emlp.readthedocs.io/en/latest/) to see how to use our system and some worked examples. 97 | 98 | # Simple example of using EMLP as a full equivariant model 99 | 100 | Suppose we want to construct a Lorentz equivariant model for particle physics data that takes in the input and output 4-momentum of two particles 101 | in a collision, as well as a some metadata about these particles like their charge, and we want to classify the output 102 | as belonging to 3 distinct classes of collisions. Since the outputs are simple logits, they should be unchanged by 103 | Lorentz transformation, and similarly with the charges. 104 | 105 | ```python 106 | import emlp 107 | from emlp.reps import T 108 | from emlp.groups import Lorentz 109 | import numpy as np 110 | 111 | repin = 4*T(1)+2*T(0) # 4 four vectors and 2 scalars for the charges 112 | repout = 3*T(0) # 3 output logits for the 3 classes of collisions 113 | group = Lorentz() 114 | model = emlp.nn.EMLP(repin,repout,group=group,num_layers=3,ch=384) 115 | 116 | x = np.random.randn(32,repin(group).size()) # Create a minibatch of data 117 | y = model(x) # Outputs the 3 class logits 118 | ``` 119 | 120 | Here we have used the default Objax EMLP, but you can also use our [PyTorch](https://emlp.readthedocs.io/en/latest/notebooks/pytorch_support.html), [Haiku](https://emlp.readthedocs.io/en/latest/notebooks/haiku_support.html), or [Flax](https://emlp.readthedocs.io/en/latest/notebooks/flax_support.html) versions of the models. To see more examples, or how to use your own representations or symmetry groups, check out the documentation. 121 | 122 | # Installation instructions 123 | 124 | To install as a package, run 125 | ```bash 126 | pip install emlp 127 | ``` 128 | 129 | To run the scripts you will instead need to clone the repo and install it locally which you can do with 130 | 131 | ```bash 132 | git clone https://github.com/mfinzi/equivariant-MLP.git 133 | cd equivariant-MLP 134 | pip install -e .[EXPTS] 135 | ``` 136 | 137 | # Experimental Results from Paper 138 | 139 | Assuming you have installed the repo locally, you can run the experiments we described in the paper. 140 | 141 | To train the regression models on one of the `Inertia`, `O5Synthetic`, or `ParticleInteraction` datasets found in [`emlp.datasets.py`](https://github.com/mfinzi/equivariant-MLP/blob/master/emlp/datasets.py) you can run the script [`experiments/train_regression.py`](https://github.com/mfinzi/equivariant-MLP/blob/master/experiments/train_regression.py) with command line arguments specifying the dataset, network, and symmetry group. For example to train [`EMLP`](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html#emlp.nn.EMLP) with [`SO(3)`](https://emlp.readthedocs.io/en/latest/package/emlp.groups.html#emlp.groups.SO) equivariance on the `Inertia` dataset, you can run 142 | 143 | ``` 144 | python experiments/train_regression.py --dataset Inertia --network EMLP --group "SO(3)" 145 | ``` 146 | 147 | or to train the MLP baseline you can run 148 | 149 | ``` 150 | python experiments/train_regression.py --dataset Inertia --network MLP 151 | ``` 152 | Other command line arguments such as `--aug=True` for data augmentation or `--ch=512` for number of hidden units and others are available, and you can browse the options and their defaults with `python experiments/train_regression.py -h`. If no group is specified, EMLP will automatically choose the one matched to the dataset, but you can also go crazy with any of the other groups implemented in [`groups.py`](https://github.com/mfinzi/equivariant-MLP/blob/master/emlp/groups.py) provided the dimensions match the data (e.g. for the 3D inertia dataset you could do `--group=` [`"Z(3)"`](https://emlp.readthedocs.io/en/latest/package/emlp.groups.html#emlp.groups.Z) or [`"DkeR3(3)"`](https://emlp.readthedocs.io/en/latest/package/emlp.groups.html#emlp.groups.DkeR3) but not [`"Sp(2)"`](https://emlp.readthedocs.io/en/latest/package/emlp.groups.html#emlp.groups.Sp) or [`"SU(5)"`](https://emlp.readthedocs.io/en/latest/package/emlp.groups.html#emlp.groups.SU)). 153 | 154 | For the dynamical systems modeling experiments you can use the scripts 155 | [`experiments/neuralode.py`](https://github.com/mfinzi/equivariant-MLP/blob/master/experiments/neuralode.py) to train (equivariant) Neural ODEs and [`experiments/hnn.py`](https://github.com/mfinzi/equivariant-MLP/blob/master/experiments/hnn.py) to train (equivariant) Hamiltonian Neural Networks. 156 | 157 | 158 | For the dynamical system task, the Neural ODE and HNN models have special names. [`EMLPode`](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html#emlp.nn.EMLPode) and [`MLPode`](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html#emlp.nn.MLPode) for the Neural ODEs in `neuralode.py` and [`EMLPH`](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html#emlp.nn.EMLPH) and [`MLPH`](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html#emlp.nn.MLPH) for the HNNs in `hnn.py`. For example, 159 | 160 | ``` 161 | python experiments/neuralode.py --network EMLPode --group="O2eR3()" 162 | ``` 163 | or 164 | 165 | ``` 166 | python experiments/hnn.py --network EMLPH --group="DkeR3(6)" 167 | ``` 168 | 169 | These models are trained to fit a double spring dynamical system. 30s rollouts of the dataset, along with rollout error on these trajectories, and conservation of angular momentum are shown below. 170 | 171 | 172 | 173 | 177 | 178 | If you find our work helpful, please cite it with 179 | ```bibtex 180 | @article{finzi2021emlp, 181 | title={A Practical Method for Constructing Equivariant Multilayer Perceptrons for Arbitrary Matrix Groups}, 182 | author={Finzi, Marc and Welling, Max and Wilson, Andrew Gordon}, 183 | journal={Arxiv}, 184 | year={2021} 185 | } 186 | ``` 187 | 189 | -------------------------------------------------------------------------------- /docs/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Change Log 2 | 3 | 7 | 8 | ## EMLP 1.0.0 9 | * New Features 10 | * Flax support (see `using EMLP with Flax`) 11 | * Auto generated `size()`, `__eq__`, `__hash__`, and `.T` methods for new representations 12 | * You can now use ints in place of `Scalars` for direct sum, e.g. add `3+V` 13 | * Codebase improvements 14 | * Streamlined product_sum_reps direct sum and product rules, now with plumb dispatch 15 | * More general `Dual(Rep)` implementation that now works with any kind of Rep, not just `V` 16 | * CI setup and with more tests 17 | 18 | ## EMLP 0.9.0 19 | * Cross Platform Support: 20 | * You can now use EMLP in PyTorch, check out `Using EMLP with PyTorch` 21 | * You can also use EMLP with Haiku in jax, check out `Using EMLP with Haiku` 22 | 23 | * Bug Fixes 24 | * Fixed broken constraints with Trivial group 25 | 26 | ## EMLP 0.8.0 (Unreleased) 27 | 28 | * New features: 29 | * Fallback autograd jvp implementation of drho to make implementing new reps easier. 30 | * Mixed group representations (now working and tested) 31 | * Experimental support of complex groups and representations 32 | * Bug Fixes: 33 | * Element ordering of mixed groups is now correctly maintained in the solution 34 | * Fixed edge case of {func}`lazy_direct_matmat` when concatenating matrices of size 0 35 | affecting {func}`emlp.reps.Rep.equivariant_basis` but not 36 | {func}`emlp.reps.Rep.equivariant_projector` 37 | * API Changes: 38 | * `emlp.solver.representation` -> `emlp.reps` 39 | * `emlp.solver.groups` -> `emlp.groups` 40 | * `emlp.models.mlp` -> `emlp.nn` 41 | * `rep.symmetric_basis()` -> `rep.equivariant_basis()` 42 | * `rep.symmetric_projector()` -> `rep.equivariant_projector()` 43 | * Tests and experiments separated from package and api 44 | 45 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/emlp_frontfig_bitmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfinzi/equivariant-MLP/b80815bd4c1ca52b37f3dacc0874e70fdd7b413f/docs/_static/emlp_frontfig_bitmap.png -------------------------------------------------------------------------------- /docs/_static/emlp_logo4x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfinzi/equivariant-MLP/b80815bd4c1ca52b37f3dacc0874e70fdd7b413f/docs/_static/emlp_logo4x.png -------------------------------------------------------------------------------- /docs/_static/style.css: -------------------------------------------------------------------------------- 1 | @import url("theme.css"); 2 | 3 | .wy-side-nav-search { 4 | background-color: #fff; 5 | } -------------------------------------------------------------------------------- /docs/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | {% set css_files = css_files + ["_static/style.css"] %} -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | import os 18 | import sys 19 | import re 20 | sys.path.insert(0, os.path.abspath('..')) 21 | 22 | RE_VERSION = re.compile(r'^__version__ \= \'(\d+\.\d+\.\d+(?:\w+\d+)?)\'$', re.MULTILINE) 23 | PROJECTDIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 24 | sys.path.insert(0, PROJECTDIR) 25 | def get_release(): 26 | with open(os.path.join(PROJECTDIR, 'emlp', '__init__.py')) as f: 27 | version = re.search(RE_VERSION, f.read()) 28 | assert version is not None, "can't parse __version__ from __init__.py" 29 | return version.group(1) 30 | 31 | # -- Project information ----------------------------------------------------- 32 | 33 | project = 'EMLP' 34 | copyright = '2021, Marc Finzi' 35 | author = 'Marc Finzi' 36 | 37 | # import emlp 38 | release = get_release() 39 | 40 | 41 | sys.path.append(os.path.abspath('sphinxext')) 42 | extensions = [ 43 | 'nbsphinx', 44 | 'sphinx.ext.mathjax', 45 | 'recommonmark', 46 | 'sphinx.ext.autodoc', 47 | 'sphinx.ext.autosummary', 48 | 'sphinx.ext.intersphinx', 49 | 'sphinx.ext.napoleon', 50 | 'sphinx.ext.viewcode', 51 | 'matplotlib.sphinxext.plot_directive', 52 | 'sphinx_autodoc_typehints', 53 | 'sphinx.ext.autosectionlabel', 54 | ] 55 | autosummary_generate = True 56 | autodoc_default_options = {'autosummary': True} 57 | autodoc_member_order = 'bysource' 58 | 59 | intersphinx_mapping = { 60 | 'python': ('https://docs.python.org/3/', None), 61 | 'numpy': ('https://numpy.org/doc/stable/', None), 62 | 'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None), 63 | 'objax': ('https://objax.readthedocs.io/en/latest/', None), 64 | } 65 | 66 | suppress_warnings = [ 67 | 'ref.citation', # Many duplicated citations in numpy/scipy docstrings. 68 | 'ref.footnote', # Many unreferenced footnotes in numpy/scipy docstrings 69 | ] 70 | 71 | # Add any paths that contain templates here, relative to this directory. 72 | templates_path = ['_templates'] 73 | 74 | # The suffix(es) of source filenames. 75 | # Note: important to list ipynb before md here: we have both md and ipynb 76 | # copies of each notebook, and myst will choose which to convert based on 77 | # the order in the source_suffix list. Notebooks which are not executed have 78 | # outputs stored in ipynb but not in md, so we must convert the ipynb. 79 | source_suffix = ['.rst', '.ipynb', '.md'] 80 | 81 | # The master toctree document. 82 | main_doc = 'index' 83 | 84 | # The language for content autogenerated by Sphinx. Refer to documentation 85 | # for a list of supported languages. 86 | # 87 | # This is also used if you do content translation via gettext catalogs. 88 | # Usually you set "language" from the command line for these cases. 89 | language = None 90 | 91 | # List of patterns, relative to source directory, that match files and 92 | # directories to ignore when looking for source files. 93 | # This pattern also affects html_static_path and html_extra_path. 94 | exclude_patterns = [ 95 | # Sometimes sphinx reads its own outputs as inputs! 96 | 'build/html', 97 | 'build/jupyter_execute', 98 | 'notebooks/README.md', 99 | 'notebooks/colabs/*.ipynb' 100 | 'README.md', 101 | # Ignore markdown source for notebooks; myst-nb builds from the ipynb 102 | 'notebooks/*.md' 103 | ] 104 | 105 | # The name of the Pygments (syntax highlighting) style to use. 106 | pygments_style = None 107 | 108 | 109 | autosummary_generate = True 110 | napolean_use_rtype = False 111 | 112 | # mathjax_config = { 113 | # 'TeX': {'equationNumbers': {'autoNumber': 'AMS', 'useLabelIds': True}}, 114 | # } 115 | 116 | # Additional files needed for generating LaTeX/PDF output: 117 | # latex_additional_files = ['references.bib'] 118 | 119 | # -- Options for HTML output ------------------------------------------------- 120 | 121 | # The theme to use for HTML and HTML Help pages. See the documentation for 122 | # a list of builtin themes. 123 | # 124 | html_theme = 'sphinx_rtd_theme' 125 | 126 | # html_theme_options = { 127 | # 'logo_only': False, 128 | # 'display_version': True, 129 | # 'prev_next_buttons_location': 'bottom', 130 | # 'style_external_links': False, 131 | # } 132 | 133 | # Theme options are theme-specific and customize the look and feel of a theme 134 | # further. For a list of options available for each theme, see the 135 | # documentation. 136 | 137 | # The name of an image file (relative to this directory) to place at the top 138 | # of the sidebar. 139 | 140 | # Add any paths that contain custom static files (such as style sheets) here, 141 | # relative to this directory. They are copied after the builtin static files, 142 | # so a file named "default.css" will overwrite the builtin "default.css". 143 | html_static_path = ['_static'] 144 | html_logo = '_static/emlp_logo4x.png' 145 | # Custom sidebar templates, must be a dictionary that maps document names 146 | # to template names. 147 | # 148 | # The default sidebars (for documents that don't match any pattern) are 149 | # defined by theme itself. Builtin themes are using these templates by 150 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 151 | # 'searchbox.html']``. 152 | # 153 | # html_sidebars = {} 154 | 155 | # # commented out notebook execution 156 | # # -- Options for myst ---------------------------------------------- 157 | # jupyter_execute_notebooks = "off"# changed from "force" 158 | # #jupyter_execute_notebooks = "cache" 159 | # execution_allow_errors = False 160 | # execution_fail_on_error = True # Requires https://github.com/executablebooks/MyST-NB/pull/296 161 | 162 | # # Notebook cell execution timeout; defaults to 30. 163 | # execution_timeout = 100 164 | 165 | # -- Options for HTMLHelp output --------------------------------------------- 166 | 167 | # Output file base name for HTML help builder. 168 | htmlhelp_basename = 'EMLPdoc' 169 | 170 | 171 | # -- Options for LaTeX output ------------------------------------------------ 172 | 173 | latex_elements = { 174 | # The paper size ('letterpaper' or 'a4paper'). 175 | # 176 | # 'papersize': 'letterpaper', 177 | 178 | # The font size ('10pt', '11pt' or '12pt'). 179 | # 180 | # 'pointsize': '10pt', 181 | 182 | # Additional stuff for the LaTeX preamble. 183 | # 184 | # 'preamble': '', 185 | 186 | # Latex figure (float) alignment 187 | # 188 | # 'figure_align': 'htbp', 189 | } 190 | 191 | # # Grouping the document tree into LaTeX files. List of tuples 192 | # # (source start file, target name, title, 193 | # # author, documentclass [howto, manual, or own class]). 194 | # latex_documents = [ 195 | # (main_doc, 'JAX.tex', 'JAX Documentation', 196 | # 'The JAX authors', 'manual'), 197 | # ] 198 | 199 | 200 | # # -- Options for manual page output ------------------------------------------ 201 | 202 | # # One entry per manual page. List of tuples 203 | # # (source start file, name, description, authors, manual section). 204 | # man_pages = [ 205 | # (main_doc, 'jax', 'JAX Documentation', 206 | # [author], 1) 207 | # ] 208 | 209 | 210 | # # -- Options for Texinfo output ---------------------------------------------- 211 | 212 | # # Grouping the document tree into Texinfo files. List of tuples 213 | # # (source start file, target name, title, author, 214 | # # dir menu entry, description, category) 215 | # texinfo_documents = [ 216 | # (main_doc, 'JAX', 'JAX Documentation', 217 | # author, 'JAX', 'One line description of project.', 218 | # 'Miscellaneous'), 219 | # ] 220 | 221 | 222 | # -- Options for Epub output ------------------------------------------------- 223 | 224 | # Bibliographic Dublin Core info. 225 | epub_title = project 226 | 227 | # The unique identifier of the text. This can be a ISBN number 228 | # or the project homepage. 229 | # 230 | # epub_identifier = '' 231 | 232 | # A unique identification for the text. 233 | # 234 | # epub_uid = '' 235 | 236 | # A list of files that should not be packed into the epub file. 237 | epub_exclude_files = ['search.html'] 238 | 239 | 240 | # -- Extension configuration ------------------------------------------------- 241 | 242 | # Tell sphinx-autodoc-typehints to generate stub parameter annotations including 243 | # types, even if the parameters aren't explicitly documented. 244 | always_document_param_types = True 245 | 246 | # -- Options for nbsphinx ----------------------------------------------------- 247 | 248 | # Execute notebooks before conversion: 'always', 'never', 'auto' (default) 249 | # We never execute notebooks to avoid problems if nbsphinx won't find all dependencies. 250 | nbsphinx_execute = 'never' 251 | 252 | # If True, the build process is continued even if an exception occurs: 253 | nbsphinx_allow_errors = True 254 | 255 | # Controls when a cell will time out (defaults to 30; use -1 for no timeout): 256 | nbsphinx_timeout = 180 257 | 258 | # Default Pygments lexer for syntax highlighting in code cells: 259 | nbsphinx_codecell_lexer = 'ipython3' 260 | 261 | nbsphinx_prolog = r""" 262 | {% set docname = 'docs/notebooks/colabs/' + env.doc2path(env.docname, base=None).split('/')[-1] %} 263 | 264 | .. only:: html 265 | 266 | .. role:: raw-html(raw) 267 | :format: html 268 | 269 | .. nbinfo:: 270 | Interactive online version: 271 | :raw-html:`Open In Colab` 272 | 273 | """ 274 | 275 | # nbsphinx_prolog = """ 276 | # ---- 277 | 278 | # Generated by nbsphinx_ from a Jupyter_ notebook. 279 | 280 | # .. _nbsphinx: http://nbsphinx.readthedocs.io/ 281 | # .. _Jupyter: https://jupyter.org/ 282 | # """ 283 | 284 | # nbsphinx_prolog = """ 285 | # ---- 286 | # {% set docname = 'notebooks/docs/' + env.doc2path(env.docname, base=None) %} 287 | # .. only:: html 288 | # .. role:: raw-html(raw) 289 | # :format: html 290 | # .. nbinfo:: 291 | # Interactive online version: 292 | # :raw-html:`Open In Colab` 293 | # """ 294 | -------------------------------------------------------------------------------- /docs/documentation.md: -------------------------------------------------------------------------------- 1 | # Update documentation 2 | 3 | To rebuild the documentation, you need to install the requirement packages: 4 | ``` 5 | pip install -r docs/requirements.txt 6 | ``` 7 | And then run: 8 | ``` 9 | sphinx-build -b html docs docs/build/html 10 | ``` 11 | This can take a long time because it executes many of the notebooks in the documentation source; 12 | if you'd prefer to build the docs without exeuting the notebooks, you can run: 13 | ``` 14 | sphinx-build -b html -D jupyter_execute_notebooks=off docs docs/build/html 15 | ``` 16 | You can then see the generated documentation in `docs/build/html/index.html`. 17 | 18 | 19 | ### Editing ipynb 20 | 21 | To edit notebooks in the Colab interface, 22 | open and `Upload` from your local repo. 23 | Update it as needed, `Run all cells` then `Download ipynb` (for editing and running in Colab you will need to add 24 | `!pip install git+https://github.com/mfinzi/equivariant-MLP.git`). 25 | You may want to test that it executes properly, using `sphinx-build` as explained above. -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. EMLP documentation master file, created by 2 | sphinx-quickstart on Mon Feb 22 18:41:05 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | 7 | EMLP reference documentation 8 | ============================ 9 | 10 | A type system for the automated construction of equivariant layers. 11 | EMLP is designed to make constructing equivariant layers with different matrix groups 12 | and representations an easy task, and one that does not require knowledge of analytic solutions. 13 | 14 | 15 | .. toctree:: 16 | :maxdepth: 1 17 | :caption: Getting Started 18 | 19 | notebooks/1quickstart.ipynb 20 | notebooks/2building_a_model.ipynb 21 | notebooks/3new_groups.ipynb 22 | 23 | .. toctree:: 24 | :maxdepth: 1 25 | :caption: Advanced Features 26 | 27 | notebooks/4new_representations.ipynb 28 | notebooks/5mixed_tensors.ipynb 29 | notebooks/6multilinear_maps.ipynb 30 | 31 | .. toctree:: 32 | :maxdepth: 1 33 | :caption: Cross Platform Support 34 | 35 | notebooks/pytorch_support.ipynb 36 | notebooks/haiku_support.ipynb 37 | notebooks/flax_support.ipynb 38 | 39 | .. toctree:: 40 | :maxdepth: 1 41 | :caption: Examples 42 | 43 | .. toctree:: 44 | :glob: 45 | :maxdepth: 1 46 | :caption: API Reference 47 | 48 | package/emlp.groups 49 | package/emlp.reps 50 | package/emlp.nn 51 | 52 | .. toctree:: 53 | :maxdepth: 1 54 | :caption: Notes 55 | 56 | testing.md 57 | inner_workings.md 58 | CHANGELOG.md 59 | 60 | .. toctree:: 61 | :maxdepth: 2 62 | :caption: Developer Documentation 63 | 64 | documentation.md 65 | 66 | Indices and tables 67 | ================== 68 | 69 | * :ref:`genindex` 70 | * :ref:`modindex` 71 | -------------------------------------------------------------------------------- /docs/inner_workings.md: -------------------------------------------------------------------------------- 1 | # Inner Workings 2 | 3 | It's a mystery, even to me -------------------------------------------------------------------------------- /docs/manual_make.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # jupytext --sync ./notebooks/* 3 | ./notebooks/merge_nbs.sh 4 | sphinx-build -b html . ./build/html -------------------------------------------------------------------------------- /docs/notebooks/5mixed_tensors.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Combining Representations from Different Groups (experimental)\n", 8 | "## Direct Product groups" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "It is possible to combine representations from different groups. These kinds of representations are relevant when there are multiple structures in the data. For example, a point cloud is a set of vectors which transforms both under permutations and rotations of the vectors. We can formalize this as $V_{S_n}\\otimes V_{SO(3)}$ or in tensor notation $T_1^{S_n}\\otimes T_1^{SO(3)}$. While this object could be expressed as the representation of the product group $V_{S_n \\times SO(3)}$, other objects like $T_k^{S_n}\\otimes T_j^{SO(3)}$ can not be so easily.\n", 16 | "\n", 17 | "Nevertheless, we can calculate the symmetric bases for these objects. For example, maps from vector edge features $T_1^{SO(3)}\\otimes T_2^{S_n}$ to matrix node features $T_2^{SO(3)}\\otimes T_1^{S_n}$ can be computed:" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "V_SO(3)⊗V²_S(4) --> V²_SO(3)⊗V_S(4)\n" 30 | ] 31 | } 32 | ], 33 | "source": [ 34 | "from emlp.groups import *\n", 35 | "from emlp.reps import T,vis,V,Scalar\n", 36 | "\n", 37 | "repin,repout = T(1)(SO(3))*T(2)(S(4)),T(2)(SO(3))*T(1)(S(4))\n", 38 | "print(repin,\"-->\", repout)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 7, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "data": { 48 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAATAAAADnCAYAAACZtwrQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAH/UlEQVR4nO3dMYhdRRsG4JPEIoKtIBjcbCxFLJU0AUurkCZiF7AIIqwhhUgKEQuxCMs2VoJWkjRilVKwESyDWGl2jdhZplkCm7UQxGLmmLnMuee+Z5+nnOueuZvL//4D7879Th0fHw8AiU7P/QYAViXAgFgCDIglwIBYAgyI9czYi9t7t1WUwKwOdm6eqr3mBAbEEmBALAEGxBJgQCwBBsQSYEAsAQbEEmBALAEGxBJgQCwBBsQSYEAsAQbEEmBALAEGxBJgQCwBBsQSYEAsAQbEEmBALAEGxBJgQCwBBsQanQtZ1WtaZG3aW+351elwncy17yp7+wzWsu+z5x8V11+69bjLtvtvP199bevjH4vrf350sWmPo7Pl9TOH5fVzn7Xt2/r8mtq+w079Z5zAgFgCDIglwIBYAgyIJcCAWOMt5NSN0NTN2zraw6n5DGb14pVfiuv7n5Qbudbm7cKdv6qv/fHtK8X1rVv1nyl5eLncdNbaw1a137lXOznGCQyIJcCAWAIMiCXAgFgCDIi12l3IXqa+h9erMZvzjuTUfAYr6dW81RrCYRiGc1fKdwNbG9Ct79pay9b2s/Y7tP4brcIJDIglwIBYAgyIJcCAWAIMiDVvCzm1Xo3Z1N+Kuq495pDyGXRqM9dxL7B1j7Gms6S1/Ry7z9mi1n6OcQIDYgkwIJYAA2IJMCCWAANinTo+rtc423u3yy/ONZMw5fkd93juYfn/Yz59/+suz9/54Z3i+sFbXxbXX/7+WtPznxyeKa6fPntUXH/w5ldN+7Y+v6a6793rTc+hv4Odm9X/pTmBAbEEGBBLgAGxBBgQS4ABsfrOhWxt3pb6/I573P/wi+L69r13i+ut7dvepW+K6699/l5xfbex/bzx09Xieq09bFX7fXu1k2w2JzAglgADYgkwIJYAA2IJMCCWAANirfaV0kv984eel78n3qPXnw/U/szhQac/39h9/W5xvab1zzdq77/134dMTmBALAEGxBJgQCwBBsQSYECsvoNt09vDTWwnG019ubn1+bWWsKa1/axdRm9Vaz+H80+6PJ9pOIEBsQQYEEuAAbEEGBBLgAGxsgbbdmrqqubad2zvmnW8J4bnfp9nsPAwGC7873Ne+NVgW2B5BBgQS4ABsQQYEEuAAbH6DrZtNfXdwyU0dXP9DnM2shtkrsHCw2C48FPt3e1JAGsmwIBYAgyIJcCAWAIMiNX3G1lbTX0XsldreRIbOe3nqKnncg6D2ZxPwwkMiCXAgFgCDIglwIBYAgyINW8LObVerWXPuZA1G9ayzWbT2s9G67gXeNJmc/68W/8ZJzAglgADYgkwIJYAA2IJMCCWuZCbsO8qe/sMlrEv/+tg56a5kMDyCDAglgADYgkwIJYAA2KZC7npfAZQ5QQGxBJgQCwBBsQSYEAsAQbEMhfyaZ6/5HtyPgOCOYEBsQQYEEuAAbEEGBBLgAGxzIX8r02cC9lzjzmkfAbazEhOYEAsAQbEEmBALAEGxBJgQKzRFvLZ84+K6y/detxl8/23ny+ub338Y3H9z48uNj3/6Gx5/cxhef3cZ2371p4/tkdNbe/fdt9oe1Avc7VyrfvO9T7X0X6etBmZK/ybOoEBsQQYEEuAAbEEGBBLgAGxRlvIF6/8Ulzf/6TcyrU2bxfu/FVc/+PbV4rrW7fK/33Nw8vllnOsPWwx9vu2NqBsqLnmco7tYTbnv5zAgFgCDIglwIBYAgyIJcCAWCt9I2utSWtt3mot4bkr5XuBre3n1ndtrWVr+1l7/2PvqVcDykKs4y7kgmdzOoEBsQQYEEuAAbEEGBBLgAGxus6F7NVO9nr+WEtY0tp+1u5yrqLWgA4Pu23BSbXg2ZxOYEAsAQbEEmBALAEGxBJgQKxTx8f1CmF773b5xalbh/Tnd9zDbM7xfXs13Bs3l3MYNm8+40z7Hnxws7qzExgQS4ABsQQYEEuAAbEEGBBr/C7k1Heilvr8jnuYzTlu6vu3s34L6YLnOfbiBAbEEmBALAEGxBJgQCwBBsRa7RtZl9oe9rw7OfEeZnP+o/b+FzGXc64GNKj9dAIDYgkwIJYAA2IJMCCWAANidZ0LGd8ebmI72chsztVU53L+3uXxWYLaTycwIJYAA2IJMCCWAANiCTAglgADYo0Otn31xm7xxU/f/7rL5js/vFNcP3jry+L6y99fa3r+k8MzxfXTZ4+K6w/e/Kpp39rzx/aoqe5993rTc7rZsOGmi92X/3WwY7AtsEACDIglwIBYAgyIJcCAWKOXue9/+EVxffveu8X11uZt79I3xfXXPn+vuL7b2H7e+OlqcX2sPWwx9vu2NqBVQRdrYd2cwIBYAgyIJcCAWAIMiCXAgFgrfaV0rUlrbd5qLeGDTu3n7ut3i+s1re1n7f2PvadeDejktJ8EcAIDYgkwIJYAA2IJMCCWAANidR1s26ud7PX8sZawpLX9rN3lXEWtAR22nnTbI4L2kwZOYEAsAQbEEmBALAEGxBJgQKzRuZDbe7fLL9Z/pHH3yvpJnA1oLiEUmQsJLJIAA2IJMCCWAANiCTAg1vhdyKmbsdbnt7afGjxYNCcwIJYAA2IJMCCWAANiCTAgVtdvZG029V3IXq2le4qwkZzAgFgCDIglwIBYAgyIJcCAWPO2kFPr1Vr2ups5RqMJzZzAgFgCDIglwIBYAgyIJcCAWKNzIQE2mRMYEEuAAbEEGBBLgAGxBBgQS4ABsf4G+uFQ5DwSeDQAAAAASUVORK5CYII=\n", 49 | "text/plain": [ 50 | "
" 51 | ] 52 | }, 53 | "metadata": { 54 | "needs_background": "light" 55 | }, 56 | "output_type": "display_data" 57 | } 58 | ], 59 | "source": [ 60 | "vis(repin,repout,cluster=False)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "Or perhaps you would like equivariant maps from two sequence of sets and a matrix (under 3D rotations) to itself. You can go wild." 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 2, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "name": "stdout", 77 | "output_type": "stream", 78 | "text": [ 79 | "Rep: V²+2V_S(4)⊗V_Z(3)\n", 80 | "Linear maps: V⁴+4V_S(4)⊗V_Z(3)⊗V²_SO(3)+4V²_S(4)⊗V²_Z(3)\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "rep = 2*T(1)(Z(3))*T(1)(S(4))+T(2)(SO(3))\n", 86 | "print(f\"Rep: {rep}\")\n", 87 | "print(f\"Linear maps: {rep>>rep}\")" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 7, 93 | "metadata": { 94 | "scrolled": true 95 | }, 96 | "outputs": [ 97 | { 98 | "data": { 99 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAI+UlEQVR4nO3dX8jfVR0H8LMhudrQLP/xYG4p1Voxc2GR0KVYEYWVrfBKKCo0NU0iqLCsJFdGGUQs6EqSjCKoXOpFNy2a2DJLZur8Uww2ZyPZTENaF7vo5nvOw/fsu+/ee57X6/J7OJ4H5psD55zf57Pi8OHDBciz8nj/AcAw4YRQwgmhhBNCCSeEOqk1+KHtn6we5e58+jWT/RGbNzxQHbv/2bWTrVNKKU/sGP67X/ed3ZOuc+Pv7q2Ofe6mT0y61qGFFYPfz7r/xUnXedn+Q9WxXZ86ZdK1NmzZN/j9pd1PTrrOSeetq45NuVZrnbsf2zL4D2jnhFDCCaGEE0IJJ4QSTgglnBCqeZXSui658Ny/j55T07ouuejVT3XNq6ldmTx67Xmj57S0rku+cdMPRs9pqV2Z7L3o5NFzWlrXJeu//9zoObTZOSGUcEIo4YRQwgmhhBNCNU9rW2qnsrVT3Naclp6T3ClPcUuZ7yS3dorbmtPSOpGd6yS3dorbmsMRdk4IJZwQSjghlHBCKOGEUMIJoVa0Kr6ff+fXJi0HX7tm+deLL59ymeZj+Z3vWph0rdo1y2kPT7pM85rl5o9fOelatWuWQ+temnSd1jXLyoMvDH5XQwg47oQTQgknhBJOCCWcEKr58H3FU/VT1J4H34cqy536039X5/Q8lt+99/Tq2OZtw9XleyvLv7YMn0Cv/PGa6pyeB99X3vOx6tiPtv5w8Htv2ZPVe2qH9PX/XXoey//n9NXVsVWV09rlxM4JoYQTQgknhBJOCCWcEEo4IVTzKqWnrk7PFUtPZfnF5tXUrkymriw/Z3X05JpEUzfwXU7snBBKOCGUcEIo4YRQwgmhhBNCdbdjmKsJ7VJs4DtX64KeBr6LzauZq4HvcmLnhFDCCaGEE0IJJ4QSTgjVrPj+7oWrJ634XjvJPby2XkOoR+ux/FzV5e/4/TsmXad1kvvPC06bdK3aSe5cleVLKWXtXXsGv6v4Dhx3wgmhhBNCCSeEEk4IJZwQqvnw/cJtw8fZpfQ9+K61LnhiR/2h+pRtH0qpt37oeSxfSr31w+on52td8KoHDwx+761JVGv9UGv7UErfY/l62wdKsXNCLOGEUMIJoYQTQgknhGqe1vaU7ug5xe2pLL/YvJraqezUleXnrI6eXPakt4Evdk6IJZwQSjghlHBCKOGEUMIJoborvs/VhHYpNvCdqzp6TwPfxebVzNXAdzmxc0Io4YRQwgmhhBNCCSeEEk4I1X2VUtPzS5adZaFrrZ4GvqWMb/3Qc83ybFk3ep2eX7Ic8dLotXquWXraPvQ28N1yyftGr7XU2DkhlHBCKOGEUMIJoYQTQjWb52687tvVwbN2HJzsj3hs8yuqY2e8Yf9k65RSyqYz/jH4/YHvXjjpOr/6+jerY++//vpJ1zrl138Z/L7re+snXeeWi39WHdvyrY9MutbzZw/2ky3nfmX7pOs8/aWLq2NTrtVa55GbPqN5LpxIhBNCCSeEEk4IJZwQSjghVPPhe+u6ZO/b1oyeU9O6LnnmkeHmtIvNq6ldmbz1mp2j57S0rkt+cdtto+e01K5M1l+9a/ScltZ1yY033Dl6Dm12TgglnBBKOCGUcEIo4YRQ3WVKaqeytVPc1pyWnpPcKU9xS5nvJLd2itua09I6kZ3rJLd2ituawxF2TgglnBBKOCGUcEIo4YRQwgmhJq/43vNYflV5vmut2pVJ67F8faSu55rlwVsvGL1Oz2P5Ukq56L5rRq/Vc82y74qNo9fpeSxfSilfvuOjo9daauycEEo4IZRwQijhhFDCCaGEE0I1r1LOv/1v1bGeX2OcfGC4u8ML286szun5JcupjauZWuuH3rYPf3zmnMHv1331J9U5Pb/GuPTmz1bH7v/icOuH3ppEz73nzYPfz7zjz9U5Pb9k+fz2D1THVo3+ry09dk4IJZwQSjghlHBCKOGEUM3T2p4H3z2nuD2P5RebV9PzWL7nJHfO6ujJNYmmbuC7nNg5IZRwQijhhFDCCaGEE0IJJ4TqriE0VxPapdjAd67WBb01iaa8Zpm67cNyYueEUMIJoYQTQgknhBJOCCWcEGrydgw9v2R5/NOv71qrp7t2T+uHnmuWV45epb91we1f+PDotXquWeZq+1BKKU/eML6dxVJj54RQwgmhhBNCCSeEEk4I1Tyt3b9puEJ7KX0PvmvV0TfNVFm+lHp1+Z7H8qXUq8uvfOjx6pypq6PfUqku31uTqFZdvlZZvpS+x/K1yvIcYeeEUMIJoYQTQgknhBJOCCWcEKp5ldLz4LvniqXnsfxi82p6Hsv3XLPM2boguSZRbwNf7JwQSzghlHBCKOGEUMIJobrLlMzVhHYpNvCdqzp6b9mTKU9yp64sv5zYOSGUcEIo4YRQwgmhhBNCCSeEmrzie89j+frlS1tPA9+e6vJ91ywvjl6ntzr6vis2jl6r55plrsrypZTyzq03jl5rqbFzQijhhFDCCaGEE0IJJ4QSTgi14vDheuuCh54+pz5I01VXjW80y//9duvWwe+XLrxl0nXe+9cD1bFfvum0Wda59o33rRj6bueEUMIJoYQTQgknhBJOCNU8rb1k5eVOa4ly8PK3V8fW3PWHwe97fr6hOmfhsoeP+m86Wvf+9y6ntXAiEU4IJZwQSjghlHBCKOGEUJPXEIJjqXZdUkr9mmXhsvqcHnNdzdg5IZRwQijhhFDCCaGEE0I5reWE0j4pHT6V7Xks3zLXY3k7J4QSTgglnBBKOCGUcEIo4YRQKr4fIyq+H50P3nrP4Pcpq7CXUspv9vypOjZldXkV32EJEU4IJZwQSjghlHBCKOGEUNoxQEXt1yytX7LUfjXT+iWLdgxwghFOCCWcEEo4IZRwQig1hFgyek5KW2qnsq2aRFNWl7dzQijhhFDCCaGEE0IJJ4QSTgjlKoUlY+o2CfWrmfENfHvaPtg5IZRwQijhhFDCCaGEE0Kp+H6MqPh+dOaq+N6qxD7lWq3K8ivPflSZEjiRCCeEEk4IJZwQSjghlHBCqOZVCnD82DkhlHBCKOGEUMIJoYQTQgknhPof2soBg+pZzXoAAAAASUVORK5CYII=\n", 100 | "text/plain": [ 101 | "
" 102 | ] 103 | }, 104 | "metadata": { 105 | "needs_background": "light" 106 | }, 107 | "output_type": "display_data" 108 | } 109 | ], 110 | "source": [ 111 | "vis(rep,rep)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "The kronecker product of the individual solutions for different groups can be seen in these basis matrices, in the top left hand corner we have a circulant matrix of deep set solutions (identity + background)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 3, 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "data": { 128 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAEpklEQVR4nO3cPWuedRiH4TttoNAvoHPFRYJkKmSU6l4EwZZu7dSl2MmX2ZcpkiUfIQiCe1F06BBIEVIpLmKHDsWv4NA+foHk6XLVng8cx/jccJHl5A8Zflur1WoBei686T8AOJs4IUqcECVOiBInRG2v+/jlHx+P/Cv3+P7ViTNJz+68GLlz5ebpyJ2aB89PR+7sHNwduVP057efbZ31u5cTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LWLiFMLRjs7Z+M3CkuKkwtGDw92h25U1tUmFoweHLvcOTOsmzOqoKXE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROi1i4hTLGo8GoWFdabXC+YWlV43YsKXk6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiNparVbnfvzgw+/O/7jBphYVlmVZHu1eHLtVMrWocOnx5ZE7RVOLChfe/mvrzN9HrgPjxAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUWtnSl7+8+7ITMnOwd2JM0nXbzwcuXN8/+rInZpnd16M3Lly83TkTtHPL380UwKbRJwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IWrtEsJ7X3w/soTw5N7hxJnkosJbj/4dubO3fzJyp7aosP3r7yN3nh7tjtxZlt6qgiUE2DDihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR/8sSwpTiosLUEsKU2qLC1BLCpKlVhalFBUsIsGHECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROiNmoJYcrUosKyLMu1W7fHbpVMLSo82r04cqdoalHh70+/soQAm0ScECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFq7RLCOz98PbKEcOXm6cSZpAfPT0fu7BzcHblTc/3Gw5E7x/evjtwp+u2Xzy0hwCYRJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6I2l73cWrB4OnR7sid4qLC1ILBk3uHI3dqiwpTCwZ7+ycjd5Zlc1YVvJwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiFo7UzLF3MmrmTtZb3JaZGry5HXPnXg5IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFqa7VanfvxowufnP9xg00tKizLslx6fHnsVsnUosK1W7dH7hRNLSp88/5PW2f97uWEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihKi1SwjAm+PlhChxQpQ4IUqcECVOiBInRP0H3NO3PQoCCbMAAAAASUVORK5CYII=\n", 129 | "text/plain": [ 130 | "
" 131 | ] 132 | }, 133 | "metadata": { 134 | "needs_background": "light" 135 | }, 136 | "output_type": "display_data" 137 | } 138 | ], 139 | "source": [ 140 | "vis(V(Z(3))*V(S(4)),V(Z(3))*V(S(4)))" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "And in the bottom left we have maps $V_{Z_3}\\otimes V_{S_4} \\rightarrow V_{SO(2)}^{\\otimes 2}$ of which the solutions are the product of $V_{Z_3}\\otimes V_{S_4} \\rightarrow \\mathbb{R}$ (the vector $\\mathbf{1}$) and $\\mathbb{R} \\rightarrow V_{SO(2)}^{\\otimes 2}$ (the flattened $I_{3\\times 3}$)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 5, 153 | "metadata": {}, 154 | "outputs": [ 155 | { 156 | "data": { 157 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAATAAAADnCAYAAACZtwrQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAADZUlEQVR4nO3dsY3CQBRF0bVFEYicnC4omlJogCo82wDjYIXMXumc9CU/uhrJgZcxxg9A0frtAwD+SsCALAEDsgQMyBIwIOu0N26vq0+UwFet5+cy3Y48BOCTBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBr97+Q98vtoDMA3nts880LDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7KWMcZ03F7X+QhwgPX8XKbbkYcAfJKAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZJ32xvvldtAZAO89tvnmBQZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZyxhjOm6v63wEOMB6fi7T7chDAD5JwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7J2/wsJ8J95gQFZAgZkCRiQJWBAloABWQIGZP0C0FAeWUQoKl0AAAAASUVORK5CYII=\n", 158 | "text/plain": [ 159 | "
" 160 | ] 161 | }, 162 | "metadata": { 163 | "needs_background": "light" 164 | }, 165 | "output_type": "display_data" 166 | } 167 | ], 168 | "source": [ 169 | "vis(V(Z(3))*V(S(4)),T(2)(SO(3)),False)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": {}, 175 | "source": [ 176 | "## Wreath Products (coming soon)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "These are all examples of tensor products of representations, which are in turn a kind of representation on the direct product group $G=G_1\\times G_2$. There are other ways of combining groups however, and for many hierarchical structures there is a larger group of symmetries $G_1\\wr G_2$ that contains $G_1 \\times G_2$ but also other elements. These so called wreath products in the very nice paper [Equivariant Maps for Hierarchical Structures](https://arxiv.org/pdf/2006.03627.pdf). Support for the wreath product of representations is not yet implemented but if this is something that would be useful to you, send me an email." 184 | ] 185 | } 186 | ], 187 | "metadata": { 188 | "kernelspec": { 189 | "display_name": "Python 3", 190 | "language": "python", 191 | "name": "python3" 192 | }, 193 | "language_info": { 194 | "codemirror_mode": { 195 | "name": "ipython", 196 | "version": 3 197 | }, 198 | "file_extension": ".py", 199 | "mimetype": "text/x-python", 200 | "name": "python", 201 | "nbconvert_exporter": "python", 202 | "pygments_lexer": "ipython3", 203 | "version": "3.8.5" 204 | } 205 | }, 206 | "nbformat": 4, 207 | "nbformat_minor": 2 208 | } 209 | -------------------------------------------------------------------------------- /docs/notebooks/6multilinear_maps.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Multilinear Maps" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Our codebase extends trivially to multilinear maps, since these maps are in fact just linear maps in disguise.\n", 15 | "\n", 16 | "If we have a sequence of representations $R_1$, $R_2$, $R_3$ for example, we can write the (bi)linear maps $R_1\\rightarrow R_2\\rightarrow R_3$. This way of thinking about maps of multiple variables borrowed from programming languages and curried functions is very powerful." 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "We can think of such an object $R_1\\rightarrow R_2\\rightarrow R_3$ either as $R_1 \\rightarrow (R_2\\rightarrow R_3)$: a linear map from $R_1$ to linear maps from $R_2$ to $R_3$ or as\n", 24 | "$(R_1\\times R_2) \\rightarrow R_3$: a bilinear map from $R_1$ and $R_2$ to $R_3$. Since linear maps from one representation to another are just another representation in our type system, you can use this way of thinking to find the equivariant solutions to arbitrary multilinear maps." 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "For example, we can get the bilinear $SO(4)$ equivariant maps $(R_1\\times R_2) \\rightarrow R_3$ with the code below." 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 1, 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "(2940, 27)\n" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "from emlp.groups import SO,rel_err\n", 49 | "from emlp.reps import V\n", 50 | "\n", 51 | "G = SO(4)\n", 52 | "W = V(G)\n", 53 | "R1 = 3*W+W**2 # some example representations\n", 54 | "R2 = W.T+W**0\n", 55 | "R3 = W**0 +W**2 +W\n", 56 | "\n", 57 | "Q = (R1>>(R2>>R3)).equivariant_basis()\n", 58 | "print(Q.shape)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "And we can verify that these multilinear solutions are indeed equivariant" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 2, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "data": { 75 | "text/plain": [ 76 | "DeviceArray(1.2422272e-07, dtype=float32)" 77 | ] 78 | }, 79 | "execution_count": 2, 80 | "metadata": {}, 81 | "output_type": "execute_result" 82 | } 83 | ], 84 | "source": [ 85 | "import numpy as np\n", 86 | "\n", 87 | "example_map = (Q@np.random.randn(Q.shape[-1]))\n", 88 | "example_map = example_map.reshape(R3.size(),R2.size(),R1.size())\n", 89 | "\n", 90 | "x1 = np.random.randn(R1.size())\n", 91 | "x2 = np.random.randn(R2.size())\n", 92 | "g = G.sample()\n", 93 | "\n", 94 | "out1 = np.einsum(\"ijk,j,k\",example_map,R2.rho(g)@x2,R1.rho(g)@x1)\n", 95 | "out2 = R3.rho(g)@np.einsum(\"ijk,j,k\",example_map,x2,x1)\n", 96 | "rel_err(out1,out2)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "Note that the output mapping is of shape $(\\mathrm{dim}(R_3),\\mathrm{dim}(R_2),\\mathrm{dim}(R_1))$\n", 104 | "with the inputs to the right as you would expect with a matrix. \n", 105 | "\n", 106 | "Note the parenthesis in the expression `(R1>>(R2>>R3))` since the python `>>` associates to the right.\n", 107 | "The notation $R_1\\rightarrow R_2 \\rightarrow R_3$ or `(R1>>(R2>>R3))` can be a bit confusing since the inputs are on the right. It can be easier in this concept to instead reverse the arrows and express the same object as $R_3\\leftarrow R_2\\leftarrow R_1$ or `R3<>R2` wherever you like, and it is usually more intuitive." 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 3, 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "data": { 119 | "text/plain": [ 120 | "True" 121 | ] 122 | }, 123 | "execution_count": 3, 124 | "metadata": {}, 125 | "output_type": "execute_result" 126 | } 127 | ], 128 | "source": [ 129 | "R3<>(R2>>R3))" 130 | ] 131 | } 132 | ], 133 | "metadata": { 134 | "kernelspec": { 135 | "display_name": "Python 3", 136 | "language": "python", 137 | "name": "python3" 138 | }, 139 | "language_info": { 140 | "codemirror_mode": { 141 | "name": "ipython", 142 | "version": 3 143 | }, 144 | "file_extension": ".py", 145 | "mimetype": "text/x-python", 146 | "name": "python", 147 | "nbconvert_exporter": "python", 148 | "pygments_lexer": "ipython3", 149 | "version": "3.8.5" 150 | } 151 | }, 152 | "nbformat": 4, 153 | "nbformat_minor": 4 154 | } 155 | -------------------------------------------------------------------------------- /docs/notebooks/_colab_preamble.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "First install the repo and requirements." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%pip --quiet install git+https://github.com/mfinzi/equivariant-MLP.git" 17 | ] 18 | } 19 | ], 20 | "metadata": { 21 | "kernelspec": { 22 | "display_name": "Python 3", 23 | "language": "python", 24 | "name": "python3" 25 | }, 26 | "language_info": { 27 | "codemirror_mode": { 28 | "name": "ipython", 29 | "version": 3 30 | }, 31 | "file_extension": ".py", 32 | "mimetype": "text/x-python", 33 | "name": "python", 34 | "nbconvert_exporter": "python", 35 | "pygments_lexer": "ipython3", 36 | "version": "3.8.5" 37 | }, 38 | "accelerator": "GPU" 39 | }, 40 | "nbformat": 4, 41 | "nbformat_minor": 4 42 | } 43 | -------------------------------------------------------------------------------- /docs/notebooks/colabs/5mixed_tensors.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "First install the repo and requirements." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%pip --quiet install git+https://github.com/mfinzi/equivariant-MLP.git" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "# Combining Representations from Different Groups (experimental)\n", 24 | "## Direct Product groups" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "It is possible to combine representations from different groups. These kinds of representations are relevant when there are multiple structures in the data. For example, a point cloud is a set of vectors which transforms both under permutations and rotations of the vectors. We can formalize this as $V_{S_n}\\otimes V_{SO(3)}$ or in tensor notation $T_1^{S_n}\\otimes T_1^{SO(3)}$. While this object could be expressed as the representation of the product group $V_{S_n \\times SO(3)}$, other objects like $T_k^{S_n}\\otimes T_j^{SO(3)}$ can not be so easily.\n", 32 | "\n", 33 | "Nevertheless, we can calculate the symmetric bases for these objects. For example, maps from vector edge features $T_1^{SO(3)}\\otimes T_2^{S_n}$ to matrix node features $T_2^{SO(3)}\\otimes T_1^{S_n}$ can be computed:" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 1, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "V_SO(3)⊗V²_S(4) --> V²_SO(3)⊗V_S(4)\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "from emlp.groups import *\n", 51 | "from emlp.reps import T,vis,V,Scalar\n", 52 | "\n", 53 | "repin,repout = T(1)(SO(3))*T(2)(S(4)),T(2)(SO(3))*T(1)(S(4))\n", 54 | "print(repin,\"-->\", repout)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 7, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "data": { 64 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAATAAAADnCAYAAACZtwrQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAH/UlEQVR4nO3dMYhdRRsG4JPEIoKtIBjcbCxFLJU0AUurkCZiF7AIIqwhhUgKEQuxCMs2VoJWkjRilVKwESyDWGl2jdhZplkCm7UQxGLmmLnMuee+Z5+nnOueuZvL//4D7879Th0fHw8AiU7P/QYAViXAgFgCDIglwIBYAgyI9czYi9t7t1WUwKwOdm6eqr3mBAbEEmBALAEGxBJgQCwBBsQSYEAsAQbEEmBALAEGxBJgQCwBBsQSYEAsAQbEEmBALAEGxBJgQCwBBsQSYEAsAQbEEmBALAEGxBJgQCwBBsQanQtZ1WtaZG3aW+351elwncy17yp7+wzWsu+z5x8V11+69bjLtvtvP199bevjH4vrf350sWmPo7Pl9TOH5fVzn7Xt2/r8mtq+w079Z5zAgFgCDIglwIBYAgyIJcCAWOMt5NSN0NTN2zraw6n5DGb14pVfiuv7n5Qbudbm7cKdv6qv/fHtK8X1rVv1nyl5eLncdNbaw1a137lXOznGCQyIJcCAWAIMiCXAgFgCDIi12l3IXqa+h9erMZvzjuTUfAYr6dW81RrCYRiGc1fKdwNbG9Ct79pay9b2s/Y7tP4brcIJDIglwIBYAgyIJcCAWAIMiDVvCzm1Xo3Z1N+Kuq495pDyGXRqM9dxL7B1j7Gms6S1/Ry7z9mi1n6OcQIDYgkwIJYAA2IJMCCWAANinTo+rtc423u3yy/ONZMw5fkd93juYfn/Yz59/+suz9/54Z3i+sFbXxbXX/7+WtPznxyeKa6fPntUXH/w5ldN+7Y+v6a6793rTc+hv4Odm9X/pTmBAbEEGBBLgAGxBBgQS4ABsfrOhWxt3pb6/I573P/wi+L69r13i+ut7dvepW+K6699/l5xfbex/bzx09Xieq09bFX7fXu1k2w2JzAglgADYgkwIJYAA2IJMCCWAANirfaV0kv984eel78n3qPXnw/U/szhQac/39h9/W5xvab1zzdq77/134dMTmBALAEGxBJgQCwBBsQSYECsvoNt09vDTWwnG019ubn1+bWWsKa1/axdRm9Vaz+H80+6PJ9pOIEBsQQYEEuAAbEEGBBLgAGxsgbbdmrqqubad2zvmnW8J4bnfp9nsPAwGC7873Ne+NVgW2B5BBgQS4ABsQQYEEuAAbH6DrZtNfXdwyU0dXP9DnM2shtkrsHCw2C48FPt3e1JAGsmwIBYAgyIJcCAWAIMiNX3G1lbTX0XsldreRIbOe3nqKnncg6D2ZxPwwkMiCXAgFgCDIglwIBYAgyINW8LObVerWXPuZA1G9ayzWbT2s9G67gXeNJmc/68W/8ZJzAglgADYgkwIJYAA2IJMCCWuZCbsO8qe/sMlrEv/+tg56a5kMDyCDAglgADYgkwIJYAA2KZC7npfAZQ5QQGxBJgQCwBBsQSYEAsAQbEMhfyaZ6/5HtyPgOCOYEBsQQYEEuAAbEEGBBLgAGxzIX8r02cC9lzjzmkfAbazEhOYEAsAQbEEmBALAEGxBJgQKzRFvLZ84+K6y/detxl8/23ny+ub338Y3H9z48uNj3/6Gx5/cxhef3cZ2371p4/tkdNbe/fdt9oe1Avc7VyrfvO9T7X0X6etBmZK/ybOoEBsQQYEEuAAbEEGBBLgAGxRlvIF6/8Ulzf/6TcyrU2bxfu/FVc/+PbV4rrW7fK/33Nw8vllnOsPWwx9vu2NqBsqLnmco7tYTbnv5zAgFgCDIglwIBYAgyIJcCAWCt9I2utSWtt3mot4bkr5XuBre3n1ndtrWVr+1l7/2PvqVcDykKs4y7kgmdzOoEBsQQYEEuAAbEEGBBLgAGxus6F7NVO9nr+WEtY0tp+1u5yrqLWgA4Pu23BSbXg2ZxOYEAsAQbEEmBALAEGxBJgQKxTx8f1CmF773b5xalbh/Tnd9zDbM7xfXs13Bs3l3MYNm8+40z7Hnxws7qzExgQS4ABsQQYEEuAAbEEGBBr/C7k1Heilvr8jnuYzTlu6vu3s34L6YLnOfbiBAbEEmBALAEGxBJgQCwBBsRa7RtZl9oe9rw7OfEeZnP+o/b+FzGXc64GNKj9dAIDYgkwIJYAA2IJMCCWAANidZ0LGd8ebmI72chsztVU53L+3uXxWYLaTycwIJYAA2IJMCCWAANiCTAglgADYo0Otn31xm7xxU/f/7rL5js/vFNcP3jry+L6y99fa3r+k8MzxfXTZ4+K6w/e/Kpp39rzx/aoqe5993rTc7rZsOGmi92X/3WwY7AtsEACDIglwIBYAgyIJcCAWKOXue9/+EVxffveu8X11uZt79I3xfXXPn+vuL7b2H7e+OlqcX2sPWwx9vu2NqBVQRdrYd2cwIBYAgyIJcCAWAIMiCXAgFgrfaV0rUlrbd5qLeGDTu3n7ut3i+s1re1n7f2PvadeDejktJ8EcAIDYgkwIJYAA2IJMCCWAANidR1s26ud7PX8sZawpLX9rN3lXEWtAR22nnTbI4L2kwZOYEAsAQbEEmBALAEGxBJgQKzRuZDbe7fLL9Z/pHH3yvpJnA1oLiEUmQsJLJIAA2IJMCCWAANiCTAg1vhdyKmbsdbnt7afGjxYNCcwIJYAA2IJMCCWAANiCTAgVtdvZG029V3IXq2le4qwkZzAgFgCDIglwIBYAgyIJcCAWPO2kFPr1Vr2ups5RqMJzZzAgFgCDIglwIBYAgyIJcCAWKNzIQE2mRMYEEuAAbEEGBBLgAGxBBgQS4ABsf4G+uFQ5DwSeDQAAAAASUVORK5CYII=\n", 65 | "text/plain": [ 66 | "
" 67 | ] 68 | }, 69 | "metadata": { 70 | "needs_background": "light" 71 | }, 72 | "output_type": "display_data" 73 | } 74 | ], 75 | "source": [ 76 | "vis(repin,repout,cluster=False)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "Or perhaps you would like equivariant maps from two sequence of sets and a matrix (under 3D rotations) to itself. You can go wild." 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 2, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "name": "stdout", 93 | "output_type": "stream", 94 | "text": [ 95 | "Rep: V²+2V_S(4)⊗V_Z(3)\n", 96 | "Linear maps: V⁴+4V_S(4)⊗V_Z(3)⊗V²_SO(3)+4V²_S(4)⊗V²_Z(3)\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "rep = 2*T(1)(Z(3))*T(1)(S(4))+T(2)(SO(3))\n", 102 | "print(f\"Rep: {rep}\")\n", 103 | "print(f\"Linear maps: {rep>>rep}\")" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 7, 109 | "metadata": { 110 | "scrolled": true 111 | }, 112 | "outputs": [ 113 | { 114 | "data": { 115 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAI+UlEQVR4nO3dX8jfVR0H8LMhudrQLP/xYG4p1Voxc2GR0KVYEYWVrfBKKCo0NU0iqLCsJFdGGUQs6EqSjCKoXOpFNy2a2DJLZur8Uww2ZyPZTENaF7vo5nvOw/fsu+/ee57X6/J7OJ4H5psD55zf57Pi8OHDBciz8nj/AcAw4YRQwgmhhBNCCSeEOqk1+KHtn6we5e58+jWT/RGbNzxQHbv/2bWTrVNKKU/sGP67X/ed3ZOuc+Pv7q2Ofe6mT0y61qGFFYPfz7r/xUnXedn+Q9WxXZ86ZdK1NmzZN/j9pd1PTrrOSeetq45NuVZrnbsf2zL4D2jnhFDCCaGEE0IJJ4QSTgglnBCqeZXSui658Ny/j55T07ouuejVT3XNq6ldmTx67Xmj57S0rku+cdMPRs9pqV2Z7L3o5NFzWlrXJeu//9zoObTZOSGUcEIo4YRQwgmhhBNCNU9rW2qnsrVT3Naclp6T3ClPcUuZ7yS3dorbmtPSOpGd6yS3dorbmsMRdk4IJZwQSjghlHBCKOGEUMIJoVa0Kr6ff+fXJi0HX7tm+deLL59ymeZj+Z3vWph0rdo1y2kPT7pM85rl5o9fOelatWuWQ+temnSd1jXLyoMvDH5XQwg47oQTQgknhBJOCCWcEKr58H3FU/VT1J4H34cqy536039X5/Q8lt+99/Tq2OZtw9XleyvLv7YMn0Cv/PGa6pyeB99X3vOx6tiPtv5w8Htv2ZPVe2qH9PX/XXoey//n9NXVsVWV09rlxM4JoYQTQgknhBJOCCWcEEo4IVTzKqWnrk7PFUtPZfnF5tXUrkymriw/Z3X05JpEUzfwXU7snBBKOCGUcEIo4YRQwgmhhBNCdbdjmKsJ7VJs4DtX64KeBr6LzauZq4HvcmLnhFDCCaGEE0IJJ4QSTgjVrPj+7oWrJ634XjvJPby2XkOoR+ux/FzV5e/4/TsmXad1kvvPC06bdK3aSe5cleVLKWXtXXsGv6v4Dhx3wgmhhBNCCSeEEk4IJZwQqvnw/cJtw8fZpfQ9+K61LnhiR/2h+pRtH0qpt37oeSxfSr31w+on52td8KoHDwx+761JVGv9UGv7UErfY/l62wdKsXNCLOGEUMIJoYQTQgknhGqe1vaU7ug5xe2pLL/YvJraqezUleXnrI6eXPakt4Evdk6IJZwQSjghlHBCKOGEUMIJoborvs/VhHYpNvCdqzp6TwPfxebVzNXAdzmxc0Io4YRQwgmhhBNCCSeEEk4I1X2VUtPzS5adZaFrrZ4GvqWMb/3Qc83ybFk3ep2eX7Ic8dLotXquWXraPvQ28N1yyftGr7XU2DkhlHBCKOGEUMIJoYQTQjWb52687tvVwbN2HJzsj3hs8yuqY2e8Yf9k65RSyqYz/jH4/YHvXjjpOr/6+jerY++//vpJ1zrl138Z/L7re+snXeeWi39WHdvyrY9MutbzZw/2ky3nfmX7pOs8/aWLq2NTrtVa55GbPqN5LpxIhBNCCSeEEk4IJZwQSjghVPPhe+u6ZO/b1oyeU9O6LnnmkeHmtIvNq6ldmbz1mp2j57S0rkt+cdtto+e01K5M1l+9a/ScltZ1yY033Dl6Dm12TgglnBBKOCGUcEIo4YRQ3WVKaqeytVPc1pyWnpPcKU9xS5nvJLd2itua09I6kZ3rJLd2ituawxF2TgglnBBKOCGUcEIo4YRQwgmhJq/43vNYflV5vmut2pVJ67F8faSu55rlwVsvGL1Oz2P5Ukq56L5rRq/Vc82y74qNo9fpeSxfSilfvuOjo9daauycEEo4IZRwQijhhFDCCaGEE0I1r1LOv/1v1bGeX2OcfGC4u8ML286szun5JcupjauZWuuH3rYPf3zmnMHv1331J9U5Pb/GuPTmz1bH7v/icOuH3ppEz73nzYPfz7zjz9U5Pb9k+fz2D1THVo3+ry09dk4IJZwQSjghlHBCKOGEUM3T2p4H3z2nuD2P5RebV9PzWL7nJHfO6ujJNYmmbuC7nNg5IZRwQijhhFDCCaGEE0IJJ4TqriE0VxPapdjAd67WBb01iaa8Zpm67cNyYueEUMIJoYQTQgknhBJOCCWcEGrydgw9v2R5/NOv71qrp7t2T+uHnmuWV45epb91we1f+PDotXquWeZq+1BKKU/eML6dxVJj54RQwgmhhBNCCSeEEk4I1Tyt3b9puEJ7KX0PvmvV0TfNVFm+lHp1+Z7H8qXUq8uvfOjx6pypq6PfUqku31uTqFZdvlZZvpS+x/K1yvIcYeeEUMIJoYQTQgknhBJOCCWcEKp5ldLz4LvniqXnsfxi82p6Hsv3XLPM2boguSZRbwNf7JwQSzghlHBCKOGEUMIJobrLlMzVhHYpNvCdqzp6b9mTKU9yp64sv5zYOSGUcEIo4YRQwgmhhBNCCSeEmrzie89j+frlS1tPA9+e6vJ91ywvjl6ntzr6vis2jl6r55plrsrypZTyzq03jl5rqbFzQijhhFDCCaGEE0IJJ4QSTgi14vDheuuCh54+pz5I01VXjW80y//9duvWwe+XLrxl0nXe+9cD1bFfvum0Wda59o33rRj6bueEUMIJoYQTQgknhBJOCNU8rb1k5eVOa4ly8PK3V8fW3PWHwe97fr6hOmfhsoeP+m86Wvf+9y6ntXAiEU4IJZwQSjghlHBCKOGEUJPXEIJjqXZdUkr9mmXhsvqcHnNdzdg5IZRwQijhhFDCCaGEE0I5reWE0j4pHT6V7Xks3zLXY3k7J4QSTgglnBBKOCGUcEIo4YRQKr4fIyq+H50P3nrP4Pcpq7CXUspv9vypOjZldXkV32EJEU4IJZwQSjghlHBCKOGEUNoxQEXt1yytX7LUfjXT+iWLdgxwghFOCCWcEEo4IZRwQig1hFgyek5KW2qnsq2aRFNWl7dzQijhhFDCCaGEE0IJJ4QSTgjlKoUlY+o2CfWrmfENfHvaPtg5IZRwQijhhFDCCaGEE0Kp+H6MqPh+dOaq+N6qxD7lWq3K8ivPflSZEjiRCCeEEk4IJZwQSjghlHBCqOZVCnD82DkhlHBCKOGEUMIJoYQTQgknhPof2soBg+pZzXoAAAAASUVORK5CYII=\n", 116 | "text/plain": [ 117 | "
" 118 | ] 119 | }, 120 | "metadata": { 121 | "needs_background": "light" 122 | }, 123 | "output_type": "display_data" 124 | } 125 | ], 126 | "source": [ 127 | "vis(rep,rep)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "The kronecker product of the individual solutions for different groups can be seen in these basis matrices, in the top left hand corner we have a circulant matrix of deep set solutions (identity + background)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 3, 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "data": { 144 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAEpklEQVR4nO3cPWuedRiH4TttoNAvoHPFRYJkKmSU6l4EwZZu7dSl2MmX2ZcpkiUfIQiCe1F06BBIEVIpLmKHDsWv4NA+foHk6XLVng8cx/jccJHl5A8Zflur1WoBei686T8AOJs4IUqcECVOiBInRG2v+/jlHx+P/Cv3+P7ViTNJz+68GLlz5ebpyJ2aB89PR+7sHNwduVP057efbZ31u5cTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LWLiFMLRjs7Z+M3CkuKkwtGDw92h25U1tUmFoweHLvcOTOsmzOqoKXE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROi1i4hTLGo8GoWFdabXC+YWlV43YsKXk6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiNparVbnfvzgw+/O/7jBphYVlmVZHu1eHLtVMrWocOnx5ZE7RVOLChfe/mvrzN9HrgPjxAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUWtnSl7+8+7ITMnOwd2JM0nXbzwcuXN8/+rInZpnd16M3Lly83TkTtHPL380UwKbRJwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IWrtEsJ7X3w/soTw5N7hxJnkosJbj/4dubO3fzJyp7aosP3r7yN3nh7tjtxZlt6qgiUE2DDihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR/8sSwpTiosLUEsKU2qLC1BLCpKlVhalFBUsIsGHECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROiNmoJYcrUosKyLMu1W7fHbpVMLSo82r04cqdoalHh70+/soQAm0ScECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFq7RLCOz98PbKEcOXm6cSZpAfPT0fu7BzcHblTc/3Gw5E7x/evjtwp+u2Xzy0hwCYRJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6I2l73cWrB4OnR7sid4qLC1ILBk3uHI3dqiwpTCwZ7+ycjd5Zlc1YVvJwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiFo7UzLF3MmrmTtZb3JaZGry5HXPnXg5IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFqa7VanfvxowufnP9xg00tKizLslx6fHnsVsnUosK1W7dH7hRNLSp88/5PW2f97uWEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihKi1SwjAm+PlhChxQpQ4IUqcECVOiBInRP0H3NO3PQoCCbMAAAAASUVORK5CYII=\n", 145 | "text/plain": [ 146 | "
" 147 | ] 148 | }, 149 | "metadata": { 150 | "needs_background": "light" 151 | }, 152 | "output_type": "display_data" 153 | } 154 | ], 155 | "source": [ 156 | "vis(V(Z(3))*V(S(4)),V(Z(3))*V(S(4)))" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "And in the bottom left we have maps $V_{Z_3}\\otimes V_{S_4} \\rightarrow V_{SO(2)}^{\\otimes 2}$ of which the solutions are the product of $V_{Z_3}\\otimes V_{S_4} \\rightarrow \\mathbb{R}$ (the vector $\\mathbf{1}$) and $\\mathbb{R} \\rightarrow V_{SO(2)}^{\\otimes 2}$ (the flattened $I_{3\\times 3}$)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 5, 169 | "metadata": {}, 170 | "outputs": [ 171 | { 172 | "data": { 173 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAATAAAADnCAYAAACZtwrQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAADZUlEQVR4nO3dsY3CQBRF0bVFEYicnC4omlJogCo82wDjYIXMXumc9CU/uhrJgZcxxg9A0frtAwD+SsCALAEDsgQMyBIwIOu0N26vq0+UwFet5+cy3Y48BOCTBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBr97+Q98vtoDMA3nts880LDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7KWMcZ03F7X+QhwgPX8XKbbkYcAfJKAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZJ32xvvldtAZAO89tvnmBQZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZyxhjOm6v63wEOMB6fi7T7chDAD5JwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7J2/wsJ8J95gQFZAgZkCRiQJWBAloABWQIGZP0C0FAeWUQoKl0AAAAASUVORK5CYII=\n", 174 | "text/plain": [ 175 | "
" 176 | ] 177 | }, 178 | "metadata": { 179 | "needs_background": "light" 180 | }, 181 | "output_type": "display_data" 182 | } 183 | ], 184 | "source": [ 185 | "vis(V(Z(3))*V(S(4)),T(2)(SO(3)),False)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "## Wreath Products (coming soon)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "These are all examples of tensor products of representations, which are in turn a kind of representation on the direct product group $G=G_1\\times G_2$. There are other ways of combining groups however, and for many hierarchical structures there is a larger group of symmetries $G_1\\wr G_2$ that contains $G_1 \\times G_2$ but also other elements. These so called wreath products in the very nice paper [Equivariant Maps for Hierarchical Structures](https://arxiv.org/pdf/2006.03627.pdf). Support for the wreath product of representations is not yet implemented but if this is something that would be useful to you, send me an email." 200 | ] 201 | } 202 | ], 203 | "metadata": { 204 | "accelerator": "GPU", 205 | "kernelspec": { 206 | "display_name": "Python 3", 207 | "language": "python", 208 | "name": "python3" 209 | }, 210 | "language_info": { 211 | "codemirror_mode": { 212 | "name": "ipython", 213 | "version": 3 214 | }, 215 | "file_extension": ".py", 216 | "mimetype": "text/x-python", 217 | "name": "python", 218 | "nbconvert_exporter": "python", 219 | "pygments_lexer": "ipython3", 220 | "version": "3.8.5" 221 | } 222 | }, 223 | "nbformat": 4, 224 | "nbformat_minor": 4 225 | } 226 | -------------------------------------------------------------------------------- /docs/notebooks/colabs/6multilinear_maps.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "First install the repo and requirements." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%pip --quiet install git+https://github.com/mfinzi/equivariant-MLP.git" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "# Multilinear Maps" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "Our codebase extends trivially to multilinear maps, since these maps are in fact just linear maps in disguise.\n", 31 | "\n", 32 | "If we have a sequence of representations $R_1$, $R_2$, $R_3$ for example, we can write the (bi)linear maps $R_1\\rightarrow R_2\\rightarrow R_3$. This way of thinking about maps of multiple variables borrowed from programming languages and curried functions is very powerful." 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "We can think of such an object $R_1\\rightarrow R_2\\rightarrow R_3$ either as $R_1 \\rightarrow (R_2\\rightarrow R_3)$: a linear map from $R_1$ to linear maps from $R_2$ to $R_3$ or as\n", 40 | "$(R_1\\times R_2) \\rightarrow R_3$: a bilinear map from $R_1$ and $R_2$ to $R_3$. Since linear maps from one representation to another are just another representation in our type system, you can use this way of thinking to find the equivariant solutions to arbitrary multilinear maps." 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "For example, we can get the bilinear $SO(4)$ equivariant maps $(R_1\\times R_2) \\rightarrow R_3$ with the code below." 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 1, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "(2940, 27)\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "from emlp.groups import SO,rel_err\n", 65 | "from emlp.reps import V\n", 66 | "\n", 67 | "G = SO(4)\n", 68 | "W = V(G)\n", 69 | "R1 = 3*W+W**2 # some example representations\n", 70 | "R2 = W.T+W**0\n", 71 | "R3 = W**0 +W**2 +W\n", 72 | "\n", 73 | "Q = (R1>>(R2>>R3)).equivariant_basis()\n", 74 | "print(Q.shape)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "And we can verify that these multilinear solutions are indeed equivariant" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 2, 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "text/plain": [ 92 | "DeviceArray(1.2422272e-07, dtype=float32)" 93 | ] 94 | }, 95 | "execution_count": 2, 96 | "metadata": {}, 97 | "output_type": "execute_result" 98 | } 99 | ], 100 | "source": [ 101 | "import numpy as np\n", 102 | "\n", 103 | "example_map = (Q@np.random.randn(Q.shape[-1]))\n", 104 | "example_map = example_map.reshape(R3.size(),R2.size(),R1.size())\n", 105 | "\n", 106 | "x1 = np.random.randn(R1.size())\n", 107 | "x2 = np.random.randn(R2.size())\n", 108 | "g = G.sample()\n", 109 | "\n", 110 | "out1 = np.einsum(\"ijk,j,k\",example_map,R2.rho(g)@x2,R1.rho(g)@x1)\n", 111 | "out2 = R3.rho(g)@np.einsum(\"ijk,j,k\",example_map,x2,x1)\n", 112 | "rel_err(out1,out2)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "Note that the output mapping is of shape $(\\mathrm{dim}(R_3),\\mathrm{dim}(R_2),\\mathrm{dim}(R_1))$\n", 120 | "with the inputs to the right as you would expect with a matrix. \n", 121 | "\n", 122 | "Note the parenthesis in the expression `(R1>>(R2>>R3))` since the python `>>` associates to the right.\n", 123 | "The notation $R_1\\rightarrow R_2 \\rightarrow R_3$ or `(R1>>(R2>>R3))` can be a bit confusing since the inputs are on the right. It can be easier in this concept to instead reverse the arrows and express the same object as $R_3\\leftarrow R_2\\leftarrow R_1$ or `R3<>R2` wherever you like, and it is usually more intuitive." 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 3, 131 | "metadata": {}, 132 | "outputs": [ 133 | { 134 | "data": { 135 | "text/plain": [ 136 | "True" 137 | ] 138 | }, 139 | "execution_count": 3, 140 | "metadata": {}, 141 | "output_type": "execute_result" 142 | } 143 | ], 144 | "source": [ 145 | "R3<>(R2>>R3))" 146 | ] 147 | } 148 | ], 149 | "metadata": { 150 | "accelerator": "GPU", 151 | "kernelspec": { 152 | "display_name": "Python 3", 153 | "language": "python", 154 | "name": "python3" 155 | }, 156 | "language_info": { 157 | "codemirror_mode": { 158 | "name": "ipython", 159 | "version": 3 160 | }, 161 | "file_extension": ".py", 162 | "mimetype": "text/x-python", 163 | "name": "python", 164 | "nbconvert_exporter": "python", 165 | "pygments_lexer": "ipython3", 166 | "version": "3.8.5" 167 | } 168 | }, 169 | "nbformat": 4, 170 | "nbformat_minor": 4 171 | } 172 | -------------------------------------------------------------------------------- /docs/notebooks/colabs/7pytorch_support.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "First install the repo and requirements." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%pip --quiet install git+https://github.com/mfinzi/equivariant-MLP.git" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "# Limited Pytorch Support" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "We strongly recommend that users of our libary write native Jax code. However, we understand that due to existing code and/or constraints from the employer, it is sometimes unavoidable to use other frameworks like PyTorch. \n", 31 | "\n", 32 | "To service these requirements, we have added a way that PyTorch users can make use of the equivariant bases $Q\\in \\mathbb{R}^{n\\times r}$ and projection matrices $P = QQ^\\top$ that are computed by our solver. Since these objects are implicitly defined through `LinearOperators`, it is not as straightforward as simply calling `torch.from_numpy(Q)`. However, there is a way to use these operators within PyTorch code while preserving any gradients of the operation. We provide the function `emlp.reps.pytorch_support.torchify_fn` to do this." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 1, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "import torch\n", 42 | "import jax\n", 43 | "import jax.numpy as jnp\n", 44 | "from emlp.reps import V\n", 45 | "from emlp.groups import S\n", 46 | "\n", 47 | "W =V(S(4))\n", 48 | "rep = 3*W+W**2" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "Q = (rep>>rep).equivariant_basis()\n", 58 | "P = (rep>>rep).equivariant_projector()" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 3, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "applyQ = lambda v: Q@v\n", 68 | "applyP = lambda v: P@v" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "The key is to wrap the desired operations as a function, and then we can apply `torchify_fn`. Now instead of taking jax objects as inputs and outputing jax objects, these functions take in PyTorch objects and output PyTorch objects." 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 4, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "from emlp.reps.pytorch_support import torchify_fn\n", 85 | "applyQ_torch = torchify_fn(applyQ)\n", 86 | "applyP_torch = torchify_fn(applyP)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 5, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "x_torch = torch.arange(Q.shape[-1]).float().cuda()\n", 96 | "x_torch.requires_grad=True\n", 97 | "x_jax = jnp.asarray(x_torch.cpu().data.numpy()) " 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 6, 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "name": "stdout", 107 | "output_type": "stream", 108 | "text": [ 109 | "jax output: [0.48484263 0.07053992 0.07053989 0.07053995 1.6988853 ]\n", 110 | "torch output: tensor([0.4848, 0.0705, 0.0705, 0.0705, 1.6989], device='cuda:0',\n", 111 | " grad_fn=)\n" 112 | ] 113 | } 114 | ], 115 | "source": [ 116 | "Qx1 = applyQ(x_jax)\n", 117 | "Qx2 = applyQ_torch(x_torch)\n", 118 | "print(\"jax output: \",Qx1[:5])\n", 119 | "print(\"torch output: \",Qx2[:5])" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "The outputs match, and note that the torch outputs will be on whichever is the default jax device. Similarly, the gradients of the two objects also match:" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 7, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "data": { 136 | "text/plain": [ 137 | "tensor([-2.8704, 2.7858, -2.8704, 2.7858, -2.8704], device='cuda:0')" 138 | ] 139 | }, 140 | "execution_count": 7, 141 | "metadata": {}, 142 | "output_type": "execute_result" 143 | } 144 | ], 145 | "source": [ 146 | "torch.autograd.grad(Qx2.sum(),x_torch)[0][:5]" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 8, 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "data": { 156 | "text/plain": [ 157 | "DeviceArray([-2.8703732, 2.7858496, -2.8703732, 2.7858496, -2.8703732], dtype=float32)" 158 | ] 159 | }, 160 | "execution_count": 8, 161 | "metadata": {}, 162 | "output_type": "execute_result" 163 | } 164 | ], 165 | "source": [ 166 | "jax.grad(lambda x: (Q@x).sum())(x_jax)[:5]" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "So you can safely use these torchified functions within your model, and still compute the gradients correctly." 174 | ] 175 | } 176 | ], 177 | "metadata": { 178 | "accelerator": "GPU", 179 | "kernelspec": { 180 | "display_name": "Python 3", 181 | "language": "python", 182 | "name": "python3" 183 | }, 184 | "language_info": { 185 | "codemirror_mode": { 186 | "name": "ipython", 187 | "version": 3 188 | }, 189 | "file_extension": ".py", 190 | "mimetype": "text/x-python", 191 | "name": "python", 192 | "nbconvert_exporter": "python", 193 | "pygments_lexer": "ipython3", 194 | "version": "3.8.5" 195 | } 196 | }, 197 | "nbformat": 4, 198 | "nbformat_minor": 4 199 | } 200 | -------------------------------------------------------------------------------- /docs/notebooks/colabs/flax_support.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "First install the repo and requirements." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%pip --quiet install git+https://github.com/mfinzi/equivariant-MLP.git" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "# Using EMLP with Flax" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "To use EMLP with [Flax](https://github.com/google/flax) is pretty similar to Objax or Haiku. Just make sure to import from the flax implementation `emlp.nn.flax`" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 1, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "from jax import random\n", 40 | "import numpy as np\n", 41 | "import emlp.nn.flax as nn # import from the flax implementation\n", 42 | "from emlp.reps import T,V\n", 43 | "from emlp.groups import SO\n", 44 | "\n", 45 | "repin= 4*V # Setup some example data representations\n", 46 | "repout = V\n", 47 | "G = SO(3)\n", 48 | "\n", 49 | "x = np.random.randn(5,repin(G).size()) # generate some random data" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "model = nn.EMLP(repin,repout,G)\n", 59 | "\n", 60 | "key = random.PRNGKey(0)\n", 61 | "params = model.init(random.PRNGKey(42), x)\n", 62 | "\n", 63 | "y = model.apply(params, x) # Forward pass with inputs x and parameters" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "And indeed, the parameters of the model are registered as expected." 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 3, 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "data": { 80 | "text/plain": [ 81 | "['modules_0', 'modules_1', 'modules_2', 'modules_3']" 82 | ] 83 | }, 84 | "execution_count": 3, 85 | "metadata": {}, 86 | "output_type": "execute_result" 87 | } 88 | ], 89 | "source": [ 90 | "list(params['params'].keys())" 91 | ] 92 | } 93 | ], 94 | "metadata": { 95 | "accelerator": "GPU", 96 | "interpreter": { 97 | "hash": "ec74566b76234e57f2cd5bb0818dcd91369c1a3af290381c3b6efeb6aea6cdd5" 98 | }, 99 | "kernelspec": { 100 | "display_name": "Python 3", 101 | "language": "python", 102 | "name": "python3" 103 | }, 104 | "language_info": { 105 | "codemirror_mode": { 106 | "name": "ipython", 107 | "version": 3 108 | }, 109 | "file_extension": ".py", 110 | "mimetype": "text/x-python", 111 | "name": "python", 112 | "nbconvert_exporter": "python", 113 | "pygments_lexer": "ipython3", 114 | "version": "3.8.5" 115 | } 116 | }, 117 | "nbformat": 4, 118 | "nbformat_minor": 4 119 | } 120 | -------------------------------------------------------------------------------- /docs/notebooks/colabs/haiku_support.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "First install the repo and requirements." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%pip --quiet install git+https://github.com/mfinzi/equivariant-MLP.git" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "# Using EMLP with Haiku" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "There are many neural network frameworks for jax, and they are often incompatible. Since most of the functionality of this package is written in pure jax, it can be used with flax, trax, linen, haiku, objax, or whatever your favorite jax NN framework.\n", 31 | "\n", 32 | "\n", 33 | "However, the equivariant neural network layers provided in the [Layers and Models](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html) are made for objax.\n", 34 | "If we try to use them with the popular [Haiku framework](https://dm-haiku.readthedocs.io/en/latest/), things will not work as expected.\n", 35 | "\n", 36 | "## Dont Do This:" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 1, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "import haiku as hk\n", 46 | "from jax import random\n", 47 | "import numpy as np\n", 48 | "import emlp.nn as nn\n", 49 | "from emlp.reps import T,V\n", 50 | "from emlp.groups import SO\n", 51 | "\n", 52 | "repin= 4*V # Setup some example data representations\n", 53 | "repout = V\n", 54 | "G = SO(3)\n", 55 | "\n", 56 | "x = np.random.randn(10,repin(G).size()) # generate some random data" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 2, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "model = nn.EMLP(repin,repout,G)\n", 66 | "net = hk.without_apply_rng(hk.transform(model))\n", 67 | "\n", 68 | "key = random.PRNGKey(0)\n", 69 | "params = net.init(random.PRNGKey(42), x)\n", 70 | "\n", 71 | "y = net.apply(params, x)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "Although the code executes, we see that Haiku does not recognize the model parameters and treats the network as if it is a stateless jax function." 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 3, 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "data": { 88 | "text/plain": [ 89 | "FlatMapping({})" 90 | ] 91 | }, 92 | "execution_count": 3, 93 | "metadata": {}, 94 | "output_type": "execute_result" 95 | } 96 | ], 97 | "source": [ 98 | "params" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "It's not hard to build EMLP layers in Haiku, and for each of the nn layers in [Layers and Models](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html) we have implemented a Haiku version with the same arguments. These layers are accessible via `emlp.nn.haiku` rather than `emlp.nn`. To use EMLP models and equivariant layers with Haiku, instead of the above you should import from `emlp.nn.haiku`." 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [ 112 | "## Instead, Do This:" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 4, 118 | "metadata": { 119 | "scrolled": true 120 | }, 121 | "outputs": [], 122 | "source": [ 123 | "import emlp.nn.haiku as ehk\n", 124 | "\n", 125 | "model = ehk.EMLP(repin,repout,SO(3))\n", 126 | "net = hk.without_apply_rng(hk.transform(model))\n", 127 | "\n", 128 | "key = random.PRNGKey(0)\n", 129 | "params = net.init(random.PRNGKey(42), x)\n", 130 | "y = net.apply(params, x)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 5, 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "data": { 140 | "text/plain": [ 141 | "KeysOnlyKeysView(['sequential/hk_linear', 'sequential/hk_bi_linear', 'sequential/hk_linear_1', 'sequential/hk_bi_linear_1', 'sequential/hk_linear_2', 'sequential/hk_bi_linear_2', 'sequential/hk_linear_3'])" 142 | ] 143 | }, 144 | "execution_count": 5, 145 | "metadata": {}, 146 | "output_type": "execute_result" 147 | } 148 | ], 149 | "source": [ 150 | "params.keys()" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "With this Haiku EMLP, paramaters are registered as expected.\n", 158 | "\n", 159 | "If your favorite deep learning framework is not one of objax, haiku, or pytorch, don't panic. It's possible to use EMLP with other jax frameworks without much trouble, similar to the objax and haiku implementations. If you need help with this, start a pull request and we can send over some pointers." 160 | ] 161 | } 162 | ], 163 | "metadata": { 164 | "accelerator": "GPU", 165 | "kernelspec": { 166 | "display_name": "Python 3", 167 | "language": "python", 168 | "name": "python3" 169 | }, 170 | "language_info": { 171 | "codemirror_mode": { 172 | "name": "ipython", 173 | "version": 3 174 | }, 175 | "file_extension": ".py", 176 | "mimetype": "text/x-python", 177 | "name": "python", 178 | "nbconvert_exporter": "python", 179 | "pygments_lexer": "ipython3", 180 | "version": "3.8.5" 181 | } 182 | }, 183 | "nbformat": 4, 184 | "nbformat_minor": 4 185 | } 186 | -------------------------------------------------------------------------------- /docs/notebooks/colabs/imgs/EMLP_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfinzi/equivariant-MLP/b80815bd4c1ca52b37f3dacc0874e70fdd7b413f/docs/notebooks/colabs/imgs/EMLP_fig.png -------------------------------------------------------------------------------- /docs/notebooks/colabs/imgs/imgs/EMLP_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfinzi/equivariant-MLP/b80815bd4c1ca52b37f3dacc0874e70fdd7b413f/docs/notebooks/colabs/imgs/imgs/EMLP_fig.png -------------------------------------------------------------------------------- /docs/notebooks/flax_support.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Using EMLP with Flax" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "To use EMLP with [Flax](https://github.com/google/flax) is pretty similar to Objax or Haiku. Just make sure to import from the flax implementation `emlp.nn.flax`" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from jax import random\n", 24 | "import numpy as np\n", 25 | "import emlp.nn.flax as nn # import from the flax implementation\n", 26 | "from emlp.reps import T,V\n", 27 | "from emlp.groups import SO\n", 28 | "\n", 29 | "repin= 4*V # Setup some example data representations\n", 30 | "repout = V\n", 31 | "G = SO(3)\n", 32 | "\n", 33 | "x = np.random.randn(5,repin(G).size()) # generate some random data" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "model = nn.EMLP(repin,repout,G)\n", 43 | "\n", 44 | "key = random.PRNGKey(0)\n", 45 | "params = model.init(random.PRNGKey(42), x)\n", 46 | "\n", 47 | "y = model.apply(params, x) # Forward pass with inputs x and parameters" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "And indeed, the parameters of the model are registered as expected." 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "output_type": "execute_result", 64 | "data": { 65 | "text/plain": [ 66 | "['modules_0', 'modules_1', 'modules_2', 'modules_3']" 67 | ] 68 | }, 69 | "metadata": {}, 70 | "execution_count": 3 71 | } 72 | ], 73 | "source": [ 74 | "list(params['params'].keys())" 75 | ] 76 | } 77 | ], 78 | "metadata": { 79 | "kernelspec": { 80 | "name": "python3", 81 | "display_name": "Python 3.8.5 64-bit ('freshenv': conda)" 82 | }, 83 | "language_info": { 84 | "codemirror_mode": { 85 | "name": "ipython", 86 | "version": 3 87 | }, 88 | "file_extension": ".py", 89 | "mimetype": "text/x-python", 90 | "name": "python", 91 | "nbconvert_exporter": "python", 92 | "pygments_lexer": "ipython3", 93 | "version": "3.8.5" 94 | }, 95 | "interpreter": { 96 | "hash": "ec74566b76234e57f2cd5bb0818dcd91369c1a3af290381c3b6efeb6aea6cdd5" 97 | } 98 | }, 99 | "nbformat": 4, 100 | "nbformat_minor": 4 101 | } -------------------------------------------------------------------------------- /docs/notebooks/haiku_support.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Using EMLP with Haiku" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "There are many neural network frameworks for jax, and they are often incompatible. Since most of the functionality of this package is written in pure jax, it can be used with flax, trax, linen, haiku, objax, or whatever your favorite jax NN framework.\n", 15 | "\n", 16 | "\n", 17 | "However, the equivariant neural network layers provided in the [Layers and Models](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html) are made for objax.\n", 18 | "If we try to use them with the popular [Haiku framework](https://dm-haiku.readthedocs.io/en/latest/), things will not work as expected.\n", 19 | "\n", 20 | "## Dont Do This:" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import haiku as hk\n", 30 | "from jax import random\n", 31 | "import numpy as np\n", 32 | "import emlp.nn as nn\n", 33 | "from emlp.reps import T,V\n", 34 | "from emlp.groups import SO\n", 35 | "\n", 36 | "repin= 4*V # Setup some example data representations\n", 37 | "repout = V\n", 38 | "G = SO(3)\n", 39 | "\n", 40 | "x = np.random.randn(10,repin(G).size()) # generate some random data" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 2, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "model = nn.EMLP(repin,repout,G)\n", 50 | "net = hk.without_apply_rng(hk.transform(model))\n", 51 | "\n", 52 | "key = random.PRNGKey(0)\n", 53 | "params = net.init(random.PRNGKey(42), x)\n", 54 | "\n", 55 | "y = net.apply(params, x)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "Although the code executes, we see that Haiku does not recognize the model parameters and treats the network as if it is a stateless jax function." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "data": { 72 | "text/plain": [ 73 | "FlatMapping({})" 74 | ] 75 | }, 76 | "execution_count": 3, 77 | "metadata": {}, 78 | "output_type": "execute_result" 79 | } 80 | ], 81 | "source": [ 82 | "params" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "It's not hard to build EMLP layers in Haiku, and for each of the nn layers in [Layers and Models](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html) we have implemented a Haiku version with the same arguments. These layers are accessible via `emlp.nn.haiku` rather than `emlp.nn`. To use EMLP models and equivariant layers with Haiku, instead of the above you should import from `emlp.nn.haiku`." 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "## Instead, Do This:" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 4, 102 | "metadata": { 103 | "scrolled": true 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "import emlp.nn.haiku as ehk\n", 108 | "\n", 109 | "model = ehk.EMLP(repin,repout,SO(3))\n", 110 | "net = hk.without_apply_rng(hk.transform(model))\n", 111 | "\n", 112 | "key = random.PRNGKey(0)\n", 113 | "params = net.init(random.PRNGKey(42), x)\n", 114 | "y = net.apply(params, x)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 5, 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "data": { 124 | "text/plain": [ 125 | "KeysOnlyKeysView(['sequential/hk_linear', 'sequential/hk_bi_linear', 'sequential/hk_linear_1', 'sequential/hk_bi_linear_1', 'sequential/hk_linear_2', 'sequential/hk_bi_linear_2', 'sequential/hk_linear_3'])" 126 | ] 127 | }, 128 | "execution_count": 5, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | } 132 | ], 133 | "source": [ 134 | "params.keys()" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "With this Haiku EMLP, paramaters are registered as expected.\n", 142 | "\n", 143 | "If your favorite deep learning framework is not one of objax, haiku, or pytorch, don't panic. It's possible to use EMLP with other jax frameworks without much trouble, similar to the objax and haiku implementations. If you need help with this, start a pull request and we can send over some pointers." 144 | ] 145 | } 146 | ], 147 | "metadata": { 148 | "kernelspec": { 149 | "display_name": "Python 3", 150 | "language": "python", 151 | "name": "python3" 152 | }, 153 | "language_info": { 154 | "codemirror_mode": { 155 | "name": "ipython", 156 | "version": 3 157 | }, 158 | "file_extension": ".py", 159 | "mimetype": "text/x-python", 160 | "name": "python", 161 | "nbconvert_exporter": "python", 162 | "pygments_lexer": "ipython3", 163 | "version": "3.8.5" 164 | } 165 | }, 166 | "nbformat": 4, 167 | "nbformat_minor": 4 168 | } 169 | -------------------------------------------------------------------------------- /docs/notebooks/imgs/EMLP_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfinzi/equivariant-MLP/b80815bd4c1ca52b37f3dacc0874e70fdd7b413f/docs/notebooks/imgs/EMLP_fig.png -------------------------------------------------------------------------------- /docs/notebooks/merge_nbs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #set -x 3 | dir_path=$(dirname $(realpath $0)) 4 | PREAMBLE=$dir_path/_colab_preamble.ipynb 5 | NOTEBOOKS=$dir_path/[0-9]*.ipynb 6 | NOTEBOOKSALL=$dir_path/[^_]*.ipynb 7 | mkdir -p $dir_path/colabs 8 | cp -r $dir_path/imgs $dir_path/colabs/imgs 9 | for nb in $NOTEBOOKSALL 10 | do 11 | nbmerge $PREAMBLE $nb -o $dir_path/colabs/$(basename $nb) 12 | done 13 | nbmerge $PREAMBLE $NOTEBOOKS -o $dir_path/colabs/all.ipynb 14 | # nbmerge -i notebooks/[^_]*.ipynb -o notebooks/_colab_all.ipynb -------------------------------------------------------------------------------- /docs/package/emlp.groups.rst: -------------------------------------------------------------------------------- 1 | Groups 2 | ====== 3 | 4 | .. automodule:: emlp.groups 5 | :members: -------------------------------------------------------------------------------- /docs/package/emlp.nn.rst: -------------------------------------------------------------------------------- 1 | Layers and Models 2 | ==== 3 | 4 | .. automodule:: emlp.nn 5 | :members: 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/package/emlp.reps.rst: -------------------------------------------------------------------------------- 1 | Representations 2 | ============== 3 | 4 | .. automodule:: emlp.reps 5 | :members: 6 | :exclude-members: Rep 7 | 8 | .. autoclass:: Rep 9 | :members: size,rho,drho,equivariant_basis,equivariant_projector, rho_dense,drho_dense 10 | 11 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # sphinx <4 required by myst-nb v0.12.0 (Feb 2021) 2 | # sphinx >=3 required by sphinx-autodoc-typehints v1.11.1 (Oct 2020) 3 | Jinja2 < 3.1 4 | sphinx >=3, <4 5 | sphinx_rtd_theme 6 | sphinx-autodoc-typehints 7 | myst-nb 8 | nbmerge 9 | nbsphinx == 0.8.6 10 | recommonmark 11 | # Packages used for CI tests. 12 | pytest 13 | pytest-xdist 14 | #jupytext 15 | # Packages used for notebook execution 16 | matplotlib 17 | #. # install emlp itself from current directory. 18 | -------------------------------------------------------------------------------- /docs/testing.md: -------------------------------------------------------------------------------- 1 | # Running the Tests 2 | 3 | We have a number of tests checking the equivariance of representations constructed in 4 | different ways (`.T`, `*`, `+`) for the groups that have been implemented (`Z(n)`,`S(n)`,`D(k)`,`SO(n)`, `O(n)`,`Sp(n)`,`SO13()`,`O13()`,`SU(n)`). 5 | We use pytest and some of the tests are automatically generated. Because there is a large amount of tests and it can take quite some time to run them all (about 15 minutes), 6 | you can run a subset using pytests built in features to filter by the matches on the name of the testcase using the `-k` argument. 7 | 8 | For example to run `test_prod` with all the groups you can run 9 | ```pytest tests/equivariance_tests.py -k "test_prod"``` 10 | 11 | To run the test case for a specific group could use the filter `-k "test_prod and SO3"` and to run all tests with that group 12 | you could run 13 | ```pytest tests/equivariance_tests.py -k "SO3"``` 14 | 15 | Due to pytest parsing limitations, all parenthesis in the test names are stripped. 16 | To list all available tests, (or those that match certain `-k` arguments) use `--co` (for collect only). 17 | 18 | The usual pytest command line arguments apply (like `-v` for verbose). 19 | 20 | 21 | Similarly, you can find tests for "mixed" representations containing sub-representations from different groups in `emlp/tests/product_groups_tests.py`. -------------------------------------------------------------------------------- /emlp/__init__.py: -------------------------------------------------------------------------------- 1 | # import importlib 2 | # import pkgutil 3 | # __all__ = [] 4 | # for loader, module_name, is_pkg in pkgutil.walk_packages(__path__): 5 | # module = importlib.import_module('.'+module_name,package=__name__) 6 | # try: 7 | # globals().update({k: getattr(module, k) for k in module.__all__}) 8 | # __all__ += module.__all__ 9 | # except AttributeError: continue 10 | # # concatenate the __all__ from each of the submodules (expose to user) 11 | 12 | __version__ = '1.0.3' 13 | from .nn import * 14 | from .groups import * 15 | from .reps import * 16 | -------------------------------------------------------------------------------- /emlp/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax.numpy as jnp 3 | from emlp.reps import Scalar,Vector,T 4 | from emlp.utils import export,Named 5 | from emlp.groups import SO,O,Trivial,Lorentz,RubiksCube,Cube 6 | from functools import partial 7 | import itertools 8 | from jax import vmap,jit 9 | from objax import Module 10 | 11 | @export 12 | class Inertia(object): 13 | def __init__(self,N=1024,k=5): 14 | super().__init__() 15 | self.dim = (1+3)*k 16 | self.X = np.random.randn(N,self.dim) 17 | self.X[:,:k] = np.log(1+np.exp(self.X[:,:k])) # Masses 18 | mi = self.X[:,:k] 19 | ri = self.X[:,k:].reshape(-1,k,3) 20 | I = np.eye(3) 21 | r2 = (ri**2).sum(-1)[...,None,None] 22 | inertia = (mi[:,:,None,None]*(r2*I - ri[...,None]*ri[...,None,:])).sum(1) 23 | self.Y = inertia.reshape(-1,9) 24 | self.rep_in = k*Scalar+k*Vector 25 | self.rep_out = T(2) 26 | self.symmetry = O(3) 27 | # One has to be careful computing offset and scale in a way so that standardizing 28 | # does not violate equivariance 29 | Xmean = self.X.mean(0) 30 | Xmean[k:] = 0 31 | Xstd = np.zeros_like(Xmean) 32 | Xstd[:k] = np.abs(self.X[:,:k]).mean(0)#.std(0) 33 | #Xstd[k:] = (np.sqrt((self.X[:,k:].reshape(N,k,3)**2).mean((0,2))[:,None]) + np.zeros((k,3))).reshape(k*3) 34 | Xstd[k:] = (np.abs(self.X[:,k:].reshape(N,k,3)).mean((0,2))[:,None] + np.zeros((k,3))).reshape(k*3) 35 | Ymean = 0*self.Y.mean(0) 36 | #Ystd = np.sqrt(((self.Y-Ymean)**2).mean((0,1)))+ np.zeros_like(Ymean) 37 | Ystd = np.abs(self.Y-Ymean).mean((0,1)) + np.zeros_like(Ymean) 38 | self.stats =0,1,0,1#Xmean,Xstd,Ymean,Ystd 39 | 40 | def __getitem__(self,i): 41 | return (self.X[i],self.Y[i]) 42 | def __len__(self): 43 | return self.X.shape[0] 44 | def default_aug(self,model): 45 | return GroupAugmentation(model,self.rep_in,self.rep_out,self.symmetry) 46 | 47 | @export 48 | class O5Synthetic(object): 49 | def __init__(self,N=1024): 50 | super().__init__() 51 | d=5 52 | self.dim = 2*d 53 | self.X = np.random.randn(N,self.dim) 54 | ri = self.X.reshape(-1,2,5) 55 | r1,r2 = ri.transpose(1,0,2) 56 | self.Y = np.sin(np.sqrt((r1**2).sum(-1)))-.5*np.sqrt((r2**2).sum(-1))**3 + (r1*r2).sum(-1)/(np.sqrt((r1**2).sum(-1))*np.sqrt((r2**2).sum(-1))) 57 | self.rep_in = 2*Vector 58 | self.rep_out = Scalar 59 | self.symmetry = O(d) 60 | self.Y = self.Y[...,None] 61 | # One has to be careful computing mean and std in a way so that standardizing 62 | # does not violate equivariance 63 | Xmean = self.X.mean(0) # can add and subtract arbitrary tensors 64 | Xscale = (np.sqrt((self.X.reshape(N,2,d)**2).mean((0,2)))[:,None]+0*ri[0]).reshape(self.dim) 65 | self.stats = 0,Xscale,self.Y.mean(axis=0),self.Y.std(axis=0) 66 | 67 | def __getitem__(self,i): 68 | return (self.X[i],self.Y[i]) 69 | def __len__(self): 70 | return self.X.shape[0] 71 | def default_aug(self,model): 72 | return GroupAugmentation(model,self.rep_in,self.rep_out,self.symmetry) 73 | 74 | @export 75 | class ParticleInteraction(object): 76 | """ Electron muon e^4 interaction""" 77 | def __init__(self,N=1024): 78 | super().__init__() 79 | self.dim = 4*4 80 | self.rep_in = 4*Vector 81 | self.rep_out = Scalar 82 | self.X = np.random.randn(N,self.dim)/4 83 | P = self.X.reshape(N,4,4) 84 | p1,p2,p3,p4 = P.transpose(1,0,2) 85 | 𝜂 = np.diag(np.array([1.,-1.,-1.,-1.])) 86 | dot = lambda v1,v2: ((v1@𝜂)*v2).sum(-1) 87 | Le = (p1[:,:,None]*p3[:,None,:] - (dot(p1,p3)-dot(p1,p1))[:,None,None]*𝜂) 88 | L𝜇 = ((p2@𝜂)[:,:,None]*(p4@𝜂)[:,None,:] - (dot(p2,p4)-dot(p2,p2))[:,None,None]*𝜂) 89 | M = 4*(Le*L𝜇).sum(-1).sum(-1) 90 | self.Y = M 91 | self.symmetry = Lorentz() 92 | self.Y = self.Y[...,None] 93 | # One has to be careful computing mean and std in a way so that standardizing 94 | # does not violate equivariance 95 | self.Xscale = np.sqrt((np.abs((self.X.reshape(N,4,4)@𝜂)*self.X.reshape(N,4,4)).mean(-1)).mean(0)) 96 | self.Xscale = (self.Xscale[:,None]+np.zeros((4,4))).reshape(-1) 97 | self.stats = 0,self.Xscale,self.Y.mean(axis=0),self.Y.std(axis=0)#self.X.mean(axis=0),self.X.std(axis=0), 98 | def __getitem__(self,i): 99 | return (self.X[i],self.Y[i]) 100 | def __len__(self): 101 | return self.X.shape[0] 102 | def default_aug(self,model): 103 | return GroupAugmentation(model,self.rep_in,self.rep_out,self.symmetry) 104 | 105 | class GroupAugmentation(Module): 106 | def __init__(self,network,rep_in,rep_out,group): 107 | super().__init__() 108 | self.rep_in=rep_in 109 | self.rep_out=rep_out 110 | self.G=group 111 | self.rho_in = jit(vmap(self.rep_in.rho)) 112 | self.rho_out = jit(vmap(self.rep_out.rho)) 113 | self.model = network 114 | def __call__(self,x,training=True): 115 | if training: 116 | gs = self.G.samples(x.shape[0]) 117 | rhout_inv = jnp.linalg.inv(self.rho_out(gs)) 118 | return (rhout_inv@self.model((self.rho_in(gs)@x[...,None])[...,0],training)[...,None])[...,0] 119 | else: 120 | return self.model(x,False) 121 | 122 | @export 123 | class InvertedCube(object): 124 | def __init__(self,train=True): 125 | pass #TODO: finish implementing this simple dataset 126 | solved_state = np.eye(6) 127 | parity_perm = np.array([5,3,4,1,2,0]) 128 | parity_state = solved_state[:,parity_perm] 129 | 130 | labels = np.array([1,0]).astype(int) 131 | self.X = np.zeros((2,6*6)) 132 | self.X[0] = solved_state.reshape(-1) 133 | self.X[1] = parity_state.reshape(-1) 134 | self.Y = labels 135 | self.symmetry = Cube() 136 | self.rep_in = 6*Vector 137 | self.rep_out = 2*Scalar 138 | self.stats = (0,1) 139 | if train==False: # Scramble the cubes for test time 140 | gs = self.symmetry.samples(100) 141 | self.X = np.repeat(self.X,50,axis=0).reshape(100,6,6)@gs 142 | self.Y = np.repeat(self.Y,50,axis=0) 143 | p = np.random.permutation(100) 144 | self.X = self.X[p].reshape((100,-1)) 145 | self.Y = self.Y[p].reshape((100,)) 146 | self.X = np.array(self.X) 147 | self.Y = np.array(self.Y) 148 | 149 | def __getitem__(self,i): 150 | return (self.X[i],self.Y[i]) 151 | def __len__(self): 152 | return self.X.shape[0] 153 | 154 | 155 | #### Ways of constructing invalid Rubik's cubes 156 | # (see https://ruwix.com/the-rubiks-cube/unsolvable-rubiks-cube-invalid-scramble/) 157 | 158 | def UBedge_flip(state): 159 | """ # invalid edge flip on Rubiks state (6,48)""" 160 | UB_edge = np.array([1,3*8+1]) # Top middle of U, top middle of B 161 | edge_flip = state.copy() 162 | edge_flip[:,UB_edge] = edge_flip[:,np.roll(UB_edge,1)] 163 | return edge_flip 164 | 165 | def ULBcorner_rot(state,i=1): 166 | """ Invalid rotated corner with ULB corner rotation """ 167 | ULB_corner_ids = np.array([0,4*8,3*8+2]) # top left of U, top left of L, top right of B 168 | rotated_corner_state = state.copy() 169 | rotated_corner_state[:,ULB_corner_ids] = rotated_corner_state[:,np.roll(ULB_corner_ids,i)] 170 | return rotated_corner_state 171 | 172 | def LBface_swap(state): 173 | """ Invalid piece swap between L center top and B center top """ 174 | L_B_center_top_faces = np.array([4*8+1,3*8+1]) 175 | piece_swap = state.copy() 176 | piece_swap[:,L_B_center_top_faces] = piece_swap[:,np.roll(L_B_center_top_faces,1)] 177 | return piece_swap 178 | 179 | 180 | @export 181 | class BrokenRubiksCube(object): 182 | """ Binary classification problem of predicting whether a Rubik's cube configuration 183 | is solvable or 'broken' and not able to be solved by transformations from the group 184 | e.g. by removing and twisting a corner before replacing. 185 | Dataset is generated by taking several hand identified simple instances of solvable 186 | and unsolvable cubes, and then scrambling them. 187 | 188 | Features are represented as 6xT(1) tensors of the Rubiks Group (one hot for each color)""" 189 | def __init__(self,train=True): 190 | super().__init__() 191 | # start with a valid configuration 192 | 193 | solved_state = np.zeros((6,48)) 194 | for i in range(6): 195 | solved_state[i,8*i:8*(i+1)] = 1 196 | 197 | Id = lambda x: x 198 | transforms = [Id,itertools.product([Id,ULBcorner_rot,partial(ULBcorner_rot,i=2)], 199 | [Id,UBedge_flip], 200 | [Id,LBface_swap])] 201 | #equivalence_classes = np.vstack([t3(t2(t1(solved_state))) for t1,t2,t3 in transforms]) 202 | labels = np.zeros((22,)) 203 | labels[:11]=1 # First configuration is solvable 204 | labels[11:]=0 # all others are not 205 | self.X = np.zeros((22,6*48)) # duplicate solvable example 11 times for convencience (balance) 206 | self.X[:11] = solved_state.reshape(-1)#equivalence_classes.reshape(12,-1)[:1] 207 | parity_perm = np.array([5,3,4,1,2,0]) 208 | self.X[11:] = solved_state.reshape(6,6,8)[:,parity_perm].reshape(-1)#equivalence_classes.reshape(12,-1)[1:] 209 | self.Y = labels 210 | self.symmetry = RubiksCube() 211 | self.rep_in = 6*Vector 212 | self.rep_out = 2*Scalar 213 | self.stats = (0,1) 214 | if train==False: # Scramble the cubes for test time 215 | gs = self.symmetry.samples(440) 216 | self.X = np.repeat(self.X,20,axis=0).reshape(440,6,48)@gs 217 | self.Y = np.repeat(self.Y,20,axis=0) 218 | p = np.random.permutation(440) 219 | self.X = self.X[p].reshape((440,-1)) 220 | self.Y = self.Y[p].reshape((440,)) 221 | self.X = np.array(self.X) 222 | self.Y = np.array(self.Y) 223 | 224 | def __getitem__(self,i): 225 | return (self.X[i],self.Y[i]) 226 | def __len__(self): 227 | return self.X.shape[0] 228 | 229 | # @export 230 | # class BrokenRubiksCube2x2(object): 231 | # """ Binary classification problem of predicting whether a Rubik's cube configuration 232 | # is solvable or 'broken' and not able to be solved by transformations from the group 233 | # e.g. by removing and twisting a corner before replacing. 234 | # Dataset is generated by taking several hand identified simple instances of solvable 235 | # and unsolvable cubes, and then scrambling them. 236 | 237 | # Features are represented as 6xT(1) tensors of the Rubiks Group (one hot for each color)""" 238 | # def __init__(self,train=True): 239 | # super().__init__() 240 | # # start with a valid configuration 241 | 242 | # solved_state = np.zeros((6,24)) 243 | # for i in range(6): 244 | # solved_state[i,4*i:4*(i+1)] = 1 245 | 246 | # ULB_corner_ids = np.array([0,4*4,3*4+2]) # top left of U, top left of L, top right of B 247 | # rotated_corner_state = solved_state.copy() 248 | # rotated_corner_state[:,ULB_corner_ids] = rotated_corner_state[:,np.roll(ULB_corner_ids,1)] 249 | 250 | # labels = np.zeros((2,)) 251 | # labels[:1]=1 # First configuration is solvable 252 | # labels[1:]=0 # all others are not 253 | # self.X = np.zeros((2,6*24)) # duplicate solvable example 11 times for convencience (balance) 254 | # self.X[0] = solved_state.reshape(-1) 255 | # parity_perm = np.array([5,3,4,1,2,0]) 256 | # self.X[1] = rotated_corner_state.reshape(6,6,4)[:,parity_perm].reshape(-1) 257 | # self.Y = labels 258 | # self.X = np.repeat(self.X,10,axis=0) 259 | # self.Y = np.repeat(self.Y,10,axis=0) 260 | # self.symmetry = RubiksCube2x2() 261 | # self.rep_in = 6*Vector 262 | # self.rep_out = 2*Scalar 263 | # self.stats = (0,1) 264 | # if train==False: # Scramble the cubes for test time 265 | # N = 200 266 | # gs = self.symmetry.samples(N) 267 | # self.X = np.repeat(self.X,10,axis=0).reshape(N,6,24)@gs 268 | # self.Y = np.repeat(self.Y,10,axis=0) 269 | # p = np.random.permutation(N) 270 | # self.X = self.X[p].reshape((N,-1)) 271 | # self.Y = self.Y[p].reshape((N,)) 272 | # self.X = np.array(self.X) 273 | # self.Y = np.array(self.Y) 274 | 275 | # def __getitem__(self,i): 276 | # return (self.X[i],self.Y[i]) 277 | # def __len__(self): 278 | # return self.X.shape[0] -------------------------------------------------------------------------------- /emlp/nn/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pkgutil 3 | __all__ = [] # expose objax implementation as base nn 4 | module = importlib.import_module('.'+'objax',package=__name__) 5 | globals().update({k: getattr(module, k) for k in module.__all__}) 6 | __all__ += module.__all__ 7 | -------------------------------------------------------------------------------- /emlp/nn/flax.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | from emlp.reps import T,Rep,Scalar 5 | from emlp.reps import bilinear_weights 6 | #from emlp.reps import LinearOperator # why does this not work? 7 | from emlp.reps.linear_operator_base import LinearOperator 8 | from emlp.reps.product_sum_reps import SumRep 9 | from emlp.groups import Group 10 | from emlp.utils import Named,export 11 | from flax import linen as nn 12 | import logging 13 | from emlp.nn import gated,gate_indices,uniform_rep 14 | from typing import Union,Iterable,Optional 15 | # def Sequential(*args): 16 | # """ Wrapped to mimic pytorch syntax""" 17 | # return nn.Sequential(args) 18 | 19 | 20 | 21 | 22 | @export 23 | def Linear(repin,repout): 24 | """ Basic equivariant Linear layer from repin to repout.""" 25 | cin =repin.size() 26 | cout = repout.size() 27 | rep_W = repin>>repout 28 | Pw = rep_W.equivariant_projector() 29 | Pb = repout.equivariant_projector() 30 | logging.info(f"Linear W components:{rep_W.size()} rep:{rep_W}") 31 | return _Linear(Pw,Pb,cout) 32 | 33 | class _Linear(nn.Module): 34 | Pw:LinearOperator 35 | Pb:LinearOperator 36 | cout:int 37 | @nn.compact 38 | def __call__(self,x): 39 | w = self.param('w',nn.initializers.lecun_normal(),(self.cout,x.shape[-1])) 40 | b = self.param('b',nn.initializers.zeros,(self.cout,)) 41 | W = (self.Pw@w.reshape(-1)).reshape(*w.shape) 42 | B = self.Pb@b 43 | return x@W.T+B 44 | 45 | @export 46 | def BiLinear(repin,repout): 47 | """ Cheap bilinear layer (adds parameters for each part of the input which can be 48 | interpreted as a linear map from a part of the input to the output representation).""" 49 | Wdim, weight_proj = bilinear_weights(repout,repin) 50 | #self.w = TrainVar(objax.random.normal((Wdim,)))#xavier_normal((Wdim,))) #TODO: revert to xavier 51 | logging.info(f"BiW components: dim:{Wdim}") 52 | return _BiLinear(Wdim,weight_proj) 53 | 54 | class _BiLinear(nn.Module): 55 | Wdim:int 56 | weight_proj:callable 57 | 58 | @nn.compact 59 | def __call__(self, x): 60 | w = self.param('w',nn.initializers.normal(),(self.Wdim,)) #TODO: change to standard normal 61 | W = self.weight_proj(w,x) 62 | out= .1*(W@x[...,None])[...,0] 63 | return out 64 | 65 | 66 | @export 67 | class GatedNonlinearity(nn.Module): #TODO: add support for mixed tensors and non sumreps 68 | """ Gated nonlinearity. Requires input to have the additional gate scalars 69 | for every non regular and non scalar rep. Applies swish to regular and 70 | scalar reps. (Right now assumes rep is a SumRep)""" 71 | rep:Rep 72 | def __call__(self,values): 73 | gate_scalars = values[..., gate_indices(self.rep)] 74 | activations = jax.nn.sigmoid(gate_scalars) * values[..., :self.rep.size()] 75 | return activations 76 | 77 | @export 78 | def EMLPBlock(rep_in,rep_out): 79 | """ Basic building block of EMLP consisting of G-Linear, biLinear, 80 | and gated nonlinearity. """ 81 | linear = Linear(rep_in,gated(rep_out)) 82 | bilinear = BiLinear(gated(rep_out),gated(rep_out)) 83 | nonlinearity = GatedNonlinearity(rep_out) 84 | return _EMLPBlock(linear,bilinear,nonlinearity) 85 | 86 | class _EMLPBlock(nn.Module): 87 | linear:nn.Module 88 | bilinear:nn.Module 89 | nonlinearity:nn.Module 90 | 91 | def __call__(self,x): 92 | lin = self.linear(x) 93 | preact =self.bilinear(lin)+lin 94 | return self.nonlinearity(preact) 95 | 96 | @export 97 | def EMLP(rep_in,rep_out,group,ch=384,num_layers=3): 98 | """ Equivariant MultiLayer Perceptron. 99 | If the input ch argument is an int, uses the hands off uniform_rep heuristic. 100 | If the ch argument is a representation, uses this representation for the hidden layers. 101 | Individual layer representations can be set explicitly by using a list of ints or a list of 102 | representations, rather than use the same for each hidden layer. 103 | 104 | Args: 105 | rep_in (Rep): input representation 106 | rep_out (Rep): output representation 107 | group (Group): symmetry group 108 | ch (int or list[int] or Rep or list[Rep]): number of channels in the hidden layers 109 | num_layers (int): number of hidden layers 110 | 111 | Returns: 112 | Module: the EMLP objax module.""" 113 | logging.info("Initing EMLP (flax)") 114 | rep_in = rep_in(group) 115 | rep_out = rep_out(group) 116 | if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)] 117 | elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)] 118 | else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch] 119 | reps = [rep_in]+middle_layers 120 | logging.info(f"Reps: {reps}") 121 | return Sequential(*[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])],Linear(reps[-1],rep_out)) 122 | 123 | def swish(x): 124 | return jax.nn.sigmoid(x)*x 125 | 126 | class _Sequential(nn.Module): 127 | modules:Iterable[callable] 128 | def __call__(self,x): 129 | for module in self.modules: 130 | x = module(x) 131 | return x 132 | 133 | def Sequential(*layers): 134 | return _Sequential(layers) 135 | 136 | def MLPBlock(cout): 137 | return Sequential(nn.Dense(cout),swish) # ,nn.BatchNorm0D(cout,momentum=.9),swish)#, 138 | 139 | @export 140 | class MLP(nn.Module,metaclass=Named): 141 | """ Standard baseline MLP. Representations and group are used for shapes only. """ 142 | rep_in:Rep 143 | rep_out:Rep 144 | group:Group 145 | ch: Optional[InterruptedError]=384 146 | num_layers:Optional[int]=3 147 | def setup(self): 148 | logging.info("Initing MLP (flax)") 149 | cout = self.rep_out(self.group).size() 150 | self.modules = [MLPBlock(self.ch) for _ in range(self.num_layers)]+[nn.Dense(cout)] 151 | def __call__(self,x): 152 | for module in self.modules: 153 | x = module(x) 154 | return x 155 | -------------------------------------------------------------------------------- /emlp/nn/haiku.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | from emlp.reps import Rep 5 | from emlp.reps import bilinear_weights 6 | from emlp.utils import export 7 | import logging 8 | import haiku as hk 9 | from emlp.nn import gated,gate_indices,uniform_rep 10 | 11 | def Sequential(*args): 12 | """ Wrapped to mimic pytorch syntax""" 13 | return lambda x: hk.Sequential(args)(x) 14 | 15 | @export 16 | def Linear(repin,repout): 17 | rep_W = repout << repin 18 | logging.info(f"Linear W components:{rep_W.size()} rep:{rep_W}") 19 | rep_bias = repout 20 | Pw = rep_W.equivariant_projector() 21 | Pb = rep_bias.equivariant_projector() 22 | return lambda x: hkLinear(Pw,Pb,(repout.size(),repin.size()))(x) 23 | 24 | class hkLinear(hk.Module): 25 | """ Basic equivariant Linear layer from repin to repout.""" 26 | def __init__(self, Pw,Pb,shape,name=None): 27 | super().__init__(name=name) 28 | self.Pw = Pw 29 | self.Pb = Pb 30 | self.shape=shape 31 | 32 | def __call__(self, x): # (cin) -> (cout) 33 | i,j = self.shape 34 | w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(i)) 35 | w = hk.get_parameter("w", shape=self.shape, dtype=x.dtype, init=w_init) 36 | b = hk.get_parameter("b", shape=[i], dtype=x.dtype, init=w_init) 37 | W = (self.Pw@w.reshape(-1)).reshape(*self.shape) 38 | b = self.Pb@b 39 | return x@W.T+b 40 | 41 | @export 42 | def BiLinear(repin,repout): 43 | """ Cheap bilinear layer (adds parameters for each part of the input which can be 44 | interpreted as a linear map from a part of the input to the output representation).""" 45 | Wdim, weight_proj = bilinear_weights(repout,repin) 46 | return lambda x: hkBiLinear(weight_proj,Wdim)(x) 47 | 48 | 49 | class hkBiLinear(hk.Module): 50 | def __init__(self, weight_proj,Wdim,name=None): 51 | super().__init__(name=name) 52 | self.weight_proj=weight_proj 53 | self.Wdim=Wdim 54 | 55 | def __call__(self, x): 56 | # compatible with non sumreps? need to check 57 | w_init = hk.initializers.TruncatedNormal(1.) 58 | w = hk.get_parameter("w", shape=[self.Wdim], dtype=x.dtype, init=w_init) 59 | W = self.weight_proj(w,x) 60 | return .1*(W@x[...,None])[...,0] 61 | 62 | @export 63 | class GatedNonlinearity(object): # TODO: add support for mixed tensors and non sumreps 64 | """ Gated nonlinearity. Requires input to have the additional gate scalars 65 | for every non regular and non scalar rep. Applies swish to regular and 66 | scalar reps. (Right now assumes rep is a SumRep)""" 67 | def __init__(self,rep,name=None): 68 | super().__init__() 69 | self.rep=rep 70 | 71 | def __call__(self,values): 72 | gate_scalars = values[..., gate_indices(self.rep)] 73 | activations = jax.nn.sigmoid(gate_scalars) * values[..., :self.rep.size()] 74 | return activations 75 | 76 | 77 | @export 78 | def EMLPBlock(repin,repout): 79 | """ Basic building block of EMLP consisting of G-Linear, biLinear, 80 | and gated nonlinearity. """ 81 | linear = Linear(repin,gated(repout)) 82 | bilinear = BiLinear(gated(repout),gated(repout)) 83 | nonlinearity = GatedNonlinearity(repout) 84 | def block(x): 85 | lin = linear(x) 86 | preact =bilinear(lin)+lin 87 | return nonlinearity(preact) 88 | return block 89 | 90 | 91 | @export 92 | def EMLP(rep_in,rep_out,group,ch=384,num_layers=3): 93 | """ Equivariant MultiLayer Perceptron. 94 | If the input ch argument is an int, uses the hands off uniform_rep heuristic. 95 | If the ch argument is a representation, uses this representation for the hidden layers. 96 | Individual layer representations can be set explicitly by using a list of ints or a list of 97 | representations, rather than use the same for each hidden layer. 98 | 99 | Args: 100 | rep_in (Rep): input representation 101 | rep_out (Rep): output representation 102 | group (Group): symmetry group 103 | ch (int or list[int] or Rep or list[Rep]): number of channels in the hidden layers 104 | num_layers (int): number of hidden layers 105 | 106 | Returns: 107 | Module: the EMLP objax module.""" 108 | logging.info("Initing EMLP (Haiku)") 109 | rep_in =rep_in(group) 110 | rep_out = rep_out(group) 111 | # Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps 112 | if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)] 113 | elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)] 114 | else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch] 115 | # assert all((not rep.G is None) for rep in middle_layers[0].reps) 116 | reps = [rep_in]+middle_layers 117 | # logging.info(f"Reps: {reps}") 118 | network = Sequential( 119 | *[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])], 120 | Linear(reps[-1],rep_out) 121 | ) 122 | return network 123 | 124 | @export 125 | def MLP(rep_in,rep_out,group,ch=384,num_layers=3): 126 | """ Standard baseline MLP. Representations and group are used for shapes only. """ 127 | cout = rep_out(group).size() 128 | mlp = lambda x: Sequential( 129 | *[Sequential(hk.Linear(ch),jax.nn.swish) for _ in range(num_layers)], 130 | hk.Linear(cout) 131 | )(x) 132 | return mlp 133 | -------------------------------------------------------------------------------- /emlp/nn/objax.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import objax.nn as nn 4 | import objax.functional as F 5 | import numpy as np 6 | from emlp.reps import T,Rep,Scalar 7 | from emlp.reps import bilinear_weights 8 | from emlp.reps.product_sum_reps import SumRep 9 | import collections 10 | from emlp.utils import Named,export 11 | import scipy as sp 12 | import scipy.special 13 | import random 14 | import logging 15 | from objax.variable import TrainVar, StateVar 16 | from objax.nn.init import kaiming_normal, xavier_normal 17 | from objax.module import Module 18 | import objax 19 | from objax.nn.init import orthogonal 20 | from scipy.special import binom 21 | from jax import jit,vmap 22 | from functools import lru_cache as cache 23 | 24 | def Sequential(*args): 25 | """ Wrapped to mimic pytorch syntax""" 26 | return nn.Sequential(args) 27 | 28 | @export 29 | class Linear(nn.Linear): 30 | """ Basic equivariant Linear layer from repin to repout.""" 31 | def __init__(self, repin, repout): 32 | nin,nout = repin.size(),repout.size() 33 | super().__init__(nin,nout) 34 | self.b = TrainVar(objax.random.uniform((nout,))/jnp.sqrt(nout)) 35 | self.w = TrainVar(orthogonal((nout, nin))) 36 | self.rep_W = rep_W = repout*repin.T 37 | 38 | rep_bias = repout 39 | self.Pw = rep_W.equivariant_projector() 40 | self.Pb = rep_bias.equivariant_projector() 41 | logging.info(f"Linear W components:{rep_W.size()} rep:{rep_W}") 42 | def __call__(self, x): # (cin) -> (cout) 43 | logging.debug(f"linear in shape: {x.shape}") 44 | W = (self.Pw@self.w.value.reshape(-1)).reshape(*self.w.value.shape) 45 | b = self.Pb@self.b.value 46 | out = x@W.T+b 47 | logging.debug(f"linear out shape:{out.shape}") 48 | return out 49 | 50 | @export 51 | class BiLinear(Module): 52 | """ Cheap bilinear layer (adds parameters for each part of the input which can be 53 | interpreted as a linear map from a part of the input to the output representation).""" 54 | def __init__(self, repin, repout): 55 | super().__init__() 56 | Wdim, weight_proj = bilinear_weights(repout,repin) 57 | self.weight_proj = jit(weight_proj) 58 | self.w = TrainVar(objax.random.normal((Wdim,))) 59 | logging.info(f"BiW components: dim:{Wdim}") 60 | 61 | def __call__(self, x,training=True): 62 | # compatible with non sumreps? need to check 63 | W = self.weight_proj(self.w.value,x) 64 | out= .1*(W@x[...,None])[...,0] 65 | return out 66 | 67 | @export 68 | def gated(ch_rep:Rep) -> Rep: 69 | """ Returns the rep with an additional scalar 'gate' for each of the nonscalars and non regular 70 | reps in the input. To be used as the output for linear (and or bilinear) layers directly 71 | before a :func:`GatedNonlinearity` to produce its scalar gates. """ 72 | if isinstance(ch_rep,SumRep): 73 | return ch_rep+sum([Scalar(rep.G) for rep in ch_rep if rep!=Scalar and not rep.is_permutation]) 74 | else: 75 | return ch_rep+Scalar(ch_rep.G) if not ch_rep.is_permutation else ch_rep 76 | 77 | @export 78 | class GatedNonlinearity(Module): 79 | """ Gated nonlinearity. Requires input to have the additional gate scalars 80 | for every non regular and non scalar rep. Applies swish to regular and 81 | scalar reps. """ 82 | def __init__(self,rep): 83 | super().__init__() 84 | self.rep=rep 85 | def __call__(self,values): 86 | gate_scalars = values[..., gate_indices(self.rep)] 87 | activations = jax.nn.sigmoid(gate_scalars) * values[..., :self.rep.size()] 88 | return activations 89 | 90 | @export 91 | class EMLPBlock(Module): 92 | """ Basic building block of EMLP consisting of G-Linear, biLinear, 93 | and gated nonlinearity. """ 94 | def __init__(self,rep_in,rep_out): 95 | super().__init__() 96 | self.linear = Linear(rep_in,gated(rep_out)) 97 | self.bilinear = BiLinear(gated(rep_out),gated(rep_out)) 98 | self.nonlinearity = GatedNonlinearity(rep_out) 99 | def __call__(self,x): 100 | lin = self.linear(x) 101 | preact =self.bilinear(lin)+lin 102 | return self.nonlinearity(preact) 103 | 104 | def uniform_rep_general(ch,*rep_types): 105 | """ adds all combinations of (powers of) rep_types up to 106 | a total size of ch channels. """ 107 | raise NotImplementedError 108 | 109 | 110 | 111 | 112 | @export 113 | def uniform_rep(ch,group): 114 | """ A heuristic method for allocating a given number of channels (ch) 115 | into tensor types. Attempts to distribute the channels evenly across 116 | the different tensor types. Useful for hands off layer construction. 117 | 118 | Args: 119 | ch (int): total number of channels 120 | group (Group): symmetry group 121 | 122 | Returns: 123 | SumRep: The direct sum representation with dim(V)=ch 124 | """ 125 | d = group.d 126 | Ns = np.zeros((lambertW(ch,d)+1,),int) # number of tensors of each rank 127 | while ch>0: 128 | max_rank = lambertW(ch,d) # compute the max rank tensor that can fit up to 129 | Ns[:max_rank+1] += np.array([d**(max_rank-r) for r in range(max_rank+1)],dtype=int) 130 | ch -= (max_rank+1)*d**max_rank # compute leftover channels 131 | sum_rep = sum([binomial_allocation(nr,r,group) for r,nr in enumerate(Ns)]) 132 | sum_rep,perm = sum_rep.canonicalize() 133 | return sum_rep 134 | 135 | def lambertW(ch,d): 136 | """ Returns solution to x*d^x = ch rounded down.""" 137 | max_rank=0 138 | while (max_rank+1)*d**max_rank <= ch: 139 | max_rank += 1 140 | max_rank -= 1 141 | return max_rank 142 | 143 | def binomial_allocation(N,rank,G): 144 | """ Allocates N of tensors of total rank r=(p+q) into 145 | T(k,r-k) for k=0,1,...,r to match the binomial distribution. 146 | For orthogonal representations there is no 147 | distinction between p and q, so this op is equivalent to N*T(rank).""" 148 | if N==0: return 0 149 | n_binoms = N//(2**rank) 150 | n_leftover = N%(2**rank) 151 | even_split = sum([n_binoms*int(binom(rank,k))*T(k,rank-k,G) for k in range(rank+1)]) 152 | ps = np.random.binomial(rank,.5,n_leftover) 153 | ragged = sum([T(int(p),rank-int(p),G) for p in ps]) 154 | out = even_split+ragged 155 | return out 156 | 157 | def uniform_allocation(N,rank): 158 | """ Uniformly allocates N of tensors of total rank r=(p+q) into 159 | T(k,r-k) for k=0,1,...,r. For orthogonal representations there is no 160 | distinction between p and q, so this op is equivalent to N*T(rank).""" 161 | if N==0: return 0 162 | even_split = sum((N//(rank+1))*T(k,rank-k) for k in range(rank+1)) 163 | ragged = sum(random.sample([T(k,rank-k) for k in range(rank+1)],N%(rank+1))) 164 | return even_split+ragged 165 | 166 | @export 167 | class EMLP(Module,metaclass=Named): 168 | """ Equivariant MultiLayer Perceptron. 169 | If the input ch argument is an int, uses the hands off uniform_rep heuristic. 170 | If the ch argument is a representation, uses this representation for the hidden layers. 171 | Individual layer representations can be set explicitly by using a list of ints or a list of 172 | representations, rather than use the same for each hidden layer. 173 | 174 | Args: 175 | rep_in (Rep): input representation 176 | rep_out (Rep): output representation 177 | group (Group): symmetry group 178 | ch (int or list[int] or Rep or list[Rep]): number of channels in the hidden layers 179 | num_layers (int): number of hidden layers 180 | 181 | Returns: 182 | Module: the EMLP objax module.""" 183 | def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):#@ 184 | super().__init__() 185 | logging.info("Initing EMLP (objax)") 186 | self.rep_in =rep_in(group) 187 | self.rep_out = rep_out(group) 188 | 189 | self.G=group 190 | # Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps 191 | if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)]#[uniform_rep(ch,group) for _ in range(num_layers)] 192 | elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)] 193 | else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch] 194 | #assert all((not rep.G is None) for rep in middle_layers[0].reps) 195 | reps = [self.rep_in]+middle_layers 196 | logging.info(f"Reps: {reps}") 197 | self.network = Sequential( 198 | *[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])], 199 | Linear(reps[-1],self.rep_out) 200 | ) 201 | def __call__(self,x,training=True): 202 | return self.network(x) 203 | 204 | def swish(x): 205 | return jax.nn.sigmoid(x)*x 206 | 207 | def MLPBlock(cin,cout): 208 | return Sequential(nn.Linear(cin,cout),swish)#,nn.BatchNorm0D(cout,momentum=.9),swish)#, 209 | 210 | @export 211 | class MLP(Module,metaclass=Named): 212 | """ Standard baseline MLP. Representations and group are used for shapes only. """ 213 | def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3): 214 | super().__init__() 215 | self.rep_in =rep_in(group) 216 | self.rep_out = rep_out(group) 217 | self.G = group 218 | chs = [self.rep_in.size()] + num_layers*[ch] 219 | cout = self.rep_out.size() 220 | logging.info("Initing MLP") 221 | self.net = Sequential( 222 | *[MLPBlock(cin,cout) for cin,cout in zip(chs,chs[1:])], 223 | nn.Linear(chs[-1],cout) 224 | ) 225 | def __call__(self,x,training=True): 226 | y = self.net(x) 227 | return y 228 | 229 | @export 230 | class Standardize(Module): 231 | """ A convenience module to wrap a given module, normalize its input 232 | by some dataset x mean and std stats, and unnormalize its output by 233 | the dataset y mean and std stats. 234 | 235 | Args: 236 | model (Module): model to wrap 237 | ds_stats ((μx,σx,μy,σy) or (μx,σx)): tuple of the normalization stats 238 | 239 | Returns: 240 | Module: Wrapped model with input normalization (and output unnormalization)""" 241 | def __init__(self,model,ds_stats): 242 | super().__init__() 243 | self.model = model 244 | self.ds_stats=ds_stats 245 | def __call__(self,x,training): 246 | if len(self.ds_stats)==2: 247 | muin,sin = self.ds_stats 248 | return self.model((x-muin)/sin,training=training) 249 | else: 250 | muin,sin,muout,sout = self.ds_stats 251 | y = sout*self.model((x-muin)/sin,training=training)+muout 252 | return y 253 | 254 | 255 | 256 | # Networks for hamiltonian dynamics (need to sum for batched Hamiltonian grads) 257 | @export 258 | class MLPode(Module,metaclass=Named): 259 | def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3): 260 | super().__init__() 261 | self.rep_in =rep_in(group) 262 | self.rep_out = rep_out(group) 263 | self.G = group 264 | chs = [self.rep_in.size()] + num_layers*[ch] 265 | cout = self.rep_out.size() 266 | logging.info("Initing MLP") 267 | self.net = Sequential( 268 | *[Sequential(nn.Linear(cin,cout),swish) for cin,cout in zip(chs,chs[1:])], 269 | nn.Linear(chs[-1],cout) 270 | ) 271 | def __call__(self,z,t): 272 | return self.net(z) 273 | 274 | @export 275 | class EMLPode(EMLP): 276 | """ Neural ODE Equivariant MLP. Same args as EMLP.""" 277 | #__doc__ += EMLP.__doc__.split('.')[1] 278 | def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):#@ 279 | #super().__init__() 280 | logging.info("Initing EMLP") 281 | self.rep_in =rep_in(group) 282 | self.rep_out = rep_out(group) 283 | self.G=group 284 | # Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps 285 | if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)]#[uniform_rep(ch,group) for _ in range(num_layers)] 286 | elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)] 287 | else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch] 288 | #print(middle_layers[0].reps[0].G) 289 | #print(self.rep_in.G) 290 | reps = [self.rep_in]+middle_layers 291 | logging.info(f"Reps: {reps}") 292 | self.network = Sequential( 293 | *[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])], 294 | Linear(reps[-1],self.rep_out) 295 | ) 296 | def __call__(self,z,t): 297 | return self.network(z) 298 | 299 | # Networks for hamiltonian dynamics (need to sum for batched Hamiltonian grads) 300 | @export 301 | class MLPH(Module,metaclass=Named): 302 | def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3): 303 | super().__init__() 304 | self.rep_in =rep_in(group) 305 | self.rep_out = rep_out(group) 306 | self.G = group 307 | chs = [self.rep_in.size()] + num_layers*[ch] 308 | cout = self.rep_out.size() 309 | logging.info("Initing MLP") 310 | self.net = Sequential( 311 | *[Sequential(nn.Linear(cin,cout),swish) for cin,cout in zip(chs,chs[1:])], 312 | nn.Linear(chs[-1],cout) 313 | ) 314 | def H(self,x):#,training=True): 315 | y = self.net(x).sum() 316 | return y 317 | def __call__(self,x): 318 | return self.H(x) 319 | 320 | @export 321 | class EMLPH(EMLP): 322 | """ Equivariant EMLP modeling a Hamiltonian for HNN. Same args as EMLP""" 323 | #__doc__ += EMLP.__doc__.split('.')[1] 324 | def H(self,x):#,training=True): 325 | y = self.network(x) 326 | return y.sum() 327 | def __call__(self,x): 328 | return self.H(x) 329 | 330 | @export 331 | @cache(maxsize=None) 332 | def gate_indices(ch_rep:Rep) -> jnp.ndarray: 333 | """ Indices for scalars, and also additional scalar gates 334 | added by gated(sumrep)""" 335 | channels = ch_rep.size() 336 | perm = ch_rep.perm 337 | indices = np.arange(channels) 338 | 339 | if not isinstance(ch_rep,SumRep): # If just a single rep, only one scalar at end 340 | return indices if ch_rep.is_permutation else np.ones(ch_rep.size())*ch_rep.size() 341 | 342 | num_nonscalars = 0 343 | i=0 344 | for rep in ch_rep: 345 | if rep!=Scalar and not rep.is_permutation: 346 | indices[perm[i:i+rep.size()]] = channels+num_nonscalars 347 | num_nonscalars+=1 348 | i+=rep.size() 349 | return indices -------------------------------------------------------------------------------- /emlp/nn/pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | import jax 4 | from jax import jit 5 | from jax.tree_util import tree_flatten, tree_unflatten 6 | import types 7 | import copy 8 | import jax.numpy as jnp 9 | import numpy as np 10 | import types 11 | from functools import partial 12 | from emlp.reps import T,Rep,Scalar 13 | from emlp.reps import bilinear_weights 14 | from emlp.utils import Named,export 15 | import logging 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from emlp.nn import gated,gate_indices,uniform_rep 20 | 21 | def torch2jax(arr): 22 | if isinstance(arr,torch.Tensor): 23 | return jnp.asarray(arr.cpu().data.numpy()) 24 | else: 25 | return arr 26 | 27 | def jax2torch(arr): 28 | if isinstance(arr,(jnp.ndarray,np.ndarray)): 29 | if jax.devices()[0].platform=='gpu': 30 | device = torch.device('cuda') 31 | else: device = torch.device('cpu') 32 | return torch.from_numpy(np.array(arr)).to(device) 33 | else: 34 | return arr 35 | 36 | def to_jax(pytree): 37 | flat_values, tree_type = tree_flatten(pytree) 38 | transformed_flat = [torch2jax(v) for v in flat_values] 39 | return tree_unflatten(tree_type, transformed_flat) 40 | 41 | 42 | def to_pytorch(pytree): 43 | flat_values, tree_type = tree_flatten(pytree) 44 | transformed_flat = [jax2torch(v) for v in flat_values] 45 | return tree_unflatten(tree_type, transformed_flat) 46 | 47 | @export 48 | def torchify_fn(function): 49 | """ A method to enable interopability between jax and pytorch autograd. 50 | Calling torchify on a given function that has pytrees of jax.ndarray 51 | objects as inputs and outputs will return a function that first converts 52 | the inputs to jax, runs through the jax function, and converts the output 53 | back to torch but preserving the gradients of the operation to be called 54 | with pytorch autograd. """ 55 | vjp = jit(lambda *args: jax.vjp(function,*args)) 56 | class torched_fn(Function): 57 | @staticmethod 58 | def forward(ctx,*args): 59 | if any(ctx.needs_input_grad): 60 | y,ctx.vjp_fn = vjp(*to_jax(args))#jax.vjp(function,*to_jax(args)) 61 | return to_pytorch(y) 62 | return to_pytorch(function(*to_jax(args))) 63 | @staticmethod 64 | def backward(ctx,*grad_outputs): 65 | return to_pytorch(ctx.vjp_fn(*to_jax(grad_outputs))) 66 | return torched_fn.apply #TORCHED #Roasted 67 | 68 | 69 | @export 70 | class Linear(nn.Linear): 71 | """ Basic equivariant Linear layer from repin to repout.""" 72 | def __init__(self, repin, repout): 73 | nin,nout = repin.size(),repout.size() 74 | super().__init__(nin,nout) 75 | rep_W = repout*repin.T 76 | rep_bias = repout 77 | Pw = rep_W.equivariant_projector() 78 | Pb = rep_bias.equivariant_projector() 79 | self.proj_b = torchify_fn(jit(lambda b: Pb@b)) 80 | self.proj_w = torchify_fn(jit(lambda w:(Pw@w.reshape(-1)).reshape(nout,nin))) 81 | logging.info(f"Linear W components:{rep_W.size()} rep:{rep_W}") 82 | 83 | def forward(self, x): # (cin) -> (cout) 84 | return F.linear(x,self.proj_w(self.weight),self.proj_b(self.bias)) 85 | 86 | @export 87 | class BiLinear(nn.Module): 88 | """ Cheap bilinear layer (adds parameters for each part of the input which can be 89 | interpreted as a linear map from a part of the input to the output representation).""" 90 | def __init__(self, repin, repout): 91 | super().__init__() 92 | Wdim, weight_proj = bilinear_weights(repout,repin) 93 | self.weight_proj = torchify_fn(jit(weight_proj)) 94 | self.bi_params = nn.Parameter(torch.randn(Wdim)) 95 | logging.info(f"BiW components: dim:{Wdim}") 96 | 97 | def forward(self, x,training=True): 98 | # compatible with non sumreps? need to check 99 | W = self.weight_proj(self.bi_params,x) 100 | out= .1*(W@x[...,None])[...,0] 101 | return out 102 | 103 | @export 104 | class GatedNonlinearity(nn.Module): #TODO: add support for mixed tensors and non sumreps 105 | """ Gated nonlinearity. Requires input to have the additional gate scalars 106 | for every non regular and non scalar rep. Applies swish to regular and 107 | scalar reps. (Right now assumes rep is a SumRep)""" 108 | def __init__(self,rep): 109 | super().__init__() 110 | self.rep=rep 111 | def forward(self,values): 112 | gate_scalars = values[..., gate_indices(self.rep)] 113 | activations = gate_scalars.sigmoid() * values[..., :self.rep.size()] 114 | return activations 115 | 116 | @export 117 | class EMLPBlock(nn.Module): 118 | """ Basic building block of EMLP consisting of G-Linear, biLinear, 119 | and gated nonlinearity. """ 120 | def __init__(self,rep_in,rep_out): 121 | super().__init__() 122 | self.linear = Linear(rep_in,gated(rep_out)) 123 | self.bilinear = BiLinear(gated(rep_out),gated(rep_out)) 124 | self.nonlinearity = GatedNonlinearity(rep_out) 125 | 126 | def forward(self,x): 127 | lin = self.linear(x) 128 | preact =self.bilinear(lin)+lin 129 | return self.nonlinearity(preact) 130 | 131 | @export 132 | class EMLP(nn.Module): 133 | """ Equivariant MultiLayer Perceptron. 134 | If the input ch argument is an int, uses the hands off uniform_rep heuristic. 135 | If the ch argument is a representation, uses this representation for the hidden layers. 136 | Individual layer representations can be set explicitly by using a list of ints or a list of 137 | representations, rather than use the same for each hidden layer. 138 | 139 | Args: 140 | rep_in (Rep): input representation 141 | rep_out (Rep): output representation 142 | group (Group): symmetry group 143 | ch (int or list[int] or Rep or list[Rep]): number of channels in the hidden layers 144 | num_layers (int): number of hidden layers 145 | 146 | Returns: 147 | Module: the EMLP objax module.""" 148 | def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3): 149 | super().__init__() 150 | logging.info("Initing EMLP (PyTorch)") 151 | self.rep_in =rep_in(group) 152 | self.rep_out = rep_out(group) 153 | 154 | self.G=group 155 | # Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps 156 | if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)]#[uniform_rep(ch,group) for _ in range(num_layers)] 157 | elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)] 158 | else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch] 159 | #assert all((not rep.G is None) for rep in middle_layers[0].reps) 160 | reps = [self.rep_in]+middle_layers 161 | #logging.info(f"Reps: {reps}") 162 | self.network = nn.Sequential( 163 | *[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])], 164 | Linear(reps[-1],self.rep_out) 165 | ) 166 | def forward(self,x): 167 | return self.network(x) 168 | 169 | class Swish(nn.Module): 170 | def forward(self,x): 171 | return x.sigmoid()*x 172 | 173 | def MLPBlock(cin,cout): 174 | return nn.Sequential(nn.Linear(cin,cout),Swish())#,nn.BatchNorm0D(cout,momentum=.9),swish)#, 175 | 176 | @export 177 | class MLP(nn.Module): 178 | """ Standard baseline MLP. Representations and group are used for shapes only. """ 179 | def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3): 180 | super().__init__() 181 | self.rep_in =rep_in(group) 182 | self.rep_out = rep_out(group) 183 | self.G = group 184 | chs = [self.rep_in.size()] + num_layers*[ch] 185 | cout = self.rep_out.size() 186 | logging.info("Initing MLP") 187 | self.net = nn.Sequential( 188 | *[MLPBlock(cin,cout) for cin,cout in zip(chs,chs[1:])], 189 | nn.Linear(chs[-1],cout) 190 | ) 191 | 192 | def forward(self,x): 193 | y = self.net(x) 194 | return y 195 | 196 | @export 197 | class Standardize(nn.Module): 198 | """ A convenience module to wrap a given module, normalize its input 199 | by some dataset x mean and std stats, and unnormalize its output by 200 | the dataset y mean and std stats. 201 | 202 | Args: 203 | model (Module): model to wrap 204 | ds_stats ((μx,σx,μy,σy) or (μx,σx)): tuple of the normalization stats 205 | 206 | Returns: 207 | Module: Wrapped model with input normalization (and output unnormalization)""" 208 | def __init__(self,model,ds_stats): 209 | super().__init__() 210 | self.model = model 211 | self.ds_stats=ds_stats 212 | 213 | def forward(self,x,training): 214 | if len(self.ds_stats)==2: 215 | muin,sin = self.ds_stats 216 | return self.model((x-muin)/sin,training=training) 217 | else: 218 | muin,sin,muout,sout = self.ds_stats 219 | y = sout*self.model((x-muin)/sin,training=training)+muout 220 | return y 221 | -------------------------------------------------------------------------------- /emlp/reps/__init__.py: -------------------------------------------------------------------------------- 1 | # from .product_sum_reps import SumRep,DeferredSumRep,ProductRep,DeferredProductRep,DirectProduct 2 | # __all__=["SumRep","DeferredSumRep","ProductRep","DeferredProductRep","DirectProduct"] 3 | import importlib 4 | import pkgutil 5 | __all__ = [] 6 | for loader, module_name, is_pkg in pkgutil.walk_packages(__path__): 7 | module = importlib.import_module('.'+module_name,package=__name__) 8 | try: 9 | globals().update({k: getattr(module, k) for k in module.__all__}) 10 | __all__ += module.__all__ 11 | except AttributeError: continue 12 | 13 | # concatenate __all__ from each of the modules -------------------------------------------------------------------------------- /emlp/reps/linear_operators.py: -------------------------------------------------------------------------------- 1 | from .linear_operator_base import LinearOperator,Lazy 2 | import jax.numpy as jnp 3 | import numpy as np 4 | from jax import jit 5 | import jax 6 | from functools import reduce 7 | 8 | product = lambda c: reduce(lambda a,b:a*b,c) 9 | 10 | def lazify(x): 11 | if isinstance(x,LinearOperator): return x 12 | elif isinstance(x,(jnp.ndarray,np.ndarray)): return Lazy(x) 13 | else: raise NotImplementedError 14 | 15 | def densify(x): 16 | if isinstance(x,LinearOperator): return x.to_dense() 17 | elif isinstance(x,(jnp.ndarray,np.ndarray)): return x 18 | else: raise NotImplementedError 19 | 20 | class I(LinearOperator): 21 | def __init__(self,d): 22 | shape = (d,d) 23 | super().__init__(None, shape) 24 | def _matmat(self,V): #(c,k) 25 | return V 26 | def _matvec(self,V): 27 | return V 28 | def _adjoint(self): 29 | return self 30 | def invT(self): 31 | return self 32 | 33 | class LazyKron(LinearOperator): 34 | 35 | def __init__(self,Ms): 36 | self.Ms = Ms 37 | shape = product([Mi.shape[0] for Mi in Ms]), product([Mi.shape[1] for Mi in Ms]) 38 | #self.dtype=Ms[0].dtype 39 | super().__init__(None,shape) 40 | 41 | def _matvec(self,v): 42 | return self._matmat(v).reshape(-1) 43 | def _matmat(self,v): 44 | ev = v.reshape(*[Mi.shape[-1] for Mi in self.Ms],-1) 45 | for i,M in enumerate(self.Ms): 46 | ev_front = jnp.moveaxis(ev,i,0) 47 | Mev_front = (M@ev_front.reshape(M.shape[-1],-1)).reshape(M.shape[0],*ev_front.shape[1:]) 48 | ev = jnp.moveaxis(Mev_front,0,i) 49 | return ev.reshape(self.shape[0],ev.shape[-1]) 50 | def _adjoint(self): 51 | return LazyKron([Mi.T for Mi in self.Ms]) 52 | def invT(self): 53 | return LazyKron([M.invT() for M in self.Ms]) 54 | def to_dense(self): 55 | Ms = [M.to_dense() if isinstance(M,LinearOperator) else M for M in self.Ms] 56 | return reduce(jnp.kron,Ms) 57 | def __new__(cls,Ms): 58 | if len(Ms)==1: return Ms[0] 59 | return super().__new__(cls) 60 | 61 | #@jit 62 | def kronsum(A,B): 63 | return jnp.kron(A,jnp.eye(B.shape[-1])) + jnp.kron(jnp.eye(A.shape[-1]),B) 64 | 65 | 66 | class LazyKronsum(LinearOperator): 67 | 68 | def __init__(self,Ms): 69 | self.Ms = Ms 70 | shape = product([Mi.shape[0] for Mi in Ms]), product([Mi.shape[1] for Mi in Ms]) 71 | #self.dtype=Ms[0].dtype 72 | dtype=jnp.dtype('float32') 73 | super().__init__(dtype,shape) 74 | 75 | def _matvec(self,v): 76 | return self._matmat(v).reshape(-1) 77 | 78 | def _matmat(self,v): 79 | ev = v.reshape(*[Mi.shape[-1] for Mi in self.Ms],-1) 80 | out = 0*ev 81 | for i,M in enumerate(self.Ms): 82 | ev_front = jnp.moveaxis(ev,i,0) 83 | Mev_front = (M@ev_front.reshape(M.shape[-1],-1)).reshape(M.shape[0],*ev_front.shape[1:]) 84 | out += jnp.moveaxis(Mev_front,0,i) 85 | return out.reshape(self.shape[0],ev.shape[-1]) 86 | 87 | def _adjoint(self): 88 | return LazyKronsum([Mi.T for Mi in self.Ms]) 89 | def to_dense(self): 90 | Ms = [M.to_dense() if isinstance(M,LinearOperator) else M for M in self.Ms] 91 | return reduce(kronsum,Ms) 92 | def __new__(cls,Ms): 93 | if len(Ms)==1: return Ms[0] 94 | return super().__new__(cls) 95 | ## could also be implemented as follows, but the fusing the sum into a single linearOperator is faster 96 | # def lazy_kronsum(Ms): 97 | # n = len(Ms) 98 | # lprod = np.cumprod([1]+[mi.shape[-1] for mi in Ms]) 99 | # rprod = np.cumprod([1]+[mi.shape[-1] for mi in reversed(Ms)])[::-1] 100 | # return reduce(lambda a,b: a+b,[lazy_kron([I(lprod[i]),Mi,I(rprod[i+1])]) for i,Mi in enumerate(Ms)]) 101 | 102 | 103 | class LazyJVP(LinearOperator): 104 | def __init__(self,operator_fn,X,TX): 105 | self.shape = operator_fn(X).shape 106 | self.vjp = lambda v: jax.jvp(lambda x: operator_fn(x)@v,[X],[TX])[1] 107 | self.vjp_T = lambda v: jax.jvp(lambda x: operator_fn(x).T@v,[X],[TX])[1] 108 | self.dtype=jnp.dtype('float32') 109 | def _matmat(self,v): 110 | return self.vjp(v) 111 | def _matvec(self,v): 112 | return self.vjp(v) 113 | def _rmatmat(self,v): 114 | return self.vjp_T(v) 115 | 116 | 117 | class ConcatLazy(LinearOperator): 118 | """ Produces a linear operator equivalent to concatenating 119 | a collection of matrices Ms along axis=0 """ 120 | def __init__(self,Ms): 121 | self.Ms = Ms 122 | assert all(M.shape[0]==Ms[0].shape[0] for M in Ms),\ 123 | f"Trying to concatenate matrices of different sizes {[M.shape for M in Ms]}" 124 | shape = (sum(M.shape[0] for M in Ms),Ms[0].shape[1]) 125 | super().__init__(None,shape) 126 | 127 | def _matmat(self,V): 128 | return jnp.concatenate([M@V for M in self.Ms],axis=0) 129 | def _rmatmat(self,V): 130 | Vs = jnp.split(V,len(self.Ms)) 131 | return sum([self.Ms[i].T@Vs[i] for i in range(len(self.Ms))]) 132 | def to_dense(self): 133 | dense_Ms = [M.to_dense() if isinstance(M,LinearOperator) else M for M in self.Ms] 134 | return jnp.concatenate(dense_Ms,axis=0) 135 | 136 | class LazyDirectSum(LinearOperator): 137 | def __init__(self,Ms,multiplicities=None): 138 | self.Ms = [jax.device_put(M.astype(np.float32)) if isinstance(M,(np.ndarray)) else M for M in Ms] 139 | self.multiplicities = [1 for M in Ms] if multiplicities is None else multiplicities 140 | shape = (sum(Mi.shape[0]*c for Mi,c in zip(Ms,multiplicities)), 141 | sum(Mi.shape[0]*c for Mi,c in zip(Ms,multiplicities))) 142 | super().__init__(None,shape) 143 | #self.dtype=Ms[0].dtype 144 | #self.dtype=jnp.dtype('float32') 145 | 146 | def _matvec(self,v): 147 | return lazy_direct_matmat(v,self.Ms,self.multiplicities) 148 | 149 | def _matmat(self,v): # (n,k) 150 | return lazy_direct_matmat(v,self.Ms,self.multiplicities) 151 | def _adjoint(self): 152 | return LazyDirectSum([Mi.T for Mi in self.Ms]) 153 | def invT(self): 154 | return LazyDirectSum([M.invT() for M in self.Ms]) 155 | def to_dense(self): 156 | Ms_all = [M for M,c in zip(self.Ms,self.multiplicities) for _ in range(c)] 157 | Ms_all = [Mi.to_dense() if isinstance(Mi,LinearOperator) else Mi for Mi in Ms_all] 158 | return jax.scipy.linalg.block_diag(*Ms_all) 159 | # def __new__(cls,Ms,multiplicities=None): 160 | # if len(Ms)==1 and multiplicities is None: return Ms[0] 161 | # return super().__new__(cls) 162 | 163 | def lazy_direct_matmat(v,Ms,mults): 164 | n = v.shape[0] 165 | k = v.shape[1] if len(v.shape)>1 else 1 166 | i=0 167 | y = [] 168 | for M, multiplicity in zip(Ms,mults): 169 | i_end = i+multiplicity*M.shape[-1] 170 | elems = M@v[i:i_end].T.reshape(k*multiplicity,M.shape[-1]).T 171 | y.append(elems.T.reshape(k,multiplicity*M.shape[0]).T) 172 | i = i_end 173 | y = jnp.concatenate(y,axis=0) #concatenate over rep axis 174 | return y 175 | 176 | 177 | class LazyPerm(LinearOperator): 178 | def __init__(self,perm): 179 | self.perm=perm 180 | shape = (len(perm),len(perm)) 181 | super().__init__(None,shape) 182 | 183 | def _matmat(self,V): 184 | return V[self.perm] 185 | def _matvec(self,V): 186 | return V[self.perm] 187 | def _adjoint(self): 188 | return LazyPerm(np.argsort(self.perm)) 189 | def invT(self): 190 | return self 191 | 192 | class LazyShift(LinearOperator): 193 | def __init__(self,n,k=1): 194 | self.k=k 195 | shape = (n,n) 196 | super().__init__(None,shape) 197 | 198 | def _matmat(self,V): #(c,k) #Still needs to be tested?? 199 | return jnp.roll(V,self.k,axis=0) 200 | def _matvec(self,V): 201 | return jnp.roll(V,self.k,axis=0) 202 | def _adjoint(self): 203 | return LazyShift(self.shape[0],-self.k) 204 | def invT(self): 205 | return self 206 | 207 | class SwapMatrix(LinearOperator): 208 | def __init__(self,swaprows,n): 209 | self.swaprows=swaprows 210 | shape = (n,n) 211 | super().__init__(None,shape) 212 | def _matmat(self,V): #(c,k) 213 | V = jax.ops.index_update(V, jax.ops.index[self.swaprows], V[self.swaprows[::-1]]) 214 | return V 215 | def _matvec(self,V): 216 | return self._matmat(V) 217 | def _adjoint(self): 218 | return self 219 | def invT(self): 220 | return self 221 | 222 | class Rot90(LinearOperator): 223 | def __init__(self,n,k): 224 | shape = (n*n,n*n) 225 | self.n=n 226 | self.k = k 227 | super().__init__(None,shape) 228 | def _matmat(self,V): #(c,k) 229 | return jnp.rot90(V.reshape((self.n,self.n,-1)),self.k).reshape(V.shape) 230 | def _matvec(self,V): 231 | return jnp.rot90(V.reshape((self.n,self.n,-1)),self.k).reshape(V.shape) 232 | def invT(self): 233 | return self -------------------------------------------------------------------------------- /emlp/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | class Named(type): 4 | def __str__(self): 5 | return self.__name__ 6 | def __repr__(self): 7 | return self.__name__ 8 | 9 | def export(fn): 10 | mod = sys.modules[fn.__module__] 11 | if hasattr(mod, '__all__'): 12 | mod.__all__.append(fn.__name__) 13 | else: 14 | mod.__all__ = [fn.__name__] 15 | return fn -------------------------------------------------------------------------------- /experiments/data_efficiency.py: -------------------------------------------------------------------------------- 1 | from emlp.nn import MLP,EMLP#,LinearBNSwish 2 | from emlp.datasets import O5Synthetic, ParticleInteraction, Inertia 3 | from emlp.reps import T,Scalar 4 | from emlp.groups import SO, O, Trivial, O13, SO13, SO13p 5 | from oil.tuning.study import train_trial, Study 6 | 7 | from oil.tuning.args import argupdated_config 8 | from emlp.experiments.train_regression import makeTrainer 9 | import copy 10 | import emlp.datasets 11 | 12 | if __name__=="__main__": 13 | Trial = train_trial(makeTrainer) 14 | config_spec = copy.deepcopy(makeTrainer.__kwdefaults__) 15 | name = 'data_efficiency_nobn' 16 | config_spec['ndata'] = 30000+5000+1000 17 | config_spec['log_level'] = 'warning' 18 | # Run MLP baseline on datasets 19 | config_spec.update({ 20 | 'dataset':ParticleInteraction,#[O5Synthetic,Inertia,ParticleInteraction], 21 | 'network':MLP,'aug':[False,True], 22 | 'num_epochs':(lambda cfg: min(int(30*30000/cfg['split']['train']),1000)), 23 | 'split':{'train':[30,100,300,1000,3000,10000,30000],'test':5000,'val':1000}, 24 | }) 25 | config_spec = argupdated_config(config_spec,namespace=datasets.regression) 26 | name = f"{name}_{config_spec['dataset']}" 27 | thestudy = Study(Trial,{},study_name=name,base_log_dir=config_spec['trainer_config'].get('log_dir',None)) 28 | thestudy.run(num_trials=-3,new_config_spec=config_spec,ordered=True) 29 | 30 | # Now run the EMLP (with appropriate group) on the datasets 31 | config_spec['network']=EMLP 32 | config_spec['aug'] = False 33 | groups = {O5Synthetic:[SO(5),O(5)],Inertia:[SO(3),O(3)],ParticleInteraction:[SO13p(),SO13(),O13()]} 34 | config_spec['net_config']['group'] = groups[config_spec['dataset']] 35 | thestudy.run(num_trials=-3,new_config_spec=config_spec,ordered=True) 36 | print(thestudy.results_df()) 37 | -------------------------------------------------------------------------------- /experiments/datasets/batchnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | import numpy as np 5 | import objax.nn as nn 6 | import jax 7 | from jax import jit 8 | import jax.numpy as jnp 9 | from emlp.reps import Scalar 10 | from emlp.reps.product_sum_reps import SumRep 11 | import logging 12 | import objax.functional as F 13 | from functools import partial 14 | import objax 15 | from functools import lru_cache as cache 16 | 17 | 18 | 19 | @cache(maxsize=None) 20 | def gate_indices(sumrep): #TODO: add support for mixed_tensors 21 | """ Indices for scalars, and also additional scalar gates 22 | added by gated(sumrep)""" 23 | assert isinstance(sumrep,SumRep), f"unexpected type for gate indices {type(sumrep)}" 24 | channels = sumrep.size() 25 | perm = sumrep.perm 26 | indices = np.arange(channels) 27 | num_nonscalars = 0 28 | i=0 29 | for rep in sumrep: 30 | if rep!=Scalar and not rep.is_regular: 31 | indices[perm[i:i+rep.size()]] = channels+num_nonscalars 32 | num_nonscalars+=1 33 | i+=rep.size() 34 | return indices 35 | 36 | @cache(maxsize=None) 37 | def scalar_mask(sumrep): 38 | channels = sumrep.size() 39 | perm = sumrep.perm 40 | mask = np.ones(channels)>0 41 | i=0 42 | for rep in sumrep: 43 | if rep!=Scalar: mask[perm[i:i+rep.size()]] = False 44 | i+=rep.size() 45 | return mask 46 | 47 | @cache(maxsize=None) 48 | def regular_mask(sumrep): 49 | channels = sumrep.size() 50 | mask = np.ones(channels)<0 51 | i=0 52 | for rep in sumrep: 53 | if rep.is_regular: mask[perm[i:i+rep.size()]] = True 54 | i+=rep.size() 55 | return mask 56 | 57 | @export 58 | class TensorBN(nn.BatchNorm0D): #TODO: add suport for mixed tensors. 59 | """ Equivariant Batchnorm for tensor representations. 60 | Applies BN on Scalar channels and Mean only BN on others """ 61 | def __init__(self,rep): 62 | super().__init__(rep.size(),momentum=0.9) 63 | self.rep=rep 64 | def __call__(self,x,training): #TODO: support elementwise for regular reps 65 | #return x #DISABLE BN, harms performance!! !! 66 | smask = jax.device_put(scalar_mask(self.rep)) 67 | if training: 68 | m = x.mean(self.redux, keepdims=True) 69 | v = (x ** 2).mean(self.redux, keepdims=True) - m ** 2 70 | v = jnp.where(smask,v,ragged_gather_scatter((x ** 2).mean(self.redux),self.rep)) #in non scalar indices, divide by sum squared 71 | self.running_mean.value += (1 - self.momentum) * (m - self.running_mean.value) 72 | self.running_var.value += (1 - self.momentum) * (v - self.running_var.value) 73 | else: 74 | m, v = self.running_mean.value, self.running_var.value 75 | y = jnp.where(smask,self.gamma.value * (x - m) * F.rsqrt(v + self.eps) + self.beta.value,x*F.rsqrt(v+self.eps))#(x-m)*F.rsqrt(v + self.eps)) 76 | return y # switch to or (x-m) 77 | 78 | class MaskBN(nn.BatchNorm0D): 79 | """ Equivariant Batchnorm for tensor representations. 80 | Applies BN on Scalar channels and Mean only BN on others """ 81 | def __init__(self,ch): 82 | super().__init__(ch,momentum=0.9) 83 | 84 | def __call__(self,vals,mask,training=True): 85 | sum_dims = list(range(len(vals.shape[:-1]))) 86 | x_or_zero = jnp.where(mask[...,None],vals,0*vals) 87 | if training: 88 | num_valid = mask.sum(sum_dims) 89 | m = x_or_zero.sum(sum_dims)/num_valid 90 | v = (x_or_zero ** 2).sum(sum_dims)/num_valid - m ** 2 91 | self.running_mean.value += (1 - self.momentum) * (m - self.running_mean.value) 92 | self.running_var.value += (1 - self.momentum) * (v - self.running_var.value) 93 | else: 94 | m, v = self.running_mean.value, self.running_var.value 95 | return ((x_or_zero-m)*self.gamma.value*F.rsqrt(v + self.eps) + self.beta.value,mask) 96 | 97 | class TensorMaskBN(nn.BatchNorm0D): #TODO find discrepancies with pytorch version 98 | """ Equivariant Batchnorm for tensor representations. 99 | Applies BN on Scalar channels and Mean only BN on others """ 100 | def __init__(self,rep): 101 | super().__init__(rep.size(),momentum=0.9) 102 | self.rep=rep 103 | def __call__(self,x,mask,training): 104 | sum_dims = list(range(len(vals.shape[:-1]))) 105 | x_or_zero = jnp.where(mask[...,None],vals,0*vals) 106 | smask = jax.device_put(scalar_mask(self.rep)) 107 | if training: 108 | num_valid = mask.sum(sum_dims) 109 | m = x_or_zero.sum(sum_dims)/num_valid 110 | x2 = (x_or_zero ** 2).sum(sum_dims)/num_valid 111 | v = x2 - m ** 2 112 | v = jnp.where(smask,v,ragged_gather_scatter(x2,self.rep)) 113 | self.running_mean.value += (1 - self.momentum) * (m - self.running_mean.value) 114 | self.running_var.value += (1 - self.momentum) * (v - self.running_var.value) 115 | else: 116 | m, v = self.running_mean.value, self.running_var.value 117 | y = jnp.where(smask,self.gamma.value * (x_or_zero - m) * F.rsqrt(v + self.eps) + \ 118 | self.beta.value,x_or_zero*F.rsqrt(v+self.eps)) 119 | return y,mask # switch to or (x-m) 120 | 121 | # @partial(jit,static_argnums=(1,)) 122 | # def ragged_gather_scatter(x,x_rep): 123 | # y = [] 124 | # i=0 125 | # for rep in x_rep.reps: # sum -> mean 126 | # y.append(x[i:i+rep.size()].mean(keepdims=True).repeat(rep.size(),axis=-1)) 127 | # i+=rep.size() 128 | # return jnp.concatenate(y,-1) 129 | 130 | @partial(jit,static_argnums=(1,)) 131 | def ragged_gather_scatter(x,x_rep): 132 | perm = x_rep.argsort() 133 | invperm = np.argsort(perm) 134 | x_sorted = x[perm] 135 | i=0 136 | y=[] 137 | for rep, multiplicity in x_rep.multiplicities().items(): 138 | i_end = i+multiplicity*rep.size() 139 | y.append(x_sorted[i:i_end].reshape(multiplicity,rep.size()).mean(-1,keepdims=True).repeat(rep.size(),axis=-1).reshape(-1)) 140 | i = i_end 141 | return jnp.concatenate(y)[invperm] -------------------------------------------------------------------------------- /experiments/depreciated/train_cube_simple.py: -------------------------------------------------------------------------------- 1 | 2 | from emlp.models.datasets import InvertedCube 3 | from emlp.solver.groups import Cube 4 | from emlp.models.mlp import MLP,EMLP,Standardize 5 | from emlp.models.model_trainer import ClassifierPlus 6 | from emlp.solver.representation import T 7 | from slax.model_trainers import Classifier 8 | from torch.utils.data import DataLoader 9 | from oil.utils.utils import cosLr, islice, export,FixedNumpySeed,FixedPytorchSeed 10 | from slax.utils import LoaderTo 11 | from oil.tuning.study import train_trial 12 | from oil.datasetup.datasets import split_dataset 13 | from oil.tuning.args import argupdated_config 14 | import objax 15 | import logging 16 | import emlp.models 17 | 18 | # intermediate_rep1 = (100*T(0)+5*T(1))#+T(2)) 19 | # intermediate_rep2 = 100*T(0)+10*T(1) 20 | #middle_reps = [intermediate_rep1,intermediate_rep2,intermediate_rep1] 21 | rep = 138*T(0)+23*T(1)+T(2) 22 | def makeTrainer(*,network=EMLP,num_epochs=500,seed=2020,aug=False, 23 | bs=50,lr=1e-3,device='cuda', 24 | net_config={'num_layers':3,'ch':rep,'group':Cube()},log_level='info', 25 | trainer_config={'log_dir':None,'log_args':{'minPeriod':.02,'timeFrac':50}},save=False): 26 | levels = {'critical': logging.CRITICAL,'error': logging.ERROR, 27 | 'warn': logging.WARNING,'warning': logging.WARNING, 28 | 'info': logging.INFO,'debug': logging.DEBUG} 29 | logging.getLogger().setLevel(levels[log_level]) 30 | # Prep the datasets splits, model, and dataloaders 31 | with FixedNumpySeed(seed),FixedPytorchSeed(seed): 32 | datasets = {'train':InvertedCube(train=True),'test':InvertedCube(train=False)} 33 | model = Standardize(network(datasets['train'].rep_in,datasets['train'].rep_out,**net_config),datasets['train'].stats) 34 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=min(bs,len(v)),shuffle=(k=='train'), 35 | num_workers=0,pin_memory=False)) for k,v in datasets.items()} 36 | dataloaders['Train'] = dataloaders['train'] 37 | opt_constr = objax.optimizer.Adam 38 | lr_sched = lambda e: lr*cosLr(num_epochs)(e) 39 | return ClassifierPlus(model,dataloaders,opt_constr,lr_sched,**trainer_config) 40 | 41 | if __name__ == "__main__": 42 | Trial = train_trial(makeTrainer) 43 | Trial(argupdated_config(makeTrainer.__kwdefaults__,namespace=(emlp.solver.groups,emlp.models.datasets,emlp.models.mlp))) -------------------------------------------------------------------------------- /experiments/depreciated/train_rubiks.py: -------------------------------------------------------------------------------- 1 | 2 | from emlp.models.datasets import BrokenRubiksCube,BrokenRubiksCube2x2 3 | from emlp.solver.groups import RubiksCube,RubiksCube2x2 4 | from emlp.models.mlp import MLP,EMLP,Standardize 5 | from emlp.models.model_trainer import ClassifierPlus 6 | from emlp.solver.representation import T 7 | from slax.model_trainers import Classifier 8 | from torch.utils.data import DataLoader 9 | from oil.utils.utils import cosLr, islice, export,FixedNumpySeed,FixedPytorchSeed 10 | from slax.utils import LoaderTo 11 | from oil.tuning.study import train_trial 12 | from oil.datasetup.datasets import split_dataset 13 | from oil.tuning.args import argupdated_config 14 | import objax 15 | import logging 16 | import emlp.models 17 | 18 | # intermediate_rep1 = (100*T(0)+5*T(1))#+T(2)) 19 | rep = 50*T(0)+5*T(1)+T(2) 20 | #middle_reps = [intermediate_rep1,intermediate_rep2,intermediate_rep1] 21 | def makeTrainer(*,network=EMLP,num_epochs=500,seed=2020,aug=False, 22 | bs=50,lr=1e-3,device='cuda', 23 | net_config={'num_layers':3,'ch':rep,'group':RubiksCube2x2()},log_level='info', 24 | trainer_config={'log_dir':None,'log_args':{'minPeriod':.02,'timeFrac':50}},save=False): 25 | levels = {'critical': logging.CRITICAL,'error': logging.ERROR, 26 | 'warn': logging.WARNING,'warning': logging.WARNING, 27 | 'info': logging.INFO,'debug': logging.DEBUG} 28 | logging.getLogger().setLevel(levels[log_level]) 29 | # Prep the datasets splits, model, and dataloaders 30 | with FixedNumpySeed(seed),FixedPytorchSeed(seed): 31 | datasets = {'train':BrokenRubiksCube2x2(train=False),'test':BrokenRubiksCube2x2(train=False)} 32 | model = Standardize(network(datasets['train'].rep_in,datasets['train'].rep_out,**net_config),datasets['train'].stats) 33 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=min(bs,len(v)),shuffle=(k=='train'), 34 | num_workers=0,pin_memory=False)) for k,v in datasets.items()} 35 | dataloaders['Train'] = dataloaders['train'] 36 | opt_constr = objax.optimizer.Adam 37 | lr_sched = lambda e: lr*cosLr(num_epochs)(e)#*min(1,e/(num_epochs/10)) 38 | return ClassifierPlus(model,dataloaders,opt_constr,lr_sched,**trainer_config) 39 | 40 | if __name__ == "__main__": 41 | Trial = train_trial(makeTrainer) 42 | Trial(argupdated_config(makeTrainer.__kwdefaults__,namespace=(emlp.solver.groups,emlp.models.datasets,emlp.models.mlp))) -------------------------------------------------------------------------------- /experiments/depreciated/train_tagging.py: -------------------------------------------------------------------------------- 1 | from emlp.models.mlp import MLP,EMLP#,LinearBNSwish 2 | from emlp.models.datasets import Fr,ParticleInteraction 3 | import jax.numpy as jnp 4 | import jax 5 | from emlp.solver.representation import T,Scalar,Matrix,Vector,Quad,repsize 6 | from emlp.solver.groups import SO,O,Trivial,Lorentz,O13,SO13,SO13p 7 | from emlp.models.mlp import EMLP,LieLinear,Standardize,EMLP2 8 | from emlp.models.model_trainer import RegressorPlus 9 | import itertools 10 | import numpy as np 11 | import torch 12 | from emlp.models.datasets import Inertia,Fr,ParticleInteraction 13 | from emlp.models.particle_dataset import TopTagging,collate_fn 14 | import objax 15 | import torch 16 | from torch.utils.data import DataLoader 17 | from oil.utils.utils import cosLr, islice, export,FixedNumpySeed,FixedPytorchSeed 18 | from slax.utils import LoaderTo 19 | from oil.tuning.study import train_trial 20 | from oil.datasetup.datasets import split_dataset 21 | from oil.tuning.args import argupdated_config 22 | from slax.model_trainers import Classifier,Trainer 23 | from functools import partial 24 | import torch.nn as nn 25 | import logging 26 | import emlp.models 27 | from emlp.models.pointconv_base import ResNet 28 | from emlp.models.model_trainer import ClassifierPlus 29 | 30 | def makeTrainer(*,network=ResNet,num_epochs=5,seed=2020,aug=False, 31 | bs=30,lr=1e-3,device='cuda',split={'train':-1,'val':10000}, 32 | net_config={'k':512,'num_layers':4},log_level='info', 33 | trainer_config={'log_dir':None,'log_args':{'minPeriod':.02,'timeFrac':.2}},save=False): 34 | # Prep the datasets splits, model, and dataloaders 35 | datasets = {split:TopTagging(split=split) for split in ['train','val']} 36 | model = network(4,2,**net_config) 37 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=bs,shuffle=(k=='train'), 38 | num_workers=0,pin_memory=False,collate_fn=collate_fn,drop_last=True)) for k,v in datasets.items()} 39 | dataloaders['Train'] = islice(dataloaders['train'],0,None,10) #for logging subsample dataset by 5x 40 | #equivariance_test(model,dataloaders['train'],net_config['group']) 41 | opt_constr = objax.optimizer.Adam 42 | lr_sched = lambda e: lr*cosLr(num_epochs)(e) 43 | return ClassifierPlus(model,dataloaders,opt_constr,lr_sched,**trainer_config) 44 | 45 | if __name__ == "__main__": 46 | Trial = train_trial(makeTrainer) 47 | Trial(argupdated_config(makeTrainer.__kwdefaults__,namespace=(emlp.solver.groups,emlp.models.datasets,emlp.models.mlp))) -------------------------------------------------------------------------------- /experiments/hnn.py: -------------------------------------------------------------------------------- 1 | from emlp.nn import MLP,EMLP,MLPH,EMLPH 2 | from emlp.groups import SO2eR3,O2eR3,DkeR3,Trivial 3 | from emlp.reps import Scalar 4 | from trainer.hamiltonian_dynamics import IntegratedDynamicsTrainer,DoubleSpringPendulum,hnn_trial 5 | from torch.utils.data import DataLoader 6 | from oil.utils.utils import cosLr,FixedNumpySeed,FixedPytorchSeed 7 | from trainer.utils import LoaderTo 8 | from oil.datasetup.datasets import split_dataset 9 | from oil.tuning.args import argupdated_config 10 | import torch.nn as nn 11 | import logging 12 | import emlp 13 | import emlp.reps 14 | import objax 15 | 16 | levels = {'critical': logging.CRITICAL,'error': logging.ERROR, 17 | 'warn': logging.WARNING,'warning': logging.WARNING, 18 | 'info': logging.INFO,'debug': logging.DEBUG} 19 | 20 | def makeTrainer(*,dataset=DoubleSpringPendulum,network=EMLPH,num_epochs=2000,ndata=5000,seed=2021,aug=False, 21 | bs=500,lr=3e-3,device='cuda',split={'train':500,'val':.1,'test':.1}, 22 | net_config={'num_layers':3,'ch':128,'group':O2eR3()},log_level='info', 23 | trainer_config={'log_dir':None,'log_args':{'minPeriod':.02,'timeFrac':.75},},#'early_stop_metric':'val_MSE'}, 24 | save=False,): 25 | 26 | logging.getLogger().setLevel(levels[log_level]) 27 | # Prep the datasets splits, model, and dataloaders 28 | with FixedNumpySeed(seed),FixedPytorchSeed(seed): 29 | base_ds = dataset(n_systems=ndata,chunk_len=5) 30 | datasets = split_dataset(base_ds,splits=split) 31 | if net_config['group'] is None: net_config['group']=base_ds.symmetry 32 | model = network(base_ds.rep_in,Scalar,**net_config) 33 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=min(bs,len(v)),shuffle=(k=='train'), 34 | num_workers=0,pin_memory=False)) for k,v in datasets.items()} 35 | dataloaders['Train'] = dataloaders['train'] 36 | #equivariance_test(model,dataloaders['train'],net_config['group']) 37 | opt_constr = objax.optimizer.Adam 38 | lr_sched = lambda e: lr#*cosLr(num_epochs)(e)#*min(1,e/(num_epochs/10)) 39 | return IntegratedDynamicsTrainer(model,dataloaders,opt_constr,lr_sched,**trainer_config) 40 | 41 | if __name__ == "__main__": 42 | Trial = hnn_trial(makeTrainer) 43 | cfg,outcome = Trial(argupdated_config(makeTrainer.__kwdefaults__,namespace=(emlp.groups,emlp.nn))) 44 | print(outcome) 45 | 46 | -------------------------------------------------------------------------------- /experiments/hnn_expt.py: -------------------------------------------------------------------------------- 1 | from oil.tuning.study import Study 2 | from emlp.nn import MLP,EMLP,MLPH,EMLPH 3 | from emlp.groups import SO2eR3,O2eR3,DkeR3,Trivial 4 | 5 | import copy 6 | from trainer.hamiltonian_dynamics import hnn_trial 7 | from hnn import makeTrainer 8 | 9 | if __name__=="__main__": 10 | Trial = hnn_trial(makeTrainer) 11 | config_spec = copy.deepcopy(makeTrainer.__kwdefaults__) 12 | name = "hnn_expt"#config_spec.pop('study_name') 13 | #name = f"{name}_{config_spec['dataset']}" 14 | thestudy = Study(Trial,{},study_name=name,base_log_dir=config_spec['trainer_config'].get('log_dir',None)) 15 | config_spec['network'] = EMLPH 16 | config_spec['net_config']['group'] = [O2eR3(),SO2eR3(),DkeR3(6),DkeR3(2)] 17 | thestudy.run(num_trials=-5,new_config_spec=config_spec,ordered=True) 18 | config_spec['network'] = MLPH 19 | config_spec['net_config']['group'] = None 20 | thestudy.run(num_trials=-3,new_config_spec=config_spec,ordered=True) 21 | print(thestudy.results_df()) 22 | -------------------------------------------------------------------------------- /experiments/neuralode.py: -------------------------------------------------------------------------------- 1 | from emlp.nn import MLP,EMLP,MLPH,EMLPH,EMLPode,MLPode#,LinearBNSwish 2 | from emlp.groups import SO2eR3,O2eR3,DkeR3,Trivial 3 | from trainer.hamiltonian_dynamics import IntegratedODETrainer,DoubleSpringPendulum,ode_trial 4 | from torch.utils.data import DataLoader 5 | from oil.utils.utils import cosLr, islice, FixedNumpySeed,FixedPytorchSeed 6 | from trainer.utils import LoaderTo 7 | from oil.tuning.study import train_trial 8 | from oil.datasetup.datasets import split_dataset 9 | from oil.tuning.args import argupdated_config 10 | import logging 11 | import emlp.nn 12 | import emlp.groups 13 | import objax 14 | 15 | levels = {'critical': logging.CRITICAL,'error': logging.ERROR, 16 | 'warn': logging.WARNING,'warning': logging.WARNING, 17 | 'info': logging.INFO,'debug': logging.DEBUG} 18 | 19 | def makeTrainer(*,dataset=DoubleSpringPendulum,network=EMLPode,num_epochs=2000,ndata=5000,seed=2021,aug=False, 20 | bs=500,lr=3e-3,device='cuda',split={'train':500,'val':.1,'test':.1}, 21 | net_config={'num_layers':3,'ch':128,'group':O2eR3()},log_level='warn', 22 | trainer_config={'log_dir':None,'log_args':{'minPeriod':.02,'timeFrac':.75},},#'early_stop_metric':'val_MSE'}, 23 | save=False,): 24 | 25 | logging.getLogger().setLevel(levels[log_level]) 26 | # Prep the datasets splits, model, and dataloaders 27 | with FixedNumpySeed(seed),FixedPytorchSeed(seed): 28 | base_ds = dataset(n_systems=ndata,chunk_len=5) 29 | datasets = split_dataset(base_ds,splits=split) 30 | if net_config['group'] is None: net_config['group']=base_ds.symmetry 31 | model = network(base_ds.rep_in,base_ds.rep_in,**net_config) 32 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=min(bs,len(v)),shuffle=(k=='train'), 33 | num_workers=0,pin_memory=False)) for k,v in datasets.items()} 34 | dataloaders['Train'] = dataloaders['train'] 35 | #equivariance_test(model,dataloaders['train'],net_config['group']) 36 | opt_constr = objax.optimizer.Adam 37 | lr_sched = lambda e: lr#*cosLr(num_epochs)(e)#*min(1,e/(num_epochs/10)) 38 | return IntegratedODETrainer(model,dataloaders,opt_constr,lr_sched,**trainer_config) 39 | 40 | if __name__ == "__main__": 41 | Trial = ode_trial(makeTrainer) 42 | cfg,outcome = Trial(argupdated_config(makeTrainer.__kwdefaults__,namespace=(emlp.groups,emlp.nn))) 43 | print(outcome) 44 | -------------------------------------------------------------------------------- /experiments/neuralode_expt.py: -------------------------------------------------------------------------------- 1 | from emlp.nn import MLP,EMLP,MLPH,EMLPH,EMLPode,MLPode#,LinearBNSwish 2 | from emlp.groups import SO2eR3,O2eR3,DkeR3,Trivial 3 | from oil.tuning.study import Study 4 | 5 | import copy 6 | from trainer.hamiltonian_dynamics import ode_trial 7 | from neuralode import makeTrainer 8 | 9 | if __name__ == "__main__": 10 | Trial = ode_trial(makeTrainer) 11 | config_spec = copy.deepcopy(makeTrainer.__kwdefaults__) 12 | name = "ode_expt"#config_spec.pop('study_name') 13 | 14 | #name = f"{name}_{config_spec['dataset']}" 15 | thestudy = Study(Trial,{},study_name=name,base_log_dir=config_spec['trainer_config'].get('log_dir',None)) 16 | config_spec['network'] = EMLPode 17 | config_spec['net_config']['group'] = [O2eR3(),SO2eR3(),DkeR3(6),DkeR3(2)] 18 | thestudy.run(num_trials=-5,new_config_spec=config_spec,ordered=True) 19 | config_spec['network'] = MLPode 20 | config_spec['net_config']['group'] = None 21 | thestudy.run(num_trials=-3,new_config_spec=config_spec,ordered=True) 22 | print(thestudy.results_df()) -------------------------------------------------------------------------------- /experiments/train_regression.py: -------------------------------------------------------------------------------- 1 | from emlp.nn import MLP, EMLP, Standardize 2 | from trainer.model_trainer import RegressorPlus 3 | from torch.utils.data import DataLoader 4 | from oil.utils.utils import cosLr, FixedNumpySeed, FixedPytorchSeed 5 | from trainer.utils import LoaderTo 6 | from oil.tuning.study import train_trial 7 | from oil.datasetup.datasets import split_dataset 8 | from oil.tuning.args import argupdated_config 9 | import logging 10 | import emlp.nn 11 | import emlp.reps 12 | import emlp.groups 13 | import objax 14 | import emlp.datasets 15 | from emlp.datasets import Inertia,O5Synthetic,ParticleInteraction 16 | 17 | log_levels = {'critical': logging.CRITICAL,'error': logging.ERROR, 18 | 'warn': logging.WARNING,'warning': logging.WARNING, 19 | 'info': logging.INFO,'debug': logging.DEBUG} 20 | 21 | def makeTrainer(*,dataset=Inertia,network=EMLP,num_epochs=300,ndata=1000+2000,seed=2021,aug=False, 22 | bs=500,lr=3e-3,device='cuda',split={'train':-1,'val':1000,'test':1000}, 23 | net_config={'num_layers':3,'ch':384,'group':None},log_level='info', 24 | trainer_config={'log_dir':None,'log_args':{'minPeriod':.02,'timeFrac':.75}, 25 | 'early_stop_metric':'val_MSE'},save=False,): 26 | 27 | logging.getLogger().setLevel(log_levels[log_level]) 28 | # Prep the datasets splits, model, and dataloaders 29 | with FixedNumpySeed(seed),FixedPytorchSeed(seed): 30 | base_dataset = dataset(ndata) 31 | datasets = split_dataset(base_dataset,splits=split) 32 | if net_config['group'] is None: net_config['group']=base_dataset.symmetry 33 | model = network(base_dataset.rep_in,base_dataset.rep_out,**net_config) 34 | if aug: model = base_dataset.default_aug(model) 35 | model = Standardize(model,datasets['train'].stats) 36 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=min(bs,len(v)),shuffle=(k=='train'), 37 | num_workers=0,pin_memory=False)) for k,v in datasets.items()} 38 | dataloaders['Train'] = dataloaders['train'] 39 | opt_constr = objax.optimizer.Adam 40 | lr_sched = lambda e: lr#*min(1,e/(num_epochs/10)) # Learning rate warmup 41 | return RegressorPlus(model,dataloaders,opt_constr,lr_sched,**trainer_config) 42 | 43 | if __name__ == "__main__": 44 | cfg = argupdated_config(makeTrainer.__kwdefaults__, 45 | namespace=(emlp.groups,emlp.datasets,emlp.nn)) 46 | trainer = makeTrainer(**cfg) 47 | trainer.train(cfg['num_epochs']) 48 | -------------------------------------------------------------------------------- /experiments/trainer/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from oil.utils.utils import export 4 | from .trainer import Trainer 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | 9 | def cross_entropy(logprobs, targets): 10 | ll = jnp.take_along_axis(logprobs, jnp.expand_dims(targets, axis=1), axis=1) 11 | ce = -jnp.mean(ll) 12 | return ce 13 | 14 | @export 15 | class Classifier(Trainer): 16 | """ Trainer subclass. Implements loss (crossentropy), batchAccuracy 17 | and getAccuracy (full dataset) """ 18 | 19 | def loss(self,minibatch): 20 | """ Standard cross-entropy loss """ #TODO: support class weights 21 | x,y = minibatch 22 | logits = self.model(x,training=True) 23 | logp = jax.nn.log_softmax(logits) 24 | return cross_entropy(logp,y) 25 | 26 | def metrics(self,loader): 27 | acc = lambda mb: np.asarray(jax.device_get(jnp.mean(jnp.argmax(self.model.predict(mb[0]),axis=-1)==mb[1]))) 28 | return {'Acc':self.evalAverageMetrics(loader,acc)} 29 | 30 | @export 31 | class Regressor(Trainer): 32 | """ Trainer subclass. Implements loss (crossentropy), batchAccuracy 33 | and getAccuracy (full dataset) """ 34 | 35 | def loss(self,minibatch): 36 | """ Standard cross-entropy loss """ 37 | x,y = minibatch 38 | mse = jnp.mean((self.model(x,training=True)-y)**2) 39 | return mse 40 | 41 | def metrics(self,loader): 42 | mse = lambda mb: np.asarray(jax.device_get(jnp.mean((self.model.predict(mb[0])-mb[1])**2))) 43 | return {'MSE':self.evalAverageMetrics(loader,mse)} -------------------------------------------------------------------------------- /experiments/trainer/model_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from oil.utils.utils import export 4 | import jax 5 | from jax import vmap 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import objax 9 | from .classifier import Regressor,Classifier 10 | from functools import partial 11 | from itertools import islice 12 | 13 | def rel_err(a,b): 14 | return jnp.sqrt(((a-b)**2).mean())/(jnp.sqrt((a**2).mean())+jnp.sqrt((b**2).mean()))# 15 | 16 | def scale_adjusted_rel_err(a,b,g): 17 | return jnp.sqrt(((a-b)**2).mean())/(jnp.sqrt((a**2).mean())+jnp.sqrt((b**2).mean())+jnp.abs(g-jnp.eye(g.shape[-1])).mean()) 18 | 19 | def equivariance_err(model,mb,group=None): 20 | x,y = mb 21 | group = model.model.G if group is None else group 22 | gs = group.samples(x.shape[0]) 23 | rho_gin = vmap(model.model.rep_in.rho_dense)(gs) 24 | rho_gout = vmap(model.model.rep_out.rho_dense)(gs) 25 | y1 = model.predict((rho_gin@x[...,None])[...,0]) 26 | y2 = (rho_gout@model.predict(x)[...,None])[...,0] 27 | return np.asarray(scale_adjusted_rel_err(y1,y2,gs)) 28 | 29 | @export 30 | class RegressorPlus(Regressor): 31 | """ Trainer subclass. Implements loss (crossentropy), batchAccuracy 32 | and getAccuracy (full dataset) """ 33 | def __init__(self,model,*args,**kwargs): 34 | super().__init__(model,*args,**kwargs) 35 | fastloss = objax.Jit(self.loss,model.vars()) 36 | self.gradvals = objax.Jit(objax.GradValues(fastloss,model.vars()),model.vars()) 37 | self.model.predict = objax.Jit(objax.ForceArgs(model.__call__,training=False),model.vars()) 38 | #self.model.predict = lambda x: self.model(x,training=False) 39 | def loss(self,minibatch): 40 | """ Standard cross-entropy loss """ 41 | x,y = minibatch 42 | mse = jnp.mean((self.model(x,training=True)-y)**2)#jnp.mean(jnp.abs(self.model(x,training=True)-y)) 43 | return mse 44 | 45 | def metrics(self,loader): 46 | mse = lambda mb: np.asarray(jax.device_get(jnp.mean((self.model.predict(mb[0])-mb[1])**2))) 47 | return {'MSE':self.evalAverageMetrics(loader,mse)} 48 | def logStuff(self, step, minibatch=None): 49 | metrics = {} 50 | metrics['test_equivar_err'] = self.evalAverageMetrics(islice(self.dataloaders['test'],0,None,5), 51 | partial(equivariance_err,self.model)) # subsample by 5x so it doesn't take too long 52 | self.logger.add_scalars('metrics', metrics, step) 53 | super().logStuff(step,minibatch) 54 | 55 | @export 56 | class ClassifierPlus(Classifier): 57 | """ Trainer subclass. Implements loss (crossentropy), batchAccuracy 58 | and getAccuracy (full dataset) """ 59 | def __init__(self,model,*args,**kwargs): 60 | super().__init__(model,*args,**kwargs) 61 | 62 | fastloss = objax.Jit(self.loss,model.vars()) 63 | self.gradvals = objax.Jit(objax.GradValues(fastloss,model.vars()),model.vars()) 64 | self.model.predict = objax.Jit(objax.ForceArgs(model.__call__,training=False),model.vars()) 65 | #self.model.predict = lambda x: self.model(x,training=False) 66 | 67 | 68 | def logStuff(self, step, minibatch=None): 69 | metrics = {} 70 | metrics['test_equivar_err'] = self.evalAverageMetrics(islice(self.dataloaders['test'],0,None,5), 71 | partial(equivariance_err,self.model)) # subsample by 5x so it doesn't take too long 72 | self.logger.add_scalars('metrics', metrics, step) 73 | super().logStuff(step,minibatch) -------------------------------------------------------------------------------- /experiments/trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import dill 2 | from oil.logging.lazyLogger import LazyLogger 3 | from oil.utils.utils import Eval, Named 4 | from oil.utils.mytqdm import tqdm 5 | from oil.tuning.study import guess_metric_sign 6 | import copy, os, random 7 | import glob 8 | import numpy as np 9 | from natsort import natsorted 10 | import jax 11 | import logging 12 | from functools import partial 13 | import objax 14 | 15 | class Trainer(object,metaclass=Named): 16 | """ Base trainer 17 | """ 18 | def __init__(self, model, dataloaders, optim = objax.optimizer.Adam,lr_sched =lambda e:1, 19 | log_dir=None, log_suffix='',log_args={},early_stop_metric=None): 20 | # Setup model, optimizer, and dataloaders 21 | self.model = model# 22 | #self.model= objax.Jit(objax.ForceArgs(model,training=True)) #TODO: figure out static nums 23 | #self.model.predict = objax.Jit(objax.ForceArgs(model.__call__,training=False),model.vars()) 24 | #self.model.predict = objax.ForceArgs(model.__call__,training=False) 25 | #self._model = model 26 | #self.model = objax.ForceArgs(model,training=True) 27 | #self.model.predict = objax.ForceArgs(model.__call__,training=False) 28 | #self.model = objax.Jit(lambda x, training: model(x,training=training),model.vars(),static_argnums=(1,)) 29 | #self.model = objax.Jit(model,static_argnums=(1,)) 30 | 31 | self.optimizer = optim(model.vars()) 32 | self.lr_sched= lr_sched 33 | self.dataloaders = dataloaders # A dictionary of dataloaders 34 | self.epoch = 0 35 | 36 | self.logger = LazyLogger(log_dir, log_suffix, **log_args) 37 | #self.logger.add_text('ModelSpec','model: {}'.format(model)) 38 | self.hypers = {} 39 | self.ckpt = None# copy.deepcopy(self.state_dict()) #TODO fix model saving 40 | self.early_stop_metric = early_stop_metric 41 | #fastloss = objax.Jit(self.loss,model.vars()) 42 | 43 | self.gradvals = objax.GradValues(self.loss,self.model.vars()) 44 | #self.gradvals = objax.Jit(objax.GradValues(fastloss,model.vars()),model.vars()) 45 | def metrics(self,loader): 46 | return {} 47 | 48 | def loss(self,minibatch): 49 | raise NotImplementedError 50 | 51 | def train_to(self, final_epoch=100): 52 | assert final_epoch>=self.epoch, "trying to train less than already trained" 53 | self.train(final_epoch-self.epoch) 54 | 55 | def train(self, num_epochs=100): 56 | """ The main training loop""" 57 | start_epoch = self.epoch 58 | steps_per_epoch = len(self.dataloaders['train']); step=0 59 | for self.epoch in tqdm(range(start_epoch, start_epoch + num_epochs),desc='train'): 60 | for i, minibatch in enumerate(self.dataloaders['train']): 61 | step = i + self.epoch*steps_per_epoch 62 | self.step(self.epoch+i/steps_per_epoch,minibatch) 63 | with self.logger as do_log: 64 | if do_log: self.logStuff(step, minibatch) 65 | self.epoch+=1 66 | self.logStuff(step) 67 | 68 | def step(self, epoch, minibatch): 69 | grad,loss = self.gradvals(minibatch) 70 | self.optimizer(self.lr_sched(epoch),grad) 71 | return loss 72 | 73 | def logStuff(self, step, minibatch=None): 74 | metrics = {} 75 | # if minibatch is not None and hasattr(self,'loss'): 76 | # try: metrics['Minibatch_Loss'] = self.loss(self.model_params,minibatch) 77 | # except (NotImplementedError, TypeError): pass 78 | for loader_name,dloader in self.dataloaders.items(): # Ignore metrics on train 79 | if loader_name=='train' or len(dloader)==0 or loader_name[0]=='_': continue 80 | for metric_name, metric_value in self.metrics(dloader).items(): 81 | metrics[loader_name+'_'+metric_name] = metric_value 82 | self.logger.add_scalars('metrics', metrics, step) 83 | # for name,m in self.model.named_modules(): 84 | # if hasattr(m, 'log_data'): 85 | # m.log_data(self.logger,step,name) 86 | self.logger.report() 87 | # update the best checkpoint 88 | if self.early_stop_metric is not None: 89 | maximize = guess_metric_sign(self.early_stop_metric) 90 | sign = 2*maximize-1 91 | best = (sign*self.logger.scalar_frame[self.early_stop_metric].values).max() 92 | current = sign*self.logger.scalar_frame[self.early_stop_metric].iloc[-1] 93 | if current >= best: self.ckpt = copy.deepcopy(self.state_dict()) 94 | else: self.ckpt = copy.deepcopy(self.state_dict()) 95 | 96 | def evalAverageMetrics(self,loader,metrics): 97 | num_total, loss_totals = 0, 0 98 | for minibatch in loader: 99 | try: mb_size = loader.batch_size 100 | except AttributeError: mb_size=1 101 | loss_totals += mb_size*metrics(minibatch) 102 | num_total += mb_size 103 | if num_total==0: raise KeyError("dataloader is empty") 104 | return loss_totals/num_total 105 | 106 | def state_dict(self): 107 | #TODO: handle saving and loading state 108 | state = { 109 | 'outcome':self.logger.scalar_frame[-1:], 110 | 'epoch':self.epoch, 111 | # 'model_state':self.model.state_dict(), 112 | # 'optim_state':self.optimizer.state_dict(), 113 | # 'logger_state':self.logger.state_dict(), 114 | } 115 | return state 116 | 117 | # def load_state_dict(self,state): 118 | # self.epoch = state['epoch'] 119 | # self.model.load_state_dict(state['model_state']) 120 | # self.optimizer.load_state_dict(state['optim_state']) 121 | # self.logger.load_state_dict(state['logger_state']) 122 | 123 | # def load_checkpoint(self,path=None): 124 | # """ Loads the checkpoint from path, if None gets the highest epoch checkpoint""" 125 | # if not path: 126 | # chkpts = glob.glob(os.path.join(self.logger.log_dirr,'checkpoints/c*.state')) 127 | # path = natsorted(chkpts)[-1] # get most recent checkpoint 128 | # print(f"loading checkpoint {path}") 129 | # with open(path,'rb') as f: 130 | # self.load_state_dict(dill.load(f)) 131 | 132 | # def save_checkpoint(self): 133 | # return self.logger.save_object(self.ckpt,suffix=f'checkpoints/c{self.epoch}.state') 134 | 135 | -------------------------------------------------------------------------------- /experiments/trainer/utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import functools 3 | from oil.utils.utils import imap 4 | import numpy as np 5 | def minibatch_to(mb): 6 | try: 7 | if isinstance(mb,np.ndarray): 8 | return jax.device_put(mb) 9 | return jax.device_put(mb.numpy()) 10 | except AttributeError: 11 | if isinstance(mb,dict): 12 | return type(mb)(((k,minibatch_to(v)) for k,v in mb.items())) 13 | else: 14 | return type(mb)(minibatch_to(elem) for elem in mb) 15 | 16 | def LoaderTo(loader): 17 | return imap(functools.partial(minibatch_to),loader) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup,find_packages 2 | import sys, os, re 3 | 4 | README_FILE = 'README.md' 5 | 6 | def get_property(prop, project): 7 | result = re.search(r'{}\s*=\s*[\'"]([^\'"]*)[\'"]'.format(prop), open(project + '/__init__.py').read()) 8 | return result.group(1) 9 | 10 | project_name = "emlp" 11 | setup(name=project_name, 12 | description="A Practical Method for Constructing Equivariant Multilayer Perceptrons for Arbitrary Matrix Groups", 13 | version= get_property('__version__',project_name), 14 | author='Marc Finzi', 15 | author_email='maf820@nyu.edu', 16 | license='MIT', 17 | python_requires='>=3.6', 18 | install_requires=['h5py','objax','pytest','plum-dispatch', 19 | 'optax','tqdm>=4.38','matplotlib','scikit-learn'], 20 | extras_require = { 21 | 'EXPTS':['olive-oil-ml'] 22 | }, 23 | packages=find_packages(), 24 | long_description=open('README.md', encoding='UTF-8').read(), 25 | long_description_content_type='text/markdown', 26 | url='https://github.com/mfinzi/equivariant-MLP', 27 | classifiers=[ 28 | 'Development Status :: 4 - Beta', 29 | 'Intended Audience :: Developers', 30 | 'Intended Audience :: Science/Research', 31 | 'Programming Language :: Python :: 3', 32 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 33 | ], 34 | keywords=[ 35 | 'equivariance','MLP','symmetry','group','AI','neural network', 36 | 'representation','group theory','deep learning','machine learning', 37 | 'rotation','Lorentz invariance', 38 | ], 39 | 40 | ) 41 | -------------------------------------------------------------------------------- /tests/equivariance_tests.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np# 3 | import copy 4 | from emlp.reps import * 5 | from emlp.groups import * 6 | from emlp.nn import uniform_rep 7 | import pytest#import unittest 8 | from jax import vmap 9 | import jax.numpy as jnp 10 | import logging 11 | import argparse 12 | import sys 13 | import copy 14 | import inspect 15 | #from functools import partialmethod,partial 16 | 17 | def rel_error(t1,t2): 18 | error = jnp.sqrt(jnp.mean(jnp.abs(t1-t2)**2)) 19 | scale = jnp.sqrt(jnp.mean(jnp.abs(t1)**2)) + jnp.sqrt(jnp.mean(jnp.abs(t2)**2)) 20 | return error/jnp.maximum(scale,1e-7) 21 | 22 | def scale_adjusted_rel_error(t1,t2,g): 23 | error = jnp.sqrt(jnp.mean(jnp.abs(t1-t2)**2)) 24 | tscale = jnp.sqrt(jnp.mean(jnp.abs(t1)**2)) + jnp.sqrt(jnp.mean(jnp.abs(t2)**2)) 25 | gscale = jnp.sqrt(jnp.mean(jnp.abs(g-jnp.eye(g.shape[-1]))**2)) 26 | scale = jnp.maximum(tscale,gscale) 27 | return error/jnp.maximum(scale,1e-7) 28 | 29 | def equivariance_error(W,repin,repout,G): 30 | """ Computes the equivariance relative error rel_err(Wρ₁(g),ρ₂(g)W) 31 | of the matrix W (dim(repout),dim(repin)) 32 | according to the input and output representations and group G. """ 33 | N=5 34 | x = np.random.rand(N,repin.size()) 35 | gs = G.samples(N) 36 | ring = vmap(repin.rho_dense)(gs) 37 | routg = vmap(repout.rho_dense)(gs) 38 | equiv_err = scale_adjusted_rel_error(W@ring,routg@W,gs) 39 | return equiv_err 40 | 41 | def strip_parens(string): 42 | return string.replace('(','').replace(')','') 43 | 44 | def parametrize(cases,ids=None): 45 | """ Expands test cases with pytest.mark.parametrize but with argnames 46 | assumed and ids given by the ids=[str(case) for case in cases] """ 47 | def decorator(test_fn): 48 | argnames = ','.join(inspect.getfullargspec(test_fn).args) 49 | theids = [strip_parens(str(case)) for case in cases] if ids is None else ids 50 | return pytest.mark.parametrize(argnames,cases,ids=theids)(test_fn) 51 | return decorator 52 | # def expand_cases(cls,argseq): 53 | # def class_decorator(testcase): 54 | # for args in argseq: 55 | # setattr(cls, f"{testcase.__name__}_{args}", partialmethod(testcase,*tuplify(args))) 56 | # #setattr(cls,f"{testcase.__name__}_{args}",types.MethodType(partial(testcase,*tuplify(args)),cls)) 57 | # return testcase 58 | # return class_decorator 59 | 60 | # def tuplify(x): 61 | # if not isinstance(x, tuple): return (x,) 62 | # return x 63 | 64 | test_groups = [SO(n) for n in [2,3,4]]+[O(n) for n in [2,3,4]]+\ 65 | [SU(n) for n in [2,3,4]] + [U(n) for n in [2,3,4]] + \ 66 | [SL(n) for n in [2,3,4]] + [GL(n) for n in [2,3,4]] + \ 67 | [C(k) for k in [2,3,4,8]]+[D(k) for k in [2,3,4,8]]+\ 68 | [S(n) for n in [2,4,6]]+[Z(n) for n in [2,4,6]]+\ 69 | [SO11p(),SO13p(),SO13(),O13()] +[Sp(n) for n in [1,3]]+\ 70 | [RubiksCube(),Cube(),ZksZnxZn(2,2),ZksZnxZn(4,4)] 71 | # class TestRepresentationSubspace(unittest.TestCase): pass 72 | # expand_test_cases = partial(expand_cases,TestRepresentationSubspace) 73 | 74 | 75 | #@pytest.mark.parametrize("G",test_groups,ids=test_group_names) 76 | @parametrize(test_groups) 77 | def test_sum(G): 78 | N=5 79 | rep = T(0,2)+3*(T(0,0)+T(1,0))+T(0,0)+T(1,1)+2*T(1,0)+T(0,2)+T(0,1)+3*T(0,2)+T(2,0) 80 | rep = rep(G) 81 | if G.num_constraints()*rep.size()>1e11 or rep.size()**2>10**7: return 82 | P = rep.equivariant_projector() 83 | v = np.random.rand(rep.size()) 84 | v = P@v 85 | gs = G.samples(N) 86 | gv = (vmap(rep.rho_dense)(gs)*v).sum(-1) 87 | err = vmap(scale_adjusted_rel_error)(gv,v+jnp.zeros_like(gv),gs).mean() 88 | assert err<1e-4,f"Symmetric vector fails err {err:.3e} with G={G}" 89 | 90 | #@pytest.mark.parametrize("G",test_groups,ids=test_group_names) 91 | @parametrize([group for group in test_groups if group.d<5]) 92 | def test_prod(G): 93 | N=5 94 | rep = T(0,1)*T(0,0)*T(1,0)**2*T(1,0)*T(0,0)**3*T(0,1) 95 | rep = rep(G) 96 | if G.num_constraints()*rep.size()>1e11 or rep.size()**2>10**7: return 97 | # P = rep.equivariant_projector() 98 | # v = np.random.rand(rep.size()) 99 | # v = P@v 100 | Q = rep.equivariant_basis() 101 | v = Q@np.random.rand(Q.shape[-1]) 102 | gs = G.samples(N) 103 | gv = (vmap(rep.rho_dense)(gs)*v).sum(-1) 104 | 105 | #print(f"g {gs[0]} and rho_dense {rep.rho_dense(gs[0])} {rep.rho_dense(gs[0]).shape}") 106 | err = vmap(scale_adjusted_rel_error)(gv,v+jnp.zeros_like(gv),gs).mean() 107 | assert err<1e-4,f"Symmetric vector fails err {err:.3e} with G={G}" 108 | 109 | #@pytest.mark.parametrize("G",test_groups,ids=test_group_names) 110 | @parametrize(test_groups) 111 | def test_high_rank_representations(G): 112 | N=5 113 | r = 10 114 | for p in range(r+1): 115 | for q in range(r-p+1): 116 | if G.num_constraints()*G.d**(3*(p+q))>1e11: continue 117 | if G.is_orthogonal and q>0: continue 118 | #try: 119 | #logging.info(f"{p},{q},{T(p,q)}") 120 | rep = T(p,q)(G) 121 | P = rep.equivariant_projector() 122 | v = np.random.rand(rep.size()) 123 | v = P@v 124 | g = vmap(rep.rho_dense)(G.samples(N)) 125 | gv = (g*v).sum(-1) 126 | #print(f"v{v.shape}, g{g.shape},gv{gv.shape},{G},T{p,q}") 127 | err = vmap(scale_adjusted_rel_error)(gv,v+jnp.zeros_like(gv),g).mean() 128 | if np.isnan(err): continue # deal with nans on cpu later 129 | assert err<1e-4,f"Symmetric vector fails err {err:.3e} with T{p,q} and G={G}" 130 | logging.info(f"Success with T{p,q} and G={G}") 131 | # except Exception as e: 132 | # print(f"Failed with G={G} and T({p,q})") 133 | # raise e 134 | 135 | @parametrize([ 136 | (SO(3),T(1)+2*T(0),T(1)+T(2)+2*T(0)+T(1)), 137 | (SO(3),5*T(0)+5*T(1),3*T(0)+T(2)+2*T(1)), 138 | (SO(3),5*(T(0)+T(1)),2*(T(0)+T(1))+T(2)+T(1)), 139 | (SO(4), T(1)+2*T(2),(T(0)+T(3))*T(0)), 140 | (SO13p(),T(2)+4*T(1,0)+T(0,1),10*T(0)+3*T(1,0)+3*T(0,1)+T(0,2)+T(2,0)+T(1,1)), 141 | (Sp(2),(V+2*V**2)*(V.T+1*V).T + V.T,3*V**0 + V + V*V.T), 142 | (SU(3),T(2,0)+T(1,1)+T(0)+2*T(0,1),T(1,1)+V+V.T+T(0)+T(2,0)+T(0,2)) 143 | ]) 144 | def test_equivariant_matrix(G,repin,repout): 145 | N=5 146 | repin = repin(G) 147 | repout = repout(G) 148 | #repW = repout*repin.T 149 | repW = repin>>repout 150 | P = repW.equivariant_projector() 151 | W = np.random.rand(repout.size(),repin.size()) 152 | W = (P@W.reshape(-1)).reshape(*W.shape) 153 | 154 | x = np.random.rand(N,repin.size()) 155 | gs = G.samples(N) 156 | ring = vmap(repin.rho_dense)(gs) 157 | routg = vmap(repout.rho_dense)(gs) 158 | gx = (ring@x[...,None])[...,0] 159 | Wgx =gx@W.T 160 | #print(g.shape,(x@W.T).shape) 161 | gWx = (routg@(x@W.T)[...,None])[...,0] 162 | equiv_err = rel_error(Wgx,gWx) 163 | assert equiv_err<1e-4,f"Equivariant gWx=Wgx fails err {equiv_err:.3e} with G={G}" 164 | 165 | # print(f"R {repW.rho(gs[0])}") 166 | # print(f"R1 x R2 {jnp.kron(routg[0],jnp.linalg.inv(ring[0]).T)}") 167 | gvecW = (vmap(repW.rho_dense)(gs)*W.reshape(-1)).sum(-1) 168 | gWerr =vmap(scale_adjusted_rel_error)(gvecW,W.reshape(-1)+jnp.zeros_like(gvecW),gs).mean() 169 | assert gWerr<1e-4,f"Symmetric gvec(W)=vec(W) fails err {gWerr:.3e} with G={G}" 170 | 171 | @parametrize([(SO(3),5*T(0)+5*T(1),3*T(0)+T(2)+2*T(1)), 172 | (SO13p(),4*T(1,0),10*T(0)+3*T(1,0)+3*T(0,1)+T(0,2)+T(2,0)+T(1,1))]) 173 | def test_bilinear_layer(G,repin,repout): 174 | N=5 175 | repin = repin(G) 176 | repout = repout(G) 177 | repW = repout*repin.T 178 | Wdim,P = bilinear_weights(repout,repin) 179 | x = np.random.rand(N,repin.size()) 180 | gs = G.samples(N) 181 | ring = vmap(repin.rho_dense)(gs) 182 | routg = vmap(repout.rho_dense)(gs) 183 | gx = (ring@x[...,None])[...,0] 184 | 185 | W = np.random.rand(Wdim) 186 | W_x = P(W,x) 187 | Wxx = (W_x@x[...,None])[...,0] 188 | gWxx = (routg@Wxx[...,None])[...,0] 189 | Wgxgx =(P(W,gx)@gx[...,None])[...,0] 190 | equiv_err = rel_error(Wgxgx,gWxx) 191 | assert equiv_err<1e-4,f"Bilinear Equivariance fails err {equiv_err:.3e} with G={G}" 192 | 193 | @parametrize(test_groups) 194 | def test_large_representations(G): 195 | N=5 196 | ch = 256 197 | rep =repin=repout= uniform_rep(ch,G) 198 | repW = rep>>rep 199 | P = repW.equivariant_projector() 200 | W = np.random.rand(repout.size(),repin.size()) 201 | W = (P@W.reshape(-1)).reshape(*W.shape) 202 | 203 | x = np.random.rand(N,repin.size()) 204 | gs = G.samples(N) 205 | ring = vmap(repin.rho_dense)(gs) 206 | routg = vmap(repout.rho_dense)(gs) 207 | gx = (ring@x[...,None])[...,0] 208 | Wgx =gx@W.T 209 | #print(g.shape,(x@W.T).shape) 210 | gWx = (routg@(x@W.T)[...,None])[...,0] 211 | equiv_err = vmap(scale_adjusted_rel_error)(Wgx,gWx,gs).mean() 212 | assert equiv_err<1e-4,f"Large Rep Equivariant gWx=Wgx fails err {equiv_err:.3e} with G={G}" 213 | logging.info(f"Success with G={G}") 214 | 215 | # #print(dir(TestRepresentationSubspace)) 216 | # if __name__ == '__main__': 217 | # parser = argparse.ArgumentParser() 218 | # parser.add_argument("--log", default="warning",help=("Logging Level Example --log debug', default='warning'")) 219 | # options,unknown_args = parser.parse_known_args()#["--log"]) 220 | # levels = {'critical': logging.CRITICAL,'error': logging.ERROR,'warn': logging.WARNING,'warning': logging.WARNING, 221 | # 'info': logging.INFO,'debug': logging.DEBUG} 222 | # level = levels.get(options.log.lower()) 223 | # logging.getLogger().setLevel(level) 224 | # unit_argv = [sys.argv[0]] + unknown_args 225 | # unittest.main(argv=unit_argv) -------------------------------------------------------------------------------- /tests/model_tests.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import vmap 3 | import numpy as np 4 | import pytest 5 | from torch.utils.data import DataLoader 6 | from emlp.nn import uniform_rep,MLP,EMLP,Standardize 7 | import emlp 8 | from equivariance_tests import parametrize,rel_error,scale_adjusted_rel_error 9 | from oil.utils.utils import FixedNumpySeed, FixedPytorchSeed 10 | from emlp.datasets import Inertia,O5Synthetic,ParticleInteraction,InvertedCube 11 | from jax import random 12 | 13 | 14 | # def rel_err(a,b): 15 | # return jnp.sqrt(((a-b)**2).mean())/(jnp.sqrt((a**2).mean())+jnp.sqrt((b**2).mean()))# 16 | 17 | # def scale_adjusted_rel_err(a,b,g): 18 | # return jnp.sqrt(((a-b)**2).mean())/(jnp.sqrt((a**2).mean())+jnp.sqrt((b**2).mean())+jnp.abs(g-jnp.eye(g.shape[-1])).mean()) 19 | 20 | def equivariance_err(model,mb,repin,repout,group): 21 | x,y = mb 22 | gs = group.samples(x.shape[0]) 23 | rho_gin = vmap(repin(group).rho_dense)(gs) 24 | rho_gout = vmap(repout(group).rho_dense)(gs) 25 | y1 = model((rho_gin@x[...,None])[...,0]) 26 | y2 = (rho_gout@model(x)[...,None])[...,0] 27 | return np.asarray(scale_adjusted_rel_error(y1,y2,gs)) 28 | 29 | def get_dsmb(dsclass): 30 | seed=2021 31 | bs=50 32 | with FixedNumpySeed(seed),FixedPytorchSeed(seed): 33 | ds = dsclass(100) 34 | dataloader = DataLoader(ds,batch_size=min(bs,len(ds)),num_workers=0,pin_memory=False) 35 | mb = next(iter(dataloader)) 36 | mb = jax.device_put(mb[0].numpy()),jax.device_put(mb[1].numpy()) 37 | return ds,mb 38 | 39 | @parametrize([Inertia,O5Synthetic,ParticleInteraction,InvertedCube]) 40 | def test_init_forward_and_equivariance(dsclass): 41 | network=emlp.nn.objax.EMLP 42 | ds,mb = get_dsmb(dsclass) 43 | model = network(ds.rep_in,ds.rep_out,group=ds.symmetry) 44 | assert equivariance_err(model,mb,ds.rep_in,ds.rep_out,ds.symmetry) < 1e-4, "Objax EMLP failed equivariance test" 45 | 46 | @parametrize([Inertia]) 47 | def test_haiku_emlp(dsclass): 48 | import haiku as hk 49 | from emlp.nn.haiku import EMLP as hkEMLP 50 | network = hkEMLP 51 | ds,mb = get_dsmb(dsclass) 52 | net = network(ds.rep_in,ds.rep_out,group=ds.symmetry) 53 | net = hk.without_apply_rng(hk.transform(net)) 54 | params = net.init(random.PRNGKey(42),mb[0]) 55 | model = lambda x: net.apply(params,x) 56 | assert equivariance_err(model,mb,ds.rep_in,ds.rep_out,ds.symmetry) < 1e-4, "Haiku EMLP failed equivariance test" 57 | 58 | @parametrize([Inertia]) 59 | def test_flax_emlp(dsclass): 60 | from emlp.nn.flax import EMLP as flaxEMLP 61 | network = flaxEMLP 62 | ds,mb = get_dsmb(dsclass) 63 | net = network(ds.rep_in,ds.rep_out,group=ds.symmetry) 64 | params = net.init(random.PRNGKey(42),mb[0]) 65 | model = lambda x: net.apply(params,x) 66 | assert equivariance_err(model,mb,ds.rep_in,ds.rep_out,ds.symmetry) < 1e-4, "flax EMLP failed equivariance test" 67 | 68 | @parametrize([Inertia]) 69 | def test_pytorch_emlp(dsclass): 70 | import torch 71 | from emlp.nn.pytorch import EMLP as ptEMLP 72 | network=ptEMLP 73 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 74 | ds,mb = get_dsmb(dsclass) 75 | net = network(ds.rep_in,ds.rep_out,group=ds.symmetry).to(device) 76 | model = lambda x: jax.device_put(net(torch.from_numpy(np.asarray(x)).to(device)).cpu().data.numpy()) 77 | assert equivariance_err(model,mb,ds.rep_in,ds.rep_out,ds.symmetry) < 1e-4, "Pytorch EMLP failed equivariance test" 78 | 79 | 80 | from emlp.reps import vis, sparsify_basis, V,Rep 81 | from emlp.groups import S,SO 82 | 83 | def test_utilities(): 84 | W = V(SO(3)) 85 | vis(W,W) 86 | Q = (W**2>>W).equivariant_basis() 87 | SQ = sparsify_basis(Q) 88 | A = SQ@(1+np.arange(SQ.shape[-1])) 89 | nunique = len(np.unique(np.abs(A))) 90 | assert nunique in (SQ.shape[-1],SQ.shape[-1]+1), "Sparsify failes on SO(3) T3" 91 | 92 | 93 | def test_bespoke_representations(): 94 | class ProductSubRep(Rep): 95 | def __init__(self,G,subgroup_id,size): 96 | """ Produces the representation of the subgroup of G = G1 x G2 97 | with the index subgroup_id in {0,1} specifying G1 or G2. 98 | Also requires specifying the size of the representation given by G1.d or G2.d """ 99 | self.G = G 100 | self.index = subgroup_id 101 | self._size = size 102 | def __str__(self): 103 | return "V_"+str(self.G).split('x')[self.index] 104 | def size(self): 105 | return self._size 106 | def rho(self,M): 107 | # Given that M is a LazyKron object, we can just get the argument 108 | return M.Ms[self.index] 109 | def drho(self,A): 110 | return A.Ms[self.index] 111 | def __call__(self,G): 112 | # adding this will probably not be necessary in a future release, 113 | # necessary now because rep is __call__ed in nn.EMLP constructor 114 | assert self.G==G 115 | return self 116 | G1,G2 = SO(3),S(5) 117 | G = G1 * G2 118 | 119 | VSO3 = ProductSubRep(G,0,G1.d) 120 | VS5 = ProductSubRep(G,1,G2.d) 121 | Vin = VS5 + V(G) 122 | Vout = VSO3 123 | str(Vin>>Vout) 124 | model = emlp.nn.EMLP(Vin, Vout, group=G) 125 | input_point = np.random.randn(Vin.size())*10 126 | from emlp.reps.linear_operators import LazyKron 127 | lazy_G_sample = LazyKron([G1.sample(),G2.sample()]) 128 | 129 | out1 = model(Vin.rho(lazy_G_sample)@input_point) 130 | out2 = Vout.rho(lazy_G_sample)@model(input_point) 131 | assert rel_error(out1,out2) < 1e-4, "EMLP equivariance fails on bespoke productsubrep" 132 | -------------------------------------------------------------------------------- /tests/product_groups_tests.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np# 3 | import copy 4 | from emlp.reps import * 5 | from emlp.groups import * 6 | from emlp.nn import uniform_rep 7 | from equivariance_tests import parametrize,rel_error,scale_adjusted_rel_error 8 | import unittest 9 | from jax import vmap 10 | import jax.numpy as jnp 11 | import logging 12 | 13 | 14 | @parametrize([(SO(3),S(5)),(S(5),SO(3))]) 15 | def test_symmetric_mixed_tensor(G1,G2): 16 | N=5 17 | rep = T(2)(G1)*T(1)(G2) 18 | P = rep.equivariant_projector() 19 | v = np.random.rand(rep.size()) 20 | v = P@v 21 | samples = {G1:G1.samples(N),G2:G2.samples(N)} 22 | gv = (vmap(rep.rho_dense)(samples)*v).sum(-1) 23 | err = rel_error(gv,v+jnp.zeros_like(gv)) 24 | assert err<3e-5,f"Symmetric vector fails err {err:.3e} with G={G1}x{G2}" 25 | 26 | 27 | @parametrize([(SO(3),S(5)),(S(5),SO(3))]) 28 | def test_symmetric_mixed_tensor_sum(G1,G2): 29 | N=5 30 | rep = T(2)(G1)*T(1)(G2) + 2*T(0)(G1)*T(2)(G2)+T(1)(G1) +T(1)(G2) 31 | P = rep.equivariant_projector() 32 | v = np.random.rand(rep.size()) 33 | v = P@v 34 | samples = {G1:G1.samples(N),G2:G2.samples(N)} 35 | gv = (vmap(rep.rho_dense)(samples)*v).sum(-1) 36 | err = rel_error(gv,v+jnp.zeros_like(gv)) 37 | assert err<3e-5,f"Symmetric vector fails err {err:.3e} with G={G1}x{G2}" 38 | 39 | 40 | @parametrize([(SO(3),S(5)),(S(5),SO(3))]) 41 | def test_symmetric_mixed_products(G1,G2): 42 | N=5 43 | rep1 = (T(0)+2*T(1)+T(2))(G1) 44 | rep2 = (T(0)+T(1))(G2) 45 | rep = rep2*rep1.T 46 | P = rep.equivariant_projector() 47 | v = np.random.rand(rep.size()) 48 | v = P@v 49 | W = v.reshape((rep2.size(),rep1.size())) 50 | x = np.random.rand(N,rep1.size()) 51 | g1s = G1.samples(N) 52 | g2s = G2.samples(N) 53 | ring = vmap(rep1.rho_dense)(g1s) 54 | routg = vmap(rep2.rho_dense)(g2s) 55 | gx = (ring@x[...,None])[...,0] 56 | Wgx =gx@W.T 57 | gWx = (routg@(x@W.T)[...,None])[...,0] 58 | equiv_err = rel_error(Wgx,gWx) 59 | assert equiv_err<1e-5,f"Equivariant gWx=Wgx fails err {equiv_err:.3e} with G={G1}x{G2}" 60 | samples = {G1:g1s,G2:g2s} 61 | gv = (vmap(rep.rho_dense)(samples)*v).sum(-1) 62 | err = rel_error(gv,v+jnp.zeros_like(gv)) 63 | assert err<3e-5,f"Symmetric vector fails err {err:.3e} with G={G1}x{G2}" 64 | 65 | @parametrize([(SO(3),S(5)),(S(5),SO(3))]) 66 | def test_equivariant_matrix(G1,G2): 67 | N=5 68 | repin = T(2)(G2) + 3*T(0)(G1) + T(1)(G2)+2*T(2)(G1)*T(1)(G2) 69 | repout = (T(1)(G1) + T(2)(G1)*T(0)(G2) + T(1)(G1)*T(1)(G2) + T(0)(G1)+T(2)(G1)*T(1)(G2)) 70 | repW = repout*repin.T 71 | P = repW.equivariant_projector() 72 | W = np.random.rand(repout.size(),repin.size()) 73 | W = (P@W.reshape(-1)).reshape(*W.shape) 74 | 75 | x = np.random.rand(N,repin.size()) 76 | samples = {G1:G1.samples(N),G2:G2.samples(N)} 77 | ring = vmap(repin.rho_dense)(samples) 78 | routg = vmap(repout.rho_dense)(samples) 79 | gx = (ring@x[...,None])[...,0] 80 | Wgx =gx@W.T 81 | #print(g.shape,(x@W.T).shape) 82 | gWx = (routg@(x@W.T)[...,None])[...,0] 83 | equiv_err = rel_error(Wgx,gWx) 84 | assert equiv_err<3e-5,f"Equivariant gWx=Wgx fails err {equiv_err:.3e} with G={G1}x{G2}" 85 | # too much memory to run 86 | # gvecW = (vmap(repW.rho_dense)(samples)*W.reshape(-1)).sum(-1) 87 | # for i in range(N): 88 | # gWerr = rel_error(gvecW[i],W.reshape(-1)) 89 | # assert gWerr<1e-5,f"Symmetric gvec(W)=vec(W) fails err {gWerr:.3e} with G={G1}x{G2}" --------------------------------------------------------------------------------