├── .github └── workflows │ ├── build.yml │ ├── coverage.yml │ └── pypi.yml ├── .gitignore ├── .readthedocs.yml ├── LICENSE ├── README.rst ├── docs ├── Makefile ├── README.md ├── make.bat └── source │ ├── conf.py │ ├── constraints.rst │ ├── index.rst │ ├── invertibility │ ├── glp.rst │ ├── index.rst │ └── sl.rst │ ├── lowrank │ ├── fixedrank.rst │ ├── index.rst │ └── lowrank.rst │ ├── orthogonal │ ├── almostorthogonal.rst │ ├── grassmannian.rst │ ├── index.rst │ ├── so.rst │ ├── sphere.rst │ └── stiefel.rst │ ├── product.rst │ ├── psd │ ├── index.rst │ ├── psd.rst │ ├── pssd.rst │ ├── pssdfixedrank.rst │ └── pssdlowrank.rst │ ├── spelling_wordlist.txt │ └── vector_spaces │ ├── index.rst │ ├── reals.rst │ ├── skew.rst │ └── symmetric.rst ├── examples ├── __init__.py ├── copying_problem.py ├── eigenvalue.py ├── parametrisations.ipynb └── sequential_mnist.py ├── geotorch ├── __init__.py ├── almostorthogonal.py ├── constraints.py ├── exceptions.py ├── fixedrank.py ├── glp.py ├── grassmannian.py ├── lowrank.py ├── parametrize.py ├── product.py ├── psd.py ├── pssd.py ├── pssdfixedrank.py ├── pssdlowrank.py ├── reals.py ├── skew.py ├── sl.py ├── so.py ├── sphere.py ├── stiefel.py ├── symmetric.py └── utils.py ├── setup.cfg ├── setup.py └── test ├── __init__.py ├── test_almostorthogonal.py ├── test_glp.py ├── test_integration.py ├── test_lowrank.py ├── test_orthogonal.py ├── test_positive_semidefinite.py ├── test_product.py ├── test_skew.py ├── test_sl.py ├── test_sphere.py └── test_symmetric.py /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ '*' ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | python-version: [3.6, 3.7, 3.8] 16 | os: [ubuntu-latest, macos-latest, windows-latest] 17 | 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | 25 | - name: Install dependencies others 26 | if: ${{ matrix.os != 'windows-latest' }} 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install .[dev] 30 | 31 | # Windows is treated differently, as PyTorch is not uploaded to Pypi atm 32 | - name: Install dependencies windows 33 | if: ${{ matrix.os == 'windows-latest' }} 34 | run: | 35 | python -m pip install --upgrade pip 36 | pip install torch===1.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 37 | pip install .[dev] 38 | 39 | - name: Lint with flake8 40 | run: | 41 | # stop the build if there are Python syntax errors or undefined names 42 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 43 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 44 | flake8 . --count --exit-zero --max-complexity=12 --max-line-length=127 --statistics 45 | 46 | - name: Lint with Black 47 | run: | 48 | black --check --diff . 49 | 50 | - name: Test with pytest 51 | run: | 52 | pytest test 53 | -------------------------------------------------------------------------------- /.github/workflows/coverage.yml: -------------------------------------------------------------------------------- 1 | name: Coverage 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | # Setup lastest python version 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 3.8 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.8' 21 | 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install codecov pytest-cov 26 | pip install .[dev] 27 | 28 | - name: Coverage tests 29 | run: pytest --cov=geotorch test/ --cov-report term-missing 30 | 31 | - name: Codecov 32 | if: success() 33 | env: 34 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 35 | run: codecov 36 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 3.x 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine 20 | - name: Build and publish 21 | env: 22 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 23 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 24 | run: | 25 | python setup.py sdist bdist_wheel 26 | twine upload dist/* 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .cache/ 2 | .git/ 3 | .vscode/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | # pytype static type analyzer 139 | .pytype/ 140 | 141 | # Cython debug symbols 142 | cython_debug/ 143 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | python: 4 | version: 3.7 5 | install: 6 | - method: pip 7 | path: . 8 | extra_requirements: 9 | - dev 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020- Mario Lezcano-Casado 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | GeoTorch 2 | ======== 3 | 4 | |Build| |Docs| |Codecov| |Codestyle Black| |License| 5 | 6 | A library for constrained optimization and manifold optimization for deep learning in PyTorch 7 | 8 | Overview 9 | -------- 10 | 11 | GeoTorch provides a simple way to perform constrained optimization and optimization on manifolds in PyTorch. 12 | It is compatible out of the box with any optimizer, layer, and model implemented in PyTorch without any boilerplate in the training code. Just state the constraints when you construct the model and you are ready to go! 13 | 14 | .. code:: python 15 | 16 | import torch 17 | import torch.nn as nn 18 | import geotorch 19 | 20 | class Model(nn.Module): 21 | def __init__(self): 22 | super().__init__() 23 | # One line suffices: Instantiate a linear layer with orthonormal columns 24 | self.linear = nn.Linear(64, 128) 25 | geotorch.orthogonal(self.linear, "weight") 26 | 27 | # Works with tensors: Instantiate a CNN with kernels of rank 1 28 | self.cnn = nn.Conv2d(16, 32, 3) 29 | geotorch.low_rank(self.cnn, "weight", rank=1) 30 | 31 | # Weights are initialized to a random value when you put the constraints, but 32 | # you may re-initialize them to a different value by assigning to them 33 | self.linear.weight = torch.eye(128, 64) 34 | # And that's all you need to do. The rest is regular PyTorch code 35 | 36 | def forward(self, x): 37 | # self.linear is orthogonal and every 3x3 kernel in self.cnn is of rank 1 38 | 39 | # Use the model as you would normally do. Everything just works 40 | model = Model().cuda() 41 | 42 | # Use your optimizer of choice. Any optimizer works out of the box with any parametrization 43 | optim = torch.optim.Adam(model.parameters(), lr=lr) 44 | 45 | Constraints 46 | ----------- 47 | 48 | The following constraints are implemented and may be used as in the example above: 49 | 50 | - |symmetric|_. Symmetric matrices 51 | - |skew_constr|_. Skew-symmetric matrices 52 | - |sphere_constr|_. Vectors of norm ``1`` 53 | - |orthogonal|_. Matrices with orthogonal columns 54 | - |grassmannian|_. Skew-symmetric matrices 55 | - |almost_orthogonal|_. Matrices with singular values in the interval ``[1-λ, 1+λ]`` 56 | - |invertible|_. Invertible matrices with positive determinant 57 | - |sln|_. Matrices of determinant equal to ``1`` 58 | - |low_rank|_. Matrices of rank at most ``r`` 59 | - |fixed_rank|_. Matrices of rank ``r`` 60 | - |positive_definite|_. Positive definite matrices 61 | - |positive_semidefinite|_. Positive semidefinite matrices 62 | - |positive_semidefinite_low_rank|_. Positive semidefinite matrices of rank at most ``r`` 63 | - |positive_semidefinite_fixed_rank|_. Positive semidefinite matrices of rank ``r`` 64 | 65 | .. |symmetric| replace:: ``geotorch.symmetric`` 66 | .. _symmetric: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.symmetric 67 | .. |skew_constr| replace:: ``geotorch.skew`` 68 | .. _skew_constr: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.skew 69 | .. |sphere_constr| replace:: ``geotorch.sphere`` 70 | .. _sphere_constr: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.sphere 71 | .. |orthogonal| replace:: ``geotorch.orthogonal`` 72 | .. _orthogonal: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.orthogonal 73 | .. |grassmannian| replace:: ``geotorch.grassmannian`` 74 | .. _grassmannian: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.grassmannian 75 | .. |almost_orthogonal| replace:: ``geotorch.almost_orthogonal(λ)`` 76 | .. _almost_orthogonal: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.almost_orthogonal 77 | .. |invertible| replace:: ``geotorch.invertible`` 78 | .. _invertible: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.invertible 79 | .. |sln| replace:: ``geotorch.sln`` 80 | .. _sln: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.sln 81 | .. |low_rank| replace:: ``geotorch.low_rank(r)`` 82 | .. _low_rank: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.low_rank 83 | .. |fixed_rank| replace:: ``geotorch.fixed_rank(r)`` 84 | .. _fixed_rank: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.fixed_rank 85 | .. |positive_definite| replace:: ``geotorch.positive_definite`` 86 | .. _positive_definite: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.positive_definite 87 | .. |positive_semidefinite| replace:: ``geotorch.positive_semidefinite`` 88 | .. _positive_semidefinite: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.positive_semidefinite 89 | .. |positive_semidefinite_low_rank| replace:: ``geotorch.positive_semidefinite_low_rank(r)`` 90 | .. _positive_semidefinite_low_rank: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.positive_semidefinite_low_rank 91 | .. |positive_semidefinite_fixed_rank| replace:: ``geotorch.positive_semidefinite_fixed_rank(r)`` 92 | .. _positive_semidefinite_fixed_rank: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.positive_semidefinite_fixed_rank 93 | 94 | Each of these constraints have some extra parameters which can be used to tailor the 95 | behavior of each constraint to the problem in hand. For more on this, see the documentation. 96 | 97 | These constraints are a fronted for the families of spaces listed below. 98 | 99 | Supported Spaces 100 | ---------------- 101 | 102 | Each constraint in GeoTorch is implemented as a manifold. These give the user more flexibility 103 | on the options that they choose for each parametrization. All these support Riemannian Gradient 104 | Descent (more on this `here`_), but they also support optimization via any other PyTorch optimizer. 105 | 106 | GeoTorch currently supports the following spaces: 107 | 108 | - |reals|_: ``Rⁿ``. Unrestricted optimization 109 | - |sym|_: Vector space of symmetric matrices 110 | - |skew|_: Vector space of skew-symmetric matrices 111 | - |sphere|_: Sphere in ``Rⁿ``. ``{ x ∈ Rⁿ | ||x|| = 1 } ⊂ Rⁿ`` 112 | - |so|_: Manifold of ``n×n`` orthogonal matrices 113 | - |st|_: Manifold of ``n×k`` matrices with orthonormal columns 114 | - |almost|_: Manifold of ``n×k`` matrices with singular values in the interval ``[1-λ, 1+λ]`` 115 | - |grass|_: Manifold of ``k``-dimensional subspaces in ``Rⁿ`` 116 | - |glp|_: Manifold of invertible ``n×n`` matrices with positive determinant 117 | - |sl|_: Manifold of ``n×n`` matrices with determinant equal to `1` 118 | - |low|_: Variety of ``n×k`` matrices of rank ``r`` or less 119 | - |fixed|_: Manifold of ``n×k`` matrices of rank ``r`` 120 | - |psd|_: Cone of ``n×n`` symmetric positive definite matrices 121 | - |pssd|_: Cone of ``n×n`` symmetric positive semi-definite matrices 122 | - |pssdlow|_: Variety of ``n×n`` symmetric positive semi-definite matrices of rank ``r`` or less 123 | - |pssdfixed|_: Manifold of ``n×n`` symmetric positive semi-definite matrices of rank ``r`` 124 | - |product|_: Product of manifolds ``M₁ × ... × Mₖ`` 125 | 126 | Every space of dimension ``(n, k)`` can be applied to tensors of shape ``(*, n, k)``, so we also get efficient parallel implementations of product spaces such as 127 | 128 | - ``ObliqueManifold(n,k)``: Matrix with unit length columns, ``Sⁿ⁻¹ × ...ᵏ⁾ × Sⁿ⁻¹`` 129 | 130 | Using GeoTorch in your Code 131 | --------------------------- 132 | 133 | The files in `examples/copying_problem.py`_ and `examples/sequential_mnist.py`_ serve as tutorials to see how to handle the initialization and usage of GeoTorch in some real code. They also show how to implement Riemannian Gradient Descent and some other tricks. For an introduction to how the library is actually implemented, see the Jupyter Notebook `examples/parametrisations.ipynb`_. 134 | 135 | You may try GeoTorch installing it as 136 | 137 | .. code:: bash 138 | 139 | pip install git+https://github.com/Lezcano/geotorch/ 140 | 141 | GeoTorch is tested in Linux, Mac, and Windows environments for Python >= 3.6 and supports PyTorch >= 1.9 142 | 143 | Sharing Weights, Parametrizations, and Normalizing Flows 144 | -------------------------------------------------------- 145 | 146 | If one wants to use a parametrized tensor in different places in their model, or uses one parametrized layer many times, for example in an RNN, it is recommended to wrap the forward pass as follows to avoid each parametrization to be computed many times: 147 | 148 | .. code:: python 149 | 150 | with geotorch.parametrize.cached(): 151 | logits = model(input_) 152 | 153 | Of course, this ``with`` statement may be used simply inside the forward function where the parametrized layer is used several times. 154 | 155 | These ideas fall in the context of parametrized optimization, where one wraps a tensor ``X`` with a function ``f``, and rather than using ``X``, uses ``f(X)``. Particular examples of this idea are pruning, weight normalization, and spectral normalization among others. This repository implements a framework to approach this kind of problems. This framework was accepted to core PyTorch 1.8. It can be found under `torch.nn.utils.parametrize`_ and `torch.nn.utils.parametrizations`_. When using PyTorch 1.10 or higher, the native PyTorch functions are used within GeoTorch. In this case, the user can interact with the parametrizations in GeoTorch using the PyTorch functions. 156 | 157 | As every space in GeoTorch is, at its core, a map from a flat space into a manifold, the tools implemented here also serve as a building block in normalizing flows. Using a factorized space such as |low|_ it is direct to compute the determinant of the transformation it defines, as we have direct access to the singular values of the layer. 158 | 159 | .. |reals| replace:: ``Rn(n)`` 160 | .. _reals: https://geotorch.readthedocs.io/en/latest/vector_spaces/reals.html 161 | .. |sym| replace:: ``Sym(n)`` 162 | .. _sym: https://geotorch.readthedocs.io/en/latest/vector_spaces/symmetric.html 163 | .. |skew| replace:: ``Skew(n)`` 164 | .. _skew: https://geotorch.readthedocs.io/en/latest/vector_spaces/skew.html 165 | .. |sphere| replace:: ``Sphere(n)`` 166 | .. _sphere: https://geotorch.readthedocs.io/en/latest/orthogonal/sphere.html 167 | .. |so| replace:: ``SO(n)`` 168 | .. _so: https://geotorch.readthedocs.io/en/latest/orthogonal/so.html 169 | .. |st| replace:: ``St(n,k)`` 170 | .. _st: https://geotorch.readthedocs.io/en/latest/orthogonal/stiefel.html 171 | .. |almost| replace:: ``AlmostOrthogonal(n,k,λ)`` 172 | .. _almost: https://geotorch.readthedocs.io/en/latest/orthogonal/almostorthogonal.html 173 | .. |grass| replace:: ``Gr(n,k)`` 174 | .. _grass: https://geotorch.readthedocs.io/en/latest/orthogonal/grassmannian.html 175 | .. |glp| replace:: ``GLp(n)`` 176 | .. _glp: https://geotorch.readthedocs.io/en/latest/invertibility/glp.html 177 | .. |sl| replace:: ``SL(n)`` 178 | .. _sl: https://geotorch.readthedocs.io/en/latest/invertibility/sl.html 179 | .. |low| replace:: ``LowRank(n,k,r)`` 180 | .. _low: https://geotorch.readthedocs.io/en/latest/lowrank/lowrank.html 181 | .. |fixed| replace:: ``FixedRank(n,k,r)`` 182 | .. _fixed: https://geotorch.readthedocs.io/en/latest/lowrank/fixedrank.html 183 | .. |psd| replace:: ``PSD(n)`` 184 | .. _psd: https://geotorch.readthedocs.io/en/latest/psd/psd.html 185 | .. |pssd| replace:: ``PSSD(n)`` 186 | .. _pssd: https://geotorch.readthedocs.io/en/latest/psd/pssd.html 187 | .. |pssdlow| replace:: ``PSSDLowRank(n,r)`` 188 | .. _pssdlow: https://geotorch.readthedocs.io/en/latest/psd/pssdlowrank.html 189 | .. |pssdfixed| replace:: ``PSSDFixedRank(n,r)`` 190 | .. _pssdfixed: https://geotorch.readthedocs.io/en/latest/psd/pssdfixedrank.html 191 | .. |product| replace:: ``ProductManifold(M₁, ..., Mₖ)`` 192 | .. _product: https://geotorch.readthedocs.io/en/latest/product.html 193 | 194 | 195 | Bibliography 196 | ------------ 197 | 198 | Please cite the following work if you found GeoTorch useful. This paper exposes a simplified mathematical explanation of part of the inner-workings of GeoTorch. 199 | 200 | .. code:: bibtex 201 | 202 | @inproceedings{lezcano2019trivializations, 203 | title = {Trivializations for gradient-based optimization on manifolds}, 204 | author = {Lezcano-Casado, Mario}, 205 | booktitle={Advances in Neural Information Processing Systems, NeurIPS}, 206 | pages = {9154--9164}, 207 | year = {2019}, 208 | } 209 | 210 | 211 | .. |Build| image:: https://github.com/lezcano/geotorch/workflows/Build/badge.svg 212 | :target: https://github.com/lezcano/geotorch/workflows/Build/badge.svg 213 | :alt: Build 214 | .. |Docs| image:: https://readthedocs.org/projects/geotorch/badge/?version=latest 215 | :target: https://geotorch.readthedocs.io/en/latest/?badge=latest 216 | .. |Codecov| image:: https://codecov.io/gh/Lezcano/geotorch/branch/master/graph/badge.svg?token=1AKM2EQ7RT 217 | :target: https://codecov.io/gh/Lezcano/geotorch/branch/master/graph/badge.svg?token=1AKM2EQ7RT 218 | :alt: Code coverage 219 | .. |Codestyle Black| image:: https://img.shields.io/badge/code%20style-black-000000.svg 220 | :target: https://github.com/ambv/black 221 | :alt: Codestyle Black 222 | .. |License| image:: https://img.shields.io/badge/license-MIT-green.svg 223 | :target: https://github.com/Lezcano/geotorch/blob/master/LICENSE 224 | :alt: License 225 | 226 | .. _here: https://github.com/Lezcano/geotorch/blob/master/examples/copying_problem.py#L16 227 | .. _torch.nn.utils.parametrize: https://pytorch.org/docs/stable/generated/torch.nn.utils.parametrize.register_parametrization.html 228 | .. _torch.nn.utils.parametrizations: https://pytorch.org/docs/stable/generated/torch.nn.utils.parametrizations.orthogonal.html 229 | .. _geotorch/parametrize.py: https://github.com/Lezcano/geotorch/blob/master/geotorch/parametrize.py 230 | .. _examples/sequential_mnist.py: https://github.com/Lezcano/geotorch/blob/master/examples/sequential_mnist.py 231 | .. _examples/copying_problem.py: https://github.com/Lezcano/geotorch/blob/master/examples/copying_problem.py 232 | .. _examples/parametrisations.ipynb: https://github.com/Lezcano/geotorch/blob/master/examples/parametrisations.ipynb 233 | 234 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | This folder contains the documentation of GeoTorch. 2 | 3 | To make the documentation you can run 4 | 5 | ``` 6 | make html 7 | ``` 8 | 9 | To generate the docs from the GeoTorch source excluding the `parametrize.py` file, run 10 | 11 | ``` 12 | SPHINX_APIDOC_OPTIONS=members sphinx-apidoc -o ./source ../geotorch ../geotorch/parametrize.py 13 | ``` 14 | 15 | To check the spelling 16 | ``` 17 | sphinx-build -b spelling docs/source docs/build 18 | ``` 19 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | import geotorch 17 | 18 | sys.path.insert(0, os.path.abspath("../../geotorch/")) 19 | 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = "geotorch" 24 | copyright = "2020-, Mario Lezcano-Casado" 25 | author = "Mario Lezcano-Casado" 26 | 27 | # The short X.Y version. 28 | version = geotorch.__version__ 29 | # The full version, including alpha/beta/rc tags 30 | release = geotorch.__version__ 31 | 32 | 33 | # -- General configuration --------------------------------------------------- 34 | 35 | # Add any Sphinx extension module names here, as strings. They can be 36 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 37 | # ones. 38 | extensions = [ 39 | "sphinx.ext.autodoc", 40 | "sphinx.ext.coverage", 41 | "sphinx.ext.mathjax", 42 | "sphinx.ext.viewcode", 43 | "sphinx.ext.autosectionlabel", 44 | "sphinx.ext.napoleon", 45 | "sphinxcontrib.spelling", 46 | ] 47 | 48 | # Spelling options 49 | spelling_word_list_filename = "spelling_wordlist.txt" 50 | 51 | # Add any paths that contain templates here, relative to this directory. 52 | templates_path = ["_templates"] 53 | 54 | # List of patterns, relative to source directory, that match files and 55 | # directories to ignore when looking for source files. 56 | # This pattern also affects html_static_path and html_extra_path. 57 | exclude_patterns = [] 58 | 59 | # Add __init__ documentation to classes 60 | autoclass_content = "both" 61 | 62 | # -- Options for HTML output ------------------------------------------------- 63 | 64 | # The theme to use for HTML and HTML Help pages. See the documentation for 65 | # a list of builtin themes. 66 | # 67 | html_theme = "sphinx_rtd_theme" 68 | 69 | # Add any paths that contain custom static files (such as style sheets) here, 70 | # relative to this directory. They are copied after the builtin static files, 71 | # so a file named "default.css" will overwrite the builtin "default.css". 72 | html_static_path = [] 73 | 74 | html_css_files = [] 75 | 76 | html_theme_options = { 77 | "collapse_navigation": False, 78 | } 79 | 80 | 81 | # mathjax_config = { 82 | # 'jax': ['output/CommonHTML'], 83 | # } 84 | 85 | # Don't show the source of each page 86 | html_show_sourcelink = False 87 | -------------------------------------------------------------------------------- /docs/source/constraints.rst: -------------------------------------------------------------------------------- 1 | Constraints API 2 | =============== 3 | 4 | .. currentmodule:: geotorch 5 | 6 | These are the functions that form the basic interface of `GeoTorch`. They all provide a common 7 | interface to the spaces implemented in `GeoTorch`. For a finer control over the underlying 8 | implementation of each space, take a look at the different manifolds. 9 | 10 | .. autofunction:: symmetric 11 | .. autofunction:: skew 12 | .. autofunction:: sphere 13 | .. autofunction:: orthogonal 14 | .. autofunction:: almost_orthogonal 15 | .. autofunction:: grassmannian 16 | .. autofunction:: low_rank 17 | .. autofunction:: fixed_rank 18 | .. autofunction:: invertible 19 | .. autofunction:: sln 20 | .. autofunction:: positive_definite 21 | .. autofunction:: positive_semidefinite 22 | .. autofunction:: positive_semidefinite_low_rank 23 | .. autofunction:: positive_semidefinite_fixed_rank 24 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to GeoTorch's documentation! 2 | ==================================== 3 | 4 | .. toctree:: 5 | :caption: GeoTorch 6 | :hidden: 7 | 8 | Introduction 9 | 10 | 11 | .. include:: ../../README.rst 12 | :start-line: 2 13 | 14 | 15 | .. toctree:: 16 | :maxdepth: 2 17 | :caption: Constrained Optimization 18 | :hidden: 19 | 20 | constraints 21 | 22 | 23 | .. toctree:: 24 | :maxdepth: 2 25 | :caption: Manifolds 26 | :hidden: 27 | 28 | vector_spaces/index 29 | orthogonal/index 30 | invertibility/index 31 | lowrank/index 32 | psd/index 33 | product 34 | 35 | .. Indices and tables 36 | .. ================== 37 | .. 38 | .. * :ref:`genindex` 39 | .. * :ref:`modindex` 40 | .. * :ref:`search` 41 | -------------------------------------------------------------------------------- /docs/source/invertibility/glp.rst: -------------------------------------------------------------------------------- 1 | General Linear Group 2 | ==================== 3 | 4 | .. currentmodule:: geotorch 5 | 6 | :math:`\operatorname{GL^+}(n)` is the manifold of invertible matrices of positive determinant 7 | 8 | .. math:: 9 | 10 | \operatorname{GL^+}(n) = \{X \in \mathbb{R}^{n\times n}\:\mid\:\det(X) > 0\} 11 | 12 | It is realized via an SVD-like factorization: 13 | 14 | .. math:: 15 | 16 | \begin{align*} 17 | \pi \colon \operatorname{SO}(n) \times \mathbb{R}^n \times \operatorname{SO}(n) 18 | &\to \operatorname{GL^+}(n) \\ 19 | (U, \Sigma, V) &\mapsto Uf(\Sigma)V^\intercal 20 | \end{align*} 21 | 22 | where we have identified the vector :math:`\Sigma` with a diagonal matrix in :math:`\mathbb{R}^{n \times n}`. The function :math:`f\colon \mathbb{R} \to (0, \infty)` is applied element-wise to the diagonal. By default, the `softplus` function is used 23 | 24 | .. math:: 25 | 26 | \begin{align*} 27 | \operatorname{softplus} \colon \mathbb{R} &\to (0, \infty) \\ 28 | x &\mapsto \log(1+\exp(x)) + \varepsilon 29 | \end{align*} 30 | 31 | where we use a small :math:`\varepsilon > 0` for numerical stability. 32 | 33 | .. autoclass:: GLp 34 | 35 | .. automethod:: sample 36 | .. automethod:: in_manifold 37 | -------------------------------------------------------------------------------- /docs/source/invertibility/index.rst: -------------------------------------------------------------------------------- 1 | Invertibility Constraints 2 | ------------------------- 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | glp 8 | -------------------------------------------------------------------------------- /docs/source/invertibility/sl.rst: -------------------------------------------------------------------------------- 1 | General Linear Group 2 | ==================== 3 | 4 | .. currentmodule:: geotorch 5 | 6 | :math:`\operatorname{SL}(n)` is the manifold of matrices of determinant equal to 1. 7 | 8 | .. math:: 9 | 10 | \operatorname{SL}(n) = \{X \in \mathbb{R}^{n\times n}\:\mid\:\det(X) = 1\} 11 | 12 | It is realized via an SVD-like factorization: 13 | 14 | .. math:: 15 | 16 | \begin{align*} 17 | \pi \colon \operatorname{SO}(n) \times \mathbb{R}^n \times \operatorname{SO}(n) 18 | &\to \operatorname{SL}(n) \\ 19 | (U, \Sigma, V) &\mapsto Uf(\Sigma)V^\intercal 20 | \end{align*} 21 | 22 | where we have identified the vector :math:`\Sigma` with a diagonal matrix in :math:`\mathbb{R}^{n \times n}`. The function :math:`f\colon \mathbb{R} \to (\varepsilon, \infty)` is applied element-wise to the diagonal for a small :math:`\varepsilon > 0`. By default, a combination of the `softplus` function 23 | 24 | .. math:: 25 | 26 | \begin{align*} 27 | \operatorname{softplus} \colon \mathbb{R} &\to (\varepsilon, \infty) \\ 28 | x &\mapsto \log(1+\exp(x)) + \varepsilon 29 | \end{align*} 30 | 31 | composed with the normalization function 32 | 33 | .. math:: 34 | 35 | \begin{align*} 36 | \operatorname{g} \colon \mathbb{R}^n &\to (\varepsilon, \infty)^n \\ 37 | (x_1, \dots, x_n) &\mapsto \left(\frac{x_i}{\sqrt[\leftroot{-2}\uproot{2}n]{\prod_i x_i}}\right)_i 38 | \end{align*} 39 | 40 | to ensure that the product of all the singular values is equal to 1. 41 | 42 | .. autoclass:: SL 43 | 44 | .. automethod:: sample 45 | .. automethod:: in_manifold 46 | -------------------------------------------------------------------------------- /docs/source/lowrank/fixedrank.rst: -------------------------------------------------------------------------------- 1 | Fixed Rank Matrices 2 | =================== 3 | 4 | .. currentmodule:: geotorch 5 | 6 | :math:`\operatorname{FixedRank}(n,k,r)` is the manifold of matrices of rank equal 7 | to :math:`r`, for a given :math:`r \leq \min\{n, k\}`: 8 | 9 | .. math:: 10 | 11 | \operatorname{FixedRank}(n,k,r) = \{X \in \mathbb{R}^{n\times k}\:\mid\:\operatorname{rank}(X) = r\} 12 | 13 | It is realized via an SVD-like factorization: 14 | 15 | .. math:: 16 | 17 | \begin{align*} 18 | \pi \colon \operatorname{St}(n,r) \times \mathbb{R}^r \times \operatorname{St}(k, r) 19 | &\to \operatorname{FixedRank}(n,k,r) \\ 20 | (U, \Sigma, V) &\mapsto Uf(\Sigma)V^\intercal 21 | \end{align*} 22 | 23 | where we have identified the vector :math:`\Sigma` with a diagonal matrix in :math:`\mathbb{R}^{r \times r}`. The function :math:`f\colon \mathbb{R} \to (0, \infty)` is applied element-wise to the diagonal. By default, the `softmax` function is used 24 | 25 | .. math:: 26 | 27 | \begin{align*} 28 | \operatorname{softmax} \colon \mathbb{R} &\to (0, \infty) \\ 29 | x &\mapsto \log(1+\exp(x)) + \varepsilon 30 | \end{align*} 31 | 32 | where we use a small :math:`\varepsilon > 0` for numerical stability. 33 | 34 | .. note:: 35 | 36 | For practical applications, it will be almost always more convenient to use the class :class:`LowRank`, as it is less restrictive, and most of the times it will converge to a max-rank solution anyway. 37 | 38 | .. autoclass:: FixedRank 39 | 40 | .. automethod:: sample 41 | .. automethod:: in_manifold 42 | -------------------------------------------------------------------------------- /docs/source/lowrank/index.rst: -------------------------------------------------------------------------------- 1 | Low Rank Constraints 2 | -------------------- 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | lowrank 8 | fixedrank 9 | -------------------------------------------------------------------------------- /docs/source/lowrank/lowrank.rst: -------------------------------------------------------------------------------- 1 | Low Rank Matrices 2 | ================= 3 | 4 | .. currentmodule:: geotorch 5 | 6 | :math:`\operatorname{LowRank}(n,k,r)` is the algebraic variety of matrices of rank less or equal 7 | to :math:`r`, for a given :math:`r \leq \min\{n, k\}`: 8 | 9 | .. math:: 10 | 11 | \operatorname{LowRank}(n,k,r) = \{X \in \mathbb{R}^{n\times k}\:\mid\:\operatorname{rank}(X) \leq r\} 12 | 13 | It is realized via an SVD-like factorization: 14 | 15 | .. math:: 16 | 17 | \begin{align*} 18 | \pi \colon \operatorname{St}(n,r) \times \mathbb{R}^r \times \operatorname{St}(k, r) 19 | &\to \operatorname{LowRank}(n,k,r) \\ 20 | (U, \Sigma, V) &\mapsto U\Sigma V^\intercal 21 | \end{align*} 22 | 23 | where we have identified the vector :math:`\Sigma` with a diagonal matrix in :math:`\mathbb{R}^{r \times r}`. 24 | 25 | .. autoclass:: LowRank 26 | 27 | .. automethod:: sample 28 | .. automethod:: in_manifold 29 | -------------------------------------------------------------------------------- /docs/source/orthogonal/almostorthogonal.rst: -------------------------------------------------------------------------------- 1 | Almost Orthogonal Matrices 2 | ========================== 3 | 4 | .. currentmodule:: geotorch 5 | 6 | :math:`\operatorname{AlmostOrthogonal}(n,k,\lambda)` is the manifold matrices with singular values in the interval :math:`(1-\lambda, 1+\lambda)` for a :math:`\lambda \in [0,1]`. 7 | 8 | .. math:: 9 | 10 | \operatorname{AlmostOrthogonal}(n,k,\lambda) = \{X \in \mathbb{R}^{n\times k}\:\mid\:\left|1-\sigma_i(X)\right| < \lambda,\ i=1, \dots, k\} 11 | 12 | It is realized via an SVD-like factorization: 13 | 14 | .. math:: 15 | 16 | \begin{align*} 17 | \pi \colon \operatorname{St}(n,k) \times \mathbb{R}^k \times \operatorname{SO}(k) 18 | &\to \operatorname{AlmostOrthogonal}(n,k,\lambda) \\ 19 | (U, \Sigma, V) &\mapsto Uf_\lambda(\Sigma) V^\intercal 20 | \end{align*} 21 | 22 | where we have identified the vector :math:`\Sigma` with a diagonal matrix in :math:`\mathbb{R}^{k \times k}`. 23 | The function :math:`f_\lambda\colon \mathbb{R} \to (1-\lambda, 1+\lambda)` takes a function :math:`f\colon \mathbb{R} \to (-1, +1)` and rescales it to be a function on :math:`(1-\lambda, 1+\lambda)` as 24 | 25 | .. math:: 26 | 27 | f_\lambda(x) = 1+\lambda f(x). 28 | 29 | The function :math:`f_\lambda` is then applied element-wise to the diagonal of :math:`\Sigma`. 30 | 31 | If :math:`\lambda = 1` is chosen, the resulting space is not a manifold, although this should not hurt optimization in practice. 32 | 33 | .. warning:: 34 | 35 | In the limit :math:`\lambda = 0`, the resulting manifold is exactly :ref:`sec-so`. For this reason, we discourage the use of small values of :math:`\lambda` as the algorithm in this class becomes numerically unstable for very small :math:`\lambda`. We recommend to use :class:`geotorch.SO` rather than this one in this scenario. 36 | 37 | .. note:: 38 | 39 | There are no restrictions in place for the image of the function :math:`f`. For a function :math:`f` with image :math:`[a,b]`, the function :math:`f_\lambda` will take values in :math:`[\lambda (1+a), \lambda (1+b)]`. As such, rescaling the function :math:`f`, one may use this class to perform optimization with singular values constrained to any prescribed interval of :math:`\mathbb{R}_{\geq 0}`. 40 | 41 | 42 | .. autoclass:: AlmostOrthogonal 43 | 44 | .. automethod:: sample 45 | .. automethod:: in_manifold 46 | -------------------------------------------------------------------------------- /docs/source/orthogonal/grassmannian.rst: -------------------------------------------------------------------------------- 1 | Grassmannian Manifold 2 | ===================== 3 | 4 | .. currentmodule:: geotorch 5 | 6 | :math:`\operatorname{Gr}(n,k)` is the Grassmannian manifold, that is, the subspaces of dimension 7 | :math:`k` in :math:`\mathbb{R}^n`. A subspace is represented by an element of the :ref:`Stiefel Manifold `, which represents the vectors in the basis that span the subspace. 8 | 9 | 10 | .. autoclass:: Grassmannian 11 | 12 | .. automethod:: sample 13 | .. automethod:: in_manifold 14 | -------------------------------------------------------------------------------- /docs/source/orthogonal/index.rst: -------------------------------------------------------------------------------- 1 | Orthogonal Constraints 2 | ---------------------- 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | sphere 8 | so 9 | stiefel 10 | almostorthogonal 11 | grassmannian 12 | -------------------------------------------------------------------------------- /docs/source/orthogonal/so.rst: -------------------------------------------------------------------------------- 1 | .. _sec-so: 2 | 3 | Special Orthogonal Group 4 | ======================== 5 | 6 | .. currentmodule:: geotorch 7 | 8 | :math:`\operatorname{SO}(n)` is the special orthogonal group, that is, the square matrices with 9 | orthonormal columns and positive determinant: 10 | 11 | .. math:: 12 | 13 | \operatorname{SO}(n) = \{X \in \mathbb{R}^{n\times n}\:\mid\:X^\intercal X = \mathrm{I}_n,\,\det(X) = 1\} 14 | 15 | 16 | .. autoclass:: SO 17 | 18 | .. automethod:: sample 19 | .. automethod:: in_manifold 20 | 21 | .. autofunction:: geotorch.so.uniform_init_ 22 | .. autofunction:: geotorch.so.torus_init_ 23 | -------------------------------------------------------------------------------- /docs/source/orthogonal/sphere.rst: -------------------------------------------------------------------------------- 1 | Sphere 2 | ====== 3 | 4 | .. currentmodule:: geotorch 5 | 6 | :math:`\operatorname{Sphere}(n, r)` is the sphere in :math:`\mathbb{R}^n` 7 | with radius :math:`r`: 8 | 9 | .. math:: 10 | 11 | \operatorname{Sphere}(n,r) = \{x \in \mathbb{R}^n\:\mid\:\lVert x \rVert = r\} 12 | 13 | .. warning:: 14 | 15 | In mathematics, :math:`\mathbb{S}^n` represents the :math:`n`-dimensional sphere. 16 | With this notation, :math:`\operatorname{Sphere}(n, 1.) = \mathbb{S}^{n-1}`. 17 | 18 | .. autoclass:: Sphere 19 | 20 | .. automethod:: sample 21 | .. automethod:: in_manifold 22 | 23 | .. autoclass:: SphereEmbedded 24 | 25 | .. automethod:: sample 26 | .. automethod:: in_manifold 27 | 28 | .. autofunction:: geotorch.sphere.uniform_init_sphere_ 29 | -------------------------------------------------------------------------------- /docs/source/orthogonal/stiefel.rst: -------------------------------------------------------------------------------- 1 | .. _RST Stiefel: 2 | 3 | Stiefel Manifold 4 | ================ 5 | 6 | .. currentmodule:: geotorch 7 | 8 | :math:`\operatorname{St}(n,k)` is the Stiefel manifold, that is, the rectangular matrices with 9 | orthonormal columns for :math:`n \geq k`: 10 | 11 | .. math:: 12 | 13 | \operatorname{St}(n,k) = \{X \in \mathbb{R}^{n\times k}\:\mid\:X^\intercal X = \mathrm{I}_k\} 14 | 15 | If :math:`n < k`, then we consider the space of matrices with orthonormal rows, that is, 16 | :math:`X^\intercal \in \operatorname{St}(n,k)`. 17 | 18 | 19 | 20 | .. autoclass:: Stiefel 21 | 22 | .. automethod:: sample 23 | .. automethod:: in_manifold 24 | -------------------------------------------------------------------------------- /docs/source/product.rst: -------------------------------------------------------------------------------- 1 | Product Manifold 2 | ================ 3 | 4 | .. currentmodule:: geotorch.product 5 | 6 | 7 | This class implements a product of manifolds :math:`M_1\times \cdots \times M_k`. They are 8 | stored as a list. 9 | 10 | .. autoclass:: ProductManifold 11 | -------------------------------------------------------------------------------- /docs/source/psd/index.rst: -------------------------------------------------------------------------------- 1 | Positive Definiteness Constraints 2 | --------------------------------- 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | psd 8 | pssd 9 | pssdlowrank 10 | pssdfixedrank 11 | -------------------------------------------------------------------------------- /docs/source/psd/psd.rst: -------------------------------------------------------------------------------- 1 | Positive Definite Matrices 2 | ========================== 3 | 4 | .. currentmodule:: geotorch 5 | 6 | :math:`\operatorname{PSD}(n)` is the manifold of positive definite matrices. 7 | 8 | .. math:: 9 | 10 | \operatorname{PSD}(n) = \{X \in \mathbb{R}^{n\times n}\:\mid\:X \succ 0\}. 11 | 12 | It is realized via an eigenvalue-like factorization: 13 | 14 | .. math:: 15 | 16 | \begin{align*} 17 | \pi \colon \operatorname{SO}(n) \times \mathbb{R}^n 18 | &\to \operatorname{PSD}(n) \\ 19 | (Q, \Lambda) &\mapsto Qf(\Lambda)Q^\intercal 20 | \end{align*} 21 | 22 | where we have identified the vector :math:`\Lambda` with a diagonal matrix in :math:`\mathbb{R}^{n \times n}`. The function :math:`f\colon \mathbb{R} \to (0, \infty)` is applied element-wise to the diagonal. By default, the `softmax` function is used 23 | 24 | .. math:: 25 | 26 | \begin{align*} 27 | \operatorname{softmax} \colon \mathbb{R} &\to (0, \infty) \\ 28 | x &\mapsto \log(1+\exp(x)) + \varepsilon 29 | \end{align*} 30 | 31 | where we use a small :math:`\varepsilon > 0` for numerical stability. 32 | 33 | .. note:: 34 | 35 | For practical applications, it is more convenient to use the class :class:`geotorch.PSSD`, unless the positive definiteness condition is essential. This is because :class:`geotorch.PSSD` is less restrictive, and most of the times it will converge to a max-rank solution anyway, although in the optimization process there might be times when the matrix might become almost singular. 36 | 37 | .. autoclass:: PSD 38 | 39 | .. automethod:: sample 40 | .. automethod:: in_manifold 41 | -------------------------------------------------------------------------------- /docs/source/psd/pssd.rst: -------------------------------------------------------------------------------- 1 | Positive Semidefinite Matrices 2 | ============================== 3 | 4 | .. currentmodule:: geotorch 5 | 6 | :math:`\operatorname{PSSD}(n)` is the algebraic variety of positive semidefinite matrices. 7 | 8 | .. math:: 9 | 10 | \operatorname{PSSD}(n,r) = \{X \in \mathbb{R}^{n\times n}\:\mid\:X \succeq 0\} 11 | 12 | It is realized via an eigenvalue-like factorization: 13 | 14 | .. math:: 15 | 16 | \begin{align*} 17 | \pi \colon \operatorname{SO}(n) \times \mathbb{R}^n 18 | &\to \operatorname{PSSD}(n) \\ 19 | (Q, \Lambda) &\mapsto Q\left|\Lambda\right|Q^\intercal 20 | \end{align*} 21 | 22 | where we have identified the vector :math:`\Lambda` with a diagonal matrix in :math:`\mathbb{R}^{n \times n}` and :math:`\left|\Lambda\right|` denotes the absolute value of the diagonal entries. 23 | 24 | .. autoclass:: PSSD 25 | 26 | .. automethod:: sample 27 | .. automethod:: in_manifold 28 | -------------------------------------------------------------------------------- /docs/source/psd/pssdfixedrank.rst: -------------------------------------------------------------------------------- 1 | Positive Semidefinite Fixed Rank Matrices 2 | ========================================= 3 | 4 | .. currentmodule:: geotorch 5 | 6 | :math:`\operatorname{PSSDFixedRank}(n,r)` is the manifold of positive semidefinite matrices with rank equal 7 | to :math:`r`, for a given :math:`r \leq n`: 8 | 9 | .. math:: 10 | 11 | \operatorname{PSSDFixedRank}(n,r) = \{X \in \mathbb{R}^{n\times n}\:\mid\:X \succeq 0,\,\operatorname{rank}(X) = r\} 12 | 13 | It is realized via an eigenvalue-like factorization: 14 | 15 | .. math:: 16 | 17 | \begin{align*} 18 | \pi \colon \operatorname{St}(n,r) \times \mathbb{R}^r 19 | &\to \operatorname{PSSDFixedRank}(n,r) \\ 20 | (Q, \Lambda) &\mapsto Qf(\Lambda)Q^\intercal 21 | \end{align*} 22 | 23 | where we have identified the vector :math:`\Lambda` with a diagonal matrix in :math:`\mathbb{R}^{r \times r}`. The function :math:`f\colon \mathbb{R} \to (0, \infty)` is applied element-wise to the diagonal. By default, the `softmax` function is used 24 | 25 | .. math:: 26 | 27 | \begin{align*} 28 | \operatorname{softmax} \colon \mathbb{R} &\to (0, \infty) \\ 29 | x &\mapsto \log(1+\exp(x)) + \varepsilon 30 | \end{align*} 31 | 32 | where we use a small :math:`\varepsilon > 0` for numerical stability. 33 | 34 | .. note:: 35 | 36 | For practical applications, it will be almost always more convenient to use the class :class:`geotorch.PSSDLowRank`, as it is less restrictive, and most of the times it will converge to a max-rank solution anyway. 37 | 38 | .. autoclass:: PSSDFixedRank 39 | 40 | .. automethod:: sample 41 | .. automethod:: in_manifold 42 | -------------------------------------------------------------------------------- /docs/source/psd/pssdlowrank.rst: -------------------------------------------------------------------------------- 1 | Positive Semidefinite Low Rank Matrices 2 | ======================================== 3 | 4 | .. currentmodule:: geotorch 5 | 6 | :math:`\operatorname{PSSDLowRank}(n,r)` is the algebraic variety of positive semidefinite matrices 7 | of rank less or equal to :math:`r`, for a given :math:`r \leq n`: 8 | 9 | .. math:: 10 | 11 | \operatorname{PSSDLowRank}(n,r) = \{X \in \mathbb{R}^{n\times n}\:\mid\:X \succeq 0,\,\operatorname{rank}(X) \leq r\} 12 | 13 | It is realized via an eigenvalue-like factorization: 14 | 15 | .. math:: 16 | 17 | \begin{align*} 18 | \pi \colon \operatorname{St}(n,r) \times \mathbb{R}^r 19 | &\to \operatorname{PSSDLowRank}(n,r) \\ 20 | (Q, \Lambda) &\mapsto Q\left|\Lambda\right| Q^\intercal 21 | \end{align*} 22 | 23 | where we have identified the vector :math:`\Lambda` with a diagonal matrix in :math:`\mathbb{R}^{r \times r}` and :math:`\left|\Lambda\right|` denotes the absolute value of the diagonal entries. 24 | 25 | .. autoclass:: PSSDLowRank 26 | 27 | .. automethod:: sample 28 | .. automethod:: in_manifold 29 | -------------------------------------------------------------------------------- /docs/source/spelling_wordlist.txt: -------------------------------------------------------------------------------- 1 | surjective 2 | surjectively 3 | semidefinite 4 | eigen 5 | Stiefel 6 | Grassmannian 7 | SVD 8 | overparametrized 9 | invertibility 10 | Riemannian 11 | reimplement 12 | cayley 13 | baseclasses 14 | functors 15 | parametrizing 16 | Parametrizations 17 | Rⁿ 18 | ᵏ 19 | Sⁿ 20 | dπ 21 | π 22 | Mₖ 23 | callables 24 | -------------------------------------------------------------------------------- /docs/source/vector_spaces/index.rst: -------------------------------------------------------------------------------- 1 | Vector Spaces 2 | ------------- 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | reals 8 | symmetric 9 | skew 10 | -------------------------------------------------------------------------------- /docs/source/vector_spaces/reals.rst: -------------------------------------------------------------------------------- 1 | Real vector space 2 | ================= 3 | 4 | .. currentmodule:: geotorch 5 | 6 | :math:`\mathbb{R}^n` is the vector space of unconstrained vectors. 7 | This class is useful when composed with other manifolds to 8 | form more interesting spaces using the classes from :ref:`Product Manifold`. 9 | 10 | .. autoclass:: Rn 11 | 12 | .. automethod:: in_manifold 13 | -------------------------------------------------------------------------------- /docs/source/vector_spaces/skew.rst: -------------------------------------------------------------------------------- 1 | Skew-symmetric Matrices 2 | ======================= 3 | 4 | .. currentmodule:: geotorch 5 | 6 | :math:`\operatorname{Skew}(n)` is the vector space of skew-symmetric matrices: 7 | 8 | .. math:: 9 | 10 | \operatorname{Skew}(n) = \{X \in \mathbb{R}^{n \times n}\:\mid\: X^\intercal = -X \} 11 | 12 | .. autoclass:: Skew 13 | 14 | .. automethod:: in_manifold 15 | -------------------------------------------------------------------------------- /docs/source/vector_spaces/symmetric.rst: -------------------------------------------------------------------------------- 1 | Symmetric Matrices 2 | ================== 3 | 4 | .. currentmodule:: geotorch 5 | 6 | :math:`\operatorname{Sym}(n)` is the vector space of symmetric matrices: 7 | 8 | .. math:: 9 | 10 | \operatorname{Sym}(n) = \{X \in \mathbb{R}^{n \times n}\:\mid\: X^\intercal = X \} 11 | 12 | .. autoclass:: Symmetric 13 | 14 | .. automethod:: in_manifold 15 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lezcano/geotorch/ba38d406c245d609fee4b4dac3f6427bf6d73a8e/examples/__init__.py -------------------------------------------------------------------------------- /examples/copying_problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic example of usage of GeoTorch 3 | 4 | Implements a constrained RNN to learn a synthetic regression problem which asks to recall 5 | some inputs and output them. From a dictionary of 9 numbers, the input-output looks like follows: 6 | 7 | Input: 14221----------:---- 8 | Output: ---------------14221 9 | 10 | This class should converge to 0% error. When at 0% error, sometimes there are some instabilities. 11 | 12 | The GeoTorch code happens in `ExpRNNCell.__init__`, `ExpRNNCell.reset_parameters` and line 107. 13 | The rest of the code is normal PyTorch. 14 | Lines 146-167 shows how to assign different learning rates to parametrized weights. 15 | 16 | This file also implements in lines 152 and 180 Riemannian Gradient Descent (RGD). As shown, RGD 17 | dynamics account for using SGD as the optimizer calling `update_basis()` after every optization step. 18 | """ 19 | 20 | import torch 21 | from torch import nn 22 | import torch.nn.functional as F 23 | 24 | import geotorch 25 | 26 | batch_size = 128 27 | hidden_size = 190 28 | iterations = 4000 # Training iterations 29 | L = 1000 # Length of sequence before asking to remember 30 | S = 10 # Length of sequence to remember 31 | alphabet_size = 8 32 | lr = 1e-3 33 | lr_orth = 2e-4 34 | device = torch.device("cuda") 35 | # When RGD == True we perform Riemannian gradient descent 36 | # This is to demonstrate how one may implement RGD with 37 | # just one extra line of code. 38 | # RGD does not perform very well in these problems though. 39 | RGD = False 40 | if RGD: 41 | print( 42 | "Optimizing using RGD. The perfomance will be _much_ worse than with Adam or RMSprop." 43 | ) 44 | 45 | 46 | class modrelu(nn.Module): 47 | def __init__(self, features): 48 | super(modrelu, self).__init__() 49 | self.features = features 50 | self.b = nn.Parameter(torch.Tensor(self.features)) 51 | self.reset_parameters() 52 | 53 | def reset_parameters(self): 54 | self.b.data.uniform_(-0.01, 0.01) 55 | 56 | def forward(self, inputs): 57 | norm = torch.abs(inputs) 58 | biased_norm = norm + self.b 59 | magnitude = nn.functional.relu(biased_norm) 60 | phase = torch.sign(inputs) 61 | 62 | return phase * magnitude 63 | 64 | 65 | class ExpRNNCell(nn.Module): 66 | def __init__(self, input_size, hidden_size): 67 | super(ExpRNNCell, self).__init__() 68 | self.input_size = input_size 69 | self.hidden_size = hidden_size 70 | self.recurrent_kernel = nn.Linear(hidden_size, hidden_size, bias=False) 71 | self.input_kernel = nn.Linear(input_size, hidden_size) 72 | self.nonlinearity = modrelu(hidden_size) 73 | 74 | # Make recurrent_kernel orthogonal 75 | geotorch.orthogonal(self.recurrent_kernel, "weight") 76 | 77 | self.reset_parameters() 78 | 79 | def reset_parameters(self): 80 | nn.init.kaiming_normal_(self.input_kernel.weight.data, nonlinearity="relu") 81 | # The manifold class is under `layer.parametrizations.tensor_name[0]` 82 | M = self.recurrent_kernel.parametrizations.weight[0] 83 | # Every manifold has a convenience sample method, but you can use your own initializer 84 | self.recurrent_kernel.weight = M.sample("torus") 85 | 86 | def default_hidden(self, input_): 87 | return input_.new_zeros(input_.size(0), self.hidden_size, requires_grad=False) 88 | 89 | def forward(self, input_, hidden): 90 | input_ = self.input_kernel(input_) 91 | hidden = self.recurrent_kernel(hidden) 92 | out = input_ + hidden 93 | return self.nonlinearity(out) 94 | 95 | 96 | class Model(nn.Module): 97 | def __init__(self, alphabet_size, hidden_size): 98 | super(Model, self).__init__() 99 | self.hidden_size = hidden_size 100 | self.rnn = ExpRNNCell(alphabet_size + 2, hidden_size) 101 | self.lin = nn.Linear(hidden_size, alphabet_size + 1) 102 | self.loss_func = nn.CrossEntropyLoss() 103 | self.reset_parameters() 104 | 105 | def reset_parameters(self): 106 | nn.init.kaiming_normal_(self.lin.weight.data, nonlinearity="relu") 107 | nn.init.constant_(self.lin.bias.data, 0) 108 | 109 | def forward(self, inputs): 110 | out_rnn = self.rnn.default_hidden(inputs[:, 0, ...]) 111 | outputs = [] 112 | with geotorch.parametrize.cached(): 113 | for input in torch.unbind(inputs, dim=1): 114 | out_rnn = self.rnn(input, out_rnn) 115 | outputs.append(self.lin(out_rnn)) 116 | return torch.stack(outputs, dim=1) 117 | 118 | def loss(self, logits, y): 119 | return self.loss_func(logits.view(-1, 9), y.view(-1)) 120 | 121 | def accuracy(self, logits, y): 122 | return torch.eq(torch.argmax(logits, dim=2), y).float().mean() 123 | 124 | 125 | def copy_data(batch_size): 126 | # Generates some random synthetic data 127 | # Example of input-output sequence 128 | # 14221----------:---- 129 | # ---------------14221 130 | # Numbers go from 1 to 8 131 | # We generate S of them and we have to recall them 132 | # L is the waiting between the last number and the 133 | # signal to start outputting the numbers 134 | # We codify `-` as a 0 and `:` as a 9. 135 | 136 | seq = torch.randint( 137 | 1, alphabet_size + 1, (batch_size, S), dtype=torch.long, device=device 138 | ) 139 | zeros1 = torch.zeros((batch_size, L), dtype=torch.long, device=device) 140 | zeros2 = torch.zeros((batch_size, S - 1), dtype=torch.long, device=device) 141 | zeros3 = torch.zeros((batch_size, S + L), dtype=torch.long, device=device) 142 | marker = torch.full( 143 | (batch_size, 1), alphabet_size + 1, dtype=torch.long, device=device 144 | ) 145 | 146 | x = torch.cat([seq, zeros1, marker, zeros2], dim=1) 147 | y = torch.cat([zeros3, seq], dim=1) 148 | 149 | return x, y 150 | 151 | 152 | def main(): 153 | model = Model(alphabet_size, hidden_size).to(device) 154 | 155 | p_orth = model.rnn.recurrent_kernel 156 | orth_params = p_orth.parameters() 157 | non_orth_params = ( 158 | p for p in model.parameters() if p not in set(p_orth.parameters()) 159 | ) 160 | 161 | if RGD: 162 | # Implement Stochstic Riemannian Gradient Descent via SGD 163 | optim = torch.optim.SGD( 164 | [{"params": non_orth_params}, {"params": orth_params, "lr": lr_orth}], lr=lr 165 | ) 166 | else: 167 | # These recurrent models benefit of slightly larger mixing constants 168 | # on the adaptive term. They also work with beta_2 = 0.999, but they 169 | # give better results with beta_2 \in [0.9, 0.99] 170 | optim = torch.optim.Adam( 171 | [ 172 | {"params": non_orth_params}, 173 | {"params": orth_params, "lr": lr_orth, "betas": (0.9, 0.95)}, 174 | ], 175 | lr=lr, 176 | ) 177 | 178 | model.train() 179 | for step in range(iterations): 180 | batch_x, batch_y = copy_data(batch_size) 181 | x_onehot = F.one_hot(batch_x, num_classes=alphabet_size + 2).float() 182 | logits = model(x_onehot) 183 | loss = model.loss(logits, batch_y) 184 | 185 | optim.zero_grad() 186 | loss.backward() 187 | optim.step() 188 | 189 | if RGD: 190 | # Updating the base after every step and using SGD gives us 191 | # Riemannian Gradient Descent. More on this in Section 5 192 | # https://arxiv.org/abs/1909.09501 193 | geotorch.update_base(model.rnn.recurrent_kernel, "weight") 194 | 195 | with torch.no_grad(): 196 | accuracy = model.accuracy(logits, batch_y) 197 | 198 | print("Iter {} Loss: {:.6f}, Accuracy: {:.5f}".format(step, loss, accuracy)) 199 | 200 | # The evaluation in this model is not quite necessary, as we do not repeat any 201 | # element of the training batch 202 | 203 | 204 | if __name__ == "__main__": 205 | main() 206 | -------------------------------------------------------------------------------- /examples/eigenvalue.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this program we show how to use GeoTorch to compute the maximum eigenvalue 3 | of a symmetric matrix via the Rayleigh quotient, restricting the optimisation 4 | problem to the Sphere 5 | """ 6 | import torch 7 | 8 | try: 9 | from torch.linalg import eigvalsh 10 | except ImportError: 11 | from torch import symeig 12 | 13 | def eigvalsh(X): 14 | return symeig(X, eigenvectors=False).eigenvalues 15 | 16 | 17 | from torch import nn 18 | import geotorch 19 | 20 | N = 1000 # matrix size 21 | LR = 1.0 / N # step-size. 22 | # Obs. If the distribution of the matrix is changed, this parameter should be tuned 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, n): 27 | super().__init__() 28 | self.x = nn.Parameter(torch.rand(n)) 29 | geotorch.sphere(self, "x") 30 | Sphere = self.parametrizations.x[0] 31 | self.x = Sphere.sample() 32 | 33 | def forward(self, A): 34 | x = self.x 35 | return x.T @ A @ x 36 | 37 | 38 | # Generate matrix 39 | A = torch.rand(N, N) # Uniform on [0, 1) 40 | A = 0.5 * (A + A.T) 41 | 42 | # Compare against diagonalization (eigenvalues are returend in ascending order) 43 | max_eigenvalue = eigvalsh(A)[-1] 44 | print("Max eigenvalue: {:10.5f}".format(max_eigenvalue)) 45 | 46 | # Instantiate model and optimiser 47 | model = Model(N) 48 | optim = torch.optim.SGD(model.parameters(), lr=LR) 49 | 50 | eigenvalue = float("inf") 51 | i = 0 52 | while (eigenvalue - max_eigenvalue).abs() > 1e-3: 53 | eigenvalue = model(A) 54 | 55 | optim.zero_grad() 56 | (-eigenvalue).backward() 57 | optim.step() 58 | print("{:2}. Best guess: {:10.5f}".format(i, eigenvalue.item())) 59 | i += 1 60 | 61 | print("Final error {:.5f}".format((eigenvalue - max_eigenvalue).abs())) 62 | -------------------------------------------------------------------------------- /examples/parametrisations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Basic imports and some hyperparameters\n", 10 | "import torch\n", 11 | "import torch.nn as nn\n", 12 | "\n", 13 | "# Matrices of size 4 x 4 or 4 x 5\n", 14 | "N = 4\n", 15 | "M = 5\n", 16 | "# Batch size of 3\n", 17 | "B = 3\n", 18 | "x = torch.rand(B, N, N)" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "# Parametrisations ([PR #33344](https://github.com/pytorch/pytorch/pull/33344))\n", 26 | "\n", 27 | "This notebook provides an introduction to the design of parametrisations in PyTorch. Parametrisations are the way `geotorch` works behind the scenes, so having some grip on how they work should greatly help in using `geotorch` effectively.\n", 28 | "\n", 29 | "## Motivating Example\n", 30 | "\n", 31 | "Given a function `f` and a `Parameter` `X` which is registered on a module, we would like to be able to use `f(X)` in place of `X`.\n", 32 | "\n", 33 | "This is easier understood with an example. Suppose that we want to have a linear layer whose matrix is symmetric. We could write:" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "class Symmetric(nn.Module):\n", 43 | " def __init__(self, n_features):\n", 44 | " super().__init__()\n", 45 | " self.weight = nn.Parameter(torch.rand(n_features, n_features))\n", 46 | "\n", 47 | " def forward(self, x):\n", 48 | " A = self.weight.triu()\n", 49 | " A = A + A.T\n", 50 | " #print(A) # A is symmetric\n", 51 | " return x @ A\n", 52 | "layer = Symmetric(N);\n", 53 | "layer(x); # It works as expected" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "This implementation has clearly two components. A reimplmenentation of `nn.Linear` and a parametrisation of the symmetric matrices:" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "class SymmetricParametrization(nn.Module):\n", 70 | " def forward(X):\n", 71 | " A = X.triu()\n", 72 | " return A + A.T " 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "## Objective\n", 80 | "\n", 81 | "We would like to separate these two, and have a mechanism to be able to inject a parametrisation onto a parameter or a buffer in a neural network. In particular, we would like to be able to do the following:\n", 82 | "\n", 83 | "```python\n", 84 | "layer = nn.Linear(N, N)\n", 85 | "torch.register_parametrization(layer, \"weight\", SymmetricParametrization())\n", 86 | "# layer now behaves as an object from the `Symmetric` class\n", 87 | "print(layer.weight) # Prints the symmetric matrix\n", 88 | "layer(x) # Multiplies the vectors `x` by the symmetric matrix layer.weight \n", 89 | "```" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "## Examples\n", 97 | "\n", 98 | "### Symmetric layers\n", 99 | "\n", 100 | "(see above)\n", 101 | "\n", 102 | "### Pruning\n", 103 | "When doing pruning, one samples a boolean mask of the size of the parameter and does an element-wise multiplication. It seems that one may train a neural network and then make it somewhat sparse, and everything magically works. This is called the \"lottery ticket hypothesis\". (see `torch.nn.utils.prune`)\n", 104 | "\n", 105 | "A simple pruning method that prunes an entry of the tensor with some given probability could go as:" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "class PruningParametrization(nn.Module):\n", 115 | " def __init__(self, X, p_drop=0.2):\n", 116 | " # sample zeros with probability p_drop\n", 117 | " mask = torch.full_like(X, 1.0 - p_drop)\n", 118 | " self.mask = torch.bernoulli(mask)\n", 119 | "\n", 120 | " def forward(self, X):\n", 121 | " return X * self.mask" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "We would like to use it as:\n", 129 | "```python\n", 130 | "cnn = nn.Conv2D(8, 16, (3, 3))\n", 131 | "torch.register_parametrization(cnn, \"weight\", PruningParametrization(cnn.weight, p_drop=0.1))\n", 132 | "# 10% of the entires of the tensor cnn.weight have now been zeroed out\n", 133 | "```\n", 134 | "### Other examples:\n", 135 | "- `torch.weight_norm`\n", 136 | "- `torch.spectral_norm` (to regularise the Lipschitz constant of a layer)\n", 137 | "- Optimisation with orthogonal constraints / invertible layers / Symmetric Positive Defininite layers... More on this later\n", 138 | "\n", 139 | "## Implementing `torch.register_parametrization`\n", 140 | "### A first approximation\n", 141 | "A moment's reflection shows that it is possible to implement `Symmetric` without having to reimplement `nn.Linear` by using inheritance and properties.\n", 142 | "\n", 143 | "```python\n", 144 | "class SymmetricRevisited(nn.Linear):\n", 145 | " def __init__(self, n_features):\n", 146 | " super().__init__(n_features, n_features, bias=False)\n", 147 | " # Rename weight attribute to _weight\n", 148 | " self._weight = self.weight\n", 149 | " delattr(self, \"weight\")\n", 150 | " \n", 151 | " @property\n", 152 | " def weight(self):\n", 153 | " A = self._weight.triu()\n", 154 | " return A + A.T\n", 155 | "```\n", 156 | "\n", 157 | "Note: This code does not work! It is possible to make it work using metaclasses (for example), but we will skip that.\n", 158 | "\n", 159 | "### A caching system\n", 160 | "\n", 161 | "Sometimes we use the same layer many times in the forward pass of a neural network (e.g., in the recurrent kernel of an RNN). In those cases, we would not want to recompute `layer.weight` every time we execute it. We would like to compute it a the beginning of the forward pass and cache the result throughout the whole forward pass.\n", 162 | "\n", 163 | "We can achieve that by implementing a caching system as follows:" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "from contextlib import contextmanager\n", 173 | "_cache_enabled = 0\n", 174 | "_cache = {}\n", 175 | "\n", 176 | "@contextmanager\n", 177 | "def cached():\n", 178 | " global _cache\n", 179 | " global _cache_enabled\n", 180 | " _cache_enabled += 1\n", 181 | " try:\n", 182 | " yield\n", 183 | " finally:\n", 184 | " _cache_enabled -= 1\n", 185 | " if not _cache_enabled:\n", 186 | " _cache = {}\n", 187 | "\n", 188 | "class SymmetricCached(nn.Module):\n", 189 | " def __init__(self, n_features):\n", 190 | " super().__init__()\n", 191 | " # Rename weight attribute to _weight\n", 192 | " self._weight = nn.Parameter(torch.rand(n_features, n_features))\n", 193 | " \n", 194 | " def parametrization(self, X):\n", 195 | " print(\"Computing\")\n", 196 | " A = X.triu()\n", 197 | " return A + A.T\n", 198 | "\n", 199 | " @property\n", 200 | " def weight(self):\n", 201 | " global _cache\n", 202 | "\n", 203 | " key = (id(self), \"weight\")\n", 204 | " if key not in _cache:\n", 205 | " _cache[key] = self.parametrization(self._weight)\n", 206 | " return _cache[key]\n", 207 | " \n", 208 | " def forward(self, x):\n", 209 | " return x @ self.weight.T\n", 210 | "\n", 211 | "# Usage:\n", 212 | "layer = SymmetricCached(N)\n", 213 | "with cached():\n", 214 | " # Just computes the parametrization once\n", 215 | " print(layer.weight - layer.weight.T)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "\n", 223 | "### A generic implementation\n", 224 | "\n", 225 | "Now, all we need to do is to implement a function that, given a module, a name, and a parametrisation (i.e., another module), injects a property similar to how we did it manually in `SymmetricCached`. In particular, we have to write a function with signature\n", 226 | "```python\n", 227 | "def register_parametrization(module: Module, tensor_name: str, parametrization: Module) -> None:\n", 228 | "```\n", 229 | "that does:\n", 230 | "\n", 231 | "- Rename the tensor from `tensor_name` to `f\"_{tensor_name}\"`\n", 232 | "- Saves `parametrization` within `module` to use it in the forward pass\n", 233 | "- Injects a property with the name `tensor_name` that computes `parametrization(module[tensor_name])` when called\n", 234 | "\n", 235 | "The first two things are direct. To implement the third one, we use the `type` function." 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "def inject_property(module, tensor_name):\n", 245 | " # We create a new class so that we can inject properties in it\n", 246 | " cls_name = \"Parametrized\" + module.__class__.__name__\n", 247 | "\n", 248 | " # Define the getter\n", 249 | " def getter(module):\n", 250 | " global _cache\n", 251 | "\n", 252 | " key = _key(module, tensor_name)\n", 253 | " # If the _cache is not enabled or the caching was not enabled for this\n", 254 | " # tensor, this function just evaluates the parametrization\n", 255 | " if _cache_enabled and key in _cache:\n", 256 | " if _cache[key] is None:\n", 257 | " _cache[key] = module.parametrizations[tensor_name]()\n", 258 | " return _cache[key]\n", 259 | " else:\n", 260 | " return module.parametrizations[tensor_name]()\n", 261 | "\n", 262 | " # Define the setter\n", 263 | " def setter(module, value):\n", 264 | " module.parametrizations[tensor_name].initialize(value)\n", 265 | " \n", 266 | " # Create a new class that inherits from `module.__class__` and has a property called `tensor_name`\n", 267 | " param_cls = type(cls_name, (module.__class__,), {\n", 268 | " tensor_name: property(getter, setter)\n", 269 | " })\n", 270 | " module.__class__ = param_cls\n", 271 | "\n", 272 | "layer = nn.Linear(3, 4)\n", 273 | "inject_property(layer, \"weight\")\n", 274 | "print(type(layer))\n", 275 | "print(type(layer).weight)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "metadata": {}, 281 | "source": [ 282 | "## Other things that `torch.register_parametrization` implements:\n", 283 | "\n", 284 | "- If the module implements a `right_inverse` method (similar to a right-inverse of forward, more on this below), it allows initialising the parametrised buffer/parameter\n", 285 | "- It allows putting several parametrisations on the same buffer/parameter\n", 286 | "- It allows removing the parametrisations and leave the original parameter or the parametrised parameter\n", 287 | "- Any combination of the above\n", 288 | "\n", 289 | "## More applications of parametrizations\n", 290 | "\n", 291 | "- Constrained optimisation on manifold using `geotorch`!\n", 292 | "- Normalising flows. The `right_inverse` method can be implemented as a right-inverse of forward. \n", 293 | " - In the simplest case, if `forward` is a diffeomorphism, then this reduces to the usual normalising flows framework.\n", 294 | " - The general case comes when the forward is a [submersion](https://en.wikipedia.org/wiki/Submersion_(mathematics)) (a function with differentiable local right-inverses). An example of this is a linear layer from `R^n` to `R^k` with `n > k` that is full rank (e.g. a `k x n` matrix with orthogonal rows). Using a submersion, one may construct a generalisation of normalising flows that allows for dimensionality reduction. The simplest case of this setting comes from projecting a vector in `R^n` onto its first `k` compontents. This is called in the normalising flows literature \"multi-scale architecture\", and it was introduced in the model [real NVP](https://arxiv.org/abs/1605.08803).\n", 295 | " \n", 296 | "## Examples of some simple parametrisations, composing them, and initialising them" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "# This part assumes that you have `geotorch` installed. You can install it doing\n", 306 | "# pip install git+https://github.com/Lezcano/geotorch/\n", 307 | "import geotorch.parametrize as P\n", 308 | "\n", 309 | "class Skew(nn.Module):\n", 310 | " def forward(self, X):\n", 311 | " X = X.triu(1)\n", 312 | " return X - X.T\n", 313 | "\n", 314 | " def is_skew(self, X):\n", 315 | " return torch.allclose(X, -X.T)\n", 316 | "\n", 317 | " def right_inverse(self, X):\n", 318 | " if not self.is_skew(X):\n", 319 | " raise ValueError(\"This matirx is not skew-symmetric!\")\n", 320 | " return X.triu(1)\n", 321 | " \n", 322 | "# Skew.forward(Skew.right_inverse(X)) == X\n", 323 | "# In functional notation: Skew.forward o Skew.right_inverse = Id\n", 324 | "# In other words, right_inverse is a right inverse of forward.\n", 325 | "\n", 326 | "model = nn.Linear(5, 5)\n", 327 | "P.register_parametrization(model, \"weight\", Skew())\n", 328 | "# Just computes `model.weight` once\n", 329 | "with P.cached():\n", 330 | " assert(torch.allclose(model.weight, -model.weight.T))\n", 331 | "# Sample a skew matrix X and initialise the parametrised model.weight\n", 332 | "X = torch.rand(5,5)\n", 333 | "X = X - X.T\n", 334 | "model.weight = X\n", 335 | "assert(torch.allclose(model.weight, X))" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "class Orthogonal(nn.Module):\n", 345 | " def __init__(self, n):\n", 346 | " super().__init__()\n", 347 | " self.register_buffer(\"B\", torch.eye(n))\n", 348 | "\n", 349 | " def forward(self, A):\n", 350 | " # Cayley map: (I + A)(I - A)^{-1}\n", 351 | " # This is orthogonal whenever A is skew-symmetric\n", 352 | " Id = torch.eye(A.size(0))\n", 353 | " return self.B @ torch.solve(Id - A, Id + A).solution\n", 354 | "\n", 355 | " def is_orthogonal(self, X):\n", 356 | " Id = torch.eye(X.size(0))\n", 357 | " return torch.allclose(X.T @ X, Id, atol=1e-6)\n", 358 | "\n", 359 | " def right_inverse(self, X):\n", 360 | " if not self.is_orthogonal(X):\n", 361 | " raise ValueError(\"This matrix is not orthogonal!\")\n", 362 | " # cayley(0) == Id, so B @ cayley(0) == B\n", 363 | " self.B = X\n", 364 | " return torch.zeros_like(X)\n", 365 | "\n", 366 | "\n", 367 | "model = nn.Linear(5,5)\n", 368 | "P.register_parametrization(model, \"weight\", Skew())\n", 369 | "P.register_parametrization(model, \"weight\", Orthogonal(5))\n", 370 | "\n", 371 | "# Sample an orthogonal matrix and initialise the layer\n", 372 | "X = torch.empty_like(model.weight)\n", 373 | "nn.init.orthogonal_(X)\n", 374 | "model.weight = X\n", 375 | "\n", 376 | "# model.weight == X\n", 377 | "assert(torch.allclose(model.weight, X))\n", 378 | "\n", 379 | "# A more programmatic way of initialising the weight\n", 380 | "model.weight = nn.init.orthogonal_(model.weight)" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "metadata": {}, 387 | "outputs": [], 388 | "source": [] 389 | } 390 | ], 391 | "metadata": { 392 | "kernelspec": { 393 | "display_name": "pytorch", 394 | "language": "python", 395 | "name": "pytorch" 396 | }, 397 | "language_info": { 398 | "codemirror_mode": { 399 | "name": "ipython", 400 | "version": 3 401 | }, 402 | "file_extension": ".py", 403 | "mimetype": "text/x-python", 404 | "name": "python", 405 | "nbconvert_exporter": "python", 406 | "pygments_lexer": "ipython3", 407 | "version": "3.8.5" 408 | } 409 | }, 410 | "nbformat": 4, 411 | "nbformat_minor": 4 412 | } 413 | -------------------------------------------------------------------------------- /examples/sequential_mnist.py: -------------------------------------------------------------------------------- 1 | """ 2 | Slightly more advanced example of usage of GeoTorch 3 | 4 | Implements a constrained RNN to classify MNIST processing the images one pixel at a time 5 | A good result for this task for size 170 would be 98.0% accuracy with orthogonal constraints 6 | and 98.5% for the almostorthogonal. Lowrank is here for demonstration purposes, it should 7 | not perform as well as the other two 8 | 9 | The GeoTorch code happens in `ExpRNNCell.__init__`, `ExpRNNCell.reset_parameters` and line 132. 10 | The rest of the code is normal PyTorch. 11 | Lines 167-176 show how to assign different learning rates to parametrized weights 12 | """ 13 | 14 | import torch 15 | import torch.nn as nn 16 | import math 17 | import argparse 18 | from torchvision import datasets, transforms 19 | 20 | import geotorch 21 | from geotorch.so import torus_init_ 22 | 23 | parser = argparse.ArgumentParser(description="Exponential Layer MNIST Task") 24 | parser.add_argument("--batch_size", type=int, default=128) 25 | parser.add_argument("--hidden_size", type=int, default=170) 26 | parser.add_argument("--epochs", type=int, default=70) 27 | parser.add_argument("--lr", type=float, default=1e-3) 28 | parser.add_argument("--lr_orth", type=float, default=1e-4) 29 | parser.add_argument("--permute", action="store_true") 30 | parser.add_argument( 31 | "--constraints", 32 | choices=["orthogonal", "lowrank", "almostorthogonal"], 33 | default="orthogonal", 34 | type=str, 35 | ) 36 | parser.add_argument( 37 | "--f", 38 | choices=["scaled_sigmoid", "tanh", "sin"], 39 | default="scaled_sigmoid", 40 | type=str, 41 | ) 42 | parser.add_argument("--r", type=float, default=0.1) 43 | 44 | 45 | args = parser.parse_args() 46 | 47 | n_classes = 10 48 | batch_size = args.batch_size 49 | hidden_size = args.hidden_size 50 | epochs = args.epochs 51 | device = torch.device("cuda") 52 | 53 | 54 | class modrelu(nn.Module): 55 | def __init__(self, features): 56 | super(modrelu, self).__init__() 57 | self.features = features 58 | self.b = nn.Parameter(torch.Tensor(self.features)) 59 | self.reset_parameters() 60 | 61 | def reset_parameters(self): 62 | self.b.data.uniform_(-0.01, 0.01) 63 | 64 | def forward(self, inputs): 65 | norm = torch.abs(inputs) 66 | biased_norm = norm + self.b 67 | magnitude = nn.functional.relu(biased_norm) 68 | phase = torch.sign(inputs) 69 | 70 | return phase * magnitude 71 | 72 | 73 | class ExpRNNCell(nn.Module): 74 | def __init__(self, input_size, hidden_size): 75 | super(ExpRNNCell, self).__init__() 76 | self.input_size = input_size 77 | self.hidden_size = hidden_size 78 | self.recurrent_kernel = nn.Linear(hidden_size, hidden_size, bias=False) 79 | self.input_kernel = nn.Linear(input_size, hidden_size) 80 | self.nonlinearity = modrelu(hidden_size) 81 | 82 | # Make recurrent_kernel orthogonal 83 | if args.constraints == "orthogonal": 84 | geotorch.orthogonal(self.recurrent_kernel, "weight") 85 | elif args.constraints == "lowrank": 86 | geotorch.low_rank(self.recurrent_kernel, "weight", hidden_size) 87 | elif args.constraints == "almostorthogonal": 88 | geotorch.almost_orthogonal(self.recurrent_kernel, "weight", args.r, args.f) 89 | else: 90 | raise ValueError("Unexpected constraints. Got {}".format(args.constraints)) 91 | 92 | self.reset_parameters() 93 | 94 | def reset_parameters(self): 95 | nn.init.kaiming_normal_(self.input_kernel.weight.data, nonlinearity="relu") 96 | 97 | # Initialize the recurrent kernel à la Cayley, as having a block-diagonal matrix 98 | # seems to help in classification problems 99 | 100 | def init_(x): 101 | x.uniform_(0.0, math.pi / 2.0) 102 | c = torch.cos(x.data) 103 | x.data = -torch.sqrt((1.0 - c) / (1.0 + c)) 104 | 105 | K = self.recurrent_kernel 106 | # We initialize it by assigning directly to it from a sampler 107 | K.weight = torus_init_(K.weight, init_=init_) 108 | 109 | def default_hidden(self, input_): 110 | return input_.new_zeros(input_.size(0), self.hidden_size, requires_grad=False) 111 | 112 | def forward(self, input_, hidden): 113 | input_ = self.input_kernel(input_) 114 | hidden = self.recurrent_kernel(hidden) 115 | out = input_ + hidden 116 | return self.nonlinearity(out) 117 | 118 | 119 | class Model(nn.Module): 120 | def __init__(self, hidden_size, permute): 121 | super(Model, self).__init__() 122 | self.permute = permute 123 | if self.permute: 124 | self.register_buffer("permutation", torch.randperm(784)) 125 | self.rnn = ExpRNNCell(1, hidden_size) 126 | self.lin = nn.Linear(hidden_size, n_classes) 127 | self.loss_func = nn.CrossEntropyLoss() 128 | 129 | def forward(self, inputs): 130 | if self.permute: 131 | inputs = inputs[:, self.permutation] 132 | out_rnn = self.rnn.default_hidden(inputs[:, 0, ...]) 133 | with geotorch.parametrize.cached(): 134 | for input in torch.unbind(inputs, dim=1): 135 | out_rnn = self.rnn(input.unsqueeze(dim=1), out_rnn) 136 | return self.lin(out_rnn) 137 | 138 | def loss(self, logits, y): 139 | return self.loss_func(logits, y) 140 | 141 | def correct(self, logits, y): 142 | return torch.eq(torch.argmax(logits, dim=1), y).float().sum() 143 | 144 | 145 | def main(): 146 | # Load data 147 | kwargs = { 148 | "batch_size": batch_size, 149 | "num_workers": 1, 150 | "pin_memory": True, 151 | "shuffle": True, 152 | } 153 | train_loader = torch.utils.data.DataLoader( 154 | datasets.MNIST( 155 | "./mnist", train=True, download=True, transform=transforms.ToTensor() 156 | ), 157 | **kwargs 158 | ) 159 | test_loader = torch.utils.data.DataLoader( 160 | datasets.MNIST("./mnist", train=False, transform=transforms.ToTensor()), 161 | **kwargs 162 | ) 163 | 164 | # Model and optimizers 165 | model = Model(hidden_size, args.permute).to(device) 166 | model.train() 167 | 168 | p_orth = model.rnn.recurrent_kernel 169 | orth_params = p_orth.parameters() 170 | non_orth_params = ( 171 | param for param in model.parameters() if param not in set(p_orth.parameters()) 172 | ) 173 | 174 | optim = torch.optim.RMSprop( 175 | [{"params": non_orth_params}, {"params": orth_params, "lr": args.lr_orth}], 176 | lr=args.lr, 177 | ) 178 | 179 | best_test_acc = 0.0 180 | for epoch in range(epochs): 181 | processed = 0 182 | for batch_idx, (batch_x, batch_y) in enumerate(train_loader): 183 | batch_x, batch_y = batch_x.to(device).view(-1, 784), batch_y.to(device) 184 | 185 | optim.zero_grad() 186 | logits = model(batch_x) 187 | loss = model.loss(logits, batch_y) 188 | loss.backward() 189 | optim.step() 190 | 191 | with torch.no_grad(): 192 | correct = model.correct(logits, batch_y) 193 | 194 | processed += len(batch_x) 195 | print( 196 | "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.2f}%\tBest: {:.2f}%".format( 197 | epoch, 198 | processed, 199 | len(train_loader.dataset), 200 | 100.0 * batch_idx / len(train_loader), 201 | loss.item(), 202 | 100 * correct / len(batch_x), 203 | best_test_acc, 204 | ) 205 | ) 206 | 207 | model.eval() 208 | with torch.no_grad(): 209 | test_loss = 0.0 210 | correct = 0.0 211 | for batch_x, batch_y in test_loader: 212 | batch_x, batch_y = batch_x.to(device).view(-1, 784), batch_y.to(device) 213 | logits = model(batch_x) 214 | test_loss += model.loss(logits, batch_y).float() 215 | correct += model.correct(logits, batch_y).float() 216 | 217 | test_loss /= len(test_loader) 218 | test_acc = 100 * correct / len(test_loader.dataset) 219 | best_test_acc = max(test_acc, best_test_acc) 220 | print( 221 | "\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%, Best Accuracy: {:.2f}%\n".format( 222 | test_loss, test_acc, best_test_acc 223 | ) 224 | ) 225 | 226 | model.train() 227 | 228 | 229 | if __name__ == "__main__": 230 | main() 231 | -------------------------------------------------------------------------------- /geotorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .constraints import ( 2 | sphere, 3 | skew, 4 | symmetric, 5 | orthogonal, 6 | grassmannian, 7 | almost_orthogonal, 8 | low_rank, 9 | fixed_rank, 10 | invertible, 11 | sln, 12 | positive_definite, 13 | positive_semidefinite, 14 | positive_semidefinite_low_rank, 15 | positive_semidefinite_fixed_rank, 16 | ) 17 | from .product import ProductManifold 18 | from .reals import Rn 19 | from .skew import Skew 20 | from .symmetric import Symmetric 21 | from .so import SO 22 | from .sphere import Sphere, SphereEmbedded 23 | from .stiefel import Stiefel 24 | from .grassmannian import Grassmannian 25 | from .almostorthogonal import AlmostOrthogonal 26 | from .lowrank import LowRank 27 | from .fixedrank import FixedRank 28 | from .glp import GLp 29 | from .sl import SL 30 | from .psd import PSD 31 | from .pssd import PSSD 32 | from .pssdfixedrank import PSSDFixedRank 33 | from .pssdlowrank import PSSDLowRank 34 | from .utils import update_base 35 | 36 | 37 | __version__ = "0.3.0" 38 | 39 | 40 | __all__ = [ 41 | "ProductManifold", 42 | "Grassmannian", 43 | "LowRank", 44 | "Rn", 45 | "Skew", 46 | "Symmetric", 47 | "SO", 48 | "Sphere", 49 | "SphereEmbedded", 50 | "Stiefel", 51 | "AlmostOrthogonal", 52 | "GLp", 53 | "SL", 54 | "FixedRank", 55 | "PSD", 56 | "PSSD", 57 | "PSSDLowRank", 58 | "PSSDFixedRank", 59 | "skew", 60 | "symmetric", 61 | "sphere", 62 | "orthogonal", 63 | "grassmannian", 64 | "low_rank", 65 | "fixed_rank", 66 | "almost_orthogonal", 67 | "invertible", 68 | "sln", 69 | "positive_definite", 70 | "positive_semidefinite", 71 | "positive_semidefinite_low_rank", 72 | "positive_semidefinite_fixed_rank", 73 | "update_base", 74 | ] 75 | -------------------------------------------------------------------------------- /geotorch/almostorthogonal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .lowrank import LowRank 3 | from .exceptions import VectorError, InManifoldError, InverseError 4 | from .utils import _extra_repr 5 | 6 | 7 | def scaled_sigmoid(t): 8 | return 2.0 * (torch.sigmoid(t) - 0.5) 9 | 10 | 11 | def inv_scaled_sigmoid(t): 12 | y = 0.5 * t + 0.5 13 | return torch.log(y / (1.0 - y)) 14 | 15 | 16 | class AlmostOrthogonal(LowRank): 17 | fs = { 18 | "scaled_sigmoid": (scaled_sigmoid, inv_scaled_sigmoid), 19 | "tanh": (torch.tanh, torch.atanh), 20 | "sin": (torch.sin, torch.asin), 21 | } 22 | 23 | def __init__(self, size, lam, f="sin", triv="expm"): 24 | r"""Manifold of matrices with singular values in the interval 25 | :math:`(1-\lambda, 1+\lambda)`. 26 | 27 | The possible default maps are the :math:`\sin,\,\tanh` functions and a scaled 28 | sigmoid. The sigmoid is scaled as 29 | :math:`\operatorname{scaled\_sigmoid}(x) = 2\sigma(x) - 1` 30 | where :math:`\sigma` is the usual sigmoid function. 31 | This is done so that the image of the scaled sigmoid is :math:`(-1, 1)`. 32 | 33 | Args: 34 | size (torch.size): Size of the tensor to be parametrized 35 | lam (float): Radius of the interval. A float in the interval :math:`(0, 1]` 36 | f (str or callable or pair of callables): Optional. Either: 37 | 38 | - One of ``["scaled_sigmoid", "tanh", "sin"]`` 39 | 40 | - A callable that maps real numbers to the interval :math:`(-1, 1)` 41 | 42 | - A pair of callables such that the first maps the real numbers to 43 | :math:`(-1, 1)` and the second is a (right) inverse of the first 44 | 45 | Default: ``"sin"`` 46 | triv (str or callable): Optional. 47 | A map that maps skew-symmetric matrices onto the orthogonal matrices 48 | surjectively. This is used to optimize the :math:`U` and :math:`V` in 49 | the SVD. It can be one of ``["expm", "cayley"]`` or a custom callable. 50 | Default: ``"expm"`` 51 | 52 | """ 53 | super().__init__(size, AlmostOrthogonal.rank(size), triv=triv) 54 | if lam < 0.0 or lam > 1.0: 55 | raise ValueError("The radius has to be between 0 and 1. Got {}".format(lam)) 56 | self.lam = lam 57 | f, inv = AlmostOrthogonal.parse_f(f) 58 | self.f = f 59 | self.inv = inv 60 | 61 | @staticmethod 62 | def parse_f(f): 63 | if f in AlmostOrthogonal.fs.keys(): 64 | return AlmostOrthogonal.fs[f] 65 | elif callable(f): 66 | return f, None 67 | elif isinstance(f, tuple) and callable(f[0]) and callable(f[1]): 68 | return f 69 | else: 70 | raise ValueError( 71 | "Argument f was not recognized and is " 72 | "not callable or a pair of callables. " 73 | "Should be one of {}. Found {}".format( 74 | list(AlmostOrthogonal.fs.keys()), f 75 | ) 76 | ) 77 | 78 | @classmethod 79 | def rank(cls, size): 80 | if len(size) < 2: 81 | raise VectorError(cls.__name__, size) 82 | return min(*size[-2:]) 83 | 84 | def submersion(self, U, S, V): 85 | S = 1.0 + self.lam * self.f(S) 86 | return super().submersion(U, S, V) 87 | 88 | def submersion_inv(self, X, check_in_manifold=True): 89 | if self.inv is None: 90 | raise InverseError(self) 91 | U, S, V = super().submersion_inv(X) 92 | if check_in_manifold and not self.in_manifold_singular_values(S): 93 | raise InManifoldError(X, self) 94 | # Harcoded epsilon... not a good practice 95 | if self.lam < 1e-6: 96 | S = S - 1.0 97 | else: 98 | S = self.inv((S - 1.0) / self.lam) 99 | return U, S, V 100 | 101 | def in_manifold_singular_values(self, S, eps=1e-5): 102 | lam = self.lam 103 | if self.lam <= eps: 104 | lam = eps 105 | return ( 106 | super().in_manifold_singular_values(S, eps) 107 | and ((S - 1.0).abs() <= lam).all().item() 108 | ) 109 | 110 | def sample(self, distribution="uniform", init_=None): 111 | r""" 112 | Returns a randomly sampled orthogonal matrix according to the specified 113 | ``distribution``. The options are: 114 | 115 | - ``"uniform"``: Samples a tensor distributed according to the Haar measure 116 | on :math:`\operatorname{SO}(n)` 117 | 118 | - ``"torus"``: Samples a block-diagonal skew-symmetric matrix. 119 | The blocks are of the form 120 | :math:`\begin{pmatrix} 0 & b \\ -b & 0\end{pmatrix}` where :math:`b` is 121 | distributed according to ``init_``. This matrix will be then projected 122 | onto :math:`\operatorname{SO}(n)` using ``self.triv`` 123 | 124 | .. note 125 | 126 | The ``"torus"`` initialization is particularly useful in recurrent kernels 127 | of RNNs 128 | 129 | The output of this method can be used to initialize a parametrized tensor 130 | that has been parametrized with this or any other manifold as:: 131 | 132 | >>> layer = nn.Linear(20, 20) 133 | >>> M = AlmostOrthogonal(layer.weight.size(), lam=0.5) 134 | >>> geotorch.register_parametrization(layer, "weight", M) 135 | >>> layer.weight = M.sample() 136 | 137 | Args: 138 | distribution (string): Optional. One of ``["uniform", "torus"]``. 139 | Default: ``"uniform"`` 140 | init\_ (callable): Optional. To be used with the ``"torus"`` option. 141 | A function that takes a tensor and fills it in place according 142 | to some distribution. See 143 | `torch.init `_. 144 | Default: :math:`\operatorname{Uniform}(-\pi, \pi)` 145 | """ 146 | # Sample an orthogonal matrix as U and return it 147 | return self[0].sample(distribution=distribution, init_=init_) 148 | 149 | def extra_repr(self): 150 | return _extra_repr( 151 | n=self.n, 152 | lam=self.lam, 153 | tensorial_size=self.tensorial_size, 154 | f=self.f, 155 | no_inv=self.inv is None, 156 | ) 157 | -------------------------------------------------------------------------------- /geotorch/constraints.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import geotorch.parametrize as P 3 | 4 | from .symmetric import Symmetric 5 | from .skew import Skew 6 | from .sphere import Sphere, SphereEmbedded 7 | from .stiefel import Stiefel 8 | from .grassmannian import Grassmannian 9 | from .almostorthogonal import AlmostOrthogonal 10 | from .lowrank import LowRank 11 | from .fixedrank import FixedRank 12 | from .glp import GLp 13 | from .sl import SL 14 | from .psd import PSD 15 | from .pssd import PSSD 16 | from .pssdlowrank import PSSDLowRank 17 | from .pssdfixedrank import PSSDFixedRank 18 | 19 | 20 | def _register_manifold(module, tensor_name, cls, *args): 21 | tensor = getattr(module, tensor_name) 22 | M = cls(tensor.size(), *args).to(device=tensor.device, dtype=tensor.dtype) 23 | 24 | # Initialize without checking in manifold 25 | X = M.sample() 26 | if not P.is_parametrized(module, tensor_name): 27 | with torch.no_grad(): 28 | tensor.copy_(X) 29 | else: 30 | setattr(module, tensor_name, X) 31 | 32 | P.register_parametrization(module, tensor_name, M, unsafe=True) 33 | 34 | return module 35 | 36 | 37 | def symmetric(module, tensor_name="weight", lower=True): 38 | r"""Adds a symmetric parametrization to the matrix ``module.tensor_name``. 39 | 40 | When accessing ``module.tensor_name``, the module will return the parametrized 41 | version :math:`X` so that :math:`X^\intercal = X`. 42 | 43 | If the tensor has more than two dimensions, the parametrization will be 44 | applied to the last two dimensions. 45 | 46 | Examples:: 47 | 48 | >>> layer = nn.Linear(30, 30) 49 | >>> geotorch.symmetric(layer, "weight") 50 | >>> torch.allclose(layer.weight, layer.weight.T) 51 | True 52 | 53 | Args: 54 | module (nn.Module): module on which to register the parametrization 55 | tensor_name (string): name of the parameter, buffer, or parametrization 56 | on which the parametrization will be applied. Default: ``"weight"`` 57 | lower (bool): Optional. Uses the lower triangular part of the matrix to 58 | parametrize the matrix. Default: ``True`` 59 | """ 60 | P.register_parametrization(module, tensor_name, Symmetric(lower)) 61 | return module 62 | 63 | 64 | def skew(module, tensor_name="weight", lower=True): 65 | r"""Adds a skew-symmetric parametrization to the matrix ``module.tensor_name``. 66 | 67 | When accessing ``module.tensor_name``, the module will return the parametrized 68 | version :math:`X` so that :math:`X^\intercal = -X`. 69 | 70 | If the tensor has more than two dimensions, the parametrization will be 71 | applied to the last two dimensions. 72 | 73 | Examples:: 74 | 75 | >>> layer = nn.Linear(30, 30) 76 | >>> geotorch.skew(layer, "weight") 77 | >>> torch.allclose(layer.weight, -layer.weight.T) 78 | True 79 | 80 | Args: 81 | module (nn.Module): module on which to register the parametrization 82 | tensor_name (string): name of the parameter, buffer, or parametrization 83 | on which the parametrization will be applied. Default: ``"weight"`` 84 | lower (bool): Optional. Uses the lower triangular part of the matrix to 85 | parametrize the matrix. Default: ``True`` 86 | """ 87 | P.register_parametrization(module, tensor_name, Skew(lower)) 88 | return module 89 | 90 | 91 | def sphere(module, tensor_name="weight", radius=1.0, embedded=False): 92 | r"""Adds a spherical parametrization to the vector (or tensor) ``module.tensor_name``. 93 | 94 | When accessing ``module.tensor_name``, the module will return the parametrized 95 | version :math:`v` so that :math:`\lVert v \rVert = 1`. 96 | 97 | If the tensor has more than one dimension, the parametrization will be 98 | applied to the last dimension. 99 | 100 | Examples:: 101 | 102 | >>> layer = nn.Linear(20, 30) 103 | >>> geotorch.sphere(layer, "bias") 104 | >>> torch.norm(layer.bias) 105 | tensor(1.) 106 | >>> geotorch.sphere(layer, "weight") # Make the columns unit norm 107 | >>> torch.allclose(torch.norm(layer.weight, dim=-1), torch.ones(30)) 108 | True 109 | 110 | Args: 111 | module (nn.Module): module on which to register the parametrization 112 | tensor_name (string): name of the parameter, buffer, or parametrization 113 | on which the parametrization will be applied. Default: ``"weight"`` 114 | radius (float): Optional. 115 | Radius of the sphere. It has to be positive. Default: 1. 116 | embedded (bool): Optional. 117 | Chooses between the implementation of the sphere using the exponential 118 | map (``embedded=False``) and that using the projection from the ambient space (``embedded=True``) 119 | Default. ``True`` 120 | """ 121 | cls = SphereEmbedded if embedded else Sphere 122 | return _register_manifold(module, tensor_name, cls, radius) 123 | 124 | 125 | def orthogonal(module, tensor_name="weight", triv="expm"): 126 | r"""Adds an orthogonal parametrization to the tensor ``module.tensor_name``. 127 | 128 | When accessing ``module.tensor_name``, the module will return the 129 | parametrized version :math:`X` so that :math:`X^\intercal X = \operatorname{I}`. 130 | 131 | If the tensor has more than two dimensions, the parametrization will be 132 | applied to the last two dimensions. 133 | 134 | Examples:: 135 | 136 | >>> layer = nn.Linear(20, 30) 137 | >>> geotorch.orthogonal(layer, "weight") 138 | >>> torch.norm(layer.weight.T @ layer.weight - torch.eye(20,20)) 139 | tensor(4.8488e-05) 140 | 141 | >>> layer = nn.Conv2d(20, 40, 3, 3) # Make the kernels orthogonal 142 | >>> geotorch.orthogonal(layer, "weight") 143 | >>> torch.norm(layer.weight.transpose(-2, -1) @ layer.weight - torch.eye(3,3)) 144 | tensor(1.2225e-05) 145 | 146 | Args: 147 | module (nn.Module): module on which to register the parametrization 148 | tensor_name (string): name of the parameter, buffer, or parametrization 149 | on which the parametrization will be applied. Default: ``"weight"`` 150 | triv (str or callable): Optional. 151 | A map that maps a skew-symmetric matrix to an orthogonal matrix. 152 | It can be the exponential of matrices or the cayley transform passing 153 | ``["expm", "cayley"]`` or a custom callable. Default: ``"expm"`` 154 | """ 155 | return _register_manifold(module, tensor_name, Stiefel, triv) 156 | 157 | 158 | def almost_orthogonal(module, tensor_name="weight", lam=0.1, f="sin", triv="expm"): 159 | r"""Adds an almost orthogonal parametrization to the tensor ``module.tensor_name``. 160 | 161 | When accessing ``module.tensor_name``, the module will return the 162 | parametrized version :math:`X` which will have its singular values in 163 | the interval :math:`[1-\texttt{lam}, 1+\texttt{lam}]` 164 | 165 | If the tensor has more than two dimensions, the parametrization will be 166 | applied to the last two dimensions. 167 | 168 | Examples:: 169 | 170 | >>> layer = nn.Linear(20, 30) 171 | >>> geotorch.almost_orthogonal(layer, "weight", 0.5) 172 | >>> S = torch.linalg.svd(layer.weight).S 173 | >>> all(S >= 0.5 and S <= 1.5) 174 | True 175 | 176 | Args: 177 | module (nn.Module): module on which to register the parametrization 178 | tensor_name (string): name of the parameter, buffer, or parametrization 179 | on which the parametrization will be applied. Default: ``"weight"`` 180 | lam (float): Radius of the interval for the singular values. A float in the interval :math:`[0, 1]`. Default: ``0.1`` 181 | f (str or callable or pair of callables): Optional. Either: 182 | 183 | - One of ``["scaled_sigmoid", "tanh", "sin"]`` 184 | 185 | - A callable that maps real numbers to the interval :math:`[-1, 1]` 186 | 187 | - A pair of callables such that the first maps the real numbers to 188 | :math:`[-1, 1]` and the second is a (right) inverse of the first 189 | 190 | Default: ``"sin"`` 191 | triv (str or callable): Optional. 192 | A map that maps skew-symmetric matrices onto the orthogonal matrices 193 | surjectively. This is used to optimize the :math:`U` and :math:`V` in the 194 | SVD. It can be one of ``["expm", "cayley"]`` or a custom 195 | callable. Default: ``"expm"`` 196 | """ 197 | return _register_manifold(module, tensor_name, AlmostOrthogonal, lam, f, triv) 198 | 199 | 200 | def grassmannian(module, tensor_name="weight", triv="expm"): 201 | r"""Adds an parametrization to the tensor ``module.tensor_name`` so that the 202 | result represents a subspace. If the initial matrix was of size :math:`n \times k` 203 | the parametrized matrix will represent a subspace of dimension :math:`k` of 204 | :math:`\mathbb{R}^n`. 205 | 206 | When accessing ``module.tensor_name``, the module will return the parametrized 207 | version :math:`X` so that :math:`X` represents :math:`k` orthogonal vectors of 208 | :math:`\mathbb{R}^n` that span the subspace. That is, the resulting matrix will 209 | be orthogonal, :math:`X^\intercal X = \operatorname{I}`. 210 | 211 | If the tensor has more than two dimensions, the parametrization will be 212 | applied to the last two dimensions. 213 | 214 | .. note:: 215 | 216 | Even though this space resembles that generated by :func:`geotorch.orthogonal`, 217 | it is actually a subspace of that, as every subspace can be represented by many 218 | different basis of vectors that span it. 219 | 220 | Examples:: 221 | 222 | >>> layer = nn.Linear(20, 30) 223 | >>> geotorch.grassmannian(layer, "weight") 224 | >>> torch.norm(layer.weight.t() @ layer.weight - torch.eye(20,20)) 225 | tensor(1.8933e-05) 226 | 227 | >>> layer = nn.Conv2d(20, 40, 3, 3) # Make the kernels represent subspaces 228 | >>> geotorch.grassmannian(layer, "weight") 229 | >>> torch.norm(layer.weight.transpose(-2, -1) @ layer.weight - torch.eye(3,3)) 230 | tensor(8.3796-06) 231 | 232 | Args: 233 | module (nn.Module): module on which to register the parametrization 234 | tensor_name (string): name of the parameter, buffer, or parametrization 235 | on which the parametrization will be applied. Default: ``"weight"`` 236 | triv (str or callable): Optional. 237 | A map that maps a skew-symmetric matrix to an orthogonal matrix. 238 | It can be the exponential of matrices or the cayley transform passing 239 | ``["expm", "cayley"]`` or a custom callable. Default: ``"expm"`` 240 | """ 241 | return _register_manifold(module, tensor_name, Grassmannian, triv) 242 | 243 | 244 | def low_rank(module, tensor_name, rank, triv="expm"): 245 | r"""Adds a low rank parametrization to the tensor ``module.tensor_name``. 246 | 247 | When accessing ``module.tensor_name``, the module will return the 248 | parametrized version :math:`X` which will have rank at most ``rank``. 249 | 250 | If the tensor has more than two dimensions, the parametrization will be 251 | applied to the last two dimensions. 252 | 253 | Examples:: 254 | 255 | >>> layer = nn.Linear(20, 30) 256 | >>> geotorch.low_rank(layer, "weight", 4) 257 | >>> list(torch.linalg.svd(layer.weight).S > 1e-7).count(True) <= 4 258 | True 259 | 260 | Args: 261 | module (nn.Module): module on which to register the parametrization 262 | tensor_name (string): name of the parameter, buffer, or parametrization 263 | on which the parametrization will be applied 264 | rank (int): Rank of the matrix. 265 | It has to be less than the minimum of the two dimensions of the 266 | matrix 267 | triv (str or callable): Optional. 268 | A map that maps skew-symmetric matrices onto the orthogonal matrices 269 | surjectively. This is used to optimize the :math:`U` and :math:`V` in the 270 | SVD. It can be one of ``["expm", "cayley"]`` or a custom 271 | callable. Default: ``"expm"`` 272 | """ 273 | return _register_manifold(module, tensor_name, LowRank, rank, triv) 274 | 275 | 276 | def fixed_rank(module, tensor_name, rank, f="softplus", triv="expm"): 277 | r"""Adds a fixed rank parametrization to the tensor ``module.tensor_name``. 278 | 279 | When accessing ``module.tensor_name``, the module will return the 280 | parametrized version :math:`X` which will have rank equal to ``rank``. 281 | 282 | If the tensor has more than two dimensions, the parametrization will be 283 | applied to the last two dimensions. 284 | 285 | Examples:: 286 | 287 | >>> layer = nn.Linear(20, 30) 288 | >>> geotorch.fixed_rank(layer, "weight", 5) 289 | >>> list(torch.linalg.svd(layer.weight).S > 1e-7).count(True) 290 | 5 291 | 292 | Args: 293 | module (nn.Module): module on which to register the parametrization 294 | tensor_name (string): name of the parameter, buffer, or parametrization 295 | on which the parametrization will be applied 296 | rank (int): Rank of the matrix. 297 | It has to be less than the minimum of the two dimensions of the 298 | matrix 299 | f (str or callable or pair of callables): Optional. Either: 300 | 301 | - ``"softplus"`` 302 | 303 | - A callable that maps real numbers to the interval :math:`(0, \infty)` 304 | 305 | - A pair of callables such that the first maps the real numbers to 306 | :math:`(0, \infty)` and the second is a (right) inverse of the first 307 | 308 | Default: ``"softplus"`` 309 | triv (str or callable): Optional. 310 | A map that maps skew-symmetric matrices onto the orthogonal matrices 311 | surjectively. This is used to optimize the :math:`U` and :math:`V` in the 312 | SVD. It can be one of ``["expm", "cayley"]`` or a custom 313 | callable. Default: ``"expm"`` 314 | """ 315 | return _register_manifold(module, tensor_name, FixedRank, rank, f, triv) 316 | 317 | 318 | def invertible(module, tensor_name="weight", f="softplus", triv="expm"): 319 | r"""Adds an invertibility constraint to the tensor ``module.tensor_name``. 320 | 321 | When accessing ``module.tensor_name``, the module will return the 322 | parametrized version :math:`X` which will have positive determinant and, 323 | in particular, it will be invertible. 324 | 325 | If the tensor has more than two dimensions, the parametrization will be 326 | applied to the last two dimensions. 327 | 328 | Examples:: 329 | 330 | >>> layer = nn.Linear(20, 20) 331 | >>> geotorch.invertible(layer, "weight") 332 | >>> torch.det(layer.weight) > 0.0 333 | True 334 | 335 | Args: 336 | module (nn.Module): module on which to register the parametrization 337 | tensor_name (string): name of the parameter, buffer, or parametrization 338 | on which the parametrization will be applied. Default: ``"weight"`` 339 | f (str or callable or pair of callables): Optional. Either: 340 | 341 | - ``"softplus"`` 342 | 343 | - A callable that maps real numbers to the interval :math:`(0, \infty)` 344 | 345 | - A pair of callables such that the first maps the real numbers to 346 | :math:`(0, \infty)` and the second is a (right) inverse of the first 347 | 348 | Default: ``"softplus"`` 349 | triv (str or callable): Optional. 350 | A map that maps skew-symmetric matrices onto the orthogonal matrices 351 | surjectively. This is used to optimize the :math:`U` and :math:`V` in the 352 | SVD. It can be one of ``["expm", "cayley"]`` or a custom 353 | callable. Default: ``"expm"`` 354 | """ 355 | return _register_manifold(module, tensor_name, GLp, f, triv) 356 | 357 | 358 | def sln(module, tensor_name="weight", f="softplus", triv="expm"): 359 | r"""Adds a constraint of having determinant one to the tensor ``module.tensor_name``. 360 | 361 | When accessing ``module.tensor_name``, the module will return the 362 | parametrized version :math:`X` which will have determinant equal to 1. 363 | 364 | If the tensor has more than two dimensions, the parametrization will be 365 | applied to the last two dimensions. 366 | 367 | Examples:: 368 | 369 | >>> layer = nn.Linear(20, 20) 370 | >>> geotorch.sln(layer, "weight") 371 | >>> torch.det(layer.weight) 372 | tensor(1.0000) 373 | 374 | Args: 375 | module (nn.Module): module on which to register the parametrization 376 | tensor_name (string): name of the parameter, buffer, or parametrization 377 | on which the parametrization will be applied. Default: ``"weight"`` 378 | f (str or callable or pair of callables): Optional. Either: 379 | 380 | - ``"softplus"`` 381 | 382 | - A callable that maps real numbers to the interval :math:`(0, \infty)` 383 | 384 | - A pair of callables such that the first maps the real numbers to 385 | :math:`(0, \infty)` and the second is a (right) inverse of the first 386 | 387 | Default: ``"softplus"`` 388 | triv (str or callable): Optional. 389 | A map that maps skew-symmetric matrices onto the orthogonal matrices 390 | surjectively. This is used to optimize the :math:`U` and :math:`V` in the 391 | SVD. It can be one of ``["expm", "cayley"]`` or a custom 392 | callable. Default: ``"expm"`` 393 | """ 394 | return _register_manifold(module, tensor_name, SL, f, triv) 395 | 396 | 397 | def positive_definite(module, tensor_name="weight", f="softplus", triv="expm"): 398 | r"""Adds a positive definiteness constraint to the tensor ``module.tensor_name``. 399 | 400 | When accessing ``module.tensor_name``, the module will return the 401 | parametrized version :math:`X` which will be symmetric and with positive 402 | eigenvalues 403 | 404 | If the tensor has more than two dimensions, the parametrization will be 405 | applied to the last two dimensions. 406 | 407 | Examples:: 408 | 409 | >>> layer = nn.Linear(20, 20) 410 | >>> geotorch.positive_definite(layer, "weight") 411 | >>> (torch.linalg.eigvalsh(layer.weight) > 0.0).all() 412 | tensor(True) 413 | 414 | Args: 415 | module (nn.Module): module on which to register the parametrization 416 | tensor_name (string): name of the parameter, buffer, or parametrization 417 | on which the parametrization will be applied. Default: ``"weight"`` 418 | f (str or callable or pair of callables): Optional. Either: 419 | 420 | - ``"softplus"`` 421 | 422 | - A callable that maps real numbers to the interval :math:`(0, \infty)` 423 | 424 | - A pair of callables such that the first maps the real numbers to 425 | :math:`(0, \infty)` and the second is a (right) inverse of the first 426 | 427 | Default: ``"softplus"`` 428 | triv (str or callable): Optional. 429 | A map that maps skew-symmetric matrices onto the orthogonal 430 | matrices surjectively. This is used to optimize the :math:`Q` in the eigenvalue 431 | decomposition. It can be one of ``["expm", "cayley"]`` or a custom 432 | callable. Default: ``"expm"`` 433 | """ 434 | return _register_manifold(module, tensor_name, PSD, f, triv) 435 | 436 | 437 | def positive_semidefinite(module, tensor_name="weight", triv="expm"): 438 | r"""Adds a positive definiteness constraint to the tensor ``module.tensor_name``. 439 | 440 | When accessing ``module.tensor_name``, the module will return the 441 | parametrized version :math:`X` which will be symmetric and with 442 | non-negative eigenvalues 443 | 444 | If the tensor has more than two dimensions, the parametrization will be 445 | applied to the last two dimensions. 446 | 447 | Examples:: 448 | 449 | >>> layer = nn.Linear(20, 20) 450 | >>> geotorch.positive_semidefinite(layer, "weight") 451 | >>> L = torch.linalg.eigvalsh(layer.weight) 452 | >>> L[L.abs() < 1e-7] = 0.0 # Round errors 453 | >>> (L >= 0.0).all() 454 | tensor(True) 455 | 456 | Args: 457 | module (nn.Module): module on which to register the parametrization 458 | tensor_name (string): name of the parameter, buffer, or parametrization 459 | on which the parametrization will be applied. Default: ``"weight"`` 460 | triv (str or callable): Optional. 461 | A map that maps skew-symmetric matrices onto the orthogonal 462 | matrices surjectively. This is used to optimize the :math:`Q` in the eigenvalue 463 | decomposition. It can be one of ``["expm", "cayley"]`` or a custom 464 | callable. Default: ``"expm"`` 465 | """ 466 | return _register_manifold(module, tensor_name, PSSD, triv) 467 | 468 | 469 | def positive_semidefinite_low_rank(module, tensor_name, rank, triv="expm"): 470 | r"""Adds a positive definiteness constraint to the tensor ``module.tensor_name``. 471 | 472 | When accessing ``module.tensor_name``, the module will return the 473 | parametrized version :math:`X` which will be symmetric and with non-negative 474 | eigenvalues and at most ``rank`` of them non-zero. 475 | 476 | If the tensor has more than two dimensions, the parametrization will be 477 | applied to the last two dimensions. 478 | 479 | Examples:: 480 | 481 | >>> layer = nn.Linear(20, 20) 482 | >>> geotorch.positive_semidefinite_low_rank(layer, "weight", 5) 483 | >>> L = torch.linalg.eigvalsh(layer.weight) 484 | >>> L[L.abs() < 1e-7] = 0.0 # Round errors 485 | >>> (L >= 0.0).all() 486 | tensor(True) 487 | >>> list(L > 0.0).count(True) <= 5 488 | True 489 | 490 | Args: 491 | module (nn.Module): module on which to register the parametrization 492 | tensor_name (string): name of the parameter, buffer, or parametrization 493 | on which the parametrization will be applied 494 | rank (int): Rank of the matrix. 495 | It has to be less than the minimum of the two dimensions of the 496 | matrix 497 | triv (str or callable): Optional. 498 | A map that maps skew-symmetric matrices onto the orthogonal 499 | matrices surjectively. This is used to optimize the :math:`Q` in the eigenvalue 500 | decomposition. It can be one of ``["expm", "cayley"]`` or a custom 501 | callable. Default: ``"expm"`` 502 | """ 503 | return _register_manifold(module, tensor_name, PSSDLowRank, rank, triv) 504 | 505 | 506 | def positive_semidefinite_fixed_rank( 507 | module, tensor_name, rank, f="softplus", triv="expm" 508 | ): 509 | r"""Adds a positive definiteness constraint to the tensor ``module.tensor_name``. 510 | 511 | When accessing ``module.tensor_name``, the module will return the 512 | parametrized version :math:`X` which will be symmetric and with non-negative 513 | eigenvalues and exactly ``rank`` of them non-zero. 514 | 515 | If the tensor has more than two dimensions, the parametrization will be 516 | applied to the last two dimensions. 517 | 518 | Examples:: 519 | 520 | >>> layer = nn.Linear(20, 20) 521 | >>> geotorch.positive_semidefinite_fixed_rank(layer, "weight", 5) 522 | >>> L = torch.linalg.eigvalsh(layer.weight) 523 | >>> L[L.abs() < 1e-7] = 0.0 # Round errors 524 | >>> (L >= 0.0).all() 525 | tensor(True) 526 | >>> list(L > 0.0).count(True) 527 | 5 528 | 529 | Args: 530 | module (nn.Module): module on which to register the parametrization 531 | tensor_name (string): name of the parameter, buffer, or parametrization 532 | on which the parametrization will be applied 533 | rank (int): Rank of the matrix. 534 | It has to be less than the minimum of the two dimensions of the 535 | matrix 536 | f (str or callable or pair of callables): Optional. Either: 537 | 538 | - ``"softplus"`` 539 | 540 | - A callable that maps real numbers to the interval :math:`(0, \infty)` 541 | 542 | - A pair of callables such that the first maps the real numbers to 543 | :math:`(0, \infty)` and the second is a (right) inverse of the first 544 | 545 | Default: ``"softplus"`` 546 | triv (str or callable): Optional. 547 | A map that maps skew-symmetric matrices onto the orthogonal 548 | matrices surjectively. This is used to optimize the :math:`Q` in the 549 | eigenvalue decomposition. It can be one of ``["expm", "cayley"]`` or 550 | a custom callable. Default: ``"expm"`` 551 | """ 552 | return _register_manifold(module, tensor_name, PSSDFixedRank, rank, f, triv) 553 | -------------------------------------------------------------------------------- /geotorch/exceptions.py: -------------------------------------------------------------------------------- 1 | class VectorError(ValueError): 2 | def __init__(self, name, size): 3 | super().__init__( 4 | "Cannot instantiate {} on a tensor of less than 2 dimensions. " 5 | "Got a tensor of size {}".format(name, size) 6 | ) 7 | 8 | 9 | class InverseError(ValueError): 10 | def __init__(self, M): 11 | super().__init__( 12 | "Cannot initialize the parametrization {} as no inverse for the function " 13 | "{} was specified in the constructor".format(M, M.f.__name__) 14 | ) 15 | 16 | 17 | class NonSquareError(ValueError): 18 | def __init__(self, name, size): 19 | super().__init__( 20 | "The {} parametrization can just be applied to square matrices. " 21 | "Got a tensor of size {}".format(name, size) 22 | ) 23 | 24 | 25 | class RankError(ValueError): 26 | def __init__(self, n, k, rank): 27 | super().__init__( 28 | "The rank has to be 1 <= rank <= min({}, {}). Found {}".format(n, k, rank) 29 | ) 30 | 31 | 32 | class InManifoldError(ValueError): 33 | def __init__(self, X, M): 34 | super().__init__("Tensor not contained in {}. Got\n{}".format(M, X)) 35 | -------------------------------------------------------------------------------- /geotorch/fixedrank.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .lowrank import LowRank 3 | from .exceptions import InverseError 4 | 5 | 6 | def softplus_epsilon(x, epsilon=1e-6): 7 | return torch.nn.functional.softplus(x) + epsilon 8 | 9 | 10 | def inv_softplus_epsilon(x, epsilon=1e-6): 11 | y = x - epsilon 12 | return torch.where(y > 20, y, y.expm1().log()) 13 | 14 | 15 | class FixedRank(LowRank): 16 | fs = {"softplus": (softplus_epsilon, inv_softplus_epsilon)} 17 | 18 | def __init__(self, size, rank, f="softplus", triv="expm"): 19 | r""" 20 | Manifold of non-square matrices of rank equal to ``rank`` 21 | 22 | Args: 23 | size (torch.size): Size of the tensor to be parametrized 24 | rank (int): Rank of the matrices. 25 | It has to be less or equal to 26 | :math:`\min(\texttt{size}[-1], \texttt{size}[-2])` 27 | f (str or callable or pair of callables): Optional. Either: 28 | 29 | - ``"softplus"`` 30 | 31 | - A callable that maps real numbers to the interval :math:`(0, \infty)` 32 | 33 | - A pair of callables such that the first maps the real numbers onto 34 | :math:`(0, \infty)` and the second is a (right) inverse of the first 35 | 36 | Default: ``"softplus"`` 37 | triv (str or callable): Optional. 38 | A map that maps skew-symmetric matrices onto the orthogonal matrices 39 | surjectively. This is used to optimize the :math:`U` and :math:`V` in 40 | the SVD. It can be one of ``["expm", "cayley"]`` or a custom callable. 41 | Default: ``"expm"`` 42 | """ 43 | super().__init__(size, rank, triv=triv) 44 | f, inv = FixedRank.parse_f(f) 45 | self.f = f 46 | self.inv = inv 47 | 48 | @staticmethod 49 | def parse_f(f): 50 | if f in FixedRank.fs.keys(): 51 | return FixedRank.fs[f] 52 | elif callable(f): 53 | return f, None 54 | elif isinstance(f, tuple) and callable(f[0]) and callable(f[1]): 55 | return f 56 | else: 57 | raise ValueError( 58 | "Argument f was not recognized and is " 59 | "not callable or a pair of callables. " 60 | "Should be one of {}. Found {}".format(list(FixedRank.fs.keys()), f) 61 | ) 62 | 63 | def submersion(self, U, S, V): 64 | return super().submersion(U, self.f(S), V) 65 | 66 | def submersion_inv(self, X, check_in_manifold=True): 67 | U, S, V = super().submersion_inv(X, check_in_manifold) 68 | if self.inv is None: 69 | raise InverseError(self) 70 | return U, self.inv(S), V 71 | 72 | def in_manifold_singular_values(self, S, eps=1e-5): 73 | r""" 74 | Checks that a vector of singular values is in the manifold. 75 | 76 | For tensors with more than 1 dimension the first dimensions are 77 | treated as batch dimensions. 78 | 79 | Args: 80 | S (torch.Tensor): Vector of singular values 81 | eps (float): Optional. Threshold at which the singular values are 82 | considered to be zero 83 | Default: ``1e-5`` 84 | """ 85 | if not super().in_manifold_singular_values(S, eps): 86 | return False 87 | # We compute the \infty-norm of the eigenvalues 88 | D = S[..., : self.rank] 89 | infty_norm = D.abs().max(dim=-1).values 90 | return (infty_norm > eps).all().item() 91 | 92 | def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6, factorized=False): 93 | r""" 94 | Returns a randomly sampled matrix on the manifold by sampling a matrix according 95 | to ``init_`` and projecting it onto the manifold. 96 | 97 | If the sampled matrix has more than `self.rank` small singular values, the 98 | smallest ones are clamped to be at least ``eps`` in absolute value. 99 | 100 | The output of this method can be used to initialize a parametrized tensor 101 | that has been parametrized with this or any other manifold as:: 102 | 103 | >>> layer = nn.Linear(20, 20) 104 | >>> M = FixedRank(layer.weight.size(), rank=6) 105 | >>> geotorch.register_parametrization(layer, "weight", M) 106 | >>> layer.weight = M.sample() 107 | 108 | Args: 109 | init\_ (callable): Optional. A function that takes a tensor and fills it 110 | in place according to some distribution. See 111 | `torch.init `_. 112 | Default: ``torch.nn.init.xavier_normal_`` 113 | eps (float): Optional. Minimum singular value of the sampled matrix. 114 | Default: ``5e-6`` 115 | """ 116 | U, S, V = super().sample(factorized=True, init_=init_) 117 | with torch.no_grad(): 118 | # S >= 0, as given by torch.linalg.eigvalsh() 119 | S[S < eps] = eps 120 | if factorized: 121 | return U, S, V 122 | else: 123 | # Compute U S V^T efficiently 124 | if self.transposed: 125 | return (U * S.unsqueeze(-2)) @ V.transpose(-2, -1) 126 | else: 127 | return U @ (S.unsqueeze(-1) * V.transpose(-2, -1)) 128 | -------------------------------------------------------------------------------- /geotorch/glp.py: -------------------------------------------------------------------------------- 1 | from .fixedrank import FixedRank 2 | from .exceptions import VectorError, NonSquareError 3 | from .utils import _extra_repr 4 | 5 | 6 | class GLp(FixedRank): 7 | def __init__(self, size, f="softplus", triv="expm"): 8 | r""" 9 | Manifold of invertible matrices 10 | 11 | Args: 12 | size (torch.size): Size of the tensor to be parametrized 13 | f (str or callable or pair of callables): Optional. Either: 14 | 15 | - ``"softplus"`` 16 | 17 | - A callable that maps real numbers to the interval :math:`(0, \infty)` 18 | 19 | - A pair of callables such that the first maps the real numbers to 20 | :math:`(0, \infty)` and the second is a (right) inverse of the first 21 | 22 | Default: ``"softplus"`` 23 | triv (str or callable): Optional. 24 | A map that maps skew-symmetric matrices onto the orthogonal matrices 25 | surjectively. This is used to optimize the :math:`U` and :math:`V` in the 26 | SVD. It can be one of ``["expm", "cayley"]`` or a custom 27 | callable. Default: ``"expm"`` 28 | """ 29 | super().__init__(size, GLp.rank(size), f, triv) 30 | 31 | @classmethod 32 | def rank(cls, size): 33 | if len(size) < 2: 34 | raise VectorError(cls.__name__, size) 35 | n, k = size[-2:] 36 | if n != k: 37 | raise NonSquareError(cls.__name__, size) 38 | return n 39 | 40 | def extra_repr(self): 41 | return _extra_repr(n=self.n, tensorial_size=self.tensorial_size) 42 | -------------------------------------------------------------------------------- /geotorch/grassmannian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .stiefel import Stiefel 4 | 5 | 6 | class Grassmannian(Stiefel): 7 | def __init__(self, size, triv="expm"): 8 | r""" 9 | Grassmannian manifold as a projection from the orthogonal 10 | matrices :math:`\operatorname{St}(n,k)`. 11 | The metric considered is the canonical. 12 | 13 | Args: 14 | size (torch.size): Size of the tensor to be parametrized 15 | triv (str or callable): Optional. 16 | A map that maps skew-symmetric matrices onto the orthogonal matrices 17 | surjectively. It can be one of ``["expm", "cayley"]`` or a custom 18 | callable. Default: ``"expm"`` 19 | """ 20 | super().__init__(size=size, triv=triv) 21 | 22 | def frame(self, X): 23 | k = X.size(-1) 24 | size_z = X.size()[:-2] + (k, k) 25 | Z = X.new_zeros(*size_z) 26 | X = torch.cat([Z, X[..., k:, :]], dim=-2) 27 | return super().frame(X) 28 | -------------------------------------------------------------------------------- /geotorch/lowrank.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .product import ProductManifold 4 | from .stiefel import Stiefel 5 | from .reals import Rn 6 | from .exceptions import VectorError, RankError, InManifoldError 7 | from .utils import transpose, _extra_repr 8 | 9 | 10 | class LowRank(ProductManifold): 11 | def __init__(self, size, rank, triv="expm"): 12 | r""" 13 | Variety of the matrices of rank :math:`r` or less. 14 | 15 | Args: 16 | size (torch.size): Size of the tensor to be parametrized 17 | rank (int): Rank of the matrices. 18 | It has to be less or equal to 19 | :math:`\min(\texttt{size}[-1], \texttt{size}[-2])` 20 | triv (str or callable): Optional. 21 | A map that maps skew-symmetric matrices onto the orthogonal matrices 22 | surjectively. This is used to optimize the :math:`U` and :math:`V` in 23 | the SVD. It can be one of ``["expm", "cayley"]`` or a custom callable. 24 | Default: ``"expm"`` 25 | """ 26 | n, k, tensorial_size, transposed = LowRank.parse_size(size) 27 | if rank > min(n, k) or rank < 1: 28 | raise RankError(n, k, rank) 29 | super().__init__(LowRank.manifolds(n, k, rank, tensorial_size, triv)) 30 | self.n = n 31 | self.k = k 32 | self.rank = rank 33 | self.tensorial_size = tensorial_size 34 | self.transposed = transposed 35 | 36 | @classmethod 37 | def parse_size(cls, size): 38 | if len(size) < 2: 39 | raise VectorError(cls.__name__, size) 40 | transposed = size[-2] < size[-1] 41 | n = max(size[-2:]) 42 | k = min(size[-2:]) 43 | tensorial_size = size[:-2] 44 | return n, k, tensorial_size, transposed 45 | 46 | @staticmethod 47 | def manifolds(n, k, rank, tensorial_size, triv): 48 | size_u = tensorial_size + (n, rank) 49 | size_s = tensorial_size + (rank,) 50 | size_v = tensorial_size + (k, rank) 51 | return Stiefel(size_u, triv), Rn(size_s), Stiefel(size_v, triv) 52 | 53 | def frame(self, X): 54 | U = X.tril(-1)[..., : self.rank] 55 | S = X.diagonal(dim1=-2, dim2=-1)[..., : self.rank] 56 | V = X.triu(1).transpose(-2, -1)[..., : self.rank] 57 | return U, S, V 58 | 59 | def submersion(self, U, S, V): 60 | return (U * S.unsqueeze(-2)) @ V.transpose(-2, -1) 61 | 62 | @transpose 63 | def forward(self, X): 64 | X = self.frame(X) 65 | U, S, V = super().forward(X) 66 | return self.submersion(U, S, V) 67 | 68 | def frame_inv(self, X1, X2, X3): 69 | with torch.no_grad(): 70 | # X1 is lower-triangular 71 | # X2 is a vector 72 | # X3 is lower-triangular 73 | size = self.tensorial_size + (self.n, self.k) 74 | ret = torch.zeros(size, dtype=X1.dtype, device=X1.device) 75 | ret[..., : self.rank] += X1 76 | ret[..., : self.rank, : self.rank] += torch.diag_embed(X2) 77 | ret.transpose(-2, -1)[..., : self.rank] += X3 78 | return ret 79 | 80 | def submersion_inv(self, X, check_in_manifold=True): 81 | U, S, Vt = torch.linalg.svd(X, full_matrices=False) 82 | V = Vt.transpose(-2, -1) 83 | if check_in_manifold and not self.in_manifold_singular_values(S): 84 | raise InManifoldError(X, self) 85 | return U[..., : self.rank], S[..., : self.rank], V[..., : self.rank] 86 | 87 | @transpose 88 | def right_inverse(self, X, check_in_manifold=True): 89 | USV = self.submersion_inv(X, check_in_manifold) 90 | X1, X2, X3 = super().right_inverse(USV, check_in_manifold=False) 91 | return self.frame_inv(X1, X2, X3) 92 | 93 | def in_manifold_singular_values(self, S, eps=1e-5): 94 | r""" 95 | Checks that an ordered vector of singular values is in the manifold. 96 | 97 | For tensors with more than 1 dimension the first dimensions are 98 | treated as batch dimensions. 99 | 100 | Args: 101 | S (torch.Tensor): Vector of singular values 102 | eps (float): Optional. Threshold at which the singular values are 103 | considered to be zero 104 | Default: ``1e-5`` 105 | """ 106 | if S.size(-1) <= self.rank: 107 | return True 108 | # We compute the \infty-norm of the remaining dimension 109 | D = S[..., self.rank :] 110 | infty_norm_err = D.abs().max(dim=-1).values 111 | return (infty_norm_err < eps).all() 112 | 113 | def in_manifold(self, X, eps=1e-5): 114 | r""" 115 | Checks that a given matrix is in the manifold. 116 | 117 | Args: 118 | X (torch.Tensor or tuple): The input matrix or matrices of shape ``(*, n, k)``. 119 | eps (float): Optional. Threshold at which the singular values are 120 | considered to be zero 121 | Default: ``1e-5`` 122 | """ 123 | if X.size(-1) > X.size(-2): 124 | X = X.transpose(-2, -1) 125 | if X.size() != self.tensorial_size + (self.n, self.k): 126 | return False 127 | S = torch.linalg.svdvals(X) 128 | return self.in_manifold_singular_values(S, eps) 129 | 130 | def sample(self, init_=torch.nn.init.xavier_normal_, factorized=False): 131 | r""" 132 | Returns a randomly sampled matrix on the manifold by sampling a matrix according 133 | to ``init_`` and projecting it onto the manifold. 134 | 135 | The output of this method can be used to initialize a parametrized tensor 136 | that has been parametrized with this or any other manifold as:: 137 | 138 | >>> layer = nn.Linear(20, 20) 139 | >>> M = LowRank(layer.weight.size(), rank=6) 140 | >>> geotorch.register_parametrization(layer, "weight", M) 141 | >>> layer.weight = M.sample() 142 | 143 | Args: 144 | init\_ (callable): Optional. A function that takes a tensor and fills it 145 | in place according to some distribution. See 146 | `torch.init `_. 147 | Default: ``torch.nn.init.xavier_normal_`` 148 | """ 149 | with torch.no_grad(): 150 | device = self[0].base.device 151 | dtype = self[0].base.dtype 152 | X = torch.empty( 153 | *(self.tensorial_size + (self.n, self.k)), device=device, dtype=dtype 154 | ) 155 | init_(X) 156 | U, S, Vt = torch.linalg.svd(X, full_matrices=False) 157 | U, S, Vt = U[..., : self.rank], S[..., : self.rank], Vt[..., : self.rank, :] 158 | if factorized: 159 | if self.transposed: 160 | return Vt.transpose(-2, -1), S, U 161 | else: 162 | return U, S, Vt.transpose(-2, -1) 163 | else: 164 | X = (U * S.unsqueeze(-2)) @ Vt 165 | if self.transposed: 166 | X = X.transpose(-2, -1) 167 | return X 168 | 169 | def extra_repr(self): 170 | return _extra_repr( 171 | n=self.n, 172 | k=self.k, 173 | rank=self.rank, 174 | tensorial_size=self.tensorial_size, 175 | transposed=self.transposed, 176 | ) 177 | -------------------------------------------------------------------------------- /geotorch/product.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ProductManifold(nn.ModuleList): 5 | def __init__(self, manifolds): 6 | r""" 7 | Product manifold :math:`M_1 \times \dots \times M_k`. It can be indexed like a 8 | regular Python list. 9 | 10 | .. note:: 11 | 12 | This is an abstract manifold. It may be used by composing it on the 13 | left and the right by an appropriate linear immersion / submersion. 14 | See for example the implementation in :class:`~geotorch.LowRank` 15 | 16 | Args: 17 | manifolds (iterable): An iterable of manifolds 18 | """ 19 | super().__init__(manifolds) 20 | 21 | def forward(self, Xs): 22 | return tuple(mani(X) for mani, X in zip(self, Xs)) 23 | 24 | def right_inverse(self, Xs, check_in_manifold=True): 25 | return tuple( 26 | mani.right_inverse(X, check_in_manifold) for mani, X in zip(self, Xs) 27 | ) 28 | -------------------------------------------------------------------------------- /geotorch/psd.py: -------------------------------------------------------------------------------- 1 | from .pssdfixedrank import PSSDFixedRank 2 | from .exceptions import VectorError, NonSquareError 3 | from .utils import _extra_repr 4 | 5 | 6 | class PSD(PSSDFixedRank): 7 | def __init__(self, size, f="softplus", triv="expm"): 8 | r""" 9 | Manifold of symmetric positive definite matrices 10 | 11 | Args: 12 | size (torch.size): Size of the tensor to be parametrized 13 | f (str or callable or pair of callables): Optional. Either: 14 | 15 | - ``"softplus"`` 16 | 17 | - A callable that maps real numbers to the interval :math:`(0, \infty)` 18 | 19 | - A pair of callables such that the first maps the real numbers to 20 | :math:`(0, \infty)` and the second is a (right) inverse of the first 21 | 22 | Default: ``"softplus"`` 23 | triv (str or callable): Optional. 24 | A map that maps skew-symmetric matrices onto the orthogonal matrices 25 | surjectively. This is used to optimize the :math:`Q` in the eigenvalue 26 | decomposition. It can be one of ``["expm", "cayley"]`` or a custom 27 | callable. Default: ``"expm"`` 28 | """ 29 | super().__init__(size, PSD.rank(size), f, triv=triv) 30 | 31 | @classmethod 32 | def rank(cls, size): 33 | if len(size) < 2: 34 | raise VectorError(cls.__name__, size) 35 | n, k = size[-2:] 36 | if n != k: 37 | raise NonSquareError(cls.__name__, size) 38 | return n 39 | 40 | def extra_repr(self): 41 | return _extra_repr( 42 | n=self.n, 43 | tensorial_size=self.tensorial_size, 44 | f=self.f, 45 | no_inv=self.inv is None, 46 | ) 47 | -------------------------------------------------------------------------------- /geotorch/pssd.py: -------------------------------------------------------------------------------- 1 | from .pssdlowrank import PSSDLowRank 2 | from .exceptions import VectorError, NonSquareError 3 | from .utils import _extra_repr 4 | 5 | 6 | class PSSD(PSSDLowRank): 7 | def __init__(self, size, triv="expm"): 8 | r""" 9 | Manifold of symmetric positive semidefinite matrices 10 | 11 | Args: 12 | size (torch.size): Size of the tensor to be parametrized 13 | triv (str or callable): Optional. 14 | A map that maps skew-symmetric matrices onto the orthogonal matrices 15 | matrices surjectively. This is used to optimize the :math:`Q` in the eigenvalue 16 | decomposition. It can be one of ``["expm", "cayley"]`` or a custom 17 | callable. Default: ``"expm"`` 18 | """ 19 | super().__init__(size, PSSD.rank(size), triv) 20 | 21 | @classmethod 22 | def rank(cls, size): 23 | if len(size) < 2: 24 | raise VectorError(cls.__name__, size) 25 | n, k = size[-2:] 26 | if n != k: 27 | raise NonSquareError(cls.__name__, size) 28 | return n 29 | 30 | def extra_repr(self): 31 | return _extra_repr(n=self.n, tensorial_size=self.tensorial_size) 32 | -------------------------------------------------------------------------------- /geotorch/pssdfixedrank.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .symmetric import SymF 4 | from .fixedrank import softplus_epsilon, inv_softplus_epsilon 5 | 6 | 7 | class PSSDFixedRank(SymF): 8 | fs = {"softplus": (softplus_epsilon, inv_softplus_epsilon)} 9 | 10 | def __init__(self, size, rank, f="softplus", triv="expm"): 11 | r""" 12 | Manifold of symmetric positive semidefinite matrices of rank :math:`r`. 13 | 14 | Args: 15 | size (torch.size): Size of the tensor to be parametrized 16 | rank (int): Rank of the matrices. 17 | It has to be less or equal to 18 | :math:`\min(\texttt{size}[-1], \texttt{size}[-2])` 19 | f (str or callable or pair of callables): Optional. Either: 20 | 21 | - ``"softplus"`` 22 | 23 | - A callable that maps real numbers to the interval :math:`(0, \infty)` 24 | 25 | - A pair of callables such that the first maps the real numbers to 26 | :math:`(0, \infty)` and the second is a (right) inverse of the first 27 | 28 | Default: ``"softplus"`` 29 | triv (str or callable): Optional. 30 | A map that maps skew-symmetric matrices onto the orthogonal matrices 31 | surjectively. This is used to optimize the :math:`Q` in the eigenvalue 32 | decomposition. It can be one of ``["expm", "cayley"]`` or a custom 33 | callable. Default: ``"expm"`` 34 | """ 35 | super().__init__(size, rank, PSSDFixedRank.parse_f(f), triv) 36 | 37 | @staticmethod 38 | def parse_f(f): 39 | if f in PSSDFixedRank.fs.keys(): 40 | return PSSDFixedRank.fs[f] 41 | elif callable(f): 42 | return f, None 43 | elif isinstance(f, tuple) and callable(f[0]) and callable(f[1]): 44 | return f 45 | else: 46 | raise ValueError( 47 | "Argument f was not recognized and is " 48 | "not callable or a pair of callables. " 49 | "Should be one of {}. Found {}".format(list(PSSDFixedRank.fs.keys()), f) 50 | ) 51 | 52 | def in_manifold_eigen(self, L, eps=1e-6): 53 | r""" 54 | Checks that an ascending ordered vector of eigenvalues is in the manifold. 55 | 56 | Args: 57 | L (torch.Tensor): Vector of eigenvalues of shape `(*, rank)` 58 | eps (float): Optional. Threshold at which the eigenvalues are 59 | considered to be zero 60 | Default: ``1e-6`` 61 | """ 62 | return ( 63 | super().in_manifold_eigen(L, eps) 64 | and (L[..., -self.rank :] >= eps).all().item() 65 | ) 66 | 67 | def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6): 68 | r""" 69 | Returns a randomly sampled matrix on the manifold as 70 | 71 | .. math:: 72 | 73 | WW^\intercal \qquad W_{i,j} \sim \texttt{init_} 74 | 75 | If the sampled matrix has more than `self.rank` small singular values, the 76 | smallest ones are clamped to be at least ``eps`` in absolute value. 77 | 78 | 79 | The output of this method can be used to initialize a parametrized tensor as:: 80 | 81 | >>> layer = nn.Linear(20, 20) 82 | >>> M = PSSD(layer.weight.size()) 83 | >>> geotorch.register_parametrization(layer, "weight", M) 84 | >>> layer.weight = M.sample() 85 | 86 | Args: 87 | init\_ (callable): Optional. 88 | A function that takes a tensor and fills it in place according 89 | to some distribution. See 90 | `torch.init `_. 91 | Default: ``torch.nn.init.xavier_normal_`` 92 | eps (float): Optional. Minimum eigenvalue of the sampled matrix. 93 | Default: ``5e-6`` 94 | """ 95 | L, Q = super().sample(factorized=True, init_=init_) 96 | with torch.no_grad(): 97 | # L >= 0, as given by torch.linalg.eigvalsh() 98 | L[L < eps] = eps 99 | return (Q * L.unsqueeze(-2)) @ Q.transpose(-2, -1) 100 | -------------------------------------------------------------------------------- /geotorch/pssdlowrank.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .symmetric import SymF 3 | 4 | 5 | class PSSDLowRank(SymF): 6 | def __init__(self, size, rank, triv="expm"): 7 | r""" 8 | Variety of the symmetric positive semidefinite matrices of rank 9 | at most :math:`r`. 10 | 11 | Args: 12 | size (torch.size): Size of the tensor to be parametrized 13 | rank (int): Rank of the matrices. 14 | It has to be less or equal to 15 | :math:`\min(\texttt{size}[-1], \texttt{size}[-2])` 16 | triv (str or callable): Optional. 17 | A map that maps skew-symmetric matrices onto the orthogonal matrices 18 | surjectively. This is used to optimize the :math:`Q` in the eigenvalue 19 | decomposition. It can be one of ``["expm", "cayley"]`` or a custom 20 | callable. Default: ``"expm"`` 21 | """ 22 | super().__init__(size, rank, f=(torch.abs, torch.abs), triv=triv) 23 | -------------------------------------------------------------------------------- /geotorch/reals.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .utils import _extra_repr 3 | from .exceptions import InManifoldError 4 | 5 | 6 | class Rn(nn.Module): 7 | def __init__(self, size): 8 | r""" 9 | Vector space of unconstrained vectors. 10 | 11 | Args: 12 | size (torch.size): Size of the tensor to be parametrized 13 | """ 14 | super().__init__() 15 | self.n = size[-1] 16 | self.tensorial_size = size[:-1] 17 | 18 | def forward(self, X): 19 | return X 20 | 21 | def right_inverse(self, X, check_in_manifold=True): 22 | if check_in_manifold and not self.in_manifold(X): 23 | raise InManifoldError(X, self) 24 | return X 25 | 26 | def in_manifold(self, X): 27 | return X.size() == self.tensorial_size + (self.n,) 28 | 29 | def extra_repr(self): 30 | return _extra_repr(n=self.n, tensorial_size=self.tensorial_size) 31 | -------------------------------------------------------------------------------- /geotorch/skew.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .exceptions import VectorError, NonSquareError 4 | 5 | 6 | class Skew(nn.Module): 7 | def __init__(self, lower=True): 8 | r""" 9 | Vector space of skew-symmetric matrices, parametrized in terms of 10 | the upper or lower triangular part of a matrix. 11 | 12 | Args: 13 | size (torch.size): Size of the tensor to be parametrized 14 | lower (bool): Optional. Uses the lower triangular part of the matrix 15 | to parametrize the matrix. Default: ``True`` 16 | """ 17 | super().__init__() 18 | self.lower = lower 19 | 20 | @staticmethod 21 | def frame(X, lower): 22 | if lower: 23 | X = X.tril(-1) 24 | else: 25 | X = X.triu(1) 26 | return X - X.transpose(-2, -1) 27 | 28 | def forward(self, X): 29 | if len(X.size()) < 2: 30 | raise VectorError(type(self).__name__, X.size()) 31 | if X.size(-2) != X.size(-1): 32 | raise NonSquareError(type(self).__name__, X.size()) 33 | return self.frame(X, self.lower) 34 | 35 | @staticmethod 36 | def in_manifold(X): 37 | return ( 38 | X.dim() >= 2 39 | and X.size(-2) == X.size(-1) 40 | and torch.allclose(X, -X.transpose(-2, -1)) 41 | ) 42 | -------------------------------------------------------------------------------- /geotorch/sl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .glp import GLp 3 | from .fixedrank import FixedRank 4 | 5 | 6 | class SL(GLp): 7 | def __init__(self, size, f="softplus", triv="expm"): 8 | r""" 9 | Manifold of special linear matrices 10 | 11 | Args: 12 | size (torch.size): Size of the tensor to be parametrized 13 | f (str or callable or pair of callables): Optional. Either: 14 | 15 | - ``"softplus"`` 16 | 17 | - A callable that maps real numbers to the interval :math:`(0, \infty)` 18 | 19 | - A pair of callables such that the first maps the real numbers to 20 | :math:`(0, \infty)` and the second is a (right) inverse of the first 21 | 22 | Default: ``"softplus"`` 23 | triv (str or callable): Optional. 24 | A map that maps skew-symmetric matrices onto the orthogonal matrices 25 | surjectively. This is used to optimize the :math:`U` and :math:`V` in the 26 | SVD. It can be one of ``["expm", "cayley"]`` or a custom 27 | callable. Default: ``"expm"`` 28 | """ 29 | super().__init__(size, SL.parse_f(f), triv) 30 | 31 | @staticmethod 32 | def parse_f(f_name): 33 | if f_name in FixedRank.fs.keys(): 34 | f, inv = FixedRank.parse_f(f_name) 35 | 36 | def f_sl(x): 37 | y = f(x) 38 | return y / y.prod(dim=-1, keepdim=True).pow(1.0 / y.shape[-1]) 39 | 40 | return (f_sl, inv) 41 | else: 42 | return f_name 43 | 44 | def in_manifold_singular_values(self, S, eps=5e-3): 45 | if not super().in_manifold_singular_values(S, eps): 46 | return False 47 | # We compute the \infty-norm of the determinant minus 1 and should be about zero 48 | infty_norm = (S.prod(dim=-1) - 1).abs().max(dim=-1).values 49 | return (infty_norm < eps).all().item() 50 | 51 | def in_manifold(self, X, eps=5e-3): 52 | r""" 53 | Checks that a given matrix is in the manifold. 54 | 55 | Args: 56 | X (torch.Tensor or tuple): The input matrix or matrices of shape ``(*, n, k)``. 57 | eps (float): Optional. Threshold at which the singular values are 58 | considered to be zero 59 | Default: ``5e-3`` 60 | """ 61 | # The purpose of this function is just to have a more lax default eps value 62 | return super().in_manifold(X, eps) 63 | 64 | def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6, factorized=False): 65 | r""" 66 | Returns a randomly sampled matrix on the manifold by sampling a matrix according 67 | to ``init_`` and projecting it onto the manifold. 68 | 69 | The output of this method can be used to initialize a parametrized tensor 70 | that has been parametrized with this or any other manifold as:: 71 | 72 | >>> layer = nn.Linear(20, 20) 73 | >>> M = SL(layer.weight.size(), rank=6) 74 | >>> geotorch.register_parametrization(layer, "weight", M) 75 | >>> layer.weight = M.sample() 76 | 77 | Args: 78 | init\_ (callable): Optional. A function that takes a tensor and fills it 79 | in place according to some distribution. See 80 | `torch.init `_. 81 | Default: ``torch.nn.init.xavier_normal_`` 82 | eps (float): Optional. Minimum singular value of the sampled matrix. 83 | Default: ``5e-6`` 84 | """ 85 | U, S, V = super().sample(factorized=True, init_=init_) 86 | with torch.no_grad(): 87 | # S >= 0, as given by torch.linalg.eigvalsh() 88 | S = S / S.prod(dim=-1, keepdim=True).pow(1.0 / S.shape[-1]) 89 | return (U * S.unsqueeze(-2)) @ V.transpose(-2, -1) 90 | -------------------------------------------------------------------------------- /geotorch/so.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | try: 6 | from torch.linalg import matrix_exp as expm 7 | except ImportError: 8 | from torch import matrix_exp as expm 9 | 10 | from .utils import _extra_repr 11 | from .skew import Skew 12 | from .exceptions import NonSquareError, VectorError, InManifoldError 13 | 14 | 15 | def _has_orthonormal_columns(X, eps): 16 | k = X.size(-1) 17 | Id = torch.eye(k, dtype=X.dtype, device=X.device) 18 | return torch.allclose(X.transpose(-2, -1) @ X, Id, atol=eps) 19 | 20 | 21 | def cayley_map(X): 22 | # compute (I+X/2)(I-X/2)^{-1} 23 | n = X.size(-1) 24 | Id = torch.eye(n, dtype=X.dtype, device=X.device) 25 | return torch.linalg.solve(Id.add(X, alpha=-0.5), Id.add(X, alpha=0.5)) 26 | 27 | 28 | class SO(nn.Module): 29 | trivializations = {"expm": expm, "cayley": cayley_map} 30 | 31 | def __init__(self, size, triv="expm", lower=True): 32 | r""" 33 | Manifold of square orthogonal matrices with positive determinant parametrized 34 | in terms of its Lie algebra, the skew-symmetric matrices. 35 | 36 | Args: 37 | size (torch.size): Size of the tensor to be parametrized 38 | triv (str or callable): Optional. 39 | A map that maps skew-symmetric onto :math:`\operatorname{SO}(n)` 40 | surjectively. It can be one of ``["expm", "cayley"]`` or a custom 41 | callable. Default: ``"expm"`` 42 | lower (bool): Optional. Uses the lower triangular part of the matrix to 43 | parametrize the skew-symmetric matrices. Default: ``True`` 44 | """ 45 | super().__init__() 46 | n, tensorial_size = SO.parse_size(size) 47 | self.n = n 48 | self.tensorial_size = tensorial_size 49 | self.lower = lower 50 | self.triv = SO.parse_triv(triv) 51 | self.register_buffer( 52 | "base", torch.empty(*(self.tensorial_size + (self.n, self.n))) 53 | ) 54 | uniform_init_(self.base) 55 | 56 | @classmethod 57 | def parse_size(cls, size): 58 | if len(size) < 2: 59 | raise VectorError(cls.__name__, size) 60 | n = max(size[-2:]) 61 | k = min(size[-2:]) 62 | if n != k: 63 | raise NonSquareError(cls.__name__, size) 64 | tensorial_size = size[:-2] 65 | return n, tensorial_size 66 | 67 | @staticmethod 68 | def parse_triv(triv): 69 | if triv in SO.trivializations.keys(): 70 | return SO.trivializations[triv] 71 | elif callable(triv): 72 | return triv 73 | else: 74 | raise ValueError( 75 | "Argument triv was not recognized and is " 76 | "not callable. Should be one of {}. Found {}".format( 77 | list(SO.trivializations.keys()), triv 78 | ) 79 | ) 80 | 81 | def forward(self, X): 82 | X = Skew.frame(X, self.lower) 83 | return self.base @ self.triv(X) 84 | 85 | def right_inverse(self, X, check_in_manifold=True): 86 | if check_in_manifold and not self.in_manifold(X): 87 | raise InManifoldError(X, self) 88 | with torch.no_grad(): 89 | self.base.copy_(X) 90 | return torch.zeros_like(X) 91 | 92 | def in_manifold(self, X, in_so=False, eps=1e-4): 93 | r""" 94 | Checks that a matrix is in the manifold. 95 | 96 | For tensors with more than 2 dimensions the first dimensions are 97 | treated as batch dimensions. 98 | 99 | Args: 100 | X (torch.Tensor): The matrix to be checked 101 | in_so (bool): Optional. Checks that the matrix is orthogonal and 102 | has positive determinant. Otherwise just orthogonality is checked. 103 | Default: ``False`` 104 | eps (float): Optional. Tolerance to numerical errors. 105 | Default: ``1e-4`` 106 | """ 107 | if X.size() != self.base.size(): 108 | return False 109 | is_orth = _has_orthonormal_columns(X, eps) 110 | X_in_correct_coset = not in_so or (X.det() > 0.0).all().item() 111 | return is_orth and X_in_correct_coset 112 | 113 | def sample(self, distribution="uniform", init_=None): 114 | r""" 115 | Returns a randomly sampled orthogonal matrix according to the specified 116 | ``distribution``. The options are: 117 | 118 | - ``"uniform"``: Samples a tensor distributed according to the Haar measure 119 | on :math:`\operatorname{SO}(n)` 120 | 121 | - ``"torus"``: Samples a block-diagonal skew-symmetric matrix. 122 | The blocks are of the form 123 | :math:`\begin{pmatrix} 0 & b \\ -b & 0\end{pmatrix}` where :math:`b` is 124 | distributed according to ``init_``. This matrix will be then projected 125 | onto :math:`\operatorname{SO}(n)` using ``self.triv`` 126 | 127 | .. note 128 | 129 | The ``"torus"`` initialization is particularly useful in recurrent kernels 130 | of RNNs 131 | 132 | Args: 133 | distribution (string): Optional. One of ``["uniform", "torus"]``. 134 | Default: ``"uniform"`` 135 | init\_ (callable): Optional. To be used with the ``"torus"`` option. 136 | A function that takes a tensor and fills it in place according 137 | to some distribution. See 138 | `torch.init `_. 139 | Default: :math:`\operatorname{Uniform}(-\pi, \pi)` 140 | """ 141 | device = self.base.device 142 | dtype = self.base.dtype 143 | ret = torch.empty( 144 | *(self.tensorial_size + (self.n, self.n)), device=device, dtype=dtype 145 | ) 146 | if distribution == "uniform": 147 | uniform_init_(ret) 148 | elif distribution == "torus": 149 | torus_init_(ret, init_, self.triv) 150 | else: 151 | raise ValueError( 152 | 'The ditribution has to be one of ["uniform", "torus"]. ' 153 | "Got {}".format(distribution) 154 | ) 155 | return ret 156 | 157 | def extra_repr(self): 158 | return _extra_repr(n=self.n, tensorial_size=self.tensorial_size, triv=self.triv) 159 | 160 | 161 | def uniform_init_(tensor): 162 | r"""Fills in the input ``tensor`` in place with an orthogonal matrix. 163 | If square, the matrix will have positive determinant. 164 | The tensor will be distributed according to the Haar measure. 165 | The input tensor must have at least 2 dimensions. 166 | For tensors with more than 2 dimensions the first dimensions are treated as 167 | batch dimensions. 168 | 169 | Args: 170 | tensor (torch.Tensor): a 2-dimensional tensor or a batch of them 171 | """ 172 | # We re-implement torch.nn.init.orthogonal_, as their treatment of batches 173 | # is not in a per-matrix base 174 | if tensor.ndim < 2: 175 | raise ValueError( 176 | "Only tensors with 2 or more dimensions are supported. " 177 | "Got a tensor of shape {}".format(tuple(tensor.size())) 178 | ) 179 | n, k = tensor.size()[-2:] 180 | transpose = n < k 181 | with torch.no_grad(): 182 | x = torch.empty_like(tensor).normal_(0, 1) 183 | if transpose: 184 | x.transpose_(-2, -1) 185 | q, r = torch.linalg.qr(x) 186 | 187 | # Make uniform (diag r >= 0) 188 | d = r.diagonal(dim1=-2, dim2=-1).sign() 189 | q *= d.unsqueeze(-2) 190 | if transpose: 191 | q.transpose_(-2, -1) 192 | 193 | # Make them have positive determinant by multiplying the 194 | # first column by -1 (does not change the measure) 195 | if n == k: 196 | mask = (torch.det(q) >= 0.0).float() 197 | mask[mask == 0.0] = -1.0 198 | mask = mask.unsqueeze(-1) 199 | q[..., 0] *= mask 200 | tensor.copy_(q) 201 | return tensor 202 | 203 | 204 | def torus_init_(tensor, init_=None, triv=expm): 205 | r"""Fills in the input ``tensor`` in place as a block-diagonal skew-symmetric matrix. 206 | The blocks are of the form 207 | :math:`\begin{pmatrix} 0 & b \\ -b & 0\end{pmatrix}` where :math:`b` is 208 | distributed according to ``init_``. 209 | This matrix is then projected onto the manifold using ``triv``. 210 | 211 | The input tensor must have at least 2 dimension. For tensors with more than 2 dimensions 212 | the first dimensions are treated as batch dimensions. 213 | 214 | Args: 215 | tensor (torch.Tensor): a 2-dimensional tensor 216 | init\_ (callable): Optional. A function that takes a tensor and fills 217 | it in place according to some distribution. See 218 | `torch.init `_. 219 | Default: :math:`\operatorname{Uniform}(-\pi, \pi)` 220 | triv (callable): Optional. A function that maps skew-symmetric matrices 221 | to orthogonal matrices. 222 | """ 223 | if tensor.ndim < 2 or tensor.size(-1) != tensor.size(-2): 224 | raise ValueError( 225 | "Only tensors with 2 or more dimensions which are square in " 226 | "the last two dimensions are supported. " 227 | "Got a tensor of shape {}".format(tuple(tensor.size())) 228 | ) 229 | 230 | n = tensor.size(-2) 231 | tensorial_size = tensor.size()[:-2] 232 | 233 | # Non-zero elements that we are going to set on the diagonal 234 | n_diag = n // 2 235 | diag = tensor.new(tensorial_size + (n_diag,)) 236 | if init_ is None: 237 | torch.nn.init.uniform_(diag, -math.pi, math.pi) 238 | else: 239 | init_(diag) 240 | 241 | with torch.no_grad(): 242 | # First non-central diagonal 243 | diag_z = tensor.new_zeros(tensorial_size + (n - 1,)) 244 | diag_z[..., ::2] = diag 245 | x = torch.diag_embed(diag_z, offset=-1) 246 | tensor.copy_(triv(x - x.transpose(-2, -1))) 247 | return tensor 248 | -------------------------------------------------------------------------------- /geotorch/sphere.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .exceptions import InManifoldError 5 | from .utils import _extra_repr 6 | 7 | 8 | def project(x): 9 | return x / x.norm(dim=-1, keepdim=True) 10 | 11 | 12 | def uniform_init_sphere_(x, r=1.0): 13 | r"""Samples a point uniformly on the sphere into the tensor ``x``. 14 | If ``x`` has :math:`d > 1` dimensions, the first :math:`d-1` dimensions 15 | are treated as batch dimensions. 16 | """ 17 | with torch.no_grad(): 18 | x.normal_() 19 | x.data = r * project(x) 20 | return x 21 | 22 | 23 | def _in_sphere(x, r, eps): 24 | norm = x.norm(dim=-1) 25 | rs = torch.full_like(norm, r) 26 | return (torch.norm(norm - rs, p=float("inf")) < eps).all() 27 | 28 | 29 | class sinc_class(torch.autograd.Function): 30 | @staticmethod 31 | def forward(ctx, x): 32 | ctx.save_for_backward(x) 33 | # Hardocoded for float, will do for now 34 | ret = torch.sin(x) / x 35 | ret[x.abs() < 1e-45] = 1.0 36 | return ret 37 | 38 | @staticmethod 39 | def backward(ctx, grad_output): 40 | (x,) = ctx.saved_tensors 41 | ret = torch.cos(x) / x - torch.sin(x) / (x * x) 42 | ret[x.abs() < 1e-10] = 0.0 43 | return ret * grad_output 44 | 45 | 46 | sinc = sinc_class.apply 47 | 48 | 49 | class SphereEmbedded(nn.Module): 50 | def __init__(self, size, radius=1.0): 51 | r""" 52 | Sphere as the orthogonal projection from 53 | :math:`\mathbb{R}^n` to :math:`\mathbb{S}^{n-1}`, that is, 54 | :math:`x \mapsto \frac{x}{\lVert x \rVert}`. 55 | 56 | Args: 57 | size (torch.size): Size of the tensor to be parametrized 58 | radius (float): Optional. 59 | Radius of the sphere. A positive number. Default: ``1.`` 60 | """ 61 | super().__init__() 62 | self.n = size[-1] 63 | self.tensorial_size = size[:-1] 64 | self.radius = SphereEmbedded.parse_radius(radius) 65 | 66 | @staticmethod 67 | def parse_radius(radius): 68 | if radius <= 0.0: 69 | raise ValueError( 70 | "The radius has to be a positive real number. Got {}".format(radius) 71 | ) 72 | return radius 73 | 74 | def forward(self, x): 75 | return self.radius * project(x) 76 | 77 | def right_inverse(self, x, check_in_manifold=True): 78 | if check_in_manifold and not self.in_manifold(x): 79 | raise InManifoldError(x, self) 80 | return x / self.radius 81 | 82 | def in_manifold(self, x, eps=1e-5): 83 | r""" 84 | Checks that a vector is on the sphere. 85 | 86 | For tensors with more than 2 dimensions the first dimensions are 87 | treated as batch dimensions. 88 | 89 | Args: 90 | X (torch.Tensor): The vector to be checked. 91 | eps (float): Optional. Threshold at which the norm is considered 92 | to be equal to ``1``. Default: ``1e-5`` 93 | """ 94 | return _in_sphere(x, self.radius, eps) 95 | 96 | def sample(self): 97 | r""" 98 | Returns a uniformly sampled vector on the sphere. 99 | """ 100 | x = torch.empty(*(self.tensorial_size) + (self.n,)) 101 | return uniform_init_sphere_(x, r=self.radius) 102 | 103 | def extra_repr(self): 104 | return _extra_repr( 105 | n=self.n, radius=self.radius, tensorial_size=self.tensorial_size 106 | ) 107 | 108 | 109 | class Sphere(nn.Module): 110 | def __init__(self, size, radius=1.0): 111 | r""" 112 | Sphere as a map from the tangent space onto the sphere using the 113 | exponential map. 114 | 115 | Args: 116 | size (torch.size): Size of the tensor to be parametrized 117 | radius (float): Optional. 118 | Radius of the sphere. A positive number. Default: ``1.`` 119 | """ 120 | super().__init__() 121 | self.n = size[-1] 122 | self.tensorial_size = size[:-1] 123 | self.radius = Sphere.parse_radius(radius) 124 | self.register_buffer("base", uniform_init_sphere_(torch.empty(*size))) 125 | 126 | @staticmethod 127 | def parse_radius(radius): 128 | if radius <= 0.0: 129 | raise ValueError( 130 | "The radius has to be a positive real number. Got {}".format(radius) 131 | ) 132 | return radius 133 | 134 | def frame(self, x, v): 135 | projection = (v.unsqueeze(-2) @ x.unsqueeze(-1)).squeeze(-1) 136 | v = v - projection * x 137 | return v 138 | 139 | def forward(self, v): 140 | x = self.base 141 | # Project v onto { = 0} 142 | v = self.frame(x, v) 143 | vnorm = v.norm(dim=-1, keepdim=True) 144 | return self.radius * (torch.cos(vnorm) * x + sinc(vnorm) * v) 145 | 146 | def right_inverse(self, x, check_in_manifold=True): 147 | if check_in_manifold and not self.in_manifold(x): 148 | raise InManifoldError(x, self) 149 | with torch.no_grad(): 150 | x = x / self.radius 151 | self.base.copy_(x) 152 | return torch.zeros_like(x) 153 | 154 | def in_manifold(self, x, eps=1e-5): 155 | r""" 156 | Checks that a vector is on the sphere. 157 | 158 | For tensors with more than 2 dimensions the first dimensions are 159 | treated as batch dimensions. 160 | 161 | Args: 162 | X (torch.Tensor): The vector to be checked. 163 | eps (float): Optional. Threshold at which the norm is considered 164 | to be equal to ``1``. Default: ``1e-5`` 165 | """ 166 | return _in_sphere(x, self.radius, eps) 167 | 168 | def sample(self): 169 | r""" 170 | Returns a uniformly sampled vector on the sphere. 171 | """ 172 | device = self.base.device 173 | dtype = self.base.dtype 174 | x = torch.empty(*(self.tensorial_size) + (self.n,), device=device, dtype=dtype) 175 | return uniform_init_sphere_(x, r=self.radius) 176 | 177 | def extra_repr(self): 178 | return _extra_repr( 179 | n=self.n, radius=self.radius, tensorial_size=self.tensorial_size 180 | ) 181 | -------------------------------------------------------------------------------- /geotorch/stiefel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .utils import transpose, _extra_repr 4 | from .so import SO, _has_orthonormal_columns 5 | 6 | from .exceptions import VectorError, InManifoldError 7 | 8 | 9 | class Stiefel(SO): 10 | def __init__(self, size, triv="expm"): 11 | r""" 12 | Manifold of rectangular orthogonal matrices parametrized as a projection 13 | onto the first :math:`k` columns from the space of square orthogonal matrices 14 | :math:`\operatorname{SO}(n)`. The metric considered is the canonical. 15 | 16 | Args: 17 | size (torch.size): Size of the tensor to be parametrized 18 | triv (str or callable): Optional. 19 | A map that maps skew-symmetric matrices onto the orthogonal matrices 20 | surjectively. It can be one of ``["expm", "cayley"]`` or a custom 21 | callable. Default: ``"expm"`` 22 | """ 23 | super().__init__(size=Stiefel.size_so(size), triv=triv, lower=True) 24 | self.k = min(size[-1], size[-2]) 25 | self.transposed = size[-2] < size[-1] 26 | 27 | @classmethod 28 | def size_so(cls, size): 29 | if len(size) < 2: 30 | raise VectorError(cls.__name__, size) 31 | size_so = list(size) 32 | size_so[-1] = size_so[-2] = max(size[-1], size[-2]) 33 | return tuple(size_so) 34 | 35 | def frame(self, X): 36 | n, k = X.size(-2), X.size(-1) 37 | size_z = X.size()[:-2] + (n, n - k) 38 | return torch.cat([X, X.new_zeros(*size_z)], dim=-1) 39 | 40 | @transpose 41 | def forward(self, X): 42 | X = self.frame(X) 43 | X = super().forward(X) 44 | return X[..., : self.k] 45 | 46 | @transpose 47 | def right_inverse(self, X, check_in_manifold=True): 48 | if check_in_manifold and not self.in_manifold(X): 49 | raise InManifoldError(X, self) 50 | if self.n != self.k: 51 | # N will be a completion of X to an orthogonal basis of R^n 52 | N = X.new_empty(*(self.tensorial_size + (self.n, self.n - self.k))) 53 | with torch.no_grad(): 54 | N.normal_() 55 | # We assume for now that X is orthogonal. 56 | # This will be checked in super().right_inverse() 57 | # Project N onto the orthogonal complement to X 58 | # We iterate this twice for this algorithm to be numerically stable 59 | # This is standard, as done in some stochastic SVD algorithms 60 | for _ in range(2): 61 | N = N - X @ (X.transpose(-2, -1) @ N) 62 | # And make it an orthonormal base of the image 63 | N = torch.linalg.qr(N).Q 64 | X = torch.cat([X, N], dim=-1) 65 | return super().right_inverse(X, check_in_manifold=False)[..., : self.k] 66 | 67 | def in_manifold(self, X, eps=1e-4): 68 | r""" 69 | Checks that a matrix is in the manifold. 70 | 71 | For tensors with more than 2 dimensions the first dimensions are 72 | treated as batch dimensions. 73 | 74 | Args: 75 | X (torch.Tensor): The matrix to be checked 76 | eps (float): Optional. Tolerance to numerical errors. 77 | Default: ``1e-4`` 78 | """ 79 | if X.size(-1) > X.size(-2): 80 | X = X.transpose(-2, -1) 81 | if X.size() != self.tensorial_size + (self.n, self.k): 82 | return False 83 | return _has_orthonormal_columns(X, eps) 84 | 85 | def sample(self, distribution="uniform", init_=None): 86 | r""" 87 | Returns a randomly sampled orthogonal matrix according to the specified 88 | ``distribution``. The options are: 89 | 90 | - ``"uniform"``: Samples a tensor distributed according to the Haar measure 91 | on :math:`\operatorname{SO}(n)` 92 | 93 | - ``"torus"``: Samples a block-diagonal skew-symmetric matrix. 94 | The blocks are of the form 95 | :math:`\begin{pmatrix} 0 & b \\ -b & 0\end{pmatrix}` where :math:`b` is 96 | distributed according to ``init_``. This matrix will be then projected 97 | onto :math:`\operatorname{SO}(n)` using ``self.triv`` 98 | 99 | .. note 100 | 101 | The ``"torus"`` initialization is particularly useful in recurrent kernels 102 | of RNNs 103 | 104 | Args: 105 | distribution (string): Optional. One of ``["uniform", "torus"]``. 106 | Default: ``"uniform"`` 107 | init\_ (callable): Optional. To be used with the ``"torus"`` option. 108 | A function that takes a tensor and fills it in place according 109 | to some distribution. See 110 | `torch.init `_. 111 | Default: :math:`\operatorname{Uniform}(-\pi, \pi)` 112 | """ 113 | X = super().sample(distribution, init_) 114 | if not self.transposed: 115 | return X[..., : self.k] 116 | else: 117 | return X[..., : self.k, :] 118 | 119 | def extra_repr(self): 120 | return _extra_repr( 121 | n=self.n, 122 | k=self.k, 123 | tensorial_size=self.tensorial_size, 124 | triv=self.triv, 125 | transposed=self.transposed, 126 | ) 127 | -------------------------------------------------------------------------------- /geotorch/symmetric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .product import ProductManifold 5 | from .stiefel import Stiefel 6 | from .reals import Rn 7 | from .exceptions import ( 8 | VectorError, 9 | NonSquareError, 10 | RankError, 11 | InManifoldError, 12 | InverseError, 13 | ) 14 | from .utils import _extra_repr 15 | 16 | 17 | class Symmetric(nn.Module): 18 | def __init__(self, lower=True): 19 | r""" 20 | Vector space of symmetric matrices, parametrized in terms of the upper or lower 21 | triangular part of a matrix. 22 | 23 | Args: 24 | size (torch.size): Size of the tensor to be parametrized 25 | lower (bool): Optional. Uses the lower triangular part of the matrix to 26 | parametrize the matrix. Default: ``True`` 27 | """ 28 | super().__init__() 29 | self.lower = lower 30 | 31 | @staticmethod 32 | def frame(X, lower): 33 | if lower: 34 | return X.tril(0) + X.tril(-1).transpose(-2, -1) 35 | else: 36 | return X.triu(0) + X.triu(1).transpose(-2, -1) 37 | 38 | def forward(self, X): 39 | if len(X.size()) < 2: 40 | raise VectorError(type(self).__name__, X.size()) 41 | if X.size(-2) != X.size(-1): 42 | raise NonSquareError(type(self).__name__, X.size()) 43 | return self.frame(X, self.lower) 44 | 45 | @staticmethod 46 | def in_manifold(X, eps=1e-6): 47 | return ( 48 | X.dim() >= 2 49 | and X.size(-2) == X.size(-1) 50 | and torch.allclose(X, X.transpose(-2, -1), atol=eps) 51 | ) 52 | 53 | 54 | class SymF(ProductManifold): 55 | def __init__(self, size, rank, f, triv="expm"): 56 | r""" 57 | Space of the symmetric matrices of rank at most k with eigenvalues 58 | in the image of a given function 59 | 60 | Args: 61 | size (torch.size): Size of the tensor to be parametrized 62 | rank (int): Rank of the matrices. 63 | It has to be less or equal to 64 | :math:`\min(\texttt{size}[-1], \texttt{size}[-2])` 65 | f (callable or pair of callables): Either: 66 | 67 | - A callable 68 | 69 | - A pair of callables such that the second is a (right) 70 | inverse of the first 71 | triv (str or callable): Optional. 72 | A map that maps skew-symmetric matrices onto the orthogonal matrices 73 | surjectively. This is used to optimize the :math:`Q` in the eigenvalue 74 | decomposition. It can be one of ``["expm", "cayley"]`` or a custom 75 | callable. Default: ``"expm"`` 76 | """ 77 | n, tensorial_size = SymF.parse_size(size) 78 | if rank > n or rank < 1: 79 | raise RankError(n, n, rank) 80 | super().__init__(SymF.manifolds(n, rank, tensorial_size, triv)) 81 | self.n = n 82 | self.tensorial_size = tensorial_size 83 | self.rank = rank 84 | f, inv = SymF.parse_f(f) 85 | self.f = f 86 | self.inv = inv 87 | 88 | @classmethod 89 | def parse_size(cls, size): 90 | if len(size) < 2: 91 | raise VectorError(cls.__name__, size) 92 | n, k = size[-2:] 93 | tensorial_size = size[:-2] 94 | if n != k: 95 | raise NonSquareError(cls.__name__, size) 96 | return n, tensorial_size 97 | 98 | @staticmethod 99 | def parse_f(f): 100 | if callable(f): 101 | return f, None 102 | elif isinstance(f, tuple) and callable(f[0]) and callable(f[1]): 103 | return f 104 | else: 105 | raise ValueError( 106 | "Argument f is not callable nor a pair of callables. " 107 | "Found {}".format(f) 108 | ) 109 | 110 | @staticmethod 111 | def manifolds(n, rank, tensorial_size, triv): 112 | size_q = tensorial_size + (n, rank) 113 | size_l = tensorial_size + (rank,) 114 | return Stiefel(size_q, triv=triv), Rn(size_l) 115 | 116 | def frame(self, X): 117 | L = X.diagonal(dim1=-2, dim2=-1)[..., : self.rank] 118 | X = X[..., : self.rank] 119 | return X, L 120 | 121 | def submersion(self, Q, L): 122 | L = self.f(L) 123 | return (Q * L.unsqueeze(-2)) @ Q.transpose(-2, -1) 124 | 125 | def forward(self, X): 126 | X = self.frame(X) 127 | Q, L = super().forward(X) 128 | return self.submersion(Q, L) 129 | 130 | def frame_inv(self, X1, X2): 131 | size = self.tensorial_size + (self.n, self.n) 132 | ret = torch.zeros(*size, dtype=X1.dtype, device=X1.device) 133 | with torch.no_grad(): 134 | ret[..., : self.rank] += X1 135 | ret[..., : self.rank, : self.rank] += torch.diag_embed(X2) 136 | return ret 137 | 138 | def submersion_inv(self, X, check_in_manifold=True): 139 | with torch.no_grad(): 140 | L, Q = torch.linalg.eigh(X) 141 | if check_in_manifold and not self.in_manifold_eigen(L): 142 | raise InManifoldError(X, self) 143 | if self.inv is None: 144 | raise InverseError(self) 145 | with torch.no_grad(): 146 | Q = Q[..., -self.rank :] 147 | L = L[..., -self.rank :] 148 | L = self.inv(L) 149 | return L, Q 150 | 151 | def right_inverse(self, X, check_in_manifold=True): 152 | L, Q = self.submersion_inv(X, check_in_manifold) 153 | X1, X2 = super().right_inverse([Q, L], check_in_manifold=False) 154 | return self.frame_inv(X1, X2) 155 | 156 | def in_manifold_eigen(self, L, eps=1e-6): 157 | r""" 158 | Checks that an ascending ordered vector of eigenvalues is in the manifold. 159 | 160 | Args: 161 | L (torch.Tensor): Vector of eigenvalues of shape `(*, rank)` 162 | eps (float): Optional. Threshold at which the eigenvalues are 163 | considered to be zero 164 | Default: ``1e-6`` 165 | """ 166 | if L.size()[:-1] != self.tensorial_size: 167 | return False 168 | if L.size(-1) > self.rank: 169 | # We compute the \infty-norm of the remaining dimension 170 | D = L[..., : -self.rank] 171 | infty_norm_err = D.abs().max(dim=-1).values 172 | if (infty_norm_err > 5.0 * eps).any(): 173 | return False 174 | return (L[..., -self.rank :] >= -eps).all().item() 175 | 176 | def in_manifold(self, X, eps=1e-6): 177 | r""" 178 | Checks that a matrix is in the manifold. 179 | 180 | Args: 181 | X (torch.Tensor): The matrix or batch of matrices of shape ``(*, n, n)`` to check. 182 | eps (float): Optional. Threshold at which the singular values are 183 | considered to be zero. Default: ``1e-6`` 184 | """ 185 | size = self.tensorial_size + (self.n, self.n) 186 | if X.size() != size or not Symmetric.in_manifold(X, eps): 187 | return False 188 | L = torch.linalg.eigvalsh(X) 189 | return self.in_manifold_eigen(L, eps) 190 | 191 | def sample(self, init_=torch.nn.init.xavier_normal_, factorized=False): 192 | r""" 193 | Returns a randomly sampled matrix on the manifold as 194 | 195 | .. math:: 196 | 197 | WW^\intercal \qquad W_{i,j} \sim \texttt{init_} 198 | 199 | By default ``init\_`` is a (xavier) normal distribution, so that the 200 | returned matrix follows a Wishart distribution. 201 | 202 | The output of this method can be used to initialize a parametrized tensor 203 | that has been parametrized with this or any other manifold as:: 204 | 205 | >>> layer = nn.Linear(20, 20) 206 | >>> M = PSSD(layer.weight.size()) 207 | >>> geotorch.register_parametrization(layer, "weight", M) 208 | >>> layer.weight = M.sample() 209 | 210 | Args: 211 | init\_ (callable): Optional. 212 | A function that takes a tensor and fills it in place according 213 | to some distribution. See 214 | `torch.init `_. 215 | Default: ``torch.nn.init.xavier_normal_`` 216 | """ 217 | with torch.no_grad(): 218 | device = self[0].base.device 219 | dtype = self[0].base.dtype 220 | X = torch.empty( 221 | *(self.tensorial_size + (self.n, self.n)), device=device, dtype=dtype 222 | ) 223 | init_(X) 224 | X = X @ X.transpose(-2, -1) 225 | L, Q = torch.linalg.eigh(X) 226 | L = L[..., -self.rank :] 227 | Q = Q[..., -self.rank :] 228 | if factorized: 229 | return L, Q 230 | else: 231 | return (Q * L.unsqueeze(-2)) @ Q.transpose(-2, -1) 232 | 233 | def extra_repr(self): 234 | return _extra_repr( 235 | n=self.n, 236 | rank=self.rank, 237 | tensorial_size=self.tensorial_size, 238 | f=self.f, 239 | no_inv=self.inv is None, 240 | ) 241 | -------------------------------------------------------------------------------- /geotorch/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def update_base(layer, tensor_name): 5 | with torch.no_grad(): 6 | setattr(layer, tensor_name, getattr(layer, tensor_name).data) 7 | 8 | 9 | def transpose(fun): 10 | def new_fun(self, X, *args, **kwargs): 11 | if self.transposed: 12 | X = X.transpose(-2, -1) 13 | X = fun(self, X, *args, **kwargs) 14 | if self.transposed: 15 | X = X.transpose(-2, -1) 16 | return X 17 | 18 | return new_fun 19 | 20 | 21 | def _extra_repr(**kwargs): # noqa: C901 22 | if "n" in kwargs: 23 | ret = "n={}".format(kwargs["n"]) 24 | elif "dim" in kwargs: 25 | ret = "dim={}".format(kwargs["dim"]) 26 | else: 27 | ret = "" 28 | 29 | if "k" in kwargs: 30 | ret += ", k={}".format(kwargs["k"]) 31 | if "rank" in kwargs: 32 | ret += ", rank={}".format(kwargs["rank"]) 33 | if "radius" in kwargs: 34 | ret += ", radius={}".format(kwargs["radius"]) 35 | if "lam" in kwargs: 36 | ret += ", lambda={}".format(kwargs["lam"]) 37 | if "f" in kwargs: 38 | ret += ", f={}".format(kwargs["f"].__name__) 39 | if "tensorial_size" in kwargs: 40 | ts = kwargs["tensorial_size"] 41 | if len(ts) != 0: 42 | ret += ", tensorial_size={}".format(tuple(ts)) 43 | if "triv" in kwargs: 44 | ret += ", triv={}".format(kwargs["triv"].__name__) 45 | if "no_inv" in kwargs: 46 | if kwargs["no_inv"]: 47 | ret += ", no inverse" 48 | if "transposed" in kwargs: 49 | if kwargs["transposed"]: 50 | ret += ", transposed" 51 | return ret 52 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # E203: black and flake8 disagree on whitespace before ':' 3 | # W503: black and flake8 disagree on how to place operators 4 | ignore = E203, W503 5 | max-line-length = 88 6 | # Exclude as it will come directly from PyTorch in the future 7 | exclude = geotorch/parametrize.py 8 | 9 | [coverage:report] 10 | omit = 11 | test/* 12 | setup.py 13 | geotorch/parametrize.py 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import os 3 | import re 4 | 5 | TEST_REQUIRES = ["pytest"] 6 | 7 | DEV_REQUIRES = TEST_REQUIRES + [ 8 | "black", 9 | "flake8", 10 | "sphinx", 11 | "sphinx-autodoc-typehints", 12 | "sphinx-rtd-theme", 13 | "sphinxcontrib-spelling", 14 | "codecov", 15 | ] 16 | 17 | classifiers = [ 18 | "Development Status :: 3 - Alpha", 19 | "Programming Language :: Python", 20 | "Programming Language :: Python :: 3", 21 | "Programming Language :: Python :: 3.5", 22 | "License :: OSI Approved :: MIT License", 23 | "Topic :: Scientific/Engineering", 24 | "Topic :: Scientific/Engineering :: Mathematics", 25 | "Intended Audience :: Science/Research", 26 | "Intended Audience :: Developers", 27 | "Operating System :: OS Independent", 28 | ] 29 | 30 | # Get the long description from the README file 31 | with open("README.rst", "r", encoding="utf8") as fh: 32 | long_description = fh.read() 33 | 34 | # Get version string from module 35 | init_path = os.path.join(os.path.dirname(__file__), "geotorch/__init__.py") 36 | with open(init_path, "r", encoding="utf8") as f: 37 | version = re.search(r"__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M).group(1) 38 | 39 | setup( 40 | name="geotorch", 41 | version=version, 42 | description="Constrained Optimization and Manifold Optimization in Pytorch", 43 | author="Mario Lezcano Casado", 44 | author_email="lezcano-93@hotmail.com", 45 | license="MIT", 46 | long_description=long_description, 47 | long_description_content_type="text/x-rst", 48 | url="https://github.com/Lezcano/geotorch", 49 | classifiers=classifiers, 50 | keywords=["Constrained Optimization", "Optimization on Manifolds", "Pytorch"], 51 | packages=find_packages(), 52 | python_requires=">=3.5", 53 | install_requires=["torch>=1.9"], 54 | extras_require={"dev": DEV_REQUIRES, "test": TEST_REQUIRES}, 55 | ) 56 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lezcano/geotorch/ba38d406c245d609fee4b4dac3f6427bf6d73a8e/test/__init__.py -------------------------------------------------------------------------------- /test/test_almostorthogonal.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from geotorch.almostorthogonal import AlmostOrthogonal 4 | 5 | 6 | class TestLowRank(TestCase): 7 | def test_almostorthogonal_errors(self): 8 | with self.assertRaises(ValueError): 9 | AlmostOrthogonal(size=(5,), lam=1.0) 10 | # Not a predefined f 11 | with self.assertRaises(ValueError): 12 | AlmostOrthogonal(size=(5, 4), lam=1, f="fail") 13 | # Not callable 14 | with self.assertRaises(ValueError): 15 | AlmostOrthogonal(size=(5, 4), lam=1, f=3.0) 16 | # But a callable should work 17 | AlmostOrthogonal(size=(5, 4), lam=0.5) 18 | # Too large a lambda 19 | with self.assertRaises(ValueError): 20 | AlmostOrthogonal(size=(5, 4), lam=2) 21 | # Or too small 22 | with self.assertRaises(ValueError): 23 | AlmostOrthogonal(size=(5, 4), lam=-1.0) 24 | -------------------------------------------------------------------------------- /test/test_glp.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from geotorch.glp import GLp 4 | 5 | 6 | class TestGLp(TestCase): 7 | def test_GLp_errors(self): 8 | # Non square 9 | with self.assertRaises(ValueError): 10 | GLp(size=(4, 3)) 11 | # Try to instantiate it in a vector rather than a matrix 12 | with self.assertRaises(ValueError): 13 | GLp(size=(5,)) 14 | -------------------------------------------------------------------------------- /test/test_integration.py: -------------------------------------------------------------------------------- 1 | # Integration tests for all the manifold 2 | from unittest import TestCase 3 | import itertools 4 | import types 5 | 6 | import torch 7 | import torch.nn as nn 8 | import geotorch.parametrize as P 9 | 10 | import geotorch 11 | from geotorch.skew import Skew 12 | from geotorch.symmetric import Symmetric 13 | from geotorch.so import SO 14 | from geotorch.stiefel import Stiefel 15 | from geotorch.grassmannian import Grassmannian 16 | from geotorch.lowrank import LowRank 17 | from geotorch.fixedrank import FixedRank 18 | from geotorch.psd import PSD 19 | from geotorch.pssd import PSSD 20 | from geotorch.pssdlowrank import PSSDLowRank 21 | from geotorch.pssdfixedrank import PSSDFixedRank 22 | from geotorch.glp import GLp 23 | from geotorch.sl import SL 24 | from geotorch.almostorthogonal import AlmostOrthogonal 25 | from geotorch.sphere import Sphere, SphereEmbedded 26 | 27 | from geotorch.utils import update_base 28 | 29 | 30 | def dicts_product(**kwargs): 31 | """Returns a product of all the lists of the keys""" 32 | keys = kwargs.keys() 33 | vals = kwargs.values() 34 | for instance in itertools.product(*vals): 35 | yield dict(zip(keys, instance)) 36 | 37 | 38 | class TestIntegration(TestCase): 39 | def sizes(self, square): 40 | sizes = [(i, i) for i in range(1, 11)] 41 | if not square: 42 | sizes.extend( 43 | [ 44 | (i, j) 45 | for i, j in itertools.product(range(1, 5), range(1, 5)) 46 | if i != j 47 | ] 48 | ) 49 | sizes.extend( 50 | [(1, 7), (2, 7), (1, 8), (2, 8), (7, 1), (7, 2), (8, 1), (8, 2)] 51 | ) 52 | if torch.cuda.is_available(): 53 | sizes.extend([(256, 256), (512, 512)]) 54 | if not square: 55 | sizes.extend([(256, 128), (128, 512), (1024, 512)]) 56 | return sizes 57 | 58 | def ranks(self): 59 | return [1, 3, 7] 60 | 61 | def lambdas(self): 62 | return [0.01, 0.5, 1.0] 63 | 64 | def radii(self): 65 | return [0.01, 1.0, 2.0, 10.0] 66 | 67 | def devices(self): 68 | if torch.cuda.is_available(): 69 | return [torch.device("cuda")] 70 | else: 71 | return [torch.device("cpu")] 72 | 73 | def test_vector_spaces(self): 74 | self._test_manifolds( 75 | [Skew, Symmetric, geotorch.skew, geotorch.symmetric], 76 | [], 77 | dicts_product(lower=[True, False]), 78 | self.devices(), 79 | self.sizes(square=True), 80 | initialize=False, 81 | ) 82 | 83 | def test_so(self): 84 | self._test_manifolds( 85 | [SO], 86 | dicts_product(distribution=["uniform", "torus"]), 87 | [{}], 88 | self.devices(), 89 | self.sizes(square=True), 90 | ) 91 | 92 | def test_orthogonal(self): 93 | self._test_manifolds( 94 | [Stiefel, Grassmannian, geotorch.orthogonal, geotorch.grassmannian], 95 | dicts_product(distribution=["uniform", "torus"]), 96 | dicts_product(triv=["expm", "cayley"]), 97 | self.devices(), 98 | self.sizes(square=False), 99 | ) 100 | 101 | def test_rank(self): 102 | self._test_manifolds( 103 | [LowRank, FixedRank, geotorch.low_rank, geotorch.fixed_rank], 104 | [{}], 105 | dicts_product(rank=self.ranks()), 106 | self.devices(), 107 | self.sizes(square=False), 108 | ) 109 | 110 | def test_psd_and_glp(self): 111 | self._test_manifolds( 112 | [ 113 | PSD, 114 | PSSD, 115 | GLp, 116 | SL, 117 | geotorch.positive_definite, 118 | geotorch.positive_semidefinite, 119 | geotorch.invertible, 120 | geotorch.sln, 121 | ], 122 | [{}], 123 | [{}], 124 | self.devices(), 125 | self.sizes(square=True), 126 | ) 127 | 128 | def test_pssd_rank(self): 129 | self._test_manifolds( 130 | [ 131 | PSSDLowRank, 132 | PSSDFixedRank, 133 | geotorch.positive_semidefinite_low_rank, 134 | geotorch.positive_semidefinite_fixed_rank, 135 | ], 136 | [{}], 137 | dicts_product(rank=self.ranks()), 138 | self.devices(), 139 | self.sizes(square=True), 140 | ) 141 | 142 | def test_almost_orthogonal(self): 143 | self._test_manifolds( 144 | [AlmostOrthogonal, geotorch.almost_orthogonal], 145 | dicts_product(distribution=["uniform", "torus"]), 146 | dicts_product(lam=self.lambdas(), f=list(AlmostOrthogonal.fs.keys())), 147 | self.devices(), 148 | self.sizes(square=True), 149 | ) 150 | 151 | def test_sphere(self): 152 | self._test_manifolds( 153 | [Sphere, SphereEmbedded, geotorch.sphere], 154 | [{}], 155 | dicts_product(radius=self.radii()), 156 | self.devices(), 157 | self.sizes(square=False), 158 | ) 159 | 160 | def _test_manifolds( 161 | self, Ms, argss_sample, argss_constr, devices, sizes, initialize=False 162 | ): 163 | with torch.random.fork_rng(devices=range(torch.cuda.device_count())): 164 | torch.random.manual_seed(8888) 165 | for M, args_sample, args_constr, device, size in itertools.product( 166 | Ms, argss_sample, argss_constr, devices, sizes 167 | ): 168 | if "rank" in args_constr and args_constr["rank"] > min(size): 169 | continue 170 | self._test_manifold( 171 | M, args_sample, args_constr, device, size, initialize 172 | ) 173 | 174 | def _test_manifold(self, M, args_sample, args_constr, device, size, initialize): 175 | inputs = [torch.rand(3, size[0], device=device)] 176 | layers = [nn.Linear(*size, device=device)] 177 | # Just test on convolution for small layers, otherwise it takes too long 178 | if min(size) < 100: 179 | inputs.append(torch.rand(6, 5, size[0] + 7, size[1] + 3, device=device)) 180 | layers.append(nn.Conv2d(5, 4, size, device=device)) 181 | 182 | for input_, layer in zip(inputs, layers): 183 | old_size = layer.weight.size() 184 | # Somewhat dirty but will do 185 | if isinstance(M, types.FunctionType): 186 | M(layer, "weight", **args_constr) 187 | else: 188 | # initialize the weight first (annoying) 189 | M_ = M(size=layer.weight.size(), **args_constr).to(device) 190 | X = M_.sample(**args_sample) 191 | with torch.no_grad(): 192 | layer.weight.copy_(X) 193 | P.register_parametrization(layer, "weight", M_) 194 | # Check that it does not change the size of the layer 195 | self.assertEqual(old_size, layer.weight.size(), msg=f"{layer}") 196 | self._test_training(layer, args_sample, input_, initialize) 197 | 198 | def _test_training(self, layer, args_sample, input_, initialize): 199 | msg = f"{layer}\n{args_sample}" 200 | M = layer.parametrizations.weight[0] 201 | if initialize: 202 | initial_size = layer.weight.size() 203 | X = M.sample(**args_sample) 204 | self.assertTrue(M.in_manifold(X), msg=msg) 205 | layer.weight = X 206 | with P.cached(): 207 | # Compute the product if it is factorized 208 | # The sampled matrix should not have a gradient 209 | self.assertFalse(X.requires_grad) 210 | # Size does not change 211 | self.assertEqual(initial_size, layer.weight.size(), msg=msg) 212 | # Tha initialisation initialisation is equal to what we passed 213 | self.assertTrue(torch.allclose(layer.weight, X, atol=1e-5), msg=msg) 214 | 215 | # Take a couple SGD steps 216 | optim = torch.optim.SGD(layer.parameters(), lr=1e-3) 217 | for i in range(3): 218 | with P.cached(): 219 | loss = layer(input_).mean() 220 | optim.zero_grad() 221 | loss.backward() 222 | optim.step() 223 | # The layer stays in the manifold while being optimised 224 | self.assertTrue(M.in_manifold(layer.weight), msg=f"i:{i}\n" + msg) 225 | 226 | with P.cached(): 227 | weight_old = layer.weight 228 | update_base(layer, "weight") 229 | # After changing the base, the weight stays the same 230 | self.assertTrue( 231 | torch.allclose(layer.weight, weight_old, atol=1e-6), msg=msg 232 | ) 233 | -------------------------------------------------------------------------------- /test/test_lowrank.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from geotorch.lowrank import LowRank 4 | from geotorch.fixedrank import FixedRank 5 | 6 | 7 | class TestLowRank(TestCase): 8 | def test_lowrank_errors(self): 9 | # rank always has to be <= min(n, k) 10 | for cls in [LowRank, FixedRank]: 11 | with self.assertRaises(ValueError): 12 | cls(size=(4, 3), rank=5) 13 | with self.assertRaises(ValueError): 14 | cls(size=(2, 3), rank=3) 15 | # Try to instantiate it in a vector rather than a matrix 16 | with self.assertRaises(ValueError): 17 | cls(size=(5,), rank=1) 18 | 19 | # On a non-callable 20 | with self.assertRaises(ValueError): 21 | FixedRank(size=(5, 3), rank=2, f=3) 22 | # On the wrong string 23 | with self.assertRaises(ValueError): 24 | FixedRank(size=(5, 3), rank=2, f="wrong") 25 | -------------------------------------------------------------------------------- /test/test_orthogonal.py: -------------------------------------------------------------------------------- 1 | # Tests for the Stiefel manifold, grassmannian and SO(n) 2 | from unittest import TestCase 3 | 4 | from geotorch.so import SO 5 | from geotorch.stiefel import Stiefel 6 | from geotorch.grassmannian import Grassmannian 7 | from geotorch.exceptions import NonSquareError, VectorError 8 | 9 | 10 | class TestOrthogonal(TestCase): 11 | def test_constructor_stiefel(self): 12 | self._test_constructor(Stiefel) 13 | 14 | def test_constructors_grassmannian(self): 15 | self._test_constructor(Grassmannian) 16 | 17 | def _test_constructor(self, cls): 18 | with self.assertRaises(ValueError): 19 | cls(size=(3, 3), triv="wrong") 20 | 21 | with self.assertRaises(ValueError): 22 | SO(size=(3, 3), triv="wrong") 23 | 24 | # Try a custom trivialization (it should break in the forward) 25 | cls(size=(3, 3), triv=lambda: 3) 26 | 27 | # Try to instantiate it in a vector rather than a matrix 28 | with self.assertRaises(VectorError): 29 | cls(size=(7,)) 30 | 31 | with self.assertRaises(VectorError): 32 | SO(size=(7,)) 33 | 34 | # Try to instantiate it in an on-square matrix 35 | with self.assertRaises(NonSquareError): 36 | SO(size=(7, 3, 2)) 37 | -------------------------------------------------------------------------------- /test/test_positive_semidefinite.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from geotorch.pssdlowrank import PSSDLowRank 4 | from geotorch.pssdfixedrank import PSSDFixedRank 5 | from geotorch.pssd import PSSD 6 | from geotorch.psd import PSD 7 | 8 | 9 | class TestPSSDLowRank(TestCase): 10 | def test_positive_semidefinite_errors(self): 11 | for cls in [PSSDLowRank, PSSDFixedRank]: 12 | # rank always has to be 1 <= rank <= n 13 | with self.assertRaises(ValueError): 14 | cls(size=(4, 4), rank=5) 15 | with self.assertRaises(ValueError): 16 | cls(size=(3, 3), rank=0) 17 | # Instantiate it in a non-square matrix 18 | with self.assertRaises(ValueError): 19 | cls(size=(3, 6), rank=2) 20 | # Try to instantiate it in a vector rather than a matrix 21 | with self.assertRaises(ValueError): 22 | cls(size=(5,), rank=1) 23 | 24 | for cls in [PSSD, PSD]: 25 | # Try to instantiate it in a vector rather than a matrix 26 | with self.assertRaises(ValueError): 27 | cls(size=(5,)) 28 | # Or a non-square 29 | with self.assertRaises(ValueError): 30 | cls(size=(5, 3)) 31 | 32 | # Pass a non-callable object 33 | with self.assertRaises(ValueError): 34 | PSSDFixedRank(size=(5, 2), rank=1, f=3) 35 | # Or the wrong string 36 | with self.assertRaises(ValueError): 37 | PSSDFixedRank(size=(5, 3), rank=2, f="fail") 38 | # Same with PSD 39 | with self.assertRaises(ValueError): 40 | PSD(size=(5, 2), f=3) 41 | with self.assertRaises(ValueError): 42 | PSD(size=(5, 3), f="fail") 43 | -------------------------------------------------------------------------------- /test/test_product.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import torch 4 | from geotorch.product import ProductManifold 5 | from geotorch.so import SO 6 | 7 | 8 | class TestManifold(TestCase): 9 | def test_product_manifold(self): 10 | # Should not throw 11 | SO3SO3 = ProductManifold([SO((3, 3)), SO((3, 3))]) 12 | 13 | # A tuple should work as well 14 | SO3SO3 = ProductManifold((SO((3, 3)), SO((3, 3)))) 15 | 16 | # Forward should work 17 | X = (torch.rand(3, 3), torch.rand(3, 3)) 18 | Y1, Y2 = SO3SO3(X) 19 | -------------------------------------------------------------------------------- /test/test_skew.py: -------------------------------------------------------------------------------- 1 | # Tests for the Sphere 2 | from unittest import TestCase 3 | import itertools 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | import geotorch.parametrize as P 9 | from geotorch.skew import Skew 10 | 11 | 12 | class TestSkew(TestCase): 13 | def test_backprop(self): 14 | r"""Test that we may instantiate the parametrizations and 15 | register them in modules of several sizes. Check that the 16 | results are on the sphere 17 | """ 18 | sizes = [1, 2, 3, 8] 19 | 20 | for n, lower in itertools.product(sizes, [True, False]): 21 | layer = nn.Linear(n, n) 22 | P.register_parametrization(layer, "weight", Skew(lower=lower)) 23 | 24 | input_ = torch.rand(5, n) 25 | optim = torch.optim.SGD(layer.parameters(), lr=1.0) 26 | 27 | # Assert that is stays in Skew(n) after some optimiser steps 28 | for _ in range(2): 29 | with P.cached(): 30 | self.assertTrue(Skew.in_manifold(layer.weight)) 31 | loss = layer(input_).sum() 32 | optim.zero_grad() 33 | loss.backward() 34 | optim.step() 35 | 36 | def test_non_square(self): 37 | # Non-square skew 38 | with self.assertRaises(ValueError): 39 | Skew()(torch.rand(3, 2)) 40 | 41 | with self.assertRaises(ValueError): 42 | Skew()(torch.rand(1, 3)) 43 | 44 | # Try to instantiate it in a vector rather than a matrix 45 | with self.assertRaises(ValueError): 46 | Skew()(torch.rand(4)) 47 | 48 | def test_repr(self): 49 | print(Skew()) 50 | -------------------------------------------------------------------------------- /test/test_sl.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from geotorch.sl import SL 4 | 5 | 6 | class TestSL(TestCase): 7 | def test_SL_errors(self): 8 | # Non square 9 | with self.assertRaises(ValueError): 10 | SL(size=(4, 3)) 11 | # Try to instantiate it in a vector rather than a matrix 12 | with self.assertRaises(ValueError): 13 | SL(size=(5,)) 14 | -------------------------------------------------------------------------------- /test/test_sphere.py: -------------------------------------------------------------------------------- 1 | # Tests for the Sphere 2 | from unittest import TestCase 3 | 4 | from geotorch.sphere import Sphere, SphereEmbedded 5 | 6 | 7 | class TestSphere(TestCase): 8 | def test_construction(self): 9 | # Negative curvature 10 | with self.assertRaises(ValueError): 11 | Sphere(size=(5,), radius=-1.0) 12 | with self.assertRaises(ValueError): 13 | SphereEmbedded(size=(4,), radius=-1.0) 14 | 15 | def test_repr(self): 16 | print(SphereEmbedded(size=(3,))) 17 | print(Sphere(size=(3,))) 18 | -------------------------------------------------------------------------------- /test/test_symmetric.py: -------------------------------------------------------------------------------- 1 | # Tests for the Sphere 2 | from unittest import TestCase 3 | import itertools 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | import geotorch.parametrize as P 9 | from geotorch.symmetric import Symmetric, SymF 10 | 11 | 12 | class TestSymmetric(TestCase): 13 | def test_backprop(self): 14 | r"""Test that we may instantiate the parametrizations and 15 | register them in modules of several sizes. Check that the 16 | results are on the sphere 17 | """ 18 | sizes = [1, 2, 3, 8] 19 | 20 | for n, lower in itertools.product(sizes, [True, False]): 21 | layer = nn.Linear(n, n) 22 | P.register_parametrization(layer, "weight", Symmetric(lower=lower)) 23 | 24 | input_ = torch.rand(5, n) 25 | optim = torch.optim.SGD(layer.parameters(), lr=1.0) 26 | 27 | # Assert that is stays in Sym(n) after some optimiser steps 28 | for _ in range(2): 29 | with P.cached(): 30 | self.assertTrue(Symmetric.in_manifold(layer.weight)) 31 | loss = layer(input_).sum() 32 | optim.zero_grad() 33 | loss.backward() 34 | optim.step() 35 | 36 | def test_construction(self): 37 | # Non-square sym 38 | with self.assertRaises(ValueError): 39 | Symmetric()(torch.rand(3, 2)) 40 | 41 | with self.assertRaises(ValueError): 42 | Symmetric()(torch.rand(1, 3)) 43 | 44 | # Try to instantiate it in a vector rather than a matrix 45 | with self.assertRaises(ValueError): 46 | Symmetric()(torch.rand(4)) 47 | 48 | # Instantiate it with a non-callable object 49 | with self.assertRaises(ValueError): 50 | SymF(size=(4, 4), rank=4, f=3.0) 51 | # Or with the wrong rank 52 | with self.assertRaises(ValueError): 53 | SymF(size=(4, 4), rank=5, f=lambda: None) 54 | with self.assertRaises(ValueError): 55 | SymF(size=(4, 4), rank=0, f=lambda: None) 56 | # Or on vectors 57 | with self.assertRaises(ValueError): 58 | SymF(size=(4,), rank=2, f=lambda: None) 59 | # Or on non-square matrices 60 | with self.assertRaises(ValueError): 61 | SymF(size=(4, 3), rank=2, f=lambda: None) 62 | 63 | def test_repr(self): 64 | print(Symmetric()) 65 | --------------------------------------------------------------------------------