├── .github └── workflows │ └── run_tests.yaml ├── .gitignore ├── .python-version ├── .readthedocs.yaml ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.md ├── README.md ├── ROADMAP.md ├── docs ├── .gitignore ├── Makefile ├── make.bat └── source │ ├── comparison.rst │ ├── conf.py │ ├── data.rst │ ├── examples │ ├── autoencoders.rst │ ├── index.rst │ └── summary_statistics.rst │ ├── index.rst │ ├── nn.rst │ ├── nutshell.rst │ ├── usage.rst │ └── utils.rst ├── logos ├── giotto.jpg └── gudhi.png ├── poetry.lock ├── pyproject.toml ├── tests ├── __init__.py ├── test_alpha_complex.py ├── test_cubical_complex.py ├── test_layers.py ├── test_multi_scale_kernel.py ├── test_ot.py ├── test_pytorch_topological.py └── test_vietoris_rips_complex.py ├── torch_topological.svg └── torch_topological ├── __init__.py ├── data ├── __init__.py ├── shapes.py └── utils.py ├── datasets ├── __init__.py ├── shapes.py └── spheres.py ├── examples ├── alpha_complex.py ├── autoencoders.py ├── benchmarking.py ├── classification.py ├── cubical_complex.py ├── distances.py ├── gan.py ├── image_smoothing.py ├── summary_statistics.py └── weighted_euler_characteristic_transform.py ├── nn ├── __init__.py ├── alpha_complex.py ├── cubical_complex.py ├── data.py ├── distances.py ├── graphs.py ├── layers.py ├── loss.py ├── multi_scale_kernel.py ├── sliced_wasserstein_distance.py ├── sliced_wasserstein_kernel.py ├── vietoris_rips_complex.py └── weighted_euler_characteristic_transform.py └── utils ├── __init__.py ├── filters.py ├── general.py └── summary_statistics.py /.github/workflows/run_tests.yaml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | tests: 9 | runs-on: ubuntu-latest 10 | 11 | strategy: 12 | matrix: 13 | python-version: ["3.10"] 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Install `poetry` 18 | run: | 19 | pipx install poetry 20 | poetry --version 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v3 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | cache: 'poetry' 26 | - name: Install package 27 | run: | 28 | poetry install 29 | - name: Run linter 30 | run: | 31 | poetry run flake8 torch_topological 32 | - name: Run tests 33 | run: | 34 | poetry run pytest 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | 3 | *.pyc 4 | *.swp 5 | 6 | /data/ 7 | /dist/ 8 | __pycache__/ 9 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.9.9 2 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | # Required to get access to more recent Python versions. 5 | image: testing 6 | 7 | sphinx: 8 | configuration: docs/source/conf.py 9 | 10 | python: 11 | version: 3.9 12 | install: 13 | - method: pip 14 | path: . 15 | extra_requirements: 16 | - docs 17 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | `pytorch-topological` follows the [semantic versioning](https://semver.org). 4 | This changelog contains all notable changes in the project. 5 | 6 | # v0.1.7 7 | 8 | ## Fixed 9 | 10 | - Fixed bug in `make_tensor` that caused the loss of device information. 11 | Device handling should now be transparent for clients. 12 | 13 | # v0.1.6 14 | 15 | ## Fixed 16 | 17 | - Fixed various documentation typos 18 | - Fixed bug in `make_tensor` creation for single batches 19 | 20 | # v0.1.5 21 | 22 | ## Added 23 | 24 | - A bunch of new test cases 25 | - Alpha complex class (`AlphaComplex`) 26 | - Dimension selector class (`SelectByDimension`) 27 | - Discussing additional packages in documentation 28 | - Linting for pull requests 29 | 30 | ## Fixed 31 | 32 | - Improved contribution guidelines 33 | - Improved documentation of summary statistics loss 34 | - Improved overall maintainability 35 | - Improved test cases 36 | - Simplified multi-scale kernel usage (distance calculations with different exponents) 37 | - Test case for cubical complexes 38 | - Usage of seed parameter for shape generation (following `numpy guidelines`) 39 | 40 | # v0.1.4 41 | 42 | ## Added 43 | 44 | - Batch handler for point cloud complexes 45 | - Sliced Wasserstein distance kernel 46 | - Support for pre-computed distances in Vietoris--Rips complex 47 | 48 | ## Fixed 49 | 50 | - Device compatibility issue (tensor being created on the wrong device) 51 | - Use of `dim` flag for alpha complexes 52 | - Various documentation issues and coding style problems 53 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of conduct 2 | 3 | Developing is a fundamentally collaborative endeavour. Whenever human 4 | beings interact, there is a potential for conflict. At the same time, 5 | there is an even greater potential to build something that is 'bigger 6 | than the sum of its parts.' In this spirit, all contributors shall be 7 | aware of the following guidelines: 8 | 9 | 1. Be tolerant of opposing views. 10 | 2. Be mindful of the way you utter your critiques; ensure that what you 11 | write is objective and does not attack a person or contain any 12 | disparaging remarks. 13 | 3. Be forgiving when interpreting the words and actions of others; 14 | always assume good intentions. 15 | 16 | We want contributors to participate in a respectful, collaborative, and 17 | overall productive space here. Any harassment will **not** be tolerated. 18 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute to `pytorch-topological` 2 | 3 | As the saying goes: 'Many hands make light work.' Thanks for being 4 | willing to contribute to `pytorch-topological`! Here are some resources 5 | to get you started: 6 | 7 | - Check out the [road map](ROADMAP.md) of the project for a high-level 8 | overview of current directions. 9 | - Check out [open issues](/issues) in case you are looking for things 10 | to tackle. 11 | 12 | **When contributing code, please be aware that your contribution will 13 | fall under the terms of the [license](https://github.com/aidos-lab/pytorch-topological/blob/main/LICENSE.md) 14 | of `pytorch-topological`.** 15 | 16 | ## Pull requests 17 | 18 | If you propose some changes, a pull request is the easiest way to 19 | integrate them. Please be mindful of the coding conventions (see below) 20 | and [write good commit messages](https://cbea.ms/git-commit/). 21 | 22 | ## Coding conventions 23 | 24 | Above all, consider that this is *open source software*. It is meant 25 | to be used and extended by many people. Be mindful of them by making 26 | your code look nice and appealing to them. We cannot build upon some 27 | module no one understands. 28 | 29 | As a way to obtain some consistency in all contributions, your code 30 | should ideally be conform with [PEP 8](https://www.python.org/dev/peps/pep-0008/). 31 | You can check this using the great [`flake8`](https://flake8.pycqa.org/) tool: 32 | 33 | ```console 34 | $ flake8 script.py 35 | ``` 36 | 37 | When creating a pull request, your code will be automatically checked 38 | for coding convention conformity. 39 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (C) 2021 Bastian Rieck 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are 5 | met: 6 | 7 | 1. Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS 19 | IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED 20 | TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 24 | TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 25 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 26 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 27 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | `pytorch-topological` icon 2 | 3 | # `pytorch-topological`: A topological machine learning framework for `pytorch` 4 | 5 | [![Documentation](https://readthedocs.org/projects/pytorch-topological/badge/?version=latest)](https://pytorch-topological.readthedocs.io/en/latest/?badge=latest) [![Maintainability](https://api.codeclimate.com/v1/badges/397f53d1968f01b86e74/maintainability)](https://codeclimate.com/github/aidos-lab/pytorch-topological/maintainability) ![GitHub contributors](https://img.shields.io/github/contributors/aidos-lab/pytorch-topological) ![PyPI - License](https://img.shields.io/pypi/l/torch_topological) ![PyPI](https://img.shields.io/pypi/v/torch_topological) [![Tests](https://github.com/aidos-lab/pytorch-topological/actions/workflows/run_tests.yaml/badge.svg)](https://github.com/aidos-lab/pytorch-topological/actions/workflows/run_tests.yaml) 6 | 7 | `pytorch-topological` (or `torch_topological`) is a topological machine 8 | learning framework for [PyTorch](https://pytorch.org). It aims to 9 | collect *loss terms* and *neural network layers* in order to simplify 10 | building the next generation of topology-based machine learning tools. 11 | 12 | # Topological machine learning in a nutshell 13 | 14 | *Topological machine learning* refers to a new class of machine learning 15 | algorithms that are able to make use of topological features in data 16 | sets. In contrast to methods based on a purely geometrical point of 17 | view, topological features are capable of focusing on *connectivity 18 | aspects* of a data set. This provides an interesting fresh perspective 19 | that can be used to create powerful hybrid algorithms, capable of 20 | yielding more insights into data. 21 | 22 | This is an *emerging research field*, firmly rooted in computational 23 | topology and topological data analysis. If you want to learn more about 24 | how topology and geometry can work in tandem, here are a few resources 25 | to get you started: 26 | 27 | - Amézquita et al., [*The Shape of Things to Come: Topological Data Analysis and Biology, 28 | from Molecules to Organisms*](https://doi.org/10.1002/dvdy.175), Developmental Dynamics 29 | Volume 249, Issue 7, pp. 816--833, 2020. 30 | 31 | - Hensel et al., [*A Survey of Topological Machine Learning Methods*](https://www.frontiersin.org/articles/10.3389/frai.2021.681108/full), 32 | Frontiers in Artificial Intelligence, 2021. 33 | 34 | # Installation and requirements 35 | 36 | `torch_topological` requires Python 3.9. More recent versions might work 37 | but necessitate building some dependencies by yourself; Python 3.9 38 | currently offers the smoothest experience. 39 | It is recommended to use the excellent [`poetry`](https://python-poetry.org) framework 40 | to install `torch_topological`: 41 | 42 | ``` 43 | poetry add torch-topological 44 | ``` 45 | 46 | Alternatively, use `pip` to install the package: 47 | 48 | ``` 49 | pip install -U torch-topological 50 | ``` 51 | 52 | **A note on older versions.** Older versions of Python are not 53 | explicitly supported, and things may break in unexpected ways. 54 | If you want to use a different version, check `pyproject.toml` 55 | and adjust the Python requirement to your preference. This may 56 | or may not work, good luck! 57 | 58 | # Usage 59 | 60 | `torch_topological` is still a work in progress. You can [browse the documentation](https://pytorch-topological.readthedocs.io) 61 | or, if code reading is more your thing, dive directly into [some example 62 | code](./torch_topological/examples). 63 | 64 | Here is a list of *other* projects that are using `torch_topological`: 65 | 66 | - [SHAPR](https://github.com/marrlab/SHAPR_torch), a method for for 67 | predicting the 3D cell shape of individual cells based on 2D 68 | microscopy images 69 | 70 | This list is incomplete---you can help expanding it by using 71 | `torch_topological` in your own projects! :innocent: 72 | 73 | # Contributing 74 | 75 | Check out the [contribution guidelines](CONTRIBUTING.md) or the [road 76 | map](ROADMAP.md) of the project. 77 | 78 | # Acknowledgements 79 | 80 | Our software and research does not exist in a vacuum. `pytorch-topological` is standing 81 | on the shoulders of proverbial giants. In particular, we want to thank the 82 | following projects for constituting the technical backbone of the 83 | project: 84 | 85 | | [`giotto-tda`](https://github.com/giotto-ai/giotto-tda) | [`gudhi`](https://github.com/GUDHI/gudhi-devel)
| 86 | |---------------------------------------------------------------|-------------------------------------------------------------| 87 | | `giotto` icon | `GUDHI` icon | 88 | 89 | Furthermore, `pytorch-topological` draws inspiration from several 90 | projects that provide a glimpse into the wonderful world of topological 91 | machine learning: 92 | 93 | - [`difftda`](https://github.com/MathieuCarriere/difftda) by [Mathieu Carrière](https://github.com/MathieuCarriere) 94 | 95 | - [`Ripser`](https://github.com/Ripser/ripser) by [Ulrich Bauer](https://github.com/ubauer) 96 | 97 | - [`Teaspoon`](https://lizliz.github.io/teaspoon/) by [Elizabeth Munch](https://elizabethmunch.com/) and her team 98 | 99 | - [`TopologyLayer`](https://github.com/bruel-gabrielsson/TopologyLayer) by [Rickard Brüel Gabrielsson](https://github.com/bruel-gabrielsson) 100 | 101 | - [`topological-autoencoders`](https://github.com/BorgwardtLab/topological-autoencoders) by [Michael Moor](https://github.com/mi92), [Max Horn](https://github.com/ExpectationMax), and [Bastian Rieck](https://github.com/Pseudomanifold) 102 | 103 | - [`torchph`](https://github.com/c-hofer/torchph) by [Christoph Hofer](https://github.com/c-hofer) and [Roland Kwitt](https://github.com/rkwitt) 104 | 105 | Finally, `pytorch-topological` makes heavy use of [`POT`](https://pythonot.github.io), the Python Optimal Transport Library. 106 | We are indebted to the many contributors of all these projects. 107 | -------------------------------------------------------------------------------- /ROADMAP.md: -------------------------------------------------------------------------------- 1 | # What's the future of `pytorch-topological`? 2 | 3 | --- 4 | 5 | **Vision** `pytorch-topological` aims to be the first stop for building 6 | powerful applications using *topological machine learning* algorithms, 7 | i.e. algorithms that are capable of jointly leveraging geometrical and 8 | topological features in a data set 9 | 10 | --- 11 | 12 | To make this vision a reality, we first and foremost need to rely on 13 | exceptional documentation. It is not enough to write outstanding code; 14 | we have to demonstrate the power of topological algorithms to our users 15 | by writing well-documented code and contributing examples. 16 | 17 | Here are short-term and long-term goals, roughly categorised: 18 | 19 | ## API 20 | 21 | - [ ] Provide consistent way of handling batches or tensor inputs. Most 22 | of the modules rely on *sparse* inputs as lists. 23 | - [ ] Support different backends for calculating persistent homology. At 24 | present, we use [`GUDHI`](https://github.com/GUDHI/gudhi-devel/) for 25 | cubical complexes and [`giotto-ph`](https://github.com/giotto-ai/giotto-ph) 26 | for Vietoris--Rips complexes. It would be nice to be able to swap 27 | implementations easily. 28 | - [ ] Check out the use of sparse tensors; could be a potential way 29 | forward for representing persistence information. The drawback is that 30 | we cannot fill everything with zeroes; there has to be a way to 31 | indicate 'unset' information. 32 | 33 | ## Complexes 34 | 35 | - [x] Add (rudimentary) support for alpha complexes: a basic 36 | implementation is already present. 37 | 38 | ## Distances and kernels 39 | 40 | At present, the module supports Wasserstein distance calculations and 41 | bottleneck distance calculations between persistence diagrams. In 42 | addition to this, several 'pseudo-distances' based on summary statistics 43 | have been implemented. There are numerous kernels out there that could 44 | be included: 45 | 46 | - [x] The multi-scale kernel by [Reininghaus et al.](https://openaccess.thecvf.com/content_cvpr_2015/papers/Reininghaus_A_Stable_Multi-Scale_2015_CVPR_paper.pdf) 47 | - [x] The sliced Wasserstein distance and kernel by [Carrière et al.](https://arxiv.org/abs/1706.03358). 48 | 49 | This list is **incomplete**. 50 | 51 | ## Layers 52 | 53 | There are quite a few topology-based layers that have been proposed by 54 | members of the community. We should include all of them to make them 55 | available with a single, consistent interface. 56 | 57 | - [ ] Include [`PersLay`](https://github.com/MathieuCarriere/perslay). 58 | This requires a conversion from TensorFlow code. 59 | - [ ] Include [`PLLay`](https://github.com/jisuk1/pllay). 60 | This requires a conversion from TensorFlow code. 61 | - [ ] Include [`SLayer`](https://github.com/c-hofer/torchph). This is 62 | still an ongoing effort. 63 | - [x] Include [`TopologyLayer`](https://github.com/bruel-gabrielsson/TopologyLayer). 64 | 65 | ## Loss terms 66 | 67 | - [x] Include *signature loss* from [`topological-autoencoders`](https://github.com/BorgwardtLab/topological-autoencoders) 68 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/comparison.rst: -------------------------------------------------------------------------------- 1 | Comparison with other packages 2 | ============================== 3 | 4 | If you are already familiar with certain packages for calculating 5 | topological features, you might be interested in understanding in 6 | what aspects ``torch_topological`` differs from them. This is not 7 | meant to be a comprehensive comparison; we are aiming for a brief 8 | overview to simplify getting acquainted with the project. 9 | 10 | ``giotto-tda`` 11 | -------------- 12 | 13 | `giotto-tda `_ is a flagship 14 | package, developed by numerous members of L2F. Its primary goal is to 15 | provide an interface consistent with ``scikit-learn``, thus facilitating 16 | an integration of topological features into a data science workflow. 17 | 18 | By contrast, ``torch_topological`` is meant to simplify the development 19 | of hybrid algorithms that can be easily integrated into deep learning 20 | architectures. ``giotto-tda`` is developed by a large team with a much 21 | more professional development agenda, whereas ``torch_topological`` is 22 | geared more towards researchers that want to prototype the integration 23 | of topological features. 24 | 25 | ``Teaspoon`` 26 | ------------ 27 | 28 | `Teaspoon `_ is a library that 29 | targets topological signal processing applications, such as the analysis 30 | of time-varying systems or complex networks. ``Teaspoon`` integrates 31 | very nicely with ``scikit-learn`` and targets a different set of 32 | applications than ``torch_topological``. 33 | 34 | ``TopologyLayer`` 35 | ----------------- 36 | 37 | `TopologyLayer `_ is 38 | a library developed by Rickard Brüel Gabrielsson and others, 39 | accompanying their AISTATS publication `A Topology Layer for Machine Learning `_. 40 | 41 | ``torch_topological`` subsumes the functionality of ``TopologyLayer``, 42 | albeit under different names: 43 | 44 | - :py:class:`torch_topological.nn.VietorisRipsComplex` or 45 | :py:class:`torch_topological.nn.CubicalComplex` can be used to extract 46 | topological features from point clouds and images, respectively. 47 | 48 | - The ``BarcodePolyFeature`` and ``SumBarcodeLengths`` classes are 49 | incorporated as summary statistics loss functions instead. See the 50 | following example for more details: :doc:`examples/summary_statistics` 51 | 52 | - The ``PartialSumBarcodeLengths`` function is not implemented, mostly 53 | because a similar effect can be achieved by pruning the persistence 54 | diagram manually. This functionality might be added later on. 55 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | # -- Project information ----------------------------------------------------- 18 | 19 | project = 'torch_topological' 20 | copyright = '2022, Bastian Rieck' 21 | author = 'Bastian Rieck' 22 | 23 | # -- General configuration --------------------------------------------------- 24 | 25 | # Add any Sphinx extension module names here, as strings. They can be 26 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 27 | # ones. 28 | extensions = [ 29 | 'sphinx.ext.autodoc', 30 | 'sphinx.ext.napoleon', 31 | 'sphinx.ext.linkcode', 32 | ] 33 | 34 | # Ensure that member functions are documented. These are sane defaults. 35 | autodoc_default_options = { 36 | 'members': True, 37 | 'member-order': 'bysource', 38 | 'special-members': '__init__', 39 | 'undoc-members': True, 40 | 'exclude-members': '__weakref__' 41 | } 42 | 43 | # Tries to assign some semantic meaning to arguments provided with 44 | # single backtics, such as `x`. This way, we can ignore `func` and 45 | # `class` targets etc. (They still work, though!) 46 | default_role = 'obj' 47 | 48 | # Add any paths that contain templates here, relative to this directory. 49 | templates_path = ['_templates'] 50 | 51 | # List of patterns, relative to source directory, that match files and 52 | # directories to ignore when looking for source files. 53 | # This pattern also affects html_static_path and html_extra_path. 54 | exclude_patterns = ['build', 'Thumbs.db', '.DS_Store'] 55 | 56 | # -- Options for HTML output ------------------------------------------------- 57 | 58 | # The theme to use for HTML and HTML Help pages. See the documentation for 59 | # a list of builtin themes. 60 | # 61 | html_theme = 'sphinx_rtd_theme' 62 | html_logo = '../../torch_topological.svg' 63 | html_theme_options = { 64 | 'logo_only': False, 65 | 'display_version': True, 66 | } 67 | 68 | # Add any paths that contain custom static files (such as style sheets) here, 69 | # relative to this directory. They are copied after the builtin static files, 70 | # so a file named "default.css" will overwrite the builtin "default.css". 71 | html_static_path = ['_static'] 72 | 73 | # Ensures that modules are sorted correctly. Since they all pertain to 74 | # the same package, the prefix itself can be ignored. 75 | modindex_common_prefix = ['torch_topological.'] 76 | 77 | # Specifies how to actually find the sources of the modules. Ensures 78 | # that people can jump to files in the repository directly. 79 | def linkcode_resolve(domain, info): 80 | # Let's frown on global imports and do everything locally as much as 81 | # we can. 82 | import sys 83 | import torch_topological 84 | 85 | if domain != 'py': 86 | return None 87 | if not info['module']: 88 | return None 89 | 90 | # Attempt to identify the source file belonging to an `info` object. 91 | # This code is adapted from the Sphinx configuration of `numpy`; see 92 | # https://github.com/numpy/numpy/blob/main/doc/source/conf.py. 93 | def find_source_file(module): 94 | obj = sys.modules[module] 95 | 96 | for part in info['fullname'].split('.'): 97 | obj = getattr(obj, part) 98 | 99 | import inspect 100 | import os 101 | 102 | fn = inspect.getsourcefile(obj) 103 | fn = os.path.relpath( 104 | fn, 105 | start=os.path.dirname(torch_topological.__file__) 106 | ) 107 | 108 | source, lineno = inspect.getsourcelines(obj) 109 | return fn, lineno, lineno + len(source) - 1 110 | 111 | try: 112 | module = info['module'] 113 | source = find_source_file(module) 114 | except Exception: 115 | source = None 116 | 117 | root = f'https://github.com/aidos-lab/pytorch-topological/tree/main/{project}/' 118 | 119 | if source is not None: 120 | fn, start, end = source 121 | return root + f'{fn}#L{start}-L{end}' 122 | else: 123 | return None 124 | -------------------------------------------------------------------------------- /docs/source/data.rst: -------------------------------------------------------------------------------- 1 | torch_topological.data 2 | ====================== 3 | 4 | .. automodule:: torch_topological.data 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/examples/autoencoders.rst: -------------------------------------------------------------------------------- 1 | Autoencoders with geometrical--topological losses 2 | ================================================= 3 | 4 | In this example, we will create a simple autoencoder based on the 5 | *Topological Signature Loss* introduced by Moor et al. [Moor20a]_. 6 | 7 | A simple autoencoder 8 | -------------------- 9 | 10 | We first define a simple linear autoencoder. The representations 11 | obtained from this autoencoder are very similar to those obtained via 12 | PCA. 13 | 14 | .. literalinclude:: ../../../torch_topological/examples/autoencoders.py 15 | :language: python 16 | :pyobject: LinearAutoencoder 17 | 18 | Of particular interest in the code are the `encode` and `decode` 19 | functions. With ``encode``, we *embed* data in a latent space, whereas 20 | with ``decode``, we reconstruct it to its 'original' space. 21 | 22 | This reconstruction is of course never perfect. We therefore measure is 23 | quality using a reconstruction loss. Let's zoom into the specific 24 | function for this: 25 | 26 | .. literalinclude:: ../../../torch_topological/examples/autoencoders.py 27 | :language: python 28 | :pyobject: LinearAutoencoder.forward 29 | 30 | The important take-away here is that ``forward`` should return at least 31 | return one *loss value*. We will make use of this later on! 32 | 33 | A topological wrapper for autoencoder models 34 | -------------------------------------------- 35 | 36 | Our previous model uses ``encode`` to provide us with a lower-dimensional 37 | representation, the so-called *latent representation*. We can use this 38 | representation in order to calculate a topology-based loss! To this end, 39 | let's write a new ``forward`` function that uses an existing model ``model`` 40 | for the latent space generation: 41 | 42 | .. literalinclude:: ../../../torch_topological/examples/autoencoders.py 43 | :language: python 44 | :pyobject: TopologicalAutoencoder.forward 45 | 46 | In the code above, the important things are: 47 | 48 | 1. The use of a Vietoris--Rips complex ``self.vr`` to obtain persistence 49 | information about the input space ``x`` and the latent space ``z``, 50 | respectively. We call this type of data ``pi_x`` and ``pi_z``, 51 | respectively. 52 | 53 | 2. The call to a topology-based loss function ``self.loss()``, which takes 54 | two spaces ``x`` and ``y``, as well as their corresponding persistence 55 | information, to calculate the *signature loss* from [Moor20a]_. 56 | 57 | Putting this all together, we have the following 'wrapper class' that 58 | makes an existing model topology-aware: 59 | 60 | .. literalinclude:: ../../../torch_topological/examples/autoencoders.py 61 | :language: python 62 | :pyobject: TopologicalAutoencoder 63 | 64 | See [Moor20a]_ for more models to extend---being topology-aware can be 65 | crucial for many applications. 66 | 67 | Source code 68 | ----------- 69 | 70 | Here's the full source code of this example. 71 | 72 | .. literalinclude:: ../../../torch_topological/examples/autoencoders.py 73 | :language: python 74 | -------------------------------------------------------------------------------- /docs/source/examples/index.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | ``torch_topological`` comes with several documented examples that 5 | showcase different use cases. Please also see the 6 | `examples `_ 7 | directory on GitHub for more details and full code. 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :caption: Contents: 12 | 13 | autoencoders 14 | summary_statistics 15 | -------------------------------------------------------------------------------- /docs/source/examples/summary_statistics.rst: -------------------------------------------------------------------------------- 1 | Point cloud optimisation with summary statistics 2 | ================================================ 3 | 4 | One interesting use case of ``torch_topological`` involves changing the 5 | shape of a point cloud using topological summary statistics. Such 6 | summary statistics can *either* be used as simple loss functions, 7 | constituting a computationally cheap way of assessing the topological 8 | similarity of a given point cloud to a target point cloud, *or* 9 | serve to highlight certain topological properties of a single point 10 | cloud. 11 | 12 | In this example, we will consider *both* operations. 13 | 14 | Ingredients 15 | ----------- 16 | 17 | Our main ingredient is the :py:class:`torch_topological.nn.SummaryStatisticLoss` 18 | class. This class bundles different summary statistics on persistence 19 | diagrams and permits their calculation and comparison. 20 | 21 | This class can operate in two modes: 22 | 23 | 1. Calculating the loss for a single input data set. 24 | 2. Calculating the loss difference for two input data sets. 25 | 26 | Our example will showcase both of these modes! 27 | 28 | Optimising all the point clouds 29 | ------------------------------- 30 | 31 | Here's the bulk of the code required to optimise a point cloud. We will 32 | walk through the most important parts! 33 | 34 | .. literalinclude:: ../../../torch_topological/examples/summary_statistics.py 35 | :language: python 36 | :pyobject: main 37 | 38 | Next to creating some test data sets---check out :py:mod:`torch_topological.data` 39 | for more routines---the most important thing is to make sure that ``X``, 40 | our point cloud, is a trainable parameter. 41 | 42 | With that being out of the way, we can set up the summary statistic loss 43 | and start training. The main loop of the training might be familiar to 44 | those of you that already have some experience with ``pytorch``: it 45 | merely evaluates the loss and optimises it, following a general 46 | structure: 47 | 48 | .. code-block:: 49 | 50 | # Set up your favourite optimiser 51 | opt = optim.SGD(...) 52 | 53 | for i in range(100): 54 | 55 | # Do some calculations and obtain a loss term. In our specific 56 | # example, we have to get persistence information from data and 57 | # evaluate the loss based on that. 58 | loss = ... 59 | 60 | # This is what you will see in many such examples: we set all 61 | # gradients to zero and do a backwards pass. 62 | opt.zero_grad() 63 | loss.backward() 64 | opt.step() 65 | 66 | The rest of this example just involves some nice plotting. 67 | 68 | Source code 69 | ----------- 70 | 71 | Here's the full source code of this example. 72 | 73 | .. literalinclude:: ../../../torch_topological/examples/summary_statistics.py 74 | :language: python 75 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | ``torch_topological`` -- Topological Machine Learning with ``pytorch`` 2 | ======================================================================= 3 | 4 | ``pytorch-topological``, also known as ``torch_topological``, brings the 5 | power of topological methods to **your** machine learning project. 6 | ``torch_topological`` is specifically geared towards working well with 7 | other ``PyTorch`` projects, so if you are already familiar with this 8 | framework you should feel right at home. 9 | 10 | .. toctree:: 11 | :maxdepth: 2 12 | :caption: Getting Started 13 | 14 | usage 15 | nutshell 16 | comparison 17 | 18 | examples/index 19 | 20 | .. toctree:: 21 | :maxdepth: 2 22 | :caption: Modules 23 | 24 | data 25 | nn 26 | utils 27 | 28 | Indices and tables 29 | ================== 30 | 31 | * :ref:`genindex` 32 | * :ref:`modindex` 33 | * :ref:`search` 34 | -------------------------------------------------------------------------------- /docs/source/nn.rst: -------------------------------------------------------------------------------- 1 | torch_topological.nn 2 | ==================== 3 | 4 | .. automodule:: torch_topological.nn 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/nutshell.rst: -------------------------------------------------------------------------------- 1 | Topological machine learning in a nutshell 2 | ========================================== 3 | 4 | If you are reading this, you are probably wondering about what 5 | topological machine learning can do for you and your projects. 6 | The purpose of this page is to provide a pithy and coarse, but 7 | hopefully *useful*, introduction to some concepts. 8 | 9 | .. contents:: 10 | :local: 11 | :depth: 2 12 | 13 | High-level View 14 | --------------- 15 | 16 | If you are an expert in machine learning, you could summarise 17 | topological machine learning as novel set of inductive biases 18 | for models, making it possible to leverage properties such as 19 | connectivity of data set. To some extent, such properties are 20 | already covered by existing methods, but topology can be very 21 | useful as an additional 'lens' through which to view data. In 22 | graph learning tasks, for instance, being able to capture and 23 | use topological features---such as *cycles*---is *crucial* in 24 | order to improve predictive performance. 25 | 26 | Additional Resources 27 | -------------------- 28 | 29 | Here are some additional resources that might be of interest: 30 | 31 | - Amézquita et al., "The Shape of Things to Come: Topological Data Analysis and Biology, 32 | from Molecules to Organisms", Developmental Dynamics 33 | Volume 249, Issue 7, pp. 816--833, 2020. `doi:10.1002/dvdy.175 `_ 34 | 35 | - Hensel et al., "A Survey of Topological Machine Learning Methods", 36 | Frontiers in Artificial Intelligence. `doi:10.3389/frai.2021.681108 `_ 37 | -------------------------------------------------------------------------------- /docs/source/usage.rst: -------------------------------------------------------------------------------- 1 | Installing and using ``torch_topological`` 2 | ========================================== 3 | 4 | Requirements 5 | ------------ 6 | 7 | ``torch_topological`` requires at least Python 3.9. Normally, version 8 | resolution should work automatically. The precise mechanism for this 9 | depends on your installation method (see below). 10 | 11 | Installation via ``pip`` 12 | ------------------------ 13 | 14 | We recommended installing ``torch_topological`` using ``pip``. This way, 15 | you will always get a release version with a known set of features. It 16 | is *recommended* to use a virtual environment manager such as `poetry `_ 17 | for handling the dependencies of your project. 18 | 19 | .. code-block:: console 20 | 21 | $ pip install torch_topological 22 | 23 | Installation from source 24 | ------------------------ 25 | 26 | Installing the package from source requires a virtual environment 27 | manager capable of parsing ``pyproject.toml`` files. With `poetry `_, 28 | for instance, the following steps should be sufficient: 29 | 30 | .. code-block:: console 31 | 32 | $ git clone git@github.com:aidos-lab/pytorch-topological.git 33 | $ cd pytorch-topological 34 | $ poetry install 35 | -------------------------------------------------------------------------------- /docs/source/utils.rst: -------------------------------------------------------------------------------- 1 | torch_topological.utils 2 | ======================= 3 | 4 | .. automodule:: torch_topological.utils 5 | :members: 6 | -------------------------------------------------------------------------------- /logos/giotto.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aidos-lab/pytorch-topological/9b22d9ebeda68cab88a9267cea53a76a3f9f0c99/logos/giotto.jpg -------------------------------------------------------------------------------- /logos/gudhi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aidos-lab/pytorch-topological/9b22d9ebeda68cab88a9267cea53a76a3f9f0c99/logos/gudhi.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "torch_topological" 3 | version = "0.1.7" 4 | description = "A framework for topological machine learning based on `pytorch`." 5 | authors = [ 6 | {name = "Bastian Rieck", email = "bastian@rieck.me"}, 7 | ] 8 | license = "BSD-3-Clause" 9 | readme = "README.md" 10 | include = ["README.md"] 11 | 12 | [tool.poetry.dependencies] 13 | python = ">=3.9" 14 | matplotlib = "^3.5.0" 15 | giotto-ph = "^0.2.0" 16 | torch = ">=1.12.0" 17 | gudhi = "^3.4.1" 18 | POT = "^0.9.0" 19 | 20 | [tool.poetry.dev-dependencies] 21 | pytest = ">=5.2" 22 | Sphinx = ">=4.3.1" 23 | sphinx-rtd-theme = "^1.0.0" 24 | tqdm = "^4.62.3" 25 | torchvision = ">=0.11.2" 26 | flake8 = "^4.0.1" 27 | 28 | [tool.black] 29 | line-length = 79 30 | 31 | [build-system] 32 | requires = ["poetry-core>=1.0.0"] 33 | build-backend = "poetry.core.masonry.api" 34 | 35 | [project.urls] 36 | homepage = "https://github.com/aidos-lab/pytorch-topological" 37 | documentation = "https://pytorch-topological.readthedocs.io/" 38 | issues = "https://github.com/aidos-lab/pytorch-topological/issues" 39 | changelog = "https://github.com/aidos-lab/pytorch-topological/blob/main/CHANGELOG.md" 40 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aidos-lab/pytorch-topological/9b22d9ebeda68cab88a9267cea53a76a3f9f0c99/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_alpha_complex.py: -------------------------------------------------------------------------------- 1 | from torch_topological.datasets import SphereVsTorus 2 | 3 | from torch_topological.nn.data import make_tensor 4 | 5 | from torch_topological.nn import AlphaComplex 6 | 7 | from torch.utils.data import DataLoader 8 | 9 | batch_size = 64 10 | 11 | 12 | class TestAlphaComplexBatchHandling: 13 | data_set = SphereVsTorus(n_point_clouds=3 * batch_size) 14 | loader = DataLoader( 15 | data_set, 16 | batch_size=batch_size, 17 | shuffle=True, 18 | drop_last=False, 19 | ) 20 | 21 | ac = AlphaComplex() 22 | 23 | def test_processing(self): 24 | for (x, y) in self.loader: 25 | pers_info = self.ac(x) 26 | 27 | assert pers_info is not None 28 | assert len(pers_info) == batch_size 29 | 30 | pers_info_dense = make_tensor(pers_info) 31 | 32 | assert pers_info_dense is not None 33 | -------------------------------------------------------------------------------- /tests/test_cubical_complex.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import MNIST 2 | 3 | from torchvision.transforms import Compose 4 | from torchvision.transforms import Normalize 5 | from torchvision.transforms import ToTensor 6 | 7 | from torch_topological.nn import CubicalComplex 8 | 9 | from torch.utils.data import Dataset 10 | from torch.utils.data import DataLoader 11 | from torch.utils.data import RandomSampler 12 | 13 | from torch_topological.nn.data import batch_iter 14 | from torch_topological.nn.data import make_tensor 15 | from torch_topological.nn.data import PersistenceInformation 16 | 17 | import numpy as np 18 | 19 | import torch 20 | 21 | 22 | class RandomDataSet(Dataset): 23 | def __init__(self, n_samples, dim, side_length, n_channels): 24 | self.dim = dim 25 | self.side_length = side_length 26 | self.n_samples = n_samples 27 | self.n_channels = n_channels 28 | 29 | self.data = np.random.default_rng().normal( 30 | size=(n_samples, n_channels, *([side_length] * dim)) 31 | ) 32 | 33 | def __getitem__(self, index): 34 | return self.data[index] 35 | 36 | def __len__(self): 37 | return len(self.data) 38 | 39 | 40 | class TestCubicalComplex: 41 | batch_size = 32 42 | 43 | def test_single_image(self): 44 | x = np.random.default_rng().normal(size=(32, 32)) 45 | x = torch.as_tensor(x) 46 | cc = CubicalComplex() 47 | 48 | pers_info = cc(x) 49 | 50 | assert pers_info is not None 51 | assert len(pers_info) == 2 52 | assert pers_info[0].dimension == 0 53 | assert pers_info[1].dimension == 1 54 | 55 | def test_image_with_channels(self): 56 | x = np.random.default_rng().normal(size=(3, 32, 32)) 57 | x = torch.as_tensor(x) 58 | cc = CubicalComplex() 59 | 60 | pers_info = cc(x) 61 | 62 | assert pers_info is not None 63 | assert len(pers_info) == 3 64 | assert len(pers_info[0]) == 2 65 | assert pers_info[0][0].dimension == 0 66 | assert pers_info[1][1].dimension == 1 67 | 68 | def test_image_with_channels_and_batch(self): 69 | x = np.random.default_rng().normal(size=(self.batch_size, 3, 32, 32)) 70 | x = torch.as_tensor(x) 71 | cc = CubicalComplex() 72 | 73 | pers_info = cc(x) 74 | 75 | assert pers_info is not None 76 | assert len(pers_info) == self.batch_size 77 | assert len(pers_info[0][0]) == 2 78 | assert pers_info[0][0][0].dimension == 0 79 | assert pers_info[1][1][1].dimension == 1 80 | 81 | def test_2d(self): 82 | for n_channels in [1, 3]: 83 | for squeeze in [False, True]: 84 | data_set = RandomDataSet(128, 2, 8, n_channels) 85 | loader = DataLoader( 86 | data_set, 87 | self.batch_size, 88 | shuffle=True, 89 | drop_last=False 90 | ) 91 | 92 | if squeeze: 93 | data_set.data = data_set.data.squeeze() 94 | 95 | cc = CubicalComplex(dim=2) 96 | 97 | for batch in loader: 98 | pers_info = cc(batch) 99 | 100 | assert pers_info is not None 101 | 102 | def test_3d(self): 103 | data_set = RandomDataSet(128, 3, 8, 1) 104 | loader = DataLoader( 105 | data_set, 106 | self.batch_size, 107 | shuffle=True, 108 | drop_last=False 109 | ) 110 | 111 | cc = CubicalComplex(dim=3) 112 | 113 | for batch in loader: 114 | pers_info = cc(batch) 115 | 116 | assert pers_info is not None 117 | assert len(pers_info) == self.batch_size 118 | 119 | 120 | class TestCubicalComplexBatchHandling: 121 | batch_size = 64 122 | 123 | data_set = MNIST( 124 | './data/MNIST', 125 | download=True, 126 | train=False, 127 | transform=Compose( 128 | [ 129 | ToTensor(), 130 | Normalize([0.5], [0.5]) 131 | ] 132 | ), 133 | ) 134 | 135 | loader = DataLoader( 136 | data_set, 137 | batch_size=batch_size, 138 | shuffle=False, 139 | drop_last=True, 140 | sampler=RandomSampler( 141 | data_set, 142 | replacement=True, 143 | num_samples=100 144 | ) 145 | ) 146 | 147 | cc = CubicalComplex() 148 | 149 | def test_processing(self): 150 | for (x, y) in self.loader: 151 | pers_info = self.cc(x) 152 | 153 | assert pers_info is not None 154 | assert len(pers_info) == self.batch_size 155 | 156 | pers_info_dense = make_tensor(pers_info) 157 | 158 | assert pers_info_dense is not None 159 | 160 | def test_batch_iter(self): 161 | for (x, y) in self.loader: 162 | pers_info = self.cc(x) 163 | 164 | assert pers_info is not None 165 | assert len(pers_info) == self.batch_size 166 | 167 | assert sum(1 for x in batch_iter(pers_info)) == self.batch_size 168 | 169 | for x in batch_iter(pers_info): 170 | for y in x: 171 | assert isinstance(y, PersistenceInformation) 172 | 173 | for x in batch_iter(pers_info, dim=0): 174 | 175 | # Make sure that we have something to iterate over. 176 | assert sum(1 for y in x) != 0 177 | 178 | for y in x: 179 | assert isinstance(y, PersistenceInformation) 180 | assert y.dimension == 0 181 | -------------------------------------------------------------------------------- /tests/test_layers.py: -------------------------------------------------------------------------------- 1 | from torch_topological.datasets import SphereVsTorus 2 | 3 | from torch_topological.nn.data import make_tensor 4 | 5 | from torch_topological.nn import VietorisRipsComplex 6 | from torch_topological.nn.layers import StructureElementLayer 7 | 8 | from torch.utils.data import DataLoader 9 | 10 | batch_size = 32 11 | 12 | 13 | class TestStructureElementLayer: 14 | data_set = SphereVsTorus(n_point_clouds=2 * batch_size) 15 | loader = DataLoader( 16 | data_set, 17 | batch_size=batch_size, 18 | shuffle=True, 19 | drop_last=False, 20 | ) 21 | 22 | vr = VietorisRipsComplex(dim=1) 23 | 24 | layer = StructureElementLayer(10) 25 | 26 | def test_processing(self): 27 | for (x, y) in self.loader: 28 | pers_info = make_tensor(self.vr(x)) 29 | output = self.layer(pers_info) 30 | 31 | assert pers_info is not None 32 | assert output is not None 33 | -------------------------------------------------------------------------------- /tests/test_multi_scale_kernel.py: -------------------------------------------------------------------------------- 1 | from torch_topological.data import sample_from_unit_cube 2 | 3 | from torch_topological.nn import MultiScaleKernel 4 | from torch_topological.nn import VietorisRipsComplex 5 | 6 | 7 | class TestMultiScaleKernel: 8 | vr = VietorisRipsComplex(dim=1) 9 | kernel = MultiScaleKernel(1.0) 10 | X = sample_from_unit_cube(100) 11 | Y = sample_from_unit_cube(100) 12 | 13 | def test_pseudo_metric(self): 14 | pers_info_X, pers_info_Y = self.vr([self.X, self.Y]) 15 | k_XX = self.kernel(pers_info_X, pers_info_X) 16 | k_YY = self.kernel(pers_info_Y, pers_info_Y) 17 | k_XY = self.kernel(pers_info_X, pers_info_Y) 18 | 19 | assert k_XY > 0 20 | assert k_XX > 0 21 | assert k_YY > 0 22 | 23 | assert k_XX + k_YY - 2 * k_XY >= 0 24 | -------------------------------------------------------------------------------- /tests/test_ot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | from torch_topological.nn import PersistenceInformation 6 | from torch_topological.nn import WassersteinDistance 7 | 8 | 9 | def wrap(diagram): 10 | diagram = torch.as_tensor(diagram, dtype=torch.float) 11 | return [ 12 | PersistenceInformation([], diagram) 13 | ] 14 | 15 | 16 | class TestWassersteinDistance: 17 | 18 | def test_simple(self): 19 | X = [ 20 | (3, 5), 21 | (1, 2), 22 | (3, 4) 23 | ] 24 | 25 | Y = [ 26 | (0, 0), 27 | (3, 4), 28 | (5, 3), 29 | (7, 4) 30 | ] 31 | 32 | X = wrap(X) 33 | Y = wrap(Y) 34 | 35 | dist = WassersteinDistance()(X, Y) 36 | assert dist > 0.0 37 | 38 | def test_random(self): 39 | n_points = 10 40 | n_instances = 10 41 | 42 | for i in range(n_instances): 43 | X = np.random.default_rng().normal(size=(n_points, 2)) 44 | Y = np.random.default_rng().normal(size=(n_points, 2)) 45 | 46 | X = wrap(X) 47 | Y = wrap(Y) 48 | 49 | dist = WassersteinDistance()(X, Y) 50 | assert dist > 0.0 51 | 52 | def test_almost_zero(self): 53 | n_points = 100 54 | 55 | X = np.random.default_rng().uniform(-1e-16, 1e-11, size=(n_points, 2)) 56 | Y = np.random.default_rng().uniform(-1e-16, 1e-11, size=(n_points, 2)) 57 | 58 | X = wrap(X) 59 | Y = wrap(Y) 60 | 61 | dist = WassersteinDistance()(X, Y) 62 | assert dist > 0.0 63 | -------------------------------------------------------------------------------- /tests/test_pytorch_topological.py: -------------------------------------------------------------------------------- 1 | from torch_topological import __version__ 2 | 3 | 4 | def test_version(): 5 | assert __version__ == '0.1.0' 6 | -------------------------------------------------------------------------------- /tests/test_vietoris_rips_complex.py: -------------------------------------------------------------------------------- 1 | from torch_topological.datasets import SphereVsTorus 2 | 3 | from torch_topological.nn.data import batch_iter 4 | from torch_topological.nn.data import make_tensor 5 | from torch_topological.nn.data import PersistenceInformation 6 | 7 | from torch_topological.nn import VietorisRipsComplex 8 | from torch_topological.nn import WassersteinDistance 9 | 10 | from torch import cdist 11 | from torch import isinf 12 | from torch.utils.data import DataLoader 13 | 14 | import torch 15 | import pytest 16 | 17 | import numpy as np 18 | 19 | batch_size = 64 20 | 21 | 22 | class TestVietorisRipsComplex: 23 | data_set = SphereVsTorus(n_point_clouds=3 * batch_size) 24 | loader = DataLoader( 25 | data_set, 26 | batch_size=batch_size, 27 | shuffle=True, 28 | drop_last=False, 29 | ) 30 | 31 | vr = VietorisRipsComplex(dim=1, p=1) 32 | 33 | def test_simple(self): 34 | for (x, y) in self.loader: 35 | pers_info = self.vr(x) 36 | 37 | assert pers_info is not None 38 | assert len(pers_info) == batch_size 39 | 40 | def test_predefined_distances(self): 41 | for (x, y) in self.loader: 42 | 43 | distances = cdist(x, x, p=1) 44 | pers_info1 = self.vr(distances, treat_as_distances=True) 45 | pers_info2 = self.vr(x, treat_as_distances=False) 46 | 47 | assert pers_info1 is not None 48 | assert pers_info2 is not None 49 | assert len(pers_info1) == batch_size 50 | assert len(pers_info1) == len(pers_info2) 51 | 52 | # Check that we are getting the same persistence diagrams. 53 | for pi1, pi2 in zip(pers_info1, pers_info2): 54 | dist = WassersteinDistance()(pi1, pi2) 55 | assert dist == pytest.approx(0.0) 56 | 57 | 58 | class TestVietorisRipsComplexThreshold: 59 | data_set = SphereVsTorus(n_point_clouds=3 * batch_size) 60 | loader = DataLoader( 61 | data_set, 62 | batch_size=batch_size, 63 | shuffle=True, 64 | drop_last=False, 65 | ) 66 | 67 | vr = VietorisRipsComplex( 68 | dim=1, 69 | p=1, 70 | threshold=0.1, 71 | keep_infinite_features=True 72 | ) 73 | 74 | def test_threshold(self): 75 | for (x, y) in self.loader: 76 | pers_info = self.vr(x) 77 | 78 | assert pers_info is not None 79 | assert len(pers_info) == batch_size 80 | 81 | assert(torch.any(isinf(pers_info[0][0].diagram.flatten().sum()))) 82 | 83 | 84 | class TestVietorisRipsComplexBatchHandling: 85 | data_set = SphereVsTorus(n_point_clouds=3 * batch_size) 86 | loader = DataLoader( 87 | data_set, 88 | batch_size=batch_size, 89 | shuffle=True, 90 | drop_last=False, 91 | ) 92 | 93 | vr = VietorisRipsComplex(dim=1) 94 | 95 | def test_processing(self): 96 | for (x, y) in self.loader: 97 | pers_info = self.vr(x) 98 | 99 | assert pers_info is not None 100 | assert len(pers_info) == batch_size 101 | 102 | pers_info_dense = make_tensor(pers_info) 103 | 104 | assert pers_info_dense is not None 105 | 106 | def test_ragged_processing(self): 107 | rng = np.random.default_rng() 108 | 109 | data = [ 110 | np.random.default_rng().uniform(size=(rng.integers(32, 64), 8)) 111 | for _ in range(batch_size) 112 | ] 113 | 114 | pers_info = self.vr(data) 115 | 116 | assert pers_info is not None 117 | assert len(pers_info) == batch_size 118 | 119 | def test_batch_iter(self): 120 | for (x, y) in self.loader: 121 | pers_info = self.vr(x) 122 | 123 | assert pers_info is not None 124 | assert len(pers_info) == batch_size 125 | 126 | # This is just to confirm that we can properly iterate over 127 | # this batch. Here, `batch_iter` is a little bit like `NoP`, 128 | # but in general, more complicated nested structures may be 129 | # present. 130 | assert sum(1 for x in batch_iter(pers_info)) == batch_size 131 | 132 | for x in batch_iter(pers_info): 133 | for y in x: 134 | assert isinstance(y, PersistenceInformation) 135 | 136 | for x in batch_iter(pers_info, dim=0): 137 | 138 | # Make sure that we have something to iterate over. 139 | assert sum(1 for y in x) != 0 140 | 141 | for y in x: 142 | assert isinstance(y, PersistenceInformation) 143 | assert y.dimension == 0 144 | -------------------------------------------------------------------------------- /torch_topological.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 47 | 61 | 64 | 65 | 66 | 67 | 68 | 69 | 74 | 75 | 76 | 79 | 82 | 83 | 86 | 87 | 90 | 91 | 94 | 95 | 98 | 99 | 102 | 103 | 106 | 107 | 110 | 111 | 114 | 115 | 118 | 119 | 122 | 123 | 126 | 127 | 130 | 131 | 134 | 135 | 138 | 139 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /torch_topological/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.0' 2 | -------------------------------------------------------------------------------- /torch_topological/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Module for various data operations and data set creation strategies.""" 2 | 3 | from .shapes import sample_from_annulus 4 | from .shapes import sample_from_disk 5 | from .shapes import sample_from_sphere 6 | from .shapes import sample_from_torus 7 | from .shapes import sample_from_unit_cube 8 | 9 | __all__ = [ 10 | 'sample_from_annulus', 11 | 'sample_from_disk', 12 | 'sample_from_sphere', 13 | 'sample_from_torus', 14 | 'sample_from_unit_cube', 15 | ] 16 | -------------------------------------------------------------------------------- /torch_topological/data/shapes.py: -------------------------------------------------------------------------------- 1 | """Contains sampling routines for various simple geometric objects.""" 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from .utils import embed 7 | 8 | 9 | def sample_from_disk(n=100, r=0.9, R=1.0, seed=None): 10 | """Sample points from disk. 11 | 12 | Parameters 13 | ---------- 14 | n : int 15 | Number of points to sample. 16 | 17 | r: float 18 | Minimum radius, i.e. the radius of the inner circle of a perfect 19 | sampling. 20 | 21 | R : float 22 | Maximum radius, i.e. the radius of the outer circle of a perfect 23 | sampling. 24 | 25 | seed : int, instance of `np.random.Generator`, or `None` 26 | Seed for the random number generator, or an instance of such 27 | a generator. If set to `None`, the default random number 28 | generator will be used. 29 | 30 | Returns 31 | ------- 32 | torch.tensor of shape `(n, 2)` 33 | Tensor containing the sampled coordinates. 34 | """ 35 | assert r <= R, RuntimeError('r > R') 36 | 37 | rng = np.random.default_rng(seed) 38 | 39 | length = rng.uniform(r, R, size=n) 40 | angle = np.pi * rng.uniform(0, 2, size=n) 41 | 42 | x = np.sqrt(length) * np.cos(angle) 43 | y = np.sqrt(length) * np.sin(angle) 44 | 45 | X = np.vstack((x, y)).T 46 | return torch.as_tensor(X) 47 | 48 | 49 | def sample_from_unit_cube(n, d=3, seed=None): 50 | """Sample points uniformly from unit cube in `d` dimensions. 51 | 52 | Parameters 53 | ---------- 54 | n : int 55 | Number of points to sample 56 | 57 | d : int 58 | Number of dimensions. 59 | 60 | seed : int, instance of `np.random.Generator`, or `None` 61 | Seed for the random number generator, or an instance of such 62 | a generator. If set to `None`, the default random number 63 | generator will be used. 64 | 65 | Returns 66 | ------- 67 | torch.tensor of shape `(n, d)` 68 | Tensor containing the sampled coordinates. 69 | """ 70 | rng = np.random.default_rng(seed) 71 | X = rng.uniform(size=(n, d)) 72 | 73 | return torch.as_tensor(X) 74 | 75 | 76 | def sample_from_sphere(n=100, d=2, r=1, noise=None, ambient=None, seed=None): 77 | """Sample `n` data points from a `d`-sphere in `d + 1` dimensions. 78 | 79 | Parameters 80 | ----------- 81 | n : int 82 | Number of data points in shape. 83 | 84 | d : int 85 | Dimension of the sphere. 86 | 87 | r : float 88 | Radius of sphere. 89 | 90 | noise : float or None 91 | Optional noise factor. If set, data coordinates will be 92 | perturbed by a standard normal distribution, scaled by 93 | `noise`. 94 | 95 | ambient : int or None 96 | Embed the sphere into a space with ambient dimension equal to 97 | `ambient`. The sphere is randomly rotated into this 98 | high-dimensional space. 99 | 100 | seed : int, instance of `np.random.Generator`, or `None` 101 | Seed for the random number generator, or an instance of such 102 | a generator. If set to `None`, the default random number 103 | generator will be used. 104 | 105 | Returns 106 | ------- 107 | torch.tensor 108 | Tensor of sampled coordinates. If `ambient` is set, array will be 109 | of shape `(n, ambient)`. Else, array will be of shape `(n, d + 1)`. 110 | 111 | Notes 112 | ----- 113 | This function was originally authored by Nathaniel Saul as part of 114 | the `tadasets` package. [tadasets]_ 115 | 116 | References 117 | ---------- 118 | .. [tadasets] https://github.com/scikit-tda/tadasets 119 | 120 | """ 121 | rng = np.random.default_rng(seed) 122 | data = rng.standard_normal((n, d+1)) 123 | 124 | # Normalize points to the sphere 125 | data = r * data / np.sqrt(np.sum(data**2, 1)[:, None]) 126 | 127 | if noise: 128 | data += noise * rng.standard_normal(data.shape) 129 | 130 | if ambient is not None: 131 | assert ambient > d 132 | data = embed(data, ambient) 133 | 134 | return torch.as_tensor(data) 135 | 136 | 137 | def sample_from_torus(n, d=3, r=1.0, R=2.0, seed=None): 138 | """Sample points uniformly from torus and embed it in `d` dimensions. 139 | 140 | Parameters 141 | ---------- 142 | n : int 143 | Number of points to sample 144 | 145 | d : int 146 | Number of dimensions. 147 | 148 | r : float 149 | Radius of the 'tube' of the torus. 150 | 151 | R : float 152 | Radius of the torus, i.e. the distance from the centre of the 153 | 'tube' to the centre of the torus. 154 | 155 | seed : int, instance of `np.random.Generator`, or `None` 156 | Seed for the random number generator, or an instance of such 157 | a generator. If set to `None`, the default random number 158 | generator will be used. 159 | 160 | Returns 161 | ------- 162 | torch.tensor of shape `(n, d)` 163 | Tensor of sampled coordinates. 164 | """ 165 | if r > R: 166 | raise RuntimeError('Radius of the tube must be less than ' + 167 | 'or equal to radius of the torus') 168 | 169 | rng = np.random.default_rng(seed) 170 | angles = [] 171 | 172 | while len(angles) < n: 173 | x = rng.uniform(0, 2 * np.pi) 174 | y = rng.uniform(0, 1 / np.pi) 175 | 176 | f = (1.0 + (r/R) * np.cos(x)) / (2 * np.pi) 177 | 178 | if y < f: 179 | psi = rng.uniform(0, 2 * np.pi) 180 | angles.append((x, psi)) 181 | 182 | X = [] 183 | 184 | for theta, psi in angles: 185 | a = R + r * np.cos(theta) 186 | x = a * np.cos(psi) 187 | y = a * np.sin(psi) 188 | z = r * np.sin(theta) 189 | 190 | X.append((x, y, z)) 191 | 192 | X = np.asarray(X) 193 | return torch.as_tensor(X) 194 | 195 | 196 | def sample_from_annulus(n, r, R, seed=None): 197 | """Sample points from a 2D annulus. 198 | 199 | This function samples `N` points from an annulus with inner radius `r` 200 | and outer radius `R`. 201 | 202 | Parameters 203 | ---------- 204 | n : int 205 | Number of points to sample 206 | 207 | r : float 208 | Inner radius of annulus 209 | 210 | R : float 211 | Outer radius of annulus 212 | 213 | seed : int, instance of `np.random.Generator`, or `None` 214 | Seed for the random number generator, or an instance of such 215 | a generator. If set to `None`, the default random number 216 | generator will be used. 217 | 218 | Returns 219 | ------- 220 | torch.tensor of shape `(n, 2)` 221 | Tensor containing sampled coordinates. 222 | """ 223 | if r >= R: 224 | raise RuntimeError( 225 | 'Inner radius must be less than or equal to outer radius' 226 | ) 227 | 228 | rng = np.random.default_rng(seed) 229 | thetas = rng.uniform(0, 2 * np.pi, n) 230 | 231 | # Need to sample based on squared radii to account for density 232 | # differences. 233 | radii = np.sqrt(rng.uniform(r ** 2, R ** 2, n)) 234 | 235 | X = np.column_stack((radii * np.cos(thetas), radii * np.sin(thetas))) 236 | return torch.as_tensor(X) 237 | -------------------------------------------------------------------------------- /torch_topological/data/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for data set generation.""" 2 | 3 | import numpy as np 4 | 5 | 6 | def embed(data, ambient=50): 7 | """Embed `data` in `ambient` dimensions. 8 | 9 | Parameters 10 | ---------- 11 | data : array_like 12 | Input data set; needs to have shape `(n, d)`, i.e. samples are 13 | in the rows, dimensions are in the columns. 14 | 15 | ambient : int 16 | Dimension of embedding space. Must be greater than 17 | dimensionality of data. 18 | 19 | Returns 20 | ------- 21 | array_like 22 | Input array of shape `(n, D)`, with `D = ambient`. 23 | 24 | Notes 25 | ----- 26 | This function was originally authored by Nathaniel Saul as part of 27 | the `tadasets` package. [tadasets]_ 28 | 29 | References 30 | ---------- 31 | .. [tadasets] https://github.com/scikit-tda/tadasets 32 | """ 33 | n, d = data.shape 34 | assert ambient > d 35 | 36 | base = np.zeros((n, ambient)) 37 | base[:, :d] = data 38 | 39 | # construct a rotation matrix of dimension `ambient`. 40 | random_rotation = np.random.random((ambient, ambient)) 41 | q, r = np.linalg.qr(random_rotation) 42 | 43 | base = np.dot(base, q) 44 | return base 45 | -------------------------------------------------------------------------------- /torch_topological/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """Data sets with special topological characteristics.""" 2 | 3 | from .shapes import SphereVsTorus 4 | 5 | from .spheres import Spheres 6 | 7 | __all__ = [ 8 | 'Spheres', 9 | 'SphereVsTorus', 10 | ] 11 | -------------------------------------------------------------------------------- /torch_topological/datasets/shapes.py: -------------------------------------------------------------------------------- 1 | """Data sets based on simple topological shapes.""" 2 | 3 | from torch_topological.data import sample_from_sphere 4 | from torch_topological.data import sample_from_torus 5 | 6 | from torch.utils.data import Dataset 7 | 8 | import torch 9 | 10 | 11 | class SphereVsTorus(Dataset): 12 | """Data set containing point cloud samples from spheres and tori.""" 13 | 14 | def __init__( 15 | self, 16 | n_point_clouds=500, 17 | n_samples=100, 18 | shuffle=True 19 | ): 20 | """Create new instance of the data set. 21 | 22 | Parameters 23 | ---------- 24 | n_point_clouds : int 25 | Number of point clouds to generate. Each point cloud will 26 | consist of `n_samples` samples. 27 | 28 | n_samples : int 29 | Number of samples to use for each of the `n_point_clouds` 30 | point clouds contained in the data set. 31 | 32 | shuffle : bool 33 | If set, shuffles point clouds. Else, point clouds will be 34 | stored in the order of their creation. 35 | """ 36 | self.n_point_clouds = n_point_clouds 37 | self.n_samples = n_samples 38 | 39 | n_spheres = n_point_clouds // 2 40 | n_tori = n_point_clouds - n_spheres 41 | 42 | spheres = torch.stack([ 43 | sample_from_sphere(self.n_samples) for i in range(n_spheres) 44 | ]) 45 | 46 | tori = torch.stack([ 47 | sample_from_torus(self.n_samples) for i in range(n_tori) 48 | ]) 49 | 50 | labels = torch.as_tensor( 51 | [0] * n_spheres + [1] * n_tori, 52 | dtype=torch.long 53 | ) 54 | 55 | self.data = torch.vstack((spheres, tori)) 56 | self.labels = labels 57 | 58 | if shuffle: 59 | perm = torch.randperm(self.n_point_clouds) 60 | 61 | self.data = self.data[perm] 62 | self.labels = self.labels[perm] 63 | 64 | def __getitem__(self, index): 65 | """Get point cloud at `index`. 66 | 67 | Accesses the point cloud stored at `index` and returns it as 68 | well as its corresponding label. 69 | 70 | Parameters 71 | ---------- 72 | index : int 73 | Index of samples to access. 74 | 75 | Returns 76 | ------- 77 | Tuple of torch.tensor, torch.tensor 78 | Point cloud at index `index` and its label. 79 | """ 80 | return self.data[index], self.labels[index] 81 | 82 | def __len__(self): 83 | """Get number of point clouds stored in data set. 84 | 85 | Returns 86 | ------- 87 | int 88 | Number of point clouds stored in this instance of the class. 89 | """ 90 | return len(self.data) 91 | -------------------------------------------------------------------------------- /torch_topological/datasets/spheres.py: -------------------------------------------------------------------------------- 1 | """Create `SPHERES` data set.""" 2 | 3 | import numpy as np 4 | 5 | import torch 6 | 7 | from torch.utils.data import Dataset 8 | from torch.utils.data import random_split 9 | 10 | from torch_topological.data import sample_from_sphere 11 | 12 | 13 | def create_sphere_dataset(n_samples=500, n_spheres=11, d=100, r=5, seed=None): 14 | """Create data set of high-dimensional spheres. 15 | 16 | Create `SPHERES` data set described in Moor et al. [Moor20a]_. The 17 | data sets consists of `n` spheres, enclosed by a single sphere. It 18 | is a perfect example of simple manifolds, being arranged in simple 19 | pattern, that is nevertheless challenging to embed by algorithms. 20 | 21 | Parameters 22 | ---------- 23 | n_samples : int 24 | Number of points to sample per sphere. 25 | 26 | n_spheres : int 27 | Total number of spheres to create. The algorithm will always 28 | create the *last* sphere to enclose the previous ones. Hence, 29 | if `n_spheres = 3`, two spheres will be enclosed by a larger 30 | one. 31 | 32 | d : int 33 | Dimension of spheres to sample from. A `d`-sphere will be 34 | embedded in `d+1` dimensions. 35 | 36 | r : float 37 | Radius of smaller spheres. The radius of the larger enclosing 38 | sphere will be `5 * r`. 39 | 40 | seed : int, instance of `np.random.Generator`, or `None` 41 | Seed for the random number generator, or an instance of such 42 | a generator. If set to `None`, the default random number 43 | generator will be used. 44 | 45 | Returns 46 | ------- 47 | Tuple of `np.array`, `np.array` 48 | Array containing the coordinates of the spheres. The second 49 | array contains the respective labels, ranging from `0` to 50 | `n_spheres - 1`. This array can be used for visualisation 51 | purposes. 52 | 53 | Notes 54 | ----- 55 | The original version of this code was authored by Michael Moor. 56 | 57 | References 58 | ---------- 59 | .. [Moor20a] M. Moor et al., "Topological Autoencoders", 60 | *Proceedings of the 37th International Conference on Machine 61 | Learning*, PMLR 119, pp. 7045--7054, 2020. 62 | """ 63 | rng = np.random.default_rng(seed) 64 | 65 | variance = 10 / np.sqrt(d) 66 | shift_matrix = rng.normal(0, variance, [n_spheres, d+1]) 67 | 68 | spheres = [] 69 | n_datapoints = 0 70 | for i in np.arange(n_spheres - 1): 71 | sphere = sample_from_sphere(n=n_samples, d=d, r=r) 72 | spheres.append(sphere + shift_matrix[i, :]) 73 | n_datapoints += n_samples 74 | 75 | # Build additional large surrounding sphere: 76 | n_samples_big = 10 * n_samples 77 | big = sample_from_sphere(n=n_samples_big, d=d, r=r*5) 78 | spheres.append(big) 79 | n_datapoints += n_samples_big 80 | 81 | X = np.concatenate(spheres, axis=0) 82 | y = np.zeros(n_datapoints) 83 | 84 | label_index = 0 85 | 86 | for index, data in enumerate(spheres): 87 | n_sphere_samples = data.shape[0] 88 | y[label_index:label_index + n_sphere_samples] = index 89 | label_index += n_sphere_samples 90 | 91 | return X, y 92 | 93 | 94 | class Spheres(Dataset): 95 | """Data set containing multiple spheres, enclosed by a larger one. 96 | 97 | This is a `Dataset` instance of the `SPHERES` data set described by 98 | Moor et al. [Moor20a]_. The class provides a convenient interface 99 | for making use of that data set in machine learning tasks. 100 | """ 101 | 102 | def __init__( 103 | self, 104 | train=True, 105 | test_fraction=0.1, 106 | n_samples=100, 107 | n_spheres=11, 108 | r=5, 109 | ): 110 | """Create new instance of `SPHERES` data set. 111 | 112 | This class wraps the `SPHERES` data set for subsequent use in 113 | machine learning tasks. See :func:`create_sphere_dataset` for 114 | more details on the supported parameters. 115 | 116 | Parameters 117 | ---------- 118 | train : bool 119 | If set, create and store training portion of data set. 120 | 121 | test_fraction : float 122 | Fraction of generated samples to be used as the test portion 123 | of the data set. 124 | """ 125 | X, y = create_sphere_dataset( 126 | n_samples=n_samples, 127 | n_spheres=n_spheres, 128 | r=r) 129 | 130 | X = torch.as_tensor(X, dtype=torch.float) 131 | 132 | test_size = int(test_fraction * len(X)) 133 | train_size = len(X) - test_size 134 | 135 | indices = torch.as_tensor(np.arange(len(X)), dtype=torch.long) 136 | 137 | train_indices, test_indices = random_split( 138 | indices, [train_size, test_size] 139 | ) 140 | 141 | X_train, X_test = X[train_indices], X[test_indices] 142 | y_train, y_test = y[train_indices], y[test_indices] 143 | 144 | self.data = X_train if train else X_test 145 | self.labels = y_train if train else y_test 146 | 147 | self.dimension = X.shape[1] 148 | 149 | def __getitem__(self, index): 150 | """Return data point and label of a specific point. 151 | 152 | Parameters 153 | ---------- 154 | index : int 155 | Index of point in data set. Invalid indices will raise an 156 | exception. 157 | 158 | Returns 159 | ------- 160 | tuple of `torch.tensor` and `torch.tensor` 161 | The point at the specific index and its label, indicating 162 | the sphere it belongs to. See :func:`create_sphere_dataset` 163 | for more details on the specifics of label assignment. 164 | """ 165 | return self.data[index], self.labels[index] 166 | 167 | def __len__(self): 168 | """Return number of points in data set. 169 | 170 | Returns 171 | ------- 172 | int 173 | Number of samples in data set. 174 | """ 175 | return len(self.data) 176 | -------------------------------------------------------------------------------- /torch_topological/examples/alpha_complex.py: -------------------------------------------------------------------------------- 1 | """Example demonstrating the computation of alpha complexes. 2 | 3 | This simple example demonstrates how to use alpha complexes to change 4 | the appearance of a point cloud, following the `TopologyLayer 5 | `_ package. 6 | 7 | This example is still a **work in progress**. 8 | """ 9 | 10 | from torch_topological.nn import AlphaComplex 11 | from torch_topological.nn import SummaryStatisticLoss 12 | 13 | from torch_topological.utils import SelectByDimension 14 | 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | 18 | import torch 19 | 20 | if __name__ == '__main__': 21 | np.random.seed(42) 22 | data = np.random.rand(100, 2) 23 | 24 | alpha_complex = AlphaComplex() 25 | 26 | loss_fn = SummaryStatisticLoss( 27 | summary_statistic='polynomial_function', 28 | p=2, 29 | q=0 30 | ) 31 | 32 | X = torch.nn.Parameter(torch.as_tensor(data), requires_grad=True) 33 | opt = torch.optim.Adam([X], lr=1e-2) 34 | 35 | for i in range(100): 36 | # We are only interested in working with persistence diagrams of 37 | # dimension 1. 38 | selector = SelectByDimension(1) 39 | 40 | # Let's think step by step; apparently, AIs like that! So let's 41 | # first get the persistence information of our complex. We pass 42 | # it through the selector to remove diagrams we do not need. 43 | pers_info = alpha_complex(X) 44 | pers_info = selector(pers_info) 45 | 46 | # Evaluate the loss; notice that we want to *maximise* it in 47 | # order to improve the holes in the data. 48 | loss = -loss_fn(pers_info) 49 | 50 | opt.zero_grad() 51 | loss.backward() 52 | opt.step() 53 | 54 | X = X.detach().numpy() 55 | 56 | plt.scatter(X[:, 0], X[:, 1]) 57 | plt.show() 58 | -------------------------------------------------------------------------------- /torch_topological/examples/autoencoders.py: -------------------------------------------------------------------------------- 1 | """Demo for topology-regularised autoencoders. 2 | 3 | This example demonstrates how to use `pytorch-topological` to create an 4 | additional differentiable loss term that makes autoencoders aware of 5 | topological features. See [Moor20a]_ for more information. 6 | """ 7 | 8 | import torch 9 | import torch.optim as optim 10 | 11 | import matplotlib.pyplot as plt 12 | 13 | from tqdm import tqdm 14 | 15 | from torch.utils.data import DataLoader 16 | 17 | from torch_topological.datasets import Spheres 18 | 19 | from torch_topological.nn import SignatureLoss 20 | from torch_topological.nn import VietorisRipsComplex 21 | 22 | 23 | class LinearAutoencoder(torch.nn.Module): 24 | """Simple linear autoencoder class. 25 | 26 | This module performs simple embeddings based on an MSE loss. This is 27 | similar to ordinary principal component analysis. Notice that the 28 | class is only meant to provide a simple example that can be run 29 | easily even without the availability of a GPU. In practice, there 30 | are many more architectures with improved expressive power 31 | available. 32 | """ 33 | 34 | def __init__(self, input_dim, latent_dim=2): 35 | """Create new autoencoder with pre-defined latent dimension.""" 36 | super().__init__() 37 | 38 | self.input_dim = input_dim 39 | self.latent_dim = latent_dim 40 | 41 | self.encoder = torch.nn.Sequential( 42 | torch.nn.Linear(self.input_dim, self.latent_dim) 43 | ) 44 | 45 | self.decoder = torch.nn.Sequential( 46 | torch.nn.Linear(self.latent_dim, self.input_dim) 47 | ) 48 | 49 | self.loss_fn = torch.nn.MSELoss() 50 | 51 | def encode(self, x): 52 | """Embed data in latent space.""" 53 | return self.encoder(x) 54 | 55 | def decode(self, z): 56 | """Decode data from latent space.""" 57 | return self.decoder(z) 58 | 59 | def forward(self, x): 60 | """Embeds and reconstructs data, returning a loss.""" 61 | z = self.encode(x) 62 | x_hat = self.decode(z) 63 | 64 | # The loss can of course be changed. If this is your first time 65 | # working with autoencoders, a good exercise would be to 'grok' 66 | # the meaning of different losses. 67 | reconstruction_error = self.loss_fn(x, x_hat) 68 | return reconstruction_error 69 | 70 | 71 | class TopologicalAutoencoder(torch.nn.Module): 72 | """Wrapper for a topologically-regularised autoencoder. 73 | 74 | This class uses another autoencoder model and imbues it with an 75 | additional topology-based loss term. 76 | """ 77 | def __init__(self, model, lam=1.0): 78 | super().__init__() 79 | 80 | self.lam = lam 81 | self.model = model 82 | self.loss = SignatureLoss(p=2) 83 | 84 | # TODO: Make dimensionality configurable 85 | self.vr = VietorisRipsComplex(dim=0) 86 | 87 | def forward(self, x): 88 | z = self.model.encode(x) 89 | 90 | pi_x = self.vr(x) 91 | pi_z = self.vr(z) 92 | 93 | geom_loss = self.model(x) 94 | topo_loss = self.loss([x, pi_x], [z, pi_z]) 95 | 96 | loss = geom_loss + self.lam * topo_loss 97 | return loss 98 | 99 | 100 | if __name__ == '__main__': 101 | # We first have to create a data set. This follows the original 102 | # publication by Moor et al. by introducing a simple 'manifold' 103 | # data set consisting of multiple spheres. 104 | n_spheres = 11 105 | data_set = Spheres(n_spheres=n_spheres) 106 | 107 | train_loader = DataLoader( 108 | data_set, 109 | batch_size=32, 110 | shuffle=True, 111 | drop_last=True 112 | ) 113 | 114 | # Let's set up the two models that we are training. Note that in 115 | # a real application, you would have a more complicated training 116 | # setup, potentially with early stopping etc. This training loop 117 | # is merely to be seen as a proof of concept. 118 | model = LinearAutoencoder(input_dim=data_set.dimension) 119 | topo_model = TopologicalAutoencoder(model, lam=10) 120 | 121 | optimizer = optim.Adam(topo_model.parameters(), lr=1e-3) 122 | 123 | n_epochs = 5 124 | 125 | progress = tqdm(range(n_epochs)) 126 | 127 | for i in progress: 128 | topo_model.train() 129 | 130 | for batch, (x, y) in enumerate(train_loader): 131 | loss = topo_model(x) 132 | 133 | optimizer.zero_grad() 134 | loss.backward() 135 | optimizer.step() 136 | 137 | progress.set_postfix(loss=loss.item()) 138 | 139 | # Evaluate the autoencoder on a new instance of the data set. 140 | data_set = Spheres( 141 | train=False, 142 | n_samples=2000, 143 | n_spheres=n_spheres, 144 | ) 145 | 146 | test_loader = DataLoader( 147 | data_set, 148 | shuffle=False, 149 | batch_size=len(data_set) 150 | ) 151 | 152 | X, y = next(iter(test_loader)) 153 | Z = topo_model.model.encode(X).detach().numpy() 154 | 155 | plt.scatter( 156 | Z[:, 0], Z[:, 1], 157 | c=y, 158 | cmap='Set1', 159 | marker='o', 160 | alpha=0.9, 161 | ) 162 | plt.show() 163 | -------------------------------------------------------------------------------- /torch_topological/examples/benchmarking.py: -------------------------------------------------------------------------------- 1 | """Debug script for benchmarking some calculations.""" 2 | 3 | import time 4 | import torch 5 | import sys 6 | 7 | from torch_topological.nn import WassersteinDistance 8 | from torch_topological.nn import VietorisRipsComplex 9 | 10 | 11 | def run_test(X, vr, name, dist=False): 12 | W1 = WassersteinDistance() 13 | 14 | pre = time.perf_counter() 15 | 16 | X_pi = vr(X, treat_as_distances=dist) 17 | Y_pi = vr(X, treat_as_distances=dist) 18 | 19 | dists = torch.stack([W1(x_pi, y_pi) for x_pi, y_pi in zip(X_pi, Y_pi)]) 20 | dist = dists.mean() 21 | 22 | cur = time.perf_counter() 23 | print(f"{name}: {cur - pre:.4f}s") 24 | 25 | 26 | if __name__ == "__main__": 27 | X = torch.load(sys.argv[1]) 28 | 29 | print("Calculating everything ourselves") 30 | 31 | run_test(X, VietorisRipsComplex(dim=0), "raw") 32 | run_test(X, VietorisRipsComplex(dim=0, threshold=1.0), "thresholded") 33 | 34 | print("\nPre-defined distances") 35 | 36 | D = torch.cdist(X, X) 37 | 38 | run_test(D, VietorisRipsComplex(dim=0), "raw", dist=True) 39 | run_test( 40 | D, VietorisRipsComplex(dim=0, threshold=1.0), "thresholded", dist=True 41 | ) 42 | -------------------------------------------------------------------------------- /torch_topological/examples/classification.py: -------------------------------------------------------------------------------- 1 | """Example of classifying data using topological layers.""" 2 | 3 | from torch_topological.datasets import SphereVsTorus 4 | 5 | from torch_topological.nn.data import make_tensor 6 | 7 | from torch_topological.nn import VietorisRipsComplex 8 | from torch_topological.nn.layers import StructureElementLayer 9 | 10 | from torch.utils.data import DataLoader 11 | 12 | from tqdm import tqdm 13 | 14 | import torch 15 | 16 | 17 | class TopologicalModel(torch.nn.Module): 18 | def __init__(self, n_elements, latent_dim=64, output_dim=2): 19 | super().__init__() 20 | 21 | self.n_elements = n_elements 22 | self.latent_dim = latent_dim 23 | 24 | self.model = torch.nn.Sequential( 25 | StructureElementLayer(self.n_elements), 26 | torch.nn.Linear(self.n_elements, self.latent_dim), 27 | torch.nn.ReLU(), 28 | torch.nn.Linear(self.latent_dim, output_dim), 29 | ) 30 | 31 | self.vr = VietorisRipsComplex(dim=0) 32 | 33 | def forward(self, x): 34 | pers_info = self.vr(x) 35 | pers_info = make_tensor(pers_info) 36 | 37 | return self.model(pers_info) 38 | 39 | 40 | if __name__ == "__main__": 41 | batch_size = 32 42 | n_epochs = 50 43 | n_elements = 10 44 | 45 | data_set = SphereVsTorus(n_point_clouds=2 * batch_size) 46 | 47 | loader = DataLoader( 48 | data_set, 49 | batch_size=batch_size, 50 | shuffle=True, 51 | drop_last=False, 52 | ) 53 | 54 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 55 | 56 | model = TopologicalModel(n_elements).to(device) 57 | loss_fn = torch.nn.CrossEntropyLoss() 58 | opt = torch.optim.Adam(model.parameters(), lr=1e-4) 59 | 60 | progress = tqdm(range(n_epochs)) 61 | 62 | for epoch in progress: 63 | for batch, (x, y) in enumerate(loader): 64 | x = x.to(device) 65 | y = y.to(device) 66 | 67 | output = model(x) 68 | loss = loss_fn(output, y) 69 | 70 | opt.zero_grad() 71 | loss.backward() 72 | opt.step() 73 | 74 | pred = torch.argmax(output, dim=1) 75 | acc = (pred == y).sum() / len(y) 76 | 77 | progress.set_postfix(loss=f"{loss.item():.08f}", acc=f"{acc:.02f}") 78 | -------------------------------------------------------------------------------- /torch_topological/examples/cubical_complex.py: -------------------------------------------------------------------------------- 1 | """Demo for calculating cubical complexes. 2 | 3 | This example demonstrates how to perform topological operations on 4 | a structured array, such as a grey-scale image. 5 | """ 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | from torch_topological.nn import CubicalComplex 11 | from torch_topological.nn import WassersteinDistance 12 | 13 | from tqdm import tqdm 14 | 15 | import torch 16 | 17 | 18 | def sample_circles(n_cells, n_samples=1000): 19 | """Sample two nested circles and bin them. 20 | 21 | Parameters 22 | ---------- 23 | n_cells : int 24 | Number of cells for the 2D histogram, i.e. the 'resolution' of 25 | the histogram. 26 | 27 | n_samples : int 28 | Number of samples to use for creating the nested circles 29 | coordinates. 30 | 31 | Returns 32 | ------- 33 | np.ndarray of shape ``(n_cells, n_cells)`` 34 | Structured array containing intensity values for the data set. 35 | """ 36 | from sklearn.datasets import make_circles 37 | X = make_circles(n_samples, shuffle=True, noise=0.01)[0] 38 | 39 | heatmap, *_ = np.histogram2d(X[:, 0], X[:, 1], bins=n_cells) 40 | heatmap -= heatmap.mean() 41 | heatmap /= heatmap.max() 42 | 43 | return heatmap 44 | 45 | 46 | if __name__ == '__main__': 47 | 48 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 49 | print('Device', device) 50 | 51 | np.random.seed(23) 52 | 53 | Y = sample_circles(50) 54 | Y = torch.as_tensor(Y, dtype=torch.float) 55 | X = torch.as_tensor( 56 | Y + np.random.normal(scale=0.20, size=Y.shape), 57 | dtype=torch.float, 58 | device=device, 59 | ) 60 | Y = Y.to(device) 61 | X = torch.nn.Parameter(X, requires_grad=True).to(device) 62 | 63 | source = X.clone() 64 | 65 | optimizer = torch.optim.Adam([X], lr=1e-3) 66 | loss_fn = WassersteinDistance(q=2) 67 | 68 | cubical_complex = CubicalComplex() 69 | 70 | persistence_information_target = cubical_complex(Y) 71 | persistence_information_target = persistence_information_target[0] 72 | 73 | n_iter = 500 74 | progress = tqdm(range(n_iter)) 75 | 76 | for i in progress: 77 | persistence_information = cubical_complex(X) 78 | persistence_information = persistence_information[0] 79 | 80 | optimizer.zero_grad() 81 | 82 | loss = loss_fn( 83 | persistence_information, 84 | persistence_information_target 85 | ) 86 | 87 | loss.backward() 88 | optimizer.step() 89 | 90 | progress.set_postfix(loss=loss.item()) 91 | 92 | source = source.detach().numpy() 93 | target = Y.cpu().detach().numpy() 94 | result = X.cpu().detach().numpy() 95 | 96 | fig, ax = plt.subplots(ncols=3) 97 | 98 | ax[0].imshow(source) 99 | ax[0].set_title('Source') 100 | 101 | ax[1].imshow(target) 102 | ax[1].set_title('Target') 103 | 104 | ax[2].imshow(result) 105 | ax[2].set_title('Result') 106 | 107 | plt.show() 108 | -------------------------------------------------------------------------------- /torch_topological/examples/distances.py: -------------------------------------------------------------------------------- 1 | """Demo for distance minimisations of a point cloud. 2 | 3 | Note 4 | ---- 5 | This demonstration is a work in progress. It is not fully documented and 6 | tested yet. 7 | """ 8 | 9 | from tqdm import tqdm 10 | 11 | from torch_topological.data import sample_from_disk 12 | 13 | from torch_topological.nn import VietorisRipsComplex 14 | from torch_topological.nn import WassersteinDistance 15 | 16 | import torch 17 | import torch.optim as optim 18 | 19 | 20 | import matplotlib.pyplot as plt 21 | 22 | 23 | if __name__ == '__main__': 24 | n = 100 25 | 26 | X = sample_from_disk(r=0.5, R=0.6, n=n) 27 | Y = sample_from_disk(r=0.9, R=1.0, n=n) 28 | 29 | X = torch.nn.Parameter(torch.as_tensor(X), requires_grad=True) 30 | 31 | vr = VietorisRipsComplex(dim=1) 32 | 33 | pi_target = vr(Y) 34 | loss_fn = WassersteinDistance(q=2) 35 | 36 | opt = optim.SGD([X], lr=0.1) 37 | 38 | n_iterations = 500 39 | progress = tqdm(range(n_iterations)) 40 | 41 | for i in progress: 42 | 43 | opt.zero_grad() 44 | 45 | pi_source = vr(X) 46 | loss = loss_fn(pi_source, pi_target) 47 | 48 | loss.backward() 49 | opt.step() 50 | 51 | progress.set_postfix(loss=loss.item()) 52 | 53 | X = X.detach().numpy() 54 | 55 | plt.scatter(X[:, 0], X[:, 1], label='Source') 56 | plt.scatter(Y[:, 0], Y[:, 1], label='Target') 57 | 58 | plt.legend() 59 | plt.show() 60 | -------------------------------------------------------------------------------- /torch_topological/examples/gan.py: -------------------------------------------------------------------------------- 1 | """Example of topology-based GANs. 2 | 3 | This is a work in progress, demonstrating how to add a simple 4 | adversarial loss into a GAN architecture. 5 | """ 6 | 7 | import torch 8 | import torchvision 9 | 10 | from torch_topological.nn import CubicalComplex 11 | from torch_topological.nn import WassersteinDistance 12 | 13 | from tqdm import tqdm 14 | 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | 18 | 19 | class Generator(torch.nn.Module): 20 | """Simple generator module.""" 21 | 22 | def __init__(self, latent_dim, shape): 23 | super().__init__() 24 | 25 | self.latent_dim = latent_dim 26 | self.output_dim = np.prod(shape) 27 | 28 | self.shape = shape 29 | 30 | def _make_layer(input_dim, output_dim): 31 | layers = [ 32 | torch.nn.Linear(input_dim, output_dim), 33 | torch.nn.BatchNorm1d(output_dim, 0.8), 34 | torch.nn.LeakyReLU(0.2), 35 | ] 36 | return layers 37 | 38 | self.model = torch.nn.Sequential( 39 | *_make_layer(self.latent_dim, 32), 40 | *_make_layer(32, 64), 41 | *_make_layer(64, 128), 42 | torch.nn.Linear(128, self.output_dim), 43 | torch.nn.Sigmoid() 44 | ) 45 | 46 | def forward(self, z): 47 | point_cloud = self.model(z) 48 | point_cloud = point_cloud.view(point_cloud.size(0), *self.shape) 49 | 50 | return point_cloud 51 | 52 | 53 | class Discriminator(torch.nn.Module): 54 | """Simple discriminator module.""" 55 | 56 | def __init__(self, shape): 57 | super().__init__() 58 | 59 | input_dim = np.prod(shape) 60 | 61 | # Inspired by the original GAN. THERE CAN ONLY BE ONE! 62 | self.model = torch.nn.Sequential( 63 | torch.nn.Linear(input_dim, 256), 64 | torch.nn.LeakyReLU(0.2), 65 | torch.nn.Linear(256, 128), 66 | torch.nn.LeakyReLU(0.2), 67 | torch.nn.Linear(128, 1), 68 | torch.nn.Sigmoid(), 69 | ) 70 | 71 | def forward(self, x): 72 | # Flatten point cloud 73 | x = x.view(x.size(0), -1) 74 | return self.model(x) 75 | 76 | 77 | class TopologicalAdversarialLoss(torch.nn.Module): 78 | """Loss term incorporating topological features.""" 79 | 80 | def __init__(self): 81 | super().__init__() 82 | 83 | self.cubical = CubicalComplex() 84 | self.loss = WassersteinDistance(q=2) 85 | 86 | def forward(self, real, synthetic): 87 | """Calculate loss between real and synthetic images.""" 88 | 89 | loss = 0.0 90 | 91 | # This could potentially be solved by stacking as well. Note 92 | # that the interface of `torch_topological` permits slightly 93 | # similar constructions. 94 | for x, y in zip(real, synthetic): 95 | # Remove single-channel information. For multiple channels, 96 | # this will have to be adapted. 97 | x = x.squeeze() 98 | y = y.squeeze() 99 | 100 | pi_real = self.cubical(x)[0] 101 | pi_synthetic = self.cubical(y)[0] 102 | 103 | loss += self.loss(pi_real, pi_synthetic) 104 | self.loss(pi_real, pi_synthetic) 105 | 106 | return loss 107 | 108 | 109 | def show(imgs): 110 | if not isinstance(imgs, list): 111 | imgs = [imgs] 112 | fix, axs = plt.subplots(ncols=len(imgs), squeeze=False) 113 | for i, img in enumerate(imgs): 114 | img = img.detach() 115 | img = torchvision.transforms.functional.to_pil_image(img) 116 | axs[0, i].imshow(np.asarray(img)) 117 | axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) 118 | 119 | 120 | if __name__ == "__main__": 121 | n_epochs = 5 122 | 123 | img_size = 16 124 | shape = (1, img_size, img_size) 125 | 126 | batch_size = 32 127 | latent_dim = 200 128 | 129 | data_loader = torch.utils.data.DataLoader( 130 | torchvision.datasets.MNIST( 131 | "./data/MNIST", 132 | train=False, 133 | download=True, 134 | transform=torchvision.transforms.Compose( 135 | [ 136 | torchvision.transforms.Resize(img_size), 137 | torchvision.transforms.ToTensor(), 138 | torchvision.transforms.Normalize([0.5], [0.5]), 139 | ] 140 | ), 141 | ), 142 | batch_size=batch_size, 143 | shuffle=True, 144 | ) 145 | 146 | generator = Generator(shape=shape, latent_dim=latent_dim) 147 | discriminator = Discriminator(shape=shape) 148 | adversarial_loss = torch.nn.BCELoss() 149 | topo_loss = TopologicalAdversarialLoss() 150 | 151 | opt_g = torch.optim.Adam(generator.parameters(), lr=1e-4) 152 | opt_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4) 153 | 154 | for epoch in range(n_epochs): 155 | for batch, (imgs, _) in tqdm(enumerate(data_loader), desc="Batch"): 156 | z = torch.autograd.Variable( 157 | torch.Tensor( 158 | np.random.normal(0, 1, (imgs.shape[0], latent_dim)) 159 | ) 160 | ) 161 | 162 | real_labels = torch.as_tensor([1.0] * len(imgs)).view(-1, 1) 163 | fake_labels = torch.as_tensor([0.0] * len(imgs)).view(-1, 1) 164 | 165 | opt_g.zero_grad() 166 | 167 | imgs_synthetic = generator(z) 168 | 169 | generator_loss = adversarial_loss( 170 | discriminator(imgs_synthetic), real_labels 171 | ) + 0.01 * topo_loss(imgs, imgs_synthetic) 172 | 173 | generator_loss.backward() 174 | 175 | opt_g.step() 176 | 177 | opt_d.zero_grad() 178 | 179 | real_loss = adversarial_loss(discriminator(imgs), real_labels) 180 | fake_loss = adversarial_loss( 181 | discriminator(imgs_synthetic).detach(), fake_labels 182 | ) 183 | 184 | discriminator_loss = 0.5 * (real_loss + fake_loss) 185 | discriminator_loss.backward() 186 | 187 | opt_d.step() 188 | 189 | output = imgs_synthetic.detach() 190 | grid = torchvision.utils.make_grid(output) 191 | 192 | show(grid) 193 | 194 | plt.show() 195 | -------------------------------------------------------------------------------- /torch_topological/examples/image_smoothing.py: -------------------------------------------------------------------------------- 1 | """Demo for image smoothing based on topology. 2 | 3 | This is a work in progress, which at the moment merely showcases 4 | a simple topology-based loss applied to an image. There's *much* 5 | more to be done here. 6 | """ 7 | 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | from torch_topological.nn import CubicalComplex 12 | from torch_topological.nn import SummaryStatisticLoss 13 | 14 | from sklearn.datasets import make_circles 15 | 16 | import torch 17 | 18 | 19 | def _make_data(n_cells, n_samples=1000): 20 | X = make_circles(n_samples, shuffle=True, noise=0.05)[0] 21 | 22 | heatmap, *_ = np.histogram2d(X[:, 0], X[:, 1], bins=n_cells) 23 | heatmap -= heatmap.mean() 24 | heatmap /= heatmap.max() 25 | 26 | return heatmap 27 | 28 | 29 | class TopologicalSimplification(torch.nn.Module): 30 | def __init__(self, theta): 31 | super().__init__() 32 | 33 | self.theta = theta 34 | 35 | def forward(self, x): 36 | persistence_information = cubical(x) 37 | persistence_information = persistence_information[0] 38 | 39 | gens, pd = ( 40 | persistence_information.pairing, 41 | persistence_information.diagram, 42 | ) 43 | 44 | persistence = (pd[:, 1] - pd[:, 0]).abs() 45 | indices = persistence <= self.theta 46 | 47 | gens = gens[indices] 48 | 49 | indices = torch.vstack((gens[:, 0:2], gens[:, 2:])) 50 | 51 | indices = np.ravel_multi_index((indices[:, 0], indices[:, 1]), x.shape) 52 | 53 | x.ravel()[indices] = 0.0 54 | 55 | persistence_information = cubical(x) 56 | persistence_information = [persistence_information[0]] 57 | 58 | return x, persistence_information 59 | 60 | 61 | if __name__ == "__main__": 62 | 63 | np.random.seed(23) 64 | 65 | Y = _make_data(50) 66 | Y = torch.as_tensor(Y, dtype=torch.float) 67 | X = torch.as_tensor( 68 | Y + np.random.normal(scale=0.05, size=Y.shape), dtype=torch.float 69 | ) 70 | 71 | theta = torch.nn.Parameter( 72 | torch.as_tensor(1.0), 73 | requires_grad=True, 74 | ) 75 | 76 | topological_simplification = TopologicalSimplification(theta) 77 | 78 | optimizer = torch.optim.Adam([theta], lr=1e-2) 79 | loss_fn = SummaryStatisticLoss("total_persistence", p=1) 80 | 81 | cubical = CubicalComplex() 82 | 83 | persistence_information_target = cubical(Y) 84 | persistence_information_target = [persistence_information_target[0]] 85 | 86 | for i in range(500): 87 | X, persistence_information = topological_simplification(X) 88 | 89 | optimizer.zero_grad() 90 | 91 | loss = loss_fn(persistence_information, persistence_information_target) 92 | 93 | print(loss.item(), theta.item()) 94 | 95 | theta.backward() 96 | optimizer.step() 97 | 98 | X = X.detach().numpy() 99 | 100 | plt.imshow(X) 101 | plt.show() 102 | -------------------------------------------------------------------------------- /torch_topological/examples/summary_statistics.py: -------------------------------------------------------------------------------- 1 | """Demo for summary statistics minimisation of a point cloud. 2 | 3 | This example demonstrates how to use various topological summary 4 | statistics in order to change the shape of an input point cloud. 5 | The script can either demonstrate how to adjust the shape of two 6 | point clouds, i.e. using a summary statistic as a loss function, 7 | or how to change the shape of a *single* point cloud. By default 8 | two point clouds will be used. 9 | """ 10 | 11 | import argparse 12 | 13 | import matplotlib.pyplot as plt 14 | 15 | from torch_topological.data import sample_from_disk 16 | from torch_topological.data import sample_from_unit_cube 17 | 18 | from torch_topological.nn import SummaryStatisticLoss 19 | from torch_topological.nn import VietorisRipsComplex 20 | 21 | from tqdm import tqdm 22 | 23 | import torch 24 | import torch.optim as optim 25 | 26 | 27 | def create_data_set(args): 28 | """Create data set based on user-provided arguments.""" 29 | n = args.n_samples 30 | if args.single: 31 | X = sample_from_unit_cube(n=n, d=2) 32 | Y = X.clone() 33 | else: 34 | X = sample_from_disk(n=n, r=0.5, R=0.6) 35 | Y = sample_from_disk(n=n, r=0.9, R=1.0) 36 | 37 | # Make source point cloud adjustable by treating it as a parameter. 38 | # This enables topological loss functions to influence the shape of 39 | # `X`. 40 | X = torch.nn.Parameter(torch.as_tensor(X), requires_grad=True) 41 | return X, Y 42 | 43 | 44 | def main(args): 45 | """Run example.""" 46 | n_iterations = args.n_iterations 47 | statistic = args.statistic 48 | p = args.p 49 | q = args.q 50 | 51 | X, Y = create_data_set(args) 52 | vr = VietorisRipsComplex(dim=2) 53 | 54 | if not args.single: 55 | pi_target = vr(Y) 56 | 57 | loss_fn = SummaryStatisticLoss( 58 | summary_statistic=statistic, 59 | p=p, 60 | q=q 61 | ) 62 | 63 | opt = optim.SGD([X], lr=0.05) 64 | progress = tqdm(range(n_iterations)) 65 | 66 | for i in progress: 67 | pi_source = vr(X) 68 | 69 | if not args.single: 70 | loss = loss_fn(pi_source, pi_target) 71 | else: 72 | loss = loss_fn(pi_source) 73 | 74 | opt.zero_grad() 75 | loss.backward() 76 | opt.step() 77 | 78 | progress.set_postfix(loss=f'{loss.item():.08f}') 79 | 80 | X = X.detach().numpy() 81 | 82 | if args.single: 83 | plt.scatter(X[:, 0], X[:, 1], label='Result') 84 | plt.scatter(Y[:, 0], Y[:, 1], label='Initial') 85 | else: 86 | plt.scatter(X[:, 0], X[:, 1], label='Source') 87 | plt.scatter(Y[:, 0], Y[:, 1], label='Target') 88 | 89 | plt.legend() 90 | plt.show() 91 | 92 | 93 | if __name__ == '__main__': 94 | parser = argparse.ArgumentParser() 95 | 96 | parser.add_argument( 97 | '-i', '--n-iterations', 98 | default=250, 99 | type=int, 100 | help='Number of iterations' 101 | ) 102 | 103 | parser.add_argument( 104 | '-n', '--n-samples', 105 | default=100, 106 | type=int, 107 | help='Number of samples in point clouds' 108 | ) 109 | 110 | parser.add_argument( 111 | '-s', '--statistic', 112 | choices=[ 113 | 'persistent_entropy', 114 | 'polynomial_function', 115 | 'total_persistence', 116 | ], 117 | default='polynomial_function', 118 | help='Name of summary statistic to use for the loss' 119 | ) 120 | 121 | parser.add_argument( 122 | '-S', '--single', 123 | action='store_true', 124 | help='If set, uses only a single point cloud' 125 | ) 126 | 127 | parser.add_argument( 128 | '-p', 129 | type=float, 130 | default=2.0, 131 | help='Outer exponent for summary statistic loss calculation' 132 | ) 133 | 134 | parser.add_argument( 135 | '-q', 136 | type=float, 137 | default=2.0, 138 | help='Inner exponent for summary statistic loss calculation. Will ' 139 | 'only be used for certain summary statistics.' 140 | ) 141 | 142 | args = parser.parse_args() 143 | main(args) 144 | -------------------------------------------------------------------------------- /torch_topological/examples/weighted_euler_characteristic_transform.py: -------------------------------------------------------------------------------- 1 | """Demo for using the Weighted Euler Characteristic transform 2 | in an optimiblzation routine. 3 | 4 | This example demonstrates how the WECT can be used to optimize 5 | a neural networks predictions to match the topological signature 6 | of a target. 7 | """ 8 | from torch import nn 9 | import torch 10 | from torch_topological.nn import EulerDistance, WeightedEulerCurve 11 | import torch.optim as optim 12 | 13 | 14 | class NN(nn.Module): 15 | def __init__(self, inp_dim, hidden_dim, out_dim): 16 | super(NN, self).__init__() 17 | self.fc1 = torch.nn.Linear(inp_dim, hidden_dim) 18 | self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim) 19 | self.fc3 = torch.nn.Linear(hidden_dim, out_dim) 20 | self.out_dim = out_dim 21 | 22 | def forward(self, x_): 23 | x = x_.clone() 24 | x = torch.nn.functional.relu(self.fc1(x)) 25 | x = torch.nn.functional.relu(self.fc2(x)) 26 | x = self.fc3(x) 27 | x = torch.nn.functional.sigmoid(x) 28 | out = int(self.out_dim ** (1 / 3)) 29 | return x.reshape([out, out, out]) 30 | 31 | 32 | if __name__ == "__main__": 33 | torch.manual_seed(4) 34 | z = 3 35 | arr = torch.ones([z, z, z], requires_grad=False) 36 | model = NN(z * z * z, 100, z * z * z) 37 | arr2 = torch.rand([z, z, z], requires_grad=False) 38 | arr2[arr2 > 0.5] = 1 39 | arr2[arr2 <= 0.5] = 0 40 | ec = WeightedEulerCurve(prod=True) 41 | dist = EulerDistance() 42 | optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) 43 | ans = 100 44 | while ans > 0.05: 45 | optimizer.zero_grad() 46 | ans = dist(ec(model(arr.flatten())), ec(arr2)) 47 | ans.backward() 48 | optimizer.step() 49 | with torch.no_grad(): 50 | print( 51 | "L2 distance:", 52 | dist(model(arr.flatten()), arr2), 53 | " Euler Distance:", 54 | ans, 55 | ) 56 | -------------------------------------------------------------------------------- /torch_topological/nn/__init__.py: -------------------------------------------------------------------------------- 1 | """Layers and loss terms for persistence-based optimisation.""" 2 | 3 | from .data import PersistenceInformation 4 | 5 | from .distances import WassersteinDistance 6 | from .sliced_wasserstein_distance import SlicedWassersteinDistance 7 | from .sliced_wasserstein_kernel import SlicedWassersteinKernel 8 | from .multi_scale_kernel import MultiScaleKernel 9 | 10 | from .loss import SignatureLoss 11 | from .loss import SummaryStatisticLoss 12 | 13 | from .alpha_complex import AlphaComplex 14 | from .cubical_complex import CubicalComplex 15 | from .vietoris_rips_complex import VietorisRipsComplex 16 | 17 | from .weighted_euler_characteristic_transform import WeightedEulerCurve 18 | from .weighted_euler_characteristic_transform import EulerDistance 19 | 20 | 21 | __all__ = [ 22 | "AlphaComplex", 23 | "CubicalComplex", 24 | "EulerDistance", 25 | "MultiScaleKernel", 26 | "PersistenceInformation", 27 | "SignatureLoss", 28 | "SlicedWassersteinDistance", 29 | "SlicedWassersteinKernel", 30 | "SummaryStatisticLoss", 31 | "VietorisRipsComplex", 32 | "WassersteinDistance", 33 | "WeightedEulerCurve" 34 | ] 35 | -------------------------------------------------------------------------------- /torch_topological/nn/alpha_complex.py: -------------------------------------------------------------------------------- 1 | """Alpha complex calculation module(s).""" 2 | 3 | from torch import nn 4 | 5 | from torch_topological.nn import PersistenceInformation 6 | from torch_topological.nn.data import batch_handler 7 | 8 | import gudhi 9 | import itertools 10 | import torch 11 | 12 | 13 | class AlphaComplex(nn.Module): 14 | """Calculate persistence diagrams of an alpha complex. 15 | 16 | This module calculates persistence diagrams of an alpha complex, 17 | i.e. a subcomplex of the Delaunay triangulation, which is sparse 18 | and thus often substantially smaller than other complex. 19 | 20 | It was first described in [Edelsbrunner94]_ and is particularly 21 | useful when analysing low-dimensional data. 22 | 23 | Notes 24 | ----- 25 | At the moment, this alpha complex implementation, following other 26 | implementations, provides *distance-based filtrations* only. This 27 | means that the resulting persistence diagrams do *not* correspond 28 | to the circumradius of a simplex. 29 | 30 | In addition, this implementation is **work in progress**. Some of 31 | the core features, such as handling of infinite features, are not 32 | available at the moment. 33 | 34 | References 35 | ---------- 36 | .. [Edelsbrunner94] H. Edelsbrunner and E.P. Mücke, 37 | "Three-dimensional alpha shapes", *ACM Transactions on Graphics*, 38 | Volume 13, Number 1, pp. 43--72, 1994. 39 | """ 40 | 41 | def __init__(self, p=2): 42 | """Initialise new alpha complex calculation module. 43 | 44 | Parameters 45 | ---------- 46 | p : float 47 | Exponent for the `p`-norm calculation of distances. 48 | 49 | Notes 50 | ----- 51 | This module currently only supports Minkowski norms. It does not 52 | yet support other metrics. 53 | """ 54 | super().__init__() 55 | 56 | self.p = p 57 | 58 | def forward(self, x): 59 | """Implement forward pass for persistence diagram calculation. 60 | 61 | The forward pass entails calculating persistent homology on 62 | a point cloud and returning a set of persistence diagrams. 63 | 64 | Parameters 65 | ---------- 66 | x : array_like 67 | Input point cloud(s). `x` can either be a 2D array of shape 68 | `(n, d)`, which is treated as a single point cloud, or a 3D 69 | array/tensor of the form `(b, n, d)`, with `b` representing 70 | the batch size. Alternatively, you may also specify a list, 71 | possibly containing point clouds of non-uniform sizes. 72 | 73 | Returns 74 | ------- 75 | list of :class:`PersistenceInformation` 76 | List of :class:`PersistenceInformation`, containing both the 77 | persistence diagrams and the generators, i.e. the 78 | *pairings*, of a certain dimension of topological features. 79 | If `x` is a 3D array, returns a list of lists, in which the 80 | first dimension denotes the batch and the second dimension 81 | refers to the individual instances of 82 | :class:`PersistenceInformation` elements. 83 | 84 | Generators will be represented in the persistence pairing 85 | based on proper creator--destroyer pairs of simplices. In 86 | dimension `k`, for instance, every generator is stored as 87 | a `k`-simplex followed by a `k+1` simplex. 88 | """ 89 | return batch_handler(x, self._forward) 90 | 91 | def _forward(self, x): 92 | alpha_complex = gudhi.AlphaComplex( 93 | x.cpu().detach(), 94 | precision='fast', 95 | ) 96 | 97 | st = alpha_complex.create_simplex_tree() 98 | st.persistence() 99 | persistence_pairs = st.persistence_pairs() 100 | 101 | max_dim = x.shape[-1] 102 | dist = torch.cdist(x.contiguous(), x.contiguous(), p=self.p) 103 | 104 | return [ 105 | self._extract_generators_and_diagrams( 106 | dist, 107 | persistence_pairs, 108 | dim 109 | ) 110 | for dim in range(0, max_dim) 111 | ] 112 | 113 | def _extract_generators_and_diagrams(self, dist, persistence_pairs, dim): 114 | pairs = [ 115 | torch.cat((torch.as_tensor(p[0]), torch.as_tensor(p[1])), 0) 116 | for p in persistence_pairs if len(p[0]) == (dim + 1) 117 | ] 118 | 119 | # TODO: Ignore infinite features for now. Will require different 120 | # handling in the future. 121 | pairs = [p for p in pairs if len(p) == 2 * dim + 3] 122 | 123 | if not pairs: 124 | return PersistenceInformation( 125 | pairing=[], 126 | diagram=[], 127 | dimension=dim 128 | ) 129 | 130 | # Create tensor of shape `(n, 2 * dim + 3)`, with `n` being the 131 | # number of finite persistence pairs. 132 | pairs = torch.stack(pairs) 133 | 134 | # We have to branch here because the creation of 135 | # zero-dimensional persistence diagrams is easy, 136 | # whereas higher-dimensional diagrams require an 137 | # involved lookup strategy. 138 | if dim == 0: 139 | creators = torch.zeros_like( 140 | torch.as_tensor(pairs)[:, 0], 141 | device=dist.device 142 | ) 143 | 144 | # Iterate over the flag complex in order to get (a) the distance 145 | # of the creator simplex, and (b) the distance of the destroyer. 146 | # We *cannot* look this information up in the filtration itself, 147 | # because we have to use gradient-imbued information such as the 148 | # set of pairwise distances. 149 | else: 150 | creators = torch.stack( 151 | [ 152 | self._get_filtration_weight(creator, dist) 153 | for creator in pairs[:, :dim+1] 154 | ] 155 | ) 156 | 157 | # For the destroyers, we can always rely on the same 158 | # construction, regardless of dimensionality. 159 | destroyers = torch.stack( 160 | [ 161 | self._get_filtration_weight(destroyer, dist) 162 | for destroyer in pairs[:, dim+1:] 163 | ] 164 | ) 165 | 166 | # Create the persistence diagram from creator and destroyer 167 | # information. This step is the same for all dimensions. 168 | persistence_diagram = torch.stack( 169 | (creators, destroyers), 1 170 | ) 171 | 172 | return PersistenceInformation( 173 | pairing=pairs, 174 | diagram=persistence_diagram, 175 | dimension=dim 176 | ) 177 | 178 | def _get_filtration_weight(self, simplex, dist): 179 | """Auxiliary function for querying simplex weights. 180 | 181 | This function returns the filtration weight of an arbitrary 182 | simplex under a distance-based filtration, i.e. the maximum 183 | weight of its cofaces. This function is crucial for getting 184 | persistence diagrams that are differentiable. 185 | 186 | Parameters 187 | ---------- 188 | simplex : torch.tensor 189 | Simplex to query; must be a sequence of numbers, i.e. of 190 | shape `(n, 1)` or of shape `(n, )`. 191 | 192 | dist : torch.tensor 193 | Matrix of pairwise distances between points. 194 | 195 | Returns 196 | ------- 197 | torch.tensor 198 | Scalar tensor containing the filtration weight. 199 | """ 200 | weights = torch.stack([ 201 | dist[edge] for edge in itertools.combinations(simplex, 2) 202 | ]) 203 | 204 | # TODO: Might have to be adjusted depending on filtration 205 | # ordering? 206 | return torch.max(weights) 207 | -------------------------------------------------------------------------------- /torch_topological/nn/cubical_complex.py: -------------------------------------------------------------------------------- 1 | """Cubical complex calculation module.""" 2 | 3 | from torch import nn 4 | 5 | from torch_topological.nn import PersistenceInformation 6 | 7 | import gudhi 8 | import torch 9 | 10 | import numpy as np 11 | 12 | 13 | class CubicalComplex(nn.Module): 14 | """Calculate cubical complex persistence diagrams. 15 | 16 | This module calculates 'differentiable' persistence diagrams for 17 | structured data, such as images. This is achieved by calculating 18 | a *cubical complex*. 19 | 20 | Cubical complexes are the natural choice for calculating topological 21 | features of highly-structured inputs. See [Rieck20a]_ for an example 22 | of how to apply such topological features in practice. 23 | 24 | References 25 | ---------- 26 | .. [Rieck20a] B. Rieck et al., "Uncovering the Topology of 27 | Time-Varying fMRI Data Using Cubical Complex", *Advances in 28 | Neural Information Processing Systems 33*, pp. 6900--6912, 2020. 29 | """ 30 | 31 | def __init__(self, superlevel=False, dim=None): 32 | """Initialise new module. 33 | 34 | Parameters 35 | ---------- 36 | superlevel : bool 37 | Indicates whether to calculate topological features based on 38 | superlevel sets. By default, *sublevel set filtrations* are 39 | used. 40 | 41 | dim : int or `None` 42 | If set, describes dimension of input data. This is meant to 43 | be the dimension of an individual image **without** channel 44 | information, if any. The value of `dim` will change the way 45 | an input tensor is being handled: additional dimensions, if 46 | present, will be treated as batches or channels. If not set 47 | to an integer value, :func:`forward` will just *guess* what 48 | to do with an input (which should work in most cases). 49 | 50 | For example, when dealing with volume data, i.e. 3D tensors, 51 | set `dim=3` when instantiating the class. This will permit a 52 | seamless user experience with *both* batched and non-batched 53 | input data sets. 54 | """ 55 | super().__init__() 56 | 57 | # TODO: This is handled somewhat inelegantly below. Might be 58 | # smarter to update. 59 | self.superlevel = superlevel 60 | self.dim = dim 61 | 62 | def forward(self, x): 63 | """Implement forward pass for persistence diagram calculation. 64 | 65 | The forward pass entails calculating persistent homology on a 66 | cubical complex and returning a set of persistence diagrams. 67 | The way the input will be interpreted depends on the presence 68 | of the `dim` attribute of this class. If `dim` is set, the 69 | *last* `dim` dimensions of an input tensor will be considered to 70 | contain the image data. If `dim` is not set, image dimensions 71 | will be guessed as follows: 72 | 73 | 1. Tensor of dimension 2: a single image 74 | 2. Tensor of dimension 3: a single 2D image with channels 75 | 3. Tensor of dimension 4: a batch of 2D images with channels 76 | 77 | This is a conservative way of handling the data, ensuring that 78 | by default, 2D tensors with channel information and a potential 79 | batch information can be handled, since this is the default for 80 | many applications. 81 | 82 | To ensure that the class can handle e.g. 3D volume data, it is 83 | sufficient to set `dim = 3` when initialising the class. Refer 84 | to the examples and parameters sections for more details. 85 | 86 | Parameters 87 | ---------- 88 | x : array_like 89 | Input image(s). If `dim` has not been set, will *guess* how 90 | to handle the input as follows: `x` can either be a 2D array 91 | of shape `(H, W)`, which is treated as a single image, or 92 | a 3D array/tensor of the form `(C, H, W)`, with `C` 93 | representing the number of channels, or a 4D array/tensor of 94 | the form `(B, C, H, W)`, with `B` being the batch size. If 95 | `dim` has been set, the same handling strategy applies, but 96 | the *last* `dim` dimensions of the tensor are being used for 97 | the cubical complex calculation. All subsequent dimensions 98 | will be assumed to represent batches or channels (in this 99 | order). Hence, if `dim` is set, the tensor must at most have 100 | `dim + 2` dimensions. 101 | 102 | Returns 103 | ------- 104 | list of :class:`PersistenceInformation` 105 | List of :class:`PersistenceInformation`, containing both the 106 | persistence diagrams and the generators, i.e. the 107 | *pairings*, of a certain dimension of topological features. 108 | If `x` is a 3D array, returns a list of lists, in which the 109 | first dimension denotes the batch and the second dimension 110 | refers to the individual instances of 111 | :class:`PersistenceInformation` elements. Similar for 112 | higher-order tensors. 113 | 114 | Examples 115 | -------- 116 | # Handling 3D tensors (volumes), either in batches or presented 117 | # individually to the function. 118 | >> cubical_complex = CubicalComplex(dim=3) 119 | >> cubical_complex(x) 120 | """ 121 | # Dimension was provided; this makes calculating the *effective* 122 | # dimension of the tensor much easier: take everything but the 123 | # last `self.dim` dimensions. 124 | if self.dim is not None: 125 | shape = x.shape[:-self.dim] 126 | dims = len(shape) 127 | 128 | # No dimension was provided; just use the shape provided by the 129 | # client. 130 | else: 131 | dims = len(x.shape) - 2 132 | 133 | # No additional dimensions present: a single image 134 | if dims == 0: 135 | return self._forward(x) 136 | 137 | # Handle image with channels, such as a tensor of the form `(C, H, W)` 138 | elif dims == 1: 139 | return [ 140 | self._forward(x_) for x_ in x 141 | ] 142 | 143 | # Handle image with channels and batch index, such as a tensor of 144 | # the form `(B, C, H, W)`. 145 | elif dims == 2: 146 | return [ 147 | [self._forward(x__) for x__ in x_] for x_ in x 148 | ] 149 | 150 | def _forward(self, x): 151 | """Handle a single-channel image. 152 | 153 | This internal function handles the calculation of topological 154 | features for a single-channel image, i.e. an `array_like`. 155 | 156 | Parameters 157 | ---------- 158 | x : array_like of shape `(d_1, d_2, ..., d_d)` 159 | Single-channel input image of arbitrary dimensions. Batch 160 | dimensions and channel dimensions have to to be handled by 161 | the calling function explicitly. This function interprets 162 | its input as a high-dimensional image. 163 | 164 | Returns 165 | ------- 166 | list of class:`PersistenceInformation` 167 | List of persistence information data structures, containing 168 | the persistence diagram and the persistence pairing of some 169 | dimension in the input data set. 170 | """ 171 | if self.superlevel: 172 | x = -x 173 | 174 | cubical_complex = gudhi.CubicalComplex( 175 | dimensions=x.shape, 176 | top_dimensional_cells=x.flatten() 177 | ) 178 | 179 | # We need the persistence pairs first, even though we are *not* 180 | # using them directly here. 181 | cubical_complex.persistence() 182 | cofaces = cubical_complex.cofaces_of_persistence_pairs() 183 | 184 | max_dim = len(x.shape) 185 | 186 | # TODO: Make this configurable; is it possible that users only 187 | # want to return a *part* of the data? 188 | persistence_information = [ 189 | self._extract_generators_and_diagrams( 190 | x, 191 | cofaces, 192 | dim 193 | ) for dim in range(0, max_dim) 194 | ] 195 | 196 | return persistence_information 197 | 198 | def _extract_generators_and_diagrams(self, x, cofaces, dim): 199 | pairs = torch.empty((0, 2), dtype=torch.long) 200 | 201 | try: 202 | regular_pairs = torch.as_tensor( 203 | cofaces[0][dim], dtype=torch.long 204 | ) 205 | pairs = torch.cat( 206 | (pairs, regular_pairs) 207 | ) 208 | except IndexError: 209 | pass 210 | 211 | try: 212 | infinite_pairs = torch.as_tensor( 213 | cofaces[1][dim], dtype=torch.long 214 | ) 215 | except IndexError: 216 | infinite_pairs = None 217 | 218 | if infinite_pairs is not None: 219 | # 'Pair off' all the indices 220 | max_index = torch.argmax(x) 221 | fake_destroyers = torch.empty_like(infinite_pairs).fill_(max_index) 222 | 223 | infinite_pairs = torch.stack( 224 | (infinite_pairs, fake_destroyers), 1 225 | ) 226 | 227 | pairs = torch.cat( 228 | (pairs, infinite_pairs) 229 | ) 230 | 231 | return self._create_tensors_from_pairs(x, pairs, dim) 232 | 233 | # Internal utility function to handle the 'heavy lifting:' 234 | # creates tensors from sets of persistence pairs. 235 | def _create_tensors_from_pairs(self, x, pairs, dim): 236 | 237 | xs = x.shape 238 | 239 | # Notice that `creators` and `destroyers` refer to pixel 240 | # coordinates in the image. 241 | creators = torch.as_tensor( 242 | np.column_stack( 243 | np.unravel_index(pairs[:, 0], xs) 244 | ), 245 | dtype=torch.long 246 | ) 247 | destroyers = torch.as_tensor( 248 | np.column_stack( 249 | np.unravel_index(pairs[:, 1], xs) 250 | ), 251 | dtype=torch.long 252 | ) 253 | gens = torch.as_tensor(torch.hstack((creators, destroyers))) 254 | 255 | # TODO: Most efficient way to generate diagram again? 256 | persistence_diagram = torch.stack(( 257 | x.ravel()[pairs[:, 0]], 258 | x.ravel()[pairs[:, 1]] 259 | ), 1) 260 | 261 | return PersistenceInformation( 262 | pairing=gens, 263 | diagram=persistence_diagram, 264 | dimension=dim 265 | ) 266 | -------------------------------------------------------------------------------- /torch_topological/nn/data.py: -------------------------------------------------------------------------------- 1 | """Data structures for persistent homology calculations.""" 2 | 3 | from collections import namedtuple 4 | 5 | from itertools import chain 6 | 7 | from operator import itemgetter 8 | 9 | from torch_topological.utils import nesting_level 10 | 11 | import torch 12 | 13 | 14 | class PersistenceInformation( 15 | namedtuple( 16 | "PersistenceInformation", 17 | [ 18 | "pairing", 19 | "diagram", 20 | "dimension", 21 | ], 22 | # Ensures that there is always a dimension specified, albeit an 23 | # 'incorrect' one. 24 | defaults=[None], 25 | ) 26 | ): 27 | """Persistence information data structure. 28 | 29 | This is a light-weight data structure for carrying information about 30 | the calculation of persistent homology. It consists of the following 31 | components: 32 | 33 | - A *persistence pairing* 34 | - A *persistence diagram* 35 | - An (optional) *dimension* 36 | 37 | Due to its lightweight nature, no validity checks are performed, but 38 | all calculation modules should return a sequence of instances of the 39 | :class:`PersistenceInformation` class. 40 | 41 | Since this data class is shared with modules that are capable of 42 | calculating persistent homology, the exact form of the persistence 43 | pairing might change. Please refer to the respective classes for 44 | more documentation. 45 | """ 46 | 47 | __slots__ = () 48 | 49 | # Disable iterating over the class since it collates heterogeneous 50 | # information and should rather be treated as a building block. 51 | __iter__ = None 52 | 53 | 54 | def make_tensor(x): 55 | """Create dense tensor representation from sparse inputs. 56 | 57 | This function turns sparse inputs of :class:`PersistenceInformation` 58 | objects into 'dense' tensor representations, thus providing a useful 59 | integration into differentiable layers. 60 | 61 | The dimension of the resulting tensor depends on maximum number of 62 | topological features, summed over all dimensions in the data. This 63 | is similar to the format in `giotto-ph`. 64 | 65 | Parameters 66 | ---------- 67 | x : list of (list of ...) :class:`PersistenceInformation` 68 | Input, consisting of a (potentially) nested list of 69 | :class:`PersistenceInformation` objects as obtained 70 | from a persistent homology calculation module, such 71 | as :class:`VietorisRipsComplex`. 72 | 73 | Returns 74 | ------- 75 | torch.tensor 76 | Dense tensor representation of `x`. The output is best 77 | understood by considering some examples: given a batch 78 | obtained from :class:`VietorisRipsComplex`, our tensor 79 | will have shape `(B, N, 3)`. `B` is the batch size and 80 | `N` is the sum of maximum lengths of diagrams relative 81 | to this batch. Each entry will consist of a creator, a 82 | destroyer, and a dimension. Dummy entries, used to pad 83 | the batch, can be detected as `torch.nan`. 84 | """ 85 | level = nesting_level(x) 86 | 87 | # Internal utility function for calculating the length of the output 88 | # tensor. This is required to ensure that all inputs can be *merged* 89 | # into a single output tensor. 90 | def _calculate_length(x, level): 91 | 92 | # Simple base case; should never occur in practice but let's be 93 | # consistent here. 94 | if len(x) == 0: 95 | return 0 96 | 97 | # Each `chain.from_iterable()` removes an additional layer of 98 | # nesting. We only have to start from level 2; we get level 1 99 | # for free because we can always iterate over a list. 100 | for i in range(2, level + 1): 101 | x = chain.from_iterable(x) 102 | 103 | # Collect information that we need to create the full tensor. An 104 | # entry of the resulting list contains the length of the diagram 105 | # and the dimension, making it possible to derive padding values 106 | # for all entries. 107 | M = list(map(lambda a: (len(a.diagram), a.dimension), x)) 108 | 109 | # Get maximum dimension 110 | dim = max(M, key=itemgetter(1))[1] 111 | 112 | # Get *sum* of maximum number of entries for each dimension. 113 | # This is calculated over all batches. 114 | N = sum( 115 | [ 116 | max([L for L in M if L[1] == d], key=itemgetter(0))[0] 117 | for d in range(dim + 1) 118 | ] 119 | ) 120 | 121 | return N 122 | 123 | # Auxiliary function for padding tensors with `torch.nan` to 124 | # a specific dimension. Will always return a `list`; we turn 125 | # it into a tensor depending on the call level. 126 | def _pad_tensors(tensors, N, value=torch.nan): 127 | return list( 128 | map( 129 | lambda t: torch.nn.functional.pad( 130 | t, (0, 0, N - len(t), 0), mode="constant", value=value 131 | ), 132 | tensors, 133 | ) 134 | ) 135 | 136 | N = _calculate_length(x, level) 137 | 138 | # List of lists: the first axis is treated as the batch axis, while 139 | # the second axis is treated as the dimension of diagrams or pairs. 140 | # This also handles ordinary lists, which will result in a batch of 141 | # size 1. 142 | if level <= 2: 143 | tensors = [ 144 | make_tensor_from_persistence_information(pers_infos) 145 | for pers_infos in x 146 | ] 147 | 148 | # Pad all tensors to length N in the first dimension, then turn 149 | # them into a batch. 150 | result = torch.stack(_pad_tensors(tensors, N)) 151 | return result 152 | 153 | # List of lists of lists: this indicates image-based data, where we 154 | # also have a set of tensors for each channel. The internal layout, 155 | # i.e. our input, has the following structure: 156 | # 157 | # B x C x D 158 | # 159 | # Each variable being the length of the respective list. We want an 160 | # output of the following shape: 161 | # 162 | # B x C x N x 3 163 | # 164 | # Here, `N` is the maximum length of an individual persistence 165 | # information object. 166 | else: 167 | tensors = [ 168 | [ 169 | make_tensor_from_persistence_information(pers_infos) 170 | for pers_infos in batch 171 | ] 172 | for batch in x 173 | ] 174 | 175 | # Pad all tensors to length N in the first dimension, then turn 176 | # them into a batch. We first stack over channels (inner), then 177 | # over the batch (outer). 178 | result = torch.stack( 179 | [ 180 | torch.stack(_pad_tensors(batch_tensors, N)) 181 | for batch_tensors in tensors 182 | ] 183 | ) 184 | 185 | return result 186 | 187 | 188 | def make_tensor_from_persistence_information( 189 | pers_info, extract_generators=False 190 | ): 191 | """Convert (sequence) of persistence information entries to tensor. 192 | 193 | This function converts instance(s) of :class:`PersistenceInformation` 194 | objects into a single tensor. No padding will be performed. A client 195 | may specify what type of information to extract from the object. For 196 | instance, by default, the function will extract persistence diagrams 197 | but this behaviour can be changed by setting `extract_generators` to 198 | `true`. 199 | 200 | Parameters 201 | ---------- 202 | pers_info : :class:`PersistenceInformation` or iterable thereof 203 | Input persistence information object(s). The function is able to 204 | handle both single objects and sequences. This has no bearing on 205 | the length of the returned tensor. 206 | 207 | extract_generators : bool 208 | If set, extracts generators instead of persistence diagram from 209 | `pers_info`. 210 | 211 | Returns 212 | ------- 213 | torch.tensor 214 | Tensor of shape `(n, 3)`, where `n` is the sum of all features, 215 | over all dimensions in the input `pers_info`. Each triple shall 216 | be of the form `(creation, destruction, dim)` for a persistence 217 | diagram. If the client requested generators to be returned, the 218 | first two entries of the triple refer to *indices* with respect 219 | to the input data set. Depending on the algorithm employed, the 220 | meaning of these indices can change. Please refer to the module 221 | used to calculate persistent homology for more details. 222 | """ 223 | # Looks a little bit cumbersome, but since `namedtuple` is iterable 224 | # as well, we need to ensure that we are actually dealing with more 225 | # than one instance here. 226 | if len(pers_info) > 1 and not isinstance( 227 | pers_info[0], PersistenceInformation 228 | ): 229 | pers_info = [pers_info] 230 | 231 | # TODO: This might not always work since the size of generators 232 | # changes in different dimensions. 233 | if extract_generators: 234 | pairs = torch.cat( 235 | [torch.as_tensor(x.pairing, dtype=torch.float) for x in pers_info], 236 | ).long() 237 | else: 238 | pairs = torch.cat( 239 | [torch.as_tensor(x.diagram, dtype=torch.float) for x in pers_info], 240 | ).float() 241 | 242 | dimensions = torch.cat( 243 | [ 244 | torch.as_tensor( 245 | [x.dimension] * len(x.diagram), 246 | dtype=torch.long, 247 | device=x.diagram.device, 248 | ) 249 | for x in pers_info 250 | ] 251 | ) 252 | 253 | result = torch.column_stack((pairs, dimensions)) 254 | return result 255 | 256 | 257 | def batch_handler(x, handler_fn, **kwargs): 258 | """Light-weight batch handling function. 259 | 260 | The purpose of this function is to simplify the handling of batches 261 | of input data, in particular for modules that deal with point cloud 262 | data. The handler essentially checks whether a 2D array (matrix) or 263 | a 3D array (tensor) was provided, and calls a handler function. The 264 | idea of the handler function is to handle an individual 2D array. 265 | 266 | Parameters 267 | ---------- 268 | x : array_like 269 | Input point cloud(s). Can be either 2D array, indicating 270 | a single point cloud, or a 3D array, or even a *list* of 271 | point clouds (of potentially different cardinalities). 272 | 273 | handler_fn : callable 274 | Function to call for handling a 2D array. 275 | 276 | **kwargs 277 | Additional arguments to provide to `handler_fn`. 278 | 279 | Returns 280 | ------- 281 | list or individual value 282 | Depending on whether `x` needs to be unwrapped, this function 283 | returns either a single value or a list of values, resulting 284 | from calling `handler_fn` on individual parts of `x`. 285 | """ 286 | # Check whether individual batches need to be handled (3D array) 287 | # or not (2D array). We default to this type of processing for a 288 | # list as well. 289 | if isinstance(x, list) or len(x.shape) == 3: 290 | # TODO: This type of batch handling is rather ugly and 291 | # inefficient but at the same time, it is the easiest 292 | # workaround for now, permitting even 'ragged' inputs of 293 | # different lengths. 294 | return [handler_fn(torch.as_tensor(x_), **kwargs) for x_ in x] 295 | else: 296 | return handler_fn(torch.as_tensor(x), **kwargs) 297 | 298 | 299 | def batch_iter(x, dim=None): 300 | """Iterate over batches from input data. 301 | 302 | This utility function simplifies working with 'sparse' data sets 303 | consisting of :class:`PersistenceInformation` instances. It will 304 | present inputs in the order in which they appear in a batch such 305 | that instances belonging to the same data set are kept together. 306 | 307 | Parameters 308 | ---------- 309 | x : recursively-nested list of :class:`PersistenceInformation` 310 | Input in sparse form, i.e. a nested structure containing 311 | persistence information about a data set. 312 | 313 | dim : int or `None` 314 | If set, only iterates over persistence information instances of 315 | the specified dimension. Else, will iterate over all instances. 316 | 317 | Returns 318 | ------- 319 | A generator (iterable) that will either yield direct instances of 320 | :class:`PersistenceInformation` objects or further iterators into 321 | them. This ensures that it is possible to always iterate over the 322 | individual batches, without having to know internal details about 323 | the structure of `x`. 324 | """ 325 | level = nesting_level(x) 326 | 327 | # Nothing to do for non-nested data structures, i.e. a single batch 328 | # that has been squeezed (for instance). Wrapping the input enables 329 | # us to treat it like a regular input again. 330 | if level == 1: 331 | x = [x] 332 | 333 | if level <= 2: 334 | 335 | def handler(x): 336 | return x 337 | 338 | # Remove the first dimension but also the subsequent one so that all 339 | # only iterables containing persistence information about a specific 340 | # data set are being returned. 341 | # 342 | # TODO: Generalise recursively? Do we want to support that? 343 | else: 344 | 345 | def handler(x): 346 | return chain.from_iterable(x) 347 | 348 | if dim is not None: 349 | for x_ in x: 350 | yield list(filter(lambda x: x.dimension == dim, handler(x_))) 351 | else: 352 | for x_ in x: 353 | yield handler(x_) 354 | -------------------------------------------------------------------------------- /torch_topological/nn/distances.py: -------------------------------------------------------------------------------- 1 | """Distance calculation modules between topological descriptors.""" 2 | 3 | import ot 4 | import torch 5 | 6 | from torch_topological.utils import wrap_if_not_iterable 7 | 8 | 9 | class WassersteinDistance(torch.nn.Module): 10 | """Implement Wasserstein distance between persistence diagrams. 11 | 12 | This module calculates the Wasserstein between two persistence 13 | diagrams. The Wasserstein distance is arguably the most common 14 | metric that is applied when dealing with such diagrams. Notice 15 | that calculating the metric involves solving optimal transport 16 | problems, which are known to suffer from scalability problems. 17 | When dealing with large persistence diagrams, other losses may 18 | be more appropriate. 19 | """ 20 | 21 | def __init__(self, p=torch.inf, q=1): 22 | """Create new Wasserstein distance calculation module. 23 | 24 | Parameters 25 | ---------- 26 | p : float or `inf` 27 | Specifies the exponent of the norm to calculate. By default, 28 | `p = torch.inf`, corresponding to the *maximum norm*. 29 | 30 | q: float 31 | Specifies the order of Wasserstein metric to calculate. This 32 | raises all internal matching costs to the power of `q`, hence 33 | subsequently returning the `q`-th root of the total cost. 34 | """ 35 | super().__init__() 36 | 37 | self.p = p 38 | self.q = q 39 | 40 | def _project_to_diagonal(self, diagram): 41 | x = diagram[:, 0] 42 | y = diagram[:, 1] 43 | 44 | # TODO: Is this the closest point in all p-norms? 45 | return 0.5 * torch.stack(((x + y), (x + y)), 1) 46 | 47 | def _distance_to_diagonal(self, diagram): 48 | return torch.linalg.vector_norm( 49 | diagram - self._project_to_diagonal(diagram), 50 | self.p, 51 | dim=1 52 | ) 53 | 54 | def _make_distance_matrix(self, D1, D2): 55 | dist_D11 = self._distance_to_diagonal(D1) 56 | dist_D22 = self._distance_to_diagonal(D2) 57 | 58 | # n x m matrix containing the distances between 'regular' 59 | # persistence pairs of both persistence diagrams. 60 | dist = torch.cdist(D1, D2, p=torch.inf) 61 | 62 | # Extend the matrix with a column of distances of samples in D1 63 | # to their respective projection on the diagonal. 64 | upper_blocks = torch.hstack((dist, dist_D11[:, None])) 65 | 66 | # Create a lower row of distances of samples in D2 to their 67 | # respective projection on the diagonal. The ordering needs 68 | # to follow the ordering of samples in D2. Note how one `0` 69 | # needs to be added to the row in order to balance it. The 70 | # entry intuitively describes the cost between *projected* 71 | # points, so it has to be zero. 72 | lower_blocks = torch.cat( 73 | (dist_D22, torch.tensor(0, device=dist_D22.device).unsqueeze(0)) 74 | ) 75 | 76 | # Full (n + 1 ) x (m + 1) matrix containing *all* distances. By 77 | # construction, M[[i, n] contains distances to projected points 78 | # in D1, whereas M[m, j] does the same for points in D2. Only a 79 | # cell M[i, j] with 0 <= i < n and 0 <= j < m contains a proper 80 | # distance. 81 | M = torch.vstack((upper_blocks, lower_blocks)) 82 | M = M.pow(self.q) 83 | 84 | return M 85 | 86 | def forward(self, X, Y): 87 | """Calculate Wasserstein metric based on input tensors. 88 | 89 | Parameters 90 | ---------- 91 | X : list or instance of :class:`PersistenceInformation` 92 | Topological features of the first space. Supposed to contain 93 | persistence diagrams and persistence pairings. 94 | 95 | Y : list or instance of :class:`PersistenceInformation` 96 | Topological features of the second space. Supposed to 97 | contain persistence diagrams and persistence pairings. 98 | 99 | Returns 100 | ------- 101 | torch.tensor 102 | A single scalar tensor containing the distance between the 103 | persistence diagram(s) contained in `X` and `Y`. 104 | """ 105 | total_cost = 0.0 106 | 107 | X = wrap_if_not_iterable(X) 108 | Y = wrap_if_not_iterable(Y) 109 | 110 | for pers_info in zip(X, Y): 111 | D1 = pers_info[0].diagram 112 | D2 = pers_info[1].diagram 113 | 114 | n = len(D1) 115 | m = len(D2) 116 | 117 | dist = self._make_distance_matrix(D1, D2) 118 | 119 | # Create weight vectors. Since the last entries of entries 120 | # describe the m points coming from D2, we have to set the 121 | # last entry accordingly. 122 | 123 | a = torch.ones(n + 1, device=dist.device) 124 | b = torch.ones(m + 1, device=dist.device) 125 | 126 | a[-1] = m 127 | b[-1] = n 128 | 129 | # TODO: Make settings configurable? 130 | total_cost += ot.emd2(a, b, dist) 131 | 132 | return total_cost.pow(1.0 / self.q) 133 | -------------------------------------------------------------------------------- /torch_topological/nn/graphs.py: -------------------------------------------------------------------------------- 1 | """Layers for topological data analysis based on graphs. 2 | 3 | This is a work-in-progress module. The goal is to provide 4 | a functionality similar to what is described in TOGL; see 5 | https://github.com/BorgwardtLab/TOGL. 6 | 7 | At the moment, the following functionality is present: 8 | - simple deep set layer 9 | - simple TOGL implementation with deep set functions 10 | - basic GCN with TOGL 11 | 12 | The following aspects are currently ignored: 13 | - handling higher-order information properly 14 | - expanding simplicial complexes 15 | - making use of the dimension of features 16 | 17 | On the other hand, the current implementation is simplified in the sense 18 | of showing that *lower star filtrations* (and their corresponding 19 | generators) may be useful. 20 | """ 21 | 22 | import torch 23 | import torch.nn as nn 24 | 25 | import gudhi as gd 26 | 27 | from torch_geometric.data import Data 28 | 29 | from torch_geometric.loader import DataLoader 30 | 31 | from torch_geometric.utils import erdos_renyi_graph 32 | 33 | from torch_geometric.nn import GCNConv 34 | from torch_geometric.nn import global_mean_pool 35 | 36 | from torch_scatter import scatter 37 | 38 | from torch_topological.utils import pairwise 39 | 40 | 41 | class DeepSetLayer(nn.Module): 42 | """Simple equivariant deep set layer.""" 43 | 44 | def __init__(self, dim_in, dim_out, aggregation_fn): 45 | """Create new deep set layer. 46 | 47 | Parameters 48 | ---------- 49 | dim_in : int 50 | Input dimension 51 | 52 | dim_out : int 53 | Output dimension 54 | 55 | aggregation_fn : str 56 | Aggregation to use for the reduction step. Must be valid for 57 | the ``torch_scatter.scatter()`` function, i.e. one of "sum", 58 | "mul", "mean", "min" or "max". 59 | """ 60 | super().__init__() 61 | 62 | self.Gamma = nn.Linear(dim_in, dim_out) 63 | self.Lambda = nn.Linear(dim_in, dim_out, bias=False) 64 | 65 | self.aggregation_fn = aggregation_fn 66 | 67 | def forward(self, x, batch): 68 | """Implement forward pass through layer.""" 69 | xm = scatter(x, batch, dim=0, reduce=self.aggregation_fn) 70 | xm = self.Lambda(xm) 71 | 72 | x = self.Gamma(x) 73 | x = x - xm[batch, :] 74 | return x 75 | 76 | 77 | class TOGL(nn.Module): 78 | """Implementation of TOGL, a topological graph layer. 79 | 80 | Some caveats: this implementation only focuses on a set function 81 | aggregation of topological features. At the moment, it is not as 82 | powerful and feature-complete as the original implementation. 83 | """ 84 | 85 | def __init__( 86 | self, 87 | n_features, 88 | n_filtrations, 89 | hidden_dim, 90 | out_dim, 91 | aggregation_fn, 92 | ): 93 | super().__init__() 94 | 95 | self.n_filtrations = n_filtrations 96 | 97 | self.filtrations = nn.Sequential( 98 | nn.Linear(n_features, hidden_dim), 99 | nn.ReLU(), 100 | nn.Linear(hidden_dim, n_filtrations), 101 | ) 102 | 103 | self.set_fn = nn.ModuleList( 104 | [ 105 | nn.Linear(n_filtrations * 2, out_dim), 106 | nn.ReLU(), 107 | DeepSetLayer(out_dim, out_dim, aggregation_fn), 108 | nn.ReLU(), 109 | DeepSetLayer( 110 | out_dim, 111 | n_features, 112 | aggregation_fn, 113 | ), 114 | ] 115 | ) 116 | 117 | self.batch_norm = nn.BatchNorm1d(n_features) 118 | 119 | def compute_persistent_homology( 120 | self, 121 | x, 122 | edge_index, 123 | vertex_slices, 124 | edge_slices, 125 | batch, 126 | n_nodes, 127 | return_filtration=False, 128 | ): 129 | """Return persistence pairs (i.e. generators).""" 130 | # Apply filtrations to node attributes. For the edge values, we 131 | # use a sublevel set filtration. 132 | # 133 | # TODO: Support different ways of filtering? 134 | filtered_v = self.filtrations(x) 135 | filtered_e, _ = torch.max( 136 | torch.stack( 137 | (filtered_v[edge_index[0]], filtered_v[edge_index[1]]) 138 | ), 139 | axis=0, 140 | ) 141 | 142 | filtered_v = filtered_v.transpose(1, 0).cpu().contiguous() 143 | filtered_e = filtered_e.transpose(1, 0).cpu().contiguous() 144 | edge_index = edge_index.cpu().transpose(1, 0).contiguous() 145 | 146 | # TODO: Do we have to enforce contiguous indices here? 147 | vertex_index = torch.arange(end=n_nodes, dtype=torch.int) 148 | 149 | # Fill all persistence information at the same time. 150 | persistence_diagrams = torch.empty( 151 | (self.n_filtrations, n_nodes, 2), 152 | dtype=torch.float, 153 | ) 154 | 155 | for filt_index in range(self.n_filtrations): 156 | for (vi, vj), (ei, ej) in zip( 157 | pairwise(vertex_slices), pairwise(edge_slices) 158 | ): 159 | vertices = vertex_index[vi:vj] 160 | edges = edge_index[ei:ej] 161 | 162 | offset = vi 163 | 164 | f_vertices = filtered_v[filt_index][vi:vj] 165 | f_edges = filtered_e[filt_index][ei:ej] 166 | 167 | persistence_diagram = self._compute_persistent_homology( 168 | vertices, f_vertices, edges, f_edges, offset 169 | ) 170 | 171 | persistence_diagrams[filt_index, vi:vj] = persistence_diagram 172 | 173 | # Make sure that the tensor is living on the proper device here; 174 | # all subsequent operations can happen either on the CPU *or* on 175 | # the GPU. 176 | persistence_diagrams = persistence_diagrams.to(x.device) 177 | return persistence_diagrams 178 | 179 | # Helper function for doing the actual calculation of topological 180 | # features of a graph. 181 | def _compute_persistent_homology( 182 | self, vertices, f_vertices, edges, f_edges, offset 183 | ): 184 | assert len(vertices) == len(f_vertices) 185 | assert len(edges) == len(f_edges) 186 | 187 | st = gd.SimplexTree() 188 | 189 | for v, f in zip(vertices, f_vertices): 190 | st.insert([v], filtration=f) 191 | 192 | for (u, v), f in zip(edges, f_edges): 193 | st.insert([u, v], filtration=f) 194 | 195 | st.make_filtration_non_decreasing() 196 | st.expansion(2) 197 | st.persistence() 198 | 199 | # The generators are split into "regular" and "essential" 200 | # vertices, sorted by dimension. 201 | generators = st.lower_star_persistence_generators() 202 | generators_regular, generators_essential = generators 203 | 204 | # TODO: Let's think about how to leverage *all* generators in 205 | # *all* dimensions. 206 | generators_regular = torch.as_tensor(generators_regular[0]) 207 | generators_regular = generators_regular - offset 208 | generators_regular = generators_regular.sort(dim=0, stable=True)[0] 209 | 210 | # By default, every vertex is paired with itself, so we just 211 | # duplicate the information here. 212 | persistence_diagram = torch.stack((f_vertices, f_vertices), dim=1) 213 | 214 | # Map generators back to filtration values, thus adding non-trivial 215 | # tuples to the persistence diagram while preserving gradients. 216 | if len(generators_regular) > 0: 217 | persistence_diagram[generators_regular[:, 0], 1] = f_vertices[ 218 | generators_regular[:, 1] 219 | ] 220 | 221 | return persistence_diagram 222 | 223 | def forward(self, x, data): 224 | """Implement forward pass through data.""" 225 | # TODO: Is this the best signature? `data` is following directly 226 | # the convention of `PyG`. 227 | # 228 | # x : current node attributes of layer; we should not use the 229 | # original attributes here because they are not informed by a 230 | # previous layer. 231 | # 232 | # data : edge slice information etc. 233 | 234 | edge_index = data.edge_index 235 | 236 | vertex_slices = torch.Tensor(data._slice_dict["x"]).long() 237 | edge_slices = torch.Tensor(data._slice_dict["edge_index"]).long() 238 | batch = data.batch 239 | 240 | persistence_pairs = self.compute_persistent_homology( 241 | x, 242 | edge_index, 243 | vertex_slices, 244 | edge_slices, 245 | batch, 246 | n_nodes=data.num_nodes, 247 | ) 248 | 249 | x0 = persistence_pairs.permute(1, 0, 2).reshape( 250 | persistence_pairs.shape[1], -1 251 | ) 252 | 253 | for layer in self.set_fn: 254 | # Preserve batch information for our set function layer 255 | # instead of treating all inputs the same. 256 | if isinstance(layer, DeepSetLayer): 257 | x0 = layer(x0, batch) 258 | else: 259 | x0 = layer(x0) 260 | 261 | # TODO: Residual step; could be made optional. Plus, the optimal 262 | # order of operations is not clear. 263 | x = x + self.batch_norm(nn.functional.relu(x0)) 264 | return x 265 | 266 | 267 | class TopoGCN(torch.nn.Module): 268 | def __init__(self): 269 | super().__init__() 270 | 271 | self.layers = nn.ModuleList([GCNConv(1, 8), GCNConv(8, 2)]) 272 | 273 | self.pooling_fn = global_mean_pool 274 | self.togl = TOGL(8, 16, 32, 16, "mean") 275 | 276 | def forward(self, data): 277 | x, edge_index = data.x, data.edge_index 278 | 279 | for layer in self.layers[:1]: 280 | x = layer(x, edge_index) 281 | 282 | x = self.togl(x, data) 283 | 284 | for layer in self.layers[1:]: 285 | x = layer(x, edge_index) 286 | 287 | x = self.pooling_fn(x, data.batch) 288 | return x 289 | 290 | 291 | B = 64 292 | N = 100 293 | p = 0.2 294 | 295 | if torch.cuda.is_available(): 296 | dev = "cuda:0" 297 | else: 298 | dev = "cpu" 299 | 300 | print("Selected device:", dev) 301 | 302 | data_list = [ 303 | Data(x=torch.rand(N, 1), edge_index=erdos_renyi_graph(N, p), num_nodes=N) 304 | for i in range(B) 305 | ] 306 | 307 | loader = DataLoader(data_list, batch_size=8) 308 | 309 | model = TopoGCN().to(dev) 310 | 311 | for index, batch in enumerate(loader): 312 | print(batch) 313 | batch = batch.to(dev) 314 | 315 | vertex_slices = torch.Tensor(batch._slice_dict["x"]).long() 316 | edge_slices = torch.Tensor(batch._slice_dict["edge_index"]).long() 317 | 318 | model(batch) 319 | -------------------------------------------------------------------------------- /torch_topological/nn/layers.py: -------------------------------------------------------------------------------- 1 | """Layers for processing persistence diagrams.""" 2 | 3 | import torch 4 | 5 | 6 | class StructureElementLayer(torch.nn.Module): 7 | def __init__( 8 | self, 9 | n_elements 10 | ): 11 | super().__init__() 12 | 13 | self.n_elements = n_elements 14 | self.dim = 2 # TODO: Make configurable 15 | 16 | size = (self.n_elements, self.dim) 17 | 18 | self.centres = torch.nn.Parameter( 19 | torch.rand(*size) 20 | ) 21 | 22 | self.sharpness = torch.nn.Parameter( 23 | torch.ones(*size) * 3 24 | ) 25 | 26 | def forward(self, x): 27 | batch = torch.cat([x] * self.n_elements, 1) 28 | 29 | B, N, D = x.shape 30 | 31 | # This is a 'butchered' variant of the much nicer `SLayerExponential` 32 | # class by C. Hofer and R. Kwitt. 33 | # 34 | # https://c-hofer.github.io/torchph/_modules/torchph/nn/slayer.html#SLayerExponential 35 | 36 | centres = torch.cat([self.centres] * N, 1) 37 | centres = centres.view(-1, self.dim) 38 | centres = torch.stack([centres] * B, 0) 39 | centres = torch.cat((centres, 2 * batch[..., -1].unsqueeze(-1)), 2) 40 | 41 | sharpness = torch.pow(self.sharpness, 2) 42 | sharpness = torch.cat([sharpness] * N, 1) 43 | sharpness = sharpness.view(-1, self.dim) 44 | sharpness = torch.stack([sharpness] * B, 0) 45 | sharpness = torch.cat( 46 | ( 47 | sharpness, 48 | torch.ones_like(batch[..., -1].unsqueeze(-1)) 49 | ), 50 | 2 51 | ) 52 | 53 | x = centres - batch 54 | x = x.pow(2) 55 | x = torch.mul(x, sharpness) 56 | x = torch.nansum(x, 2) 57 | x = torch.exp(-x) 58 | x = x.view(B, self.n_elements, -1) 59 | x = torch.sum(x, 2) 60 | x = x.squeeze() 61 | 62 | return x 63 | -------------------------------------------------------------------------------- /torch_topological/nn/loss.py: -------------------------------------------------------------------------------- 1 | """Loss terms for various optimisation objectives.""" 2 | 3 | import torch 4 | 5 | from torch_topological.utils import is_iterable 6 | 7 | 8 | class SummaryStatisticLoss(torch.nn.Module): 9 | r"""Implement loss based on summary statistic. 10 | 11 | This is a generic loss function based on topological summary 12 | statistics. It implements a loss of the following form: 13 | 14 | .. math:: \|s(X) - s(Y)\|^p 15 | 16 | In the preceding equation, `s` refers to a function that results in 17 | a scalar-valued summary of a persistence diagram. 18 | """ 19 | 20 | def __init__(self, summary_statistic="total_persistence", **kwargs): 21 | """Create new loss function based on summary statistic. 22 | 23 | Parameters 24 | ---------- 25 | summary_statistic : str 26 | Indicates which summary statistic function to use. Must be 27 | a summary statistics function that exists in the utilities 28 | module, i.e. :mod:`torch_topological.utils`. 29 | 30 | At present, the following choices are valid: 31 | 32 | - `torch_topological.utils.persistent_entropy` 33 | - `torch_topological.utils.polynomial_function` 34 | - `torch_topological.utils.total_persistence` 35 | - `torch_topological.utils.p_norm` 36 | 37 | **kwargs 38 | Optional keyword arguments, to be passed to the 39 | summary statistic function. 40 | """ 41 | super().__init__() 42 | 43 | self.p = kwargs.get("p", 1.0) 44 | self.kwargs = kwargs 45 | 46 | import torch_topological.utils.summary_statistics as stat 47 | 48 | self.stat_fn = getattr(stat, summary_statistic, None) 49 | 50 | def forward(self, X, Y=None): 51 | r"""Calculate loss based on input tensor(s). 52 | 53 | Parameters 54 | ---------- 55 | X : list of :class:`PersistenceInformation` 56 | Source information. Supposed to contain persistence diagrams 57 | and persistence pairings. 58 | 59 | Y : list of :class:`PersistenceInformation` or `None` 60 | Optional target information. If set, evaluates a difference 61 | in loss functions as shown in the introduction. If `None`, 62 | a simpler variant of the loss will be evaluated. 63 | 64 | Returns 65 | ------- 66 | torch.tensor 67 | Loss based on the summary statistic selected by the client. 68 | Given a statistic :math:`s`, the function returns the 69 | following expression: 70 | 71 | .. math:: \|s(X) - s(Y)\|^p 72 | 73 | In case no target tensor `Y` has been provided, the latter part 74 | of the expression amounts to `0`. 75 | """ 76 | stat_src = self._evaluate_stat_fn(X) 77 | 78 | if Y is not None: 79 | stat_target = self._evaluate_stat_fn(Y) 80 | return (stat_target - stat_src).abs().pow(self.p) 81 | else: 82 | return stat_src.abs().pow(self.p) 83 | 84 | def _evaluate_stat_fn(self, X): 85 | """Evaluate statistic function for a given tensor.""" 86 | return torch.sum( 87 | torch.stack( 88 | [ 89 | self.stat_fn(pers_info.diagram, **self.kwargs) 90 | for pers_info in X 91 | ] 92 | ) 93 | ) 94 | 95 | 96 | class SignatureLoss(torch.nn.Module): 97 | """Implement topological signature loss. 98 | 99 | This module implements the topological signature loss first 100 | described in [Moor20a]_. In contrast to the original code provided 101 | by the authors, this module also provides extensions to 102 | higher-dimensional generators if desired. 103 | 104 | The module can be used in conjunction with any set of generators and 105 | persistence diagrams, i.e. with any set of persistence pairings and 106 | persistence diagrams. At the moment, it is restricted to calculating 107 | a Minkowski distances for the loss calculation. 108 | 109 | References 110 | ---------- 111 | .. [Moor20a] M. Moor et al., "Topological Autoencoders", 112 | *Proceedings of the 37th International Conference on Machine 113 | Learning*, PMLR 119, pp. 7045--7054, 2020. 114 | """ 115 | 116 | def __init__(self, p=2, normalise=True, dimensions=0): 117 | """Create new loss instance. 118 | 119 | Parameters 120 | ---------- 121 | p : float 122 | Exponent for the `p`-norm calculation of distances. 123 | 124 | normalise : bool 125 | If set, normalises distances for each point cloud. This can 126 | be useful when working with batches. 127 | 128 | dimensions : int or tuple of int 129 | Dimensions to use in the signature calculation. Following 130 | [Moor20a]_, this is set by default to `0`. 131 | """ 132 | super().__init__() 133 | 134 | self.p = p 135 | self.normalise = normalise 136 | self.dimensions = dimensions 137 | 138 | # Ensure that we can iterate over the dimensions later on, as 139 | # this simplifies the code. 140 | if not is_iterable(self.dimensions): 141 | self.dimensions = [self.dimensions] 142 | 143 | def forward(self, X, Y): 144 | """Calculate the signature loss between two point clouds. 145 | 146 | This loss function uses the persistent homology from each point 147 | cloud in order to retrieve the topologically relevant distances 148 | from a distance matrix calculated from the point clouds. For 149 | more information, see [Moor20a]_. 150 | 151 | Parameters 152 | ---------- 153 | X: Tuple[torch.tensor, PersistenceInformation] 154 | A tuple consisting of the point cloud and the persistence 155 | information of the point cloud. The persistent information 156 | is calculated by performing persistent homology calculation 157 | to retrieve a list of topologically relevant edges. 158 | 159 | Y: Tuple[torch.tensor, PersistenceInformation] 160 | A tuple consisting of the point cloud and the persistence 161 | information of the point cloud. The persistent information 162 | is calculated by performing persistent homology calculation 163 | to retrieve a list of topologically relevant edges. 164 | 165 | Returns 166 | ------- 167 | torch.tensor 168 | A scalar representing the topological loss term for the two 169 | data sets. 170 | """ 171 | X_point_cloud, X_persistence_info = X 172 | Y_point_cloud, Y_persistence_info = Y 173 | 174 | # Calculate the pairwise distance matrix between points in the 175 | # point cloud. Distances are calculated using the p-norm. 176 | X_pairwise_dist = torch.cdist(X_point_cloud, X_point_cloud, self.p) 177 | Y_pairwise_dist = torch.cdist(Y_point_cloud, Y_point_cloud, self.p) 178 | 179 | if self.normalise: 180 | X_pairwise_dist = X_pairwise_dist / X_pairwise_dist.max() 181 | Y_pairwise_dist = Y_pairwise_dist / Y_pairwise_dist.max() 182 | 183 | # Using the topologically relevant edges from point cloud X, 184 | # retrieve the corresponding distances from the pairwise 185 | # distance matrix of X. 186 | X_sig_X = [ 187 | self._select_distances( 188 | X_pairwise_dist, X_persistence_info[dim].pairing 189 | ) 190 | for dim in self.dimensions 191 | ] 192 | 193 | # Using the topologically relevant edges from point cloud Y, 194 | # retrieve the corresponding distances from the pairwise 195 | # distance matrix of X. 196 | X_sig_Y = [ 197 | self._select_distances( 198 | X_pairwise_dist, Y_persistence_info[dim].pairing 199 | ) 200 | for dim in self.dimensions 201 | ] 202 | 203 | # Using the topologically relevant edges from point cloud X, 204 | # retrieve the corresponding distances from the pairwise 205 | # distance matrix of Y. 206 | Y_sig_X = [ 207 | self._select_distances( 208 | Y_pairwise_dist, X_persistence_info[dim].pairing 209 | ) 210 | for dim in self.dimensions 211 | ] 212 | 213 | # Using the topologically relevant edges from point cloud Y, 214 | # retrieve the corresponding distances from the pairwise 215 | # distance matrix of Y. 216 | Y_sig_Y = [ 217 | self._select_distances( 218 | Y_pairwise_dist, Y_persistence_info[dim].pairing 219 | ) 220 | for dim in self.dimensions 221 | ] 222 | 223 | XY_dist = self._partial_distance(X_sig_X, Y_sig_X) 224 | YX_dist = self._partial_distance(Y_sig_Y, X_sig_Y) 225 | 226 | return torch.stack(XY_dist).sum() + torch.stack(YX_dist).sum() 227 | 228 | def _select_distances(self, pairwise_distance_matrix, generators): 229 | """Select topologically relevant edges from a pairwise distance matrix. 230 | 231 | Parameters 232 | ---------- 233 | pairwise_distance_matrix: torch.tensor 234 | NxN pairwise distance matrix of a point cloud. 235 | 236 | generators: np.ndarray 237 | A 2D array consisting of indices corresponding to edges that 238 | correspond to the birth/destruction of some topological 239 | feature during persistent homology calculation. If the 240 | generator corresponds to topological features in 241 | 0-dimension, i.e. connected components, we only consider the 242 | edges that destroy connected components (we do not consider 243 | vertices). If the generator corresponds to topological 244 | features in > 0 dimensions, e.g holes or voids, we consider 245 | edges that create/destroy such topological features. 246 | 247 | Returns 248 | ------- 249 | torch.tensor 250 | A vector that contains all of the topologically relevant 251 | distances. 252 | """ 253 | # Dimension 0: only a mapping of vertices--edges is present, and 254 | # we must *only* access the edges. 255 | if generators.shape[1] == 3: 256 | selected_distances = pairwise_distance_matrix[ 257 | generators[:, 1], generators[:, 2] 258 | ] 259 | 260 | # Dimension > 0: we can access all distances 261 | else: 262 | creator_distances = pairwise_distance_matrix[ 263 | generators[:, 0], generators[:, 1] 264 | ] 265 | destroyer_distances = pairwise_distance_matrix[ 266 | generators[:, 2], generators[:, 3] 267 | ] 268 | 269 | # Need to use `torch.abs` here because of the way the 270 | # signature lookup works. We are *not* guaranteed to 271 | # get 'valid' persistence values when using a pairing 272 | # from space X to access distances from space Y, for 273 | # instance, hence some of values could be *negative*. 274 | selected_distances = torch.abs( 275 | destroyer_distances - creator_distances 276 | ) 277 | 278 | return selected_distances 279 | 280 | def _partial_distance(self, A, B): 281 | """ 282 | Calculate partial distances between pairings. 283 | 284 | The purpose of this function is to calculate a partial distance 285 | for the loss, depending on distances selected from the pairing. 286 | """ 287 | dist = [ 288 | 0.5 * torch.linalg.vector_norm(a - b, ord=self.p) 289 | for a, b in zip(A, B) 290 | ] 291 | 292 | return dist 293 | -------------------------------------------------------------------------------- /torch_topological/nn/multi_scale_kernel.py: -------------------------------------------------------------------------------- 1 | """Contains multi-scale kernel (scale space) kernel module.""" 2 | 3 | import torch 4 | 5 | from torch_topological.utils import wrap_if_not_iterable 6 | 7 | 8 | class MultiScaleKernel(torch.nn.Module): 9 | # TODO: more detailed description 10 | r"""Implement the multi-scale kernel between two persistence diagrams. 11 | 12 | This class implements the multi-scale kernel between two persistence 13 | diagrams (also known as the scale space kernel) as defined by 14 | Reininghaus et al. [Reininghaus15a]_ as 15 | 16 | .. math:: 17 | k_\sigma(F,G) = \frac{1}{8 \pi \sigma} 18 | \sum_{\substack{p \in F\\q \in G}} exp{-\frac{\|p-q\|^2}{8\sigma}} 19 | - exp{-\frac{\|p-\overline{q}\|^2}{8\sigma}} 20 | 21 | where :math:`z=(z_1, z_2)` and :math:`\overline{z}=(z_2, z_1)` 22 | 23 | References 24 | ---------- 25 | .. [Reininghaus15a] J. Reininghaus, U. Bauer and R. Kwitt, "A Stable 26 | Multi-Scale Kernel for Topological Machine Learning", *Proceedings 27 | of the IEEE Conference on Computer Vision and Pattern Recognition*, 28 | pp. 4741--4748, 2015. 29 | """ 30 | 31 | def __init__(self, sigma): 32 | """Create new instance of the kernel. 33 | 34 | Parameters 35 | ---------- 36 | sigma : float 37 | scale parameter of the kernel 38 | """ 39 | super().__init__() 40 | 41 | self.sigma = sigma 42 | 43 | @staticmethod 44 | def _check_upper(d): 45 | # Check if all points in the diagram are above the diagonal. 46 | # All points below the diagonal are 'swapped'. 47 | is_upper = d[:, 0] < d[:, 1] 48 | if not torch.all(is_upper): 49 | d[~is_upper, 0] = d[~is_upper, 1] 50 | d[~is_upper, 1] = d[~is_upper, 0] 51 | return d 52 | 53 | @staticmethod 54 | def _mirror(x): 55 | # Mirror one or multiple points of a persistence 56 | # diagram at the diagonal 57 | if len(x.shape) > 1: 58 | return x[:, [1, 0]] 59 | # only a single point in the diagram 60 | return x[[1, 0]] 61 | 62 | @staticmethod 63 | def _dist(x, y, p): 64 | # Compute the point-wise lp-distance between two 65 | # persistence diagrams 66 | dist = torch.cdist(x, y, p=p) 67 | return dist.pow(2) 68 | 69 | def forward(self, X, Y, p=2.): 70 | """Calculate the kernel value between two persistence diagrams. 71 | 72 | The kernel value is computed for each dimension of the persistence 73 | diagram individually, according to Equation 10 from Reininghaus et al. 74 | The final kernel value is computed as the sum of kernel values over 75 | all dimensions. 76 | 77 | Parameters 78 | ---------- 79 | X : list or instance of :class:`PersistenceInformation` 80 | Topological features of the first space. Supposed to 81 | contain persistence diagrams and persistence pairings. 82 | 83 | Y : list or instance of :class:`PersistenceInformation` 84 | Topological features of the second space. Supposed to 85 | contain persistence diagrams and persistence pairings. 86 | 87 | p : float or inf, default 2. 88 | Specify which p-norm to use for distance calculation. 89 | For infinity/maximum norm pass p=float('inf'). 90 | Please note that using norms other than the 2-norm 91 | (Euclidean norm) are not guaranteed to give positive 92 | definite results. 93 | 94 | Returns 95 | ------- 96 | torch.tensor 97 | A single scalar tensor containing the kernel value between the 98 | persistence diagram(s) contained in `X` and `Y`. 99 | 100 | Examples 101 | -------- 102 | >>> from torch_topological.data.shapes import sample_from_disk 103 | >>> from torch_topological.nn import VietorisRipsComplex 104 | >>> # sample randomly from two disks 105 | >>> x = sample_from_disk(r=0.5, R=0.6, n=100) 106 | >>> y = sample_from_disk(r=0.9, R=1.0, n=100) 107 | >>> # compute vietoris rips filtration for both point clouds 108 | >>> vr = VietorisRipsComplex(dim=1) 109 | >>> vr_x = vr(x) 110 | >>> vr_y = vr(y) 111 | >>> # compute kernel value between persistence 112 | >>> # diagrams with sigma set to 1 113 | >>> msk = MultiScaleKernel(1.) 114 | >>> msk_value = msk(vr_x, vr_y) 115 | """ 116 | X_ = wrap_if_not_iterable(X) 117 | Y_ = wrap_if_not_iterable(Y) 118 | 119 | k_sigma = 0.0 120 | 121 | for pers_info in zip(X_, Y_): 122 | # ensure that all points in the diagram are 123 | # above the diagonal 124 | D1 = self._check_upper(pers_info[0].diagram) 125 | D2 = self._check_upper(pers_info[1].diagram) 126 | 127 | # compute the pairwise distances between the 128 | # two diagrams 129 | nom = self._dist(D1, D2, p) 130 | # distance between diagram 1 and mirrored 131 | # diagram 2 132 | denom = self._dist(D1, self._mirror(D2), p) 133 | 134 | M = torch.exp(-nom / (8 * self.sigma)) 135 | M -= torch.exp(-denom / (8 * self.sigma)) 136 | 137 | # sum over all points 138 | k_sigma += M.sum() / (8. * self.sigma * torch.pi) 139 | 140 | return k_sigma 141 | -------------------------------------------------------------------------------- /torch_topological/nn/sliced_wasserstein_distance.py: -------------------------------------------------------------------------------- 1 | """Sliced Wasserstein distance implementation.""" 2 | 3 | 4 | import torch 5 | 6 | import numpy as np 7 | 8 | from torch_topological.utils import wrap_if_not_iterable 9 | 10 | 11 | class SlicedWassersteinDistance(torch.nn.Module): 12 | """Calculate sliced Wasserstein distance between persistence diagrams. 13 | 14 | This is an implementation of the sliced Wasserstein distance between 15 | persistence diagrams, following [Carriere17a]_. 16 | 17 | This module calculates the sliced Wasserstein distance between two 18 | persistence diagrams. It is an efficient variant of the Wasserstein 19 | distance, and it is commonly used in the Sliced Wasserstein Kernel. 20 | It computes the expected value of the Wasserstein distance when the 21 | persistence diagram is projected on a random line passing through 22 | the origin. 23 | """ 24 | 25 | def __init__(self, num_directions=10): 26 | """Create new sliced Wasserstein distance calculation module. 27 | 28 | Parameters 29 | ---------- 30 | num_directions : int 31 | Specifies the number of random directions to be sampled for 32 | computation of the sliced Wasserstein distance. 33 | """ 34 | super().__init__() 35 | 36 | # Generates num_directions number of lines with slopes randomly sampled 37 | # between -pi/2 and pi/2. 38 | self.num_directions = num_directions 39 | thetas = torch.linspace(-np.pi/2, np.pi/2, steps=self.num_directions+1) 40 | thetas = thetas[:-1] 41 | self.lines = torch.vstack([torch.tensor([torch.cos(i), torch.sin(i)], 42 | dtype=torch.float32) for i in thetas]) 43 | 44 | def _emd1d(self, X, Y): 45 | # Compute Wasserstein Distance between two 1d-distributions. 46 | X, ind = torch.sort(X, dim=0) 47 | Y, ind = torch.sort(Y, dim=0) 48 | return torch.sum(torch.abs(torch.sub(X, Y))) 49 | 50 | def _project_diagram(self, D1, L): 51 | # Project persistence diagram D1 onto a given line L. 52 | return torch.stack([torch.dot(x, L)/torch.dot(L, L) for x in D1]) 53 | 54 | def forward(self, X, Y): 55 | """Calculate sliced Wasserstein metric based on input tensors. 56 | 57 | Parameters 58 | ---------- 59 | X : list or instance of :class:`PersistenceInformation` 60 | Topological features of the first space. Supposed to contain 61 | persistence diagrams and persistence pairings. 62 | 63 | Y : list or instance of :class:`PersistenceInformation` 64 | Topological features of the second space. Supposed to 65 | contain persistence diagrams and persistence pairings. 66 | 67 | Returns 68 | ------- 69 | torch.tensor 70 | A single scalar tensor containing the sliced Wasserstein distance 71 | between the persistence diagram(s) contained in `X` and `Y`. 72 | """ 73 | total_cost = 0.0 74 | 75 | X = wrap_if_not_iterable(X) 76 | Y = wrap_if_not_iterable(Y) 77 | 78 | for pers_info in zip(X, Y): 79 | D1 = pers_info[0].diagram.float() 80 | D2 = pers_info[1].diagram.float() 81 | 82 | # Auxiliary array to project onto diagonal. 83 | diag = torch.tensor([0.5, 0.5], dtype=torch.float32) 84 | 85 | # Project both the diagrams onto the diagonals. 86 | D1_diag = torch.sum(D1, dim=1, keepdim=True) * diag 87 | D2_diag = torch.sum(D2, dim=1, keepdim=True) * diag 88 | 89 | cost = 0.0 90 | 91 | for line in self.lines: 92 | proj_d1 = self._project_diagram(D1, line) 93 | proj_d2 = self._project_diagram(D2, line) 94 | 95 | proj_diag_d1 = self._project_diagram(D1_diag, line) 96 | proj_diag_d2 = self._project_diagram(D2_diag, line) 97 | 98 | cost += self._emd1d(torch.cat([proj_d1, proj_diag_d2]), 99 | torch.cat([proj_d2, proj_diag_d1])) 100 | 101 | cost /= self.num_directions 102 | 103 | total_cost += cost 104 | 105 | return total_cost 106 | -------------------------------------------------------------------------------- /torch_topological/nn/sliced_wasserstein_kernel.py: -------------------------------------------------------------------------------- 1 | """Sliced Wasserstein kernel implementation.""" 2 | 3 | 4 | import torch 5 | 6 | from torch_topological.nn import SlicedWassersteinDistance 7 | 8 | from torch_topological.utils import wrap_if_not_iterable 9 | 10 | 11 | class SlicedWassersteinKernel(torch.nn.Module): 12 | """Calculate sliced Wasserstein kernel between persistence diagrams. 13 | 14 | This is an implementation of the sliced Wasserstein kernel between 15 | persistence diagrams, following [Carriere17a]_. 16 | 17 | References 18 | ---------- 19 | .. [Carriere17a] M. Carrière et al., "Sliced Wasserstein Kernel for 20 | Persistence Diagrams", *Proceedings of the 34th International 21 | Conference on Machine Learning*, PMLR 70, pp. 664--673, 2017. 22 | """ 23 | 24 | def __init__(self, num_directions=10, sigma=1.0): 25 | """Create new sliced Wasserstein kernel module. 26 | 27 | Parameters 28 | ---------- 29 | num_directions : int 30 | Specifies the number of random directions to be sampled for 31 | computation of the sliced Wasserstein distance. 32 | 33 | sigma : int 34 | Variance term of the sliced Wasserstein kernel expression. 35 | """ 36 | super().__init__() 37 | 38 | self.num_directions = num_directions 39 | self.sigma = sigma 40 | 41 | def forward(self, X, Y): 42 | """Calculate sliced Wasserstein kernel based on input tensors. 43 | 44 | Parameters 45 | ---------- 46 | X : list or instance of :class:`PersistenceInformation` 47 | Topological features of the first space. Supposed to contain 48 | persistence diagrams and persistence pairings. 49 | 50 | Y : list or instance of :class:`PersistenceInformation` 51 | Topological features of the second space. Supposed to 52 | contain persistence diagrams and persistence pairings. 53 | 54 | Returns 55 | ------- 56 | torch.tensor 57 | A single scalar tensor containing the sliced Wasserstein kernel 58 | between the persistence diagram(s) contained in `X` and `Y`. 59 | """ 60 | total_cost = 0.0 61 | 62 | X = wrap_if_not_iterable(X) 63 | Y = wrap_if_not_iterable(Y) 64 | 65 | swd = SlicedWassersteinDistance(num_directions=self.num_directions) 66 | 67 | for pers_info in zip(X, Y): 68 | D1 = pers_info[0] 69 | D2 = pers_info[1] 70 | 71 | total_cost += torch.exp(-swd(D1, D2)) / self.sigma 72 | 73 | return total_cost 74 | -------------------------------------------------------------------------------- /torch_topological/nn/vietoris_rips_complex.py: -------------------------------------------------------------------------------- 1 | """Vietoris--Rips complex calculation module(s).""" 2 | 3 | from itertools import starmap 4 | 5 | from gph import ripser_parallel 6 | from torch import nn 7 | 8 | from torch_topological.nn import PersistenceInformation 9 | from torch_topological.nn.data import batch_handler 10 | 11 | import numpy 12 | import torch 13 | 14 | 15 | class VietorisRipsComplex(nn.Module): 16 | """Calculate Vietoris--Rips complex of a data set. 17 | 18 | This module calculates 'differentiable' persistence diagrams for 19 | point clouds. The underlying topological approximations are done 20 | by calculating a Vietoris--Rips complex of the data. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | dim=1, 26 | p=2, 27 | threshold=numpy.inf, 28 | keep_infinite_features=False, 29 | **kwargs 30 | ): 31 | """Initialise new module. 32 | 33 | Parameters 34 | ---------- 35 | dim : int 36 | Calculates persistent homology up to (and including) the 37 | prescribed dimension. 38 | 39 | p : float 40 | Exponent indicating which Minkowski `p`-norm to use for the 41 | calculation of pairwise distances between points. Note that 42 | if `treat_as_distances` is supplied to :func:`forward`, the 43 | parameter is ignored and will have no effect. The rationale 44 | is to permit clients to use a pre-computed distance matrix, 45 | while always falling back to Minkowski norms. 46 | 47 | threshold : float 48 | If set to a finite number, only calculates topological 49 | features up to the specified distance threshold. Thus, 50 | any persistence pairings may contain infinite features 51 | as well. 52 | 53 | keep_infinite_features : bool 54 | If set, keeps infinite features. This flag is disabled by 55 | default. The rationale for this is that infinite features 56 | require more deliberate handling and, in case `threshold` 57 | is not changed, only a *single* infinite feature will not 58 | be considered in subsequent calculations. 59 | 60 | **kwargs 61 | Additional arguments to be provided to ``ripser``, i.e. the 62 | backend for calculating persistent homology. The `n_threads` 63 | parameter, which controls parallelisation, is probably the 64 | most relevant parameter to be adjusted. 65 | Please refer to the `the gitto-ph documentation 66 | `_ 67 | for more details on admissible parameters. 68 | 69 | Notes 70 | ----- 71 | This module currently only supports Minkowski norms. It does not 72 | yet support other metrics internally. To use custom metrics, you 73 | need to set `treat_as_distances` in the :func:`forward` function 74 | instead. 75 | """ 76 | super().__init__() 77 | 78 | self.dim = dim 79 | self.p = p 80 | self.threshold = threshold 81 | self.keep_infinite_features = keep_infinite_features 82 | 83 | # Ensures that the same parameters are used whenever calling 84 | # `ripser`. 85 | self.ripser_params = { 86 | 'return_generators': True, 87 | 'maxdim': self.dim, 88 | 'thresh': self.threshold 89 | } 90 | 91 | self.ripser_params.update(kwargs) 92 | 93 | def forward(self, x, treat_as_distances=False): 94 | """Implement forward pass for persistence diagram calculation. 95 | 96 | The forward pass entails calculating persistent homology on 97 | a point cloud and returning a set of persistence diagrams. 98 | 99 | Parameters 100 | ---------- 101 | x : array_like 102 | Input point cloud(s). `x` can either be a 2D array of shape 103 | `(n, d)`, which is treated as a single point cloud, or a 3D 104 | array/tensor of the form `(b, n, d)`, with `b` representing 105 | the batch size. Alternatively, you may also specify a list, 106 | possibly containing point clouds of non-uniform sizes. 107 | 108 | treat_as_distances : bool 109 | If set, treats `x` as containing pre-computed distances 110 | between points. The semantics of how `x` is handled are 111 | not changed; the only difference is that when `x` has a 112 | shape of `(n, d)`, the values of `n` and `d` need to be 113 | the same. 114 | 115 | Returns 116 | ------- 117 | list of :class:`PersistenceInformation` 118 | List of :class:`PersistenceInformation`, containing both the 119 | persistence diagrams and the generators, i.e. the 120 | *pairings*, of a certain dimension of topological features. 121 | If `x` is a 3D array, returns a list of lists, in which the 122 | first dimension denotes the batch and the second dimension 123 | refers to the individual instances of 124 | :class:`PersistenceInformation` elements. 125 | 126 | Generators will be represented in the persistence pairing 127 | based on vertex--edge pairs (dimension 0) or edge--edge 128 | pairs. Thus, the persistence pairing in dimension zero will 129 | have three components, corresponding to a vertex and an 130 | edge, respectively, while the persistence pairing for higher 131 | dimensions will have four components. 132 | """ 133 | return batch_handler( 134 | x, 135 | self._forward, 136 | treat_as_distances=treat_as_distances 137 | ) 138 | 139 | def _forward(self, x, treat_as_distances=False): 140 | """Handle a *single* point cloud. 141 | 142 | This internal function handles the calculation of topological 143 | features for a single point cloud, i.e. an `array_like` of 2D 144 | shape. 145 | 146 | Parameters 147 | ---------- 148 | x : array_like of shape `(n, d)` 149 | Single input point cloud. 150 | 151 | treat_as_distances : bool 152 | Flag indicating whether `x` should be treated as a distance 153 | matrix. See :func:`forward` for more information. 154 | 155 | Returns 156 | ------- 157 | list of class:`PersistenceInformation` 158 | List of persistence information data structures, containing 159 | the persistence diagram and the persistence pairing of some 160 | dimension in the input data set. 161 | """ 162 | if treat_as_distances: 163 | distances = x 164 | else: 165 | distances = torch.cdist(x, x, p=self.p) 166 | 167 | generators = ripser_parallel( 168 | distances.cpu().detach().numpy(), 169 | metric='precomputed', 170 | **self.ripser_params 171 | )['gens'] 172 | 173 | # We always have 0D information. 174 | persistence_information = \ 175 | self._extract_generators_and_diagrams( 176 | distances, 177 | generators, 178 | dim0=True, 179 | ) 180 | 181 | if self.keep_infinite_features: 182 | persistence_information_inf = \ 183 | self._extract_generators_and_diagrams( 184 | distances, 185 | generators, 186 | finite=False, 187 | dim0=True, 188 | ) 189 | 190 | # Check whether we have any higher-dimensional information that 191 | # we should return. 192 | if self.dim >= 1: 193 | persistence_information.extend( 194 | self._extract_generators_and_diagrams( 195 | distances, 196 | generators, 197 | dim0=False, 198 | ) 199 | ) 200 | 201 | if self.keep_infinite_features: 202 | persistence_information_inf.extend( 203 | self._extract_generators_and_diagrams( 204 | distances, 205 | generators, 206 | finite=False, 207 | dim0=False, 208 | ) 209 | ) 210 | 211 | # Concatenation is only necessary if we want to keep infinite 212 | # features. 213 | if self.keep_infinite_features: 214 | persistence_information = self._concatenate_features( 215 | persistence_information, persistence_information_inf 216 | ) 217 | 218 | return persistence_information 219 | 220 | def _extract_generators_and_diagrams( 221 | self, 222 | dist, 223 | gens, 224 | finite=True, 225 | dim0=False 226 | ): 227 | """Extract generators and persistence diagrams from raw data. 228 | 229 | This convenience function translates between the output of 230 | `ripser_parallel` and the required output of this function. 231 | """ 232 | index = 1 if not dim0 else 0 233 | 234 | # Perform index shift to find infinite features in the tensor. 235 | if not finite: 236 | index += 2 237 | 238 | gens = gens[index] 239 | 240 | if dim0: 241 | if finite: 242 | # In a Vietoris--Rips complex, all vertices are created at 243 | # time zero. 244 | creators = torch.zeros_like( 245 | torch.as_tensor(gens)[:, 0], 246 | device=dist.device 247 | ) 248 | 249 | destroyers = dist[gens[:, 1], gens[:, 2]] 250 | else: 251 | creators = torch.zeros_like( 252 | torch.as_tensor(gens)[:], 253 | device=dist.device 254 | ) 255 | 256 | destroyers = torch.full_like( 257 | torch.as_tensor(gens)[:], 258 | torch.inf, 259 | dtype=torch.float, 260 | device=dist.device 261 | ) 262 | 263 | inf_pairs = numpy.full( 264 | shape=(gens.shape[0], 2), fill_value=-1 265 | ) 266 | gens = numpy.column_stack((gens, inf_pairs)) 267 | 268 | persistence_diagram = torch.stack( 269 | (creators, destroyers), 1 270 | ) 271 | 272 | return [PersistenceInformation(gens, persistence_diagram, 0)] 273 | else: 274 | result = [] 275 | 276 | for index, gens_ in enumerate(gens): 277 | # Dimension zero is handled differently, so we need to 278 | # use an offset here. Note that this is not used as an 279 | # index into the `gens` array any more. 280 | dimension = index + 1 281 | 282 | if finite: 283 | creators = dist[gens_[:, 0], gens_[:, 1]] 284 | destroyers = dist[gens_[:, 2], gens_[:, 3]] 285 | 286 | persistence_diagram = torch.stack( 287 | (creators, destroyers), 1 288 | ) 289 | else: 290 | creators = dist[gens_[:, 0], gens_[:, 1]] 291 | 292 | destroyers = torch.full_like( 293 | torch.as_tensor(gens_)[:, 0], 294 | torch.inf, 295 | dtype=torch.float, 296 | device=dist.device 297 | ) 298 | 299 | # Create special infinite pairs; we pretend that we 300 | # are concatenating with unknown edges here. 301 | inf_pairs = numpy.full( 302 | shape=(gens_.shape[0], 2), fill_value=-1 303 | ) 304 | gens_ = numpy.column_stack((gens_, inf_pairs)) 305 | 306 | persistence_diagram = torch.stack( 307 | (creators, destroyers), 1 308 | ) 309 | 310 | result.append( 311 | PersistenceInformation( 312 | gens_, 313 | persistence_diagram, 314 | dimension) 315 | ) 316 | 317 | return result 318 | 319 | def _concatenate_features(self, pers_info_finite, pers_info_infinite): 320 | """Concatenate finite and infinite features.""" 321 | def _apply(fin, inf): 322 | assert fin.dimension == inf.dimension 323 | 324 | diagram = torch.concat((fin.diagram, inf.diagram)) 325 | pairing = numpy.concatenate((fin.pairing, inf.pairing), axis=0) 326 | dimension = fin.dimension 327 | 328 | return PersistenceInformation( 329 | pairing=pairing, 330 | diagram=diagram, 331 | dimension=dimension 332 | ) 333 | 334 | return list(starmap(_apply, zip(pers_info_finite, pers_info_infinite))) 335 | -------------------------------------------------------------------------------- /torch_topological/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utilities module.""" 2 | 3 | from .general import is_iterable 4 | from .general import nesting_level 5 | from .general import pairwise 6 | from .general import wrap_if_not_iterable 7 | 8 | from .filters import SelectByDimension 9 | 10 | from .summary_statistics import total_persistence 11 | from .summary_statistics import persistent_entropy 12 | from .summary_statistics import polynomial_function 13 | 14 | __all__ = [ 15 | "is_iterable", 16 | "nesting_level", 17 | "pairwise", 18 | "persistent_entropy", 19 | "polynomial_function", 20 | "total_persistence", 21 | "wrap_if_not_iterable", 22 | "SelectByDimension", 23 | ] 24 | -------------------------------------------------------------------------------- /torch_topological/utils/filters.py: -------------------------------------------------------------------------------- 1 | """Filter functions for algorithmic outputs.""" 2 | 3 | import torch 4 | 5 | 6 | class SelectByDimension(torch.nn.Module): 7 | """Select persistence diagrams by dimension. 8 | 9 | This is a simple selector that enables filtering the outputs of 10 | persistent homology calculation algorithms by dimension: often, 11 | one is really only interested in a *specific* dimension but the 12 | corresponding algorithm yields more diagrams. To this end, this 13 | module can be applied as a lightweight filter. 14 | """ 15 | 16 | def __init__(self, min_dim, max_dim=None): 17 | """Prepare filter for subsequent usage. 18 | 19 | Provides the filter with the required parameters. A minimum 20 | dimension must be provided. There is also the option to use 21 | a maximum dimension, thus permitting filtering by ranges. 22 | 23 | Parameters 24 | ---------- 25 | min_dim : int 26 | Minimum dimension to allow through the filter. If this is 27 | the sole provided parameter, only diagrams satisfying the 28 | dimension requirement will be selected. 29 | 30 | max_dim : int 31 | Optional upper dimension. If set, the selection returns a 32 | diagram whose dimension is within `[min_dim, max_dim]`. 33 | """ 34 | super().__init__() 35 | 36 | self.min_dim = min_dim 37 | self.max_dim = max_dim 38 | 39 | def forward(self, X): 40 | """Apply selection parameters to input. 41 | 42 | Iterate over input and select diagrams according to the 43 | pre-defined parameters. 44 | 45 | Parameters 46 | ---------- 47 | X : iterable of `PersistenceInformation` 48 | An iterable containing `PersistenceInformation` objects at 49 | its lowest level. 50 | 51 | Returns 52 | ------- 53 | iterable 54 | Input, i.e. `X`, but with all non-matching persistence 55 | diagrams removed. 56 | """ 57 | return [ 58 | pers_info for pers_info in X if self._is_valid(pers_info) 59 | ] 60 | 61 | def _is_valid(self, pers_info): 62 | if self.max_dim is not None: 63 | return self.min_dim <= pers_info.dimension <= self.max_dim 64 | else: 65 | return self.min_dim == pers_info.dimension 66 | -------------------------------------------------------------------------------- /torch_topological/utils/general.py: -------------------------------------------------------------------------------- 1 | """General utility functions. 2 | 3 | This module contains generic utility functions that are used throughout 4 | the code base. If a function is 'relatively small' and somewhat generic 5 | it should be put here. 6 | """ 7 | 8 | import itertools 9 | 10 | 11 | def pairwise(iterable): 12 | """Return iterator to iterate over consecutive pairs. 13 | 14 | Parameters 15 | ---------- 16 | iterable : iterable 17 | 18 | Returns 19 | ------- 20 | iterator 21 | An iterator to iterate over consecutive pairs of the input data. 22 | 23 | Notes 24 | ----- 25 | A similar function appears in more recent versions of the 26 | ``collections`` module. For compatibility with old Python 27 | versions, we provide our own implementation. 28 | 29 | Examples 30 | -------- 31 | >>> [x for x in pairwise(["ABC"])] 32 | ["AB", "BC"] 33 | """ 34 | a, b = itertools.tee(iterable) 35 | next(b, None) 36 | return zip(a, b) 37 | 38 | 39 | def is_iterable(x): 40 | """Check whether variable is iterable. 41 | 42 | Parameters 43 | ---------- 44 | x : any 45 | Input object. 46 | 47 | Returns 48 | ------- 49 | bool 50 | `True` if `x` is iterable. 51 | """ 52 | result = True 53 | 54 | # This is the most generic way; it also permits objects that only 55 | # implement the `__getitem__` interface. 56 | try: 57 | iter(x) 58 | except TypeError: 59 | result = False 60 | 61 | return result 62 | 63 | 64 | def wrap_if_not_iterable(x): 65 | """Wrap variable in case it cannot be iterated over. 66 | 67 | This function provides a convenience wrapper for variables that need 68 | to be iterated over. If the variable is already an `iterable`, there 69 | is nothing to be done and it will be returned as-is. Otherwise, will 70 | 'wrap' the variable to be the single item of a list. 71 | 72 | The primary purpose of this function is to make it easier for users 73 | to interact with certain classes: essentially, one does not have to 74 | think any more about single inputs vs. `iterable` inputs. 75 | 76 | Parameters 77 | ---------- 78 | x : any 79 | Input object. 80 | 81 | Returns 82 | ------- 83 | list or type of x 84 | If `x` can be iterated over, `x` will be returned as-is. Else, 85 | will return `[x]`, i.e. a list containing `x`. 86 | 87 | Examples 88 | -------- 89 | >>> wrap_if_not_iterable(1.0) 90 | [1.0] 91 | >>> wrap_if_not_iterable('Hello, World!') 92 | 'Hello, World!' 93 | """ 94 | if is_iterable(x): 95 | return x 96 | else: 97 | return [x] 98 | 99 | 100 | def nesting_level(x): 101 | """Calculate nesting level of a list of objects. 102 | 103 | To convert between sparse and dense representations of topological 104 | features, we need to determine the nesting level of an input list. 105 | The nesting level is defined as the maximum number of times we can 106 | recurse into the object while still obtaining lists. 107 | 108 | Parameters 109 | ---------- 110 | x : list 111 | Input list of objects. 112 | 113 | Returns 114 | ------- 115 | int 116 | Nesting level of `x`. If `x` has no well-defined nesting level, 117 | for example because `x` is not a list of something, will return 118 | `0`. 119 | 120 | Notes 121 | ----- 122 | This function is implemented recursively. It is therefore a bad idea 123 | to apply it to objects with an extremely high nesting level. 124 | 125 | Examples 126 | -------- 127 | >>> nesting_level([1, 2, 3]) 128 | 1 129 | 130 | >>> nesting_level([[1, 2], [3, 4]]) 131 | 2 132 | """ 133 | # This is really only supposed to work with lists. Anything fancier, 134 | # for example a `torch.tensor`, can already be used as a dense data 135 | # structure. 136 | if not isinstance(x, list): 137 | return 0 138 | 139 | # Empty lists have a nesting level of 1. 140 | if len(x) == 0: 141 | return 1 142 | else: 143 | return max(nesting_level(y) for y in x) + 1 144 | -------------------------------------------------------------------------------- /torch_topological/utils/summary_statistics.py: -------------------------------------------------------------------------------- 1 | """Summary statistics for persistence diagrams.""" 2 | 3 | import torch 4 | 5 | 6 | def persistent_entropy(D, **kwargs): 7 | """Calculate persistent entropy of a persistence diagram. 8 | 9 | Parameters 10 | ---------- 11 | D : `torch.tensor` 12 | Persistence diagram, assumed to be in shape `(n, 2)`, where each 13 | entry corresponds to a tuple of the form :math:`(x, y)`, with 14 | :math:`x` denoting the creation of a topological feature and 15 | :math:`y` denoting its destruction. 16 | 17 | Returns 18 | ------- 19 | Persistent entropy of `D`. 20 | """ 21 | persistence = torch.diff(D) 22 | persistence = persistence[torch.isfinite(persistence)].abs() 23 | 24 | P = persistence.sum() 25 | probabilities = persistence / P 26 | 27 | # Ensures that a probability of zero will just result in 28 | # a logarithm of zero as well. This is required whenever 29 | # one deals with entropy calculations. 30 | indices = probabilities > 0 31 | log_prob = torch.zeros_like(probabilities) 32 | log_prob[indices] = torch.log2(probabilities[indices]) 33 | 34 | return torch.sum(-probabilities * log_prob) 35 | 36 | 37 | def polynomial_function(D, p, q, **kwargs): 38 | r"""Parametrise polynomial function over persistence diagrams. 39 | 40 | This function follows an approach by Adcock et al. [Adcock16a]_ and 41 | parametrises a polynomial function over a persistence diagram. 42 | 43 | Parameters 44 | ---------- 45 | D : `torch.tensor` 46 | Persistence diagram, assumed to be in shape `(n, 2)`, where each 47 | entry corresponds to a tuple of the form :math:`(x, y)`, with 48 | :math:`x` denoting the creation of a topological feature and 49 | :math:`y` denoting its destruction. 50 | 51 | p : float 52 | Exponent for persistence differences in the diagram. 53 | 54 | q : float 55 | Exponent for mean persistence in the diagram. 56 | 57 | Returns 58 | ------- 59 | Sum of the form :math:`\sigma L^p * \mu^q`, with :math:`L` denoting 60 | an individual persistence value, and :math:`\mu` denoting its 61 | average persistence. 62 | 63 | References 64 | ---------- 65 | .. [Adcock16a] A. Adcock et al., "The Ring of Algebraic Functions on 66 | Persistence Bar Codes", *Homology, Homotopy and Applications*, 67 | Volume 18, Issue 1, pp. 381--402, 2016. 68 | """ 69 | lengths = torch.diff(D) 70 | means = torch.sum(D, dim=-1, keepdim=True) / 2 71 | 72 | # Filter out non-finite values; the same mask works here because the 73 | # mean is non-finite if and only if the persistence is. 74 | mask = torch.isfinite(lengths) 75 | lengths = lengths[mask] 76 | means = means[mask] 77 | 78 | return torch.sum(torch.mul(lengths.pow(p), means.pow(q))) 79 | 80 | 81 | def total_persistence(D, p=2, **kwargs): 82 | """Calculate total persistence of a persistence diagram. 83 | 84 | This function will calculate the total persistence of a persistence 85 | diagram. Infinite values will be ignored. 86 | 87 | Parameters 88 | ---------- 89 | D : `torch.tensor` 90 | Persistence diagram, assumed to be in shape `(n, 2)`, where each 91 | entry corresponds to a tuple of the form :math:`(x, y)`, with 92 | :math:`x` denoting the creation of a topological feature and 93 | :math:`y` denoting its destruction. 94 | 95 | p : float 96 | Weight parameter for the total persistence calculation. 97 | 98 | Returns 99 | ------- 100 | float 101 | Total persistence of `D`. 102 | """ 103 | persistence = torch.diff(D) 104 | persistence = persistence[torch.isfinite(persistence)] 105 | 106 | return persistence.abs().pow(p).sum() 107 | 108 | 109 | def p_norm(D, p=2, **kwargs): 110 | """Calculate :math:`p`-norm of a persistence diagram. 111 | 112 | This function will calculate the :math:`p`-norm of a persistence 113 | diagram. Infinite value will be ignored. 114 | 115 | Parameters 116 | ---------- 117 | D : `torch.tensor` 118 | Persistence diagram, assumed to be in shape `(n, 2)`, where each 119 | entry corresponds to a tuple of the form :math:`(x, y)`, with 120 | :math:`x` denoting the creation of a topological feature and 121 | :math:`y` denoting its destruction. 122 | 123 | p : float 124 | Weight parameter for the norm calculation. It must be valid to 125 | raise a term to the :math:`p`th power, but the function has no 126 | additional checks for this. 127 | 128 | Returns 129 | ------- 130 | float 131 | :math:`p`-norm of `D`. 132 | """ 133 | return torch.pow(total_persistence(D), 1.0 / p) 134 | --------------------------------------------------------------------------------