├── .github └── workflows │ ├── check_code_quality.yml │ ├── deploy.yml │ └── run_unit_tests.yml ├── .gitignore ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── README.md ├── make.bat └── source │ ├── api.rst │ ├── conf.py │ ├── index.rst │ └── install.rst ├── hypll ├── __init__.py ├── manifolds │ ├── __init__.py │ ├── base │ │ ├── __init__.py │ │ └── manifold.py │ ├── euclidean │ │ ├── __init__.py │ │ └── manifold.py │ └── poincare_ball │ │ ├── __init__.py │ │ ├── curvature.py │ │ ├── manifold.py │ │ └── math │ │ ├── __init__.py │ │ ├── diffgeom.py │ │ ├── linalg.py │ │ └── stats.py ├── nn │ ├── __init__.py │ └── modules │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── batchnorm.py │ │ ├── change_manifold.py │ │ ├── container.py │ │ ├── convolution.py │ │ ├── embedding.py │ │ ├── flatten.py │ │ ├── fold.py │ │ ├── linear.py │ │ └── pooling.py ├── optim │ ├── __init__.py │ ├── adam.py │ └── sgd.py ├── tensors │ ├── __init__.py │ ├── manifold_parameter.py │ ├── manifold_tensor.py │ └── tangent_tensor.py └── utils │ ├── __init__.py │ ├── layer_utils.py │ ├── math.py │ └── tensor_utils.py ├── poetry.lock ├── pyproject.toml ├── tests ├── manifolds │ ├── euclidean │ │ └── test_euclidean.py │ └── poincare_ball │ │ ├── test_curvature.py │ │ └── test_poincare_ball.py ├── nn │ ├── test_change_manifold.py │ ├── test_convolution.py │ └── test_flatten.py └── test_manifold_tensor.py └── tutorials ├── README.txt ├── cifar10_resnet_tutorial.py ├── cifar10_tutorial.py ├── data └── wordnet_mammals.json ├── hyperbolic_vit_tutorial.py └── poincare_embeddings_tutorial.py /.github/workflows/check_code_quality.yml: -------------------------------------------------------------------------------- 1 | name: Check Code Quality 2 | 3 | on: 4 | push: 5 | pull_request: 6 | workflow_dispatch: 7 | 8 | jobs: 9 | test: 10 | name: Check Code Quality 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout Code 14 | uses: actions/checkout@v3 15 | - name: Set up Python 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: "3.10" 19 | - uses: actions/cache@v2 20 | with: 21 | path: ${{ env.pythonLocation }} 22 | key: cache_v2_${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }} 23 | - name: Install Poetry 24 | run: | 25 | pip install poetry 26 | echo "$HOME/.local/bin" >> $GITHUB_PATH 27 | - name: Install Package 28 | run: poetry install 29 | - name: Check Code Quality 30 | run: | 31 | poetry run black --check . 32 | poetry run isort --check . 33 | -------------------------------------------------------------------------------- /.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | name: Deploy to PyPI 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | version-type: 7 | description: 'Version bump type (major, minor, patch)' 8 | required: true 9 | default: 'patch' 10 | 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Check out repository 16 | uses: actions/checkout@v2 17 | 18 | - name: Set up Python 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: '3.10' 22 | 23 | - name: Install Poetry 24 | run: | 25 | pip install poetry 26 | echo "$HOME/.local/bin" >> $GITHUB_PATH 27 | 28 | - name: Bump version 29 | run: poetry version ${{ github.event.inputs.version-type }} 30 | 31 | - name: Push changes 32 | run: | 33 | git config --local user.email "action@github.com" 34 | git config --local user.name "GitHub Action" 35 | git commit -am "Increase version [skip ci]" 36 | git push 37 | 38 | - name: Publish to PyPI 39 | run: poetry publish --build --username ${{ secrets.PYPI_USERNAME }} --password ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.github/workflows/run_unit_tests.yml: -------------------------------------------------------------------------------- 1 | name: Run Unit Tests 2 | 3 | on: 4 | push: 5 | pull_request: 6 | workflow_dispatch: 7 | 8 | jobs: 9 | test: 10 | name: Run Unit Tests 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout Code 14 | uses: actions/checkout@v3 15 | - name: Set up Python 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: "3.10" 19 | - uses: actions/cache@v2 20 | with: 21 | path: ${{ env.pythonLocation }} 22 | key: cache_v2_${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }} 23 | - name: Install Poetry 24 | run: | 25 | pip install poetry 26 | echo "$HOME/.local/bin" >> $GITHUB_PATH 27 | - name: Install Package 28 | run: poetry install 29 | - name: Run Unit Tests 30 | run: | 31 | poetry run pytest tests/ 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .env/ 3 | *.egg-info/ 4 | .pytest_cache 5 | /data 6 | /tests/data 7 | .vscode 8 | build 9 | /dist 10 | data 11 | docs/build 12 | docs/source/_autosummary 13 | docs/source/tutorials 14 | tests/data 15 | 16 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.11" 7 | 8 | sphinx: 9 | configuration: docs/source/conf.py 10 | 11 | python: 12 | install: 13 | - method: pip 14 | path: . 15 | extra_requirements: 16 | - docs 17 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to HypLL 2 | 3 | Thank you for considering contributing to HypLL! We are always looking for new contributors to help us implement the constantly improving hyperbolic learning methodology and to maintain what is already here, so we are happy to have you! 4 | 5 | We are looking for all sorts of contributions such as 6 | 12 | 13 | # Getting started 14 | ### Git workflow 15 | We want any new changes to HypLL to be linked to GitHub issues. If you have something new in mind that is not yet mentioned in an issue, please create an issue detailing the intended change first. Once you have an issue in mind that you would like to contribute to, [fork the repository](https://docs.github.com/en/get-started/quickstart/fork-a-repo). 16 | 17 | After you have finished your contribution, [open a pull request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) on GitHub to have your contribution merged into our `main` branch. Note that your pull request will automatically be checked to see if it matches some of our standards. To avoid running into errors during your pull request, consider performing these tests locally first. A description on how to run the tests is given down below. 18 | 19 | ### Local setup 20 | In order to get started locally with HypLL development, first clone the repository and make sure to install the repository with the correct optional dependencies from the `pyproject.toml` file. For example, if you only need the base development dependencies, navigate to the root directory (so the one containing this file) of your local version of this repository and install using: 21 | ``` 22 | pip install -e .[dev] 23 | ``` 24 | We recommend using a virtual environment tool such as `conda` or `venv`. 25 | 26 | For further instructions, please read the section below that corresponds to the type of contribution that you intend to make. 27 | 28 | 29 | # Contributing to development 30 | **Optional dependencies:** If you are making development contributions, then you will at least need the `dev` optional dependencies. If your contribution warrants additional changes to the documentation, then the `docs` optional dependencies will likely also be of use for testing the documentation build process. 31 | 32 | **Formatting:** For formatting we use [black](https://black.readthedocs.io) and [isort](https://pycqa.github.io/isort/). If you are unfamiliar with these, feel free to check out their documentations. However, it should not be a large problem if you are unfamiliar with their style guides, since they automatically format your code for you (in most cases). `black` is a style formatter that ensures uniformity of the coding style throughout the project, which helps with readability. To use `black`, simply run (inside the root directory) 33 | ``` 34 | black . 35 | ``` 36 | `isort` is a utility that automatically sorts and separates imports to also improve readability. To use `isort`, simply run (inside the root directory) 37 | ``` 38 | isort . 39 | ``` 40 | 41 | **Testing:** When your contribution is a simple bugfix it can be sufficient to use the existing tests. However, in most cases you will have to add additional tests to check your new feature or to ensure that whatever bug you are fixing no longer occurs. These tests can be added to the `/tests` directory. We use [pytest](https://docs.pytest.org) for testing, so if you are unfamiliar with this, please check their documentation or use the other tests as an example. If you think your contribution is ready, you can test it by running (inside the root directory) 42 | ``` 43 | pytest 44 | ``` 45 | If you made any changes to the documentation then you will also need to test the build process of the docs. You can read how to do this down below underneath the "Contributing to documentation" header. 46 | 47 | 48 | # Contributing to documentation 49 | **Optional dependencies:** When making documentation contributions, you will most likely need the `docs` optional dependencies to test the build process of the docs and the `dev` optional dependencies for formatting. 50 | 51 | **Contributing:** Our documentation is built using the documentation generator [sphinx](https://www.sphinx-doc.org). Most of the documentation is generated automatically from the docstrings, so if you want to contribute to the API documentation, a simple change to the docstring of the relevant part of the code should suffice. For more complicated changes, please take a look at the `sphinx` documentation. 52 | 53 | **Formatting:** For formatting new documentation, please use the existing documentation as examples. Once you are happy with your contribution, use `black` as described under the "Contributing to development" header to ensure that the code satisfies our formatting style. 54 | 55 | **Testing:** To test the build process of the documentation, please follow the instructions within the `README` inside the `docs` directory. 56 | 57 | 58 | # Contributing to new tutorials 59 | **Optional dependencies:** When contributing to the tutorials, you will need the `dev` optional dependencies. 60 | 61 | **Contributing:** We have two types of tutorials: 62 |
    63 |
  1. Basic tutorials for learning about HypLL
  2. 64 |
  3. Tutorials showcasing how to implement peer-reviewed papers using HypLL
  4. 65 |
66 | If your intended contribution falls under the first category, first make sure that there is a need for your specific tutorial. In other words, make sure that there is an issue asking for your tutorial and make sure that its inclusion is already agreed upon. If it falls under the second category, make sure that the paper that you are implementing is a peer-reviewed publication and, again, that there is an issue asking for this implementation. 67 | 68 | **Formatting:** The formatting for tutorials is not very strict, but please use the existing tutorials as a guideline and ensure that the tutorials are self-contained. 69 | 70 | **Testing:** Once you are ready with your tutorial, please test it by ensuring that it runs to completion. 71 | 72 | 73 | # Code review process 74 | Pull requests have to be reviewed by at least one of the HypLL collaborators. 75 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Hyperbolic Learning Library developers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hyperbolic Learning Library 2 | 3 | [![Documentation Status](https://readthedocs.org/projects/hyperbolic-learning-library/badge/?version=latest)](https://hyperbolic-learning-library.readthedocs.io/en/latest/?badge=latest) 4 | ![Unit Tests](https://github.com/maxvanspengler/hyperbolic_pytorch/workflows/Run%20Unit%20Tests/badge.svg) 5 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 6 | [![isort: checked](https://img.shields.io/badge/isort-checked-yellow)](https://github.com/PyCQA/isort) 7 | 8 | An extension of the PyTorch library containing various tools for performing deep learning in hyperbolic space. 9 | 10 | Contents: 11 | * [Documentation](#documentation) 12 | * [Installation](#installation) 13 | * [BibTeX](#bibtex) 14 | 15 | 16 | ## Documentation 17 | Visit our [documentation](https://hyperbolic-learning-library.readthedocs.io/en/latest/index.html) for tutorials and more. 18 | 19 | 20 | ## Installation 21 | 22 | The Hyperbolic Learning Library was written for Python 3.10+ and PyTorch 1.11+. 23 | 24 | It's recommended to have a 25 | working PyTorch installation before setting up HypLL: 26 | 27 | * [PyTorch](https://pytorch.org/get-started/locally/) installation instructions. 28 | 29 | Start by setting up a Python [virtual environment](https://docs.python.org/3/library/venv.html): 30 | 31 | ``` 32 | python -venv .env 33 | ``` 34 | 35 | Activate the virtual environment on Linux and MacOs: 36 | ``` 37 | source .env/bin/activate 38 | ``` 39 | Or on Windows: 40 | ``` 41 | .env/Scripts/activate 42 | ``` 43 | 44 | Finally, install HypLL from PyPI. 45 | 46 | ``` 47 | pip install hypll 48 | ``` 49 | 50 | ## BibTeX 51 | If you would like to cite this project, please use the following bibtex entry 52 | ``` 53 | @article{spengler2023hypll, 54 | title={HypLL: The Hyperbolic Learning Library}, 55 | author={van Spengler, Max and Wirth, Philipp and Mettes, Pascal}, 56 | journal={arXiv preprint arXiv:2306.06154}, 57 | year={2023} 58 | } 59 | ``` 60 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Build the Docs 2 | `sphinx` provides a Makefile, so to build the `html` documentation, simply type: 3 | ```make html 4 | ``` 5 | 6 | Or, alternatively 7 | ``` 8 | SPHINXBUILD="python3 -m sphinx.cmd.build" make html 9 | ``` 10 | 11 | You can browse the docs by opening the generated html files in the `docs/build` directory 12 | with a browser of your choice. -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | 2 | .. autosummary:: 3 | :toctree: _autosummary 4 | :nosignatures: 5 | :recursive: 6 | 7 | hypll -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | import os 10 | import sys 11 | 12 | sys.path.insert(0, os.path.abspath("../..")) # Source code dir relative to this file 13 | 14 | project = "hypll" 15 | copyright = '2023, ""' 16 | author = '""' 17 | 18 | # -- General configuration --------------------------------------------------- 19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 20 | 21 | extensions = [ 22 | "sphinx_copybutton", 23 | "sphinx.ext.autodoc", # Core library for html generation from docstrings 24 | "sphinx.ext.autosummary", 25 | "sphinx.ext.napoleon", 26 | "sphinx_gallery.gen_gallery", 27 | "sphinx_tabs.tabs", 28 | ] 29 | 30 | templates_path = ["_templates"] 31 | exclude_patterns = [] 32 | 33 | # Autodoc options: 34 | autodoc_class_signature = "separated" 35 | autodoc_default_options = { 36 | "members": True, 37 | "member-order": "bysource", 38 | "special-members": "__init__", 39 | "undoc-members": True, 40 | "exclude-members": "__weakref__", 41 | } 42 | autodoc_default_flags = [ 43 | "members", 44 | "undoc-members", 45 | "special-members", 46 | "show-inheritance", 47 | ] 48 | 49 | # Napoleon options: 50 | napoleon_google_docstring = True 51 | napoleon_numpy_docstring = False 52 | napoleon_include_init_with_doc = False 53 | napoleon_include_private_with_doc = False 54 | napoleon_include_special_with_doc = False 55 | napoleon_use_admonition_for_examples = False 56 | napoleon_use_admonition_for_notes = False 57 | napoleon_use_admonition_for_references = False 58 | napoleon_use_ivar = False 59 | napoleon_use_param = False 60 | napoleon_use_rtype = False 61 | napoleon_type_aliases = None 62 | 63 | # Sphinx gallery config: 64 | sphinx_gallery_conf = { 65 | "examples_dirs": "../../tutorials", 66 | "gallery_dirs": "tutorials/", 67 | "filename_pattern": "", 68 | # TODO(Philipp, 06/23): Figure out how we can build and host tutorials on RTD. 69 | "plot_gallery": "False", 70 | } 71 | 72 | # -- Options for HTML output ------------------------------------------------- 73 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 74 | 75 | html_theme = "alabaster" 76 | html_static_path = ["_static"] 77 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | 2 | Hyperbolic Learning Library 3 | =========================== 4 | 5 | The Hyperbolic Learning Library (HypLL) was designed to bring 6 | the progress on hyperbolic deep learning together. HypLL is built on 7 | top of PyTorch, with an emphasis on ease-of-use. 8 | 9 | All code is open-source. Visit us on `GitHub `_ 10 | to see open issues and share your feedback! 11 | 12 | .. toctree:: 13 | :caption: Getting Started 14 | :maxdepth: 1 15 | 16 | install.rst 17 | 18 | 19 | .. toctree:: 20 | :caption: Tutorials 21 | :name: tutorials 22 | :maxdepth: 1 23 | 24 | tutorials/cifar10_tutorial.rst 25 | tutorials/cifar10_resnet_tutorial.rst 26 | tutorials/poincare_embeddings_tutorial.rst 27 | 28 | 29 | .. toctree:: 30 | :caption: API 31 | :maxdepth: 2 32 | 33 | api.rst 34 | 35 | 36 | 37 | Indices and tables 38 | ------------------ 39 | 40 | * :ref:`genindex` 41 | * :ref:`modindex` 42 | * :ref:`search` 43 | -------------------------------------------------------------------------------- /docs/source/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | HypLL was written for Python 3.10+ and PyTorch 1.11+. 5 | 6 | It's recommended to have a 7 | working PyTorch installation before setting up HypLL: 8 | 9 | * `PyTorch `_ installation instructions. 10 | 11 | 12 | Install with pip 13 | ---------------- 14 | 15 | Start by setting up a Python `virtual environment `_: 16 | 17 | .. code:: 18 | 19 | python -venv .env 20 | 21 | Activate the virtual environment. 22 | 23 | .. tabs:: 24 | 25 | .. tab:: Linux and MacOs 26 | 27 | .. code:: 28 | 29 | source .env/bin/activate 30 | 31 | .. tab:: Windows 32 | 33 | .. code:: 34 | 35 | .env/Scripts/activate 36 | 37 | 38 | Install HypLL from PyPI. 39 | 40 | .. code:: 41 | 42 | pip install hypll 43 | 44 | 45 | Congratulations, you are ready to do machine learning in hyperbolic space. 46 | Check out our :ref:`tutorials` next to get started! 47 | 48 | -------------------------------------------------------------------------------- /hypll/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxvanspengler/hyperbolic_learning_library/62ea7882848d878459eba4100aeb6b8ad64cc38f/hypll/__init__.py -------------------------------------------------------------------------------- /hypll/manifolds/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Manifold 2 | -------------------------------------------------------------------------------- /hypll/manifolds/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .manifold import Manifold 2 | -------------------------------------------------------------------------------- /hypll/manifolds/base/manifold.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import TYPE_CHECKING, List, Optional, Tuple, Union 5 | 6 | from torch import Tensor 7 | from torch.nn import Module, Parameter 8 | from torch.nn.common_types import _size_2_t 9 | 10 | # TODO: find a less hacky solution for this 11 | if TYPE_CHECKING: 12 | from hypll.tensors import ManifoldParameter, ManifoldTensor, TangentTensor 13 | 14 | 15 | class Manifold(Module, ABC): 16 | def __init__(self) -> None: 17 | super(Manifold, self).__init__() 18 | 19 | @abstractmethod 20 | def project(self, x: ManifoldTensor, eps: float = -1.0) -> ManifoldTensor: 21 | raise NotImplementedError 22 | 23 | @abstractmethod 24 | def expmap(self, v: TangentTensor) -> ManifoldTensor: 25 | raise NotImplementedError 26 | 27 | @abstractmethod 28 | def logmap(self, x: Optional[ManifoldTensor], y: ManifoldTensor) -> TangentTensor: 29 | raise NotImplementedError 30 | 31 | @abstractmethod 32 | def transp(self, v: TangentTensor, y: ManifoldTensor) -> TangentTensor: 33 | raise NotImplementedError 34 | 35 | @abstractmethod 36 | def dist(self, x: ManifoldTensor, y: ManifoldTensor) -> Tensor: 37 | raise NotImplementedError 38 | 39 | @abstractmethod 40 | def cdist(self, x: ManifoldTensor, y: ManifoldTensor) -> Tensor: 41 | raise NotImplementedError 42 | 43 | @abstractmethod 44 | def euc_to_tangent(self, x: ManifoldTensor, u: ManifoldTensor) -> TangentTensor: 45 | raise NotImplementedError 46 | 47 | @abstractmethod 48 | def hyperplane_dists(self, x: ManifoldTensor, z: ManifoldTensor, r: Optional[Tensor]) -> Tensor: 49 | raise NotImplementedError 50 | 51 | @abstractmethod 52 | def fully_connected( 53 | self, x: ManifoldTensor, z: ManifoldTensor, bias: Optional[Tensor] 54 | ) -> ManifoldTensor: 55 | raise NotImplementedError 56 | 57 | @abstractmethod 58 | def frechet_mean( 59 | self, 60 | x: ManifoldTensor, 61 | batch_dim: Union[int, list[int]] = 0, 62 | keepdim: bool = False, 63 | ) -> ManifoldTensor: 64 | raise NotImplementedError 65 | 66 | @abstractmethod 67 | def midpoint( 68 | self, 69 | x: ManifoldTensor, 70 | batch_dim: Union[int, list[int]] = 0, 71 | w: Optional[Tensor] = None, 72 | keepdim: bool = False, 73 | ) -> ManifoldTensor: 74 | raise NotImplementedError 75 | 76 | @abstractmethod 77 | def frechet_variance( 78 | self, 79 | x: ManifoldTensor, 80 | mu: Optional[ManifoldTensor] = None, 81 | batch_dim: Union[int, list[int]] = -1, 82 | keepdim: bool = False, 83 | ) -> Tensor: 84 | raise NotImplementedError 85 | 86 | @abstractmethod 87 | def construct_dl_parameters( 88 | self, in_features: int, out_features: int, bias: bool = True 89 | ) -> Union[ManifoldParameter, tuple[ManifoldParameter, Parameter]]: 90 | # TODO: make an annotation object for the return type of this method 91 | raise NotImplementedError 92 | 93 | @abstractmethod 94 | def reset_parameters(self, weight: ManifoldParameter, bias: Parameter) -> None: 95 | raise NotImplementedError 96 | 97 | @abstractmethod 98 | def flatten(self, x: ManifoldTensor, start_dim: int = 1, end_dim: int = -1) -> ManifoldTensor: 99 | raise NotImplementedError 100 | 101 | @abstractmethod 102 | def unfold( 103 | self, 104 | input: ManifoldTensor, 105 | kernel_size: _size_2_t, 106 | dilation: _size_2_t = 1, 107 | padding: _size_2_t = 0, 108 | stride: _size_2_t = 1, 109 | ) -> ManifoldTensor: 110 | raise NotImplementedError 111 | 112 | @abstractmethod 113 | def cat( 114 | self, 115 | manifold_tensors: Union[Tuple[ManifoldTensor, ...], List[ManifoldTensor]], 116 | dim: int = 0, 117 | ) -> ManifoldTensor: 118 | raise NotImplementedError 119 | -------------------------------------------------------------------------------- /hypll/manifolds/euclidean/__init__.py: -------------------------------------------------------------------------------- 1 | from .manifold import Euclidean 2 | -------------------------------------------------------------------------------- /hypll/manifolds/euclidean/manifold.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | from torch import Tensor, broadcast_shapes, empty, matmul, var 6 | from torch.nn import Parameter 7 | from torch.nn.common_types import _size_2_t 8 | from torch.nn.functional import unfold 9 | from torch.nn.init import _calculate_fan_in_and_fan_out, kaiming_uniform_, uniform_ 10 | 11 | from hypll.manifolds.base import Manifold 12 | from hypll.tensors import ManifoldParameter, ManifoldTensor, TangentTensor 13 | from hypll.utils.tensor_utils import ( 14 | check_dims_with_broadcasting, 15 | check_if_man_dims_match, 16 | check_tangent_tensor_positions, 17 | ) 18 | 19 | 20 | class Euclidean(Manifold): 21 | def __init__(self): 22 | super(Euclidean, self).__init__() 23 | 24 | def project(self, x: ManifoldTensor, eps: float = -1.0) -> ManifoldTensor: 25 | return x 26 | 27 | def expmap(self, v: TangentTensor) -> ManifoldTensor: 28 | if v.manifold_points is None: 29 | new_tensor = v.tensor 30 | else: 31 | new_tensor = v.manifold_points.tensor + v.tensor 32 | return ManifoldTensor(data=new_tensor, manifold=self, man_dim=v.broadcasted_man_dim) 33 | 34 | def logmap(self, x: Optional[ManifoldTensor], y: ManifoldTensor) -> TangentTensor: 35 | if x is None: 36 | dim = y.man_dim 37 | new_tensor = y.tensor 38 | else: 39 | dim = check_dims_with_broadcasting(x, y) 40 | new_tensor = y.tensor - x.tensor 41 | return TangentTensor(data=new_tensor, manifold_points=x, manifold=self, man_dim=dim) 42 | 43 | def transp(self, v: TangentTensor, y: ManifoldTensor) -> TangentTensor: 44 | dim = check_dims_with_broadcasting(v, y) 45 | output_shape = broadcast_shapes(v.size(), y.size()) 46 | new_tensor = v.tensor.broadcast_to(output_shape) 47 | return TangentTensor(data=new_tensor, manifold_points=y, manifold=self, man_dim=dim) 48 | 49 | def dist(self, x: ManifoldTensor, y: ManifoldTensor) -> Tensor: 50 | dim = check_dims_with_broadcasting(x, y) 51 | return (y.tensor - x.tensor).norm(dim=dim) 52 | 53 | def inner( 54 | self, u: TangentTensor, v: TangentTensor, keepdim: bool = False, safe_mode: bool = True 55 | ) -> Tensor: 56 | dim = check_dims_with_broadcasting(u, v) 57 | if safe_mode: 58 | check_tangent_tensor_positions(u, v) 59 | 60 | return (u.tensor * v.tensor).sum(dim=dim, keepdim=keepdim) 61 | 62 | def euc_to_tangent(self, x: ManifoldTensor, u: ManifoldTensor) -> TangentTensor: 63 | dim = check_dims_with_broadcasting(x, u) 64 | return TangentTensor( 65 | data=u.tensor, 66 | manifold_points=x, 67 | manifold=self, 68 | man_dim=dim, 69 | ) 70 | 71 | def hyperplane_dists(self, x: ManifoldTensor, z: ManifoldTensor, r: Optional[Tensor]) -> Tensor: 72 | if x.man_dim != 1 or z.man_dim != 0: 73 | raise ValueError( 74 | f"Expected the manifold dimension of the inputs to be 1 and the manifold " 75 | f"dimension of the hyperplane orientations to be 0, but got {x.man_dim} and " 76 | f"{z.man_dim}, respectively" 77 | ) 78 | if r is None: 79 | return matmul(x.tensor, z.tensor) 80 | else: 81 | return matmul(x.tensor, z.tensor) + r 82 | 83 | def fully_connected( 84 | self, x: ManifoldTensor, z: ManifoldTensor, bias: Optional[Tensor] 85 | ) -> ManifoldTensor: 86 | if z.man_dim != 0: 87 | raise ValueError( 88 | f"Expected the manifold dimension of the hyperplane orientations to be 0, but got " 89 | f"{z.man_dim} instead" 90 | ) 91 | 92 | dim_shifted_x_tensor = x.tensor.movedim(source=x.man_dim, destination=-1) 93 | dim_shifted_new_tensor = matmul(dim_shifted_x_tensor, z.tensor) 94 | if bias is not None: 95 | dim_shifted_new_tensor = dim_shifted_new_tensor + bias 96 | new_tensor = dim_shifted_new_tensor.movedim(source=-1, destination=x.man_dim) 97 | 98 | return ManifoldTensor(data=new_tensor, manifold=self, man_dim=x.man_dim) 99 | 100 | def frechet_mean( 101 | self, 102 | x: ManifoldTensor, 103 | batch_dim: Union[int, list[int]] = 0, 104 | keepdim: bool = False, 105 | ) -> ManifoldTensor: 106 | if isinstance(batch_dim, int): 107 | batch_dim = [batch_dim] 108 | 109 | if x.man_dim in batch_dim: 110 | raise ValueError( 111 | f"Tried to aggregate over dimensions {batch_dim}, but input has manifold " 112 | f"dimension {x.man_dim} and cannot aggregate over this dimension" 113 | ) 114 | 115 | # Output manifold dimension is shifted left for each batch dim that disappears 116 | man_dim_shift = sum(bd < x.man_dim for bd in batch_dim) 117 | new_man_dim = x.man_dim - man_dim_shift if not keepdim else x.man_dim 118 | 119 | mean_tensor = x.tensor.mean(dim=batch_dim, keepdim=keepdim) 120 | 121 | return ManifoldTensor(data=mean_tensor, manifold=self, man_dim=new_man_dim) 122 | 123 | def midpoint( 124 | self, 125 | x: ManifoldTensor, 126 | batch_dim: Union[int, list[int]] = 0, 127 | w: Optional[Tensor] = None, 128 | keepdim: bool = False, 129 | ) -> ManifoldTensor: 130 | return self.frechet_mean(x=x, batch_dim=batch_dim, keepdim=keepdim) 131 | 132 | def frechet_variance( 133 | self, 134 | x: ManifoldTensor, 135 | mu: Optional[ManifoldTensor] = None, 136 | batch_dim: Union[int, list[int]] = -1, 137 | keepdim: bool = False, 138 | ) -> Tensor: 139 | if isinstance(batch_dim, int): 140 | batch_dim = [batch_dim] 141 | 142 | if mu is None: 143 | return var(x.tensor, dim=batch_dim, keepdim=keepdim) 144 | else: 145 | if x.dim() != mu.dim(): 146 | for bd in sorted(batch_dim): 147 | mu.man_dim += 1 if bd <= mu.man_dim else 0 148 | mu.tensor = mu.tensor.unsqueeze(bd) 149 | if mu.man_dim != x.man_dim: 150 | raise ValueError("Input tensor and mean do not have matching manifold dimensions") 151 | n = 1 152 | for bd in batch_dim: 153 | n *= x.size(dim=bd) 154 | return (x.tensor - mu.tensor).pow(2).sum(dim=batch_dim, keepdim=keepdim) / (n - 1) 155 | 156 | def construct_dl_parameters( 157 | self, in_features: int, out_features: int, bias: bool = True 158 | ) -> tuple[ManifoldParameter, Optional[Parameter]]: 159 | weight = ManifoldParameter( 160 | data=empty(in_features, out_features), 161 | manifold=self, 162 | man_dim=0, 163 | ) 164 | 165 | if bias: 166 | b = Parameter(data=empty(out_features)) 167 | else: 168 | b = None 169 | 170 | return weight, b 171 | 172 | def reset_parameters(self, weight: ManifoldParameter, bias: Parameter) -> None: 173 | kaiming_uniform_(weight.tensor, a=sqrt(5)) 174 | if bias is not None: 175 | fan_in, _ = _calculate_fan_in_and_fan_out(weight.tensor) 176 | bound = 1 / sqrt(fan_in) if fan_in > 0 else 0 177 | uniform_(bias, -bound, bound) 178 | 179 | def unfold( 180 | self, 181 | input: ManifoldTensor, 182 | kernel_size: _size_2_t, 183 | dilation: _size_2_t = 1, 184 | padding: _size_2_t = 0, 185 | stride: _size_2_t = 1, 186 | ) -> ManifoldTensor: 187 | new_tensor = unfold( 188 | input=input.tensor, 189 | kernel_size=kernel_size, 190 | dilation=dilation, 191 | padding=padding, 192 | stride=stride, 193 | ) 194 | return ManifoldTensor(data=new_tensor, manifold=input.manifold, man_dim=1) 195 | 196 | def flatten(self, x: ManifoldTensor, start_dim: int = 1, end_dim: int = -1) -> ManifoldTensor: 197 | """Flattens a manifold tensor by reshaping it. If start_dim or end_dim are passed, 198 | only dimensions starting with start_dim and ending with end_dim are flattend. 199 | 200 | Updates the manifold dimension if necessary. 201 | 202 | """ 203 | start_dim = x.dim() + start_dim if start_dim < 0 else start_dim 204 | end_dim = x.dim() + end_dim if end_dim < 0 else end_dim 205 | 206 | # Get the range of dimensions to flatten. 207 | dimensions_to_flatten = x.shape[start_dim + 1 : end_dim + 1] 208 | 209 | # Get the new manifold dimension. 210 | if start_dim <= x.man_dim and end_dim >= x.man_dim: 211 | man_dim = start_dim 212 | elif end_dim <= x.man_dim: 213 | man_dim = x.man_dim - len(dimensions_to_flatten) 214 | else: 215 | man_dim = x.man_dim 216 | 217 | # Flatten the tensor and return the new instance. 218 | flattened = torch.flatten( 219 | input=x.tensor, 220 | start_dim=start_dim, 221 | end_dim=end_dim, 222 | ) 223 | return ManifoldTensor(data=flattened, manifold=x.manifold, man_dim=man_dim) 224 | 225 | def cdist(self, x: ManifoldTensor, y: ManifoldTensor) -> Tensor: 226 | return torch.cdist(x.tensor, y.tensor) 227 | 228 | def cat( 229 | self, 230 | manifold_tensors: Union[Tuple[ManifoldTensor, ...], List[ManifoldTensor]], 231 | dim: int = 0, 232 | ) -> ManifoldTensor: 233 | check_if_man_dims_match(manifold_tensors) 234 | cat = torch.cat([t.tensor for t in manifold_tensors], dim=dim) 235 | man_dim = manifold_tensors[0].man_dim 236 | return ManifoldTensor(data=cat, manifold=self, man_dim=man_dim) 237 | -------------------------------------------------------------------------------- /hypll/manifolds/poincare_ball/__init__.py: -------------------------------------------------------------------------------- 1 | from .curvature import Curvature 2 | from .manifold import PoincareBall 3 | -------------------------------------------------------------------------------- /hypll/manifolds/poincare_ball/curvature.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | from torch import Tensor, as_tensor 5 | from torch.nn import Module 6 | from torch.nn.functional import softplus 7 | from torch.nn.parameter import Parameter 8 | 9 | 10 | class Curvature(Module): 11 | """Class representing curvature of a manifold. 12 | 13 | Attributes: 14 | value: 15 | Learnable parameter indicating curvature of the manifold. The actual 16 | curvature is calculated as constraining_strategy(value). 17 | constraining_strategy: 18 | Function applied to the curvature value in order to constrain the 19 | curvature of the manifold. By default uses softplus to guarantee 20 | positive curvature. 21 | requires_grad: 22 | If the curvature requires gradient. False by default. 23 | 24 | """ 25 | 26 | def __init__( 27 | self, 28 | value: float = 1.0, 29 | constraining_strategy: Callable[[Tensor], Tensor] = softplus, 30 | requires_grad: bool = False, 31 | ): 32 | super(Curvature, self).__init__() 33 | self.value = Parameter( 34 | data=as_tensor(value), 35 | requires_grad=requires_grad, 36 | ) 37 | self.constraining_strategy = constraining_strategy 38 | 39 | def forward(self) -> Tensor: 40 | """Returns curvature calculated as constraining_strategy(value).""" 41 | return self.constraining_strategy(self.value) 42 | -------------------------------------------------------------------------------- /hypll/manifolds/poincare_ball/manifold.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | from torch import Tensor, empty, eye, no_grad 6 | from torch.nn import Parameter 7 | from torch.nn.common_types import _size_2_t 8 | from torch.nn.functional import softplus, unfold 9 | from torch.nn.init import normal_, zeros_ 10 | 11 | from hypll.manifolds.base import Manifold 12 | from hypll.manifolds.euclidean import Euclidean 13 | from hypll.manifolds.poincare_ball.curvature import Curvature 14 | from hypll.tensors import ManifoldParameter, ManifoldTensor, TangentTensor 15 | from hypll.utils.math import beta_func 16 | from hypll.utils.tensor_utils import ( 17 | check_dims_with_broadcasting, 18 | check_if_man_dims_match, 19 | check_tangent_tensor_positions, 20 | ) 21 | 22 | from .math.diffgeom import ( 23 | cdist, 24 | dist, 25 | euc_to_tangent, 26 | expmap, 27 | expmap0, 28 | gyration, 29 | inner, 30 | logmap, 31 | logmap0, 32 | mobius_add, 33 | project, 34 | transp, 35 | ) 36 | from .math.linalg import poincare_fully_connected, poincare_hyperplane_dists 37 | from .math.stats import frechet_mean, frechet_variance, midpoint 38 | 39 | 40 | class PoincareBall(Manifold): 41 | """Class representing the Poincare ball model of hyperbolic space. 42 | 43 | Implementation based on the geoopt implementation, but changed to use 44 | hyperbolic torch functions. 45 | 46 | Attributes: 47 | c: 48 | Curvature of the manifold. 49 | 50 | """ 51 | 52 | def __init__(self, c: Curvature): 53 | """Initializes an instance of PoincareBall manifold. 54 | 55 | Examples: 56 | >>> from hypll.manifolds.poincare_ball import PoincareBall, Curvature 57 | >>> curvature = Curvature(value=1.0) 58 | >>> manifold = Manifold(c=curvature) 59 | 60 | """ 61 | super(PoincareBall, self).__init__() 62 | self.c = c 63 | 64 | def mobius_add(self, x: ManifoldTensor, y: ManifoldTensor) -> ManifoldTensor: 65 | dim = check_dims_with_broadcasting(x, y) 66 | new_tensor = mobius_add(x=x.tensor, y=y.tensor, c=self.c(), dim=dim) 67 | return ManifoldTensor(data=new_tensor, manifold=self, man_dim=dim) 68 | 69 | def project(self, x: ManifoldTensor, eps: float = -1.0) -> ManifoldTensor: 70 | new_tensor = project(x=x.tensor, c=self.c(), dim=x.man_dim, eps=eps) 71 | return ManifoldTensor(data=new_tensor, manifold=self, man_dim=x.man_dim) 72 | 73 | def expmap(self, v: TangentTensor) -> ManifoldTensor: 74 | dim = v.broadcasted_man_dim 75 | if v.manifold_points is None: 76 | new_tensor = expmap0(v=v.tensor, c=self.c(), dim=dim) 77 | else: 78 | new_tensor = expmap(x=v.manifold_points.tensor, v=v.tensor, c=self.c(), dim=dim) 79 | return ManifoldTensor(data=new_tensor, manifold=self, man_dim=dim) 80 | 81 | def logmap(self, x: Optional[ManifoldTensor], y: ManifoldTensor): 82 | if x is None: 83 | dim = y.man_dim 84 | new_tensor = logmap0(y=y.tensor, c=self.c(), dim=y.man_dim) 85 | else: 86 | dim = check_dims_with_broadcasting(x, y) 87 | new_tensor = logmap(x=x.tensor, y=y.tensor, c=self.c(), dim=dim) 88 | return TangentTensor(data=new_tensor, manifold_points=x, manifold=self, man_dim=dim) 89 | 90 | def gyration(self, u: ManifoldTensor, v: ManifoldTensor, w: ManifoldTensor) -> ManifoldTensor: 91 | dim = check_dims_with_broadcasting(u, v, w) 92 | new_tensor = gyration(u=u.tensor, v=v.tensor, w=w.tensor, c=self.c(), dim=dim) 93 | return ManifoldTensor(data=new_tensor, manifold=self, man_dim=dim) 94 | 95 | def transp(self, v: TangentTensor, y: ManifoldTensor) -> TangentTensor: 96 | dim = check_dims_with_broadcasting(v, y) 97 | tangent_vectors = transp( 98 | x=v.manifold_points.tensor, y=y.tensor, v=v.tensor, c=self.c(), dim=dim 99 | ) 100 | return TangentTensor( 101 | data=tangent_vectors, 102 | manifold_points=y, 103 | manifold=self, 104 | man_dim=dim, 105 | ) 106 | 107 | def dist(self, x: ManifoldTensor, y: ManifoldTensor) -> Tensor: 108 | dim = check_dims_with_broadcasting(x, y) 109 | return dist(x=x.tensor, y=y.tensor, c=self.c(), dim=dim) 110 | 111 | def inner( 112 | self, u: TangentTensor, v: TangentTensor, keepdim: bool = False, safe_mode: bool = True 113 | ) -> Tensor: 114 | dim = check_dims_with_broadcasting(u, v) 115 | if safe_mode: 116 | check_tangent_tensor_positions(u, v) 117 | 118 | return inner( 119 | x=u.manifold_points.tensor, u=u.tensor, v=v.tensor, c=self.c(), dim=dim, keepdim=keepdim 120 | ) 121 | 122 | def euc_to_tangent(self, x: ManifoldTensor, u: ManifoldTensor) -> TangentTensor: 123 | dim = check_dims_with_broadcasting(x, u) 124 | tangent_vectors = euc_to_tangent(x=x.tensor, u=u.tensor, c=self.c(), dim=x.man_dim) 125 | return TangentTensor( 126 | data=tangent_vectors, 127 | manifold_points=x, 128 | manifold=self, 129 | man_dim=dim, 130 | ) 131 | 132 | def hyperplane_dists(self, x: ManifoldTensor, z: ManifoldTensor, r: Optional[Tensor]) -> Tensor: 133 | if x.man_dim != 1 or z.man_dim != 0: 134 | raise ValueError( 135 | f"Expected the manifold dimension of the inputs to be 1 and the manifold " 136 | f"dimension of the hyperplane orientations to be 0, but got {x.man_dim} and " 137 | f"{z.man_dim}, respectively" 138 | ) 139 | return poincare_hyperplane_dists(x=x.tensor, z=z.tensor, r=r, c=self.c()) 140 | 141 | def fully_connected( 142 | self, x: ManifoldTensor, z: ManifoldTensor, bias: Optional[Tensor] 143 | ) -> ManifoldTensor: 144 | if z.man_dim != 0: 145 | raise ValueError( 146 | f"Expected the manifold dimension of the hyperplane orientations to be 0, but got " 147 | f"{z.man_dim} instead" 148 | ) 149 | new_tensor = poincare_fully_connected( 150 | x=x.tensor, z=z.tensor, bias=bias, c=self.c(), dim=x.man_dim 151 | ) 152 | new_tensor = ManifoldTensor(data=new_tensor, manifold=self, man_dim=x.man_dim) 153 | return self.project(new_tensor) 154 | 155 | def frechet_mean( 156 | self, 157 | x: ManifoldTensor, 158 | batch_dim: Union[int, list[int]] = 0, 159 | keepdim: bool = False, 160 | ) -> ManifoldTensor: 161 | if isinstance(batch_dim, int): 162 | batch_dim = [batch_dim] 163 | output_man_dim = x.man_dim - sum(bd < x.man_dim for bd in batch_dim) 164 | new_tensor = frechet_mean( 165 | x=x.tensor, c=self.c(), vec_dim=x.man_dim, batch_dim=batch_dim, keepdim=keepdim 166 | ) 167 | return ManifoldTensor(data=new_tensor, manifold=self, man_dim=output_man_dim) 168 | 169 | def midpoint( 170 | self, 171 | x: ManifoldTensor, 172 | batch_dim: int = 0, 173 | w: Optional[Tensor] = None, 174 | keepdim: bool = False, 175 | ) -> ManifoldTensor: 176 | if isinstance(batch_dim, int): 177 | batch_dim = [batch_dim] 178 | 179 | if x.man_dim in batch_dim: 180 | raise ValueError( 181 | f"Tried to aggregate over dimensions {batch_dim}, but input has manifold " 182 | f"dimension {x.man_dim} and cannot aggregate over this dimension" 183 | ) 184 | 185 | # Output manifold dimension is shifted left for each batch dim that disappears 186 | man_dim_shift = sum(bd < x.man_dim for bd in batch_dim) 187 | new_man_dim = x.man_dim - man_dim_shift if not keepdim else x.man_dim 188 | 189 | new_tensor = midpoint( 190 | x=x.tensor, c=self.c(), man_dim=x.man_dim, batch_dim=batch_dim, w=w, keepdim=keepdim 191 | ) 192 | return ManifoldTensor(data=new_tensor, manifold=self, man_dim=new_man_dim) 193 | 194 | def frechet_variance( 195 | self, 196 | x: ManifoldTensor, 197 | mu: Optional[ManifoldTensor] = None, 198 | batch_dim: Union[int, list[int]] = -1, 199 | keepdim: bool = False, 200 | ) -> Tensor: 201 | if mu is not None: 202 | mu = mu.tensor 203 | 204 | # TODO: Check if x and mu have compatible man_dims 205 | return frechet_variance( 206 | x=x.tensor, 207 | c=self.c(), 208 | mu=mu, 209 | vec_dim=x.man_dim, 210 | batch_dim=batch_dim, 211 | keepdim=keepdim, 212 | ) 213 | 214 | def construct_dl_parameters( 215 | self, in_features: int, out_features: int, bias: bool = True 216 | ) -> tuple[ManifoldParameter, Optional[Parameter]]: 217 | weight = ManifoldParameter( 218 | data=empty(in_features, out_features), 219 | manifold=Euclidean(), 220 | man_dim=0, 221 | ) 222 | 223 | if bias: 224 | b = Parameter(data=empty(out_features)) 225 | else: 226 | b = None 227 | 228 | return weight, b 229 | 230 | def reset_parameters(self, weight: ManifoldParameter, bias: Optional[Parameter]) -> None: 231 | in_features, out_features = weight.size() 232 | if in_features <= out_features: 233 | with no_grad(): 234 | weight.tensor.copy_(1 / 2 * eye(in_features, out_features)) 235 | else: 236 | normal_( 237 | weight.tensor, 238 | mean=0, 239 | std=(2 * in_features * out_features) ** -0.5, 240 | ) 241 | if bias is not None: 242 | zeros_(bias) 243 | 244 | def unfold( 245 | self, 246 | input: ManifoldTensor, 247 | kernel_size: _size_2_t, 248 | dilation: _size_2_t = 1, 249 | padding: _size_2_t = 0, 250 | stride: _size_2_t = 1, 251 | ) -> ManifoldTensor: 252 | # TODO: may have to cache some of this stuff for efficiency. 253 | in_channels = input.size(1) 254 | if len(kernel_size) == 2: 255 | kernel_vol = kernel_size[0] * kernel_size[1] 256 | else: 257 | kernel_vol = kernel_size**2 258 | kernel_size = (kernel_size, kernel_size) 259 | 260 | beta_ni = beta_func(in_channels / 2, 1 / 2) 261 | beta_n = beta_func(in_channels * kernel_vol / 2, 1 / 2) 262 | 263 | input = self.logmap(x=None, y=input) 264 | input.tensor = input.tensor * beta_n / beta_ni 265 | new_tensor = unfold( 266 | input=input.tensor, 267 | kernel_size=kernel_size, 268 | dilation=dilation, 269 | padding=padding, 270 | stride=stride, 271 | ) 272 | 273 | new_tensor = TangentTensor(data=new_tensor, manifold_points=None, manifold=self, man_dim=1) 274 | return self.expmap(new_tensor) 275 | 276 | def flatten(self, x: ManifoldTensor, start_dim: int = 1, end_dim: int = -1) -> ManifoldTensor: 277 | """Flattens a manifold tensor by reshaping it. If start_dim or end_dim are passed, 278 | only dimensions starting with start_dim and ending with end_dim are flattend. 279 | 280 | If the manifold dimension of the input tensor is among the dimensions which 281 | are flattened, applies beta-concatenation to the points on the manifold. 282 | Otherwise simply flattens the tensor using torch.flatten. 283 | 284 | Updates the manifold dimension if necessary. 285 | 286 | """ 287 | start_dim = x.dim() + start_dim if start_dim < 0 else start_dim 288 | end_dim = x.dim() + end_dim if end_dim < 0 else end_dim 289 | 290 | # Get the range of dimensions to flatten. 291 | dimensions_to_flatten = x.shape[start_dim + 1 : end_dim + 1] 292 | 293 | if start_dim <= x.man_dim and end_dim >= x.man_dim: 294 | # Use beta concatenation to flatten the manifold dimension of the tensor. 295 | # 296 | # Start by applying logmap at the origin and computing the betas. 297 | tangents = self.logmap(None, x) 298 | n_i = x.shape[x.man_dim] 299 | n = n_i * functools.reduce(lambda a, b: a * b, dimensions_to_flatten) 300 | beta_n = beta_func(n / 2, 0.5) 301 | beta_n_i = beta_func(n_i / 2, 0.5) 302 | # Flatten the tensor and rescale. 303 | tangents.tensor = torch.flatten( 304 | input=tangents.tensor, 305 | start_dim=start_dim, 306 | end_dim=end_dim, 307 | ) 308 | tangents.tensor = tangents.tensor * beta_n / beta_n_i 309 | # Set the new manifold dimension 310 | tangents.man_dim = start_dim 311 | # Apply exponential map at the origin. 312 | return self.expmap(tangents) 313 | else: 314 | flattened = torch.flatten( 315 | input=x.tensor, 316 | start_dim=start_dim, 317 | end_dim=end_dim, 318 | ) 319 | man_dim = x.man_dim if end_dim > x.man_dim else x.man_dim - len(dimensions_to_flatten) 320 | return ManifoldTensor(data=flattened, manifold=x.manifold, man_dim=man_dim) 321 | 322 | def cdist(self, x: ManifoldTensor, y: ManifoldTensor) -> Tensor: 323 | return cdist(x=x.tensor, y=y.tensor, c=self.c()) 324 | 325 | def cat( 326 | self, 327 | manifold_tensors: Union[Tuple[ManifoldTensor, ...], List[ManifoldTensor]], 328 | dim: int = 0, 329 | ) -> ManifoldTensor: 330 | check_if_man_dims_match(manifold_tensors) 331 | if dim == manifold_tensors[0].man_dim: 332 | tangent_tensors = [self.logmap(None, t) for t in manifold_tensors] 333 | ns = torch.tensor([t.shape[t.man_dim] for t in manifold_tensors]) 334 | n = ns.sum() 335 | beta_ns = beta_func(ns / 2, 0.5) 336 | beta_n = beta_func(n / 2, 0.5) 337 | cat = torch.cat( 338 | [(t.tensor * beta_n) / beta_n_i for (t, beta_n_i) in zip(tangent_tensors, beta_ns)], 339 | dim=dim, 340 | ) 341 | new_tensor = TangentTensor(data=cat, manifold=self, man_dim=dim) 342 | return self.expmap(new_tensor) 343 | else: 344 | cat = torch.cat([t.tensor for t in manifold_tensors], dim=dim) 345 | man_dim = manifold_tensors[0].man_dim 346 | return ManifoldTensor(data=cat, manifold=self, man_dim=man_dim) 347 | -------------------------------------------------------------------------------- /hypll/manifolds/poincare_ball/math/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxvanspengler/hyperbolic_learning_library/62ea7882848d878459eba4100aeb6b8ad64cc38f/hypll/manifolds/poincare_ball/math/__init__.py -------------------------------------------------------------------------------- /hypll/manifolds/poincare_ball/math/diffgeom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mobius_add(x: torch.Tensor, y: torch.Tensor, c: torch.Tensor, dim: int = -1): 5 | broadcasted_dim = max(x.dim(), y.dim()) 6 | dim = dim if dim >= 0 else broadcasted_dim + dim 7 | x2 = x.pow(2).sum( 8 | dim=dim - broadcasted_dim + x.dim(), 9 | keepdim=True, 10 | ) 11 | y2 = y.pow(2).sum( 12 | dim=dim - broadcasted_dim + y.dim(), 13 | keepdim=True, 14 | ) 15 | xy = (x * y).sum(dim=dim, keepdim=True) 16 | num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y 17 | denom = 1 + 2 * c * xy + c**2 * x2 * y2 18 | return num / denom.clamp_min(1e-15) 19 | 20 | 21 | def project(x: torch.Tensor, c: torch.Tensor, dim: int = -1, eps: float = -1.0): 22 | if eps < 0: 23 | if x.dtype == torch.float32: 24 | eps = 4e-3 25 | else: 26 | eps = 1e-5 27 | maxnorm = (1 - eps) / ((c + 1e-15) ** 0.5) 28 | maxnorm = torch.where(c.gt(0), maxnorm, c.new_full((), 1e15)) 29 | norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(1e-15) 30 | cond = norm > maxnorm 31 | projected = x / norm * maxnorm 32 | return torch.where(cond, projected, x) 33 | 34 | 35 | def expmap0(v: torch.Tensor, c: torch.Tensor, dim: int = -1): 36 | v_norm_c_sqrt = v.norm(dim=dim, keepdim=True).clamp_min(1e-15) * c.sqrt() 37 | return project(torch.tanh(v_norm_c_sqrt) * v / v_norm_c_sqrt, c, dim=dim) 38 | 39 | 40 | def logmap0(y: torch.Tensor, c: torch.Tensor, dim: int = -1): 41 | y_norm_c_sqrt = y.norm(dim=dim, keepdim=True).clamp_min(1e-15) * c.sqrt() 42 | return torch.atanh(y_norm_c_sqrt) * y / y_norm_c_sqrt 43 | 44 | 45 | def expmap(x: torch.Tensor, v: torch.Tensor, c: torch.Tensor, dim: int = -1): 46 | broadcasted_dim = max(x.dim(), v.dim()) 47 | dim = dim if dim >= 0 else broadcasted_dim + dim 48 | v_norm = v.norm(dim=dim - broadcasted_dim + v.dim(), keepdim=True).clamp_min(1e-15) 49 | lambda_x = 2 / ( 50 | 1 - c * x.pow(2).sum(dim=dim - broadcasted_dim + x.dim(), keepdim=True) 51 | ).clamp_min(1e-15) 52 | c_sqrt = c.sqrt() 53 | second_term = torch.tanh(c_sqrt * lambda_x * v_norm / 2) * v / (c_sqrt * v_norm) 54 | return project(mobius_add(x, second_term, c, dim=dim), c, dim=dim) 55 | 56 | 57 | def logmap(x: torch.Tensor, y: torch.Tensor, c: torch.Tensor, dim: int = -1): 58 | broadcasted_dim = max(x.dim(), y.dim()) 59 | dim = dim if dim >= 0 else broadcasted_dim + dim 60 | min_x_y = mobius_add(-x, y, c, dim=dim) 61 | min_x_y_norm = min_x_y.norm(dim=dim, keepdim=True).clamp_min(1e-15) 62 | lambda_x = 2 / ( 63 | 1 - c * x.pow(2).sum(dim=dim - broadcasted_dim + x.dim(), keepdim=True) 64 | ).clamp_min(1e-15) 65 | c_sqrt = c.sqrt() 66 | return 2 / (c_sqrt * lambda_x) * torch.atanh(c_sqrt * min_x_y_norm) * min_x_y / min_x_y_norm 67 | 68 | 69 | def gyration( 70 | u: torch.Tensor, 71 | v: torch.Tensor, 72 | w: torch.Tensor, 73 | c: torch.Tensor, 74 | dim: int = -1, 75 | ): 76 | broadcasted_dim = max(u.dim(), v.dim(), w.dim()) 77 | dim = dim if dim >= 0 else broadcasted_dim + dim 78 | u2 = u.pow(2).sum(dim=dim - broadcasted_dim + u.dim(), keepdim=True) 79 | v2 = v.pow(2).sum(dim=dim - broadcasted_dim + v.dim(), keepdim=True) 80 | uv = (u * v).sum(dim=dim - broadcasted_dim + max(u.dim(), v.dim()), keepdim=True) 81 | uw = (u * w).sum(dim=dim - broadcasted_dim + max(u.dim(), w.dim()), keepdim=True) 82 | vw = (v * w).sum(dim=dim - broadcasted_dim + max(v.dim(), w.dim()), keepdim=True) 83 | K2 = c**2 84 | a = -K2 * uw * v2 + c * vw + 2 * K2 * uv * vw 85 | b = -K2 * vw * u2 - c * uw 86 | d = 1 + 2 * c * uv + K2 * u2 * v2 87 | return w + 2 * (a * u + b * v) / d.clamp_min(1e-15) 88 | 89 | 90 | def transp( 91 | x: torch.Tensor, 92 | y: torch.Tensor, 93 | v: torch.Tensor, 94 | c: torch.Tensor, 95 | dim: int = -1, 96 | ): 97 | broadcasted_dim = max(x.dim(), y.dim(), v.dim()) 98 | dim = dim if dim >= 0 else broadcasted_dim + dim 99 | lambda_x = 2 / ( 100 | 1 - c * x.pow(2).sum(dim=dim - broadcasted_dim + x.dim(), keepdim=True) 101 | ).clamp_min(1e-15) 102 | lambda_y = 2 / ( 103 | 1 - c * y.pow(2).sum(dim=dim - broadcasted_dim + y.dim(), keepdim=True) 104 | ).clamp_min(1e-15) 105 | return gyration(y, -x, v, c, dim=dim) * lambda_x / lambda_y 106 | 107 | 108 | def dist( 109 | x: torch.Tensor, 110 | y: torch.Tensor, 111 | c: torch.Tensor, 112 | dim: int = -1, 113 | keepdim: bool = False, 114 | ) -> torch.Tensor: 115 | return ( 116 | 2 117 | / c.sqrt() 118 | * (c.sqrt() * mobius_add(-x, y, c, dim=dim).norm(dim=dim, keepdim=keepdim)).atanh() 119 | ) 120 | 121 | 122 | def inner( 123 | x: torch.Tensor, 124 | u: torch.Tensor, 125 | v: torch.Tensor, 126 | c: torch.Tensor, 127 | dim: int = -1, 128 | keepdim: bool = False, 129 | ) -> torch.Tensor: 130 | broadcasted_dim = max(x.dim(), u.dim(), v.dim()) 131 | dim = dim if dim >= 0 else broadcasted_dim + dim 132 | lambda_x = 2 / ( 133 | 1 - c * x.pow(2).sum(dim=dim - broadcasted_dim + x.dim(), keepdim=True) 134 | ).clamp_min(1e-15) 135 | dot_prod = (u * v).sum(dim=dim, keepdim=keepdim) 136 | return lambda_x.square() * dot_prod 137 | 138 | 139 | def euc_to_tangent( 140 | x: torch.Tensor, 141 | u: torch.Tensor, 142 | c: torch.Tensor, 143 | dim: int = -1, 144 | ) -> torch.Tensor: 145 | broadcasted_dim = max(x.dim(), u.dim()) 146 | dim = dim if dim >= 0 else broadcasted_dim + dim 147 | lambda_x = 2 / ( 148 | 1 - c * x.pow(2).sum(dim=dim - broadcasted_dim + x.dim(), keepdim=True) 149 | ).clamp_min(1e-15) 150 | return u / lambda_x**2 151 | 152 | 153 | def cdist( 154 | x: torch.Tensor, 155 | y: torch.Tensor, 156 | c: torch.Tensor, 157 | ) -> torch.Tensor: 158 | return 2 / c.sqrt() * (c.sqrt() * mobius_add_batch(-x, y, c).norm(dim=-1)).atanh() 159 | 160 | 161 | def mobius_add_batch( 162 | x: torch.Tensor, 163 | y: torch.Tensor, 164 | c: torch.Tensor, 165 | ): 166 | xy = torch.einsum("bij,bkj->bik", (x, y)) 167 | x2 = x.pow(2).sum(dim=-1, keepdim=True) 168 | y2 = y.pow(2).sum(dim=-1, keepdim=True) 169 | num = 1 + 2 * c * xy + c * y2.permute(0, 2, 1) 170 | num = num.unsqueeze(3) * x.unsqueeze(2) 171 | num = num + (1 - c * x2).unsqueeze(3) * y.unsqueeze(1) 172 | denom = 1 + 2 * c * xy + c**2 * x2 * y2.permute(0, 2, 1) 173 | return num / denom.unsqueeze(3).clamp_min(1e-15) 174 | -------------------------------------------------------------------------------- /hypll/manifolds/poincare_ball/math/linalg.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | 6 | def poincare_hyperplane_dists( 7 | x: torch.Tensor, 8 | z: torch.Tensor, 9 | r: Optional[torch.Tensor], 10 | c: torch.Tensor, 11 | dim: int = -1, 12 | ) -> torch.Tensor: 13 | """ 14 | The Poincare signed distance to hyperplanes operation. 15 | 16 | Parameters 17 | ---------- 18 | x : tensor 19 | contains input values 20 | z : tensor 21 | contains the hyperbolic vectors describing the hyperplane orientations 22 | r : tensor 23 | contains the hyperplane offsets 24 | c : tensor 25 | curvature of the Poincare disk 26 | 27 | Returns 28 | ------- 29 | tensor 30 | signed distances of input w.r.t. the hyperplanes, denoted by v_k(x) in 31 | the HNN++ paper 32 | """ 33 | dim_shifted_x = x.movedim(source=dim, destination=-1) 34 | 35 | c_sqrt = c.sqrt() 36 | lam = 2 / (1 - c * dim_shifted_x.pow(2).sum(dim=-1, keepdim=True)) 37 | z_norm = z.norm(dim=0).clamp_min(1e-15) 38 | 39 | # Computation can be simplified if there is no offset 40 | if r is None: 41 | dim_shifted_output = ( 42 | 2 43 | * z_norm 44 | / c_sqrt 45 | * torch.asinh(c_sqrt * lam / z_norm * torch.matmul(dim_shifted_x, z)) 46 | ) 47 | else: 48 | two_csqrt_r = 2.0 * c_sqrt * r 49 | dim_shifted_output = ( 50 | 2 51 | * z_norm 52 | / c_sqrt 53 | * torch.asinh( 54 | c_sqrt * lam / z_norm * torch.matmul(dim_shifted_x, z) * two_csqrt_r.cosh() 55 | - (lam - 1) * two_csqrt_r.sinh() 56 | ) 57 | ) 58 | 59 | return dim_shifted_output.movedim(source=-1, destination=dim) 60 | 61 | 62 | def poincare_fully_connected( 63 | x: torch.Tensor, 64 | z: torch.Tensor, 65 | bias: Optional[torch.Tensor], 66 | c: torch.Tensor, 67 | dim: int = -1, 68 | ) -> torch.Tensor: 69 | """ 70 | The Poincare fully connected layer operation. 71 | 72 | Parameters 73 | ---------- 74 | x : tensor 75 | contains the layer inputs 76 | z : tensor 77 | contains the hyperbolic vectors describing the hyperplane orientations 78 | bias : tensor 79 | contains the biases (hyperplane offsets) 80 | c : tensor 81 | curvature of the Poincare disk 82 | 83 | Returns 84 | ------- 85 | tensor 86 | Poincare FC transformed hyperbolic tensor, commonly denoted by y 87 | """ 88 | c_sqrt = c.sqrt() 89 | x = poincare_hyperplane_dists(x=x, z=z, r=bias, c=c, dim=dim) 90 | x = (c_sqrt * x).sinh() / c_sqrt 91 | return x / (1 + (1 + c * x.pow(2).sum(dim=dim, keepdim=True)).sqrt()) 92 | -------------------------------------------------------------------------------- /hypll/manifolds/poincare_ball/math/stats.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .diffgeom import dist 7 | 8 | _TOLEPS = {torch.float32: 1e-6, torch.float64: 1e-12} 9 | 10 | 11 | class FrechetMean(torch.autograd.Function): 12 | """ 13 | This implementation is copied mostly from: 14 | https://github.com/CUVL/Differentiable-Frechet-Mean.git 15 | 16 | which is itself based on the paper: 17 | https://arxiv.org/abs/2003.00335 18 | 19 | Both by Aaron Lou (et al.) 20 | """ 21 | 22 | @staticmethod 23 | def forward(ctx, x, c, vec_dim, batch_dim, keepdim): 24 | # Convert input dimensions to positive values 25 | vec_dim = vec_dim if vec_dim > 0 else x.dim() + vec_dim 26 | batch_dim = [bd if bd >= 0 else x.dim() + bd for bd in batch_dim] 27 | 28 | # Compute some dim ids for later 29 | output_vec_dim = vec_dim - sum(bd < vec_dim for bd in batch_dim) 30 | batch_start_id = -len(batch_dim) - 1 31 | batch_stop_id = -2 32 | original_dims = [x.size(bd) for bd in batch_dim] 33 | 34 | # Move dims around and flatten the batch dimensions 35 | x = x.movedim( 36 | source=batch_dim + [vec_dim], 37 | destination=list(range(batch_start_id, 0)), 38 | ) 39 | x = x.flatten( 40 | start_dim=batch_start_id, end_dim=batch_stop_id 41 | ) # ..., prod(batch_dim), vec_dim 42 | 43 | # Compute frechet mean and store variables 44 | mean = frechet_ball_forward(X=x, K=c, rtol=_TOLEPS[x.dtype], atol=_TOLEPS[x.dtype]) 45 | ctx.save_for_backward(x, mean, c) 46 | ctx.vec_dim = vec_dim 47 | ctx.batch_dim = batch_dim 48 | ctx.output_vec_dim = output_vec_dim 49 | ctx.batch_start_id = batch_start_id 50 | ctx.original_dims = original_dims 51 | 52 | # Reorder dimensions to match dimensions of input 53 | mean = mean.movedim( 54 | source=list(range(output_vec_dim - mean.dim(), 0)), 55 | destination=list(range(output_vec_dim + 1, mean.dim())) + [output_vec_dim], 56 | ) 57 | 58 | # Add dimensions back to mean if keepdim 59 | if keepdim: 60 | for bd in sorted(batch_dim): 61 | mean = mean.unsqueeze(bd) 62 | 63 | return mean 64 | 65 | @staticmethod 66 | def backward(ctx, grad_output): 67 | x, mean, c = ctx.saved_tensors 68 | # Reshift dims in grad to match the dims of the mean that was stored in ctx 69 | grad_output = grad_output.movedim( 70 | source=list(range(ctx.output_vec_dim - mean.dim(), 0)), 71 | destination=[-1] + list(range(ctx.output_vec_dim - mean.dim(), -1)), 72 | ) 73 | dx, dc = frechet_ball_backward(X=x, y=mean, grad=grad_output, K=c) 74 | vec_dim = ctx.vec_dim 75 | batch_dim = ctx.batch_dim 76 | batch_start_id = ctx.batch_start_id 77 | original_dims = ctx.original_dims 78 | 79 | # Unflatten the batch dimensions 80 | dx = dx.unflatten(dim=-2, sizes=original_dims) 81 | 82 | # Move the vector dimension back 83 | dx = dx.movedim( 84 | source=list(range(batch_start_id, 0)), 85 | destination=batch_dim + [vec_dim], 86 | ) 87 | 88 | return dx, dc, None, None, None 89 | 90 | 91 | def frechet_mean( 92 | x: torch.Tensor, 93 | c: torch.Tensor, 94 | vec_dim: int = -1, 95 | batch_dim: Union[int, list[int]] = 0, 96 | keepdim: bool = False, 97 | ) -> torch.Tensor: 98 | if isinstance(batch_dim, int): 99 | batch_dim = [batch_dim] 100 | 101 | return FrechetMean.apply(x, c, vec_dim, batch_dim, keepdim) 102 | 103 | 104 | def midpoint( 105 | x: torch.Tensor, 106 | c: torch.Tensor, 107 | man_dim: int = -1, 108 | batch_dim: Union[int, list[int]] = 0, 109 | w: Optional[torch.Tensor] = None, 110 | keepdim: bool = False, 111 | ) -> torch.Tensor: 112 | lambda_x = 2 / (1 - c * x.pow(2).sum(dim=man_dim, keepdim=True)).clamp_min(1e-15) 113 | if w is None: 114 | numerator = (lambda_x * x).sum(dim=batch_dim, keepdim=True) 115 | denominator = (lambda_x - 1).sum(dim=batch_dim, keepdim=True) 116 | else: 117 | # TODO: test weights 118 | numerator = (lambda_x * w * x).sum(dim=batch_dim, keepdim=True) 119 | denominator = ((lambda_x - 1) * w).sum(dim=batch_dim, keepdim=True) 120 | 121 | frac = numerator / denominator.clamp_min(1e-15) 122 | midpoint = 1 / (1 + (1 - c * frac.pow(2).sum(dim=man_dim, keepdim=True)).sqrt()) * frac 123 | if not keepdim: 124 | midpoint = midpoint.squeeze(dim=batch_dim) 125 | 126 | return midpoint 127 | 128 | 129 | def frechet_variance( 130 | x: torch.Tensor, 131 | c: torch.Tensor, 132 | mu: Optional[torch.Tensor] = None, 133 | vec_dim: int = -1, 134 | batch_dim: Union[int, list[int]] = 0, 135 | keepdim: bool = False, 136 | ) -> torch.Tensor: # TODO 137 | """ 138 | Args 139 | ---- 140 | x (tensor): points of shape [..., points, dim] 141 | mu (tensor): mean of shape [..., dim] 142 | 143 | where the ... of the three variables match 144 | 145 | Returns 146 | ------- 147 | tensor of shape [...] 148 | """ 149 | if isinstance(batch_dim, int): 150 | batch_dim = [batch_dim] 151 | 152 | if mu is None: 153 | mu = frechet_mean(x=x, c=c, vec_dim=vec_dim, batch_dim=batch_dim, keepdim=True) 154 | 155 | if x.dim() != mu.dim(): 156 | for bd in sorted(batch_dim): 157 | mu = mu.unsqueeze(bd) 158 | 159 | distance = dist(x=x, y=mu, c=c, dim=vec_dim, keepdim=keepdim) 160 | distance = distance.pow(2) 161 | return distance.mean(dim=batch_dim, keepdim=keepdim) 162 | 163 | 164 | def l_prime(y: torch.Tensor) -> torch.Tensor: 165 | cond = y < 1e-12 166 | val = 4 * torch.ones_like(y) 167 | ret = torch.where(cond, val, 2 * (1 + 2 * y).acosh() / (y.pow(2) + y).sqrt()) 168 | return ret 169 | 170 | 171 | def frechet_ball_forward( 172 | X: torch.Tensor, 173 | K: torch.Tensor = torch.Tensor([1]), 174 | max_iter: int = 1000, 175 | rtol: float = 1e-6, 176 | atol: float = 1e-6, 177 | ) -> torch.Tensor: 178 | """ 179 | Args 180 | ---- 181 | X (tensor): point of shape [..., points, dim] 182 | K (float): curvature (must be negative) 183 | 184 | Returns 185 | ------- 186 | frechet mean (tensor): shape [..., dim] 187 | """ 188 | mu = X[..., 0, :].clone() 189 | 190 | x_ss = X.pow(2).sum(dim=-1) 191 | 192 | mu_prev = mu 193 | iters = 0 194 | for _ in range(max_iter): 195 | mu_ss = mu.pow(2).sum(dim=-1) 196 | xmu_ss = (X - mu.unsqueeze(-2)).pow(2).sum(dim=-1) 197 | 198 | alphas = l_prime(K * xmu_ss / ((1 - K * x_ss) * (1 - K * mu_ss.unsqueeze(-1)))) / ( 199 | 1 - K * x_ss 200 | ) 201 | 202 | alphas = alphas 203 | 204 | c = (alphas * x_ss).sum(dim=-1) 205 | b = (alphas.unsqueeze(-1) * X).sum(dim=-2) 206 | a = alphas.sum(dim=-1) 207 | 208 | b_ss = b.pow(2).sum(dim=-1) 209 | 210 | eta = (a + K * c - ((a + K * c).pow(2) - 4 * K * b_ss).sqrt()) / (2 * K * b_ss).clamp_min( 211 | 1e-15 212 | ) 213 | 214 | mu = eta.unsqueeze(-1) * b 215 | 216 | dist = (mu - mu_prev).norm(dim=-1) 217 | prev_dist = mu_prev.norm(dim=-1) 218 | if (dist < atol).all() or (dist / prev_dist < rtol).all(): 219 | break 220 | 221 | mu_prev = mu 222 | iters += 1 223 | 224 | return mu 225 | 226 | 227 | def darcosh(x): 228 | cond = x < 1 + 1e-7 229 | x = torch.where(cond, 2 * torch.ones_like(x), x) 230 | x = torch.where(~cond, 2 * (x).acosh() / torch.sqrt(x**2 - 1), x) 231 | return x 232 | 233 | 234 | def d2arcosh(x): 235 | cond = x < 1 + 1e-7 236 | x = torch.where(cond, -2 / 3 * torch.ones_like(x), x) 237 | x = torch.where( 238 | ~cond, 239 | 2 / (x**2 - 1) - 2 * x * x.acosh() / ((x**2 - 1) ** (3 / 2)), 240 | x, 241 | ) 242 | return x 243 | 244 | 245 | def grad_var( 246 | X: torch.Tensor, 247 | y: torch.Tensor, 248 | K: torch.Tensor, 249 | ) -> torch.Tensor: 250 | """ 251 | Args 252 | ---- 253 | X (tensor): point of shape [..., points, dim] 254 | y (tensor): mean point of shape [..., dim] 255 | K (float): curvature (must be negative) 256 | 257 | Returns 258 | ------- 259 | grad (tensor): gradient of variance [..., dim] 260 | """ 261 | yl = y.unsqueeze(-2) 262 | xnorm = 1 - K * X.pow(2).sum(dim=-1) 263 | ynorm = 1 - K * yl.pow(2).sum(dim=-1) 264 | xynorm = (X - yl).pow(2).sum(dim=-1) 265 | 266 | D = xnorm * ynorm 267 | v = 1 + 2 * K * xynorm / D 268 | 269 | Dl = D.unsqueeze(-1) 270 | vl = v.unsqueeze(-1) 271 | 272 | first_term = (X - yl) / Dl 273 | sec_term = -K / Dl.pow(2) * yl * xynorm.unsqueeze(-1) * xnorm.unsqueeze(-1) 274 | return -(4 * darcosh(vl) * (first_term + sec_term)).sum(dim=-2) 275 | 276 | 277 | def inverse_hessian( 278 | X: torch.Tensor, 279 | y: torch.Tensor, 280 | K: torch.Tensor, 281 | ) -> torch.Tensor: 282 | """ 283 | Args 284 | ---- 285 | X (tensor): point of shape [..., points, dim] 286 | y (tensor): mean point of shape [..., dim] 287 | K (float): curvature (must be negative) 288 | 289 | Returns 290 | ------- 291 | inv_hess (tensor): inverse hessian of [..., points, dim, dim] 292 | """ 293 | yl = y.unsqueeze(-2) 294 | xnorm = 1 - K * X.pow(2).sum(dim=-1) 295 | ynorm = 1 - K * yl.pow(2).sum(dim=-1) 296 | xynorm = (X - yl).pow(2).sum(dim=-1) 297 | 298 | D = xnorm * ynorm 299 | v = 1 + 2 * K * xynorm / D 300 | 301 | Dl = D.unsqueeze(-1) 302 | vl = v.unsqueeze(-1) 303 | vll = vl.unsqueeze(-1) 304 | 305 | """ 306 | \partial T/ \partial y 307 | """ 308 | first_const = -8 * (K**2) * xnorm / D.pow(2) 309 | matrix_val = (first_const.unsqueeze(-1) * yl).unsqueeze(-1) * (X - yl).unsqueeze(-2) 310 | first_term = matrix_val + matrix_val.transpose(-1, -2) 311 | 312 | sec_const = 16 * (K**3) * xnorm.pow(2) / D.pow(3) * xynorm 313 | sec_term = (sec_const.unsqueeze(-1) * yl).unsqueeze(-1) * yl.unsqueeze(-2) 314 | 315 | third_const = 4 * K / D + 4 * (K**2) * xnorm / D.pow(2) * xynorm 316 | third_term = third_const.reshape(*third_const.shape, 1, 1) * torch.eye(y.shape[-1]).to( 317 | X 318 | ).reshape((1,) * len(third_const.shape) + (y.shape[-1], y.shape[-1])) 319 | 320 | Ty = first_term + sec_term + third_term 321 | 322 | """ 323 | T 324 | """ 325 | 326 | first_term = -K / Dl * (X - yl) 327 | sec_term = K.pow(2) / Dl.pow(2) * yl * xynorm.unsqueeze(-1) * xnorm.unsqueeze(-1) 328 | T = 4 * (first_term + sec_term) 329 | 330 | """ 331 | inverse of shape [..., points, dim, dim] 332 | """ 333 | first_term = d2arcosh(vll) * T.unsqueeze(-1) * T.unsqueeze(-2) 334 | sec_term = darcosh(vll) * Ty 335 | hessian = (first_term + sec_term).sum(dim=-3) / K 336 | inv_hess = torch.inverse(hessian) 337 | return inv_hess 338 | 339 | 340 | def frechet_ball_backward( 341 | X: torch.Tensor, 342 | y: torch.Tensor, 343 | grad: torch.Tensor, 344 | K: torch.Tensor, 345 | ) -> tuple[torch.Tensor]: 346 | """ 347 | Args 348 | ---- 349 | X (tensor): point of shape [..., points, dim] 350 | y (tensor): mean point of shape [..., dim] 351 | grad (tensor): gradient 352 | K (float): curvature (must be negative) 353 | 354 | Returns 355 | ------- 356 | gradients (tensor, tensor, tensor): 357 | gradient of X [..., points, dim], curvature [] 358 | """ 359 | if not torch.is_tensor(K): 360 | K = torch.tensor(K).to(X) 361 | 362 | with torch.no_grad(): 363 | inv_hess = inverse_hessian(X, y, K=K) 364 | 365 | with torch.enable_grad(): 366 | # clone variables 367 | X = nn.Parameter(X.detach()) 368 | y = y.detach() 369 | K = nn.Parameter(K) 370 | 371 | grad = (inv_hess @ grad.unsqueeze(-1)).squeeze() 372 | gradf = grad_var(X, y, K) 373 | dx, dK = torch.autograd.grad(-gradf.squeeze(), (X, K), grad) 374 | 375 | return dx, dK 376 | -------------------------------------------------------------------------------- /hypll/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules.activation import HReLU 2 | from .modules.batchnorm import HBatchNorm, HBatchNorm2d 3 | from .modules.change_manifold import ChangeManifold 4 | from .modules.container import TangentSequential 5 | from .modules.convolution import HConvolution2d 6 | from .modules.embedding import HEmbedding 7 | from .modules.flatten import HFlatten 8 | from .modules.linear import HLinear 9 | from .modules.pooling import HAvgPool2d, HMaxPool2d 10 | -------------------------------------------------------------------------------- /hypll/nn/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxvanspengler/hyperbolic_learning_library/62ea7882848d878459eba4100aeb6b8ad64cc38f/hypll/nn/modules/__init__.py -------------------------------------------------------------------------------- /hypll/nn/modules/activation.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | from torch.nn.functional import relu 3 | 4 | from hypll.manifolds import Manifold 5 | from hypll.tensors import ManifoldTensor 6 | from hypll.utils.layer_utils import check_if_manifolds_match, op_in_tangent_space 7 | 8 | 9 | class HReLU(Module): 10 | def __init__(self, manifold: Manifold) -> None: 11 | super(HReLU, self).__init__() 12 | self.manifold = manifold 13 | 14 | def forward(self, input: ManifoldTensor) -> ManifoldTensor: 15 | check_if_manifolds_match(layer=self, input=input) 16 | return op_in_tangent_space( 17 | op=relu, 18 | manifold=self.manifold, 19 | input=input, 20 | ) 21 | -------------------------------------------------------------------------------- /hypll/nn/modules/batchnorm.py: -------------------------------------------------------------------------------- 1 | from torch import tensor, zeros 2 | from torch.nn import Module, Parameter 3 | 4 | from hypll.manifolds import Manifold 5 | from hypll.tensors import ManifoldTensor, TangentTensor 6 | from hypll.utils.layer_utils import check_if_manifolds_match 7 | 8 | 9 | class HBatchNorm(Module): 10 | """ 11 | Basic implementation of hyperbolic batch normalization. 12 | 13 | Based on: 14 | https://arxiv.org/abs/2003.00335 15 | """ 16 | 17 | def __init__( 18 | self, 19 | features: int, 20 | manifold: Manifold, 21 | use_midpoint: bool = False, 22 | ) -> None: 23 | super(HBatchNorm, self).__init__() 24 | self.features = features 25 | self.manifold = manifold 26 | self.use_midpoint = use_midpoint 27 | 28 | # TODO: Store bias on manifold 29 | self.bias = Parameter(zeros(features)) 30 | self.weight = Parameter(tensor(1.0)) 31 | 32 | def forward(self, x: ManifoldTensor) -> ManifoldTensor: 33 | check_if_manifolds_match(layer=self, input=x) 34 | bias_on_manifold = self.manifold.expmap( 35 | v=TangentTensor(data=self.bias, manifold_points=None, manifold=self.manifold) 36 | ) 37 | 38 | if self.use_midpoint: 39 | input_mean = self.manifold.midpoint(x=x, batch_dim=0) 40 | else: 41 | input_mean = self.manifold.frechet_mean(x=x, batch_dim=0) 42 | 43 | input_var = self.manifold.frechet_variance(x=x, mu=input_mean, batch_dim=0) 44 | 45 | input_logm = self.manifold.transp( 46 | v=self.manifold.logmap(input_mean, x), 47 | y=bias_on_manifold, 48 | ) 49 | 50 | input_logm.tensor = (self.weight / (input_var + 1e-6)).sqrt() * input_logm.tensor 51 | 52 | output = self.manifold.expmap(input_logm) 53 | 54 | return output 55 | 56 | 57 | class HBatchNorm2d(Module): 58 | """ 59 | 2D implementation of hyperbolic batch normalization. 60 | 61 | Based on: 62 | https://arxiv.org/abs/2003.00335 63 | """ 64 | 65 | def __init__( 66 | self, 67 | features: int, 68 | manifold: Manifold, 69 | use_midpoint: bool = False, 70 | ) -> None: 71 | super(HBatchNorm2d, self).__init__() 72 | self.features = features 73 | self.manifold = manifold 74 | self.use_midpoint = use_midpoint 75 | 76 | self.norm = HBatchNorm( 77 | features=features, 78 | manifold=manifold, 79 | use_midpoint=use_midpoint, 80 | ) 81 | 82 | def forward(self, x: ManifoldTensor) -> ManifoldTensor: 83 | check_if_manifolds_match(layer=self, input=x) 84 | batch_size, height, width = x.size(0), x.size(2), x.size(3) 85 | flat_x = ManifoldTensor( 86 | data=x.tensor.permute(0, 2, 3, 1).flatten(start_dim=0, end_dim=2), 87 | manifold=x.manifold, 88 | man_dim=-1, 89 | ) 90 | flat_x = self.norm(flat_x) 91 | new_tensor = flat_x.tensor.reshape(batch_size, height, width, self.features).permute( 92 | 0, 3, 1, 2 93 | ) 94 | return ManifoldTensor(data=new_tensor, manifold=x.manifold, man_dim=1) 95 | -------------------------------------------------------------------------------- /hypll/nn/modules/change_manifold.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from torch.nn import Module 3 | 4 | from hypll.manifolds import Manifold 5 | from hypll.tensors import ManifoldTensor, TangentTensor 6 | 7 | 8 | class ChangeManifold(Module): 9 | """Changes the manifold of the input manifold tensor to the target manifold. 10 | 11 | Attributes: 12 | target_manifold: 13 | Manifold the output tensor should be on. 14 | 15 | """ 16 | 17 | def __init__(self, target_manifold: Manifold): 18 | super(ChangeManifold, self).__init__() 19 | self.target_manifold = target_manifold 20 | 21 | def forward(self, x: ManifoldTensor) -> ManifoldTensor: 22 | """Changes the manifold of the input tensor to self.target_manifold. 23 | 24 | By default, applies logmap and expmap at the origin to switch between 25 | manifold. Applies shortcuts if possible. 26 | 27 | Args: 28 | x: 29 | Input manifold tensor. 30 | 31 | Returns: 32 | Tensor on the target manifold. 33 | 34 | """ 35 | 36 | match (x.manifold, self.target_manifold): 37 | # TODO(Philipp, 05/23): Apply shortcuts where possible: For example, 38 | # case (PoincareBall(), PoincareBall()): ... 39 | # 40 | # By default resort to logmap + expmap at the origin. 41 | case _: 42 | tangent_tensor = x.manifold.logmap(None, x) 43 | return self.target_manifold.expmap(tangent_tensor) 44 | -------------------------------------------------------------------------------- /hypll/nn/modules/container.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module, Sequential 2 | 3 | from hypll.manifolds import Manifold 4 | from hypll.tensors import ManifoldTensor 5 | from hypll.utils.layer_utils import check_if_manifolds_match 6 | 7 | 8 | class TangentSequential(Module): 9 | def __init__(self, seq: Sequential, manifold: Manifold) -> None: 10 | super(TangentSequential, self).__init__() 11 | self.seq = seq 12 | self.manifold = manifold 13 | 14 | def forward(self, input: ManifoldTensor) -> ManifoldTensor: 15 | check_if_manifolds_match(layer=self, input=input) 16 | man_dim = input.man_dim 17 | 18 | input = self.manifold.logmap(x=None, y=input) 19 | for module in self.seq: 20 | input.tensor = module(input.tensor) 21 | return self.manifold.expmap(input) 22 | -------------------------------------------------------------------------------- /hypll/nn/modules/convolution.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | from torch.nn.common_types import _size_1_t, _size_2_t 3 | 4 | from hypll.manifolds import Manifold 5 | from hypll.tensors import ManifoldTensor 6 | from hypll.utils.layer_utils import check_if_man_dims_match, check_if_manifolds_match 7 | 8 | 9 | class HConvolution2d(Module): 10 | """Applies a 2D convolution over a hyperbolic input signal. 11 | 12 | Attributes: 13 | in_channels: 14 | Number of channels in the input image. 15 | out_channels: 16 | Number of channels produced by the convolution. 17 | kernel_size: 18 | Size of the convolving kernel. 19 | manifold: 20 | Hyperbolic manifold of the tensors. 21 | bias: 22 | If True, adds a learnable bias to the output. Default: True 23 | stride: 24 | Stride of the convolution. Default: 1 25 | padding: 26 | Padding added to all four sides of the input. Default: 0 27 | id_init: 28 | Use identity initialization (True) if appropriate or use HNN++ initialization (False). 29 | 30 | """ 31 | 32 | def __init__( 33 | self, 34 | in_channels: int, 35 | out_channels: int, 36 | kernel_size: _size_2_t, 37 | manifold: Manifold, 38 | bias: bool = True, 39 | stride: int = 1, 40 | padding: int = 0, 41 | id_init: bool = True, 42 | ) -> None: 43 | super(HConvolution2d, self).__init__() 44 | self.in_channels = in_channels 45 | self.out_channels = out_channels 46 | self.kernel_size = ( 47 | kernel_size 48 | if isinstance(kernel_size, tuple) and len(kernel_size) == 2 49 | else (kernel_size, kernel_size) 50 | ) 51 | self.kernel_vol = self.kernel_size[0] * self.kernel_size[1] 52 | self.manifold = manifold 53 | self.stride = stride 54 | self.padding = padding 55 | self.id_init = id_init 56 | self.has_bias = bias 57 | 58 | self.weights, self.bias = self.manifold.construct_dl_parameters( 59 | in_features=self.kernel_vol * in_channels, 60 | out_features=out_channels, 61 | bias=self.has_bias, 62 | ) 63 | 64 | self.reset_parameters() 65 | 66 | def reset_parameters(self) -> None: 67 | """Resets parameter weights based on the manifold.""" 68 | self.manifold.reset_parameters(weight=self.weights, bias=self.bias) 69 | 70 | def forward(self, x: ManifoldTensor) -> ManifoldTensor: 71 | """Does a forward pass of the 2D convolutional layer. 72 | 73 | Args: 74 | x: 75 | Manifold tensor of shape (B, C_in, H, W) with manifold dimension 1. 76 | 77 | Returns: 78 | Manifold tensor of shape (B, C_in, H_out, W_out) with manifold dimension 1. 79 | 80 | Raises: 81 | ValueError: If the manifolds or manifold dimensions don't match. 82 | 83 | """ 84 | check_if_manifolds_match(layer=self, input=x) 85 | check_if_man_dims_match(layer=self, man_dim=1, input=x) 86 | 87 | batch_size, height, width = x.size(0), x.size(2), x.size(3) 88 | out_height = _output_side_length( 89 | input_side_length=height, 90 | kernel_size=self.kernel_size[0], 91 | padding=self.padding, 92 | stride=self.stride, 93 | ) 94 | out_width = _output_side_length( 95 | input_side_length=width, 96 | kernel_size=self.kernel_size[1], 97 | padding=self.padding, 98 | stride=self.stride, 99 | ) 100 | 101 | x = self.manifold.unfold( 102 | input=x, 103 | kernel_size=self.kernel_size, 104 | padding=self.padding, 105 | stride=self.stride, 106 | ) 107 | x = self.manifold.fully_connected(x=x, z=self.weights, bias=self.bias) 108 | x = ManifoldTensor( 109 | data=x.tensor.reshape(batch_size, self.out_channels, out_height, out_width), 110 | manifold=x.manifold, 111 | man_dim=1, 112 | ) 113 | return x 114 | 115 | 116 | def _output_side_length( 117 | input_side_length: int, kernel_size: _size_1_t, padding: int, stride: int 118 | ) -> int: 119 | """Calculates the output side length of the kernel. 120 | 121 | Based on https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html. 122 | 123 | """ 124 | if kernel_size > input_side_length: 125 | raise RuntimeError( 126 | f"Encountered invalid kernel size {kernel_size} " 127 | f"larger than input side length {input_side_length}" 128 | ) 129 | if stride > input_side_length: 130 | raise RuntimeError( 131 | f"Encountered invalid stride {stride} " 132 | f"larger than input side length {input_side_length}" 133 | ) 134 | return 1 + (input_side_length + 2 * padding - (kernel_size - 1) - 1) // stride 135 | -------------------------------------------------------------------------------- /hypll/nn/modules/embedding.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, empty, no_grad, normal, ones_like, zeros_like 2 | from torch.nn import Module 3 | 4 | from hypll.manifolds import Manifold 5 | from hypll.tensors import ManifoldParameter, ManifoldTensor, TangentTensor 6 | 7 | 8 | class HEmbedding(Module): 9 | def __init__( 10 | self, 11 | num_embeddings: int, 12 | embedding_dim: int, 13 | manifold: Manifold, 14 | ) -> None: 15 | super(HEmbedding, self).__init__() 16 | self.num_embeddings = num_embeddings 17 | self.embedding_dim = embedding_dim 18 | self.manifold = manifold 19 | 20 | self.weight = ManifoldParameter( 21 | data=empty((num_embeddings, embedding_dim)), manifold=manifold, man_dim=-1 22 | ) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | with no_grad(): 27 | new_weight = normal( 28 | mean=zeros_like(self.weight.tensor), 29 | std=ones_like(self.weight.tensor), 30 | ) 31 | new_weight = TangentTensor( 32 | data=new_weight, 33 | manifold_points=None, 34 | manifold=self.manifold, 35 | man_dim=-1, 36 | ) 37 | self.weight.copy_(self.manifold.expmap(new_weight).tensor) 38 | 39 | def forward(self, input: Tensor) -> ManifoldTensor: 40 | return self.weight[input] 41 | -------------------------------------------------------------------------------- /hypll/nn/modules/flatten.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module, functional 2 | 3 | from hypll.tensors import ManifoldTensor 4 | 5 | 6 | class HFlatten(Module): 7 | """Flattens a contiguous range of dims into a tensor. 8 | 9 | Attributes: 10 | start_dim: 11 | First dimension to flatten (default = 1). 12 | end_dim: 13 | Last dimension to flatten (default = -1). 14 | 15 | """ 16 | 17 | def __init__(self, start_dim: int = 1, end_dim: int = -1): 18 | super(HFlatten, self).__init__() 19 | self.start_dim = start_dim 20 | self.end_dim = end_dim 21 | 22 | def forward(self, x: ManifoldTensor) -> ManifoldTensor: 23 | """Flattens the manifold input tensor.""" 24 | return x.flatten(start_dim=self.start_dim, end_dim=self.end_dim) 25 | -------------------------------------------------------------------------------- /hypll/nn/modules/fold.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | from torch.nn.common_types import _size_2_t 3 | 4 | from hypll.manifolds import Manifold 5 | from hypll.tensors import ManifoldTensor 6 | from hypll.utils.layer_utils import check_if_man_dims_match, check_if_manifolds_match 7 | 8 | 9 | class HUnfold(Module): 10 | def __init__( 11 | self, 12 | kernel_size: _size_2_t, 13 | manifold: Manifold, 14 | dilation: _size_2_t = 1, 15 | padding: _size_2_t = 0, 16 | stride: _size_2_t = 1, 17 | ) -> None: 18 | self.kernel_size = kernel_size 19 | self.manifold = manifold 20 | self.dilation = dilation 21 | self.padding = padding 22 | self.stride = stride 23 | 24 | def forward(self, input: ManifoldTensor) -> ManifoldTensor: 25 | check_if_manifolds_match(layer=self, input=input) 26 | check_if_man_dims_match(layer=self, man_dim=1, input=input) 27 | return self.manifold.unfold( 28 | input=input, 29 | kernel_size=self.kernel_size, 30 | dilation=self.dilation, 31 | padding=self.padding, 32 | stride=self.stride, 33 | ) 34 | -------------------------------------------------------------------------------- /hypll/nn/modules/linear.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | 3 | from hypll.manifolds import Manifold 4 | from hypll.tensors import ManifoldTensor 5 | from hypll.utils.layer_utils import check_if_man_dims_match, check_if_manifolds_match 6 | 7 | 8 | class HLinear(Module): 9 | """Poincare fully connected linear layer""" 10 | 11 | def __init__( 12 | self, 13 | in_features: int, 14 | out_features: int, 15 | manifold: Manifold, 16 | bias: bool = True, 17 | ) -> None: 18 | super(HLinear, self).__init__() 19 | self.in_features = in_features 20 | self.out_features = out_features 21 | self.manifold = manifold 22 | self.has_bias = bias 23 | 24 | # TODO: torch stores weights transposed supposedly due to efficiency 25 | # https://discuss.pytorch.org/t/why-does-the-linear-module-seems-to-do-unnecessary-transposing/6277/7 26 | # We may want to do the same 27 | self.z, self.bias = self.manifold.construct_dl_parameters( 28 | in_features=in_features, out_features=out_features, bias=self.has_bias 29 | ) 30 | 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self) -> None: 34 | self.manifold.reset_parameters(self.z, self.bias) 35 | 36 | def forward(self, x: ManifoldTensor) -> ManifoldTensor: 37 | check_if_manifolds_match(layer=self, input=x) 38 | check_if_man_dims_match(layer=self, man_dim=-1, input=x) 39 | return self.manifold.fully_connected(x=x, z=self.z, bias=self.bias) 40 | -------------------------------------------------------------------------------- /hypll/nn/modules/pooling.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional 3 | 4 | from torch.nn import Module 5 | from torch.nn.common_types import _size_2_t, _size_any_t 6 | from torch.nn.functional import max_pool2d 7 | 8 | from hypll.manifolds import Manifold 9 | from hypll.tensors import ManifoldTensor 10 | from hypll.utils.layer_utils import ( 11 | check_if_man_dims_match, 12 | check_if_manifolds_match, 13 | op_in_tangent_space, 14 | ) 15 | 16 | 17 | class HAvgPool2d(Module): 18 | def __init__( 19 | self, 20 | kernel_size: _size_2_t, 21 | manifold: Manifold, 22 | stride: Optional[_size_2_t] = None, 23 | padding: _size_2_t = 0, 24 | use_midpoint: bool = False, 25 | ): 26 | super().__init__() 27 | self.kernel_size = ( 28 | kernel_size 29 | if isinstance(kernel_size, tuple) and len(kernel_size) == 2 30 | else (kernel_size, kernel_size) 31 | ) 32 | self.manifold = manifold 33 | self.stride = stride if (stride is not None) else self.kernel_size 34 | self.padding = ( 35 | padding if isinstance(padding, tuple) and len(padding) == 2 else (padding, padding) 36 | ) 37 | self.use_midpoint = use_midpoint 38 | 39 | def forward(self, input: ManifoldTensor) -> ManifoldTensor: 40 | check_if_manifolds_match(layer=self, input=input) 41 | check_if_man_dims_match(layer=self, man_dim=1, input=input) 42 | 43 | batch_size, channels, height, width = input.size() 44 | out_height = int((height + 2 * self.padding[0] - self.kernel_size[0]) / self.stride[0] + 1) 45 | out_width = int((width + 2 * self.padding[1] - self.kernel_size[1]) / self.stride[1] + 1) 46 | 47 | unfolded_input = self.manifold.unfold( 48 | input=input, 49 | kernel_size=self.kernel_size, 50 | padding=self.padding, 51 | stride=self.stride, 52 | ) 53 | per_kernel_view = unfolded_input.tensor.view( 54 | batch_size, 55 | channels, 56 | self.kernel_size[0] * self.kernel_size[1], 57 | unfolded_input.size(-1), 58 | ) 59 | 60 | x = ManifoldTensor(data=per_kernel_view, manifold=self.manifold, man_dim=1) 61 | 62 | if self.use_midpoint: 63 | aggregates = self.manifold.midpoint(x=x, batch_dim=2) 64 | 65 | else: 66 | aggregates = self.manifold.frechet_mean(x=x, batch_dim=2) 67 | 68 | return ManifoldTensor( 69 | data=aggregates.tensor.reshape(batch_size, channels, out_height, out_width), 70 | manifold=self.manifold, 71 | man_dim=1, 72 | ) 73 | 74 | 75 | class _HMaxPoolNd(Module): 76 | def __init__( 77 | self, 78 | kernel_size: _size_any_t, 79 | manifold: Manifold, 80 | stride: Optional[_size_any_t] = None, 81 | padding: _size_any_t = 0, 82 | dilation: _size_any_t = 1, 83 | return_indices: bool = False, 84 | ceil_mode: bool = False, 85 | ) -> None: 86 | super().__init__() 87 | self.kernel_size = kernel_size 88 | self.manifold = manifold 89 | self.stride = stride if (stride is not None) else kernel_size 90 | self.padding = padding 91 | self.dilation = dilation 92 | self.return_indices = return_indices 93 | self.ceil_mode = ceil_mode 94 | 95 | 96 | class HMaxPool2d(_HMaxPoolNd): 97 | kernel_size: _size_2_t 98 | stride: _size_2_t 99 | padding: _size_2_t 100 | dilation: _size_2_t 101 | 102 | def forward(self, input: ManifoldTensor) -> ManifoldTensor: 103 | check_if_manifolds_match(layer=self, input=input) 104 | 105 | # TODO: check if defining this partial func each forward pass is slow. 106 | # If it is, put this stuff inside the init or add kwargs to op_in_tangent_space. 107 | max_pool2d_partial = partial( 108 | max_pool2d, 109 | kernel_size=self.kernel_size, 110 | stride=self.stride, 111 | padding=self.padding, 112 | dilation=self.dilation, 113 | ceil_mode=self.ceil_mode, 114 | return_indices=self.return_indices, 115 | ) 116 | return op_in_tangent_space(op=max_pool2d_partial, manifold=self.manifold, input=input) 117 | -------------------------------------------------------------------------------- /hypll/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adam import RiemannianAdam 2 | from .sgd import RiemannianSGD 3 | -------------------------------------------------------------------------------- /hypll/optim/adam.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | from typing import Union 3 | 4 | from torch import max, no_grad, zeros_like 5 | from torch.optim import Optimizer 6 | 7 | from hypll.manifolds import Manifold 8 | from hypll.manifolds.euclidean import Euclidean 9 | from hypll.tensors import ManifoldParameter, ManifoldTensor, TangentTensor 10 | 11 | 12 | class RiemannianAdam(Optimizer): 13 | def __init__( 14 | self, 15 | params: Iterable[Union[ManifoldParameter, ManifoldTensor]], 16 | lr: float, 17 | betas: tuple[float, float] = (0.9, 0.999), 18 | eps: float = 1e-8, 19 | weight_decay: float = 0, 20 | amsgrad: bool = False, 21 | ) -> None: 22 | if lr < 0.0: 23 | raise ValueError("Invalid learning rate: {}".format(lr)) 24 | if eps < 0.0: 25 | raise ValueError("Invalid epsilon value: {}".format(eps)) 26 | if not 0.0 <= betas[0] < 1.0: 27 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 28 | if not 0.0 <= betas[1] < 1.0: 29 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 30 | if weight_decay < 0.0: 31 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 32 | 33 | defaults = dict( 34 | lr=lr, 35 | betas=betas, 36 | eps=eps, 37 | weight_decay=weight_decay, 38 | amsgrad=amsgrad, 39 | ) 40 | super(RiemannianAdam, self).__init__(params=params, defaults=defaults) 41 | 42 | def step(self) -> None: 43 | # TODO: refactor this and add some comments, because it's currently unreadable 44 | with no_grad(): 45 | for group in self.param_groups: 46 | betas = group["betas"] 47 | weight_decay = group["weight_decay"] 48 | eps = group["eps"] 49 | lr = group["lr"] 50 | amsgrad = group["amsgrad"] 51 | for param in group["params"]: 52 | if isinstance(param, ManifoldParameter): 53 | manifold: Manifold = param.manifold 54 | grad = param.grad 55 | if grad is None: 56 | continue 57 | grad = ManifoldTensor( 58 | data=grad, manifold=Euclidean(), man_dim=param.man_dim 59 | ) 60 | state = self.state[param] 61 | 62 | if len(state) == 0: 63 | state["step"] = 0 64 | state["exp_avg"] = zeros_like(param.tensor) 65 | state["exp_avg_sq"] = zeros_like(param.tensor) 66 | if amsgrad: 67 | state["max_exp_avg_sq"] = zeros_like(param.tensor) 68 | 69 | state["step"] += 1 70 | exp_avg = state["exp_avg"] 71 | exp_avg_sq = state["exp_avg_sq"] 72 | 73 | # TODO: check if this next line makes sense, because I don't think so 74 | grad.tensor.add_(param.tensor, alpha=weight_decay) 75 | grad = manifold.euc_to_tangent(x=param, u=grad) 76 | exp_avg.mul_(betas[0]).add_(grad.tensor, alpha=1 - betas[0]) 77 | exp_avg_sq.mul_(betas[1]).add_( 78 | manifold.inner(u=grad, v=grad, keepdim=True, safe_mode=False), 79 | alpha=1 - betas[1], 80 | ) 81 | bias_correction1 = 1 - betas[0] ** state["step"] 82 | bias_correction2 = 1 - betas[1] ** state["step"] 83 | 84 | if amsgrad: 85 | max_exp_avg_sq = state["max_exp_avg_sq"] 86 | max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 87 | denom = max_exp_avg_sq.div(bias_correction2).sqrt_() 88 | else: 89 | denom = exp_avg_sq.div(bias_correction2).sqrt_() 90 | 91 | direction = -lr * exp_avg.div(bias_correction1) / denom.add_(eps) 92 | direction = TangentTensor( 93 | data=direction, 94 | manifold_points=param, 95 | manifold=manifold, 96 | man_dim=param.man_dim, 97 | ) 98 | exp_avg_man = TangentTensor( 99 | data=exp_avg, 100 | manifold_points=param, 101 | manifold=manifold, 102 | man_dim=param.man_dim, 103 | ) 104 | 105 | new_param = manifold.expmap(direction) 106 | exp_avg_new = manifold.transp(v=exp_avg_man, y=new_param) 107 | 108 | param.tensor.copy_(new_param.tensor) 109 | exp_avg.copy_(exp_avg_new.tensor) 110 | 111 | else: 112 | grad = param.grad 113 | if grad is None: 114 | continue 115 | 116 | state = self.state[param] 117 | 118 | if len(state) == 0: 119 | state["step"] = 0 120 | state["exp_avg"] = zeros_like(param) 121 | state["exp_avg_sq"] = zeros_like(param) 122 | if amsgrad: 123 | state["max_exp_avg_sq"] = zeros_like(param) 124 | state["step"] += 1 125 | exp_avg = state["exp_avg"] 126 | exp_avg_sq = state["exp_avg_sq"] 127 | grad.add_(param, alpha=weight_decay) 128 | exp_avg.mul_(betas[0]).add_(grad, alpha=1 - betas[0]) 129 | exp_avg_sq.mul_(betas[1]).add_( 130 | grad.square().sum(dim=-1, keepdim=True), alpha=1 - betas[1] 131 | ) 132 | bias_correction1 = 1 - betas[0] ** state["step"] 133 | bias_correction2 = 1 - betas[1] ** state["step"] 134 | if amsgrad: 135 | max_exp_avg_sq = state["max_exp_avg_sq"] 136 | max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 137 | denom = max_exp_avg_sq.div(bias_correction2).sqrt_() 138 | else: 139 | denom = exp_avg_sq.div(bias_correction2).sqrt_() 140 | 141 | direction = exp_avg.div(bias_correction1) / denom.add_(eps) 142 | 143 | new_param = param - lr * direction 144 | exp_avg_new = exp_avg 145 | 146 | param.copy_(new_param) 147 | exp_avg.copy_(exp_avg_new) 148 | -------------------------------------------------------------------------------- /hypll/optim/sgd.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | from typing import Union 3 | 4 | from torch import no_grad 5 | from torch.optim import Optimizer 6 | 7 | from hypll.manifolds import Manifold 8 | from hypll.manifolds.euclidean import Euclidean 9 | from hypll.tensors import ManifoldParameter, ManifoldTensor 10 | 11 | 12 | class RiemannianSGD(Optimizer): 13 | def __init__( 14 | self, 15 | params: Iterable[Union[ManifoldParameter, ManifoldTensor]], 16 | lr: float, 17 | momentum: float = 0, 18 | dampening: float = 0, 19 | weight_decay: float = 0, 20 | nesterov: bool = False, 21 | ) -> None: 22 | if lr < 0.0: 23 | raise ValueError("Invalid learning rate: {}".format(lr)) 24 | if momentum < 0.0: 25 | raise ValueError("Invalid momentum value: {}".format(momentum)) 26 | if weight_decay < 0.0: 27 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 28 | if nesterov and (momentum <= 0 or dampening != 0): 29 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 30 | 31 | defaults = dict( 32 | lr=lr, 33 | momentum=momentum, 34 | dampening=dampening, 35 | weight_decay=weight_decay, 36 | nesterov=nesterov, 37 | ) 38 | 39 | super(RiemannianSGD, self).__init__(params=params, defaults=defaults) 40 | 41 | def step(self) -> None: 42 | with no_grad(): 43 | for group in self.param_groups: 44 | lr = group["lr"] 45 | momentum = group["momentum"] 46 | dampening = group["dampening"] 47 | weight_decay = group["weight_decay"] 48 | nestorov = group["nesterov"] 49 | 50 | for param in group["params"]: 51 | if isinstance(param, ManifoldParameter): 52 | grad = param.grad 53 | if grad is None: 54 | continue 55 | grad = ManifoldTensor( 56 | data=grad, manifold=Euclidean(), man_dim=param.man_dim 57 | ) 58 | state = self.state[param] 59 | 60 | if len(state) == 0: 61 | if momentum > 0: 62 | state["momentum_buffer"] = ManifoldTensor( 63 | data=grad.tensor.clone(), 64 | manifold=Euclidean(), 65 | man_dim=param.man_dim, 66 | ) 67 | 68 | manifold: Manifold = param.manifold 69 | 70 | grad.tensor.add_( 71 | param.tensor, alpha=weight_decay 72 | ) # TODO: check if this makes sense 73 | grad = manifold.euc_to_tangent(x=param, u=grad) 74 | if momentum > 0: 75 | momentum_buffer = manifold.euc_to_tangent( 76 | x=param, u=state["momentum_buffer"] 77 | ) 78 | momentum_buffer.tensor.mul_(momentum).add_( 79 | grad.tensor, alpha=1 - dampening 80 | ) 81 | if nestorov: 82 | grad.tensor.add_(momentum_buffer.tensor, alpha=momentum) 83 | else: 84 | grad = momentum_buffer 85 | 86 | grad.tensor = -lr * grad.tensor 87 | 88 | new_param = manifold.expmap(grad) 89 | 90 | momentum_buffer = manifold.transp(v=momentum_buffer, y=new_param) 91 | 92 | # momentum_buffer.tensor.copy_(new_momentum_buffer.tensor) 93 | param.tensor.copy_(new_param.tensor) 94 | else: 95 | grad.tensor = -lr * grad.tensor 96 | new_param = manifold.expmap(v=grad) 97 | param.tensor.copy_(new_param.tensor) 98 | 99 | else: 100 | grad = param.grad 101 | if grad is None: 102 | continue 103 | state = self.state[param] 104 | 105 | if len(state) == 0: 106 | if momentum > 0: 107 | state["momentum_buffer"] = grad.clone() 108 | 109 | grad.add_(param, alpha=weight_decay) 110 | if momentum > 0: 111 | momentum_buffer = state["momentum_buffer"] 112 | momentum_buffer.mul_(momentum).add_(grad, alpha=1 - dampening) 113 | if nestorov: 114 | grad = grad.add_(momentum_buffer, alpha=momentum) 115 | else: 116 | grad = momentum_buffer 117 | 118 | new_param = param - lr * grad 119 | new_momentum_buffer = momentum_buffer 120 | 121 | momentum_buffer.copy_(new_momentum_buffer) 122 | param.copy_(new_param) 123 | else: 124 | new_param = param - lr * grad 125 | param.copy_(new_param) 126 | -------------------------------------------------------------------------------- /hypll/tensors/__init__.py: -------------------------------------------------------------------------------- 1 | from .manifold_parameter import ManifoldParameter 2 | from .manifold_tensor import ManifoldTensor 3 | from .tangent_tensor import TangentTensor 4 | -------------------------------------------------------------------------------- /hypll/tensors/manifold_parameter.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from torch import Tensor, tensor 5 | from torch.nn import Parameter 6 | 7 | from hypll.manifolds import Manifold 8 | from hypll.tensors.manifold_tensor import ManifoldTensor 9 | 10 | 11 | class ManifoldParameter(ManifoldTensor, Parameter): 12 | _allowed_methods = [ 13 | torch._has_compatible_shallow_copy_type, # Required for torch.nn.Parameter 14 | torch.Tensor.copy_, # Required to load ManifoldParameters state dicts 15 | ] 16 | 17 | def __new__(cls, data, manifold, man_dim, requires_grad=True): 18 | return super(ManifoldTensor, cls).__new__(cls) 19 | 20 | # TODO: Create a mixin class containing the methods for this class and for ManifoldTensor 21 | # to avoid all the boilerplate stuff. 22 | def __init__( 23 | self, data, manifold: Manifold, man_dim: int = -1, requires_grad: bool = True 24 | ) -> None: 25 | super(ManifoldParameter, self).__init__(data=data, manifold=manifold) 26 | if isinstance(data, Parameter): 27 | self.tensor = data 28 | elif isinstance(data, Tensor): 29 | self.tensor = Parameter(data=data, requires_grad=requires_grad) 30 | else: 31 | self.tensor = Parameter(data=tensor(data), requires_grad=requires_grad) 32 | 33 | self.manifold = manifold 34 | 35 | if man_dim >= 0: 36 | self.man_dim = man_dim 37 | else: 38 | self.man_dim = self.tensor.dim() + man_dim 39 | if self.man_dim < 0: 40 | raise ValueError( 41 | f"Dimension out of range (expected to be in range of " 42 | f"{[-self.tensor.dim() - 1, self.tensor.dim()]}, but got {man_dim})" 43 | ) 44 | 45 | def __getattr__(self, name: str) -> Any: 46 | # TODO: go through https://pytorch.org/docs/stable/tensors.html and check which methods 47 | # are relevant. 48 | if hasattr(self.tensor, name): 49 | torch_attribute = getattr(self.tensor, name) 50 | 51 | if callable(torch_attribute): 52 | raise AttributeError( 53 | f"Attempting to apply the torch.nn.Parameter method {name} on a ManifoldParameter." 54 | f"Use ManifoldTensor.tensor.{name} instead." 55 | ) 56 | else: 57 | return torch_attribute 58 | 59 | else: 60 | raise AttributeError( 61 | f"Neither {self.__class__.__name__}, nor torch.Tensor has attribute {name}" 62 | ) 63 | 64 | @classmethod 65 | def __torch_function__(cls, func, types, args=(), kwargs=None): 66 | if func.__class__.__name__ == "method-wrapper" or func in cls._allowed_methods: 67 | args = [a.tensor if isinstance(a, ManifoldTensor) else a for a in args] 68 | if kwargs is None: 69 | kwargs = {} 70 | kwargs = {k: (v.tensor if isinstance(v, ManifoldTensor) else v) for k, v in kwargs} 71 | return func(*args, **kwargs) 72 | # if func.__name__ == "__get__": 73 | # return func(args[0].tensor) 74 | # TODO: check if there are torch functions that should be allowed 75 | raise TypeError( 76 | f"Attempting to apply the torch function {func} on a ManifoldParameter. " 77 | f"Use ManifoldParameter.tensor as argument to {func} instead." 78 | ) 79 | -------------------------------------------------------------------------------- /hypll/tensors/manifold_tensor.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Optional 4 | 5 | from torch import Tensor, long, tensor 6 | 7 | from hypll.manifolds import Manifold 8 | 9 | 10 | class ManifoldTensor: 11 | """Represents a tensor on a manifold. 12 | 13 | Attributes: 14 | tensor: 15 | Torch tensor of points on the manifold. 16 | manifold: 17 | Manifold instance. 18 | man_dim: 19 | Dimension along which points are on the manifold. 20 | 21 | """ 22 | 23 | def __init__( 24 | self, data: Tensor, manifold: Manifold, man_dim: int = -1, requires_grad: bool = False 25 | ) -> None: 26 | """Creates an instance of ManifoldTensor. 27 | 28 | Args: 29 | data: 30 | Torch tensor of points on the manifold. 31 | manifold: 32 | Manifold instance. 33 | man_dim: 34 | Dimension along which points are on the manifold. -1 by default. 35 | 36 | TODO(Philipp, 05/23): Let's get rid of requires_grad if possible. 37 | 38 | """ 39 | self.tensor = data if isinstance(data, Tensor) else tensor(data, requires_grad=True) 40 | self.manifold = manifold 41 | 42 | if man_dim >= 0: 43 | self.man_dim = man_dim 44 | else: 45 | self.man_dim = self.tensor.dim() + man_dim 46 | if self.man_dim < 0: 47 | raise ValueError( 48 | f"Dimension out of range (expected to be in range of " 49 | f"{[-self.tensor.dim() - 1, self.tensor.dim()]}, but got {man_dim})" 50 | ) 51 | 52 | def __getitem__(self, *args): 53 | # Catch some undefined behaviour by checking if args is a single element 54 | if len(args) != 1: 55 | raise ValueError( 56 | f"No support for slicing with these arguments. If you think there should be " 57 | f"support, please consider opening a issue on GitHub describing your case." 58 | ) 59 | 60 | # Deal with the case where the argument is a long tensor 61 | if isinstance(args[0], Tensor) and args[0].dtype == long: 62 | if self.man_dim == 0: 63 | raise ValueError( 64 | f"Long tensor indexing is only possible when the manifold dimension " 65 | f"is not 0, but the manifold dimension is {self.man_dim}" 66 | ) 67 | new_tensor = self.tensor.__getitem__(*args) 68 | new_man_dim = self.man_dim + args[0].dim() - 1 69 | return ManifoldTensor(data=new_tensor, manifold=self.manifold, man_dim=new_man_dim) 70 | 71 | # Convert the args to a list and replace Ellipsis by the correct number of full slices 72 | if isinstance(args[0], int): 73 | arg_list = [args[0]] 74 | else: 75 | arg_list = list(args[0]) 76 | 77 | if Ellipsis in arg_list: 78 | ell_id = arg_list.index(Ellipsis) 79 | colon_repeats = self.dim() - sum(1 for a in arg_list if a is not None) + 1 80 | arg_list[ell_id : ell_id + 1] = colon_repeats * [slice(None, None, None)] 81 | 82 | new_tensor = self.tensor.__getitem__(*args) 83 | output_man_dim = self.man_dim 84 | counter = self.man_dim + 1 85 | 86 | # Compute output manifold dimension 87 | for arg in arg_list: 88 | # None values add a dimension 89 | if arg is None: 90 | output_man_dim += 1 91 | continue 92 | # Integers remove a dimension 93 | elif isinstance(arg, int): 94 | output_man_dim -= 1 95 | counter -= 1 96 | # Other values leave the dimension intact 97 | else: 98 | counter -= 1 99 | 100 | # When the counter hits 0 and the next term isn't None, we hit the man_dim term 101 | if counter == 0: 102 | if isinstance(arg, int) or isinstance(arg, list): 103 | raise ValueError( 104 | f"Attempting to slice into the manifold dimension, but this is not a " 105 | "valid operation" 106 | ) 107 | # If we get past the man_dim term, the output man_dim doesn't change anymore 108 | break 109 | 110 | return ManifoldTensor(data=new_tensor, manifold=self.manifold, man_dim=output_man_dim) 111 | 112 | def __hash__(self): 113 | """Returns the Python unique identifier of the object. 114 | 115 | Note: This is how PyTorch implements hash of tensors. See also: 116 | https://github.com/pytorch/pytorch/issues/2569. 117 | 118 | """ 119 | return id(self) 120 | 121 | def cpu(self) -> ManifoldTensor: 122 | """Returns a copy of this object with self.tensor in CPU memory.""" 123 | new_tensor = self.tensor.cpu() 124 | return ManifoldTensor(data=new_tensor, manifold=self.manifold, man_dim=self.man_dim) 125 | 126 | def cuda(self, device=None) -> ManifoldTensor: 127 | """Returns a copy of this object with self.tensor in CUDA memory.""" 128 | new_tensor = self.tensor.cuda(device) 129 | return ManifoldTensor(data=new_tensor, manifold=self.manifold, man_dim=self.man_dim) 130 | 131 | def dim(self) -> int: 132 | """Returns the number of dimensions of self.tensor.""" 133 | return self.tensor.dim() 134 | 135 | def detach(self) -> ManifoldTensor: 136 | """Returns a new Tensor, detached from the current graph.""" 137 | detached = self.tensor.detach() 138 | return ManifoldTensor(data=detached, manifold=self.manifold, man_dim=self.man_dim) 139 | 140 | def flatten(self, start_dim: int = 0, end_dim: int = -1) -> ManifoldTensor: 141 | """Flattens tensor by reshaping it. If start_dim or end_dim are passed, 142 | only dimensions starting with start_dim and ending with end_dim are flattend. 143 | 144 | """ 145 | return self.manifold.flatten(self, start_dim=start_dim, end_dim=end_dim) 146 | 147 | @property 148 | def is_cpu(self): 149 | return self.tensor.is_cpu 150 | 151 | @property 152 | def is_cuda(self): 153 | return self.tensor.is_cuda 154 | 155 | def is_floating_point(self) -> bool: 156 | """Returns true if the tensor is of dtype float.""" 157 | return self.tensor.is_floating_point() 158 | 159 | def project(self) -> ManifoldTensor: 160 | """Projects the tensor to the manifold.""" 161 | return self.manifold.project(x=self) 162 | 163 | @property 164 | def shape(self): 165 | """Alias for size().""" 166 | return self.size() 167 | 168 | def size(self, dim: Optional[int] = None): 169 | """Returns the size of self.tensor.""" 170 | if dim is None: 171 | return self.tensor.size() 172 | else: 173 | return self.tensor.size(dim) 174 | 175 | def squeeze(self, dim=None): 176 | """Returns a squeezed version of the manifold tensor.""" 177 | if dim == self.man_dim or (dim is None and self.size(self.man_dim) == 1): 178 | raise ValueError("Attempting to squeeze the manifold dimension") 179 | 180 | if dim is None: 181 | new_tensor = self.tensor.squeeze() 182 | new_man_dim = self.man_dim - sum(self.size(d) == 1 for d in range(self.man_dim)) 183 | else: 184 | new_tensor = self.tensor.squeeze(dim=dim) 185 | new_man_dim = self.man_dim - (1 if dim < self.man_dim else 0) 186 | 187 | return ManifoldTensor(data=new_tensor, manifold=self.manifold, man_dim=new_man_dim) 188 | 189 | def to(self, *args, **kwargs) -> ManifoldTensor: 190 | """Returns a new tensor with the specified device and (optional) dtype.""" 191 | new_tensor = self.tensor.to(*args, **kwargs) 192 | return ManifoldTensor(data=new_tensor, manifold=self.manifold, man_dim=self.man_dim) 193 | 194 | def transpose(self, dim0: int, dim1: int) -> ManifoldTensor: 195 | """Returns a transposed version of the manifold tensor. The given dimensions 196 | dim0 and dim1 are swapped. 197 | """ 198 | if self.man_dim == dim0: 199 | new_man_dim = dim1 200 | elif self.man_dim == dim1: 201 | new_man_dim = dim0 202 | new_tensor = self.tensor.transpose(dim0, dim1) 203 | return ManifoldTensor(data=new_tensor, manifold=self.manifold, man_dim=new_man_dim) 204 | 205 | @classmethod 206 | def __torch_function__(cls, func, types, args=(), kwargs=None): 207 | # TODO: check if there are torch functions that should be allowed 208 | raise TypeError( 209 | f"Attempting to apply the torch function {func} on a ManifoldTensor. " 210 | f"Use ManifoldTensor.tensor as argument to {func} instead." 211 | ) 212 | 213 | def unsqueeze(self, dim: int) -> ManifoldTensor: 214 | """Returns a new manifold tensor with a dimension of size one inserted at the specified position.""" 215 | new_tensor = self.tensor.unsqueeze(dim=dim) 216 | new_man_dim = self.man_dim + (1 if dim <= self.man_dim else 0) 217 | return ManifoldTensor(data=new_tensor, manifold=self.manifold, man_dim=new_man_dim) 218 | -------------------------------------------------------------------------------- /hypll/tensors/tangent_tensor.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | from torch import Tensor, broadcast_shapes, tensor 4 | 5 | from hypll.manifolds import Manifold 6 | from hypll.tensors.manifold_tensor import ManifoldTensor 7 | 8 | 9 | class TangentTensor: 10 | def __init__( 11 | self, 12 | data, 13 | manifold_points: Optional[ManifoldTensor] = None, 14 | manifold: Optional[Manifold] = None, 15 | man_dim: int = -1, 16 | requires_grad: bool = False, 17 | ) -> None: 18 | # Create tangent vector tensor with correct device and dtype 19 | if isinstance(data, Tensor): 20 | self.tensor = data 21 | else: 22 | self.tensor = tensor(data, requires_grad=requires_grad) 23 | 24 | # Store manifold dimension as a nonnegative integer 25 | if man_dim >= 0: 26 | self.man_dim = man_dim 27 | else: 28 | self.man_dim = self.tensor.dim() + man_dim 29 | if self.man_dim < 0: 30 | raise ValueError( 31 | f"Dimension out of range (expected to be in range of " 32 | f"{[-self.tensor.dim() - 1, self.tensor.dim()]}, but got {man_dim})" 33 | ) 34 | 35 | if manifold_points is not None: 36 | # Check if the manifold points and tangent vectors are broadcastable together 37 | try: 38 | broadcasted_size = broadcast_shapes(self.tensor.size(), manifold_points.size()) 39 | except RuntimeError: 40 | raise ValueError( 41 | f"The shapes of the manifold points tensor {manifold_points.size()} and " 42 | f"the tangent vector tensor {self.tensor.size()} are not broadcastable " 43 | f"togther." 44 | ) 45 | 46 | # Check if the manifold dimensions match after broadcasting 47 | dim = len(broadcasted_size) 48 | broadcast_man_dims = [ 49 | manifold_points.man_dim + dim - manifold_points.tensor.dim(), 50 | self.man_dim + dim - self.tensor.dim(), 51 | ] 52 | if broadcast_man_dims[0] != broadcast_man_dims[1]: 53 | raise ValueError( 54 | f"After broadcasting the manifold points with the tangent vectors, the " 55 | f"manifold dimension computed from the manifold points should match the " 56 | f"manifold dimension computed from the supplied man_dim, but these are" 57 | f"{broadcast_man_dims}, respectively." 58 | ) 59 | 60 | # Check if the supplied manifolds match 61 | if manifold_points is not None and manifold is not None: 62 | if manifold_points.manifold != manifold: 63 | raise ValueError( 64 | f"The manifold of the manifold_points and the provided manifold should match, " 65 | f"but are {manifold_points.manifold} and {manifold}, respectively." 66 | ) 67 | 68 | self.manifold_points = manifold_points 69 | self.manifold = manifold or manifold_points.manifold 70 | 71 | def __getattr__(self, name: str) -> Any: 72 | # TODO: go through https://pytorch.org/docs/stable/tensors.html and check which methods 73 | # are relevant. 74 | if hasattr(self.tensor, name): 75 | torch_attribute = getattr(self.tensor, name) 76 | 77 | if callable(torch_attribute): 78 | raise AttributeError( 79 | f"Attempting to apply the torch.Tensor method {name} on a TangentTensor." 80 | f"Use TangentTensor.tensor.{name} or TangentTensor.manifold_points.tensor " 81 | f"instead." 82 | ) 83 | else: 84 | return torch_attribute 85 | 86 | else: 87 | raise AttributeError( 88 | f"Neither {self.__class__.__name__}, nor torch.Tensor has attribute {name}" 89 | ) 90 | 91 | def __hash__(self): 92 | """Returns the Python unique identifier of the object. 93 | 94 | Note: This is how PyTorch implements hash of tensors. See also: 95 | https://github.com/pytorch/pytorch/issues/2569. 96 | 97 | """ 98 | return id(self) 99 | 100 | def cuda(self, device=None): 101 | new_tensor = self.tensor.cuda(device) 102 | new_manifold_points = self.manifold_points.cuda(device) 103 | return TangentTensor( 104 | data=new_tensor, 105 | manifold_points=new_manifold_points, 106 | manifold=self.manifold, 107 | man_dim=self.man_dim, 108 | ) 109 | 110 | def cpu(self): 111 | new_tensor = self.tensor.cpu() 112 | new_manifold_points = self.manifold_points.cpu() 113 | return TangentTensor( 114 | data=new_tensor, 115 | manifold_points=new_manifold_points, 116 | manifold=self.manifold, 117 | man_dim=self.man_dim, 118 | ) 119 | 120 | def to(self, *args, **kwargs): 121 | new_tensor = self.tensor.to(*args, **kwargs) 122 | new_manifold_points = self.manifold_points(*args, **kwargs) 123 | return TangentTensor( 124 | data=new_tensor, 125 | manifold_points=new_manifold_points, 126 | manifold=self.manifold, 127 | man_dim=self.man_dim, 128 | ) 129 | 130 | def size(self, dim: Optional[int] = None): 131 | if self.manifold_points is None: 132 | manifold_points_size = None 133 | manifold_points_size = ( 134 | self.manifold_points.size() if self.manifold_points is not None else () 135 | ) 136 | broadcasted_size = broadcast_shapes(self.tensor.size(), manifold_points_size) 137 | if dim is None: 138 | return broadcasted_size 139 | else: 140 | return broadcasted_size[dim] 141 | 142 | @property 143 | def broadcasted_man_dim(self): 144 | return self.man_dim + self.dim() - self.tensor.dim() 145 | 146 | def dim(self): 147 | return len(self.size()) 148 | 149 | @classmethod 150 | def __torch_function__(cls, func, types, args=(), kwargs=None): 151 | # TODO: check if there are torch functions that should be allowed 152 | raise TypeError( 153 | f"Attempting to apply the torch function {func} on a TangentTensor." 154 | f"Use TangentTensor.tensor as argument to {func} instead." 155 | ) 156 | -------------------------------------------------------------------------------- /hypll/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxvanspengler/hyperbolic_learning_library/62ea7882848d878459eba4100aeb6b8ad64cc38f/hypll/utils/__init__.py -------------------------------------------------------------------------------- /hypll/utils/layer_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from torch.nn import Module 4 | 5 | from hypll.manifolds import Manifold 6 | from hypll.tensors import ManifoldTensor 7 | 8 | 9 | def check_if_manifolds_match(layer: Module, input: ManifoldTensor) -> None: 10 | if layer.manifold != input.manifold: 11 | raise ValueError( 12 | f"Manifold of {layer.__class__.__name__} layer is {layer.manifold} " 13 | f"but input has manifold {input.manifold}" 14 | ) 15 | 16 | 17 | def check_if_man_dims_match(layer: Module, man_dim: int, input: ManifoldTensor) -> None: 18 | if man_dim < 0: 19 | new_man_dim = input.dim() + man_dim 20 | else: 21 | new_man_dim = man_dim 22 | 23 | if input.man_dim != new_man_dim: 24 | raise ValueError( 25 | f"Layer of type {layer.__class__.__name__} expects the manifold dimension to be {man_dim}, " 26 | f"but input has manifold dimension {input.man_dim}" 27 | ) 28 | 29 | 30 | def op_in_tangent_space(op: Callable, manifold: Manifold, input: ManifoldTensor) -> ManifoldTensor: 31 | input = manifold.logmap(x=None, y=input) 32 | input.tensor = op(input.tensor) 33 | return manifold.expmap(input) 34 | -------------------------------------------------------------------------------- /hypll/utils/math.py: -------------------------------------------------------------------------------- 1 | from math import exp, lgamma 2 | from typing import Union 3 | 4 | import torch 5 | 6 | 7 | def beta_func( 8 | a: Union[float, torch.Tensor], b: Union[float, torch.Tensor] 9 | ) -> Union[float, torch.Tensor]: 10 | if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor): 11 | a = torch.as_tensor(a) if not isinstance(a, torch.Tensor) else a 12 | b = torch.as_tensor(b) if not isinstance(b, torch.Tensor) else b 13 | return torch.exp(torch.lgamma(a) + torch.lgamma(b) - torch.lgamma(a + b)) 14 | return exp(lgamma(a) + lgamma(b) - lgamma(a + b)) 15 | -------------------------------------------------------------------------------- /hypll/utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | from torch import broadcast_shapes, equal 4 | 5 | from hypll.tensors import ManifoldTensor, TangentTensor 6 | 7 | 8 | def check_dims_with_broadcasting(*args: ManifoldTensor) -> int: 9 | # Check if shapes can be broadcasted together 10 | shapes = [a.size() for a in args] 11 | try: 12 | broadcasted_shape = broadcast_shapes(*shapes) 13 | except RuntimeError: 14 | raise ValueError(f"Shapes of inputs were {shapes} and cannot be broadcasted together") 15 | 16 | # Find the manifold dimensions after broadcasting 17 | max_dim = len(broadcasted_shape) 18 | man_dims = [] 19 | for a in args: 20 | if isinstance(a, ManifoldTensor): 21 | man_dims.append(a.man_dim + (max_dim - a.dim())) 22 | elif isinstance(a, TangentTensor): 23 | man_dims.append(a.broadcasted_man_dim + (max_dim - a.dim())) 24 | 25 | for md in man_dims[1:]: 26 | if man_dims[0] != md: 27 | raise ValueError("Manifold dimensions of inputs after broadcasting do not match.") 28 | 29 | return man_dims[0] 30 | 31 | 32 | def check_tangent_tensor_positions(*args: TangentTensor) -> None: 33 | manifold_points = [a.manifold_points for a in args] 34 | if any([mp is None for mp in manifold_points]): 35 | if not all(mp is None for mp in manifold_points): 36 | raise ValueError(f"Some but not all tangent tensors are located at the origin") 37 | 38 | else: 39 | broadcasted_shape = broadcast_shapes(*[a.size() for a in args]) 40 | broadcasted_manifold_points = [ 41 | mp.tensor.broadcast_to(broadcasted_shape) for mp in manifold_points 42 | ] 43 | for bmp in broadcasted_manifold_points[1:]: 44 | if not equal(broadcasted_manifold_points[0], bmp): 45 | raise ValueError( 46 | f"Tangent tensors are positioned at the different points on the manifold" 47 | ) 48 | 49 | 50 | def check_if_man_dims_match( 51 | manifold_tensors: Union[Tuple[ManifoldTensor, ...], List[ManifoldTensor]] 52 | ) -> None: 53 | iterator = iter(manifold_tensors) 54 | 55 | try: 56 | first_item = next(iterator) 57 | except StopIteration: 58 | return 59 | 60 | for i, x in enumerate(iterator): 61 | if x.man_dim != first_item.man_dim: 62 | raise ValueError( 63 | f"Manifold dimensions of inputs do not match. " 64 | f"Input at index [0] has manifold dimension {first_item.man_dim} " 65 | f"but input at index [{i + 1}] has manifold dimension {x.man_dim}." 66 | ) 67 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "hypll" 3 | version = "0.1.2" 4 | description = "A framework for hyperbolic learning in PyTorch" 5 | authors = ["Max van Spengler", "Philipp Wirth", "Pascal Mettes"] 6 | 7 | [tool.poetry.dependencies] 8 | python = "^3.10" 9 | torch = "*" 10 | # Optional dependencies for building the docs. 11 | # Install with poetry install --extras "docs" 12 | sphinx-copybutton = { version = "^0.5.2", optional = true } 13 | sphinx = { version = "^7.2.6", optional = true } 14 | sphinx-gallery = { version = "^0.15.0", optional = true } 15 | sphinx-tabs = { version = "^3.4.5", optional = true } 16 | torchvision = { version = "^0.17.2", optional = true } 17 | matplotlib = { version = "^3.8.4", optional = true } 18 | networkx = { version = "^3.2.1", optional = true } 19 | 20 | [tool.poetry.dev-dependencies] 21 | black = "*" 22 | isort = "*" 23 | pytest = "*" 24 | pytest-mock = "*" 25 | 26 | [tool.poetry.extras] 27 | docs = [ 28 | "sphinx-copybutton", 29 | "sphinx", 30 | "sphinx-gallery", 31 | "sphinx-tabs", 32 | "torchvision", 33 | "matplotlib", 34 | "networkx", 35 | "timm", 36 | "fastai", 37 | ] 38 | 39 | [build-system] 40 | requires = ["poetry-core>=1.0.0"] 41 | build-backend = "poetry.core.masonry.api" 42 | 43 | [tool.black] 44 | line-length = 100 45 | 46 | [tool.isort] 47 | profile = "black" 48 | -------------------------------------------------------------------------------- /tests/manifolds/euclidean/test_euclidean.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from hypll.manifolds.euclidean import Euclidean 4 | from hypll.tensors import ManifoldTensor 5 | 6 | 7 | def test_flatten__man_dim_equals_start_dim() -> None: 8 | tensor = torch.randn(2, 2, 2, 2) 9 | manifold = Euclidean() 10 | manifold_tensor = ManifoldTensor( 11 | data=tensor, 12 | manifold=manifold, 13 | man_dim=1, 14 | requires_grad=True, 15 | ) 16 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=-1) 17 | assert flattened.shape == (2, 8) 18 | assert flattened.man_dim == 1 19 | 20 | 21 | def test_flatten__man_dim_equals_start_dim__set_end_dim() -> None: 22 | tensor = torch.randn(2, 2, 2, 2) 23 | manifold = Euclidean() 24 | manifold_tensor = ManifoldTensor( 25 | data=tensor, 26 | manifold=manifold, 27 | man_dim=1, 28 | requires_grad=True, 29 | ) 30 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=2) 31 | assert flattened.shape == (2, 4, 2) 32 | assert flattened.man_dim == 1 33 | 34 | 35 | def test_flatten__man_dim_larger_start_dim() -> None: 36 | tensor = torch.randn(2, 2, 2, 2) 37 | manifold = Euclidean() 38 | manifold_tensor = ManifoldTensor( 39 | data=tensor, 40 | manifold=manifold, 41 | man_dim=2, 42 | requires_grad=True, 43 | ) 44 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=-1) 45 | assert flattened.shape == (2, 8) 46 | assert flattened.man_dim == 1 47 | 48 | 49 | def test_flatten__man_dim_larger_start_dim__set_end_dim() -> None: 50 | tensor = torch.randn(2, 2, 2, 2) 51 | manifold = Euclidean() 52 | manifold_tensor = ManifoldTensor( 53 | data=tensor, 54 | manifold=manifold, 55 | man_dim=2, 56 | requires_grad=True, 57 | ) 58 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=2) 59 | assert flattened.shape == (2, 4, 2) 60 | assert flattened.man_dim == 1 61 | 62 | 63 | def test_flatten__man_dim_smaller_start_dim() -> None: 64 | tensor = torch.randn(2, 2, 2, 2) 65 | manifold = Euclidean() 66 | manifold_tensor = ManifoldTensor( 67 | data=tensor, 68 | manifold=manifold, 69 | man_dim=0, 70 | requires_grad=True, 71 | ) 72 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=-1) 73 | assert flattened.shape == (2, 8) 74 | assert flattened.man_dim == 0 75 | 76 | 77 | def test_flatten__man_dim_larger_end_dim() -> None: 78 | tensor = torch.randn(2, 2, 2, 2) 79 | manifold = Euclidean() 80 | manifold_tensor = ManifoldTensor( 81 | data=tensor, 82 | manifold=manifold, 83 | man_dim=2, 84 | requires_grad=True, 85 | ) 86 | flattened = manifold.flatten(manifold_tensor, start_dim=0, end_dim=1) 87 | assert flattened.shape == (4, 2, 2) 88 | assert flattened.man_dim == 1 89 | 90 | 91 | def test_cdist__correct_dist(): 92 | B, P, R, M = 2, 3, 4, 8 93 | manifold = Euclidean() 94 | mt1 = ManifoldTensor(torch.randn(B, P, M), manifold=manifold) 95 | mt2 = ManifoldTensor(torch.randn(B, R, M), manifold=manifold) 96 | dist_matrix = manifold.cdist(mt1, mt2) 97 | for b in range(B): 98 | for p in range(P): 99 | for r in range(R): 100 | assert torch.isclose( 101 | dist_matrix[b, p, r], manifold.dist(mt1[b, p], mt2[b, r]), equal_nan=True 102 | ) 103 | 104 | 105 | def test_cdist__correct_dims(): 106 | B, P, R, M = 2, 3, 4, 8 107 | manifold = Euclidean() 108 | mt1 = ManifoldTensor(torch.randn(B, P, M), manifold=manifold) 109 | mt2 = ManifoldTensor(torch.randn(B, R, M), manifold=manifold) 110 | dist_matrix = manifold.cdist(mt1, mt2) 111 | assert dist_matrix.shape == (B, P, R) 112 | 113 | 114 | def test_cat__correct_dims(): 115 | N, D1, D2 = 10, 2, 3 116 | manifold = Euclidean() 117 | manifold_tensors = [ManifoldTensor(torch.randn(D1, D2), manifold=manifold) for _ in range(N)] 118 | cat_0 = manifold.cat(manifold_tensors, dim=0) 119 | cat_1 = manifold.cat(manifold_tensors, dim=1) 120 | assert cat_0.shape == (D1 * N, D2) 121 | assert cat_1.shape == (D1, D2 * N) 122 | 123 | 124 | def test_cat__correct_man_dim(): 125 | N, D1, D2 = 10, 2, 3 126 | manifold = Euclidean() 127 | manifold_tensors = [ 128 | ManifoldTensor(torch.randn(D1, D2), manifold=manifold, man_dim=1) for _ in range(N) 129 | ] 130 | cat_0 = manifold.cat(manifold_tensors, dim=0) 131 | cat_1 = manifold.cat(manifold_tensors, dim=1) 132 | assert cat_0.man_dim == cat_1.man_dim == 1 133 | -------------------------------------------------------------------------------- /tests/manifolds/poincare_ball/test_curvature.py: -------------------------------------------------------------------------------- 1 | from pytest_mock import MockerFixture 2 | 3 | from hypll.manifolds.poincare_ball import Curvature 4 | 5 | 6 | def test_curvature(mocker: MockerFixture) -> None: 7 | mocked_constraining_strategy = mocker.MagicMock() 8 | mocked_constraining_strategy.return_value = 1.33 # dummy value 9 | curvature = Curvature(value=1.0, constraining_strategy=mocked_constraining_strategy) 10 | assert curvature() == 1.33 11 | mocked_constraining_strategy.assert_called_once_with(1.0) 12 | -------------------------------------------------------------------------------- /tests/manifolds/poincare_ball/test_poincare_ball.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from hypll.manifolds.poincare_ball import Curvature, PoincareBall 4 | from hypll.tensors import ManifoldTensor 5 | 6 | 7 | def test_flatten__man_dim_equals_start_dim() -> None: 8 | tensor = torch.randn(2, 2, 2, 2) 9 | manifold = PoincareBall(c=Curvature()) 10 | manifold_tensor = ManifoldTensor( 11 | data=tensor, 12 | manifold=manifold, 13 | man_dim=1, 14 | requires_grad=True, 15 | ) 16 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=-1) 17 | assert flattened.shape == (2, 8) 18 | assert flattened.man_dim == 1 19 | 20 | 21 | def test_flatten__man_dim_equals_start_dim__set_end_dim() -> None: 22 | tensor = torch.randn(2, 2, 2, 2) 23 | manifold = PoincareBall(c=Curvature()) 24 | manifold_tensor = ManifoldTensor( 25 | data=tensor, 26 | manifold=manifold, 27 | man_dim=1, 28 | requires_grad=True, 29 | ) 30 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=2) 31 | assert flattened.shape == (2, 4, 2) 32 | assert flattened.man_dim == 1 33 | 34 | 35 | def test_flatten__man_dim_larger_start_dim() -> None: 36 | tensor = torch.randn(2, 2, 2, 2) 37 | manifold = PoincareBall(c=Curvature()) 38 | manifold_tensor = ManifoldTensor( 39 | data=tensor, 40 | manifold=manifold, 41 | man_dim=2, 42 | requires_grad=True, 43 | ) 44 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=-1) 45 | assert flattened.shape == (2, 8) 46 | assert flattened.man_dim == 1 47 | 48 | 49 | def test_flatten__man_dim_larger_start_dim__set_end_dim() -> None: 50 | tensor = torch.randn(2, 2, 2, 2) 51 | manifold = PoincareBall(c=Curvature()) 52 | manifold_tensor = ManifoldTensor( 53 | data=tensor, 54 | manifold=manifold, 55 | man_dim=2, 56 | requires_grad=True, 57 | ) 58 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=2) 59 | assert flattened.shape == (2, 4, 2) 60 | assert flattened.man_dim == 1 61 | 62 | 63 | def test_flatten__man_dim_smaller_start_dim() -> None: 64 | tensor = torch.randn(2, 2, 2, 2) 65 | manifold = PoincareBall(c=Curvature()) 66 | manifold_tensor = ManifoldTensor( 67 | data=tensor, 68 | manifold=manifold, 69 | man_dim=0, 70 | requires_grad=True, 71 | ) 72 | flattened = manifold.flatten(manifold_tensor, start_dim=1, end_dim=-1) 73 | assert flattened.shape == (2, 8) 74 | assert flattened.man_dim == 0 75 | 76 | 77 | def test_flatten__man_dim_larger_end_dim() -> None: 78 | tensor = torch.randn(2, 2, 2, 2) 79 | manifold = PoincareBall(c=Curvature()) 80 | manifold_tensor = ManifoldTensor( 81 | data=tensor, 82 | manifold=manifold, 83 | man_dim=2, 84 | requires_grad=True, 85 | ) 86 | flattened = manifold.flatten(manifold_tensor, start_dim=0, end_dim=1) 87 | assert flattened.shape == (4, 2, 2) 88 | assert flattened.man_dim == 1 89 | 90 | 91 | def test_cdist__correct_dist(): 92 | B, P, R, M = 2, 3, 4, 8 93 | manifold = PoincareBall(c=Curvature()) 94 | mt1 = ManifoldTensor(torch.randn(B, P, M), manifold=manifold) 95 | mt2 = ManifoldTensor(torch.randn(B, R, M), manifold=manifold) 96 | dist_matrix = manifold.cdist(mt1, mt2) 97 | for b in range(B): 98 | for p in range(P): 99 | for r in range(R): 100 | assert torch.isclose( 101 | dist_matrix[b, p, r], manifold.dist(mt1[b, p], mt2[b, r]), equal_nan=True 102 | ) 103 | 104 | 105 | def test_cdist__correct_dims(): 106 | B, P, R, M = 2, 3, 4, 8 107 | manifold = PoincareBall(c=Curvature()) 108 | mt1 = ManifoldTensor(torch.randn(B, P, M), manifold=manifold) 109 | mt2 = ManifoldTensor(torch.randn(B, R, M), manifold=manifold) 110 | dist_matrix = manifold.cdist(mt1, mt2) 111 | assert dist_matrix.shape == (B, P, R) 112 | 113 | 114 | def test_cat__correct_dims(): 115 | N, D1, D2 = 10, 2, 3 116 | manifold = PoincareBall(c=Curvature()) 117 | manifold_tensors = [ManifoldTensor(torch.randn(D1, D2), manifold=manifold) for _ in range(N)] 118 | cat_0 = manifold.cat(manifold_tensors, dim=0) 119 | cat_1 = manifold.cat(manifold_tensors, dim=1) 120 | assert cat_0.shape == (D1 * N, D2) 121 | assert cat_1.shape == (D1, D2 * N) 122 | 123 | 124 | def test_cat__correct_man_dim(): 125 | N, D1, D2 = 10, 2, 3 126 | manifold = PoincareBall(c=Curvature()) 127 | manifold_tensors = [ 128 | ManifoldTensor(torch.randn(D1, D2), manifold=manifold, man_dim=1) for _ in range(N) 129 | ] 130 | cat_0 = manifold.cat(manifold_tensors, dim=0) 131 | cat_1 = manifold.cat(manifold_tensors, dim=1) 132 | assert cat_0.man_dim == cat_1.man_dim == 1 133 | 134 | 135 | def test_cat__beta_concatenation_correct_norm(): 136 | MD = 64 137 | manifold = PoincareBall(c=Curvature(0.1, constraining_strategy=lambda x: x)) 138 | t1 = torch.randn(MD) 139 | t2 = torch.randn(MD) 140 | mt1 = ManifoldTensor(t1 / t1.norm(), manifold=manifold) 141 | mt2 = ManifoldTensor(t2 / t2.norm(), manifold=manifold) 142 | cat = manifold.cat([mt1, mt2]) 143 | assert torch.isclose(cat.tensor.norm(), torch.as_tensor(1.0), atol=1e-2) 144 | -------------------------------------------------------------------------------- /tests/nn/test_change_manifold.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from hypll.manifolds.euclidean import Euclidean 4 | from hypll.manifolds.poincare_ball import Curvature, PoincareBall 5 | from hypll.nn import ChangeManifold 6 | from hypll.tensors import ManifoldTensor, TangentTensor 7 | 8 | 9 | def test_change_manifold__euclidean_to_euclidean() -> None: 10 | # Define inputs. 11 | manifold = Euclidean() 12 | inputs = ManifoldTensor(data=torch.randn(10, 2), manifold=manifold, man_dim=1) 13 | # Apply change manifold. 14 | change_manifold = ChangeManifold(target_manifold=Euclidean()) 15 | outputs = change_manifold(inputs) 16 | # Assert outputs are correct. 17 | assert isinstance(outputs, ManifoldTensor) 18 | assert outputs.shape == inputs.shape 19 | assert change_manifold.target_manifold == outputs.manifold 20 | assert outputs.man_dim == 1 21 | assert torch.allclose(inputs.tensor, outputs.tensor) 22 | 23 | 24 | def test_change_manifold__euclidean_to_poincare_ball() -> None: 25 | # Define inputs. 26 | manifold = Euclidean() 27 | inputs = ManifoldTensor(data=torch.randn(10, 2), manifold=manifold, man_dim=1) 28 | # Apply change manifold. 29 | change_manifold = ChangeManifold( 30 | target_manifold=PoincareBall(c=Curvature()), 31 | ) 32 | outputs = change_manifold(inputs) 33 | # Assert outputs are correct. 34 | assert isinstance(outputs, ManifoldTensor) 35 | assert outputs.shape == inputs.shape 36 | assert change_manifold.target_manifold == outputs.manifold 37 | assert outputs.man_dim == 1 38 | 39 | 40 | def test_change_manifold__poincare_ball_to_euclidean() -> None: 41 | # Define inputs. 42 | manifold = PoincareBall(c=Curvature(0.1)) 43 | tangents = TangentTensor(data=torch.randn(10, 2), manifold=manifold, man_dim=1) 44 | inputs = manifold.expmap(tangents) 45 | # Apply change manifold. 46 | change_manifold = ChangeManifold(target_manifold=Euclidean()) 47 | outputs = change_manifold(inputs) 48 | # Assert outputs are correct. 49 | assert isinstance(outputs, ManifoldTensor) 50 | assert outputs.shape == inputs.shape 51 | assert change_manifold.target_manifold == outputs.manifold 52 | assert outputs.man_dim == 1 53 | 54 | 55 | def test_change_manifold__poincare_ball_to_poincare_ball() -> None: 56 | # Define inputs. 57 | manifold = PoincareBall(c=Curvature(0.1)) 58 | tangents = TangentTensor(data=torch.randn(10, 2), manifold=manifold, man_dim=1) 59 | inputs = manifold.expmap(tangents) 60 | # Apply change manifold. 61 | change_manifold = ChangeManifold( 62 | target_manifold=PoincareBall(c=Curvature(1.0)), 63 | ) 64 | outputs = change_manifold(inputs) 65 | # Assert outputs are correct. 66 | assert isinstance(outputs, ManifoldTensor) 67 | assert outputs.shape == inputs.shape 68 | assert change_manifold.target_manifold == outputs.manifold 69 | assert outputs.man_dim == 1 70 | -------------------------------------------------------------------------------- /tests/nn/test_convolution.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from hypll.nn.modules import convolution 4 | 5 | 6 | def test__output_side_length() -> None: 7 | assert ( 8 | convolution._output_side_length(input_side_length=32, kernel_size=3, stride=1, padding=0) 9 | == 30 10 | ) 11 | assert ( 12 | convolution._output_side_length(input_side_length=32, kernel_size=3, stride=2, padding=0) 13 | == 15 14 | ) 15 | 16 | 17 | def test__output_side_length__padding() -> None: 18 | assert ( 19 | convolution._output_side_length(input_side_length=32, kernel_size=3, stride=1, padding=1) 20 | == 32 21 | ) 22 | assert ( 23 | convolution._output_side_length(input_side_length=32, kernel_size=3, stride=2, padding=1) 24 | == 16 25 | ) 26 | # Reproduce https://github.com/maxvanspengler/hyperbolic_learning_library/issues/33 27 | assert ( 28 | convolution._output_side_length(input_side_length=224, kernel_size=11, stride=4, padding=2) 29 | == 55 30 | ) 31 | 32 | 33 | def test__output_side_length__raises() -> None: 34 | with pytest.raises(RuntimeError) as e: 35 | convolution._output_side_length(input_side_length=32, kernel_size=33, stride=1, padding=0) 36 | with pytest.raises(RuntimeError) as e: 37 | convolution._output_side_length(input_side_length=32, kernel_size=3, stride=33, padding=0) 38 | -------------------------------------------------------------------------------- /tests/nn/test_flatten.py: -------------------------------------------------------------------------------- 1 | from pytest_mock import MockerFixture 2 | 3 | from hypll.nn import HFlatten 4 | 5 | 6 | def test_hflatten(mocker: MockerFixture) -> None: 7 | hflatten = HFlatten(start_dim=1, end_dim=-1) 8 | mocked_manifold_tensor = mocker.MagicMock() 9 | flattened = hflatten(mocked_manifold_tensor) 10 | mocked_manifold_tensor.flatten.assert_called_once_with(start_dim=1, end_dim=-1) 11 | -------------------------------------------------------------------------------- /tests/test_manifold_tensor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from hypll.manifolds.poincare_ball import Curvature, PoincareBall 5 | from hypll.tensors import ManifoldTensor 6 | 7 | 8 | @pytest.fixture 9 | def manifold_tensor() -> ManifoldTensor: 10 | return ManifoldTensor( 11 | data=[ 12 | [1.0, 2.0], 13 | [3.0, 4.0], 14 | ], 15 | manifold=PoincareBall(c=Curvature()), 16 | man_dim=-1, 17 | requires_grad=True, 18 | ) 19 | 20 | 21 | def test_attributes(manifold_tensor: ManifoldTensor): 22 | # Check if the standard attributes are set correctly 23 | # TODO: fix this once __eq__ has been implemented on manifolds 24 | assert isinstance(manifold_tensor.manifold, PoincareBall) 25 | assert manifold_tensor.man_dim == 1 26 | # Check if non-callable attributes are taken from tensor attribute 27 | assert manifold_tensor.is_cpu 28 | 29 | 30 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda") 31 | def test_device_methods(manifold_tensor: ManifoldTensor): 32 | # Check if we can move the manifold tensor to the gpu while keeping it intact 33 | manifold_tensor = manifold_tensor.cuda() 34 | assert isinstance(manifold_tensor, ManifoldTensor) 35 | assert manifold_tensor.is_cuda 36 | assert isinstance(manifold_tensor.manifold, PoincareBall) 37 | assert manifold_tensor.man_dim == 1 38 | 39 | # And move it back to the cpu 40 | manifold_tensor = manifold_tensor.cpu() 41 | assert isinstance(manifold_tensor, ManifoldTensor) 42 | assert manifold_tensor.is_cpu 43 | assert isinstance(manifold_tensor.manifold, PoincareBall) 44 | assert manifold_tensor.man_dim == 1 45 | 46 | 47 | def test_slicing(manifold_tensor: ManifoldTensor): 48 | # First test slicing with the usual numpy slicing 49 | manifold_tensor = ManifoldTensor( 50 | data=torch.ones(2, 3, 4, 5, 6, 7, 8, 9), 51 | manifold=PoincareBall(Curvature()), 52 | man_dim=-2, 53 | ) 54 | 55 | # Single index 56 | single_index = manifold_tensor[1] 57 | assert list(single_index.size()) == [3, 4, 5, 6, 7, 8, 9] 58 | assert single_index.man_dim == 5 59 | 60 | # More complicated list of slicing arguments 61 | sliced_tensor = manifold_tensor[1, None, [0, 2], ..., None, 2:5, :, 1] 62 | # Explanation of the output size: the dimension of size 2 dissappears because of the integer 63 | # index. Then, a dim of size 1 is added by None. The size 3 dimension reduces to 2 because 64 | # of the list of indices. The Ellipsis skips the dimensions of sizes 4, 5 and 6. Then another 65 | # None leads to an insertion of a dimension of size 1. The dimension with size 7 is reduced 66 | # to 3 because of the 2:5 slice. The manifold dimension of size 8 is left alone and the last 67 | # dimension is removed because of an integer index. 68 | assert list(sliced_tensor.size()) == [1, 2, 4, 5, 6, 1, 3, 8] 69 | assert sliced_tensor.man_dim == 7 70 | 71 | # Now we try to slice into the manifold dimension, which should raise an error 72 | with pytest.raises(ValueError): 73 | manifold_tensor[1, None, [0, 2], ..., None, 2:5, 5, 1] 74 | 75 | # Next, we try long tensor indexing, which is used in embeddings 76 | embedding_manifold_tensor = ManifoldTensor( 77 | data=torch.ones(10, 3), 78 | manifold=PoincareBall(Curvature()), 79 | man_dim=-1, 80 | ) 81 | indices = torch.Tensor( 82 | [ 83 | [1, 2], 84 | [3, 4], 85 | ] 86 | ).long() 87 | embedding_selection = embedding_manifold_tensor[indices] 88 | assert list(embedding_selection.size()) == [2, 2, 3] 89 | assert embedding_selection.man_dim == 2 90 | 91 | # This should fail if the man_dim is 0 on the embedding tensor though 92 | embedding_manifold_tensor = ManifoldTensor( 93 | data=torch.ones(10, 3), 94 | manifold=PoincareBall(Curvature()), 95 | man_dim=0, 96 | ) 97 | with pytest.raises(ValueError): 98 | embedding_manifold_tensor[indices] 99 | 100 | # Lastly, embeddings with more dimensions should work too 101 | embedding_manifold_tensor = ManifoldTensor( 102 | data=torch.ones(10, 3, 3, 3), 103 | manifold=PoincareBall(Curvature()), 104 | man_dim=2, 105 | ) 106 | embedding_selection = embedding_manifold_tensor[indices] 107 | assert list(embedding_selection.size()) == [2, 2, 3, 3, 3] 108 | assert embedding_selection.man_dim == 3 109 | 110 | 111 | def test_torch_ops(manifold_tensor: ManifoldTensor): 112 | # We want torch functons to raise an error 113 | with pytest.raises(TypeError): 114 | torch.norm(manifold_tensor) 115 | 116 | # Same for torch.Tensor methods (callable attributes) 117 | with pytest.raises(AttributeError): 118 | manifold_tensor.mean() 119 | -------------------------------------------------------------------------------- /tutorials/README.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxvanspengler/hyperbolic_learning_library/62ea7882848d878459eba4100aeb6b8ad64cc38f/tutorials/README.txt -------------------------------------------------------------------------------- /tutorials/cifar10_resnet_tutorial.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Training a Poincare ResNet 4 | ========================== 5 | 6 | This is an implementation based on the Poincare Resnet paper, which can be found at: 7 | 8 | - https://arxiv.org/abs/2303.14027 9 | 10 | Due to the complexity of hyperbolic operations we strongly advise to only run this tutorial with a 11 | GPU. 12 | 13 | We will perform the following steps in order: 14 | 15 | 1. Define a hyperbolic manifold 16 | 2. Load and normalize the CIFAR10 training and test datasets using ``torchvision`` 17 | 3. Define a Poincare ResNet 18 | 4. Define a loss function and optimizer 19 | 5. Train the network on the training data 20 | 6. Test the network on the test data 21 | 22 | """ 23 | 24 | ############################## 25 | # 0. Grab the available device 26 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 27 | import torch 28 | 29 | device = "cuda" if torch.cuda.is_available() else "cpu" 30 | 31 | 32 | ############################# 33 | # 1. Define the Poincare ball 34 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 35 | 36 | from hypll.manifolds.poincare_ball import Curvature, PoincareBall 37 | 38 | # Making the curvature a learnable parameter is usually suboptimal but can 39 | # make training smoother. An initial curvature of 0.1 has also been shown 40 | # to help during training. 41 | manifold = PoincareBall(c=Curvature(value=0.1, requires_grad=True)) 42 | 43 | 44 | ############################### 45 | # 2. Load and normalize CIFAR10 46 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 47 | 48 | import torchvision 49 | import torchvision.transforms as transforms 50 | 51 | ######################################################################## 52 | # .. note:: 53 | # If running on Windows and you get a BrokenPipeError, try setting 54 | # the num_worker of torch.utils.data.DataLoader() to 0. 55 | 56 | transform = transforms.Compose( 57 | [ 58 | transforms.ToTensor(), 59 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 60 | ] 61 | ) 62 | 63 | batch_size = 128 64 | 65 | trainset = torchvision.datasets.CIFAR10( 66 | root="./data", train=True, download=True, transform=transform 67 | ) 68 | trainloader = torch.utils.data.DataLoader( 69 | trainset, batch_size=batch_size, shuffle=True, num_workers=2 70 | ) 71 | 72 | testset = torchvision.datasets.CIFAR10( 73 | root="./data", train=False, download=True, transform=transform 74 | ) 75 | testloader = torch.utils.data.DataLoader( 76 | testset, batch_size=batch_size, shuffle=False, num_workers=2 77 | ) 78 | 79 | 80 | ############################### 81 | # 3. Define a Poincare ResNet 82 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 83 | # This implementation is based on the Poincare ResNet paper, which can 84 | # be found at https://arxiv.org/abs/2303.14027 and which, in turn, is 85 | # based on the original Euclidean implementation described in the paper 86 | # Deep Residual Learning for Image Recognition by He et al. from 2015: 87 | # https://arxiv.org/abs/1512.03385. 88 | 89 | 90 | from typing import Optional 91 | 92 | from torch import nn 93 | 94 | from hypll import nn as hnn 95 | from hypll.tensors import ManifoldTensor 96 | 97 | 98 | class PoincareResidualBlock(nn.Module): 99 | def __init__( 100 | self, 101 | in_channels: int, 102 | out_channels: int, 103 | manifold: PoincareBall, 104 | stride: int = 1, 105 | downsample: Optional[nn.Sequential] = None, 106 | ): 107 | # We can replace each operation in the usual ResidualBlock by a manifold-agnostic 108 | # operation and supply the PoincareBall object to these operations. 109 | super().__init__() 110 | self.in_channels = in_channels 111 | self.out_channels = out_channels 112 | self.manifold = manifold 113 | self.stride = stride 114 | self.downsample = downsample 115 | 116 | self.conv1 = hnn.HConvolution2d( 117 | in_channels=in_channels, 118 | out_channels=out_channels, 119 | kernel_size=3, 120 | manifold=manifold, 121 | stride=stride, 122 | padding=1, 123 | ) 124 | self.bn1 = hnn.HBatchNorm2d(features=out_channels, manifold=manifold) 125 | self.relu = hnn.HReLU(manifold=self.manifold) 126 | self.conv2 = hnn.HConvolution2d( 127 | in_channels=out_channels, 128 | out_channels=out_channels, 129 | kernel_size=3, 130 | manifold=manifold, 131 | padding=1, 132 | ) 133 | self.bn2 = hnn.HBatchNorm2d(features=out_channels, manifold=manifold) 134 | 135 | def forward(self, x: ManifoldTensor) -> ManifoldTensor: 136 | residual = x 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.conv2(x) 141 | x = self.bn2(x) 142 | 143 | if self.downsample is not None: 144 | residual = self.downsample(residual) 145 | 146 | # We replace the addition operation inside the skip connection by a Mobius addition. 147 | x = self.manifold.mobius_add(x, residual) 148 | x = self.relu(x) 149 | 150 | return x 151 | 152 | 153 | class PoincareResNet(nn.Module): 154 | def __init__( 155 | self, 156 | channel_sizes: list[int], 157 | group_depths: list[int], 158 | manifold: PoincareBall, 159 | ): 160 | # For the Poincare ResNet itself we again replace each layer by a manifold-agnostic one 161 | # and supply the PoincareBall to each of these. We also replace the ResidualBlocks by 162 | # the manifold-agnostic one defined above. 163 | super().__init__() 164 | self.channel_sizes = channel_sizes 165 | self.group_depths = group_depths 166 | self.manifold = manifold 167 | 168 | self.conv = hnn.HConvolution2d( 169 | in_channels=3, 170 | out_channels=channel_sizes[0], 171 | kernel_size=3, 172 | manifold=manifold, 173 | padding=1, 174 | ) 175 | self.bn = hnn.HBatchNorm2d(features=channel_sizes[0], manifold=manifold) 176 | self.relu = hnn.HReLU(manifold=manifold) 177 | self.group1 = self._make_group( 178 | in_channels=channel_sizes[0], 179 | out_channels=channel_sizes[0], 180 | depth=group_depths[0], 181 | ) 182 | self.group2 = self._make_group( 183 | in_channels=channel_sizes[0], 184 | out_channels=channel_sizes[1], 185 | depth=group_depths[1], 186 | stride=2, 187 | ) 188 | self.group3 = self._make_group( 189 | in_channels=channel_sizes[1], 190 | out_channels=channel_sizes[2], 191 | depth=group_depths[2], 192 | stride=2, 193 | ) 194 | 195 | self.avg_pool = hnn.HAvgPool2d(kernel_size=8, manifold=manifold) 196 | self.fc = hnn.HLinear(in_features=channel_sizes[2], out_features=10, manifold=manifold) 197 | 198 | def forward(self, x: ManifoldTensor) -> ManifoldTensor: 199 | x = self.conv(x) 200 | x = self.bn(x) 201 | x = self.relu(x) 202 | x = self.group1(x) 203 | x = self.group2(x) 204 | x = self.group3(x) 205 | x = self.avg_pool(x) 206 | x = self.fc(x.squeeze()) 207 | return x 208 | 209 | def _make_group( 210 | self, 211 | in_channels: int, 212 | out_channels: int, 213 | depth: int, 214 | stride: int = 1, 215 | ) -> nn.Sequential: 216 | if stride == 1: 217 | downsample = None 218 | else: 219 | downsample = hnn.HConvolution2d( 220 | in_channels=in_channels, 221 | out_channels=out_channels, 222 | kernel_size=1, 223 | manifold=self.manifold, 224 | stride=stride, 225 | ) 226 | 227 | layers = [ 228 | PoincareResidualBlock( 229 | in_channels=in_channels, 230 | out_channels=out_channels, 231 | manifold=self.manifold, 232 | stride=stride, 233 | downsample=downsample, 234 | ) 235 | ] 236 | 237 | for _ in range(1, depth): 238 | layers.append( 239 | PoincareResidualBlock( 240 | in_channels=out_channels, 241 | out_channels=out_channels, 242 | manifold=self.manifold, 243 | ) 244 | ) 245 | 246 | return nn.Sequential(*layers) 247 | 248 | 249 | # Now, let's create a thin Poincare ResNet with channel sizes [4, 8, 16] and with a depth of 20 250 | # layers. 251 | net = PoincareResNet( 252 | channel_sizes=[4, 8, 16], 253 | group_depths=[3, 3, 3], 254 | manifold=manifold, 255 | ).to(device) 256 | 257 | 258 | ######################################### 259 | # 4. Define a Loss function and optimizer 260 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 261 | # Let's use a Classification Cross-Entropy loss and RiemannianAdam optimizer. 262 | 263 | criterion = nn.CrossEntropyLoss() 264 | # net.parameters() includes the learnable curvature "c" of the manifold. 265 | from hypll.optim import RiemannianAdam 266 | 267 | optimizer = RiemannianAdam(net.parameters(), lr=0.001) 268 | 269 | 270 | ###################### 271 | # 5. Train the network 272 | # ^^^^^^^^^^^^^^^^^^^^ 273 | # We simply have to loop over our data iterator, project the inputs onto the 274 | # manifold, and feed them to the network and optimize. We will train for a limited 275 | # number of epochs here due to the long training time of this model. 276 | 277 | from hypll.tensors import TangentTensor 278 | 279 | for epoch in range(2): # Increase this number to at least 100 for good results 280 | running_loss = 0.0 281 | for i, data in enumerate(trainloader, 0): 282 | # get the inputs; data is a list of [inputs, labels] 283 | inputs, labels = data[0].to(device), data[1].to(device) 284 | 285 | # move the inputs to the manifold 286 | tangents = TangentTensor(data=inputs, man_dim=1, manifold=manifold) 287 | manifold_inputs = manifold.expmap(tangents) 288 | 289 | # zero the parameter gradients 290 | optimizer.zero_grad() 291 | 292 | # forward + backward + optimize 293 | outputs = net(manifold_inputs) 294 | loss = criterion(outputs.tensor, labels) 295 | loss.backward() 296 | optimizer.step() 297 | 298 | # print statistics 299 | running_loss += loss.item() 300 | print(f"[{epoch + 1}, {i + 1:5d}] loss: {loss.item():.3f}") 301 | running_loss = 0.0 302 | 303 | print("Finished Training") 304 | 305 | 306 | ###################################### 307 | # 6. Test the network on the test data 308 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 309 | # 310 | # Let us look at how the network performs on the whole dataset. 311 | 312 | correct = 0 313 | total = 0 314 | # since we're not training, we don't need to calculate the gradients for our outputs 315 | with torch.no_grad(): 316 | for data in testloader: 317 | images, labels = data[0].to(device), data[1].to(device) 318 | 319 | # move the images to the manifold 320 | tangents = TangentTensor(data=images, man_dim=1, manifold=manifold) 321 | manifold_images = manifold.expmap(tangents) 322 | 323 | # calculate outputs by running images through the network 324 | outputs = net(manifold_images) 325 | # the class with the highest energy is what we choose as prediction 326 | _, predicted = torch.max(outputs.tensor, 1) 327 | total += labels.size(0) 328 | correct += (predicted == labels).sum().item() 329 | 330 | print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %") 331 | -------------------------------------------------------------------------------- /tutorials/cifar10_tutorial.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Training a Hyperbolic Classifier 4 | ================================ 5 | 6 | This is an adaptation of torchvision's tutorial "Training a Classifier" to 7 | hyperbolic space. The original tutorial can be found here: 8 | 9 | - https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html 10 | 11 | Training a Hyperbolic Image Classifier 12 | -------------------------------------- 13 | 14 | We will do the following steps in order: 15 | 16 | 1. Define a hyperbolic manifold 17 | 2. Load and normalize the CIFAR10 training and test datasets using ``torchvision`` 18 | 3. Define a hyperbolic Convolutional Neural Network 19 | 4. Define a loss function and optimizer 20 | 5. Train the network on the training data 21 | 6. Test the network on the test data 22 | 23 | """ 24 | 25 | ######################################################################## 26 | # 1. Define a hyperbolic manifold 27 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 28 | # We use the Poincaré ball model for the purposes of this tutorial. 29 | 30 | 31 | from hypll.manifolds.poincare_ball import Curvature, PoincareBall 32 | 33 | # Making the curvature a learnable parameter is usually suboptimal but can 34 | # make training smoother. 35 | manifold = PoincareBall(c=Curvature(requires_grad=True)) 36 | 37 | 38 | ######################################################################## 39 | # 2. Load and normalize CIFAR10 40 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 41 | 42 | import torch 43 | import torchvision 44 | import torchvision.transforms as transforms 45 | 46 | ######################################################################## 47 | # .. note:: 48 | # If running on Windows and you get a BrokenPipeError, try setting 49 | # the num_worker of torch.utils.data.DataLoader() to 0. 50 | 51 | transform = transforms.Compose( 52 | [ 53 | transforms.ToTensor(), 54 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 55 | ] 56 | ) 57 | 58 | 59 | batch_size = 4 60 | 61 | trainset = torchvision.datasets.CIFAR10( 62 | root="./data", train=True, download=True, transform=transform 63 | ) 64 | trainloader = torch.utils.data.DataLoader( 65 | trainset, batch_size=batch_size, shuffle=True, num_workers=2 66 | ) 67 | 68 | testset = torchvision.datasets.CIFAR10( 69 | root="./data", train=False, download=True, transform=transform 70 | ) 71 | testloader = torch.utils.data.DataLoader( 72 | testset, batch_size=batch_size, shuffle=False, num_workers=2 73 | ) 74 | 75 | classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck") 76 | 77 | 78 | ######################################################################## 79 | # 3. Define a hyperbolic Convolutional Neural Network 80 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 81 | # Let's rebuild the convolutional neural network from torchvision's tutorial 82 | # using hyperbolic modules. 83 | 84 | from torch import nn 85 | 86 | from hypll import nn as hnn 87 | 88 | 89 | class Net(nn.Module): 90 | def __init__(self): 91 | super().__init__() 92 | self.conv1 = hnn.HConvolution2d( 93 | in_channels=3, out_channels=6, kernel_size=5, manifold=manifold 94 | ) 95 | self.pool = hnn.HMaxPool2d(kernel_size=2, manifold=manifold, stride=2) 96 | self.conv2 = hnn.HConvolution2d( 97 | in_channels=6, out_channels=16, kernel_size=5, manifold=manifold 98 | ) 99 | self.fc1 = hnn.HLinear(in_features=16 * 5 * 5, out_features=120, manifold=manifold) 100 | self.fc2 = hnn.HLinear(in_features=120, out_features=84, manifold=manifold) 101 | self.fc3 = hnn.HLinear(in_features=84, out_features=10, manifold=manifold) 102 | self.relu = hnn.HReLU(manifold=manifold) 103 | 104 | def forward(self, x): 105 | x = self.pool(self.relu(self.conv1(x))) 106 | x = self.pool(self.relu(self.conv2(x))) 107 | x = x.flatten(start_dim=1) 108 | x = self.relu(self.fc1(x)) 109 | x = self.relu(self.fc2(x)) 110 | x = self.fc3(x) 111 | return x 112 | 113 | 114 | net = Net() 115 | 116 | ######################################################################## 117 | # 4. Define a Loss function and optimizer 118 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 119 | # Let's use a Classification Cross-Entropy loss and RiemannianAdam optimizer. 120 | # Adam is preferred because hyperbolic linear layers can sometimes have training 121 | # difficulties early on due to poor initialization. 122 | 123 | criterion = nn.CrossEntropyLoss() 124 | # net.parameters() includes the learnable curvature "c" of the manifold. 125 | from hypll.optim import RiemannianAdam 126 | 127 | optimizer = RiemannianAdam(net.parameters(), lr=0.001) 128 | 129 | 130 | ######################################################################## 131 | # 5. Train the network 132 | # ^^^^^^^^^^^^^^^^^^^^ 133 | # This is when things start to get interesting. 134 | # We simply have to loop over our data iterator, project the inputs onto the 135 | # manifold, and feed them to the network and optimize. 136 | 137 | from hypll.tensors import TangentTensor 138 | 139 | for epoch in range(2): # loop over the dataset multiple times 140 | running_loss = 0.0 141 | for i, data in enumerate(trainloader, 0): 142 | # get the inputs; data is a list of [inputs, labels] 143 | inputs, labels = data 144 | 145 | # move the inputs to the manifold 146 | tangents = TangentTensor(data=inputs, man_dim=1, manifold=manifold) 147 | manifold_inputs = manifold.expmap(tangents) 148 | 149 | # zero the parameter gradients 150 | optimizer.zero_grad() 151 | 152 | # forward + backward + optimize 153 | outputs = net(manifold_inputs) 154 | loss = criterion(outputs.tensor, labels) 155 | loss.backward() 156 | optimizer.step() 157 | 158 | # print statistics 159 | running_loss += loss.item() 160 | if i % 2000 == 1999: # print every 2000 mini-batches 161 | print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") 162 | running_loss = 0.0 163 | 164 | print("Finished Training") 165 | 166 | 167 | ######################################################################## 168 | # Let's quickly save our trained model: 169 | 170 | PATH = "./cifar_net.pth" 171 | torch.save(net.state_dict(), PATH) 172 | 173 | 174 | ######################################################################## 175 | # Next, let's load back in our saved model (note: saving and re-loading the model 176 | # wasn't necessary here, we only did it to illustrate how to do so): 177 | 178 | net = Net() 179 | net.load_state_dict(torch.load(PATH)) 180 | 181 | ######################################################################## 182 | # 6. Test the network on the test data 183 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 184 | # 185 | # Let us look at how the network performs on the whole dataset. 186 | 187 | correct = 0 188 | total = 0 189 | # since we're not training, we don't need to calculate the gradients for our outputs 190 | with torch.no_grad(): 191 | for data in testloader: 192 | images, labels = data 193 | 194 | # move the images to the manifold 195 | tangents = TangentTensor(data=images, man_dim=1, manifold=manifold) 196 | manifold_images = manifold.expmap(tangents) 197 | 198 | # calculate outputs by running images through the network 199 | outputs = net(manifold_images) 200 | # the class with the highest energy is what we choose as prediction 201 | _, predicted = torch.max(outputs.tensor, 1) 202 | total += labels.size(0) 203 | correct += (predicted == labels).sum().item() 204 | 205 | print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %") 206 | 207 | 208 | ######################################################################## 209 | # That looks way better than chance, which is 10% accuracy (randomly picking 210 | # a class out of 10 classes). 211 | # Seems like the network learnt something. 212 | # 213 | # Hmmm, what are the classes that performed well, and the classes that did 214 | # not perform well: 215 | 216 | # prepare to count predictions for each class 217 | correct_pred = {classname: 0 for classname in classes} 218 | total_pred = {classname: 0 for classname in classes} 219 | 220 | # again no gradients needed 221 | with torch.no_grad(): 222 | for data in testloader: 223 | images, labels = data 224 | 225 | # move the images to the manifold 226 | tangents = TangentTensor(data=images, man_dim=1, manifold=manifold) 227 | manifold_images = manifold.expmap(tangents) 228 | 229 | outputs = net(manifold_images) 230 | _, predictions = torch.max(outputs.tensor, 1) 231 | # collect the correct predictions for each class 232 | for label, prediction in zip(labels, predictions): 233 | if label == prediction: 234 | correct_pred[classes[label]] += 1 235 | total_pred[classes[label]] += 1 236 | 237 | # print accuracy for each class 238 | for classname, correct_count in correct_pred.items(): 239 | accuracy = 100 * float(correct_count) / total_pred[classname] 240 | print(f"Accuracy for class: {classname:5s} is {accuracy:.1f} %") 241 | 242 | 243 | ######################################################################## 244 | # 245 | # Training on GPU 246 | # ---------------- 247 | # Just like how you transfer a Tensor onto the GPU, you transfer the neural 248 | # net onto the GPU. 249 | # 250 | # Let's first define our device as the first visible cuda device if we have 251 | # CUDA available: 252 | # 253 | # .. code:: python 254 | # 255 | # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 256 | # 257 | # 258 | # Assuming that we are on a CUDA machine, this should print a CUDA device: 259 | # 260 | # .. code:: python 261 | # 262 | # print(device) 263 | # 264 | # 265 | # The rest of this section assumes that ``device`` is a CUDA device. 266 | # 267 | # Then these methods will recursively go over all modules and convert their 268 | # parameters and buffers to CUDA tensors: 269 | # 270 | # .. code:: python 271 | # 272 | # net.to(device) 273 | # 274 | # 275 | # Remember that you will have to send the inputs and targets at every step 276 | # to the GPU too: 277 | # 278 | # .. code:: python 279 | # 280 | # inputs, labels = data[0].to(device), data[1].to(device) 281 | # 282 | # 283 | # **Goals achieved**: 284 | # 285 | # - Train a small hyperbolic neural network to classify images. 286 | # 287 | -------------------------------------------------------------------------------- /tutorials/hyperbolic_vit_tutorial.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Training a Hyperbolic Vision Transformer 4 | ======================================== 5 | 6 | This implementation is based on the "Hyperbolic Vision Transformers: Combining Improvements in Metric Learning" 7 | paper by Aleksandr Ermolov et al., which can be found at: 8 | 9 | - https://arxiv.org/pdf/2203.10833.pdf 10 | 11 | Their implementation can be found at: 12 | 13 | - https://github.com/htdt/hyp_metric 14 | 15 | We will perform the following steps in order: 16 | 17 | 1. Initialize the manifold on which the embeddings will be trained 18 | 2. Load the CUB_200_2011 dataset using fastai 19 | 3. Initialize train and test dataloaders 20 | 4. Define the Hyperbolic Vision Transformer model 21 | 5. Define the pairwise cross-entropy loss function 22 | 6. Define the recall@k evaluation metric 23 | 7. Train the model on the training data 24 | 8. Test the model on the test data 25 | 26 | Please make sure to install the following packages before running this script: 27 | - fastai 28 | - timm 29 | 30 | """ 31 | 32 | ####################################################### 33 | # 0. Set the device and random seed for reproducibility 34 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 35 | 36 | import random 37 | 38 | import numpy as np 39 | import torch 40 | 41 | device = "cuda" if torch.cuda.is_available() else "cpu" 42 | 43 | seed = 0 44 | 45 | torch.manual_seed(seed) 46 | torch.cuda.manual_seed(seed) 47 | np.random.seed(seed) 48 | random.seed(seed) 49 | 50 | torch.backends.cudnn.deterministic = True 51 | torch.backends.cudnn.benchmark = False 52 | 53 | #################################################################### 54 | # 1. Initialize the manifold on which the embeddings will be trained 55 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 56 | 57 | from hypll.manifolds.euclidean import Euclidean 58 | from hypll.manifolds.poincare_ball import Curvature, PoincareBall 59 | 60 | # We can choose between the Poincaré ball model and Euclidean space. 61 | do_hyperbolic = True 62 | 63 | if do_hyperbolic: 64 | # We fix the curvature to 0.1 as in the paper (Section 3.2). 65 | manifold = PoincareBall( 66 | c=Curvature(value=0.1, constraining_strategy=lambda x: x, requires_grad=False) 67 | ) 68 | else: 69 | manifold = Euclidean() 70 | 71 | ############################################### 72 | # 2. Load the CUB_200_2011 dataset using fastai 73 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 74 | 75 | import pandas as pd 76 | import torchvision.transforms as transforms 77 | from fastai.data.external import URLs, untar_data 78 | from PIL import Image 79 | from torch.utils.data import Dataset 80 | 81 | # Download the dataset using fastai. 82 | path = untar_data(URLs.CUB_200_2011) / "CUB_200_2011" 83 | 84 | 85 | class CUB(Dataset): 86 | """Handles the downloaded CUB_200_2011 files.""" 87 | 88 | def __init__(self, files_path, train, transform): 89 | self.files_path = files_path 90 | self.transform = transform 91 | 92 | self.targets = pd.read_csv(files_path / "image_class_labels.txt", header=None, sep=" ") 93 | self.targets.columns = ["id", "target"] 94 | self.targets = self.targets["target"].values 95 | 96 | self.train_test = pd.read_csv(files_path / "train_test_split.txt", header=None, sep=" ") 97 | self.train_test.columns = ["id", "is_train"] 98 | 99 | self.images = pd.read_csv(files_path / "images.txt", header=None, sep=" ") 100 | self.images.columns = ["id", "name"] 101 | 102 | mask = self.train_test.is_train.values == int(train) 103 | 104 | self.filenames = self.images.iloc[mask] 105 | self.targets = self.targets[mask] 106 | self.num_files = len(self.targets) 107 | 108 | def __len__(self): 109 | return self.num_files 110 | 111 | def __getitem__(self, index): 112 | y = self.targets[index] - 1 113 | file_name = self.filenames.iloc[index, 1] 114 | path = self.files_path / "images" / file_name 115 | x = Image.open(path).convert("RGB") 116 | x = self.transform(x) 117 | return x, y 118 | 119 | 120 | # We use the same resizing and data augmentation as in the paper (Section 3.2). 121 | # Normalization values are taken from the original implementation. 122 | 123 | train_transform = transforms.Compose( 124 | [ 125 | transforms.RandomResizedCrop( 126 | 224, scale=(0.9, 1.0), interpolation=transforms.InterpolationMode("bicubic") 127 | ), 128 | transforms.RandomHorizontalFlip(), 129 | transforms.ToTensor(), 130 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 131 | ] 132 | ) 133 | trainset = CUB(path, train=True, transform=train_transform) 134 | 135 | test_transform = transforms.Compose( 136 | [ 137 | transforms.Resize(256, interpolation=transforms.InterpolationMode("bicubic")), 138 | transforms.CenterCrop(224), 139 | transforms.ToTensor(), 140 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 141 | ] 142 | ) 143 | testset = CUB(path, train=False, transform=test_transform) 144 | 145 | ########################################## 146 | # 3. Initialize train and test dataloaders 147 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 148 | 149 | import collections 150 | import multiprocessing 151 | 152 | from torch.utils.data import DataLoader 153 | from torch.utils.data.sampler import Sampler 154 | 155 | # We need a dataloader that returns a requested number of images belonging to the same class. 156 | # This will allow us to compute the pairwise cross-entropy loss. 157 | 158 | 159 | class UniqueClassSampler(Sampler): 160 | """Custom sampler for creating batches with m_per_class samples per class.""" 161 | 162 | def __init__(self, labels, m_per_class, seed): 163 | self.labels_to_indices = self.get_labels_to_indices(labels) 164 | self.labels = sorted(list(self.labels_to_indices.keys())) 165 | self.m_per_class = m_per_class 166 | self.seed = seed 167 | self.epoch = 0 168 | 169 | def __len__(self): 170 | return ( 171 | np.max([len(v) for v in self.labels_to_indices.values()]) 172 | * m_per_class 173 | * len(self.labels) 174 | ) 175 | 176 | def set_epoch(self, epoch: int): 177 | # Update the epoch, so that the random permutation changes. 178 | self.epoch = epoch 179 | 180 | def get_labels_to_indices(self, labels: list): 181 | # Create a dictionary mapping each label to the indices in the dataset 182 | labels_to_indices = collections.defaultdict(list) 183 | for i, label in enumerate(labels): 184 | labels_to_indices[label].append(i) 185 | for k, v in labels_to_indices.items(): 186 | labels_to_indices[k] = np.array(v, dtype=int) 187 | return labels_to_indices 188 | 189 | def __iter__(self): 190 | # Generate a random iteration order for the indices. 191 | # For example, if we have 3 classes (A,B,C) and m_per_class=2, 192 | # we could get the following indices: [A1, A2, C1, C2, B1, B2]. 193 | idx_list = [] 194 | g = torch.Generator() 195 | g.manual_seed(self.seed * 10000 + self.epoch) 196 | idx = torch.randperm(len(self.labels), generator=g).tolist() 197 | max_indices = np.max([len(v) for v in self.labels_to_indices.values()]) 198 | for i in idx: 199 | t = self.labels_to_indices[self.labels[i]] 200 | idx_list.append(np.random.choice(t, size=self.m_per_class * max_indices)) 201 | idx_list = np.stack(idx_list).reshape(len(self.labels), -1, self.m_per_class) 202 | idx_list = idx_list.transpose(1, 0, 2).reshape(-1).tolist() 203 | return iter(idx_list) 204 | 205 | 206 | # First, define how many distinct classes to encounter per batch. 207 | # The dataset has 200 classes in total, but the number we choose depends on the available memory. 208 | n_sampled_classes = 128 209 | # Then, define how many examples per class to encounter per batch. 210 | # This number must be at least 2. 211 | m_per_class = 2 212 | # Finally, the batch size is determined by our two choices above. 213 | batch_size = n_sampled_classes * m_per_class 214 | 215 | # Note that this sampler is only used during training. 216 | sampler = UniqueClassSampler(labels=trainset.targets, m_per_class=m_per_class, seed=seed) 217 | 218 | trainloader = DataLoader( 219 | dataset=trainset, 220 | sampler=sampler, 221 | batch_size=batch_size, 222 | pin_memory=True, 223 | drop_last=True, 224 | ) 225 | 226 | testloader = DataLoader( 227 | dataset=testset, 228 | batch_size=batch_size, 229 | pin_memory=True, 230 | drop_last=False, 231 | ) 232 | 233 | ################################################### 234 | # 4. Define the Hyperbolic Vision Transformer model 235 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 236 | 237 | from typing import Union 238 | 239 | import timm 240 | from torch import nn 241 | 242 | from hypll.tensors import ManifoldTensor, TangentTensor 243 | 244 | 245 | class HyperbolicViT(nn.Module): 246 | def __init__( 247 | self, 248 | vit_name: str, 249 | vit_dim: int, 250 | emb_dim: int, 251 | manifold: Union[PoincareBall, Euclidean], 252 | clip_r: float, 253 | pretrained: bool = True, 254 | ): 255 | super().__init__() 256 | # We use the timm library to load a (pretrained) ViT backbone model. 257 | self.vit = timm.create_model(vit_name, pretrained=pretrained) 258 | # We remove the head of the model, since we won't need it. 259 | self.remove_head() 260 | self.fc = nn.Linear(vit_dim, emb_dim) 261 | # We freeze the linear projection for patch embeddings as in the paper (Section 3.2). 262 | self.freeze_patch_embed() 263 | self.manifold = manifold 264 | self.clip_r = clip_r 265 | 266 | def forward(self, x: torch.Tensor) -> ManifoldTensor: 267 | # We first encode the images using the ViT backbone. 268 | x = self.vit(x) 269 | if type(x) == tuple: 270 | x = x[0] 271 | # Then, we project the features into a lower-dimensional space. 272 | x = self.fc(x) 273 | # If we use a Euclidean manifold, 274 | # we simply normalize the features to lie on the unit sphere. 275 | if isinstance(self.manifold, Euclidean): 276 | x = F.normalize(x, p=2, dim=1) 277 | return ManifoldTensor(data=x, man_dim=1, manifold=self.manifold) 278 | # If we use a Poincaré ball model, 279 | # we (optionally) perform feature clipping (Section 2.4), 280 | if self.clip_r is not None: 281 | x = self.clip_features(x) 282 | # and then map the features to the manifold. 283 | tangents = TangentTensor(data=x, man_dim=1, manifold=self.manifold) 284 | return self.manifold.expmap(tangents) 285 | 286 | def clip_features(self, x: torch.Tensor) -> torch.Tensor: 287 | x_norm = torch.norm(x, dim=-1, keepdim=True) + 1e-5 288 | fac = torch.minimum(torch.ones_like(x_norm), self.clip_r / x_norm) 289 | return x * fac 290 | 291 | def remove_head(self): 292 | names = set(x[0] for x in self.vit.named_children()) 293 | target = {"head", "fc", "head_dist", "head_drop"} 294 | for x in names & target: 295 | self.vit.add_module(x, nn.Identity()) 296 | 297 | def freeze_patch_embed(self): 298 | def fr(m): 299 | for param in m.parameters(): 300 | param.requires_grad = False 301 | 302 | fr(self.vit.patch_embed) 303 | fr(self.vit.pos_drop) 304 | 305 | 306 | # Initialize the model using the same hyperparameters as in the paper (Section 3.2). 307 | hvit = HyperbolicViT( 308 | vit_name="vit_small_patch16_224", vit_dim=384, emb_dim=128, manifold=manifold, clip_r=2.3 309 | ).to(device) 310 | 311 | #################################################### 312 | # 5. Define the pairwise cross-entropy loss function 313 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 314 | 315 | import torch.nn.functional as F 316 | 317 | # The function is described in Section 2.2 of the paper. 318 | 319 | 320 | def pairwise_cross_entropy_loss( 321 | z_0: ManifoldTensor, 322 | z_1: ManifoldTensor, 323 | manifold: Union[PoincareBall, Euclidean], 324 | tau: float, 325 | ) -> torch.Tensor: 326 | dist_f = lambda x, y: -manifold.cdist(x.unsqueeze(0), y.unsqueeze(0))[0] 327 | num_classes = z_0.shape[0] 328 | target = torch.arange(num_classes, device=device) 329 | eye_mask = torch.eye(num_classes, device=device) * 1e9 330 | logits00 = dist_f(z_0, z_0) / tau - eye_mask 331 | logits01 = dist_f(z_0, z_1) / tau 332 | logits = torch.cat([logits01, logits00], dim=1) 333 | logits -= logits.max(1, keepdim=True)[0].detach() 334 | loss = F.cross_entropy(logits, target) 335 | return loss 336 | 337 | 338 | # We use the same temperature as in the paper (Section 3.2). 339 | criterion = lambda z_0, z_1: pairwise_cross_entropy_loss(z_0, z_1, manifold=manifold, tau=0.2) 340 | 341 | ########################################## 342 | # 6. Define the recall@k evaluation metric 343 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 344 | 345 | # Retrieval is performed by computing the distance between the query and all other embeddings. 346 | # The top-k embeddings with the smallest distance are then used to compute the recall@k metric. 347 | 348 | 349 | def eval_recall_k( 350 | k: int, model: nn.Module, dataloader: DataLoader, manifold: Union[PoincareBall, Euclidean] 351 | ): 352 | def get_embeddings(): 353 | embs = [] 354 | model.eval() 355 | for data in dataloader: 356 | x = data[0].to(device) 357 | with torch.no_grad(): 358 | z = model(x) 359 | embs.append(z) 360 | embs = manifold.cat(embs, dim=0) 361 | model.train() 362 | return embs 363 | 364 | def get_recall_k(embs): 365 | dist_matrix = torch.zeros(embs.shape[0], embs.shape[0]) 366 | for i in range(embs.shape[0]): 367 | dist_matrix[i] = -manifold.cpu().cdist( 368 | embs[[i]].unsqueeze(0).cpu(), embs.unsqueeze(0).cpu() 369 | )[0] 370 | dist_matrix = torch.nan_to_num(dist_matrix, nan=-torch.inf) 371 | targets = np.array(dataloader.dataset.targets) 372 | top_k = targets[dist_matrix.topk(1 + k).indices[:, 1:].numpy()] 373 | recall_k = np.mean([1 if t in top_k[i] else 0 for i, t in enumerate(targets)]) 374 | return recall_k 375 | 376 | return get_recall_k(get_embeddings()) 377 | 378 | 379 | ######################################### 380 | # 7. Train the model on the training data 381 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 382 | 383 | import torch.optim as optim 384 | 385 | # We use the AdamW optimizer with the same hyperparameters as in the paper (Section 3.2). 386 | optimizer = optim.AdamW(hvit.parameters(), lr=3e-5, weight_decay=0.01) 387 | 388 | # Training for 5 epochs is enough to achieve good results. 389 | num_epochs = 5 390 | for epoch in range(num_epochs): 391 | sampler.set_epoch(epoch) 392 | running_loss = 0.0 393 | num_samples = 0 394 | for d_i, data in enumerate(trainloader): 395 | # Get the inputs; data is a list of [inputs, labels] 396 | inputs = data[0].to(device) 397 | 398 | # Zero out the parameter gradients 399 | optimizer.zero_grad() 400 | 401 | # Forward + backward + optimize 402 | bs = len(inputs) 403 | z = hvit(inputs) 404 | z = ManifoldTensor( 405 | data=z.tensor.reshape((bs // m_per_class, m_per_class, -1)), 406 | man_dim=-1, 407 | manifold=z.manifold, 408 | ) 409 | # The loss is computed pair-wise. 410 | loss = 0 411 | for i in range(m_per_class): 412 | for j in range(m_per_class): 413 | if i != j: 414 | z_i = z[:, i] 415 | z_j = z[:, j] 416 | loss += criterion(z_i, z_j) 417 | loss.backward() 418 | torch.nn.utils.clip_grad_norm_(hvit.parameters(), 3) 419 | optimizer.step() 420 | 421 | # Print statistics every 10 batches. 422 | running_loss += loss.item() * bs 423 | num_samples += bs 424 | if d_i % 10 == 9: 425 | print( 426 | f"[Epoch {epoch + 1}, {d_i + 1:5d}/{len(trainloader)}] train loss: {running_loss/num_samples:.3f}" 427 | ) 428 | running_loss = 0.0 429 | num_samples = 0 430 | 431 | print("Finished training!") 432 | 433 | #################################### 434 | # 8. Test the model on the test data 435 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 436 | 437 | # Finally, let's see how well the model performs on the test data. 438 | 439 | k = 5 440 | recall_k = eval_recall_k(k, hvit, testloader, manifold) 441 | print(f"Test recall@{k}: {recall_k:.3f}") 442 | 443 | # Using the default hyperparameters in this tutorial, 444 | # we are able to achieve a recall@5 of 0.946 for the Poincaré ball model, 445 | # and a recall@5 of 0.916 for the Euclidean model. 446 | -------------------------------------------------------------------------------- /tutorials/poincare_embeddings_tutorial.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Training Poincare embeddings for the mammals subset from WordNet 4 | ================================================================ 5 | 6 | This implementation is based on the "Poincare Embeddings for Learning Hierarchical Representations" 7 | paper by Maximilian Nickel and Douwe Kiela, which can be found at: 8 | 9 | - https://arxiv.org/pdf/1705.08039.pdf 10 | 11 | Their implementation can be found at: 12 | 13 | - https://github.com/facebookresearch/poincare-embeddings 14 | 15 | We will perform the following steps in order: 16 | 17 | 1. Load the mammals subset from the WordNet hierarchy using NetworkX 18 | 2. Create a dataset containing the graph from which we can sample 19 | 3. Initialize the Poincare ball on which the embeddings will be trained 20 | 4. Define the Poincare embedding model 21 | 5. Define the Poincare embedding loss function 22 | 6. Perform a few "burn-in" training epochs with reduced learning rate 23 | 24 | """ 25 | 26 | ###################################################################### 27 | # 1. Load the mammals subset from the WordNet hierarchy using NetworkX 28 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 29 | 30 | import json 31 | import os 32 | 33 | import networkx as nx 34 | 35 | # If you have stored the wordnet_mammals.json file differently that the definition here, 36 | # adjust the mammals_json_file string to correctly point to this json file. The file itself can 37 | # be found in the repository under tutorials/data/wordnet_mammals.json. 38 | root = os.path.dirname(os.path.abspath(__file__)) 39 | mammals_json_file = os.path.join(root, "data", "wordnet_mammals.json") 40 | with open(mammals_json_file, "r") as json_file: 41 | graph_dict = json.load(json_file) 42 | 43 | mammals_graph = nx.node_link_graph(graph_dict) 44 | 45 | 46 | ################################################################### 47 | # 2. Create a dataset containing the graph from which we can sample 48 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 49 | 50 | import random 51 | 52 | import torch 53 | from torch.utils.data import DataLoader, Dataset 54 | 55 | 56 | class MammalsEmbeddingDataset(Dataset): 57 | def __init__( 58 | self, 59 | mammals: nx.DiGraph, 60 | ): 61 | super().__init__() 62 | self.mammals = mammals 63 | self.edges_list = list(mammals.edges()) 64 | 65 | # This Dataset object has a sample for each edge in the graph. 66 | def __len__(self) -> int: 67 | return len(self.edges_list) 68 | 69 | def __getitem__(self, idx: int): 70 | # For each existing edge in the graph we choose 10 fake or negative edges, which we build 71 | # from the idx-th existing edge. So, first we grab this edge from the graph. 72 | rel = self.edges_list[idx] 73 | 74 | # Next, we take our source node rel[0] and see which nodes in the graph are not a child of 75 | # this node. 76 | negative_target_nodes = list( 77 | self.mammals.nodes() - nx.descendants(self.mammals, rel[0]) - {rel[0]} 78 | ) 79 | 80 | # Then, we sample at most 5 of these negative target nodes... 81 | # NOTE: this type of sampling is a straightforward, but inefficient method. If your dataset 82 | # is larger, consider using more efficient sampling methods, as this will otherwise 83 | # form a bottleneck. See (closed) Issue 59 on GitHub for some more information. 84 | negative_target_sample_size = min(5, len(negative_target_nodes)) 85 | negative_target_nodes_sample = random.sample( 86 | negative_target_nodes, negative_target_sample_size 87 | ) 88 | 89 | # and add these to a tensor which will be used as input for our embedding model. 90 | edges = torch.tensor([rel] + [[rel[0], neg] for neg in negative_target_nodes_sample]) 91 | 92 | # Next, we do the same with our target node rel[1], but now where we sample from nodes 93 | # which aren't a parent of it. 94 | negative_source_nodes = list( 95 | self.mammals.nodes() - nx.ancestors(self.mammals, rel[1]) - {rel[1]} 96 | ) 97 | 98 | # We sample from these negative source nodes until we have a total of 10 negative edges... 99 | negative_source_sample_size = 10 - negative_target_sample_size 100 | negative_source_nodes_sample = random.sample( 101 | negative_source_nodes, negative_source_sample_size 102 | ) 103 | 104 | # and add these to the tensor that we created above. 105 | edges = torch.cat( 106 | tensors=(edges, torch.tensor([[neg, rel[1]] for neg in negative_source_nodes_sample])), 107 | dim=0, 108 | ) 109 | 110 | # Lastly, we create a tensor containing the labels of the edges, indicating whether it's a 111 | # True or a False edge. 112 | edge_label_targets = torch.cat(tensors=[torch.ones(1).bool(), torch.zeros(10).bool()]) 113 | 114 | return edges, edge_label_targets 115 | 116 | 117 | # Now, we construct the dataset. 118 | dataset = MammalsEmbeddingDataset( 119 | mammals=mammals_graph, 120 | ) 121 | dataloader = DataLoader(dataset, batch_size=10, shuffle=True) 122 | 123 | 124 | ######################################################################### 125 | # 3. Initialize the Poincare ball on which the embeddings will be trained 126 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 127 | 128 | from hypll.manifolds.poincare_ball import Curvature, PoincareBall 129 | 130 | poincare_ball = PoincareBall(Curvature(1.0)) 131 | 132 | 133 | ######################################## 134 | # 4. Define the Poincare embedding model 135 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 136 | 137 | import hypll.nn as hnn 138 | 139 | 140 | class PoincareEmbedding(hnn.HEmbedding): 141 | def __init__( 142 | self, 143 | num_embeddings: int, 144 | embedding_dim: int, 145 | manifold: PoincareBall, 146 | ): 147 | super().__init__(num_embeddings, embedding_dim, manifold) 148 | 149 | # The model outputs the distances between the nodes involved in the input edges as these are 150 | # used to compute the loss. 151 | def forward(self, edges: torch.Tensor) -> torch.Tensor: 152 | embeddings = super().forward(edges) 153 | edge_distances = self.manifold.dist(x=embeddings[:, :, 0, :], y=embeddings[:, :, 1, :]) 154 | return edge_distances 155 | 156 | 157 | # We want to embed every node into a 2-dimensional Poincare ball. 158 | model = PoincareEmbedding( 159 | num_embeddings=len(mammals_graph.nodes()), 160 | embedding_dim=2, 161 | manifold=poincare_ball, 162 | ) 163 | 164 | 165 | ################################################ 166 | # 5. Define the Poincare embedding loss function 167 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 168 | 169 | 170 | # This function is given in equation (5) of the Poincare Embeddings paper. 171 | def poincare_embeddings_loss(dists: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 172 | logits = dists.neg().exp() 173 | numerator = torch.where(condition=targets, input=logits, other=0).sum(dim=-1) 174 | denominator = logits.sum(dim=-1) 175 | loss = (numerator / denominator).log().mean().neg() 176 | return loss 177 | 178 | 179 | ####################################################################### 180 | # 6. Perform a few "burn-in" training epochs with reduced learning rate 181 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 182 | 183 | from hypll.optim import RiemannianSGD 184 | 185 | # The learning rate of 0.3 is dived by 10 during burn-in. 186 | optimizer = RiemannianSGD( 187 | params=model.parameters(), 188 | lr=0.3 / 10, 189 | ) 190 | 191 | # Perform training as we would usually. 192 | for epoch in range(10): 193 | average_loss = 0 194 | for idx, (edges, edge_label_targets) in enumerate(dataloader): 195 | optimizer.zero_grad() 196 | 197 | dists = model(edges) 198 | loss = poincare_embeddings_loss(dists=dists, targets=edge_label_targets) 199 | loss.backward() 200 | optimizer.step() 201 | 202 | average_loss += loss 203 | 204 | average_loss /= len(dataloader) 205 | print(f"Burn-in epoch {epoch} loss: {average_loss}") 206 | 207 | 208 | ######################### 209 | # 6. Train the embeddings 210 | # ^^^^^^^^^^^^^^^^^^^^^^^ 211 | 212 | # Now we use the actual learning rate 0.3. 213 | optimizer = RiemannianSGD( 214 | params=model.parameters(), 215 | lr=0.3, 216 | ) 217 | 218 | for epoch in range(300): 219 | average_loss = 0 220 | for idx, (edges, edge_label_targets) in enumerate(dataloader): 221 | optimizer.zero_grad() 222 | 223 | dists = model(edges) 224 | loss = poincare_embeddings_loss(dists=dists, targets=edge_label_targets) 225 | loss.backward() 226 | optimizer.step() 227 | 228 | average_loss += loss 229 | 230 | average_loss /= len(dataloader) 231 | print(f"Epoch {epoch} loss: {average_loss}") 232 | 233 | # You have now trained your own Poincare Embeddings! 234 | --------------------------------------------------------------------------------