├── .github
└── workflows
│ ├── check_code_quality.yml
│ ├── deploy.yml
│ └── run_unit_tests.yml
├── .gitignore
├── .readthedocs.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
├── Makefile
├── README.md
├── make.bat
└── source
│ ├── api.rst
│ ├── conf.py
│ ├── index.rst
│ └── install.rst
├── hypll
├── __init__.py
├── manifolds
│ ├── __init__.py
│ ├── base
│ │ ├── __init__.py
│ │ └── manifold.py
│ ├── euclidean
│ │ ├── __init__.py
│ │ └── manifold.py
│ └── poincare_ball
│ │ ├── __init__.py
│ │ ├── curvature.py
│ │ ├── manifold.py
│ │ └── math
│ │ ├── __init__.py
│ │ ├── diffgeom.py
│ │ ├── linalg.py
│ │ └── stats.py
├── nn
│ ├── __init__.py
│ └── modules
│ │ ├── __init__.py
│ │ ├── activation.py
│ │ ├── batchnorm.py
│ │ ├── change_manifold.py
│ │ ├── container.py
│ │ ├── convolution.py
│ │ ├── embedding.py
│ │ ├── flatten.py
│ │ ├── fold.py
│ │ ├── linear.py
│ │ └── pooling.py
├── optim
│ ├── __init__.py
│ ├── adam.py
│ └── sgd.py
├── tensors
│ ├── __init__.py
│ ├── manifold_parameter.py
│ ├── manifold_tensor.py
│ └── tangent_tensor.py
└── utils
│ ├── __init__.py
│ ├── layer_utils.py
│ ├── math.py
│ └── tensor_utils.py
├── poetry.lock
├── pyproject.toml
├── tests
├── manifolds
│ ├── euclidean
│ │ └── test_euclidean.py
│ └── poincare_ball
│ │ ├── test_curvature.py
│ │ └── test_poincare_ball.py
├── nn
│ ├── test_change_manifold.py
│ ├── test_convolution.py
│ └── test_flatten.py
└── test_manifold_tensor.py
└── tutorials
├── README.txt
├── cifar10_resnet_tutorial.py
├── cifar10_tutorial.py
├── data
└── wordnet_mammals.json
├── hyperbolic_vit_tutorial.py
└── poincare_embeddings_tutorial.py
/.github/workflows/check_code_quality.yml:
--------------------------------------------------------------------------------
1 | name: Check Code Quality
2 |
3 | on:
4 | push:
5 | pull_request:
6 | workflow_dispatch:
7 |
8 | jobs:
9 | test:
10 | name: Check Code Quality
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Checkout Code
14 | uses: actions/checkout@v3
15 | - name: Set up Python
16 | uses: actions/setup-python@v4
17 | with:
18 | python-version: "3.10"
19 | - uses: actions/cache@v2
20 | with:
21 | path: ${{ env.pythonLocation }}
22 | key: cache_v2_${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}
23 | - name: Install Poetry
24 | run: |
25 | pip install poetry
26 | echo "$HOME/.local/bin" >> $GITHUB_PATH
27 | - name: Install Package
28 | run: poetry install
29 | - name: Check Code Quality
30 | run: |
31 | poetry run black --check .
32 | poetry run isort --check .
33 |
--------------------------------------------------------------------------------
/.github/workflows/deploy.yml:
--------------------------------------------------------------------------------
1 | name: Deploy to PyPI
2 |
3 | on:
4 | workflow_dispatch:
5 | inputs:
6 | version-type:
7 | description: 'Version bump type (major, minor, patch)'
8 | required: true
9 | default: 'patch'
10 |
11 | jobs:
12 | deploy:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - name: Check out repository
16 | uses: actions/checkout@v2
17 |
18 | - name: Set up Python
19 | uses: actions/setup-python@v2
20 | with:
21 | python-version: '3.10'
22 |
23 | - name: Install Poetry
24 | run: |
25 | pip install poetry
26 | echo "$HOME/.local/bin" >> $GITHUB_PATH
27 |
28 | - name: Bump version
29 | run: poetry version ${{ github.event.inputs.version-type }}
30 |
31 | - name: Push changes
32 | run: |
33 | git config --local user.email "action@github.com"
34 | git config --local user.name "GitHub Action"
35 | git commit -am "Increase version [skip ci]"
36 | git push
37 |
38 | - name: Publish to PyPI
39 | run: poetry publish --build --username ${{ secrets.PYPI_USERNAME }} --password ${{ secrets.PYPI_API_TOKEN }}
40 |
--------------------------------------------------------------------------------
/.github/workflows/run_unit_tests.yml:
--------------------------------------------------------------------------------
1 | name: Run Unit Tests
2 |
3 | on:
4 | push:
5 | pull_request:
6 | workflow_dispatch:
7 |
8 | jobs:
9 | test:
10 | name: Run Unit Tests
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Checkout Code
14 | uses: actions/checkout@v3
15 | - name: Set up Python
16 | uses: actions/setup-python@v4
17 | with:
18 | python-version: "3.10"
19 | - uses: actions/cache@v2
20 | with:
21 | path: ${{ env.pythonLocation }}
22 | key: cache_v2_${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}
23 | - name: Install Poetry
24 | run: |
25 | pip install poetry
26 | echo "$HOME/.local/bin" >> $GITHUB_PATH
27 | - name: Install Package
28 | run: poetry install
29 | - name: Run Unit Tests
30 | run: |
31 | poetry run pytest tests/
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .env/
3 | *.egg-info/
4 | .pytest_cache
5 | /data
6 | /tests/data
7 | .vscode
8 | build
9 | /dist
10 | data
11 | docs/build
12 | docs/source/_autosummary
13 | docs/source/tutorials
14 | tests/data
15 |
16 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: ubuntu-22.04
5 | tools:
6 | python: "3.11"
7 |
8 | sphinx:
9 | configuration: docs/source/conf.py
10 |
11 | python:
12 | install:
13 | - method: pip
14 | path: .
15 | extra_requirements:
16 | - docs
17 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to HypLL
2 |
3 | Thank you for considering contributing to HypLL! We are always looking for new contributors to help us implement the constantly improving hyperbolic learning methodology and to maintain what is already here, so we are happy to have you!
4 |
5 | We are looking for all sorts of contributions such as
6 |
7 | - Bugfixes
8 | - Documentation additions and improvements
9 | - New features
10 | - New tutorials
11 |
12 |
13 | # Getting started
14 | ### Git workflow
15 | We want any new changes to HypLL to be linked to GitHub issues. If you have something new in mind that is not yet mentioned in an issue, please create an issue detailing the intended change first. Once you have an issue in mind that you would like to contribute to, [fork the repository](https://docs.github.com/en/get-started/quickstart/fork-a-repo).
16 |
17 | After you have finished your contribution, [open a pull request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) on GitHub to have your contribution merged into our `main` branch. Note that your pull request will automatically be checked to see if it matches some of our standards. To avoid running into errors during your pull request, consider performing these tests locally first. A description on how to run the tests is given down below.
18 |
19 | ### Local setup
20 | In order to get started locally with HypLL development, first clone the repository and make sure to install the repository with the correct optional dependencies from the `pyproject.toml` file. For example, if you only need the base development dependencies, navigate to the root directory (so the one containing this file) of your local version of this repository and install using:
21 | ```
22 | pip install -e .[dev]
23 | ```
24 | We recommend using a virtual environment tool such as `conda` or `venv`.
25 |
26 | For further instructions, please read the section below that corresponds to the type of contribution that you intend to make.
27 |
28 |
29 | # Contributing to development
30 | **Optional dependencies:** If you are making development contributions, then you will at least need the `dev` optional dependencies. If your contribution warrants additional changes to the documentation, then the `docs` optional dependencies will likely also be of use for testing the documentation build process.
31 |
32 | **Formatting:** For formatting we use [black](https://black.readthedocs.io) and [isort](https://pycqa.github.io/isort/). If you are unfamiliar with these, feel free to check out their documentations. However, it should not be a large problem if you are unfamiliar with their style guides, since they automatically format your code for you (in most cases). `black` is a style formatter that ensures uniformity of the coding style throughout the project, which helps with readability. To use `black`, simply run (inside the root directory)
33 | ```
34 | black .
35 | ```
36 | `isort` is a utility that automatically sorts and separates imports to also improve readability. To use `isort`, simply run (inside the root directory)
37 | ```
38 | isort .
39 | ```
40 |
41 | **Testing:** When your contribution is a simple bugfix it can be sufficient to use the existing tests. However, in most cases you will have to add additional tests to check your new feature or to ensure that whatever bug you are fixing no longer occurs. These tests can be added to the `/tests` directory. We use [pytest](https://docs.pytest.org) for testing, so if you are unfamiliar with this, please check their documentation or use the other tests as an example. If you think your contribution is ready, you can test it by running (inside the root directory)
42 | ```
43 | pytest
44 | ```
45 | If you made any changes to the documentation then you will also need to test the build process of the docs. You can read how to do this down below underneath the "Contributing to documentation" header.
46 |
47 |
48 | # Contributing to documentation
49 | **Optional dependencies:** When making documentation contributions, you will most likely need the `docs` optional dependencies to test the build process of the docs and the `dev` optional dependencies for formatting.
50 |
51 | **Contributing:** Our documentation is built using the documentation generator [sphinx](https://www.sphinx-doc.org). Most of the documentation is generated automatically from the docstrings, so if you want to contribute to the API documentation, a simple change to the docstring of the relevant part of the code should suffice. For more complicated changes, please take a look at the `sphinx` documentation.
52 |
53 | **Formatting:** For formatting new documentation, please use the existing documentation as examples. Once you are happy with your contribution, use `black` as described under the "Contributing to development" header to ensure that the code satisfies our formatting style.
54 |
55 | **Testing:** To test the build process of the documentation, please follow the instructions within the `README` inside the `docs` directory.
56 |
57 |
58 | # Contributing to new tutorials
59 | **Optional dependencies:** When contributing to the tutorials, you will need the `dev` optional dependencies.
60 |
61 | **Contributing:** We have two types of tutorials:
62 |
63 | - Basic tutorials for learning about HypLL
64 | - Tutorials showcasing how to implement peer-reviewed papers using HypLL
65 |
66 | If your intended contribution falls under the first category, first make sure that there is a need for your specific tutorial. In other words, make sure that there is an issue asking for your tutorial and make sure that its inclusion is already agreed upon. If it falls under the second category, make sure that the paper that you are implementing is a peer-reviewed publication and, again, that there is an issue asking for this implementation.
67 |
68 | **Formatting:** The formatting for tutorials is not very strict, but please use the existing tutorials as a guideline and ensure that the tutorials are self-contained.
69 |
70 | **Testing:** Once you are ready with your tutorial, please test it by ensuring that it runs to completion.
71 |
72 |
73 | # Code review process
74 | Pull requests have to be reviewed by at least one of the HypLL collaborators.
75 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Hyperbolic Learning Library developers
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Hyperbolic Learning Library
2 |
3 | [](https://hyperbolic-learning-library.readthedocs.io/en/latest/?badge=latest)
4 | 
5 | [](https://github.com/psf/black)
6 | [](https://github.com/PyCQA/isort)
7 |
8 | An extension of the PyTorch library containing various tools for performing deep learning in hyperbolic space.
9 |
10 | Contents:
11 | * [Documentation](#documentation)
12 | * [Installation](#installation)
13 | * [BibTeX](#bibtex)
14 |
15 |
16 | ## Documentation
17 | Visit our [documentation](https://hyperbolic-learning-library.readthedocs.io/en/latest/index.html) for tutorials and more.
18 |
19 |
20 | ## Installation
21 |
22 | The Hyperbolic Learning Library was written for Python 3.10+ and PyTorch 1.11+.
23 |
24 | It's recommended to have a
25 | working PyTorch installation before setting up HypLL:
26 |
27 | * [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
28 |
29 | Start by setting up a Python [virtual environment](https://docs.python.org/3/library/venv.html):
30 |
31 | ```
32 | python -venv .env
33 | ```
34 |
35 | Activate the virtual environment on Linux and MacOs:
36 | ```
37 | source .env/bin/activate
38 | ```
39 | Or on Windows:
40 | ```
41 | .env/Scripts/activate
42 | ```
43 |
44 | Finally, install HypLL from PyPI.
45 |
46 | ```
47 | pip install hypll
48 | ```
49 |
50 | ## BibTeX
51 | If you would like to cite this project, please use the following bibtex entry
52 | ```
53 | @article{spengler2023hypll,
54 | title={HypLL: The Hyperbolic Learning Library},
55 | author={van Spengler, Max and Wirth, Philipp and Mettes, Pascal},
56 | journal={arXiv preprint arXiv:2306.06154},
57 | year={2023}
58 | }
59 | ```
60 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Build the Docs
2 | `sphinx` provides a Makefile, so to build the `html` documentation, simply type:
3 | ```make html
4 | ```
5 |
6 | Or, alternatively
7 | ```
8 | SPHINXBUILD="python3 -m sphinx.cmd.build" make html
9 | ```
10 |
11 | You can browse the docs by opening the generated html files in the `docs/build` directory
12 | with a browser of your choice.
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/api.rst:
--------------------------------------------------------------------------------
1 |
2 | .. autosummary::
3 | :toctree: _autosummary
4 | :nosignatures:
5 | :recursive:
6 |
7 | hypll
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 |
9 | import os
10 | import sys
11 |
12 | sys.path.insert(0, os.path.abspath("../..")) # Source code dir relative to this file
13 |
14 | project = "hypll"
15 | copyright = '2023, ""'
16 | author = '""'
17 |
18 | # -- General configuration ---------------------------------------------------
19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
20 |
21 | extensions = [
22 | "sphinx_copybutton",
23 | "sphinx.ext.autodoc", # Core library for html generation from docstrings
24 | "sphinx.ext.autosummary",
25 | "sphinx.ext.napoleon",
26 | "sphinx_gallery.gen_gallery",
27 | "sphinx_tabs.tabs",
28 | ]
29 |
30 | templates_path = ["_templates"]
31 | exclude_patterns = []
32 |
33 | # Autodoc options:
34 | autodoc_class_signature = "separated"
35 | autodoc_default_options = {
36 | "members": True,
37 | "member-order": "bysource",
38 | "special-members": "__init__",
39 | "undoc-members": True,
40 | "exclude-members": "__weakref__",
41 | }
42 | autodoc_default_flags = [
43 | "members",
44 | "undoc-members",
45 | "special-members",
46 | "show-inheritance",
47 | ]
48 |
49 | # Napoleon options:
50 | napoleon_google_docstring = True
51 | napoleon_numpy_docstring = False
52 | napoleon_include_init_with_doc = False
53 | napoleon_include_private_with_doc = False
54 | napoleon_include_special_with_doc = False
55 | napoleon_use_admonition_for_examples = False
56 | napoleon_use_admonition_for_notes = False
57 | napoleon_use_admonition_for_references = False
58 | napoleon_use_ivar = False
59 | napoleon_use_param = False
60 | napoleon_use_rtype = False
61 | napoleon_type_aliases = None
62 |
63 | # Sphinx gallery config:
64 | sphinx_gallery_conf = {
65 | "examples_dirs": "../../tutorials",
66 | "gallery_dirs": "tutorials/",
67 | "filename_pattern": "",
68 | # TODO(Philipp, 06/23): Figure out how we can build and host tutorials on RTD.
69 | "plot_gallery": "False",
70 | }
71 |
72 | # -- Options for HTML output -------------------------------------------------
73 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
74 |
75 | html_theme = "alabaster"
76 | html_static_path = ["_static"]
77 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 |
2 | Hyperbolic Learning Library
3 | ===========================
4 |
5 | The Hyperbolic Learning Library (HypLL) was designed to bring
6 | the progress on hyperbolic deep learning together. HypLL is built on
7 | top of PyTorch, with an emphasis on ease-of-use.
8 |
9 | All code is open-source. Visit us on `GitHub `_
10 | to see open issues and share your feedback!
11 |
12 | .. toctree::
13 | :caption: Getting Started
14 | :maxdepth: 1
15 |
16 | install.rst
17 |
18 |
19 | .. toctree::
20 | :caption: Tutorials
21 | :name: tutorials
22 | :maxdepth: 1
23 |
24 | tutorials/cifar10_tutorial.rst
25 | tutorials/cifar10_resnet_tutorial.rst
26 | tutorials/poincare_embeddings_tutorial.rst
27 |
28 |
29 | .. toctree::
30 | :caption: API
31 | :maxdepth: 2
32 |
33 | api.rst
34 |
35 |
36 |
37 | Indices and tables
38 | ------------------
39 |
40 | * :ref:`genindex`
41 | * :ref:`modindex`
42 | * :ref:`search`
43 |
--------------------------------------------------------------------------------
/docs/source/install.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | HypLL was written for Python 3.10+ and PyTorch 1.11+.
5 |
6 | It's recommended to have a
7 | working PyTorch installation before setting up HypLL:
8 |
9 | * `PyTorch `_ installation instructions.
10 |
11 |
12 | Install with pip
13 | ----------------
14 |
15 | Start by setting up a Python `virtual environment `_:
16 |
17 | .. code::
18 |
19 | python -venv .env
20 |
21 | Activate the virtual environment.
22 |
23 | .. tabs::
24 |
25 | .. tab:: Linux and MacOs
26 |
27 | .. code::
28 |
29 | source .env/bin/activate
30 |
31 | .. tab:: Windows
32 |
33 | .. code::
34 |
35 | .env/Scripts/activate
36 |
37 |
38 | Install HypLL from PyPI.
39 |
40 | .. code::
41 |
42 | pip install hypll
43 |
44 |
45 | Congratulations, you are ready to do machine learning in hyperbolic space.
46 | Check out our :ref:`tutorials` next to get started!
47 |
48 |
--------------------------------------------------------------------------------
/hypll/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxvanspengler/hyperbolic_learning_library/62ea7882848d878459eba4100aeb6b8ad64cc38f/hypll/__init__.py
--------------------------------------------------------------------------------
/hypll/manifolds/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Manifold
2 |
--------------------------------------------------------------------------------
/hypll/manifolds/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .manifold import Manifold
2 |
--------------------------------------------------------------------------------
/hypll/manifolds/base/manifold.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from abc import ABC, abstractmethod
4 | from typing import TYPE_CHECKING, List, Optional, Tuple, Union
5 |
6 | from torch import Tensor
7 | from torch.nn import Module, Parameter
8 | from torch.nn.common_types import _size_2_t
9 |
10 | # TODO: find a less hacky solution for this
11 | if TYPE_CHECKING:
12 | from hypll.tensors import ManifoldParameter, ManifoldTensor, TangentTensor
13 |
14 |
15 | class Manifold(Module, ABC):
16 | def __init__(self) -> None:
17 | super(Manifold, self).__init__()
18 |
19 | @abstractmethod
20 | def project(self, x: ManifoldTensor, eps: float = -1.0) -> ManifoldTensor:
21 | raise NotImplementedError
22 |
23 | @abstractmethod
24 | def expmap(self, v: TangentTensor) -> ManifoldTensor:
25 | raise NotImplementedError
26 |
27 | @abstractmethod
28 | def logmap(self, x: Optional[ManifoldTensor], y: ManifoldTensor) -> TangentTensor:
29 | raise NotImplementedError
30 |
31 | @abstractmethod
32 | def transp(self, v: TangentTensor, y: ManifoldTensor) -> TangentTensor:
33 | raise NotImplementedError
34 |
35 | @abstractmethod
36 | def dist(self, x: ManifoldTensor, y: ManifoldTensor) -> Tensor:
37 | raise NotImplementedError
38 |
39 | @abstractmethod
40 | def cdist(self, x: ManifoldTensor, y: ManifoldTensor) -> Tensor:
41 | raise NotImplementedError
42 |
43 | @abstractmethod
44 | def euc_to_tangent(self, x: ManifoldTensor, u: ManifoldTensor) -> TangentTensor:
45 | raise NotImplementedError
46 |
47 | @abstractmethod
48 | def hyperplane_dists(self, x: ManifoldTensor, z: ManifoldTensor, r: Optional[Tensor]) -> Tensor:
49 | raise NotImplementedError
50 |
51 | @abstractmethod
52 | def fully_connected(
53 | self, x: ManifoldTensor, z: ManifoldTensor, bias: Optional[Tensor]
54 | ) -> ManifoldTensor:
55 | raise NotImplementedError
56 |
57 | @abstractmethod
58 | def frechet_mean(
59 | self,
60 | x: ManifoldTensor,
61 | batch_dim: Union[int, list[int]] = 0,
62 | keepdim: bool = False,
63 | ) -> ManifoldTensor:
64 | raise NotImplementedError
65 |
66 | @abstractmethod
67 | def midpoint(
68 | self,
69 | x: ManifoldTensor,
70 | batch_dim: Union[int, list[int]] = 0,
71 | w: Optional[Tensor] = None,
72 | keepdim: bool = False,
73 | ) -> ManifoldTensor:
74 | raise NotImplementedError
75 |
76 | @abstractmethod
77 | def frechet_variance(
78 | self,
79 | x: ManifoldTensor,
80 | mu: Optional[ManifoldTensor] = None,
81 | batch_dim: Union[int, list[int]] = -1,
82 | keepdim: bool = False,
83 | ) -> Tensor:
84 | raise NotImplementedError
85 |
86 | @abstractmethod
87 | def construct_dl_parameters(
88 | self, in_features: int, out_features: int, bias: bool = True
89 | ) -> Union[ManifoldParameter, tuple[ManifoldParameter, Parameter]]:
90 | # TODO: make an annotation object for the return type of this method
91 | raise NotImplementedError
92 |
93 | @abstractmethod
94 | def reset_parameters(self, weight: ManifoldParameter, bias: Parameter) -> None:
95 | raise NotImplementedError
96 |
97 | @abstractmethod
98 | def flatten(self, x: ManifoldTensor, start_dim: int = 1, end_dim: int = -1) -> ManifoldTensor:
99 | raise NotImplementedError
100 |
101 | @abstractmethod
102 | def unfold(
103 | self,
104 | input: ManifoldTensor,
105 | kernel_size: _size_2_t,
106 | dilation: _size_2_t = 1,
107 | padding: _size_2_t = 0,
108 | stride: _size_2_t = 1,
109 | ) -> ManifoldTensor:
110 | raise NotImplementedError
111 |
112 | @abstractmethod
113 | def cat(
114 | self,
115 | manifold_tensors: Union[Tuple[ManifoldTensor, ...], List[ManifoldTensor]],
116 | dim: int = 0,
117 | ) -> ManifoldTensor:
118 | raise NotImplementedError
119 |
--------------------------------------------------------------------------------
/hypll/manifolds/euclidean/__init__.py:
--------------------------------------------------------------------------------
1 | from .manifold import Euclidean
2 |
--------------------------------------------------------------------------------
/hypll/manifolds/euclidean/manifold.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 | from typing import List, Optional, Tuple, Union
3 |
4 | import torch
5 | from torch import Tensor, broadcast_shapes, empty, matmul, var
6 | from torch.nn import Parameter
7 | from torch.nn.common_types import _size_2_t
8 | from torch.nn.functional import unfold
9 | from torch.nn.init import _calculate_fan_in_and_fan_out, kaiming_uniform_, uniform_
10 |
11 | from hypll.manifolds.base import Manifold
12 | from hypll.tensors import ManifoldParameter, ManifoldTensor, TangentTensor
13 | from hypll.utils.tensor_utils import (
14 | check_dims_with_broadcasting,
15 | check_if_man_dims_match,
16 | check_tangent_tensor_positions,
17 | )
18 |
19 |
20 | class Euclidean(Manifold):
21 | def __init__(self):
22 | super(Euclidean, self).__init__()
23 |
24 | def project(self, x: ManifoldTensor, eps: float = -1.0) -> ManifoldTensor:
25 | return x
26 |
27 | def expmap(self, v: TangentTensor) -> ManifoldTensor:
28 | if v.manifold_points is None:
29 | new_tensor = v.tensor
30 | else:
31 | new_tensor = v.manifold_points.tensor + v.tensor
32 | return ManifoldTensor(data=new_tensor, manifold=self, man_dim=v.broadcasted_man_dim)
33 |
34 | def logmap(self, x: Optional[ManifoldTensor], y: ManifoldTensor) -> TangentTensor:
35 | if x is None:
36 | dim = y.man_dim
37 | new_tensor = y.tensor
38 | else:
39 | dim = check_dims_with_broadcasting(x, y)
40 | new_tensor = y.tensor - x.tensor
41 | return TangentTensor(data=new_tensor, manifold_points=x, manifold=self, man_dim=dim)
42 |
43 | def transp(self, v: TangentTensor, y: ManifoldTensor) -> TangentTensor:
44 | dim = check_dims_with_broadcasting(v, y)
45 | output_shape = broadcast_shapes(v.size(), y.size())
46 | new_tensor = v.tensor.broadcast_to(output_shape)
47 | return TangentTensor(data=new_tensor, manifold_points=y, manifold=self, man_dim=dim)
48 |
49 | def dist(self, x: ManifoldTensor, y: ManifoldTensor) -> Tensor:
50 | dim = check_dims_with_broadcasting(x, y)
51 | return (y.tensor - x.tensor).norm(dim=dim)
52 |
53 | def inner(
54 | self, u: TangentTensor, v: TangentTensor, keepdim: bool = False, safe_mode: bool = True
55 | ) -> Tensor:
56 | dim = check_dims_with_broadcasting(u, v)
57 | if safe_mode:
58 | check_tangent_tensor_positions(u, v)
59 |
60 | return (u.tensor * v.tensor).sum(dim=dim, keepdim=keepdim)
61 |
62 | def euc_to_tangent(self, x: ManifoldTensor, u: ManifoldTensor) -> TangentTensor:
63 | dim = check_dims_with_broadcasting(x, u)
64 | return TangentTensor(
65 | data=u.tensor,
66 | manifold_points=x,
67 | manifold=self,
68 | man_dim=dim,
69 | )
70 |
71 | def hyperplane_dists(self, x: ManifoldTensor, z: ManifoldTensor, r: Optional[Tensor]) -> Tensor:
72 | if x.man_dim != 1 or z.man_dim != 0:
73 | raise ValueError(
74 | f"Expected the manifold dimension of the inputs to be 1 and the manifold "
75 | f"dimension of the hyperplane orientations to be 0, but got {x.man_dim} and "
76 | f"{z.man_dim}, respectively"
77 | )
78 | if r is None:
79 | return matmul(x.tensor, z.tensor)
80 | else:
81 | return matmul(x.tensor, z.tensor) + r
82 |
83 | def fully_connected(
84 | self, x: ManifoldTensor, z: ManifoldTensor, bias: Optional[Tensor]
85 | ) -> ManifoldTensor:
86 | if z.man_dim != 0:
87 | raise ValueError(
88 | f"Expected the manifold dimension of the hyperplane orientations to be 0, but got "
89 | f"{z.man_dim} instead"
90 | )
91 |
92 | dim_shifted_x_tensor = x.tensor.movedim(source=x.man_dim, destination=-1)
93 | dim_shifted_new_tensor = matmul(dim_shifted_x_tensor, z.tensor)
94 | if bias is not None:
95 | dim_shifted_new_tensor = dim_shifted_new_tensor + bias
96 | new_tensor = dim_shifted_new_tensor.movedim(source=-1, destination=x.man_dim)
97 |
98 | return ManifoldTensor(data=new_tensor, manifold=self, man_dim=x.man_dim)
99 |
100 | def frechet_mean(
101 | self,
102 | x: ManifoldTensor,
103 | batch_dim: Union[int, list[int]] = 0,
104 | keepdim: bool = False,
105 | ) -> ManifoldTensor:
106 | if isinstance(batch_dim, int):
107 | batch_dim = [batch_dim]
108 |
109 | if x.man_dim in batch_dim:
110 | raise ValueError(
111 | f"Tried to aggregate over dimensions {batch_dim}, but input has manifold "
112 | f"dimension {x.man_dim} and cannot aggregate over this dimension"
113 | )
114 |
115 | # Output manifold dimension is shifted left for each batch dim that disappears
116 | man_dim_shift = sum(bd < x.man_dim for bd in batch_dim)
117 | new_man_dim = x.man_dim - man_dim_shift if not keepdim else x.man_dim
118 |
119 | mean_tensor = x.tensor.mean(dim=batch_dim, keepdim=keepdim)
120 |
121 | return ManifoldTensor(data=mean_tensor, manifold=self, man_dim=new_man_dim)
122 |
123 | def midpoint(
124 | self,
125 | x: ManifoldTensor,
126 | batch_dim: Union[int, list[int]] = 0,
127 | w: Optional[Tensor] = None,
128 | keepdim: bool = False,
129 | ) -> ManifoldTensor:
130 | return self.frechet_mean(x=x, batch_dim=batch_dim, keepdim=keepdim)
131 |
132 | def frechet_variance(
133 | self,
134 | x: ManifoldTensor,
135 | mu: Optional[ManifoldTensor] = None,
136 | batch_dim: Union[int, list[int]] = -1,
137 | keepdim: bool = False,
138 | ) -> Tensor:
139 | if isinstance(batch_dim, int):
140 | batch_dim = [batch_dim]
141 |
142 | if mu is None:
143 | return var(x.tensor, dim=batch_dim, keepdim=keepdim)
144 | else:
145 | if x.dim() != mu.dim():
146 | for bd in sorted(batch_dim):
147 | mu.man_dim += 1 if bd <= mu.man_dim else 0
148 | mu.tensor = mu.tensor.unsqueeze(bd)
149 | if mu.man_dim != x.man_dim:
150 | raise ValueError("Input tensor and mean do not have matching manifold dimensions")
151 | n = 1
152 | for bd in batch_dim:
153 | n *= x.size(dim=bd)
154 | return (x.tensor - mu.tensor).pow(2).sum(dim=batch_dim, keepdim=keepdim) / (n - 1)
155 |
156 | def construct_dl_parameters(
157 | self, in_features: int, out_features: int, bias: bool = True
158 | ) -> tuple[ManifoldParameter, Optional[Parameter]]:
159 | weight = ManifoldParameter(
160 | data=empty(in_features, out_features),
161 | manifold=self,
162 | man_dim=0,
163 | )
164 |
165 | if bias:
166 | b = Parameter(data=empty(out_features))
167 | else:
168 | b = None
169 |
170 | return weight, b
171 |
172 | def reset_parameters(self, weight: ManifoldParameter, bias: Parameter) -> None:
173 | kaiming_uniform_(weight.tensor, a=sqrt(5))
174 | if bias is not None:
175 | fan_in, _ = _calculate_fan_in_and_fan_out(weight.tensor)
176 | bound = 1 / sqrt(fan_in) if fan_in > 0 else 0
177 | uniform_(bias, -bound, bound)
178 |
179 | def unfold(
180 | self,
181 | input: ManifoldTensor,
182 | kernel_size: _size_2_t,
183 | dilation: _size_2_t = 1,
184 | padding: _size_2_t = 0,
185 | stride: _size_2_t = 1,
186 | ) -> ManifoldTensor:
187 | new_tensor = unfold(
188 | input=input.tensor,
189 | kernel_size=kernel_size,
190 | dilation=dilation,
191 | padding=padding,
192 | stride=stride,
193 | )
194 | return ManifoldTensor(data=new_tensor, manifold=input.manifold, man_dim=1)
195 |
196 | def flatten(self, x: ManifoldTensor, start_dim: int = 1, end_dim: int = -1) -> ManifoldTensor:
197 | """Flattens a manifold tensor by reshaping it. If start_dim or end_dim are passed,
198 | only dimensions starting with start_dim and ending with end_dim are flattend.
199 |
200 | Updates the manifold dimension if necessary.
201 |
202 | """
203 | start_dim = x.dim() + start_dim if start_dim < 0 else start_dim
204 | end_dim = x.dim() + end_dim if end_dim < 0 else end_dim
205 |
206 | # Get the range of dimensions to flatten.
207 | dimensions_to_flatten = x.shape[start_dim + 1 : end_dim + 1]
208 |
209 | # Get the new manifold dimension.
210 | if start_dim <= x.man_dim and end_dim >= x.man_dim:
211 | man_dim = start_dim
212 | elif end_dim <= x.man_dim:
213 | man_dim = x.man_dim - len(dimensions_to_flatten)
214 | else:
215 | man_dim = x.man_dim
216 |
217 | # Flatten the tensor and return the new instance.
218 | flattened = torch.flatten(
219 | input=x.tensor,
220 | start_dim=start_dim,
221 | end_dim=end_dim,
222 | )
223 | return ManifoldTensor(data=flattened, manifold=x.manifold, man_dim=man_dim)
224 |
225 | def cdist(self, x: ManifoldTensor, y: ManifoldTensor) -> Tensor:
226 | return torch.cdist(x.tensor, y.tensor)
227 |
228 | def cat(
229 | self,
230 | manifold_tensors: Union[Tuple[ManifoldTensor, ...], List[ManifoldTensor]],
231 | dim: int = 0,
232 | ) -> ManifoldTensor:
233 | check_if_man_dims_match(manifold_tensors)
234 | cat = torch.cat([t.tensor for t in manifold_tensors], dim=dim)
235 | man_dim = manifold_tensors[0].man_dim
236 | return ManifoldTensor(data=cat, manifold=self, man_dim=man_dim)
237 |
--------------------------------------------------------------------------------
/hypll/manifolds/poincare_ball/__init__.py:
--------------------------------------------------------------------------------
1 | from .curvature import Curvature
2 | from .manifold import PoincareBall
3 |
--------------------------------------------------------------------------------
/hypll/manifolds/poincare_ball/curvature.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import torch
4 | from torch import Tensor, as_tensor
5 | from torch.nn import Module
6 | from torch.nn.functional import softplus
7 | from torch.nn.parameter import Parameter
8 |
9 |
10 | class Curvature(Module):
11 | """Class representing curvature of a manifold.
12 |
13 | Attributes:
14 | value:
15 | Learnable parameter indicating curvature of the manifold. The actual
16 | curvature is calculated as constraining_strategy(value).
17 | constraining_strategy:
18 | Function applied to the curvature value in order to constrain the
19 | curvature of the manifold. By default uses softplus to guarantee
20 | positive curvature.
21 | requires_grad:
22 | If the curvature requires gradient. False by default.
23 |
24 | """
25 |
26 | def __init__(
27 | self,
28 | value: float = 1.0,
29 | constraining_strategy: Callable[[Tensor], Tensor] = softplus,
30 | requires_grad: bool = False,
31 | ):
32 | super(Curvature, self).__init__()
33 | self.value = Parameter(
34 | data=as_tensor(value),
35 | requires_grad=requires_grad,
36 | )
37 | self.constraining_strategy = constraining_strategy
38 |
39 | def forward(self) -> Tensor:
40 | """Returns curvature calculated as constraining_strategy(value)."""
41 | return self.constraining_strategy(self.value)
42 |
--------------------------------------------------------------------------------
/hypll/manifolds/poincare_ball/manifold.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from typing import List, Optional, Tuple, Union
3 |
4 | import torch
5 | from torch import Tensor, empty, eye, no_grad
6 | from torch.nn import Parameter
7 | from torch.nn.common_types import _size_2_t
8 | from torch.nn.functional import softplus, unfold
9 | from torch.nn.init import normal_, zeros_
10 |
11 | from hypll.manifolds.base import Manifold
12 | from hypll.manifolds.euclidean import Euclidean
13 | from hypll.manifolds.poincare_ball.curvature import Curvature
14 | from hypll.tensors import ManifoldParameter, ManifoldTensor, TangentTensor
15 | from hypll.utils.math import beta_func
16 | from hypll.utils.tensor_utils import (
17 | check_dims_with_broadcasting,
18 | check_if_man_dims_match,
19 | check_tangent_tensor_positions,
20 | )
21 |
22 | from .math.diffgeom import (
23 | cdist,
24 | dist,
25 | euc_to_tangent,
26 | expmap,
27 | expmap0,
28 | gyration,
29 | inner,
30 | logmap,
31 | logmap0,
32 | mobius_add,
33 | project,
34 | transp,
35 | )
36 | from .math.linalg import poincare_fully_connected, poincare_hyperplane_dists
37 | from .math.stats import frechet_mean, frechet_variance, midpoint
38 |
39 |
40 | class PoincareBall(Manifold):
41 | """Class representing the Poincare ball model of hyperbolic space.
42 |
43 | Implementation based on the geoopt implementation, but changed to use
44 | hyperbolic torch functions.
45 |
46 | Attributes:
47 | c:
48 | Curvature of the manifold.
49 |
50 | """
51 |
52 | def __init__(self, c: Curvature):
53 | """Initializes an instance of PoincareBall manifold.
54 |
55 | Examples:
56 | >>> from hypll.manifolds.poincare_ball import PoincareBall, Curvature
57 | >>> curvature = Curvature(value=1.0)
58 | >>> manifold = Manifold(c=curvature)
59 |
60 | """
61 | super(PoincareBall, self).__init__()
62 | self.c = c
63 |
64 | def mobius_add(self, x: ManifoldTensor, y: ManifoldTensor) -> ManifoldTensor:
65 | dim = check_dims_with_broadcasting(x, y)
66 | new_tensor = mobius_add(x=x.tensor, y=y.tensor, c=self.c(), dim=dim)
67 | return ManifoldTensor(data=new_tensor, manifold=self, man_dim=dim)
68 |
69 | def project(self, x: ManifoldTensor, eps: float = -1.0) -> ManifoldTensor:
70 | new_tensor = project(x=x.tensor, c=self.c(), dim=x.man_dim, eps=eps)
71 | return ManifoldTensor(data=new_tensor, manifold=self, man_dim=x.man_dim)
72 |
73 | def expmap(self, v: TangentTensor) -> ManifoldTensor:
74 | dim = v.broadcasted_man_dim
75 | if v.manifold_points is None:
76 | new_tensor = expmap0(v=v.tensor, c=self.c(), dim=dim)
77 | else:
78 | new_tensor = expmap(x=v.manifold_points.tensor, v=v.tensor, c=self.c(), dim=dim)
79 | return ManifoldTensor(data=new_tensor, manifold=self, man_dim=dim)
80 |
81 | def logmap(self, x: Optional[ManifoldTensor], y: ManifoldTensor):
82 | if x is None:
83 | dim = y.man_dim
84 | new_tensor = logmap0(y=y.tensor, c=self.c(), dim=y.man_dim)
85 | else:
86 | dim = check_dims_with_broadcasting(x, y)
87 | new_tensor = logmap(x=x.tensor, y=y.tensor, c=self.c(), dim=dim)
88 | return TangentTensor(data=new_tensor, manifold_points=x, manifold=self, man_dim=dim)
89 |
90 | def gyration(self, u: ManifoldTensor, v: ManifoldTensor, w: ManifoldTensor) -> ManifoldTensor:
91 | dim = check_dims_with_broadcasting(u, v, w)
92 | new_tensor = gyration(u=u.tensor, v=v.tensor, w=w.tensor, c=self.c(), dim=dim)
93 | return ManifoldTensor(data=new_tensor, manifold=self, man_dim=dim)
94 |
95 | def transp(self, v: TangentTensor, y: ManifoldTensor) -> TangentTensor:
96 | dim = check_dims_with_broadcasting(v, y)
97 | tangent_vectors = transp(
98 | x=v.manifold_points.tensor, y=y.tensor, v=v.tensor, c=self.c(), dim=dim
99 | )
100 | return TangentTensor(
101 | data=tangent_vectors,
102 | manifold_points=y,
103 | manifold=self,
104 | man_dim=dim,
105 | )
106 |
107 | def dist(self, x: ManifoldTensor, y: ManifoldTensor) -> Tensor:
108 | dim = check_dims_with_broadcasting(x, y)
109 | return dist(x=x.tensor, y=y.tensor, c=self.c(), dim=dim)
110 |
111 | def inner(
112 | self, u: TangentTensor, v: TangentTensor, keepdim: bool = False, safe_mode: bool = True
113 | ) -> Tensor:
114 | dim = check_dims_with_broadcasting(u, v)
115 | if safe_mode:
116 | check_tangent_tensor_positions(u, v)
117 |
118 | return inner(
119 | x=u.manifold_points.tensor, u=u.tensor, v=v.tensor, c=self.c(), dim=dim, keepdim=keepdim
120 | )
121 |
122 | def euc_to_tangent(self, x: ManifoldTensor, u: ManifoldTensor) -> TangentTensor:
123 | dim = check_dims_with_broadcasting(x, u)
124 | tangent_vectors = euc_to_tangent(x=x.tensor, u=u.tensor, c=self.c(), dim=x.man_dim)
125 | return TangentTensor(
126 | data=tangent_vectors,
127 | manifold_points=x,
128 | manifold=self,
129 | man_dim=dim,
130 | )
131 |
132 | def hyperplane_dists(self, x: ManifoldTensor, z: ManifoldTensor, r: Optional[Tensor]) -> Tensor:
133 | if x.man_dim != 1 or z.man_dim != 0:
134 | raise ValueError(
135 | f"Expected the manifold dimension of the inputs to be 1 and the manifold "
136 | f"dimension of the hyperplane orientations to be 0, but got {x.man_dim} and "
137 | f"{z.man_dim}, respectively"
138 | )
139 | return poincare_hyperplane_dists(x=x.tensor, z=z.tensor, r=r, c=self.c())
140 |
141 | def fully_connected(
142 | self, x: ManifoldTensor, z: ManifoldTensor, bias: Optional[Tensor]
143 | ) -> ManifoldTensor:
144 | if z.man_dim != 0:
145 | raise ValueError(
146 | f"Expected the manifold dimension of the hyperplane orientations to be 0, but got "
147 | f"{z.man_dim} instead"
148 | )
149 | new_tensor = poincare_fully_connected(
150 | x=x.tensor, z=z.tensor, bias=bias, c=self.c(), dim=x.man_dim
151 | )
152 | new_tensor = ManifoldTensor(data=new_tensor, manifold=self, man_dim=x.man_dim)
153 | return self.project(new_tensor)
154 |
155 | def frechet_mean(
156 | self,
157 | x: ManifoldTensor,
158 | batch_dim: Union[int, list[int]] = 0,
159 | keepdim: bool = False,
160 | ) -> ManifoldTensor:
161 | if isinstance(batch_dim, int):
162 | batch_dim = [batch_dim]
163 | output_man_dim = x.man_dim - sum(bd < x.man_dim for bd in batch_dim)
164 | new_tensor = frechet_mean(
165 | x=x.tensor, c=self.c(), vec_dim=x.man_dim, batch_dim=batch_dim, keepdim=keepdim
166 | )
167 | return ManifoldTensor(data=new_tensor, manifold=self, man_dim=output_man_dim)
168 |
169 | def midpoint(
170 | self,
171 | x: ManifoldTensor,
172 | batch_dim: int = 0,
173 | w: Optional[Tensor] = None,
174 | keepdim: bool = False,
175 | ) -> ManifoldTensor:
176 | if isinstance(batch_dim, int):
177 | batch_dim = [batch_dim]
178 |
179 | if x.man_dim in batch_dim:
180 | raise ValueError(
181 | f"Tried to aggregate over dimensions {batch_dim}, but input has manifold "
182 | f"dimension {x.man_dim} and cannot aggregate over this dimension"
183 | )
184 |
185 | # Output manifold dimension is shifted left for each batch dim that disappears
186 | man_dim_shift = sum(bd < x.man_dim for bd in batch_dim)
187 | new_man_dim = x.man_dim - man_dim_shift if not keepdim else x.man_dim
188 |
189 | new_tensor = midpoint(
190 | x=x.tensor, c=self.c(), man_dim=x.man_dim, batch_dim=batch_dim, w=w, keepdim=keepdim
191 | )
192 | return ManifoldTensor(data=new_tensor, manifold=self, man_dim=new_man_dim)
193 |
194 | def frechet_variance(
195 | self,
196 | x: ManifoldTensor,
197 | mu: Optional[ManifoldTensor] = None,
198 | batch_dim: Union[int, list[int]] = -1,
199 | keepdim: bool = False,
200 | ) -> Tensor:
201 | if mu is not None:
202 | mu = mu.tensor
203 |
204 | # TODO: Check if x and mu have compatible man_dims
205 | return frechet_variance(
206 | x=x.tensor,
207 | c=self.c(),
208 | mu=mu,
209 | vec_dim=x.man_dim,
210 | batch_dim=batch_dim,
211 | keepdim=keepdim,
212 | )
213 |
214 | def construct_dl_parameters(
215 | self, in_features: int, out_features: int, bias: bool = True
216 | ) -> tuple[ManifoldParameter, Optional[Parameter]]:
217 | weight = ManifoldParameter(
218 | data=empty(in_features, out_features),
219 | manifold=Euclidean(),
220 | man_dim=0,
221 | )
222 |
223 | if bias:
224 | b = Parameter(data=empty(out_features))
225 | else:
226 | b = None
227 |
228 | return weight, b
229 |
230 | def reset_parameters(self, weight: ManifoldParameter, bias: Optional[Parameter]) -> None:
231 | in_features, out_features = weight.size()
232 | if in_features <= out_features:
233 | with no_grad():
234 | weight.tensor.copy_(1 / 2 * eye(in_features, out_features))
235 | else:
236 | normal_(
237 | weight.tensor,
238 | mean=0,
239 | std=(2 * in_features * out_features) ** -0.5,
240 | )
241 | if bias is not None:
242 | zeros_(bias)
243 |
244 | def unfold(
245 | self,
246 | input: ManifoldTensor,
247 | kernel_size: _size_2_t,
248 | dilation: _size_2_t = 1,
249 | padding: _size_2_t = 0,
250 | stride: _size_2_t = 1,
251 | ) -> ManifoldTensor:
252 | # TODO: may have to cache some of this stuff for efficiency.
253 | in_channels = input.size(1)
254 | if len(kernel_size) == 2:
255 | kernel_vol = kernel_size[0] * kernel_size[1]
256 | else:
257 | kernel_vol = kernel_size**2
258 | kernel_size = (kernel_size, kernel_size)
259 |
260 | beta_ni = beta_func(in_channels / 2, 1 / 2)
261 | beta_n = beta_func(in_channels * kernel_vol / 2, 1 / 2)
262 |
263 | input = self.logmap(x=None, y=input)
264 | input.tensor = input.tensor * beta_n / beta_ni
265 | new_tensor = unfold(
266 | input=input.tensor,
267 | kernel_size=kernel_size,
268 | dilation=dilation,
269 | padding=padding,
270 | stride=stride,
271 | )
272 |
273 | new_tensor = TangentTensor(data=new_tensor, manifold_points=None, manifold=self, man_dim=1)
274 | return self.expmap(new_tensor)
275 |
276 | def flatten(self, x: ManifoldTensor, start_dim: int = 1, end_dim: int = -1) -> ManifoldTensor:
277 | """Flattens a manifold tensor by reshaping it. If start_dim or end_dim are passed,
278 | only dimensions starting with start_dim and ending with end_dim are flattend.
279 |
280 | If the manifold dimension of the input tensor is among the dimensions which
281 | are flattened, applies beta-concatenation to the points on the manifold.
282 | Otherwise simply flattens the tensor using torch.flatten.
283 |
284 | Updates the manifold dimension if necessary.
285 |
286 | """
287 | start_dim = x.dim() + start_dim if start_dim < 0 else start_dim
288 | end_dim = x.dim() + end_dim if end_dim < 0 else end_dim
289 |
290 | # Get the range of dimensions to flatten.
291 | dimensions_to_flatten = x.shape[start_dim + 1 : end_dim + 1]
292 |
293 | if start_dim <= x.man_dim and end_dim >= x.man_dim:
294 | # Use beta concatenation to flatten the manifold dimension of the tensor.
295 | #
296 | # Start by applying logmap at the origin and computing the betas.
297 | tangents = self.logmap(None, x)
298 | n_i = x.shape[x.man_dim]
299 | n = n_i * functools.reduce(lambda a, b: a * b, dimensions_to_flatten)
300 | beta_n = beta_func(n / 2, 0.5)
301 | beta_n_i = beta_func(n_i / 2, 0.5)
302 | # Flatten the tensor and rescale.
303 | tangents.tensor = torch.flatten(
304 | input=tangents.tensor,
305 | start_dim=start_dim,
306 | end_dim=end_dim,
307 | )
308 | tangents.tensor = tangents.tensor * beta_n / beta_n_i
309 | # Set the new manifold dimension
310 | tangents.man_dim = start_dim
311 | # Apply exponential map at the origin.
312 | return self.expmap(tangents)
313 | else:
314 | flattened = torch.flatten(
315 | input=x.tensor,
316 | start_dim=start_dim,
317 | end_dim=end_dim,
318 | )
319 | man_dim = x.man_dim if end_dim > x.man_dim else x.man_dim - len(dimensions_to_flatten)
320 | return ManifoldTensor(data=flattened, manifold=x.manifold, man_dim=man_dim)
321 |
322 | def cdist(self, x: ManifoldTensor, y: ManifoldTensor) -> Tensor:
323 | return cdist(x=x.tensor, y=y.tensor, c=self.c())
324 |
325 | def cat(
326 | self,
327 | manifold_tensors: Union[Tuple[ManifoldTensor, ...], List[ManifoldTensor]],
328 | dim: int = 0,
329 | ) -> ManifoldTensor:
330 | check_if_man_dims_match(manifold_tensors)
331 | if dim == manifold_tensors[0].man_dim:
332 | tangent_tensors = [self.logmap(None, t) for t in manifold_tensors]
333 | ns = torch.tensor([t.shape[t.man_dim] for t in manifold_tensors])
334 | n = ns.sum()
335 | beta_ns = beta_func(ns / 2, 0.5)
336 | beta_n = beta_func(n / 2, 0.5)
337 | cat = torch.cat(
338 | [(t.tensor * beta_n) / beta_n_i for (t, beta_n_i) in zip(tangent_tensors, beta_ns)],
339 | dim=dim,
340 | )
341 | new_tensor = TangentTensor(data=cat, manifold=self, man_dim=dim)
342 | return self.expmap(new_tensor)
343 | else:
344 | cat = torch.cat([t.tensor for t in manifold_tensors], dim=dim)
345 | man_dim = manifold_tensors[0].man_dim
346 | return ManifoldTensor(data=cat, manifold=self, man_dim=man_dim)
347 |
--------------------------------------------------------------------------------
/hypll/manifolds/poincare_ball/math/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxvanspengler/hyperbolic_learning_library/62ea7882848d878459eba4100aeb6b8ad64cc38f/hypll/manifolds/poincare_ball/math/__init__.py
--------------------------------------------------------------------------------
/hypll/manifolds/poincare_ball/math/diffgeom.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def mobius_add(x: torch.Tensor, y: torch.Tensor, c: torch.Tensor, dim: int = -1):
5 | broadcasted_dim = max(x.dim(), y.dim())
6 | dim = dim if dim >= 0 else broadcasted_dim + dim
7 | x2 = x.pow(2).sum(
8 | dim=dim - broadcasted_dim + x.dim(),
9 | keepdim=True,
10 | )
11 | y2 = y.pow(2).sum(
12 | dim=dim - broadcasted_dim + y.dim(),
13 | keepdim=True,
14 | )
15 | xy = (x * y).sum(dim=dim, keepdim=True)
16 | num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
17 | denom = 1 + 2 * c * xy + c**2 * x2 * y2
18 | return num / denom.clamp_min(1e-15)
19 |
20 |
21 | def project(x: torch.Tensor, c: torch.Tensor, dim: int = -1, eps: float = -1.0):
22 | if eps < 0:
23 | if x.dtype == torch.float32:
24 | eps = 4e-3
25 | else:
26 | eps = 1e-5
27 | maxnorm = (1 - eps) / ((c + 1e-15) ** 0.5)
28 | maxnorm = torch.where(c.gt(0), maxnorm, c.new_full((), 1e15))
29 | norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(1e-15)
30 | cond = norm > maxnorm
31 | projected = x / norm * maxnorm
32 | return torch.where(cond, projected, x)
33 |
34 |
35 | def expmap0(v: torch.Tensor, c: torch.Tensor, dim: int = -1):
36 | v_norm_c_sqrt = v.norm(dim=dim, keepdim=True).clamp_min(1e-15) * c.sqrt()
37 | return project(torch.tanh(v_norm_c_sqrt) * v / v_norm_c_sqrt, c, dim=dim)
38 |
39 |
40 | def logmap0(y: torch.Tensor, c: torch.Tensor, dim: int = -1):
41 | y_norm_c_sqrt = y.norm(dim=dim, keepdim=True).clamp_min(1e-15) * c.sqrt()
42 | return torch.atanh(y_norm_c_sqrt) * y / y_norm_c_sqrt
43 |
44 |
45 | def expmap(x: torch.Tensor, v: torch.Tensor, c: torch.Tensor, dim: int = -1):
46 | broadcasted_dim = max(x.dim(), v.dim())
47 | dim = dim if dim >= 0 else broadcasted_dim + dim
48 | v_norm = v.norm(dim=dim - broadcasted_dim + v.dim(), keepdim=True).clamp_min(1e-15)
49 | lambda_x = 2 / (
50 | 1 - c * x.pow(2).sum(dim=dim - broadcasted_dim + x.dim(), keepdim=True)
51 | ).clamp_min(1e-15)
52 | c_sqrt = c.sqrt()
53 | second_term = torch.tanh(c_sqrt * lambda_x * v_norm / 2) * v / (c_sqrt * v_norm)
54 | return project(mobius_add(x, second_term, c, dim=dim), c, dim=dim)
55 |
56 |
57 | def logmap(x: torch.Tensor, y: torch.Tensor, c: torch.Tensor, dim: int = -1):
58 | broadcasted_dim = max(x.dim(), y.dim())
59 | dim = dim if dim >= 0 else broadcasted_dim + dim
60 | min_x_y = mobius_add(-x, y, c, dim=dim)
61 | min_x_y_norm = min_x_y.norm(dim=dim, keepdim=True).clamp_min(1e-15)
62 | lambda_x = 2 / (
63 | 1 - c * x.pow(2).sum(dim=dim - broadcasted_dim + x.dim(), keepdim=True)
64 | ).clamp_min(1e-15)
65 | c_sqrt = c.sqrt()
66 | return 2 / (c_sqrt * lambda_x) * torch.atanh(c_sqrt * min_x_y_norm) * min_x_y / min_x_y_norm
67 |
68 |
69 | def gyration(
70 | u: torch.Tensor,
71 | v: torch.Tensor,
72 | w: torch.Tensor,
73 | c: torch.Tensor,
74 | dim: int = -1,
75 | ):
76 | broadcasted_dim = max(u.dim(), v.dim(), w.dim())
77 | dim = dim if dim >= 0 else broadcasted_dim + dim
78 | u2 = u.pow(2).sum(dim=dim - broadcasted_dim + u.dim(), keepdim=True)
79 | v2 = v.pow(2).sum(dim=dim - broadcasted_dim + v.dim(), keepdim=True)
80 | uv = (u * v).sum(dim=dim - broadcasted_dim + max(u.dim(), v.dim()), keepdim=True)
81 | uw = (u * w).sum(dim=dim - broadcasted_dim + max(u.dim(), w.dim()), keepdim=True)
82 | vw = (v * w).sum(dim=dim - broadcasted_dim + max(v.dim(), w.dim()), keepdim=True)
83 | K2 = c**2
84 | a = -K2 * uw * v2 + c * vw + 2 * K2 * uv * vw
85 | b = -K2 * vw * u2 - c * uw
86 | d = 1 + 2 * c * uv + K2 * u2 * v2
87 | return w + 2 * (a * u + b * v) / d.clamp_min(1e-15)
88 |
89 |
90 | def transp(
91 | x: torch.Tensor,
92 | y: torch.Tensor,
93 | v: torch.Tensor,
94 | c: torch.Tensor,
95 | dim: int = -1,
96 | ):
97 | broadcasted_dim = max(x.dim(), y.dim(), v.dim())
98 | dim = dim if dim >= 0 else broadcasted_dim + dim
99 | lambda_x = 2 / (
100 | 1 - c * x.pow(2).sum(dim=dim - broadcasted_dim + x.dim(), keepdim=True)
101 | ).clamp_min(1e-15)
102 | lambda_y = 2 / (
103 | 1 - c * y.pow(2).sum(dim=dim - broadcasted_dim + y.dim(), keepdim=True)
104 | ).clamp_min(1e-15)
105 | return gyration(y, -x, v, c, dim=dim) * lambda_x / lambda_y
106 |
107 |
108 | def dist(
109 | x: torch.Tensor,
110 | y: torch.Tensor,
111 | c: torch.Tensor,
112 | dim: int = -1,
113 | keepdim: bool = False,
114 | ) -> torch.Tensor:
115 | return (
116 | 2
117 | / c.sqrt()
118 | * (c.sqrt() * mobius_add(-x, y, c, dim=dim).norm(dim=dim, keepdim=keepdim)).atanh()
119 | )
120 |
121 |
122 | def inner(
123 | x: torch.Tensor,
124 | u: torch.Tensor,
125 | v: torch.Tensor,
126 | c: torch.Tensor,
127 | dim: int = -1,
128 | keepdim: bool = False,
129 | ) -> torch.Tensor:
130 | broadcasted_dim = max(x.dim(), u.dim(), v.dim())
131 | dim = dim if dim >= 0 else broadcasted_dim + dim
132 | lambda_x = 2 / (
133 | 1 - c * x.pow(2).sum(dim=dim - broadcasted_dim + x.dim(), keepdim=True)
134 | ).clamp_min(1e-15)
135 | dot_prod = (u * v).sum(dim=dim, keepdim=keepdim)
136 | return lambda_x.square() * dot_prod
137 |
138 |
139 | def euc_to_tangent(
140 | x: torch.Tensor,
141 | u: torch.Tensor,
142 | c: torch.Tensor,
143 | dim: int = -1,
144 | ) -> torch.Tensor:
145 | broadcasted_dim = max(x.dim(), u.dim())
146 | dim = dim if dim >= 0 else broadcasted_dim + dim
147 | lambda_x = 2 / (
148 | 1 - c * x.pow(2).sum(dim=dim - broadcasted_dim + x.dim(), keepdim=True)
149 | ).clamp_min(1e-15)
150 | return u / lambda_x**2
151 |
152 |
153 | def cdist(
154 | x: torch.Tensor,
155 | y: torch.Tensor,
156 | c: torch.Tensor,
157 | ) -> torch.Tensor:
158 | return 2 / c.sqrt() * (c.sqrt() * mobius_add_batch(-x, y, c).norm(dim=-1)).atanh()
159 |
160 |
161 | def mobius_add_batch(
162 | x: torch.Tensor,
163 | y: torch.Tensor,
164 | c: torch.Tensor,
165 | ):
166 | xy = torch.einsum("bij,bkj->bik", (x, y))
167 | x2 = x.pow(2).sum(dim=-1, keepdim=True)
168 | y2 = y.pow(2).sum(dim=-1, keepdim=True)
169 | num = 1 + 2 * c * xy + c * y2.permute(0, 2, 1)
170 | num = num.unsqueeze(3) * x.unsqueeze(2)
171 | num = num + (1 - c * x2).unsqueeze(3) * y.unsqueeze(1)
172 | denom = 1 + 2 * c * xy + c**2 * x2 * y2.permute(0, 2, 1)
173 | return num / denom.unsqueeze(3).clamp_min(1e-15)
174 |
--------------------------------------------------------------------------------
/hypll/manifolds/poincare_ball/math/linalg.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 |
5 |
6 | def poincare_hyperplane_dists(
7 | x: torch.Tensor,
8 | z: torch.Tensor,
9 | r: Optional[torch.Tensor],
10 | c: torch.Tensor,
11 | dim: int = -1,
12 | ) -> torch.Tensor:
13 | """
14 | The Poincare signed distance to hyperplanes operation.
15 |
16 | Parameters
17 | ----------
18 | x : tensor
19 | contains input values
20 | z : tensor
21 | contains the hyperbolic vectors describing the hyperplane orientations
22 | r : tensor
23 | contains the hyperplane offsets
24 | c : tensor
25 | curvature of the Poincare disk
26 |
27 | Returns
28 | -------
29 | tensor
30 | signed distances of input w.r.t. the hyperplanes, denoted by v_k(x) in
31 | the HNN++ paper
32 | """
33 | dim_shifted_x = x.movedim(source=dim, destination=-1)
34 |
35 | c_sqrt = c.sqrt()
36 | lam = 2 / (1 - c * dim_shifted_x.pow(2).sum(dim=-1, keepdim=True))
37 | z_norm = z.norm(dim=0).clamp_min(1e-15)
38 |
39 | # Computation can be simplified if there is no offset
40 | if r is None:
41 | dim_shifted_output = (
42 | 2
43 | * z_norm
44 | / c_sqrt
45 | * torch.asinh(c_sqrt * lam / z_norm * torch.matmul(dim_shifted_x, z))
46 | )
47 | else:
48 | two_csqrt_r = 2.0 * c_sqrt * r
49 | dim_shifted_output = (
50 | 2
51 | * z_norm
52 | / c_sqrt
53 | * torch.asinh(
54 | c_sqrt * lam / z_norm * torch.matmul(dim_shifted_x, z) * two_csqrt_r.cosh()
55 | - (lam - 1) * two_csqrt_r.sinh()
56 | )
57 | )
58 |
59 | return dim_shifted_output.movedim(source=-1, destination=dim)
60 |
61 |
62 | def poincare_fully_connected(
63 | x: torch.Tensor,
64 | z: torch.Tensor,
65 | bias: Optional[torch.Tensor],
66 | c: torch.Tensor,
67 | dim: int = -1,
68 | ) -> torch.Tensor:
69 | """
70 | The Poincare fully connected layer operation.
71 |
72 | Parameters
73 | ----------
74 | x : tensor
75 | contains the layer inputs
76 | z : tensor
77 | contains the hyperbolic vectors describing the hyperplane orientations
78 | bias : tensor
79 | contains the biases (hyperplane offsets)
80 | c : tensor
81 | curvature of the Poincare disk
82 |
83 | Returns
84 | -------
85 | tensor
86 | Poincare FC transformed hyperbolic tensor, commonly denoted by y
87 | """
88 | c_sqrt = c.sqrt()
89 | x = poincare_hyperplane_dists(x=x, z=z, r=bias, c=c, dim=dim)
90 | x = (c_sqrt * x).sinh() / c_sqrt
91 | return x / (1 + (1 + c * x.pow(2).sum(dim=dim, keepdim=True)).sqrt())
92 |
--------------------------------------------------------------------------------
/hypll/manifolds/poincare_ball/math/stats.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .diffgeom import dist
7 |
8 | _TOLEPS = {torch.float32: 1e-6, torch.float64: 1e-12}
9 |
10 |
11 | class FrechetMean(torch.autograd.Function):
12 | """
13 | This implementation is copied mostly from:
14 | https://github.com/CUVL/Differentiable-Frechet-Mean.git
15 |
16 | which is itself based on the paper:
17 | https://arxiv.org/abs/2003.00335
18 |
19 | Both by Aaron Lou (et al.)
20 | """
21 |
22 | @staticmethod
23 | def forward(ctx, x, c, vec_dim, batch_dim, keepdim):
24 | # Convert input dimensions to positive values
25 | vec_dim = vec_dim if vec_dim > 0 else x.dim() + vec_dim
26 | batch_dim = [bd if bd >= 0 else x.dim() + bd for bd in batch_dim]
27 |
28 | # Compute some dim ids for later
29 | output_vec_dim = vec_dim - sum(bd < vec_dim for bd in batch_dim)
30 | batch_start_id = -len(batch_dim) - 1
31 | batch_stop_id = -2
32 | original_dims = [x.size(bd) for bd in batch_dim]
33 |
34 | # Move dims around and flatten the batch dimensions
35 | x = x.movedim(
36 | source=batch_dim + [vec_dim],
37 | destination=list(range(batch_start_id, 0)),
38 | )
39 | x = x.flatten(
40 | start_dim=batch_start_id, end_dim=batch_stop_id
41 | ) # ..., prod(batch_dim), vec_dim
42 |
43 | # Compute frechet mean and store variables
44 | mean = frechet_ball_forward(X=x, K=c, rtol=_TOLEPS[x.dtype], atol=_TOLEPS[x.dtype])
45 | ctx.save_for_backward(x, mean, c)
46 | ctx.vec_dim = vec_dim
47 | ctx.batch_dim = batch_dim
48 | ctx.output_vec_dim = output_vec_dim
49 | ctx.batch_start_id = batch_start_id
50 | ctx.original_dims = original_dims
51 |
52 | # Reorder dimensions to match dimensions of input
53 | mean = mean.movedim(
54 | source=list(range(output_vec_dim - mean.dim(), 0)),
55 | destination=list(range(output_vec_dim + 1, mean.dim())) + [output_vec_dim],
56 | )
57 |
58 | # Add dimensions back to mean if keepdim
59 | if keepdim:
60 | for bd in sorted(batch_dim):
61 | mean = mean.unsqueeze(bd)
62 |
63 | return mean
64 |
65 | @staticmethod
66 | def backward(ctx, grad_output):
67 | x, mean, c = ctx.saved_tensors
68 | # Reshift dims in grad to match the dims of the mean that was stored in ctx
69 | grad_output = grad_output.movedim(
70 | source=list(range(ctx.output_vec_dim - mean.dim(), 0)),
71 | destination=[-1] + list(range(ctx.output_vec_dim - mean.dim(), -1)),
72 | )
73 | dx, dc = frechet_ball_backward(X=x, y=mean, grad=grad_output, K=c)
74 | vec_dim = ctx.vec_dim
75 | batch_dim = ctx.batch_dim
76 | batch_start_id = ctx.batch_start_id
77 | original_dims = ctx.original_dims
78 |
79 | # Unflatten the batch dimensions
80 | dx = dx.unflatten(dim=-2, sizes=original_dims)
81 |
82 | # Move the vector dimension back
83 | dx = dx.movedim(
84 | source=list(range(batch_start_id, 0)),
85 | destination=batch_dim + [vec_dim],
86 | )
87 |
88 | return dx, dc, None, None, None
89 |
90 |
91 | def frechet_mean(
92 | x: torch.Tensor,
93 | c: torch.Tensor,
94 | vec_dim: int = -1,
95 | batch_dim: Union[int, list[int]] = 0,
96 | keepdim: bool = False,
97 | ) -> torch.Tensor:
98 | if isinstance(batch_dim, int):
99 | batch_dim = [batch_dim]
100 |
101 | return FrechetMean.apply(x, c, vec_dim, batch_dim, keepdim)
102 |
103 |
104 | def midpoint(
105 | x: torch.Tensor,
106 | c: torch.Tensor,
107 | man_dim: int = -1,
108 | batch_dim: Union[int, list[int]] = 0,
109 | w: Optional[torch.Tensor] = None,
110 | keepdim: bool = False,
111 | ) -> torch.Tensor:
112 | lambda_x = 2 / (1 - c * x.pow(2).sum(dim=man_dim, keepdim=True)).clamp_min(1e-15)
113 | if w is None:
114 | numerator = (lambda_x * x).sum(dim=batch_dim, keepdim=True)
115 | denominator = (lambda_x - 1).sum(dim=batch_dim, keepdim=True)
116 | else:
117 | # TODO: test weights
118 | numerator = (lambda_x * w * x).sum(dim=batch_dim, keepdim=True)
119 | denominator = ((lambda_x - 1) * w).sum(dim=batch_dim, keepdim=True)
120 |
121 | frac = numerator / denominator.clamp_min(1e-15)
122 | midpoint = 1 / (1 + (1 - c * frac.pow(2).sum(dim=man_dim, keepdim=True)).sqrt()) * frac
123 | if not keepdim:
124 | midpoint = midpoint.squeeze(dim=batch_dim)
125 |
126 | return midpoint
127 |
128 |
129 | def frechet_variance(
130 | x: torch.Tensor,
131 | c: torch.Tensor,
132 | mu: Optional[torch.Tensor] = None,
133 | vec_dim: int = -1,
134 | batch_dim: Union[int, list[int]] = 0,
135 | keepdim: bool = False,
136 | ) -> torch.Tensor: # TODO
137 | """
138 | Args
139 | ----
140 | x (tensor): points of shape [..., points, dim]
141 | mu (tensor): mean of shape [..., dim]
142 |
143 | where the ... of the three variables match
144 |
145 | Returns
146 | -------
147 | tensor of shape [...]
148 | """
149 | if isinstance(batch_dim, int):
150 | batch_dim = [batch_dim]
151 |
152 | if mu is None:
153 | mu = frechet_mean(x=x, c=c, vec_dim=vec_dim, batch_dim=batch_dim, keepdim=True)
154 |
155 | if x.dim() != mu.dim():
156 | for bd in sorted(batch_dim):
157 | mu = mu.unsqueeze(bd)
158 |
159 | distance = dist(x=x, y=mu, c=c, dim=vec_dim, keepdim=keepdim)
160 | distance = distance.pow(2)
161 | return distance.mean(dim=batch_dim, keepdim=keepdim)
162 |
163 |
164 | def l_prime(y: torch.Tensor) -> torch.Tensor:
165 | cond = y < 1e-12
166 | val = 4 * torch.ones_like(y)
167 | ret = torch.where(cond, val, 2 * (1 + 2 * y).acosh() / (y.pow(2) + y).sqrt())
168 | return ret
169 |
170 |
171 | def frechet_ball_forward(
172 | X: torch.Tensor,
173 | K: torch.Tensor = torch.Tensor([1]),
174 | max_iter: int = 1000,
175 | rtol: float = 1e-6,
176 | atol: float = 1e-6,
177 | ) -> torch.Tensor:
178 | """
179 | Args
180 | ----
181 | X (tensor): point of shape [..., points, dim]
182 | K (float): curvature (must be negative)
183 |
184 | Returns
185 | -------
186 | frechet mean (tensor): shape [..., dim]
187 | """
188 | mu = X[..., 0, :].clone()
189 |
190 | x_ss = X.pow(2).sum(dim=-1)
191 |
192 | mu_prev = mu
193 | iters = 0
194 | for _ in range(max_iter):
195 | mu_ss = mu.pow(2).sum(dim=-1)
196 | xmu_ss = (X - mu.unsqueeze(-2)).pow(2).sum(dim=-1)
197 |
198 | alphas = l_prime(K * xmu_ss / ((1 - K * x_ss) * (1 - K * mu_ss.unsqueeze(-1)))) / (
199 | 1 - K * x_ss
200 | )
201 |
202 | alphas = alphas
203 |
204 | c = (alphas * x_ss).sum(dim=-1)
205 | b = (alphas.unsqueeze(-1) * X).sum(dim=-2)
206 | a = alphas.sum(dim=-1)
207 |
208 | b_ss = b.pow(2).sum(dim=-1)
209 |
210 | eta = (a + K * c - ((a + K * c).pow(2) - 4 * K * b_ss).sqrt()) / (2 * K * b_ss).clamp_min(
211 | 1e-15
212 | )
213 |
214 | mu = eta.unsqueeze(-1) * b
215 |
216 | dist = (mu - mu_prev).norm(dim=-1)
217 | prev_dist = mu_prev.norm(dim=-1)
218 | if (dist < atol).all() or (dist / prev_dist < rtol).all():
219 | break
220 |
221 | mu_prev = mu
222 | iters += 1
223 |
224 | return mu
225 |
226 |
227 | def darcosh(x):
228 | cond = x < 1 + 1e-7
229 | x = torch.where(cond, 2 * torch.ones_like(x), x)
230 | x = torch.where(~cond, 2 * (x).acosh() / torch.sqrt(x**2 - 1), x)
231 | return x
232 |
233 |
234 | def d2arcosh(x):
235 | cond = x < 1 + 1e-7
236 | x = torch.where(cond, -2 / 3 * torch.ones_like(x), x)
237 | x = torch.where(
238 | ~cond,
239 | 2 / (x**2 - 1) - 2 * x * x.acosh() / ((x**2 - 1) ** (3 / 2)),
240 | x,
241 | )
242 | return x
243 |
244 |
245 | def grad_var(
246 | X: torch.Tensor,
247 | y: torch.Tensor,
248 | K: torch.Tensor,
249 | ) -> torch.Tensor:
250 | """
251 | Args
252 | ----
253 | X (tensor): point of shape [..., points, dim]
254 | y (tensor): mean point of shape [..., dim]
255 | K (float): curvature (must be negative)
256 |
257 | Returns
258 | -------
259 | grad (tensor): gradient of variance [..., dim]
260 | """
261 | yl = y.unsqueeze(-2)
262 | xnorm = 1 - K * X.pow(2).sum(dim=-1)
263 | ynorm = 1 - K * yl.pow(2).sum(dim=-1)
264 | xynorm = (X - yl).pow(2).sum(dim=-1)
265 |
266 | D = xnorm * ynorm
267 | v = 1 + 2 * K * xynorm / D
268 |
269 | Dl = D.unsqueeze(-1)
270 | vl = v.unsqueeze(-1)
271 |
272 | first_term = (X - yl) / Dl
273 | sec_term = -K / Dl.pow(2) * yl * xynorm.unsqueeze(-1) * xnorm.unsqueeze(-1)
274 | return -(4 * darcosh(vl) * (first_term + sec_term)).sum(dim=-2)
275 |
276 |
277 | def inverse_hessian(
278 | X: torch.Tensor,
279 | y: torch.Tensor,
280 | K: torch.Tensor,
281 | ) -> torch.Tensor:
282 | """
283 | Args
284 | ----
285 | X (tensor): point of shape [..., points, dim]
286 | y (tensor): mean point of shape [..., dim]
287 | K (float): curvature (must be negative)
288 |
289 | Returns
290 | -------
291 | inv_hess (tensor): inverse hessian of [..., points, dim, dim]
292 | """
293 | yl = y.unsqueeze(-2)
294 | xnorm = 1 - K * X.pow(2).sum(dim=-1)
295 | ynorm = 1 - K * yl.pow(2).sum(dim=-1)
296 | xynorm = (X - yl).pow(2).sum(dim=-1)
297 |
298 | D = xnorm * ynorm
299 | v = 1 + 2 * K * xynorm / D
300 |
301 | Dl = D.unsqueeze(-1)
302 | vl = v.unsqueeze(-1)
303 | vll = vl.unsqueeze(-1)
304 |
305 | """
306 | \partial T/ \partial y
307 | """
308 | first_const = -8 * (K**2) * xnorm / D.pow(2)
309 | matrix_val = (first_const.unsqueeze(-1) * yl).unsqueeze(-1) * (X - yl).unsqueeze(-2)
310 | first_term = matrix_val + matrix_val.transpose(-1, -2)
311 |
312 | sec_const = 16 * (K**3) * xnorm.pow(2) / D.pow(3) * xynorm
313 | sec_term = (sec_const.unsqueeze(-1) * yl).unsqueeze(-1) * yl.unsqueeze(-2)
314 |
315 | third_const = 4 * K / D + 4 * (K**2) * xnorm / D.pow(2) * xynorm
316 | third_term = third_const.reshape(*third_const.shape, 1, 1) * torch.eye(y.shape[-1]).to(
317 | X
318 | ).reshape((1,) * len(third_const.shape) + (y.shape[-1], y.shape[-1]))
319 |
320 | Ty = first_term + sec_term + third_term
321 |
322 | """
323 | T
324 | """
325 |
326 | first_term = -K / Dl * (X - yl)
327 | sec_term = K.pow(2) / Dl.pow(2) * yl * xynorm.unsqueeze(-1) * xnorm.unsqueeze(-1)
328 | T = 4 * (first_term + sec_term)
329 |
330 | """
331 | inverse of shape [..., points, dim, dim]
332 | """
333 | first_term = d2arcosh(vll) * T.unsqueeze(-1) * T.unsqueeze(-2)
334 | sec_term = darcosh(vll) * Ty
335 | hessian = (first_term + sec_term).sum(dim=-3) / K
336 | inv_hess = torch.inverse(hessian)
337 | return inv_hess
338 |
339 |
340 | def frechet_ball_backward(
341 | X: torch.Tensor,
342 | y: torch.Tensor,
343 | grad: torch.Tensor,
344 | K: torch.Tensor,
345 | ) -> tuple[torch.Tensor]:
346 | """
347 | Args
348 | ----
349 | X (tensor): point of shape [..., points, dim]
350 | y (tensor): mean point of shape [..., dim]
351 | grad (tensor): gradient
352 | K (float): curvature (must be negative)
353 |
354 | Returns
355 | -------
356 | gradients (tensor, tensor, tensor):
357 | gradient of X [..., points, dim], curvature []
358 | """
359 | if not torch.is_tensor(K):
360 | K = torch.tensor(K).to(X)
361 |
362 | with torch.no_grad():
363 | inv_hess = inverse_hessian(X, y, K=K)
364 |
365 | with torch.enable_grad():
366 | # clone variables
367 | X = nn.Parameter(X.detach())
368 | y = y.detach()
369 | K = nn.Parameter(K)
370 |
371 | grad = (inv_hess @ grad.unsqueeze(-1)).squeeze()
372 | gradf = grad_var(X, y, K)
373 | dx, dK = torch.autograd.grad(-gradf.squeeze(), (X, K), grad)
374 |
375 | return dx, dK
376 |
--------------------------------------------------------------------------------
/hypll/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .modules.activation import HReLU
2 | from .modules.batchnorm import HBatchNorm, HBatchNorm2d
3 | from .modules.change_manifold import ChangeManifold
4 | from .modules.container import TangentSequential
5 | from .modules.convolution import HConvolution2d
6 | from .modules.embedding import HEmbedding
7 | from .modules.flatten import HFlatten
8 | from .modules.linear import HLinear
9 | from .modules.pooling import HAvgPool2d, HMaxPool2d
10 |
--------------------------------------------------------------------------------
/hypll/nn/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxvanspengler/hyperbolic_learning_library/62ea7882848d878459eba4100aeb6b8ad64cc38f/hypll/nn/modules/__init__.py
--------------------------------------------------------------------------------
/hypll/nn/modules/activation.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module
2 | from torch.nn.functional import relu
3 |
4 | from hypll.manifolds import Manifold
5 | from hypll.tensors import ManifoldTensor
6 | from hypll.utils.layer_utils import check_if_manifolds_match, op_in_tangent_space
7 |
8 |
9 | class HReLU(Module):
10 | def __init__(self, manifold: Manifold) -> None:
11 | super(HReLU, self).__init__()
12 | self.manifold = manifold
13 |
14 | def forward(self, input: ManifoldTensor) -> ManifoldTensor:
15 | check_if_manifolds_match(layer=self, input=input)
16 | return op_in_tangent_space(
17 | op=relu,
18 | manifold=self.manifold,
19 | input=input,
20 | )
21 |
--------------------------------------------------------------------------------
/hypll/nn/modules/batchnorm.py:
--------------------------------------------------------------------------------
1 | from torch import tensor, zeros
2 | from torch.nn import Module, Parameter
3 |
4 | from hypll.manifolds import Manifold
5 | from hypll.tensors import ManifoldTensor, TangentTensor
6 | from hypll.utils.layer_utils import check_if_manifolds_match
7 |
8 |
9 | class HBatchNorm(Module):
10 | """
11 | Basic implementation of hyperbolic batch normalization.
12 |
13 | Based on:
14 | https://arxiv.org/abs/2003.00335
15 | """
16 |
17 | def __init__(
18 | self,
19 | features: int,
20 | manifold: Manifold,
21 | use_midpoint: bool = False,
22 | ) -> None:
23 | super(HBatchNorm, self).__init__()
24 | self.features = features
25 | self.manifold = manifold
26 | self.use_midpoint = use_midpoint
27 |
28 | # TODO: Store bias on manifold
29 | self.bias = Parameter(zeros(features))
30 | self.weight = Parameter(tensor(1.0))
31 |
32 | def forward(self, x: ManifoldTensor) -> ManifoldTensor:
33 | check_if_manifolds_match(layer=self, input=x)
34 | bias_on_manifold = self.manifold.expmap(
35 | v=TangentTensor(data=self.bias, manifold_points=None, manifold=self.manifold)
36 | )
37 |
38 | if self.use_midpoint:
39 | input_mean = self.manifold.midpoint(x=x, batch_dim=0)
40 | else:
41 | input_mean = self.manifold.frechet_mean(x=x, batch_dim=0)
42 |
43 | input_var = self.manifold.frechet_variance(x=x, mu=input_mean, batch_dim=0)
44 |
45 | input_logm = self.manifold.transp(
46 | v=self.manifold.logmap(input_mean, x),
47 | y=bias_on_manifold,
48 | )
49 |
50 | input_logm.tensor = (self.weight / (input_var + 1e-6)).sqrt() * input_logm.tensor
51 |
52 | output = self.manifold.expmap(input_logm)
53 |
54 | return output
55 |
56 |
57 | class HBatchNorm2d(Module):
58 | """
59 | 2D implementation of hyperbolic batch normalization.
60 |
61 | Based on:
62 | https://arxiv.org/abs/2003.00335
63 | """
64 |
65 | def __init__(
66 | self,
67 | features: int,
68 | manifold: Manifold,
69 | use_midpoint: bool = False,
70 | ) -> None:
71 | super(HBatchNorm2d, self).__init__()
72 | self.features = features
73 | self.manifold = manifold
74 | self.use_midpoint = use_midpoint
75 |
76 | self.norm = HBatchNorm(
77 | features=features,
78 | manifold=manifold,
79 | use_midpoint=use_midpoint,
80 | )
81 |
82 | def forward(self, x: ManifoldTensor) -> ManifoldTensor:
83 | check_if_manifolds_match(layer=self, input=x)
84 | batch_size, height, width = x.size(0), x.size(2), x.size(3)
85 | flat_x = ManifoldTensor(
86 | data=x.tensor.permute(0, 2, 3, 1).flatten(start_dim=0, end_dim=2),
87 | manifold=x.manifold,
88 | man_dim=-1,
89 | )
90 | flat_x = self.norm(flat_x)
91 | new_tensor = flat_x.tensor.reshape(batch_size, height, width, self.features).permute(
92 | 0, 3, 1, 2
93 | )
94 | return ManifoldTensor(data=new_tensor, manifold=x.manifold, man_dim=1)
95 |
--------------------------------------------------------------------------------
/hypll/nn/modules/change_manifold.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 | from torch.nn import Module
3 |
4 | from hypll.manifolds import Manifold
5 | from hypll.tensors import ManifoldTensor, TangentTensor
6 |
7 |
8 | class ChangeManifold(Module):
9 | """Changes the manifold of the input manifold tensor to the target manifold.
10 |
11 | Attributes:
12 | target_manifold:
13 | Manifold the output tensor should be on.
14 |
15 | """
16 |
17 | def __init__(self, target_manifold: Manifold):
18 | super(ChangeManifold, self).__init__()
19 | self.target_manifold = target_manifold
20 |
21 | def forward(self, x: ManifoldTensor) -> ManifoldTensor:
22 | """Changes the manifold of the input tensor to self.target_manifold.
23 |
24 | By default, applies logmap and expmap at the origin to switch between
25 | manifold. Applies shortcuts if possible.
26 |
27 | Args:
28 | x:
29 | Input manifold tensor.
30 |
31 | Returns:
32 | Tensor on the target manifold.
33 |
34 | """
35 |
36 | match (x.manifold, self.target_manifold):
37 | # TODO(Philipp, 05/23): Apply shortcuts where possible: For example,
38 | # case (PoincareBall(), PoincareBall()): ...
39 | #
40 | # By default resort to logmap + expmap at the origin.
41 | case _:
42 | tangent_tensor = x.manifold.logmap(None, x)
43 | return self.target_manifold.expmap(tangent_tensor)
44 |
--------------------------------------------------------------------------------
/hypll/nn/modules/container.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module, Sequential
2 |
3 | from hypll.manifolds import Manifold
4 | from hypll.tensors import ManifoldTensor
5 | from hypll.utils.layer_utils import check_if_manifolds_match
6 |
7 |
8 | class TangentSequential(Module):
9 | def __init__(self, seq: Sequential, manifold: Manifold) -> None:
10 | super(TangentSequential, self).__init__()
11 | self.seq = seq
12 | self.manifold = manifold
13 |
14 | def forward(self, input: ManifoldTensor) -> ManifoldTensor:
15 | check_if_manifolds_match(layer=self, input=input)
16 | man_dim = input.man_dim
17 |
18 | input = self.manifold.logmap(x=None, y=input)
19 | for module in self.seq:
20 | input.tensor = module(input.tensor)
21 | return self.manifold.expmap(input)
22 |
--------------------------------------------------------------------------------
/hypll/nn/modules/convolution.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module
2 | from torch.nn.common_types import _size_1_t, _size_2_t
3 |
4 | from hypll.manifolds import Manifold
5 | from hypll.tensors import ManifoldTensor
6 | from hypll.utils.layer_utils import check_if_man_dims_match, check_if_manifolds_match
7 |
8 |
9 | class HConvolution2d(Module):
10 | """Applies a 2D convolution over a hyperbolic input signal.
11 |
12 | Attributes:
13 | in_channels:
14 | Number of channels in the input image.
15 | out_channels:
16 | Number of channels produced by the convolution.
17 | kernel_size:
18 | Size of the convolving kernel.
19 | manifold:
20 | Hyperbolic manifold of the tensors.
21 | bias:
22 | If True, adds a learnable bias to the output. Default: True
23 | stride:
24 | Stride of the convolution. Default: 1
25 | padding:
26 | Padding added to all four sides of the input. Default: 0
27 | id_init:
28 | Use identity initialization (True) if appropriate or use HNN++ initialization (False).
29 |
30 | """
31 |
32 | def __init__(
33 | self,
34 | in_channels: int,
35 | out_channels: int,
36 | kernel_size: _size_2_t,
37 | manifold: Manifold,
38 | bias: bool = True,
39 | stride: int = 1,
40 | padding: int = 0,
41 | id_init: bool = True,
42 | ) -> None:
43 | super(HConvolution2d, self).__init__()
44 | self.in_channels = in_channels
45 | self.out_channels = out_channels
46 | self.kernel_size = (
47 | kernel_size
48 | if isinstance(kernel_size, tuple) and len(kernel_size) == 2
49 | else (kernel_size, kernel_size)
50 | )
51 | self.kernel_vol = self.kernel_size[0] * self.kernel_size[1]
52 | self.manifold = manifold
53 | self.stride = stride
54 | self.padding = padding
55 | self.id_init = id_init
56 | self.has_bias = bias
57 |
58 | self.weights, self.bias = self.manifold.construct_dl_parameters(
59 | in_features=self.kernel_vol * in_channels,
60 | out_features=out_channels,
61 | bias=self.has_bias,
62 | )
63 |
64 | self.reset_parameters()
65 |
66 | def reset_parameters(self) -> None:
67 | """Resets parameter weights based on the manifold."""
68 | self.manifold.reset_parameters(weight=self.weights, bias=self.bias)
69 |
70 | def forward(self, x: ManifoldTensor) -> ManifoldTensor:
71 | """Does a forward pass of the 2D convolutional layer.
72 |
73 | Args:
74 | x:
75 | Manifold tensor of shape (B, C_in, H, W) with manifold dimension 1.
76 |
77 | Returns:
78 | Manifold tensor of shape (B, C_in, H_out, W_out) with manifold dimension 1.
79 |
80 | Raises:
81 | ValueError: If the manifolds or manifold dimensions don't match.
82 |
83 | """
84 | check_if_manifolds_match(layer=self, input=x)
85 | check_if_man_dims_match(layer=self, man_dim=1, input=x)
86 |
87 | batch_size, height, width = x.size(0), x.size(2), x.size(3)
88 | out_height = _output_side_length(
89 | input_side_length=height,
90 | kernel_size=self.kernel_size[0],
91 | padding=self.padding,
92 | stride=self.stride,
93 | )
94 | out_width = _output_side_length(
95 | input_side_length=width,
96 | kernel_size=self.kernel_size[1],
97 | padding=self.padding,
98 | stride=self.stride,
99 | )
100 |
101 | x = self.manifold.unfold(
102 | input=x,
103 | kernel_size=self.kernel_size,
104 | padding=self.padding,
105 | stride=self.stride,
106 | )
107 | x = self.manifold.fully_connected(x=x, z=self.weights, bias=self.bias)
108 | x = ManifoldTensor(
109 | data=x.tensor.reshape(batch_size, self.out_channels, out_height, out_width),
110 | manifold=x.manifold,
111 | man_dim=1,
112 | )
113 | return x
114 |
115 |
116 | def _output_side_length(
117 | input_side_length: int, kernel_size: _size_1_t, padding: int, stride: int
118 | ) -> int:
119 | """Calculates the output side length of the kernel.
120 |
121 | Based on https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html.
122 |
123 | """
124 | if kernel_size > input_side_length:
125 | raise RuntimeError(
126 | f"Encountered invalid kernel size {kernel_size} "
127 | f"larger than input side length {input_side_length}"
128 | )
129 | if stride > input_side_length:
130 | raise RuntimeError(
131 | f"Encountered invalid stride {stride} "
132 | f"larger than input side length {input_side_length}"
133 | )
134 | return 1 + (input_side_length + 2 * padding - (kernel_size - 1) - 1) // stride
135 |
--------------------------------------------------------------------------------
/hypll/nn/modules/embedding.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor, empty, no_grad, normal, ones_like, zeros_like
2 | from torch.nn import Module
3 |
4 | from hypll.manifolds import Manifold
5 | from hypll.tensors import ManifoldParameter, ManifoldTensor, TangentTensor
6 |
7 |
8 | class HEmbedding(Module):
9 | def __init__(
10 | self,
11 | num_embeddings: int,
12 | embedding_dim: int,
13 | manifold: Manifold,
14 | ) -> None:
15 | super(HEmbedding, self).__init__()
16 | self.num_embeddings = num_embeddings
17 | self.embedding_dim = embedding_dim
18 | self.manifold = manifold
19 |
20 | self.weight = ManifoldParameter(
21 | data=empty((num_embeddings, embedding_dim)), manifold=manifold, man_dim=-1
22 | )
23 | self.reset_parameters()
24 |
25 | def reset_parameters(self):
26 | with no_grad():
27 | new_weight = normal(
28 | mean=zeros_like(self.weight.tensor),
29 | std=ones_like(self.weight.tensor),
30 | )
31 | new_weight = TangentTensor(
32 | data=new_weight,
33 | manifold_points=None,
34 | manifold=self.manifold,
35 | man_dim=-1,
36 | )
37 | self.weight.copy_(self.manifold.expmap(new_weight).tensor)
38 |
39 | def forward(self, input: Tensor) -> ManifoldTensor:
40 | return self.weight[input]
41 |
--------------------------------------------------------------------------------
/hypll/nn/modules/flatten.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module, functional
2 |
3 | from hypll.tensors import ManifoldTensor
4 |
5 |
6 | class HFlatten(Module):
7 | """Flattens a contiguous range of dims into a tensor.
8 |
9 | Attributes:
10 | start_dim:
11 | First dimension to flatten (default = 1).
12 | end_dim:
13 | Last dimension to flatten (default = -1).
14 |
15 | """
16 |
17 | def __init__(self, start_dim: int = 1, end_dim: int = -1):
18 | super(HFlatten, self).__init__()
19 | self.start_dim = start_dim
20 | self.end_dim = end_dim
21 |
22 | def forward(self, x: ManifoldTensor) -> ManifoldTensor:
23 | """Flattens the manifold input tensor."""
24 | return x.flatten(start_dim=self.start_dim, end_dim=self.end_dim)
25 |
--------------------------------------------------------------------------------
/hypll/nn/modules/fold.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module
2 | from torch.nn.common_types import _size_2_t
3 |
4 | from hypll.manifolds import Manifold
5 | from hypll.tensors import ManifoldTensor
6 | from hypll.utils.layer_utils import check_if_man_dims_match, check_if_manifolds_match
7 |
8 |
9 | class HUnfold(Module):
10 | def __init__(
11 | self,
12 | kernel_size: _size_2_t,
13 | manifold: Manifold,
14 | dilation: _size_2_t = 1,
15 | padding: _size_2_t = 0,
16 | stride: _size_2_t = 1,
17 | ) -> None:
18 | self.kernel_size = kernel_size
19 | self.manifold = manifold
20 | self.dilation = dilation
21 | self.padding = padding
22 | self.stride = stride
23 |
24 | def forward(self, input: ManifoldTensor) -> ManifoldTensor:
25 | check_if_manifolds_match(layer=self, input=input)
26 | check_if_man_dims_match(layer=self, man_dim=1, input=input)
27 | return self.manifold.unfold(
28 | input=input,
29 | kernel_size=self.kernel_size,
30 | dilation=self.dilation,
31 | padding=self.padding,
32 | stride=self.stride,
33 | )
34 |
--------------------------------------------------------------------------------
/hypll/nn/modules/linear.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module
2 |
3 | from hypll.manifolds import Manifold
4 | from hypll.tensors import ManifoldTensor
5 | from hypll.utils.layer_utils import check_if_man_dims_match, check_if_manifolds_match
6 |
7 |
8 | class HLinear(Module):
9 | """Poincare fully connected linear layer"""
10 |
11 | def __init__(
12 | self,
13 | in_features: int,
14 | out_features: int,
15 | manifold: Manifold,
16 | bias: bool = True,
17 | ) -> None:
18 | super(HLinear, self).__init__()
19 | self.in_features = in_features
20 | self.out_features = out_features
21 | self.manifold = manifold
22 | self.has_bias = bias
23 |
24 | # TODO: torch stores weights transposed supposedly due to efficiency
25 | # https://discuss.pytorch.org/t/why-does-the-linear-module-seems-to-do-unnecessary-transposing/6277/7
26 | # We may want to do the same
27 | self.z, self.bias = self.manifold.construct_dl_parameters(
28 | in_features=in_features, out_features=out_features, bias=self.has_bias
29 | )
30 |
31 | self.reset_parameters()
32 |
33 | def reset_parameters(self) -> None:
34 | self.manifold.reset_parameters(self.z, self.bias)
35 |
36 | def forward(self, x: ManifoldTensor) -> ManifoldTensor:
37 | check_if_manifolds_match(layer=self, input=x)
38 | check_if_man_dims_match(layer=self, man_dim=-1, input=x)
39 | return self.manifold.fully_connected(x=x, z=self.z, bias=self.bias)
40 |
--------------------------------------------------------------------------------
/hypll/nn/modules/pooling.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Optional
3 |
4 | from torch.nn import Module
5 | from torch.nn.common_types import _size_2_t, _size_any_t
6 | from torch.nn.functional import max_pool2d
7 |
8 | from hypll.manifolds import Manifold
9 | from hypll.tensors import ManifoldTensor
10 | from hypll.utils.layer_utils import (
11 | check_if_man_dims_match,
12 | check_if_manifolds_match,
13 | op_in_tangent_space,
14 | )
15 |
16 |
17 | class HAvgPool2d(Module):
18 | def __init__(
19 | self,
20 | kernel_size: _size_2_t,
21 | manifold: Manifold,
22 | stride: Optional[_size_2_t] = None,
23 | padding: _size_2_t = 0,
24 | use_midpoint: bool = False,
25 | ):
26 | super().__init__()
27 | self.kernel_size = (
28 | kernel_size
29 | if isinstance(kernel_size, tuple) and len(kernel_size) == 2
30 | else (kernel_size, kernel_size)
31 | )
32 | self.manifold = manifold
33 | self.stride = stride if (stride is not None) else self.kernel_size
34 | self.padding = (
35 | padding if isinstance(padding, tuple) and len(padding) == 2 else (padding, padding)
36 | )
37 | self.use_midpoint = use_midpoint
38 |
39 | def forward(self, input: ManifoldTensor) -> ManifoldTensor:
40 | check_if_manifolds_match(layer=self, input=input)
41 | check_if_man_dims_match(layer=self, man_dim=1, input=input)
42 |
43 | batch_size, channels, height, width = input.size()
44 | out_height = int((height + 2 * self.padding[0] - self.kernel_size[0]) / self.stride[0] + 1)
45 | out_width = int((width + 2 * self.padding[1] - self.kernel_size[1]) / self.stride[1] + 1)
46 |
47 | unfolded_input = self.manifold.unfold(
48 | input=input,
49 | kernel_size=self.kernel_size,
50 | padding=self.padding,
51 | stride=self.stride,
52 | )
53 | per_kernel_view = unfolded_input.tensor.view(
54 | batch_size,
55 | channels,
56 | self.kernel_size[0] * self.kernel_size[1],
57 | unfolded_input.size(-1),
58 | )
59 |
60 | x = ManifoldTensor(data=per_kernel_view, manifold=self.manifold, man_dim=1)
61 |
62 | if self.use_midpoint:
63 | aggregates = self.manifold.midpoint(x=x, batch_dim=2)
64 |
65 | else:
66 | aggregates = self.manifold.frechet_mean(x=x, batch_dim=2)
67 |
68 | return ManifoldTensor(
69 | data=aggregates.tensor.reshape(batch_size, channels, out_height, out_width),
70 | manifold=self.manifold,
71 | man_dim=1,
72 | )
73 |
74 |
75 | class _HMaxPoolNd(Module):
76 | def __init__(
77 | self,
78 | kernel_size: _size_any_t,
79 | manifold: Manifold,
80 | stride: Optional[_size_any_t] = None,
81 | padding: _size_any_t = 0,
82 | dilation: _size_any_t = 1,
83 | return_indices: bool = False,
84 | ceil_mode: bool = False,
85 | ) -> None:
86 | super().__init__()
87 | self.kernel_size = kernel_size
88 | self.manifold = manifold
89 | self.stride = stride if (stride is not None) else kernel_size
90 | self.padding = padding
91 | self.dilation = dilation
92 | self.return_indices = return_indices
93 | self.ceil_mode = ceil_mode
94 |
95 |
96 | class HMaxPool2d(_HMaxPoolNd):
97 | kernel_size: _size_2_t
98 | stride: _size_2_t
99 | padding: _size_2_t
100 | dilation: _size_2_t
101 |
102 | def forward(self, input: ManifoldTensor) -> ManifoldTensor:
103 | check_if_manifolds_match(layer=self, input=input)
104 |
105 | # TODO: check if defining this partial func each forward pass is slow.
106 | # If it is, put this stuff inside the init or add kwargs to op_in_tangent_space.
107 | max_pool2d_partial = partial(
108 | max_pool2d,
109 | kernel_size=self.kernel_size,
110 | stride=self.stride,
111 | padding=self.padding,
112 | dilation=self.dilation,
113 | ceil_mode=self.ceil_mode,
114 | return_indices=self.return_indices,
115 | )
116 | return op_in_tangent_space(op=max_pool2d_partial, manifold=self.manifold, input=input)
117 |
--------------------------------------------------------------------------------
/hypll/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from .adam import RiemannianAdam
2 | from .sgd import RiemannianSGD
3 |
--------------------------------------------------------------------------------
/hypll/optim/adam.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Iterable
2 | from typing import Union
3 |
4 | from torch import max, no_grad, zeros_like
5 | from torch.optim import Optimizer
6 |
7 | from hypll.manifolds import Manifold
8 | from hypll.manifolds.euclidean import Euclidean
9 | from hypll.tensors import ManifoldParameter, ManifoldTensor, TangentTensor
10 |
11 |
12 | class RiemannianAdam(Optimizer):
13 | def __init__(
14 | self,
15 | params: Iterable[Union[ManifoldParameter, ManifoldTensor]],
16 | lr: float,
17 | betas: tuple[float, float] = (0.9, 0.999),
18 | eps: float = 1e-8,
19 | weight_decay: float = 0,
20 | amsgrad: bool = False,
21 | ) -> None:
22 | if lr < 0.0:
23 | raise ValueError("Invalid learning rate: {}".format(lr))
24 | if eps < 0.0:
25 | raise ValueError("Invalid epsilon value: {}".format(eps))
26 | if not 0.0 <= betas[0] < 1.0:
27 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
28 | if not 0.0 <= betas[1] < 1.0:
29 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
30 | if weight_decay < 0.0:
31 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
32 |
33 | defaults = dict(
34 | lr=lr,
35 | betas=betas,
36 | eps=eps,
37 | weight_decay=weight_decay,
38 | amsgrad=amsgrad,
39 | )
40 | super(RiemannianAdam, self).__init__(params=params, defaults=defaults)
41 |
42 | def step(self) -> None:
43 | # TODO: refactor this and add some comments, because it's currently unreadable
44 | with no_grad():
45 | for group in self.param_groups:
46 | betas = group["betas"]
47 | weight_decay = group["weight_decay"]
48 | eps = group["eps"]
49 | lr = group["lr"]
50 | amsgrad = group["amsgrad"]
51 | for param in group["params"]:
52 | if isinstance(param, ManifoldParameter):
53 | manifold: Manifold = param.manifold
54 | grad = param.grad
55 | if grad is None:
56 | continue
57 | grad = ManifoldTensor(
58 | data=grad, manifold=Euclidean(), man_dim=param.man_dim
59 | )
60 | state = self.state[param]
61 |
62 | if len(state) == 0:
63 | state["step"] = 0
64 | state["exp_avg"] = zeros_like(param.tensor)
65 | state["exp_avg_sq"] = zeros_like(param.tensor)
66 | if amsgrad:
67 | state["max_exp_avg_sq"] = zeros_like(param.tensor)
68 |
69 | state["step"] += 1
70 | exp_avg = state["exp_avg"]
71 | exp_avg_sq = state["exp_avg_sq"]
72 |
73 | # TODO: check if this next line makes sense, because I don't think so
74 | grad.tensor.add_(param.tensor, alpha=weight_decay)
75 | grad = manifold.euc_to_tangent(x=param, u=grad)
76 | exp_avg.mul_(betas[0]).add_(grad.tensor, alpha=1 - betas[0])
77 | exp_avg_sq.mul_(betas[1]).add_(
78 | manifold.inner(u=grad, v=grad, keepdim=True, safe_mode=False),
79 | alpha=1 - betas[1],
80 | )
81 | bias_correction1 = 1 - betas[0] ** state["step"]
82 | bias_correction2 = 1 - betas[1] ** state["step"]
83 |
84 | if amsgrad:
85 | max_exp_avg_sq = state["max_exp_avg_sq"]
86 | max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
87 | denom = max_exp_avg_sq.div(bias_correction2).sqrt_()
88 | else:
89 | denom = exp_avg_sq.div(bias_correction2).sqrt_()
90 |
91 | direction = -lr * exp_avg.div(bias_correction1) / denom.add_(eps)
92 | direction = TangentTensor(
93 | data=direction,
94 | manifold_points=param,
95 | manifold=manifold,
96 | man_dim=param.man_dim,
97 | )
98 | exp_avg_man = TangentTensor(
99 | data=exp_avg,
100 | manifold_points=param,
101 | manifold=manifold,
102 | man_dim=param.man_dim,
103 | )
104 |
105 | new_param = manifold.expmap(direction)
106 | exp_avg_new = manifold.transp(v=exp_avg_man, y=new_param)
107 |
108 | param.tensor.copy_(new_param.tensor)
109 | exp_avg.copy_(exp_avg_new.tensor)
110 |
111 | else:
112 | grad = param.grad
113 | if grad is None:
114 | continue
115 |
116 | state = self.state[param]
117 |
118 | if len(state) == 0:
119 | state["step"] = 0
120 | state["exp_avg"] = zeros_like(param)
121 | state["exp_avg_sq"] = zeros_like(param)
122 | if amsgrad:
123 | state["max_exp_avg_sq"] = zeros_like(param)
124 | state["step"] += 1
125 | exp_avg = state["exp_avg"]
126 | exp_avg_sq = state["exp_avg_sq"]
127 | grad.add_(param, alpha=weight_decay)
128 | exp_avg.mul_(betas[0]).add_(grad, alpha=1 - betas[0])
129 | exp_avg_sq.mul_(betas[1]).add_(
130 | grad.square().sum(dim=-1, keepdim=True), alpha=1 - betas[1]
131 | )
132 | bias_correction1 = 1 - betas[0] ** state["step"]
133 | bias_correction2 = 1 - betas[1] ** state["step"]
134 | if amsgrad:
135 | max_exp_avg_sq = state["max_exp_avg_sq"]
136 | max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
137 | denom = max_exp_avg_sq.div(bias_correction2).sqrt_()
138 | else:
139 | denom = exp_avg_sq.div(bias_correction2).sqrt_()
140 |
141 | direction = exp_avg.div(bias_correction1) / denom.add_(eps)
142 |
143 | new_param = param - lr * direction
144 | exp_avg_new = exp_avg
145 |
146 | param.copy_(new_param)
147 | exp_avg.copy_(exp_avg_new)
148 |
--------------------------------------------------------------------------------
/hypll/optim/sgd.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Iterable
2 | from typing import Union
3 |
4 | from torch import no_grad
5 | from torch.optim import Optimizer
6 |
7 | from hypll.manifolds import Manifold
8 | from hypll.manifolds.euclidean import Euclidean
9 | from hypll.tensors import ManifoldParameter, ManifoldTensor
10 |
11 |
12 | class RiemannianSGD(Optimizer):
13 | def __init__(
14 | self,
15 | params: Iterable[Union[ManifoldParameter, ManifoldTensor]],
16 | lr: float,
17 | momentum: float = 0,
18 | dampening: float = 0,
19 | weight_decay: float = 0,
20 | nesterov: bool = False,
21 | ) -> None:
22 | if lr < 0.0:
23 | raise ValueError("Invalid learning rate: {}".format(lr))
24 | if momentum < 0.0:
25 | raise ValueError("Invalid momentum value: {}".format(momentum))
26 | if weight_decay < 0.0:
27 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
28 | if nesterov and (momentum <= 0 or dampening != 0):
29 | raise ValueError("Nesterov momentum requires a momentum and zero dampening")
30 |
31 | defaults = dict(
32 | lr=lr,
33 | momentum=momentum,
34 | dampening=dampening,
35 | weight_decay=weight_decay,
36 | nesterov=nesterov,
37 | )
38 |
39 | super(RiemannianSGD, self).__init__(params=params, defaults=defaults)
40 |
41 | def step(self) -> None:
42 | with no_grad():
43 | for group in self.param_groups:
44 | lr = group["lr"]
45 | momentum = group["momentum"]
46 | dampening = group["dampening"]
47 | weight_decay = group["weight_decay"]
48 | nestorov = group["nesterov"]
49 |
50 | for param in group["params"]:
51 | if isinstance(param, ManifoldParameter):
52 | grad = param.grad
53 | if grad is None:
54 | continue
55 | grad = ManifoldTensor(
56 | data=grad, manifold=Euclidean(), man_dim=param.man_dim
57 | )
58 | state = self.state[param]
59 |
60 | if len(state) == 0:
61 | if momentum > 0:
62 | state["momentum_buffer"] = ManifoldTensor(
63 | data=grad.tensor.clone(),
64 | manifold=Euclidean(),
65 | man_dim=param.man_dim,
66 | )
67 |
68 | manifold: Manifold = param.manifold
69 |
70 | grad.tensor.add_(
71 | param.tensor, alpha=weight_decay
72 | ) # TODO: check if this makes sense
73 | grad = manifold.euc_to_tangent(x=param, u=grad)
74 | if momentum > 0:
75 | momentum_buffer = manifold.euc_to_tangent(
76 | x=param, u=state["momentum_buffer"]
77 | )
78 | momentum_buffer.tensor.mul_(momentum).add_(
79 | grad.tensor, alpha=1 - dampening
80 | )
81 | if nestorov:
82 | grad.tensor.add_(momentum_buffer.tensor, alpha=momentum)
83 | else:
84 | grad = momentum_buffer
85 |
86 | grad.tensor = -lr * grad.tensor
87 |
88 | new_param = manifold.expmap(grad)
89 |
90 | momentum_buffer = manifold.transp(v=momentum_buffer, y=new_param)
91 |
92 | # momentum_buffer.tensor.copy_(new_momentum_buffer.tensor)
93 | param.tensor.copy_(new_param.tensor)
94 | else:
95 | grad.tensor = -lr * grad.tensor
96 | new_param = manifold.expmap(v=grad)
97 | param.tensor.copy_(new_param.tensor)
98 |
99 | else:
100 | grad = param.grad
101 | if grad is None:
102 | continue
103 | state = self.state[param]
104 |
105 | if len(state) == 0:
106 | if momentum > 0:
107 | state["momentum_buffer"] = grad.clone()
108 |
109 | grad.add_(param, alpha=weight_decay)
110 | if momentum > 0:
111 | momentum_buffer = state["momentum_buffer"]
112 | momentum_buffer.mul_(momentum).add_(grad, alpha=1 - dampening)
113 | if nestorov:
114 | grad = grad.add_(momentum_buffer, alpha=momentum)
115 | else:
116 | grad = momentum_buffer
117 |
118 | new_param = param - lr * grad
119 | new_momentum_buffer = momentum_buffer
120 |
121 | momentum_buffer.copy_(new_momentum_buffer)
122 | param.copy_(new_param)
123 | else:
124 | new_param = param - lr * grad
125 | param.copy_(new_param)
126 |
--------------------------------------------------------------------------------
/hypll/tensors/__init__.py:
--------------------------------------------------------------------------------
1 | from .manifold_parameter import ManifoldParameter
2 | from .manifold_tensor import ManifoldTensor
3 | from .tangent_tensor import TangentTensor
4 |
--------------------------------------------------------------------------------
/hypll/tensors/manifold_parameter.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import torch
4 | from torch import Tensor, tensor
5 | from torch.nn import Parameter
6 |
7 | from hypll.manifolds import Manifold
8 | from hypll.tensors.manifold_tensor import ManifoldTensor
9 |
10 |
11 | class ManifoldParameter(ManifoldTensor, Parameter):
12 | _allowed_methods = [
13 | torch._has_compatible_shallow_copy_type, # Required for torch.nn.Parameter
14 | torch.Tensor.copy_, # Required to load ManifoldParameters state dicts
15 | ]
16 |
17 | def __new__(cls, data, manifold, man_dim, requires_grad=True):
18 | return super(ManifoldTensor, cls).__new__(cls)
19 |
20 | # TODO: Create a mixin class containing the methods for this class and for ManifoldTensor
21 | # to avoid all the boilerplate stuff.
22 | def __init__(
23 | self, data, manifold: Manifold, man_dim: int = -1, requires_grad: bool = True
24 | ) -> None:
25 | super(ManifoldParameter, self).__init__(data=data, manifold=manifold)
26 | if isinstance(data, Parameter):
27 | self.tensor = data
28 | elif isinstance(data, Tensor):
29 | self.tensor = Parameter(data=data, requires_grad=requires_grad)
30 | else:
31 | self.tensor = Parameter(data=tensor(data), requires_grad=requires_grad)
32 |
33 | self.manifold = manifold
34 |
35 | if man_dim >= 0:
36 | self.man_dim = man_dim
37 | else:
38 | self.man_dim = self.tensor.dim() + man_dim
39 | if self.man_dim < 0:
40 | raise ValueError(
41 | f"Dimension out of range (expected to be in range of "
42 | f"{[-self.tensor.dim() - 1, self.tensor.dim()]}, but got {man_dim})"
43 | )
44 |
45 | def __getattr__(self, name: str) -> Any:
46 | # TODO: go through https://pytorch.org/docs/stable/tensors.html and check which methods
47 | # are relevant.
48 | if hasattr(self.tensor, name):
49 | torch_attribute = getattr(self.tensor, name)
50 |
51 | if callable(torch_attribute):
52 | raise AttributeError(
53 | f"Attempting to apply the torch.nn.Parameter method {name} on a ManifoldParameter."
54 | f"Use ManifoldTensor.tensor.{name} instead."
55 | )
56 | else:
57 | return torch_attribute
58 |
59 | else:
60 | raise AttributeError(
61 | f"Neither {self.__class__.__name__}, nor torch.Tensor has attribute {name}"
62 | )
63 |
64 | @classmethod
65 | def __torch_function__(cls, func, types, args=(), kwargs=None):
66 | if func.__class__.__name__ == "method-wrapper" or func in cls._allowed_methods:
67 | args = [a.tensor if isinstance(a, ManifoldTensor) else a for a in args]
68 | if kwargs is None:
69 | kwargs = {}
70 | kwargs = {k: (v.tensor if isinstance(v, ManifoldTensor) else v) for k, v in kwargs}
71 | return func(*args, **kwargs)
72 | # if func.__name__ == "__get__":
73 | # return func(args[0].tensor)
74 | # TODO: check if there are torch functions that should be allowed
75 | raise TypeError(
76 | f"Attempting to apply the torch function {func} on a ManifoldParameter. "
77 | f"Use ManifoldParameter.tensor as argument to {func} instead."
78 | )
79 |
--------------------------------------------------------------------------------
/hypll/tensors/manifold_tensor.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Optional
4 |
5 | from torch import Tensor, long, tensor
6 |
7 | from hypll.manifolds import Manifold
8 |
9 |
10 | class ManifoldTensor:
11 | """Represents a tensor on a manifold.
12 |
13 | Attributes:
14 | tensor:
15 | Torch tensor of points on the manifold.
16 | manifold:
17 | Manifold instance.
18 | man_dim:
19 | Dimension along which points are on the manifold.
20 |
21 | """
22 |
23 | def __init__(
24 | self, data: Tensor, manifold: Manifold, man_dim: int = -1, requires_grad: bool = False
25 | ) -> None:
26 | """Creates an instance of ManifoldTensor.
27 |
28 | Args:
29 | data:
30 | Torch tensor of points on the manifold.
31 | manifold:
32 | Manifold instance.
33 | man_dim:
34 | Dimension along which points are on the manifold. -1 by default.
35 |
36 | TODO(Philipp, 05/23): Let's get rid of requires_grad if possible.
37 |
38 | """
39 | self.tensor = data if isinstance(data, Tensor) else tensor(data, requires_grad=True)
40 | self.manifold = manifold
41 |
42 | if man_dim >= 0:
43 | self.man_dim = man_dim
44 | else:
45 | self.man_dim = self.tensor.dim() + man_dim
46 | if self.man_dim < 0:
47 | raise ValueError(
48 | f"Dimension out of range (expected to be in range of "
49 | f"{[-self.tensor.dim() - 1, self.tensor.dim()]}, but got {man_dim})"
50 | )
51 |
52 | def __getitem__(self, *args):
53 | # Catch some undefined behaviour by checking if args is a single element
54 | if len(args) != 1:
55 | raise ValueError(
56 | f"No support for slicing with these arguments. If you think there should be "
57 | f"support, please consider opening a issue on GitHub describing your case."
58 | )
59 |
60 | # Deal with the case where the argument is a long tensor
61 | if isinstance(args[0], Tensor) and args[0].dtype == long:
62 | if self.man_dim == 0:
63 | raise ValueError(
64 | f"Long tensor indexing is only possible when the manifold dimension "
65 | f"is not 0, but the manifold dimension is {self.man_dim}"
66 | )
67 | new_tensor = self.tensor.__getitem__(*args)
68 | new_man_dim = self.man_dim + args[0].dim() - 1
69 | return ManifoldTensor(data=new_tensor, manifold=self.manifold, man_dim=new_man_dim)
70 |
71 | # Convert the args to a list and replace Ellipsis by the correct number of full slices
72 | if isinstance(args[0], int):
73 | arg_list = [args[0]]
74 | else:
75 | arg_list = list(args[0])
76 |
77 | if Ellipsis in arg_list:
78 | ell_id = arg_list.index(Ellipsis)
79 | colon_repeats = self.dim() - sum(1 for a in arg_list if a is not None) + 1
80 | arg_list[ell_id : ell_id + 1] = colon_repeats * [slice(None, None, None)]
81 |
82 | new_tensor = self.tensor.__getitem__(*args)
83 | output_man_dim = self.man_dim
84 | counter = self.man_dim + 1
85 |
86 | # Compute output manifold dimension
87 | for arg in arg_list:
88 | # None values add a dimension
89 | if arg is None:
90 | output_man_dim += 1
91 | continue
92 | # Integers remove a dimension
93 | elif isinstance(arg, int):
94 | output_man_dim -= 1
95 | counter -= 1
96 | # Other values leave the dimension intact
97 | else:
98 | counter -= 1
99 |
100 | # When the counter hits 0 and the next term isn't None, we hit the man_dim term
101 | if counter == 0:
102 | if isinstance(arg, int) or isinstance(arg, list):
103 | raise ValueError(
104 | f"Attempting to slice into the manifold dimension, but this is not a "
105 | "valid operation"
106 | )
107 | # If we get past the man_dim term, the output man_dim doesn't change anymore
108 | break
109 |
110 | return ManifoldTensor(data=new_tensor, manifold=self.manifold, man_dim=output_man_dim)
111 |
112 | def __hash__(self):
113 | """Returns the Python unique identifier of the object.
114 |
115 | Note: This is how PyTorch implements hash of tensors. See also:
116 | https://github.com/pytorch/pytorch/issues/2569.
117 |
118 | """
119 | return id(self)
120 |
121 | def cpu(self) -> ManifoldTensor:
122 | """Returns a copy of this object with self.tensor in CPU memory."""
123 | new_tensor = self.tensor.cpu()
124 | return ManifoldTensor(data=new_tensor, manifold=self.manifold, man_dim=self.man_dim)
125 |
126 | def cuda(self, device=None) -> ManifoldTensor:
127 | """Returns a copy of this object with self.tensor in CUDA memory."""
128 | new_tensor = self.tensor.cuda(device)
129 | return ManifoldTensor(data=new_tensor, manifold=self.manifold, man_dim=self.man_dim)
130 |
131 | def dim(self) -> int:
132 | """Returns the number of dimensions of self.tensor."""
133 | return self.tensor.dim()
134 |
135 | def detach(self) -> ManifoldTensor:
136 | """Returns a new Tensor, detached from the current graph."""
137 | detached = self.tensor.detach()
138 | return ManifoldTensor(data=detached, manifold=self.manifold, man_dim=self.man_dim)
139 |
140 | def flatten(self, start_dim: int = 0, end_dim: int = -1) -> ManifoldTensor:
141 | """Flattens tensor by reshaping it. If start_dim or end_dim are passed,
142 | only dimensions starting with start_dim and ending with end_dim are flattend.
143 |
144 | """
145 | return self.manifold.flatten(self, start_dim=start_dim, end_dim=end_dim)
146 |
147 | @property
148 | def is_cpu(self):
149 | return self.tensor.is_cpu
150 |
151 | @property
152 | def is_cuda(self):
153 | return self.tensor.is_cuda
154 |
155 | def is_floating_point(self) -> bool:
156 | """Returns true if the tensor is of dtype float."""
157 | return self.tensor.is_floating_point()
158 |
159 | def project(self) -> ManifoldTensor:
160 | """Projects the tensor to the manifold."""
161 | return self.manifold.project(x=self)
162 |
163 | @property
164 | def shape(self):
165 | """Alias for size()."""
166 | return self.size()
167 |
168 | def size(self, dim: Optional[int] = None):
169 | """Returns the size of self.tensor."""
170 | if dim is None:
171 | return self.tensor.size()
172 | else:
173 | return self.tensor.size(dim)
174 |
175 | def squeeze(self, dim=None):
176 | """Returns a squeezed version of the manifold tensor."""
177 | if dim == self.man_dim or (dim is None and self.size(self.man_dim) == 1):
178 | raise ValueError("Attempting to squeeze the manifold dimension")
179 |
180 | if dim is None:
181 | new_tensor = self.tensor.squeeze()
182 | new_man_dim = self.man_dim - sum(self.size(d) == 1 for d in range(self.man_dim))
183 | else:
184 | new_tensor = self.tensor.squeeze(dim=dim)
185 | new_man_dim = self.man_dim - (1 if dim < self.man_dim else 0)
186 |
187 | return ManifoldTensor(data=new_tensor, manifold=self.manifold, man_dim=new_man_dim)
188 |
189 | def to(self, *args, **kwargs) -> ManifoldTensor:
190 | """Returns a new tensor with the specified device and (optional) dtype."""
191 | new_tensor = self.tensor.to(*args, **kwargs)
192 | return ManifoldTensor(data=new_tensor, manifold=self.manifold, man_dim=self.man_dim)
193 |
194 | def transpose(self, dim0: int, dim1: int) -> ManifoldTensor:
195 | """Returns a transposed version of the manifold tensor. The given dimensions
196 | dim0 and dim1 are swapped.
197 | """
198 | if self.man_dim == dim0:
199 | new_man_dim = dim1
200 | elif self.man_dim == dim1:
201 | new_man_dim = dim0
202 | new_tensor = self.tensor.transpose(dim0, dim1)
203 | return ManifoldTensor(data=new_tensor, manifold=self.manifold, man_dim=new_man_dim)
204 |
205 | @classmethod
206 | def __torch_function__(cls, func, types, args=(), kwargs=None):
207 | # TODO: check if there are torch functions that should be allowed
208 | raise TypeError(
209 | f"Attempting to apply the torch function {func} on a ManifoldTensor. "
210 | f"Use ManifoldTensor.tensor as argument to {func} instead."
211 | )
212 |
213 | def unsqueeze(self, dim: int) -> ManifoldTensor:
214 | """Returns a new manifold tensor with a dimension of size one inserted at the specified position."""
215 | new_tensor = self.tensor.unsqueeze(dim=dim)
216 | new_man_dim = self.man_dim + (1 if dim <= self.man_dim else 0)
217 | return ManifoldTensor(data=new_tensor, manifold=self.manifold, man_dim=new_man_dim)
218 |
--------------------------------------------------------------------------------
/hypll/tensors/tangent_tensor.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional
2 |
3 | from torch import Tensor, broadcast_shapes, tensor
4 |
5 | from hypll.manifolds import Manifold
6 | from hypll.tensors.manifold_tensor import ManifoldTensor
7 |
8 |
9 | class TangentTensor:
10 | def __init__(
11 | self,
12 | data,
13 | manifold_points: Optional[ManifoldTensor] = None,
14 | manifold: Optional[Manifold] = None,
15 | man_dim: int = -1,
16 | requires_grad: bool = False,
17 | ) -> None:
18 | # Create tangent vector tensor with correct device and dtype
19 | if isinstance(data, Tensor):
20 | self.tensor = data
21 | else:
22 | self.tensor = tensor(data, requires_grad=requires_grad)
23 |
24 | # Store manifold dimension as a nonnegative integer
25 | if man_dim >= 0:
26 | self.man_dim = man_dim
27 | else:
28 | self.man_dim = self.tensor.dim() + man_dim
29 | if self.man_dim < 0:
30 | raise ValueError(
31 | f"Dimension out of range (expected to be in range of "
32 | f"{[-self.tensor.dim() - 1, self.tensor.dim()]}, but got {man_dim})"
33 | )
34 |
35 | if manifold_points is not None:
36 | # Check if the manifold points and tangent vectors are broadcastable together
37 | try:
38 | broadcasted_size = broadcast_shapes(self.tensor.size(), manifold_points.size())
39 | except RuntimeError:
40 | raise ValueError(
41 | f"The shapes of the manifold points tensor {manifold_points.size()} and "
42 | f"the tangent vector tensor {self.tensor.size()} are not broadcastable "
43 | f"togther."
44 | )
45 |
46 | # Check if the manifold dimensions match after broadcasting
47 | dim = len(broadcasted_size)
48 | broadcast_man_dims = [
49 | manifold_points.man_dim + dim - manifold_points.tensor.dim(),
50 | self.man_dim + dim - self.tensor.dim(),
51 | ]
52 | if broadcast_man_dims[0] != broadcast_man_dims[1]:
53 | raise ValueError(
54 | f"After broadcasting the manifold points with the tangent vectors, the "
55 | f"manifold dimension computed from the manifold points should match the "
56 | f"manifold dimension computed from the supplied man_dim, but these are"
57 | f"{broadcast_man_dims}, respectively."
58 | )
59 |
60 | # Check if the supplied manifolds match
61 | if manifold_points is not None and manifold is not None:
62 | if manifold_points.manifold != manifold:
63 | raise ValueError(
64 | f"The manifold of the manifold_points and the provided manifold should match, "
65 | f"but are {manifold_points.manifold} and {manifold}, respectively."
66 | )
67 |
68 | self.manifold_points = manifold_points
69 | self.manifold = manifold or manifold_points.manifold
70 |
71 | def __getattr__(self, name: str) -> Any:
72 | # TODO: go through https://pytorch.org/docs/stable/tensors.html and check which methods
73 | # are relevant.
74 | if hasattr(self.tensor, name):
75 | torch_attribute = getattr(self.tensor, name)
76 |
77 | if callable(torch_attribute):
78 | raise AttributeError(
79 | f"Attempting to apply the torch.Tensor method {name} on a TangentTensor."
80 | f"Use TangentTensor.tensor.{name} or TangentTensor.manifold_points.tensor "
81 | f"instead."
82 | )
83 | else:
84 | return torch_attribute
85 |
86 | else:
87 | raise AttributeError(
88 | f"Neither {self.__class__.__name__}, nor torch.Tensor has attribute {name}"
89 | )
90 |
91 | def __hash__(self):
92 | """Returns the Python unique identifier of the object.
93 |
94 | Note: This is how PyTorch implements hash of tensors. See also:
95 | https://github.com/pytorch/pytorch/issues/2569.
96 |
97 | """
98 | return id(self)
99 |
100 | def cuda(self, device=None):
101 | new_tensor = self.tensor.cuda(device)
102 | new_manifold_points = self.manifold_points.cuda(device)
103 | return TangentTensor(
104 | data=new_tensor,
105 | manifold_points=new_manifold_points,
106 | manifold=self.manifold,
107 | man_dim=self.man_dim,
108 | )
109 |
110 | def cpu(self):
111 | new_tensor = self.tensor.cpu()
112 | new_manifold_points = self.manifold_points.cpu()
113 | return TangentTensor(
114 | data=new_tensor,
115 | manifold_points=new_manifold_points,
116 | manifold=self.manifold,
117 | man_dim=self.man_dim,
118 | )
119 |
120 | def to(self, *args, **kwargs):
121 | new_tensor = self.tensor.to(*args, **kwargs)
122 | new_manifold_points = self.manifold_points(*args, **kwargs)
123 | return TangentTensor(
124 | data=new_tensor,
125 | manifold_points=new_manifold_points,
126 | manifold=self.manifold,
127 | man_dim=self.man_dim,
128 | )
129 |
130 | def size(self, dim: Optional[int] = None):
131 | if self.manifold_points is None:
132 | manifold_points_size = None
133 | manifold_points_size = (
134 | self.manifold_points.size() if self.manifold_points is not None else ()
135 | )
136 | broadcasted_size = broadcast_shapes(self.tensor.size(), manifold_points_size)
137 | if dim is None:
138 | return broadcasted_size
139 | else:
140 | return broadcasted_size[dim]
141 |
142 | @property
143 | def broadcasted_man_dim(self):
144 | return self.man_dim + self.dim() - self.tensor.dim()
145 |
146 | def dim(self):
147 | return len(self.size())
148 |
149 | @classmethod
150 | def __torch_function__(cls, func, types, args=(), kwargs=None):
151 | # TODO: check if there are torch functions that should be allowed
152 | raise TypeError(
153 | f"Attempting to apply the torch function {func} on a TangentTensor."
154 | f"Use TangentTensor.tensor as argument to {func} instead."
155 | )
156 |
--------------------------------------------------------------------------------
/hypll/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxvanspengler/hyperbolic_learning_library/62ea7882848d878459eba4100aeb6b8ad64cc38f/hypll/utils/__init__.py
--------------------------------------------------------------------------------
/hypll/utils/layer_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | from torch.nn import Module
4 |
5 | from hypll.manifolds import Manifold
6 | from hypll.tensors import ManifoldTensor
7 |
8 |
9 | def check_if_manifolds_match(layer: Module, input: ManifoldTensor) -> None:
10 | if layer.manifold != input.manifold:
11 | raise ValueError(
12 | f"Manifold of {layer.__class__.__name__} layer is {layer.manifold} "
13 | f"but input has manifold {input.manifold}"
14 | )
15 |
16 |
17 | def check_if_man_dims_match(layer: Module, man_dim: int, input: ManifoldTensor) -> None:
18 | if man_dim < 0:
19 | new_man_dim = input.dim() + man_dim
20 | else:
21 | new_man_dim = man_dim
22 |
23 | if input.man_dim != new_man_dim:
24 | raise ValueError(
25 | f"Layer of type {layer.__class__.__name__} expects the manifold dimension to be {man_dim}, "
26 | f"but input has manifold dimension {input.man_dim}"
27 | )
28 |
29 |
30 | def op_in_tangent_space(op: Callable, manifold: Manifold, input: ManifoldTensor) -> ManifoldTensor:
31 | input = manifold.logmap(x=None, y=input)
32 | input.tensor = op(input.tensor)
33 | return manifold.expmap(input)
34 |
--------------------------------------------------------------------------------
/hypll/utils/math.py:
--------------------------------------------------------------------------------
1 | from math import exp, lgamma
2 | from typing import Union
3 |
4 | import torch
5 |
6 |
7 | def beta_func(
8 | a: Union[float, torch.Tensor], b: Union[float, torch.Tensor]
9 | ) -> Union[float, torch.Tensor]:
10 | if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor):
11 | a = torch.as_tensor(a) if not isinstance(a, torch.Tensor) else a
12 | b = torch.as_tensor(b) if not isinstance(b, torch.Tensor) else b
13 | return torch.exp(torch.lgamma(a) + torch.lgamma(b) - torch.lgamma(a + b))
14 | return exp(lgamma(a) + lgamma(b) - lgamma(a + b))
15 |
--------------------------------------------------------------------------------
/hypll/utils/tensor_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple, Union
2 |
3 | from torch import broadcast_shapes, equal
4 |
5 | from hypll.tensors import ManifoldTensor, TangentTensor
6 |
7 |
8 | def check_dims_with_broadcasting(*args: ManifoldTensor) -> int:
9 | # Check if shapes can be broadcasted together
10 | shapes = [a.size() for a in args]
11 | try:
12 | broadcasted_shape = broadcast_shapes(*shapes)
13 | except RuntimeError:
14 | raise ValueError(f"Shapes of inputs were {shapes} and cannot be broadcasted together")
15 |
16 | # Find the manifold dimensions after broadcasting
17 | max_dim = len(broadcasted_shape)
18 | man_dims = []
19 | for a in args:
20 | if isinstance(a, ManifoldTensor):
21 | man_dims.append(a.man_dim + (max_dim - a.dim()))
22 | elif isinstance(a, TangentTensor):
23 | man_dims.append(a.broadcasted_man_dim + (max_dim - a.dim()))
24 |
25 | for md in man_dims[1:]:
26 | if man_dims[0] != md:
27 | raise ValueError("Manifold dimensions of inputs after broadcasting do not match.")
28 |
29 | return man_dims[0]
30 |
31 |
32 | def check_tangent_tensor_positions(*args: TangentTensor) -> None:
33 | manifold_points = [a.manifold_points for a in args]
34 | if any([mp is None for mp in manifold_points]):
35 | if not all(mp is None for mp in manifold_points):
36 | raise ValueError(f"Some but not all tangent tensors are located at the origin")
37 |
38 | else:
39 | broadcasted_shape = broadcast_shapes(*[a.size() for a in args])
40 | broadcasted_manifold_points = [
41 | mp.tensor.broadcast_to(broadcasted_shape) for mp in manifold_points
42 | ]
43 | for bmp in broadcasted_manifold_points[1:]:
44 | if not equal(broadcasted_manifold_points[0], bmp):
45 | raise ValueError(
46 | f"Tangent tensors are positioned at the different points on the manifold"
47 | )
48 |
49 |
50 | def check_if_man_dims_match(
51 | manifold_tensors: Union[Tuple[ManifoldTensor, ...], List[ManifoldTensor]]
52 | ) -> None:
53 | iterator = iter(manifold_tensors)
54 |
55 | try:
56 | first_item = next(iterator)
57 | except StopIteration:
58 | return
59 |
60 | for i, x in enumerate(iterator):
61 | if x.man_dim != first_item.man_dim:
62 | raise ValueError(
63 | f"Manifold dimensions of inputs do not match. "
64 | f"Input at index [0] has manifold dimension {first_item.man_dim} "
65 | f"but input at index [{i + 1}] has manifold dimension {x.man_dim}."
66 | )
67 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "hypll"
3 | version = "0.1.2"
4 | description = "A framework for hyperbolic learning in PyTorch"
5 | authors = ["Max van Spengler", "Philipp Wirth", "Pascal Mettes"]
6 |
7 | [tool.poetry.dependencies]
8 | python = "^3.10"
9 | torch = "*"
10 | # Optional dependencies for building the docs.
11 | # Install with poetry install --extras "docs"
12 | sphinx-copybutton = { version = "^0.5.2", optional = true }
13 | sphinx = { version = "^7.2.6", optional = true }
14 | sphinx-gallery = { version = "^0.15.0", optional = true }
15 | sphinx-tabs = { version = "^3.4.5", optional = true }
16 | torchvision = { version = "^0.17.2", optional = true }
17 | matplotlib = { version = "^3.8.4", optional = true }
18 | networkx = { version = "^3.2.1", optional = true }
19 |
20 | [tool.poetry.dev-dependencies]
21 | black = "*"
22 | isort = "*"
23 | pytest = "*"
24 | pytest-mock = "*"
25 |
26 | [tool.poetry.extras]
27 | docs = [
28 | "sphinx-copybutton",
29 | "sphinx",
30 | "sphinx-gallery",
31 | "sphinx-tabs",
32 | "torchvision",
33 | "matplotlib",
34 | "networkx",
35 | "timm",
36 | "fastai",
37 | ]
38 |
39 | [build-system]
40 | requires = ["poetry-core>=1.0.0"]
41 | build-backend = "poetry.core.masonry.api"
42 |
43 | [tool.black]
44 | line-length = 100
45 |
46 | [tool.isort]
47 | profile = "black"
48 |
--------------------------------------------------------------------------------
/tests/manifolds/euclidean/test_euclidean.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from hypll.manifolds.euclidean import Euclidean
4 | from hypll.tensors import ManifoldTensor
5 |
6 |
7 | def test_flatten__man_dim_equals_start_dim() -> None:
8 | tensor = torch.randn(2, 2, 2, 2)
9 | manifold = Euclidean()
10 | manifold_tensor = ManifoldTensor(
11 | data=tensor,
12 | manifold=manifold,
13 | man_dim=1,
14 | requires_grad=True,
15 | )
16 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=-1)
17 | assert flattened.shape == (2, 8)
18 | assert flattened.man_dim == 1
19 |
20 |
21 | def test_flatten__man_dim_equals_start_dim__set_end_dim() -> None:
22 | tensor = torch.randn(2, 2, 2, 2)
23 | manifold = Euclidean()
24 | manifold_tensor = ManifoldTensor(
25 | data=tensor,
26 | manifold=manifold,
27 | man_dim=1,
28 | requires_grad=True,
29 | )
30 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=2)
31 | assert flattened.shape == (2, 4, 2)
32 | assert flattened.man_dim == 1
33 |
34 |
35 | def test_flatten__man_dim_larger_start_dim() -> None:
36 | tensor = torch.randn(2, 2, 2, 2)
37 | manifold = Euclidean()
38 | manifold_tensor = ManifoldTensor(
39 | data=tensor,
40 | manifold=manifold,
41 | man_dim=2,
42 | requires_grad=True,
43 | )
44 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=-1)
45 | assert flattened.shape == (2, 8)
46 | assert flattened.man_dim == 1
47 |
48 |
49 | def test_flatten__man_dim_larger_start_dim__set_end_dim() -> None:
50 | tensor = torch.randn(2, 2, 2, 2)
51 | manifold = Euclidean()
52 | manifold_tensor = ManifoldTensor(
53 | data=tensor,
54 | manifold=manifold,
55 | man_dim=2,
56 | requires_grad=True,
57 | )
58 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=2)
59 | assert flattened.shape == (2, 4, 2)
60 | assert flattened.man_dim == 1
61 |
62 |
63 | def test_flatten__man_dim_smaller_start_dim() -> None:
64 | tensor = torch.randn(2, 2, 2, 2)
65 | manifold = Euclidean()
66 | manifold_tensor = ManifoldTensor(
67 | data=tensor,
68 | manifold=manifold,
69 | man_dim=0,
70 | requires_grad=True,
71 | )
72 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=-1)
73 | assert flattened.shape == (2, 8)
74 | assert flattened.man_dim == 0
75 |
76 |
77 | def test_flatten__man_dim_larger_end_dim() -> None:
78 | tensor = torch.randn(2, 2, 2, 2)
79 | manifold = Euclidean()
80 | manifold_tensor = ManifoldTensor(
81 | data=tensor,
82 | manifold=manifold,
83 | man_dim=2,
84 | requires_grad=True,
85 | )
86 | flattened = manifold.flatten(manifold_tensor, start_dim=0, end_dim=1)
87 | assert flattened.shape == (4, 2, 2)
88 | assert flattened.man_dim == 1
89 |
90 |
91 | def test_cdist__correct_dist():
92 | B, P, R, M = 2, 3, 4, 8
93 | manifold = Euclidean()
94 | mt1 = ManifoldTensor(torch.randn(B, P, M), manifold=manifold)
95 | mt2 = ManifoldTensor(torch.randn(B, R, M), manifold=manifold)
96 | dist_matrix = manifold.cdist(mt1, mt2)
97 | for b in range(B):
98 | for p in range(P):
99 | for r in range(R):
100 | assert torch.isclose(
101 | dist_matrix[b, p, r], manifold.dist(mt1[b, p], mt2[b, r]), equal_nan=True
102 | )
103 |
104 |
105 | def test_cdist__correct_dims():
106 | B, P, R, M = 2, 3, 4, 8
107 | manifold = Euclidean()
108 | mt1 = ManifoldTensor(torch.randn(B, P, M), manifold=manifold)
109 | mt2 = ManifoldTensor(torch.randn(B, R, M), manifold=manifold)
110 | dist_matrix = manifold.cdist(mt1, mt2)
111 | assert dist_matrix.shape == (B, P, R)
112 |
113 |
114 | def test_cat__correct_dims():
115 | N, D1, D2 = 10, 2, 3
116 | manifold = Euclidean()
117 | manifold_tensors = [ManifoldTensor(torch.randn(D1, D2), manifold=manifold) for _ in range(N)]
118 | cat_0 = manifold.cat(manifold_tensors, dim=0)
119 | cat_1 = manifold.cat(manifold_tensors, dim=1)
120 | assert cat_0.shape == (D1 * N, D2)
121 | assert cat_1.shape == (D1, D2 * N)
122 |
123 |
124 | def test_cat__correct_man_dim():
125 | N, D1, D2 = 10, 2, 3
126 | manifold = Euclidean()
127 | manifold_tensors = [
128 | ManifoldTensor(torch.randn(D1, D2), manifold=manifold, man_dim=1) for _ in range(N)
129 | ]
130 | cat_0 = manifold.cat(manifold_tensors, dim=0)
131 | cat_1 = manifold.cat(manifold_tensors, dim=1)
132 | assert cat_0.man_dim == cat_1.man_dim == 1
133 |
--------------------------------------------------------------------------------
/tests/manifolds/poincare_ball/test_curvature.py:
--------------------------------------------------------------------------------
1 | from pytest_mock import MockerFixture
2 |
3 | from hypll.manifolds.poincare_ball import Curvature
4 |
5 |
6 | def test_curvature(mocker: MockerFixture) -> None:
7 | mocked_constraining_strategy = mocker.MagicMock()
8 | mocked_constraining_strategy.return_value = 1.33 # dummy value
9 | curvature = Curvature(value=1.0, constraining_strategy=mocked_constraining_strategy)
10 | assert curvature() == 1.33
11 | mocked_constraining_strategy.assert_called_once_with(1.0)
12 |
--------------------------------------------------------------------------------
/tests/manifolds/poincare_ball/test_poincare_ball.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from hypll.manifolds.poincare_ball import Curvature, PoincareBall
4 | from hypll.tensors import ManifoldTensor
5 |
6 |
7 | def test_flatten__man_dim_equals_start_dim() -> None:
8 | tensor = torch.randn(2, 2, 2, 2)
9 | manifold = PoincareBall(c=Curvature())
10 | manifold_tensor = ManifoldTensor(
11 | data=tensor,
12 | manifold=manifold,
13 | man_dim=1,
14 | requires_grad=True,
15 | )
16 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=-1)
17 | assert flattened.shape == (2, 8)
18 | assert flattened.man_dim == 1
19 |
20 |
21 | def test_flatten__man_dim_equals_start_dim__set_end_dim() -> None:
22 | tensor = torch.randn(2, 2, 2, 2)
23 | manifold = PoincareBall(c=Curvature())
24 | manifold_tensor = ManifoldTensor(
25 | data=tensor,
26 | manifold=manifold,
27 | man_dim=1,
28 | requires_grad=True,
29 | )
30 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=2)
31 | assert flattened.shape == (2, 4, 2)
32 | assert flattened.man_dim == 1
33 |
34 |
35 | def test_flatten__man_dim_larger_start_dim() -> None:
36 | tensor = torch.randn(2, 2, 2, 2)
37 | manifold = PoincareBall(c=Curvature())
38 | manifold_tensor = ManifoldTensor(
39 | data=tensor,
40 | manifold=manifold,
41 | man_dim=2,
42 | requires_grad=True,
43 | )
44 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=-1)
45 | assert flattened.shape == (2, 8)
46 | assert flattened.man_dim == 1
47 |
48 |
49 | def test_flatten__man_dim_larger_start_dim__set_end_dim() -> None:
50 | tensor = torch.randn(2, 2, 2, 2)
51 | manifold = PoincareBall(c=Curvature())
52 | manifold_tensor = ManifoldTensor(
53 | data=tensor,
54 | manifold=manifold,
55 | man_dim=2,
56 | requires_grad=True,
57 | )
58 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=2)
59 | assert flattened.shape == (2, 4, 2)
60 | assert flattened.man_dim == 1
61 |
62 |
63 | def test_flatten__man_dim_smaller_start_dim() -> None:
64 | tensor = torch.randn(2, 2, 2, 2)
65 | manifold = PoincareBall(c=Curvature())
66 | manifold_tensor = ManifoldTensor(
67 | data=tensor,
68 | manifold=manifold,
69 | man_dim=0,
70 | requires_grad=True,
71 | )
72 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=-1)
73 | assert flattened.shape == (2, 8)
74 | assert flattened.man_dim == 0
75 |
76 |
77 | def test_flatten__man_dim_larger_end_dim() -> None:
78 | tensor = torch.randn(2, 2, 2, 2)
79 | manifold = PoincareBall(c=Curvature())
80 | manifold_tensor = ManifoldTensor(
81 | data=tensor,
82 | manifold=manifold,
83 | man_dim=2,
84 | requires_grad=True,
85 | )
86 | flattened = manifold.flatten(manifold_tensor, start_dim=0, end_dim=1)
87 | assert flattened.shape == (4, 2, 2)
88 | assert flattened.man_dim == 1
89 |
90 |
91 | def test_cdist__correct_dist():
92 | B, P, R, M = 2, 3, 4, 8
93 | manifold = PoincareBall(c=Curvature())
94 | mt1 = ManifoldTensor(torch.randn(B, P, M), manifold=manifold)
95 | mt2 = ManifoldTensor(torch.randn(B, R, M), manifold=manifold)
96 | dist_matrix = manifold.cdist(mt1, mt2)
97 | for b in range(B):
98 | for p in range(P):
99 | for r in range(R):
100 | assert torch.isclose(
101 | dist_matrix[b, p, r], manifold.dist(mt1[b, p], mt2[b, r]), equal_nan=True
102 | )
103 |
104 |
105 | def test_cdist__correct_dims():
106 | B, P, R, M = 2, 3, 4, 8
107 | manifold = PoincareBall(c=Curvature())
108 | mt1 = ManifoldTensor(torch.randn(B, P, M), manifold=manifold)
109 | mt2 = ManifoldTensor(torch.randn(B, R, M), manifold=manifold)
110 | dist_matrix = manifold.cdist(mt1, mt2)
111 | assert dist_matrix.shape == (B, P, R)
112 |
113 |
114 | def test_cat__correct_dims():
115 | N, D1, D2 = 10, 2, 3
116 | manifold = PoincareBall(c=Curvature())
117 | manifold_tensors = [ManifoldTensor(torch.randn(D1, D2), manifold=manifold) for _ in range(N)]
118 | cat_0 = manifold.cat(manifold_tensors, dim=0)
119 | cat_1 = manifold.cat(manifold_tensors, dim=1)
120 | assert cat_0.shape == (D1 * N, D2)
121 | assert cat_1.shape == (D1, D2 * N)
122 |
123 |
124 | def test_cat__correct_man_dim():
125 | N, D1, D2 = 10, 2, 3
126 | manifold = PoincareBall(c=Curvature())
127 | manifold_tensors = [
128 | ManifoldTensor(torch.randn(D1, D2), manifold=manifold, man_dim=1) for _ in range(N)
129 | ]
130 | cat_0 = manifold.cat(manifold_tensors, dim=0)
131 | cat_1 = manifold.cat(manifold_tensors, dim=1)
132 | assert cat_0.man_dim == cat_1.man_dim == 1
133 |
134 |
135 | def test_cat__beta_concatenation_correct_norm():
136 | MD = 64
137 | manifold = PoincareBall(c=Curvature(0.1, constraining_strategy=lambda x: x))
138 | t1 = torch.randn(MD)
139 | t2 = torch.randn(MD)
140 | mt1 = ManifoldTensor(t1 / t1.norm(), manifold=manifold)
141 | mt2 = ManifoldTensor(t2 / t2.norm(), manifold=manifold)
142 | cat = manifold.cat([mt1, mt2])
143 | assert torch.isclose(cat.tensor.norm(), torch.as_tensor(1.0), atol=1e-2)
144 |
--------------------------------------------------------------------------------
/tests/nn/test_change_manifold.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from hypll.manifolds.euclidean import Euclidean
4 | from hypll.manifolds.poincare_ball import Curvature, PoincareBall
5 | from hypll.nn import ChangeManifold
6 | from hypll.tensors import ManifoldTensor, TangentTensor
7 |
8 |
9 | def test_change_manifold__euclidean_to_euclidean() -> None:
10 | # Define inputs.
11 | manifold = Euclidean()
12 | inputs = ManifoldTensor(data=torch.randn(10, 2), manifold=manifold, man_dim=1)
13 | # Apply change manifold.
14 | change_manifold = ChangeManifold(target_manifold=Euclidean())
15 | outputs = change_manifold(inputs)
16 | # Assert outputs are correct.
17 | assert isinstance(outputs, ManifoldTensor)
18 | assert outputs.shape == inputs.shape
19 | assert change_manifold.target_manifold == outputs.manifold
20 | assert outputs.man_dim == 1
21 | assert torch.allclose(inputs.tensor, outputs.tensor)
22 |
23 |
24 | def test_change_manifold__euclidean_to_poincare_ball() -> None:
25 | # Define inputs.
26 | manifold = Euclidean()
27 | inputs = ManifoldTensor(data=torch.randn(10, 2), manifold=manifold, man_dim=1)
28 | # Apply change manifold.
29 | change_manifold = ChangeManifold(
30 | target_manifold=PoincareBall(c=Curvature()),
31 | )
32 | outputs = change_manifold(inputs)
33 | # Assert outputs are correct.
34 | assert isinstance(outputs, ManifoldTensor)
35 | assert outputs.shape == inputs.shape
36 | assert change_manifold.target_manifold == outputs.manifold
37 | assert outputs.man_dim == 1
38 |
39 |
40 | def test_change_manifold__poincare_ball_to_euclidean() -> None:
41 | # Define inputs.
42 | manifold = PoincareBall(c=Curvature(0.1))
43 | tangents = TangentTensor(data=torch.randn(10, 2), manifold=manifold, man_dim=1)
44 | inputs = manifold.expmap(tangents)
45 | # Apply change manifold.
46 | change_manifold = ChangeManifold(target_manifold=Euclidean())
47 | outputs = change_manifold(inputs)
48 | # Assert outputs are correct.
49 | assert isinstance(outputs, ManifoldTensor)
50 | assert outputs.shape == inputs.shape
51 | assert change_manifold.target_manifold == outputs.manifold
52 | assert outputs.man_dim == 1
53 |
54 |
55 | def test_change_manifold__poincare_ball_to_poincare_ball() -> None:
56 | # Define inputs.
57 | manifold = PoincareBall(c=Curvature(0.1))
58 | tangents = TangentTensor(data=torch.randn(10, 2), manifold=manifold, man_dim=1)
59 | inputs = manifold.expmap(tangents)
60 | # Apply change manifold.
61 | change_manifold = ChangeManifold(
62 | target_manifold=PoincareBall(c=Curvature(1.0)),
63 | )
64 | outputs = change_manifold(inputs)
65 | # Assert outputs are correct.
66 | assert isinstance(outputs, ManifoldTensor)
67 | assert outputs.shape == inputs.shape
68 | assert change_manifold.target_manifold == outputs.manifold
69 | assert outputs.man_dim == 1
70 |
--------------------------------------------------------------------------------
/tests/nn/test_convolution.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from hypll.nn.modules import convolution
4 |
5 |
6 | def test__output_side_length() -> None:
7 | assert (
8 | convolution._output_side_length(input_side_length=32, kernel_size=3, stride=1, padding=0)
9 | == 30
10 | )
11 | assert (
12 | convolution._output_side_length(input_side_length=32, kernel_size=3, stride=2, padding=0)
13 | == 15
14 | )
15 |
16 |
17 | def test__output_side_length__padding() -> None:
18 | assert (
19 | convolution._output_side_length(input_side_length=32, kernel_size=3, stride=1, padding=1)
20 | == 32
21 | )
22 | assert (
23 | convolution._output_side_length(input_side_length=32, kernel_size=3, stride=2, padding=1)
24 | == 16
25 | )
26 | # Reproduce https://github.com/maxvanspengler/hyperbolic_learning_library/issues/33
27 | assert (
28 | convolution._output_side_length(input_side_length=224, kernel_size=11, stride=4, padding=2)
29 | == 55
30 | )
31 |
32 |
33 | def test__output_side_length__raises() -> None:
34 | with pytest.raises(RuntimeError) as e:
35 | convolution._output_side_length(input_side_length=32, kernel_size=33, stride=1, padding=0)
36 | with pytest.raises(RuntimeError) as e:
37 | convolution._output_side_length(input_side_length=32, kernel_size=3, stride=33, padding=0)
38 |
--------------------------------------------------------------------------------
/tests/nn/test_flatten.py:
--------------------------------------------------------------------------------
1 | from pytest_mock import MockerFixture
2 |
3 | from hypll.nn import HFlatten
4 |
5 |
6 | def test_hflatten(mocker: MockerFixture) -> None:
7 | hflatten = HFlatten(start_dim=1, end_dim=-1)
8 | mocked_manifold_tensor = mocker.MagicMock()
9 | flattened = hflatten(mocked_manifold_tensor)
10 | mocked_manifold_tensor.flatten.assert_called_once_with(start_dim=1, end_dim=-1)
11 |
--------------------------------------------------------------------------------
/tests/test_manifold_tensor.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from hypll.manifolds.poincare_ball import Curvature, PoincareBall
5 | from hypll.tensors import ManifoldTensor
6 |
7 |
8 | @pytest.fixture
9 | def manifold_tensor() -> ManifoldTensor:
10 | return ManifoldTensor(
11 | data=[
12 | [1.0, 2.0],
13 | [3.0, 4.0],
14 | ],
15 | manifold=PoincareBall(c=Curvature()),
16 | man_dim=-1,
17 | requires_grad=True,
18 | )
19 |
20 |
21 | def test_attributes(manifold_tensor: ManifoldTensor):
22 | # Check if the standard attributes are set correctly
23 | # TODO: fix this once __eq__ has been implemented on manifolds
24 | assert isinstance(manifold_tensor.manifold, PoincareBall)
25 | assert manifold_tensor.man_dim == 1
26 | # Check if non-callable attributes are taken from tensor attribute
27 | assert manifold_tensor.is_cpu
28 |
29 |
30 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
31 | def test_device_methods(manifold_tensor: ManifoldTensor):
32 | # Check if we can move the manifold tensor to the gpu while keeping it intact
33 | manifold_tensor = manifold_tensor.cuda()
34 | assert isinstance(manifold_tensor, ManifoldTensor)
35 | assert manifold_tensor.is_cuda
36 | assert isinstance(manifold_tensor.manifold, PoincareBall)
37 | assert manifold_tensor.man_dim == 1
38 |
39 | # And move it back to the cpu
40 | manifold_tensor = manifold_tensor.cpu()
41 | assert isinstance(manifold_tensor, ManifoldTensor)
42 | assert manifold_tensor.is_cpu
43 | assert isinstance(manifold_tensor.manifold, PoincareBall)
44 | assert manifold_tensor.man_dim == 1
45 |
46 |
47 | def test_slicing(manifold_tensor: ManifoldTensor):
48 | # First test slicing with the usual numpy slicing
49 | manifold_tensor = ManifoldTensor(
50 | data=torch.ones(2, 3, 4, 5, 6, 7, 8, 9),
51 | manifold=PoincareBall(Curvature()),
52 | man_dim=-2,
53 | )
54 |
55 | # Single index
56 | single_index = manifold_tensor[1]
57 | assert list(single_index.size()) == [3, 4, 5, 6, 7, 8, 9]
58 | assert single_index.man_dim == 5
59 |
60 | # More complicated list of slicing arguments
61 | sliced_tensor = manifold_tensor[1, None, [0, 2], ..., None, 2:5, :, 1]
62 | # Explanation of the output size: the dimension of size 2 dissappears because of the integer
63 | # index. Then, a dim of size 1 is added by None. The size 3 dimension reduces to 2 because
64 | # of the list of indices. The Ellipsis skips the dimensions of sizes 4, 5 and 6. Then another
65 | # None leads to an insertion of a dimension of size 1. The dimension with size 7 is reduced
66 | # to 3 because of the 2:5 slice. The manifold dimension of size 8 is left alone and the last
67 | # dimension is removed because of an integer index.
68 | assert list(sliced_tensor.size()) == [1, 2, 4, 5, 6, 1, 3, 8]
69 | assert sliced_tensor.man_dim == 7
70 |
71 | # Now we try to slice into the manifold dimension, which should raise an error
72 | with pytest.raises(ValueError):
73 | manifold_tensor[1, None, [0, 2], ..., None, 2:5, 5, 1]
74 |
75 | # Next, we try long tensor indexing, which is used in embeddings
76 | embedding_manifold_tensor = ManifoldTensor(
77 | data=torch.ones(10, 3),
78 | manifold=PoincareBall(Curvature()),
79 | man_dim=-1,
80 | )
81 | indices = torch.Tensor(
82 | [
83 | [1, 2],
84 | [3, 4],
85 | ]
86 | ).long()
87 | embedding_selection = embedding_manifold_tensor[indices]
88 | assert list(embedding_selection.size()) == [2, 2, 3]
89 | assert embedding_selection.man_dim == 2
90 |
91 | # This should fail if the man_dim is 0 on the embedding tensor though
92 | embedding_manifold_tensor = ManifoldTensor(
93 | data=torch.ones(10, 3),
94 | manifold=PoincareBall(Curvature()),
95 | man_dim=0,
96 | )
97 | with pytest.raises(ValueError):
98 | embedding_manifold_tensor[indices]
99 |
100 | # Lastly, embeddings with more dimensions should work too
101 | embedding_manifold_tensor = ManifoldTensor(
102 | data=torch.ones(10, 3, 3, 3),
103 | manifold=PoincareBall(Curvature()),
104 | man_dim=2,
105 | )
106 | embedding_selection = embedding_manifold_tensor[indices]
107 | assert list(embedding_selection.size()) == [2, 2, 3, 3, 3]
108 | assert embedding_selection.man_dim == 3
109 |
110 |
111 | def test_torch_ops(manifold_tensor: ManifoldTensor):
112 | # We want torch functons to raise an error
113 | with pytest.raises(TypeError):
114 | torch.norm(manifold_tensor)
115 |
116 | # Same for torch.Tensor methods (callable attributes)
117 | with pytest.raises(AttributeError):
118 | manifold_tensor.mean()
119 |
--------------------------------------------------------------------------------
/tutorials/README.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxvanspengler/hyperbolic_learning_library/62ea7882848d878459eba4100aeb6b8ad64cc38f/tutorials/README.txt
--------------------------------------------------------------------------------
/tutorials/cifar10_resnet_tutorial.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Training a Poincare ResNet
4 | ==========================
5 |
6 | This is an implementation based on the Poincare Resnet paper, which can be found at:
7 |
8 | - https://arxiv.org/abs/2303.14027
9 |
10 | Due to the complexity of hyperbolic operations we strongly advise to only run this tutorial with a
11 | GPU.
12 |
13 | We will perform the following steps in order:
14 |
15 | 1. Define a hyperbolic manifold
16 | 2. Load and normalize the CIFAR10 training and test datasets using ``torchvision``
17 | 3. Define a Poincare ResNet
18 | 4. Define a loss function and optimizer
19 | 5. Train the network on the training data
20 | 6. Test the network on the test data
21 |
22 | """
23 |
24 | ##############################
25 | # 0. Grab the available device
26 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
27 | import torch
28 |
29 | device = "cuda" if torch.cuda.is_available() else "cpu"
30 |
31 |
32 | #############################
33 | # 1. Define the Poincare ball
34 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^
35 |
36 | from hypll.manifolds.poincare_ball import Curvature, PoincareBall
37 |
38 | # Making the curvature a learnable parameter is usually suboptimal but can
39 | # make training smoother. An initial curvature of 0.1 has also been shown
40 | # to help during training.
41 | manifold = PoincareBall(c=Curvature(value=0.1, requires_grad=True))
42 |
43 |
44 | ###############################
45 | # 2. Load and normalize CIFAR10
46 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
47 |
48 | import torchvision
49 | import torchvision.transforms as transforms
50 |
51 | ########################################################################
52 | # .. note::
53 | # If running on Windows and you get a BrokenPipeError, try setting
54 | # the num_worker of torch.utils.data.DataLoader() to 0.
55 |
56 | transform = transforms.Compose(
57 | [
58 | transforms.ToTensor(),
59 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
60 | ]
61 | )
62 |
63 | batch_size = 128
64 |
65 | trainset = torchvision.datasets.CIFAR10(
66 | root="./data", train=True, download=True, transform=transform
67 | )
68 | trainloader = torch.utils.data.DataLoader(
69 | trainset, batch_size=batch_size, shuffle=True, num_workers=2
70 | )
71 |
72 | testset = torchvision.datasets.CIFAR10(
73 | root="./data", train=False, download=True, transform=transform
74 | )
75 | testloader = torch.utils.data.DataLoader(
76 | testset, batch_size=batch_size, shuffle=False, num_workers=2
77 | )
78 |
79 |
80 | ###############################
81 | # 3. Define a Poincare ResNet
82 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
83 | # This implementation is based on the Poincare ResNet paper, which can
84 | # be found at https://arxiv.org/abs/2303.14027 and which, in turn, is
85 | # based on the original Euclidean implementation described in the paper
86 | # Deep Residual Learning for Image Recognition by He et al. from 2015:
87 | # https://arxiv.org/abs/1512.03385.
88 |
89 |
90 | from typing import Optional
91 |
92 | from torch import nn
93 |
94 | from hypll import nn as hnn
95 | from hypll.tensors import ManifoldTensor
96 |
97 |
98 | class PoincareResidualBlock(nn.Module):
99 | def __init__(
100 | self,
101 | in_channels: int,
102 | out_channels: int,
103 | manifold: PoincareBall,
104 | stride: int = 1,
105 | downsample: Optional[nn.Sequential] = None,
106 | ):
107 | # We can replace each operation in the usual ResidualBlock by a manifold-agnostic
108 | # operation and supply the PoincareBall object to these operations.
109 | super().__init__()
110 | self.in_channels = in_channels
111 | self.out_channels = out_channels
112 | self.manifold = manifold
113 | self.stride = stride
114 | self.downsample = downsample
115 |
116 | self.conv1 = hnn.HConvolution2d(
117 | in_channels=in_channels,
118 | out_channels=out_channels,
119 | kernel_size=3,
120 | manifold=manifold,
121 | stride=stride,
122 | padding=1,
123 | )
124 | self.bn1 = hnn.HBatchNorm2d(features=out_channels, manifold=manifold)
125 | self.relu = hnn.HReLU(manifold=self.manifold)
126 | self.conv2 = hnn.HConvolution2d(
127 | in_channels=out_channels,
128 | out_channels=out_channels,
129 | kernel_size=3,
130 | manifold=manifold,
131 | padding=1,
132 | )
133 | self.bn2 = hnn.HBatchNorm2d(features=out_channels, manifold=manifold)
134 |
135 | def forward(self, x: ManifoldTensor) -> ManifoldTensor:
136 | residual = x
137 | x = self.conv1(x)
138 | x = self.bn1(x)
139 | x = self.relu(x)
140 | x = self.conv2(x)
141 | x = self.bn2(x)
142 |
143 | if self.downsample is not None:
144 | residual = self.downsample(residual)
145 |
146 | # We replace the addition operation inside the skip connection by a Mobius addition.
147 | x = self.manifold.mobius_add(x, residual)
148 | x = self.relu(x)
149 |
150 | return x
151 |
152 |
153 | class PoincareResNet(nn.Module):
154 | def __init__(
155 | self,
156 | channel_sizes: list[int],
157 | group_depths: list[int],
158 | manifold: PoincareBall,
159 | ):
160 | # For the Poincare ResNet itself we again replace each layer by a manifold-agnostic one
161 | # and supply the PoincareBall to each of these. We also replace the ResidualBlocks by
162 | # the manifold-agnostic one defined above.
163 | super().__init__()
164 | self.channel_sizes = channel_sizes
165 | self.group_depths = group_depths
166 | self.manifold = manifold
167 |
168 | self.conv = hnn.HConvolution2d(
169 | in_channels=3,
170 | out_channels=channel_sizes[0],
171 | kernel_size=3,
172 | manifold=manifold,
173 | padding=1,
174 | )
175 | self.bn = hnn.HBatchNorm2d(features=channel_sizes[0], manifold=manifold)
176 | self.relu = hnn.HReLU(manifold=manifold)
177 | self.group1 = self._make_group(
178 | in_channels=channel_sizes[0],
179 | out_channels=channel_sizes[0],
180 | depth=group_depths[0],
181 | )
182 | self.group2 = self._make_group(
183 | in_channels=channel_sizes[0],
184 | out_channels=channel_sizes[1],
185 | depth=group_depths[1],
186 | stride=2,
187 | )
188 | self.group3 = self._make_group(
189 | in_channels=channel_sizes[1],
190 | out_channels=channel_sizes[2],
191 | depth=group_depths[2],
192 | stride=2,
193 | )
194 |
195 | self.avg_pool = hnn.HAvgPool2d(kernel_size=8, manifold=manifold)
196 | self.fc = hnn.HLinear(in_features=channel_sizes[2], out_features=10, manifold=manifold)
197 |
198 | def forward(self, x: ManifoldTensor) -> ManifoldTensor:
199 | x = self.conv(x)
200 | x = self.bn(x)
201 | x = self.relu(x)
202 | x = self.group1(x)
203 | x = self.group2(x)
204 | x = self.group3(x)
205 | x = self.avg_pool(x)
206 | x = self.fc(x.squeeze())
207 | return x
208 |
209 | def _make_group(
210 | self,
211 | in_channels: int,
212 | out_channels: int,
213 | depth: int,
214 | stride: int = 1,
215 | ) -> nn.Sequential:
216 | if stride == 1:
217 | downsample = None
218 | else:
219 | downsample = hnn.HConvolution2d(
220 | in_channels=in_channels,
221 | out_channels=out_channels,
222 | kernel_size=1,
223 | manifold=self.manifold,
224 | stride=stride,
225 | )
226 |
227 | layers = [
228 | PoincareResidualBlock(
229 | in_channels=in_channels,
230 | out_channels=out_channels,
231 | manifold=self.manifold,
232 | stride=stride,
233 | downsample=downsample,
234 | )
235 | ]
236 |
237 | for _ in range(1, depth):
238 | layers.append(
239 | PoincareResidualBlock(
240 | in_channels=out_channels,
241 | out_channels=out_channels,
242 | manifold=self.manifold,
243 | )
244 | )
245 |
246 | return nn.Sequential(*layers)
247 |
248 |
249 | # Now, let's create a thin Poincare ResNet with channel sizes [4, 8, 16] and with a depth of 20
250 | # layers.
251 | net = PoincareResNet(
252 | channel_sizes=[4, 8, 16],
253 | group_depths=[3, 3, 3],
254 | manifold=manifold,
255 | ).to(device)
256 |
257 |
258 | #########################################
259 | # 4. Define a Loss function and optimizer
260 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
261 | # Let's use a Classification Cross-Entropy loss and RiemannianAdam optimizer.
262 |
263 | criterion = nn.CrossEntropyLoss()
264 | # net.parameters() includes the learnable curvature "c" of the manifold.
265 | from hypll.optim import RiemannianAdam
266 |
267 | optimizer = RiemannianAdam(net.parameters(), lr=0.001)
268 |
269 |
270 | ######################
271 | # 5. Train the network
272 | # ^^^^^^^^^^^^^^^^^^^^
273 | # We simply have to loop over our data iterator, project the inputs onto the
274 | # manifold, and feed them to the network and optimize. We will train for a limited
275 | # number of epochs here due to the long training time of this model.
276 |
277 | from hypll.tensors import TangentTensor
278 |
279 | for epoch in range(2): # Increase this number to at least 100 for good results
280 | running_loss = 0.0
281 | for i, data in enumerate(trainloader, 0):
282 | # get the inputs; data is a list of [inputs, labels]
283 | inputs, labels = data[0].to(device), data[1].to(device)
284 |
285 | # move the inputs to the manifold
286 | tangents = TangentTensor(data=inputs, man_dim=1, manifold=manifold)
287 | manifold_inputs = manifold.expmap(tangents)
288 |
289 | # zero the parameter gradients
290 | optimizer.zero_grad()
291 |
292 | # forward + backward + optimize
293 | outputs = net(manifold_inputs)
294 | loss = criterion(outputs.tensor, labels)
295 | loss.backward()
296 | optimizer.step()
297 |
298 | # print statistics
299 | running_loss += loss.item()
300 | print(f"[{epoch + 1}, {i + 1:5d}] loss: {loss.item():.3f}")
301 | running_loss = 0.0
302 |
303 | print("Finished Training")
304 |
305 |
306 | ######################################
307 | # 6. Test the network on the test data
308 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
309 | #
310 | # Let us look at how the network performs on the whole dataset.
311 |
312 | correct = 0
313 | total = 0
314 | # since we're not training, we don't need to calculate the gradients for our outputs
315 | with torch.no_grad():
316 | for data in testloader:
317 | images, labels = data[0].to(device), data[1].to(device)
318 |
319 | # move the images to the manifold
320 | tangents = TangentTensor(data=images, man_dim=1, manifold=manifold)
321 | manifold_images = manifold.expmap(tangents)
322 |
323 | # calculate outputs by running images through the network
324 | outputs = net(manifold_images)
325 | # the class with the highest energy is what we choose as prediction
326 | _, predicted = torch.max(outputs.tensor, 1)
327 | total += labels.size(0)
328 | correct += (predicted == labels).sum().item()
329 |
330 | print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")
331 |
--------------------------------------------------------------------------------
/tutorials/cifar10_tutorial.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Training a Hyperbolic Classifier
4 | ================================
5 |
6 | This is an adaptation of torchvision's tutorial "Training a Classifier" to
7 | hyperbolic space. The original tutorial can be found here:
8 |
9 | - https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
10 |
11 | Training a Hyperbolic Image Classifier
12 | --------------------------------------
13 |
14 | We will do the following steps in order:
15 |
16 | 1. Define a hyperbolic manifold
17 | 2. Load and normalize the CIFAR10 training and test datasets using ``torchvision``
18 | 3. Define a hyperbolic Convolutional Neural Network
19 | 4. Define a loss function and optimizer
20 | 5. Train the network on the training data
21 | 6. Test the network on the test data
22 |
23 | """
24 |
25 | ########################################################################
26 | # 1. Define a hyperbolic manifold
27 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
28 | # We use the Poincaré ball model for the purposes of this tutorial.
29 |
30 |
31 | from hypll.manifolds.poincare_ball import Curvature, PoincareBall
32 |
33 | # Making the curvature a learnable parameter is usually suboptimal but can
34 | # make training smoother.
35 | manifold = PoincareBall(c=Curvature(requires_grad=True))
36 |
37 |
38 | ########################################################################
39 | # 2. Load and normalize CIFAR10
40 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
41 |
42 | import torch
43 | import torchvision
44 | import torchvision.transforms as transforms
45 |
46 | ########################################################################
47 | # .. note::
48 | # If running on Windows and you get a BrokenPipeError, try setting
49 | # the num_worker of torch.utils.data.DataLoader() to 0.
50 |
51 | transform = transforms.Compose(
52 | [
53 | transforms.ToTensor(),
54 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
55 | ]
56 | )
57 |
58 |
59 | batch_size = 4
60 |
61 | trainset = torchvision.datasets.CIFAR10(
62 | root="./data", train=True, download=True, transform=transform
63 | )
64 | trainloader = torch.utils.data.DataLoader(
65 | trainset, batch_size=batch_size, shuffle=True, num_workers=2
66 | )
67 |
68 | testset = torchvision.datasets.CIFAR10(
69 | root="./data", train=False, download=True, transform=transform
70 | )
71 | testloader = torch.utils.data.DataLoader(
72 | testset, batch_size=batch_size, shuffle=False, num_workers=2
73 | )
74 |
75 | classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")
76 |
77 |
78 | ########################################################################
79 | # 3. Define a hyperbolic Convolutional Neural Network
80 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
81 | # Let's rebuild the convolutional neural network from torchvision's tutorial
82 | # using hyperbolic modules.
83 |
84 | from torch import nn
85 |
86 | from hypll import nn as hnn
87 |
88 |
89 | class Net(nn.Module):
90 | def __init__(self):
91 | super().__init__()
92 | self.conv1 = hnn.HConvolution2d(
93 | in_channels=3, out_channels=6, kernel_size=5, manifold=manifold
94 | )
95 | self.pool = hnn.HMaxPool2d(kernel_size=2, manifold=manifold, stride=2)
96 | self.conv2 = hnn.HConvolution2d(
97 | in_channels=6, out_channels=16, kernel_size=5, manifold=manifold
98 | )
99 | self.fc1 = hnn.HLinear(in_features=16 * 5 * 5, out_features=120, manifold=manifold)
100 | self.fc2 = hnn.HLinear(in_features=120, out_features=84, manifold=manifold)
101 | self.fc3 = hnn.HLinear(in_features=84, out_features=10, manifold=manifold)
102 | self.relu = hnn.HReLU(manifold=manifold)
103 |
104 | def forward(self, x):
105 | x = self.pool(self.relu(self.conv1(x)))
106 | x = self.pool(self.relu(self.conv2(x)))
107 | x = x.flatten(start_dim=1)
108 | x = self.relu(self.fc1(x))
109 | x = self.relu(self.fc2(x))
110 | x = self.fc3(x)
111 | return x
112 |
113 |
114 | net = Net()
115 |
116 | ########################################################################
117 | # 4. Define a Loss function and optimizer
118 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
119 | # Let's use a Classification Cross-Entropy loss and RiemannianAdam optimizer.
120 | # Adam is preferred because hyperbolic linear layers can sometimes have training
121 | # difficulties early on due to poor initialization.
122 |
123 | criterion = nn.CrossEntropyLoss()
124 | # net.parameters() includes the learnable curvature "c" of the manifold.
125 | from hypll.optim import RiemannianAdam
126 |
127 | optimizer = RiemannianAdam(net.parameters(), lr=0.001)
128 |
129 |
130 | ########################################################################
131 | # 5. Train the network
132 | # ^^^^^^^^^^^^^^^^^^^^
133 | # This is when things start to get interesting.
134 | # We simply have to loop over our data iterator, project the inputs onto the
135 | # manifold, and feed them to the network and optimize.
136 |
137 | from hypll.tensors import TangentTensor
138 |
139 | for epoch in range(2): # loop over the dataset multiple times
140 | running_loss = 0.0
141 | for i, data in enumerate(trainloader, 0):
142 | # get the inputs; data is a list of [inputs, labels]
143 | inputs, labels = data
144 |
145 | # move the inputs to the manifold
146 | tangents = TangentTensor(data=inputs, man_dim=1, manifold=manifold)
147 | manifold_inputs = manifold.expmap(tangents)
148 |
149 | # zero the parameter gradients
150 | optimizer.zero_grad()
151 |
152 | # forward + backward + optimize
153 | outputs = net(manifold_inputs)
154 | loss = criterion(outputs.tensor, labels)
155 | loss.backward()
156 | optimizer.step()
157 |
158 | # print statistics
159 | running_loss += loss.item()
160 | if i % 2000 == 1999: # print every 2000 mini-batches
161 | print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
162 | running_loss = 0.0
163 |
164 | print("Finished Training")
165 |
166 |
167 | ########################################################################
168 | # Let's quickly save our trained model:
169 |
170 | PATH = "./cifar_net.pth"
171 | torch.save(net.state_dict(), PATH)
172 |
173 |
174 | ########################################################################
175 | # Next, let's load back in our saved model (note: saving and re-loading the model
176 | # wasn't necessary here, we only did it to illustrate how to do so):
177 |
178 | net = Net()
179 | net.load_state_dict(torch.load(PATH))
180 |
181 | ########################################################################
182 | # 6. Test the network on the test data
183 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
184 | #
185 | # Let us look at how the network performs on the whole dataset.
186 |
187 | correct = 0
188 | total = 0
189 | # since we're not training, we don't need to calculate the gradients for our outputs
190 | with torch.no_grad():
191 | for data in testloader:
192 | images, labels = data
193 |
194 | # move the images to the manifold
195 | tangents = TangentTensor(data=images, man_dim=1, manifold=manifold)
196 | manifold_images = manifold.expmap(tangents)
197 |
198 | # calculate outputs by running images through the network
199 | outputs = net(manifold_images)
200 | # the class with the highest energy is what we choose as prediction
201 | _, predicted = torch.max(outputs.tensor, 1)
202 | total += labels.size(0)
203 | correct += (predicted == labels).sum().item()
204 |
205 | print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")
206 |
207 |
208 | ########################################################################
209 | # That looks way better than chance, which is 10% accuracy (randomly picking
210 | # a class out of 10 classes).
211 | # Seems like the network learnt something.
212 | #
213 | # Hmmm, what are the classes that performed well, and the classes that did
214 | # not perform well:
215 |
216 | # prepare to count predictions for each class
217 | correct_pred = {classname: 0 for classname in classes}
218 | total_pred = {classname: 0 for classname in classes}
219 |
220 | # again no gradients needed
221 | with torch.no_grad():
222 | for data in testloader:
223 | images, labels = data
224 |
225 | # move the images to the manifold
226 | tangents = TangentTensor(data=images, man_dim=1, manifold=manifold)
227 | manifold_images = manifold.expmap(tangents)
228 |
229 | outputs = net(manifold_images)
230 | _, predictions = torch.max(outputs.tensor, 1)
231 | # collect the correct predictions for each class
232 | for label, prediction in zip(labels, predictions):
233 | if label == prediction:
234 | correct_pred[classes[label]] += 1
235 | total_pred[classes[label]] += 1
236 |
237 | # print accuracy for each class
238 | for classname, correct_count in correct_pred.items():
239 | accuracy = 100 * float(correct_count) / total_pred[classname]
240 | print(f"Accuracy for class: {classname:5s} is {accuracy:.1f} %")
241 |
242 |
243 | ########################################################################
244 | #
245 | # Training on GPU
246 | # ----------------
247 | # Just like how you transfer a Tensor onto the GPU, you transfer the neural
248 | # net onto the GPU.
249 | #
250 | # Let's first define our device as the first visible cuda device if we have
251 | # CUDA available:
252 | #
253 | # .. code:: python
254 | #
255 | # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
256 | #
257 | #
258 | # Assuming that we are on a CUDA machine, this should print a CUDA device:
259 | #
260 | # .. code:: python
261 | #
262 | # print(device)
263 | #
264 | #
265 | # The rest of this section assumes that ``device`` is a CUDA device.
266 | #
267 | # Then these methods will recursively go over all modules and convert their
268 | # parameters and buffers to CUDA tensors:
269 | #
270 | # .. code:: python
271 | #
272 | # net.to(device)
273 | #
274 | #
275 | # Remember that you will have to send the inputs and targets at every step
276 | # to the GPU too:
277 | #
278 | # .. code:: python
279 | #
280 | # inputs, labels = data[0].to(device), data[1].to(device)
281 | #
282 | #
283 | # **Goals achieved**:
284 | #
285 | # - Train a small hyperbolic neural network to classify images.
286 | #
287 |
--------------------------------------------------------------------------------
/tutorials/hyperbolic_vit_tutorial.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Training a Hyperbolic Vision Transformer
4 | ========================================
5 |
6 | This implementation is based on the "Hyperbolic Vision Transformers: Combining Improvements in Metric Learning"
7 | paper by Aleksandr Ermolov et al., which can be found at:
8 |
9 | - https://arxiv.org/pdf/2203.10833.pdf
10 |
11 | Their implementation can be found at:
12 |
13 | - https://github.com/htdt/hyp_metric
14 |
15 | We will perform the following steps in order:
16 |
17 | 1. Initialize the manifold on which the embeddings will be trained
18 | 2. Load the CUB_200_2011 dataset using fastai
19 | 3. Initialize train and test dataloaders
20 | 4. Define the Hyperbolic Vision Transformer model
21 | 5. Define the pairwise cross-entropy loss function
22 | 6. Define the recall@k evaluation metric
23 | 7. Train the model on the training data
24 | 8. Test the model on the test data
25 |
26 | Please make sure to install the following packages before running this script:
27 | - fastai
28 | - timm
29 |
30 | """
31 |
32 | #######################################################
33 | # 0. Set the device and random seed for reproducibility
34 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
35 |
36 | import random
37 |
38 | import numpy as np
39 | import torch
40 |
41 | device = "cuda" if torch.cuda.is_available() else "cpu"
42 |
43 | seed = 0
44 |
45 | torch.manual_seed(seed)
46 | torch.cuda.manual_seed(seed)
47 | np.random.seed(seed)
48 | random.seed(seed)
49 |
50 | torch.backends.cudnn.deterministic = True
51 | torch.backends.cudnn.benchmark = False
52 |
53 | ####################################################################
54 | # 1. Initialize the manifold on which the embeddings will be trained
55 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
56 |
57 | from hypll.manifolds.euclidean import Euclidean
58 | from hypll.manifolds.poincare_ball import Curvature, PoincareBall
59 |
60 | # We can choose between the Poincaré ball model and Euclidean space.
61 | do_hyperbolic = True
62 |
63 | if do_hyperbolic:
64 | # We fix the curvature to 0.1 as in the paper (Section 3.2).
65 | manifold = PoincareBall(
66 | c=Curvature(value=0.1, constraining_strategy=lambda x: x, requires_grad=False)
67 | )
68 | else:
69 | manifold = Euclidean()
70 |
71 | ###############################################
72 | # 2. Load the CUB_200_2011 dataset using fastai
73 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
74 |
75 | import pandas as pd
76 | import torchvision.transforms as transforms
77 | from fastai.data.external import URLs, untar_data
78 | from PIL import Image
79 | from torch.utils.data import Dataset
80 |
81 | # Download the dataset using fastai.
82 | path = untar_data(URLs.CUB_200_2011) / "CUB_200_2011"
83 |
84 |
85 | class CUB(Dataset):
86 | """Handles the downloaded CUB_200_2011 files."""
87 |
88 | def __init__(self, files_path, train, transform):
89 | self.files_path = files_path
90 | self.transform = transform
91 |
92 | self.targets = pd.read_csv(files_path / "image_class_labels.txt", header=None, sep=" ")
93 | self.targets.columns = ["id", "target"]
94 | self.targets = self.targets["target"].values
95 |
96 | self.train_test = pd.read_csv(files_path / "train_test_split.txt", header=None, sep=" ")
97 | self.train_test.columns = ["id", "is_train"]
98 |
99 | self.images = pd.read_csv(files_path / "images.txt", header=None, sep=" ")
100 | self.images.columns = ["id", "name"]
101 |
102 | mask = self.train_test.is_train.values == int(train)
103 |
104 | self.filenames = self.images.iloc[mask]
105 | self.targets = self.targets[mask]
106 | self.num_files = len(self.targets)
107 |
108 | def __len__(self):
109 | return self.num_files
110 |
111 | def __getitem__(self, index):
112 | y = self.targets[index] - 1
113 | file_name = self.filenames.iloc[index, 1]
114 | path = self.files_path / "images" / file_name
115 | x = Image.open(path).convert("RGB")
116 | x = self.transform(x)
117 | return x, y
118 |
119 |
120 | # We use the same resizing and data augmentation as in the paper (Section 3.2).
121 | # Normalization values are taken from the original implementation.
122 |
123 | train_transform = transforms.Compose(
124 | [
125 | transforms.RandomResizedCrop(
126 | 224, scale=(0.9, 1.0), interpolation=transforms.InterpolationMode("bicubic")
127 | ),
128 | transforms.RandomHorizontalFlip(),
129 | transforms.ToTensor(),
130 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
131 | ]
132 | )
133 | trainset = CUB(path, train=True, transform=train_transform)
134 |
135 | test_transform = transforms.Compose(
136 | [
137 | transforms.Resize(256, interpolation=transforms.InterpolationMode("bicubic")),
138 | transforms.CenterCrop(224),
139 | transforms.ToTensor(),
140 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
141 | ]
142 | )
143 | testset = CUB(path, train=False, transform=test_transform)
144 |
145 | ##########################################
146 | # 3. Initialize train and test dataloaders
147 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
148 |
149 | import collections
150 | import multiprocessing
151 |
152 | from torch.utils.data import DataLoader
153 | from torch.utils.data.sampler import Sampler
154 |
155 | # We need a dataloader that returns a requested number of images belonging to the same class.
156 | # This will allow us to compute the pairwise cross-entropy loss.
157 |
158 |
159 | class UniqueClassSampler(Sampler):
160 | """Custom sampler for creating batches with m_per_class samples per class."""
161 |
162 | def __init__(self, labels, m_per_class, seed):
163 | self.labels_to_indices = self.get_labels_to_indices(labels)
164 | self.labels = sorted(list(self.labels_to_indices.keys()))
165 | self.m_per_class = m_per_class
166 | self.seed = seed
167 | self.epoch = 0
168 |
169 | def __len__(self):
170 | return (
171 | np.max([len(v) for v in self.labels_to_indices.values()])
172 | * m_per_class
173 | * len(self.labels)
174 | )
175 |
176 | def set_epoch(self, epoch: int):
177 | # Update the epoch, so that the random permutation changes.
178 | self.epoch = epoch
179 |
180 | def get_labels_to_indices(self, labels: list):
181 | # Create a dictionary mapping each label to the indices in the dataset
182 | labels_to_indices = collections.defaultdict(list)
183 | for i, label in enumerate(labels):
184 | labels_to_indices[label].append(i)
185 | for k, v in labels_to_indices.items():
186 | labels_to_indices[k] = np.array(v, dtype=int)
187 | return labels_to_indices
188 |
189 | def __iter__(self):
190 | # Generate a random iteration order for the indices.
191 | # For example, if we have 3 classes (A,B,C) and m_per_class=2,
192 | # we could get the following indices: [A1, A2, C1, C2, B1, B2].
193 | idx_list = []
194 | g = torch.Generator()
195 | g.manual_seed(self.seed * 10000 + self.epoch)
196 | idx = torch.randperm(len(self.labels), generator=g).tolist()
197 | max_indices = np.max([len(v) for v in self.labels_to_indices.values()])
198 | for i in idx:
199 | t = self.labels_to_indices[self.labels[i]]
200 | idx_list.append(np.random.choice(t, size=self.m_per_class * max_indices))
201 | idx_list = np.stack(idx_list).reshape(len(self.labels), -1, self.m_per_class)
202 | idx_list = idx_list.transpose(1, 0, 2).reshape(-1).tolist()
203 | return iter(idx_list)
204 |
205 |
206 | # First, define how many distinct classes to encounter per batch.
207 | # The dataset has 200 classes in total, but the number we choose depends on the available memory.
208 | n_sampled_classes = 128
209 | # Then, define how many examples per class to encounter per batch.
210 | # This number must be at least 2.
211 | m_per_class = 2
212 | # Finally, the batch size is determined by our two choices above.
213 | batch_size = n_sampled_classes * m_per_class
214 |
215 | # Note that this sampler is only used during training.
216 | sampler = UniqueClassSampler(labels=trainset.targets, m_per_class=m_per_class, seed=seed)
217 |
218 | trainloader = DataLoader(
219 | dataset=trainset,
220 | sampler=sampler,
221 | batch_size=batch_size,
222 | pin_memory=True,
223 | drop_last=True,
224 | )
225 |
226 | testloader = DataLoader(
227 | dataset=testset,
228 | batch_size=batch_size,
229 | pin_memory=True,
230 | drop_last=False,
231 | )
232 |
233 | ###################################################
234 | # 4. Define the Hyperbolic Vision Transformer model
235 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
236 |
237 | from typing import Union
238 |
239 | import timm
240 | from torch import nn
241 |
242 | from hypll.tensors import ManifoldTensor, TangentTensor
243 |
244 |
245 | class HyperbolicViT(nn.Module):
246 | def __init__(
247 | self,
248 | vit_name: str,
249 | vit_dim: int,
250 | emb_dim: int,
251 | manifold: Union[PoincareBall, Euclidean],
252 | clip_r: float,
253 | pretrained: bool = True,
254 | ):
255 | super().__init__()
256 | # We use the timm library to load a (pretrained) ViT backbone model.
257 | self.vit = timm.create_model(vit_name, pretrained=pretrained)
258 | # We remove the head of the model, since we won't need it.
259 | self.remove_head()
260 | self.fc = nn.Linear(vit_dim, emb_dim)
261 | # We freeze the linear projection for patch embeddings as in the paper (Section 3.2).
262 | self.freeze_patch_embed()
263 | self.manifold = manifold
264 | self.clip_r = clip_r
265 |
266 | def forward(self, x: torch.Tensor) -> ManifoldTensor:
267 | # We first encode the images using the ViT backbone.
268 | x = self.vit(x)
269 | if type(x) == tuple:
270 | x = x[0]
271 | # Then, we project the features into a lower-dimensional space.
272 | x = self.fc(x)
273 | # If we use a Euclidean manifold,
274 | # we simply normalize the features to lie on the unit sphere.
275 | if isinstance(self.manifold, Euclidean):
276 | x = F.normalize(x, p=2, dim=1)
277 | return ManifoldTensor(data=x, man_dim=1, manifold=self.manifold)
278 | # If we use a Poincaré ball model,
279 | # we (optionally) perform feature clipping (Section 2.4),
280 | if self.clip_r is not None:
281 | x = self.clip_features(x)
282 | # and then map the features to the manifold.
283 | tangents = TangentTensor(data=x, man_dim=1, manifold=self.manifold)
284 | return self.manifold.expmap(tangents)
285 |
286 | def clip_features(self, x: torch.Tensor) -> torch.Tensor:
287 | x_norm = torch.norm(x, dim=-1, keepdim=True) + 1e-5
288 | fac = torch.minimum(torch.ones_like(x_norm), self.clip_r / x_norm)
289 | return x * fac
290 |
291 | def remove_head(self):
292 | names = set(x[0] for x in self.vit.named_children())
293 | target = {"head", "fc", "head_dist", "head_drop"}
294 | for x in names & target:
295 | self.vit.add_module(x, nn.Identity())
296 |
297 | def freeze_patch_embed(self):
298 | def fr(m):
299 | for param in m.parameters():
300 | param.requires_grad = False
301 |
302 | fr(self.vit.patch_embed)
303 | fr(self.vit.pos_drop)
304 |
305 |
306 | # Initialize the model using the same hyperparameters as in the paper (Section 3.2).
307 | hvit = HyperbolicViT(
308 | vit_name="vit_small_patch16_224", vit_dim=384, emb_dim=128, manifold=manifold, clip_r=2.3
309 | ).to(device)
310 |
311 | ####################################################
312 | # 5. Define the pairwise cross-entropy loss function
313 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
314 |
315 | import torch.nn.functional as F
316 |
317 | # The function is described in Section 2.2 of the paper.
318 |
319 |
320 | def pairwise_cross_entropy_loss(
321 | z_0: ManifoldTensor,
322 | z_1: ManifoldTensor,
323 | manifold: Union[PoincareBall, Euclidean],
324 | tau: float,
325 | ) -> torch.Tensor:
326 | dist_f = lambda x, y: -manifold.cdist(x.unsqueeze(0), y.unsqueeze(0))[0]
327 | num_classes = z_0.shape[0]
328 | target = torch.arange(num_classes, device=device)
329 | eye_mask = torch.eye(num_classes, device=device) * 1e9
330 | logits00 = dist_f(z_0, z_0) / tau - eye_mask
331 | logits01 = dist_f(z_0, z_1) / tau
332 | logits = torch.cat([logits01, logits00], dim=1)
333 | logits -= logits.max(1, keepdim=True)[0].detach()
334 | loss = F.cross_entropy(logits, target)
335 | return loss
336 |
337 |
338 | # We use the same temperature as in the paper (Section 3.2).
339 | criterion = lambda z_0, z_1: pairwise_cross_entropy_loss(z_0, z_1, manifold=manifold, tau=0.2)
340 |
341 | ##########################################
342 | # 6. Define the recall@k evaluation metric
343 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
344 |
345 | # Retrieval is performed by computing the distance between the query and all other embeddings.
346 | # The top-k embeddings with the smallest distance are then used to compute the recall@k metric.
347 |
348 |
349 | def eval_recall_k(
350 | k: int, model: nn.Module, dataloader: DataLoader, manifold: Union[PoincareBall, Euclidean]
351 | ):
352 | def get_embeddings():
353 | embs = []
354 | model.eval()
355 | for data in dataloader:
356 | x = data[0].to(device)
357 | with torch.no_grad():
358 | z = model(x)
359 | embs.append(z)
360 | embs = manifold.cat(embs, dim=0)
361 | model.train()
362 | return embs
363 |
364 | def get_recall_k(embs):
365 | dist_matrix = torch.zeros(embs.shape[0], embs.shape[0])
366 | for i in range(embs.shape[0]):
367 | dist_matrix[i] = -manifold.cpu().cdist(
368 | embs[[i]].unsqueeze(0).cpu(), embs.unsqueeze(0).cpu()
369 | )[0]
370 | dist_matrix = torch.nan_to_num(dist_matrix, nan=-torch.inf)
371 | targets = np.array(dataloader.dataset.targets)
372 | top_k = targets[dist_matrix.topk(1 + k).indices[:, 1:].numpy()]
373 | recall_k = np.mean([1 if t in top_k[i] else 0 for i, t in enumerate(targets)])
374 | return recall_k
375 |
376 | return get_recall_k(get_embeddings())
377 |
378 |
379 | #########################################
380 | # 7. Train the model on the training data
381 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
382 |
383 | import torch.optim as optim
384 |
385 | # We use the AdamW optimizer with the same hyperparameters as in the paper (Section 3.2).
386 | optimizer = optim.AdamW(hvit.parameters(), lr=3e-5, weight_decay=0.01)
387 |
388 | # Training for 5 epochs is enough to achieve good results.
389 | num_epochs = 5
390 | for epoch in range(num_epochs):
391 | sampler.set_epoch(epoch)
392 | running_loss = 0.0
393 | num_samples = 0
394 | for d_i, data in enumerate(trainloader):
395 | # Get the inputs; data is a list of [inputs, labels]
396 | inputs = data[0].to(device)
397 |
398 | # Zero out the parameter gradients
399 | optimizer.zero_grad()
400 |
401 | # Forward + backward + optimize
402 | bs = len(inputs)
403 | z = hvit(inputs)
404 | z = ManifoldTensor(
405 | data=z.tensor.reshape((bs // m_per_class, m_per_class, -1)),
406 | man_dim=-1,
407 | manifold=z.manifold,
408 | )
409 | # The loss is computed pair-wise.
410 | loss = 0
411 | for i in range(m_per_class):
412 | for j in range(m_per_class):
413 | if i != j:
414 | z_i = z[:, i]
415 | z_j = z[:, j]
416 | loss += criterion(z_i, z_j)
417 | loss.backward()
418 | torch.nn.utils.clip_grad_norm_(hvit.parameters(), 3)
419 | optimizer.step()
420 |
421 | # Print statistics every 10 batches.
422 | running_loss += loss.item() * bs
423 | num_samples += bs
424 | if d_i % 10 == 9:
425 | print(
426 | f"[Epoch {epoch + 1}, {d_i + 1:5d}/{len(trainloader)}] train loss: {running_loss/num_samples:.3f}"
427 | )
428 | running_loss = 0.0
429 | num_samples = 0
430 |
431 | print("Finished training!")
432 |
433 | ####################################
434 | # 8. Test the model on the test data
435 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
436 |
437 | # Finally, let's see how well the model performs on the test data.
438 |
439 | k = 5
440 | recall_k = eval_recall_k(k, hvit, testloader, manifold)
441 | print(f"Test recall@{k}: {recall_k:.3f}")
442 |
443 | # Using the default hyperparameters in this tutorial,
444 | # we are able to achieve a recall@5 of 0.946 for the Poincaré ball model,
445 | # and a recall@5 of 0.916 for the Euclidean model.
446 |
--------------------------------------------------------------------------------
/tutorials/poincare_embeddings_tutorial.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Training Poincare embeddings for the mammals subset from WordNet
4 | ================================================================
5 |
6 | This implementation is based on the "Poincare Embeddings for Learning Hierarchical Representations"
7 | paper by Maximilian Nickel and Douwe Kiela, which can be found at:
8 |
9 | - https://arxiv.org/pdf/1705.08039.pdf
10 |
11 | Their implementation can be found at:
12 |
13 | - https://github.com/facebookresearch/poincare-embeddings
14 |
15 | We will perform the following steps in order:
16 |
17 | 1. Load the mammals subset from the WordNet hierarchy using NetworkX
18 | 2. Create a dataset containing the graph from which we can sample
19 | 3. Initialize the Poincare ball on which the embeddings will be trained
20 | 4. Define the Poincare embedding model
21 | 5. Define the Poincare embedding loss function
22 | 6. Perform a few "burn-in" training epochs with reduced learning rate
23 |
24 | """
25 |
26 | ######################################################################
27 | # 1. Load the mammals subset from the WordNet hierarchy using NetworkX
28 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
29 |
30 | import json
31 | import os
32 |
33 | import networkx as nx
34 |
35 | # If you have stored the wordnet_mammals.json file differently that the definition here,
36 | # adjust the mammals_json_file string to correctly point to this json file. The file itself can
37 | # be found in the repository under tutorials/data/wordnet_mammals.json.
38 | root = os.path.dirname(os.path.abspath(__file__))
39 | mammals_json_file = os.path.join(root, "data", "wordnet_mammals.json")
40 | with open(mammals_json_file, "r") as json_file:
41 | graph_dict = json.load(json_file)
42 |
43 | mammals_graph = nx.node_link_graph(graph_dict)
44 |
45 |
46 | ###################################################################
47 | # 2. Create a dataset containing the graph from which we can sample
48 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
49 |
50 | import random
51 |
52 | import torch
53 | from torch.utils.data import DataLoader, Dataset
54 |
55 |
56 | class MammalsEmbeddingDataset(Dataset):
57 | def __init__(
58 | self,
59 | mammals: nx.DiGraph,
60 | ):
61 | super().__init__()
62 | self.mammals = mammals
63 | self.edges_list = list(mammals.edges())
64 |
65 | # This Dataset object has a sample for each edge in the graph.
66 | def __len__(self) -> int:
67 | return len(self.edges_list)
68 |
69 | def __getitem__(self, idx: int):
70 | # For each existing edge in the graph we choose 10 fake or negative edges, which we build
71 | # from the idx-th existing edge. So, first we grab this edge from the graph.
72 | rel = self.edges_list[idx]
73 |
74 | # Next, we take our source node rel[0] and see which nodes in the graph are not a child of
75 | # this node.
76 | negative_target_nodes = list(
77 | self.mammals.nodes() - nx.descendants(self.mammals, rel[0]) - {rel[0]}
78 | )
79 |
80 | # Then, we sample at most 5 of these negative target nodes...
81 | # NOTE: this type of sampling is a straightforward, but inefficient method. If your dataset
82 | # is larger, consider using more efficient sampling methods, as this will otherwise
83 | # form a bottleneck. See (closed) Issue 59 on GitHub for some more information.
84 | negative_target_sample_size = min(5, len(negative_target_nodes))
85 | negative_target_nodes_sample = random.sample(
86 | negative_target_nodes, negative_target_sample_size
87 | )
88 |
89 | # and add these to a tensor which will be used as input for our embedding model.
90 | edges = torch.tensor([rel] + [[rel[0], neg] for neg in negative_target_nodes_sample])
91 |
92 | # Next, we do the same with our target node rel[1], but now where we sample from nodes
93 | # which aren't a parent of it.
94 | negative_source_nodes = list(
95 | self.mammals.nodes() - nx.ancestors(self.mammals, rel[1]) - {rel[1]}
96 | )
97 |
98 | # We sample from these negative source nodes until we have a total of 10 negative edges...
99 | negative_source_sample_size = 10 - negative_target_sample_size
100 | negative_source_nodes_sample = random.sample(
101 | negative_source_nodes, negative_source_sample_size
102 | )
103 |
104 | # and add these to the tensor that we created above.
105 | edges = torch.cat(
106 | tensors=(edges, torch.tensor([[neg, rel[1]] for neg in negative_source_nodes_sample])),
107 | dim=0,
108 | )
109 |
110 | # Lastly, we create a tensor containing the labels of the edges, indicating whether it's a
111 | # True or a False edge.
112 | edge_label_targets = torch.cat(tensors=[torch.ones(1).bool(), torch.zeros(10).bool()])
113 |
114 | return edges, edge_label_targets
115 |
116 |
117 | # Now, we construct the dataset.
118 | dataset = MammalsEmbeddingDataset(
119 | mammals=mammals_graph,
120 | )
121 | dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
122 |
123 |
124 | #########################################################################
125 | # 3. Initialize the Poincare ball on which the embeddings will be trained
126 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
127 |
128 | from hypll.manifolds.poincare_ball import Curvature, PoincareBall
129 |
130 | poincare_ball = PoincareBall(Curvature(1.0))
131 |
132 |
133 | ########################################
134 | # 4. Define the Poincare embedding model
135 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
136 |
137 | import hypll.nn as hnn
138 |
139 |
140 | class PoincareEmbedding(hnn.HEmbedding):
141 | def __init__(
142 | self,
143 | num_embeddings: int,
144 | embedding_dim: int,
145 | manifold: PoincareBall,
146 | ):
147 | super().__init__(num_embeddings, embedding_dim, manifold)
148 |
149 | # The model outputs the distances between the nodes involved in the input edges as these are
150 | # used to compute the loss.
151 | def forward(self, edges: torch.Tensor) -> torch.Tensor:
152 | embeddings = super().forward(edges)
153 | edge_distances = self.manifold.dist(x=embeddings[:, :, 0, :], y=embeddings[:, :, 1, :])
154 | return edge_distances
155 |
156 |
157 | # We want to embed every node into a 2-dimensional Poincare ball.
158 | model = PoincareEmbedding(
159 | num_embeddings=len(mammals_graph.nodes()),
160 | embedding_dim=2,
161 | manifold=poincare_ball,
162 | )
163 |
164 |
165 | ################################################
166 | # 5. Define the Poincare embedding loss function
167 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
168 |
169 |
170 | # This function is given in equation (5) of the Poincare Embeddings paper.
171 | def poincare_embeddings_loss(dists: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
172 | logits = dists.neg().exp()
173 | numerator = torch.where(condition=targets, input=logits, other=0).sum(dim=-1)
174 | denominator = logits.sum(dim=-1)
175 | loss = (numerator / denominator).log().mean().neg()
176 | return loss
177 |
178 |
179 | #######################################################################
180 | # 6. Perform a few "burn-in" training epochs with reduced learning rate
181 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
182 |
183 | from hypll.optim import RiemannianSGD
184 |
185 | # The learning rate of 0.3 is dived by 10 during burn-in.
186 | optimizer = RiemannianSGD(
187 | params=model.parameters(),
188 | lr=0.3 / 10,
189 | )
190 |
191 | # Perform training as we would usually.
192 | for epoch in range(10):
193 | average_loss = 0
194 | for idx, (edges, edge_label_targets) in enumerate(dataloader):
195 | optimizer.zero_grad()
196 |
197 | dists = model(edges)
198 | loss = poincare_embeddings_loss(dists=dists, targets=edge_label_targets)
199 | loss.backward()
200 | optimizer.step()
201 |
202 | average_loss += loss
203 |
204 | average_loss /= len(dataloader)
205 | print(f"Burn-in epoch {epoch} loss: {average_loss}")
206 |
207 |
208 | #########################
209 | # 6. Train the embeddings
210 | # ^^^^^^^^^^^^^^^^^^^^^^^
211 |
212 | # Now we use the actual learning rate 0.3.
213 | optimizer = RiemannianSGD(
214 | params=model.parameters(),
215 | lr=0.3,
216 | )
217 |
218 | for epoch in range(300):
219 | average_loss = 0
220 | for idx, (edges, edge_label_targets) in enumerate(dataloader):
221 | optimizer.zero_grad()
222 |
223 | dists = model(edges)
224 | loss = poincare_embeddings_loss(dists=dists, targets=edge_label_targets)
225 | loss.backward()
226 | optimizer.step()
227 |
228 | average_loss += loss
229 |
230 | average_loss /= len(dataloader)
231 | print(f"Epoch {epoch} loss: {average_loss}")
232 |
233 | # You have now trained your own Poincare Embeddings!
234 |
--------------------------------------------------------------------------------