├── .github
└── workflows
│ └── run_tests.yaml
├── .gitignore
├── .python-version
├── .readthedocs.yaml
├── CHANGELOG.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE.md
├── README.md
├── ROADMAP.md
├── docs
├── .gitignore
├── Makefile
├── make.bat
└── source
│ ├── comparison.rst
│ ├── conf.py
│ ├── data.rst
│ ├── examples
│ ├── autoencoders.rst
│ ├── index.rst
│ └── summary_statistics.rst
│ ├── index.rst
│ ├── nn.rst
│ ├── nutshell.rst
│ ├── usage.rst
│ └── utils.rst
├── logos
├── giotto.jpg
└── gudhi.png
├── poetry.lock
├── pyproject.toml
├── tests
├── __init__.py
├── test_alpha_complex.py
├── test_cubical_complex.py
├── test_layers.py
├── test_multi_scale_kernel.py
├── test_ot.py
├── test_pytorch_topological.py
└── test_vietoris_rips_complex.py
├── torch_topological.svg
└── torch_topological
├── __init__.py
├── data
├── __init__.py
├── shapes.py
└── utils.py
├── datasets
├── __init__.py
├── shapes.py
└── spheres.py
├── examples
├── alpha_complex.py
├── autoencoders.py
├── benchmarking.py
├── classification.py
├── cubical_complex.py
├── distances.py
├── gan.py
├── image_smoothing.py
├── summary_statistics.py
└── weighted_euler_characteristic_transform.py
├── nn
├── __init__.py
├── alpha_complex.py
├── cubical_complex.py
├── data.py
├── distances.py
├── graphs.py
├── layers.py
├── loss.py
├── multi_scale_kernel.py
├── sliced_wasserstein_distance.py
├── sliced_wasserstein_kernel.py
├── vietoris_rips_complex.py
└── weighted_euler_characteristic_transform.py
└── utils
├── __init__.py
├── filters.py
├── general.py
└── summary_statistics.py
/.github/workflows/run_tests.yaml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | - push
5 | - pull_request
6 |
7 | jobs:
8 | tests:
9 | runs-on: ubuntu-latest
10 |
11 | strategy:
12 | matrix:
13 | python-version: ["3.10"]
14 |
15 | steps:
16 | - uses: actions/checkout@v3
17 | - name: Install `poetry`
18 | run: |
19 | pipx install poetry
20 | poetry --version
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v3
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | cache: 'poetry'
26 | - name: Install package
27 | run: |
28 | poetry install
29 | - name: Run linter
30 | run: |
31 | poetry run flake8 torch_topological
32 | - name: Run tests
33 | run: |
34 | poetry run pytest
35 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 |
3 | *.pyc
4 | *.swp
5 |
6 | /data/
7 | /dist/
8 | __pycache__/
9 |
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.9.9
2 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | # Required to get access to more recent Python versions.
5 | image: testing
6 |
7 | sphinx:
8 | configuration: docs/source/conf.py
9 |
10 | python:
11 | version: 3.9
12 | install:
13 | - method: pip
14 | path: .
15 | extra_requirements:
16 | - docs
17 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | `pytorch-topological` follows the [semantic versioning](https://semver.org).
4 | This changelog contains all notable changes in the project.
5 |
6 | # v0.1.7
7 |
8 | ## Fixed
9 |
10 | - Fixed bug in `make_tensor` that caused the loss of device information.
11 | Device handling should now be transparent for clients.
12 |
13 | # v0.1.6
14 |
15 | ## Fixed
16 |
17 | - Fixed various documentation typos
18 | - Fixed bug in `make_tensor` creation for single batches
19 |
20 | # v0.1.5
21 |
22 | ## Added
23 |
24 | - A bunch of new test cases
25 | - Alpha complex class (`AlphaComplex`)
26 | - Dimension selector class (`SelectByDimension`)
27 | - Discussing additional packages in documentation
28 | - Linting for pull requests
29 |
30 | ## Fixed
31 |
32 | - Improved contribution guidelines
33 | - Improved documentation of summary statistics loss
34 | - Improved overall maintainability
35 | - Improved test cases
36 | - Simplified multi-scale kernel usage (distance calculations with different exponents)
37 | - Test case for cubical complexes
38 | - Usage of seed parameter for shape generation (following `numpy guidelines`)
39 |
40 | # v0.1.4
41 |
42 | ## Added
43 |
44 | - Batch handler for point cloud complexes
45 | - Sliced Wasserstein distance kernel
46 | - Support for pre-computed distances in Vietoris--Rips complex
47 |
48 | ## Fixed
49 |
50 | - Device compatibility issue (tensor being created on the wrong device)
51 | - Use of `dim` flag for alpha complexes
52 | - Various documentation issues and coding style problems
53 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of conduct
2 |
3 | Developing is a fundamentally collaborative endeavour. Whenever human
4 | beings interact, there is a potential for conflict. At the same time,
5 | there is an even greater potential to build something that is 'bigger
6 | than the sum of its parts.' In this spirit, all contributors shall be
7 | aware of the following guidelines:
8 |
9 | 1. Be tolerant of opposing views.
10 | 2. Be mindful of the way you utter your critiques; ensure that what you
11 | write is objective and does not attack a person or contain any
12 | disparaging remarks.
13 | 3. Be forgiving when interpreting the words and actions of others;
14 | always assume good intentions.
15 |
16 | We want contributors to participate in a respectful, collaborative, and
17 | overall productive space here. Any harassment will **not** be tolerated.
18 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to contribute to `pytorch-topological`
2 |
3 | As the saying goes: 'Many hands make light work.' Thanks for being
4 | willing to contribute to `pytorch-topological`! Here are some resources
5 | to get you started:
6 |
7 | - Check out the [road map](ROADMAP.md) of the project for a high-level
8 | overview of current directions.
9 | - Check out [open issues](/issues) in case you are looking for things
10 | to tackle.
11 |
12 | **When contributing code, please be aware that your contribution will
13 | fall under the terms of the [license](https://github.com/aidos-lab/pytorch-topological/blob/main/LICENSE.md)
14 | of `pytorch-topological`.**
15 |
16 | ## Pull requests
17 |
18 | If you propose some changes, a pull request is the easiest way to
19 | integrate them. Please be mindful of the coding conventions (see below)
20 | and [write good commit messages](https://cbea.ms/git-commit/).
21 |
22 | ## Coding conventions
23 |
24 | Above all, consider that this is *open source software*. It is meant
25 | to be used and extended by many people. Be mindful of them by making
26 | your code look nice and appealing to them. We cannot build upon some
27 | module no one understands.
28 |
29 | As a way to obtain some consistency in all contributions, your code
30 | should ideally be conform with [PEP 8](https://www.python.org/dev/peps/pep-0008/).
31 | You can check this using the great [`flake8`](https://flake8.pycqa.org/) tool:
32 |
33 | ```console
34 | $ flake8 script.py
35 | ```
36 |
37 | When creating a pull request, your code will be automatically checked
38 | for coding convention conformity.
39 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Copyright (C) 2021 Bastian Rieck
2 |
3 | Redistribution and use in source and binary forms, with or without
4 | modification, are permitted provided that the following conditions are
5 | met:
6 |
7 | 1. Redistributions of source code must retain the above copyright
8 | notice, this list of conditions and the following disclaimer.
9 |
10 | 2. Redistributions in binary form must reproduce the above copyright
11 | notice, this list of conditions and the following disclaimer in the
12 | documentation and/or other materials provided with the distribution.
13 |
14 | 3. Neither the name of the copyright holder nor the names of its
15 | contributors may be used to endorse or promote products derived from
16 | this software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
19 | IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
20 | TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
24 | TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
25 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
26 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
27 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # `pytorch-topological`: A topological machine learning framework for `pytorch`
4 |
5 | [](https://pytorch-topological.readthedocs.io/en/latest/?badge=latest) [](https://codeclimate.com/github/aidos-lab/pytorch-topological/maintainability)    [](https://github.com/aidos-lab/pytorch-topological/actions/workflows/run_tests.yaml)
6 |
7 | `pytorch-topological` (or `torch_topological`) is a topological machine
8 | learning framework for [PyTorch](https://pytorch.org). It aims to
9 | collect *loss terms* and *neural network layers* in order to simplify
10 | building the next generation of topology-based machine learning tools.
11 |
12 | # Topological machine learning in a nutshell
13 |
14 | *Topological machine learning* refers to a new class of machine learning
15 | algorithms that are able to make use of topological features in data
16 | sets. In contrast to methods based on a purely geometrical point of
17 | view, topological features are capable of focusing on *connectivity
18 | aspects* of a data set. This provides an interesting fresh perspective
19 | that can be used to create powerful hybrid algorithms, capable of
20 | yielding more insights into data.
21 |
22 | This is an *emerging research field*, firmly rooted in computational
23 | topology and topological data analysis. If you want to learn more about
24 | how topology and geometry can work in tandem, here are a few resources
25 | to get you started:
26 |
27 | - Amézquita et al., [*The Shape of Things to Come: Topological Data Analysis and Biology,
28 | from Molecules to Organisms*](https://doi.org/10.1002/dvdy.175), Developmental Dynamics
29 | Volume 249, Issue 7, pp. 816--833, 2020.
30 |
31 | - Hensel et al., [*A Survey of Topological Machine Learning Methods*](https://www.frontiersin.org/articles/10.3389/frai.2021.681108/full),
32 | Frontiers in Artificial Intelligence, 2021.
33 |
34 | # Installation and requirements
35 |
36 | `torch_topological` requires Python 3.9. More recent versions might work
37 | but necessitate building some dependencies by yourself; Python 3.9
38 | currently offers the smoothest experience.
39 | It is recommended to use the excellent [`poetry`](https://python-poetry.org) framework
40 | to install `torch_topological`:
41 |
42 | ```
43 | poetry add torch-topological
44 | ```
45 |
46 | Alternatively, use `pip` to install the package:
47 |
48 | ```
49 | pip install -U torch-topological
50 | ```
51 |
52 | **A note on older versions.** Older versions of Python are not
53 | explicitly supported, and things may break in unexpected ways.
54 | If you want to use a different version, check `pyproject.toml`
55 | and adjust the Python requirement to your preference. This may
56 | or may not work, good luck!
57 |
58 | # Usage
59 |
60 | `torch_topological` is still a work in progress. You can [browse the documentation](https://pytorch-topological.readthedocs.io)
61 | or, if code reading is more your thing, dive directly into [some example
62 | code](./torch_topological/examples).
63 |
64 | Here is a list of *other* projects that are using `torch_topological`:
65 |
66 | - [SHAPR](https://github.com/marrlab/SHAPR_torch), a method for for
67 | predicting the 3D cell shape of individual cells based on 2D
68 | microscopy images
69 |
70 | This list is incomplete---you can help expanding it by using
71 | `torch_topological` in your own projects! :innocent:
72 |
73 | # Contributing
74 |
75 | Check out the [contribution guidelines](CONTRIBUTING.md) or the [road
76 | map](ROADMAP.md) of the project.
77 |
78 | # Acknowledgements
79 |
80 | Our software and research does not exist in a vacuum. `pytorch-topological` is standing
81 | on the shoulders of proverbial giants. In particular, we want to thank the
82 | following projects for constituting the technical backbone of the
83 | project:
84 |
85 | | [`giotto-tda`](https://github.com/giotto-ai/giotto-tda) | [`gudhi`](https://github.com/GUDHI/gudhi-devel) |
86 | |---------------------------------------------------------------|-------------------------------------------------------------|
87 | | | |
88 |
89 | Furthermore, `pytorch-topological` draws inspiration from several
90 | projects that provide a glimpse into the wonderful world of topological
91 | machine learning:
92 |
93 | - [`difftda`](https://github.com/MathieuCarriere/difftda) by [Mathieu Carrière](https://github.com/MathieuCarriere)
94 |
95 | - [`Ripser`](https://github.com/Ripser/ripser) by [Ulrich Bauer](https://github.com/ubauer)
96 |
97 | - [`Teaspoon`](https://lizliz.github.io/teaspoon/) by [Elizabeth Munch](https://elizabethmunch.com/) and her team
98 |
99 | - [`TopologyLayer`](https://github.com/bruel-gabrielsson/TopologyLayer) by [Rickard Brüel Gabrielsson](https://github.com/bruel-gabrielsson)
100 |
101 | - [`topological-autoencoders`](https://github.com/BorgwardtLab/topological-autoencoders) by [Michael Moor](https://github.com/mi92), [Max Horn](https://github.com/ExpectationMax), and [Bastian Rieck](https://github.com/Pseudomanifold)
102 |
103 | - [`torchph`](https://github.com/c-hofer/torchph) by [Christoph Hofer](https://github.com/c-hofer) and [Roland Kwitt](https://github.com/rkwitt)
104 |
105 | Finally, `pytorch-topological` makes heavy use of [`POT`](https://pythonot.github.io), the Python Optimal Transport Library.
106 | We are indebted to the many contributors of all these projects.
107 |
--------------------------------------------------------------------------------
/ROADMAP.md:
--------------------------------------------------------------------------------
1 | # What's the future of `pytorch-topological`?
2 |
3 | ---
4 |
5 | **Vision** `pytorch-topological` aims to be the first stop for building
6 | powerful applications using *topological machine learning* algorithms,
7 | i.e. algorithms that are capable of jointly leveraging geometrical and
8 | topological features in a data set
9 |
10 | ---
11 |
12 | To make this vision a reality, we first and foremost need to rely on
13 | exceptional documentation. It is not enough to write outstanding code;
14 | we have to demonstrate the power of topological algorithms to our users
15 | by writing well-documented code and contributing examples.
16 |
17 | Here are short-term and long-term goals, roughly categorised:
18 |
19 | ## API
20 |
21 | - [ ] Provide consistent way of handling batches or tensor inputs. Most
22 | of the modules rely on *sparse* inputs as lists.
23 | - [ ] Support different backends for calculating persistent homology. At
24 | present, we use [`GUDHI`](https://github.com/GUDHI/gudhi-devel/) for
25 | cubical complexes and [`giotto-ph`](https://github.com/giotto-ai/giotto-ph)
26 | for Vietoris--Rips complexes. It would be nice to be able to swap
27 | implementations easily.
28 | - [ ] Check out the use of sparse tensors; could be a potential way
29 | forward for representing persistence information. The drawback is that
30 | we cannot fill everything with zeroes; there has to be a way to
31 | indicate 'unset' information.
32 |
33 | ## Complexes
34 |
35 | - [x] Add (rudimentary) support for alpha complexes: a basic
36 | implementation is already present.
37 |
38 | ## Distances and kernels
39 |
40 | At present, the module supports Wasserstein distance calculations and
41 | bottleneck distance calculations between persistence diagrams. In
42 | addition to this, several 'pseudo-distances' based on summary statistics
43 | have been implemented. There are numerous kernels out there that could
44 | be included:
45 |
46 | - [x] The multi-scale kernel by [Reininghaus et al.](https://openaccess.thecvf.com/content_cvpr_2015/papers/Reininghaus_A_Stable_Multi-Scale_2015_CVPR_paper.pdf)
47 | - [x] The sliced Wasserstein distance and kernel by [Carrière et al.](https://arxiv.org/abs/1706.03358).
48 |
49 | This list is **incomplete**.
50 |
51 | ## Layers
52 |
53 | There are quite a few topology-based layers that have been proposed by
54 | members of the community. We should include all of them to make them
55 | available with a single, consistent interface.
56 |
57 | - [ ] Include [`PersLay`](https://github.com/MathieuCarriere/perslay).
58 | This requires a conversion from TensorFlow code.
59 | - [ ] Include [`PLLay`](https://github.com/jisuk1/pllay).
60 | This requires a conversion from TensorFlow code.
61 | - [ ] Include [`SLayer`](https://github.com/c-hofer/torchph). This is
62 | still an ongoing effort.
63 | - [x] Include [`TopologyLayer`](https://github.com/bruel-gabrielsson/TopologyLayer).
64 |
65 | ## Loss terms
66 |
67 | - [x] Include *signature loss* from [`topological-autoencoders`](https://github.com/BorgwardtLab/topological-autoencoders)
68 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | build/
2 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/comparison.rst:
--------------------------------------------------------------------------------
1 | Comparison with other packages
2 | ==============================
3 |
4 | If you are already familiar with certain packages for calculating
5 | topological features, you might be interested in understanding in
6 | what aspects ``torch_topological`` differs from them. This is not
7 | meant to be a comprehensive comparison; we are aiming for a brief
8 | overview to simplify getting acquainted with the project.
9 |
10 | ``giotto-tda``
11 | --------------
12 |
13 | `giotto-tda `_ is a flagship
14 | package, developed by numerous members of L2F. Its primary goal is to
15 | provide an interface consistent with ``scikit-learn``, thus facilitating
16 | an integration of topological features into a data science workflow.
17 |
18 | By contrast, ``torch_topological`` is meant to simplify the development
19 | of hybrid algorithms that can be easily integrated into deep learning
20 | architectures. ``giotto-tda`` is developed by a large team with a much
21 | more professional development agenda, whereas ``torch_topological`` is
22 | geared more towards researchers that want to prototype the integration
23 | of topological features.
24 |
25 | ``Teaspoon``
26 | ------------
27 |
28 | `Teaspoon `_ is a library that
29 | targets topological signal processing applications, such as the analysis
30 | of time-varying systems or complex networks. ``Teaspoon`` integrates
31 | very nicely with ``scikit-learn`` and targets a different set of
32 | applications than ``torch_topological``.
33 |
34 | ``TopologyLayer``
35 | -----------------
36 |
37 | `TopologyLayer `_ is
38 | a library developed by Rickard Brüel Gabrielsson and others,
39 | accompanying their AISTATS publication `A Topology Layer for Machine Learning `_.
40 |
41 | ``torch_topological`` subsumes the functionality of ``TopologyLayer``,
42 | albeit under different names:
43 |
44 | - :py:class:`torch_topological.nn.VietorisRipsComplex` or
45 | :py:class:`torch_topological.nn.CubicalComplex` can be used to extract
46 | topological features from point clouds and images, respectively.
47 |
48 | - The ``BarcodePolyFeature`` and ``SumBarcodeLengths`` classes are
49 | incorporated as summary statistics loss functions instead. See the
50 | following example for more details: :doc:`examples/summary_statistics`
51 |
52 | - The ``PartialSumBarcodeLengths`` function is not implemented, mostly
53 | because a similar effect can be achieved by pruning the persistence
54 | diagram manually. This functionality might be added later on.
55 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | # import os
14 | # import sys
15 | # sys.path.insert(0, os.path.abspath('.'))
16 |
17 | # -- Project information -----------------------------------------------------
18 |
19 | project = 'torch_topological'
20 | copyright = '2022, Bastian Rieck'
21 | author = 'Bastian Rieck'
22 |
23 | # -- General configuration ---------------------------------------------------
24 |
25 | # Add any Sphinx extension module names here, as strings. They can be
26 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
27 | # ones.
28 | extensions = [
29 | 'sphinx.ext.autodoc',
30 | 'sphinx.ext.napoleon',
31 | 'sphinx.ext.linkcode',
32 | ]
33 |
34 | # Ensure that member functions are documented. These are sane defaults.
35 | autodoc_default_options = {
36 | 'members': True,
37 | 'member-order': 'bysource',
38 | 'special-members': '__init__',
39 | 'undoc-members': True,
40 | 'exclude-members': '__weakref__'
41 | }
42 |
43 | # Tries to assign some semantic meaning to arguments provided with
44 | # single backtics, such as `x`. This way, we can ignore `func` and
45 | # `class` targets etc. (They still work, though!)
46 | default_role = 'obj'
47 |
48 | # Add any paths that contain templates here, relative to this directory.
49 | templates_path = ['_templates']
50 |
51 | # List of patterns, relative to source directory, that match files and
52 | # directories to ignore when looking for source files.
53 | # This pattern also affects html_static_path and html_extra_path.
54 | exclude_patterns = ['build', 'Thumbs.db', '.DS_Store']
55 |
56 | # -- Options for HTML output -------------------------------------------------
57 |
58 | # The theme to use for HTML and HTML Help pages. See the documentation for
59 | # a list of builtin themes.
60 | #
61 | html_theme = 'sphinx_rtd_theme'
62 | html_logo = '../../torch_topological.svg'
63 | html_theme_options = {
64 | 'logo_only': False,
65 | 'display_version': True,
66 | }
67 |
68 | # Add any paths that contain custom static files (such as style sheets) here,
69 | # relative to this directory. They are copied after the builtin static files,
70 | # so a file named "default.css" will overwrite the builtin "default.css".
71 | html_static_path = ['_static']
72 |
73 | # Ensures that modules are sorted correctly. Since they all pertain to
74 | # the same package, the prefix itself can be ignored.
75 | modindex_common_prefix = ['torch_topological.']
76 |
77 | # Specifies how to actually find the sources of the modules. Ensures
78 | # that people can jump to files in the repository directly.
79 | def linkcode_resolve(domain, info):
80 | # Let's frown on global imports and do everything locally as much as
81 | # we can.
82 | import sys
83 | import torch_topological
84 |
85 | if domain != 'py':
86 | return None
87 | if not info['module']:
88 | return None
89 |
90 | # Attempt to identify the source file belonging to an `info` object.
91 | # This code is adapted from the Sphinx configuration of `numpy`; see
92 | # https://github.com/numpy/numpy/blob/main/doc/source/conf.py.
93 | def find_source_file(module):
94 | obj = sys.modules[module]
95 |
96 | for part in info['fullname'].split('.'):
97 | obj = getattr(obj, part)
98 |
99 | import inspect
100 | import os
101 |
102 | fn = inspect.getsourcefile(obj)
103 | fn = os.path.relpath(
104 | fn,
105 | start=os.path.dirname(torch_topological.__file__)
106 | )
107 |
108 | source, lineno = inspect.getsourcelines(obj)
109 | return fn, lineno, lineno + len(source) - 1
110 |
111 | try:
112 | module = info['module']
113 | source = find_source_file(module)
114 | except Exception:
115 | source = None
116 |
117 | root = f'https://github.com/aidos-lab/pytorch-topological/tree/main/{project}/'
118 |
119 | if source is not None:
120 | fn, start, end = source
121 | return root + f'{fn}#L{start}-L{end}'
122 | else:
123 | return None
124 |
--------------------------------------------------------------------------------
/docs/source/data.rst:
--------------------------------------------------------------------------------
1 | torch_topological.data
2 | ======================
3 |
4 | .. automodule:: torch_topological.data
5 | :members:
6 |
--------------------------------------------------------------------------------
/docs/source/examples/autoencoders.rst:
--------------------------------------------------------------------------------
1 | Autoencoders with geometrical--topological losses
2 | =================================================
3 |
4 | In this example, we will create a simple autoencoder based on the
5 | *Topological Signature Loss* introduced by Moor et al. [Moor20a]_.
6 |
7 | A simple autoencoder
8 | --------------------
9 |
10 | We first define a simple linear autoencoder. The representations
11 | obtained from this autoencoder are very similar to those obtained via
12 | PCA.
13 |
14 | .. literalinclude:: ../../../torch_topological/examples/autoencoders.py
15 | :language: python
16 | :pyobject: LinearAutoencoder
17 |
18 | Of particular interest in the code are the `encode` and `decode`
19 | functions. With ``encode``, we *embed* data in a latent space, whereas
20 | with ``decode``, we reconstruct it to its 'original' space.
21 |
22 | This reconstruction is of course never perfect. We therefore measure is
23 | quality using a reconstruction loss. Let's zoom into the specific
24 | function for this:
25 |
26 | .. literalinclude:: ../../../torch_topological/examples/autoencoders.py
27 | :language: python
28 | :pyobject: LinearAutoencoder.forward
29 |
30 | The important take-away here is that ``forward`` should return at least
31 | return one *loss value*. We will make use of this later on!
32 |
33 | A topological wrapper for autoencoder models
34 | --------------------------------------------
35 |
36 | Our previous model uses ``encode`` to provide us with a lower-dimensional
37 | representation, the so-called *latent representation*. We can use this
38 | representation in order to calculate a topology-based loss! To this end,
39 | let's write a new ``forward`` function that uses an existing model ``model``
40 | for the latent space generation:
41 |
42 | .. literalinclude:: ../../../torch_topological/examples/autoencoders.py
43 | :language: python
44 | :pyobject: TopologicalAutoencoder.forward
45 |
46 | In the code above, the important things are:
47 |
48 | 1. The use of a Vietoris--Rips complex ``self.vr`` to obtain persistence
49 | information about the input space ``x`` and the latent space ``z``,
50 | respectively. We call this type of data ``pi_x`` and ``pi_z``,
51 | respectively.
52 |
53 | 2. The call to a topology-based loss function ``self.loss()``, which takes
54 | two spaces ``x`` and ``y``, as well as their corresponding persistence
55 | information, to calculate the *signature loss* from [Moor20a]_.
56 |
57 | Putting this all together, we have the following 'wrapper class' that
58 | makes an existing model topology-aware:
59 |
60 | .. literalinclude:: ../../../torch_topological/examples/autoencoders.py
61 | :language: python
62 | :pyobject: TopologicalAutoencoder
63 |
64 | See [Moor20a]_ for more models to extend---being topology-aware can be
65 | crucial for many applications.
66 |
67 | Source code
68 | -----------
69 |
70 | Here's the full source code of this example.
71 |
72 | .. literalinclude:: ../../../torch_topological/examples/autoencoders.py
73 | :language: python
74 |
--------------------------------------------------------------------------------
/docs/source/examples/index.rst:
--------------------------------------------------------------------------------
1 | Examples
2 | ========
3 |
4 | ``torch_topological`` comes with several documented examples that
5 | showcase different use cases. Please also see the
6 | `examples `_
7 | directory on GitHub for more details and full code.
8 |
9 | .. toctree::
10 | :maxdepth: 1
11 | :caption: Contents:
12 |
13 | autoencoders
14 | summary_statistics
15 |
--------------------------------------------------------------------------------
/docs/source/examples/summary_statistics.rst:
--------------------------------------------------------------------------------
1 | Point cloud optimisation with summary statistics
2 | ================================================
3 |
4 | One interesting use case of ``torch_topological`` involves changing the
5 | shape of a point cloud using topological summary statistics. Such
6 | summary statistics can *either* be used as simple loss functions,
7 | constituting a computationally cheap way of assessing the topological
8 | similarity of a given point cloud to a target point cloud, *or*
9 | serve to highlight certain topological properties of a single point
10 | cloud.
11 |
12 | In this example, we will consider *both* operations.
13 |
14 | Ingredients
15 | -----------
16 |
17 | Our main ingredient is the :py:class:`torch_topological.nn.SummaryStatisticLoss`
18 | class. This class bundles different summary statistics on persistence
19 | diagrams and permits their calculation and comparison.
20 |
21 | This class can operate in two modes:
22 |
23 | 1. Calculating the loss for a single input data set.
24 | 2. Calculating the loss difference for two input data sets.
25 |
26 | Our example will showcase both of these modes!
27 |
28 | Optimising all the point clouds
29 | -------------------------------
30 |
31 | Here's the bulk of the code required to optimise a point cloud. We will
32 | walk through the most important parts!
33 |
34 | .. literalinclude:: ../../../torch_topological/examples/summary_statistics.py
35 | :language: python
36 | :pyobject: main
37 |
38 | Next to creating some test data sets---check out :py:mod:`torch_topological.data`
39 | for more routines---the most important thing is to make sure that ``X``,
40 | our point cloud, is a trainable parameter.
41 |
42 | With that being out of the way, we can set up the summary statistic loss
43 | and start training. The main loop of the training might be familiar to
44 | those of you that already have some experience with ``pytorch``: it
45 | merely evaluates the loss and optimises it, following a general
46 | structure:
47 |
48 | .. code-block::
49 |
50 | # Set up your favourite optimiser
51 | opt = optim.SGD(...)
52 |
53 | for i in range(100):
54 |
55 | # Do some calculations and obtain a loss term. In our specific
56 | # example, we have to get persistence information from data and
57 | # evaluate the loss based on that.
58 | loss = ...
59 |
60 | # This is what you will see in many such examples: we set all
61 | # gradients to zero and do a backwards pass.
62 | opt.zero_grad()
63 | loss.backward()
64 | opt.step()
65 |
66 | The rest of this example just involves some nice plotting.
67 |
68 | Source code
69 | -----------
70 |
71 | Here's the full source code of this example.
72 |
73 | .. literalinclude:: ../../../torch_topological/examples/summary_statistics.py
74 | :language: python
75 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | ``torch_topological`` -- Topological Machine Learning with ``pytorch``
2 | =======================================================================
3 |
4 | ``pytorch-topological``, also known as ``torch_topological``, brings the
5 | power of topological methods to **your** machine learning project.
6 | ``torch_topological`` is specifically geared towards working well with
7 | other ``PyTorch`` projects, so if you are already familiar with this
8 | framework you should feel right at home.
9 |
10 | .. toctree::
11 | :maxdepth: 2
12 | :caption: Getting Started
13 |
14 | usage
15 | nutshell
16 | comparison
17 |
18 | examples/index
19 |
20 | .. toctree::
21 | :maxdepth: 2
22 | :caption: Modules
23 |
24 | data
25 | nn
26 | utils
27 |
28 | Indices and tables
29 | ==================
30 |
31 | * :ref:`genindex`
32 | * :ref:`modindex`
33 | * :ref:`search`
34 |
--------------------------------------------------------------------------------
/docs/source/nn.rst:
--------------------------------------------------------------------------------
1 | torch_topological.nn
2 | ====================
3 |
4 | .. automodule:: torch_topological.nn
5 | :members:
6 |
--------------------------------------------------------------------------------
/docs/source/nutshell.rst:
--------------------------------------------------------------------------------
1 | Topological machine learning in a nutshell
2 | ==========================================
3 |
4 | If you are reading this, you are probably wondering about what
5 | topological machine learning can do for you and your projects.
6 | The purpose of this page is to provide a pithy and coarse, but
7 | hopefully *useful*, introduction to some concepts.
8 |
9 | .. contents::
10 | :local:
11 | :depth: 2
12 |
13 | High-level View
14 | ---------------
15 |
16 | If you are an expert in machine learning, you could summarise
17 | topological machine learning as novel set of inductive biases
18 | for models, making it possible to leverage properties such as
19 | connectivity of data set. To some extent, such properties are
20 | already covered by existing methods, but topology can be very
21 | useful as an additional 'lens' through which to view data. In
22 | graph learning tasks, for instance, being able to capture and
23 | use topological features---such as *cycles*---is *crucial* in
24 | order to improve predictive performance.
25 |
26 | Additional Resources
27 | --------------------
28 |
29 | Here are some additional resources that might be of interest:
30 |
31 | - Amézquita et al., "The Shape of Things to Come: Topological Data Analysis and Biology,
32 | from Molecules to Organisms", Developmental Dynamics
33 | Volume 249, Issue 7, pp. 816--833, 2020. `doi:10.1002/dvdy.175 `_
34 |
35 | - Hensel et al., "A Survey of Topological Machine Learning Methods",
36 | Frontiers in Artificial Intelligence. `doi:10.3389/frai.2021.681108 `_
37 |
--------------------------------------------------------------------------------
/docs/source/usage.rst:
--------------------------------------------------------------------------------
1 | Installing and using ``torch_topological``
2 | ==========================================
3 |
4 | Requirements
5 | ------------
6 |
7 | ``torch_topological`` requires at least Python 3.9. Normally, version
8 | resolution should work automatically. The precise mechanism for this
9 | depends on your installation method (see below).
10 |
11 | Installation via ``pip``
12 | ------------------------
13 |
14 | We recommended installing ``torch_topological`` using ``pip``. This way,
15 | you will always get a release version with a known set of features. It
16 | is *recommended* to use a virtual environment manager such as `poetry `_
17 | for handling the dependencies of your project.
18 |
19 | .. code-block:: console
20 |
21 | $ pip install torch_topological
22 |
23 | Installation from source
24 | ------------------------
25 |
26 | Installing the package from source requires a virtual environment
27 | manager capable of parsing ``pyproject.toml`` files. With `poetry `_,
28 | for instance, the following steps should be sufficient:
29 |
30 | .. code-block:: console
31 |
32 | $ git clone git@github.com:aidos-lab/pytorch-topological.git
33 | $ cd pytorch-topological
34 | $ poetry install
35 |
--------------------------------------------------------------------------------
/docs/source/utils.rst:
--------------------------------------------------------------------------------
1 | torch_topological.utils
2 | =======================
3 |
4 | .. automodule:: torch_topological.utils
5 | :members:
6 |
--------------------------------------------------------------------------------
/logos/giotto.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aidos-lab/pytorch-topological/9b22d9ebeda68cab88a9267cea53a76a3f9f0c99/logos/giotto.jpg
--------------------------------------------------------------------------------
/logos/gudhi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aidos-lab/pytorch-topological/9b22d9ebeda68cab88a9267cea53a76a3f9f0c99/logos/gudhi.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "torch_topological"
3 | version = "0.1.7"
4 | description = "A framework for topological machine learning based on `pytorch`."
5 | authors = [
6 | {name = "Bastian Rieck", email = "bastian@rieck.me"},
7 | ]
8 | license = "BSD-3-Clause"
9 | readme = "README.md"
10 | include = ["README.md"]
11 |
12 | [tool.poetry.dependencies]
13 | python = ">=3.9"
14 | matplotlib = "^3.5.0"
15 | giotto-ph = "^0.2.0"
16 | torch = ">=1.12.0"
17 | gudhi = "^3.4.1"
18 | POT = "^0.9.0"
19 |
20 | [tool.poetry.dev-dependencies]
21 | pytest = ">=5.2"
22 | Sphinx = ">=4.3.1"
23 | sphinx-rtd-theme = "^1.0.0"
24 | tqdm = "^4.62.3"
25 | torchvision = ">=0.11.2"
26 | flake8 = "^4.0.1"
27 |
28 | [tool.black]
29 | line-length = 79
30 |
31 | [build-system]
32 | requires = ["poetry-core>=1.0.0"]
33 | build-backend = "poetry.core.masonry.api"
34 |
35 | [project.urls]
36 | homepage = "https://github.com/aidos-lab/pytorch-topological"
37 | documentation = "https://pytorch-topological.readthedocs.io/"
38 | issues = "https://github.com/aidos-lab/pytorch-topological/issues"
39 | changelog = "https://github.com/aidos-lab/pytorch-topological/blob/main/CHANGELOG.md"
40 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aidos-lab/pytorch-topological/9b22d9ebeda68cab88a9267cea53a76a3f9f0c99/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_alpha_complex.py:
--------------------------------------------------------------------------------
1 | from torch_topological.datasets import SphereVsTorus
2 |
3 | from torch_topological.nn.data import make_tensor
4 |
5 | from torch_topological.nn import AlphaComplex
6 |
7 | from torch.utils.data import DataLoader
8 |
9 | batch_size = 64
10 |
11 |
12 | class TestAlphaComplexBatchHandling:
13 | data_set = SphereVsTorus(n_point_clouds=3 * batch_size)
14 | loader = DataLoader(
15 | data_set,
16 | batch_size=batch_size,
17 | shuffle=True,
18 | drop_last=False,
19 | )
20 |
21 | ac = AlphaComplex()
22 |
23 | def test_processing(self):
24 | for (x, y) in self.loader:
25 | pers_info = self.ac(x)
26 |
27 | assert pers_info is not None
28 | assert len(pers_info) == batch_size
29 |
30 | pers_info_dense = make_tensor(pers_info)
31 |
32 | assert pers_info_dense is not None
33 |
--------------------------------------------------------------------------------
/tests/test_cubical_complex.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import MNIST
2 |
3 | from torchvision.transforms import Compose
4 | from torchvision.transforms import Normalize
5 | from torchvision.transforms import ToTensor
6 |
7 | from torch_topological.nn import CubicalComplex
8 |
9 | from torch.utils.data import Dataset
10 | from torch.utils.data import DataLoader
11 | from torch.utils.data import RandomSampler
12 |
13 | from torch_topological.nn.data import batch_iter
14 | from torch_topological.nn.data import make_tensor
15 | from torch_topological.nn.data import PersistenceInformation
16 |
17 | import numpy as np
18 |
19 | import torch
20 |
21 |
22 | class RandomDataSet(Dataset):
23 | def __init__(self, n_samples, dim, side_length, n_channels):
24 | self.dim = dim
25 | self.side_length = side_length
26 | self.n_samples = n_samples
27 | self.n_channels = n_channels
28 |
29 | self.data = np.random.default_rng().normal(
30 | size=(n_samples, n_channels, *([side_length] * dim))
31 | )
32 |
33 | def __getitem__(self, index):
34 | return self.data[index]
35 |
36 | def __len__(self):
37 | return len(self.data)
38 |
39 |
40 | class TestCubicalComplex:
41 | batch_size = 32
42 |
43 | def test_single_image(self):
44 | x = np.random.default_rng().normal(size=(32, 32))
45 | x = torch.as_tensor(x)
46 | cc = CubicalComplex()
47 |
48 | pers_info = cc(x)
49 |
50 | assert pers_info is not None
51 | assert len(pers_info) == 2
52 | assert pers_info[0].dimension == 0
53 | assert pers_info[1].dimension == 1
54 |
55 | def test_image_with_channels(self):
56 | x = np.random.default_rng().normal(size=(3, 32, 32))
57 | x = torch.as_tensor(x)
58 | cc = CubicalComplex()
59 |
60 | pers_info = cc(x)
61 |
62 | assert pers_info is not None
63 | assert len(pers_info) == 3
64 | assert len(pers_info[0]) == 2
65 | assert pers_info[0][0].dimension == 0
66 | assert pers_info[1][1].dimension == 1
67 |
68 | def test_image_with_channels_and_batch(self):
69 | x = np.random.default_rng().normal(size=(self.batch_size, 3, 32, 32))
70 | x = torch.as_tensor(x)
71 | cc = CubicalComplex()
72 |
73 | pers_info = cc(x)
74 |
75 | assert pers_info is not None
76 | assert len(pers_info) == self.batch_size
77 | assert len(pers_info[0][0]) == 2
78 | assert pers_info[0][0][0].dimension == 0
79 | assert pers_info[1][1][1].dimension == 1
80 |
81 | def test_2d(self):
82 | for n_channels in [1, 3]:
83 | for squeeze in [False, True]:
84 | data_set = RandomDataSet(128, 2, 8, n_channels)
85 | loader = DataLoader(
86 | data_set,
87 | self.batch_size,
88 | shuffle=True,
89 | drop_last=False
90 | )
91 |
92 | if squeeze:
93 | data_set.data = data_set.data.squeeze()
94 |
95 | cc = CubicalComplex(dim=2)
96 |
97 | for batch in loader:
98 | pers_info = cc(batch)
99 |
100 | assert pers_info is not None
101 |
102 | def test_3d(self):
103 | data_set = RandomDataSet(128, 3, 8, 1)
104 | loader = DataLoader(
105 | data_set,
106 | self.batch_size,
107 | shuffle=True,
108 | drop_last=False
109 | )
110 |
111 | cc = CubicalComplex(dim=3)
112 |
113 | for batch in loader:
114 | pers_info = cc(batch)
115 |
116 | assert pers_info is not None
117 | assert len(pers_info) == self.batch_size
118 |
119 |
120 | class TestCubicalComplexBatchHandling:
121 | batch_size = 64
122 |
123 | data_set = MNIST(
124 | './data/MNIST',
125 | download=True,
126 | train=False,
127 | transform=Compose(
128 | [
129 | ToTensor(),
130 | Normalize([0.5], [0.5])
131 | ]
132 | ),
133 | )
134 |
135 | loader = DataLoader(
136 | data_set,
137 | batch_size=batch_size,
138 | shuffle=False,
139 | drop_last=True,
140 | sampler=RandomSampler(
141 | data_set,
142 | replacement=True,
143 | num_samples=100
144 | )
145 | )
146 |
147 | cc = CubicalComplex()
148 |
149 | def test_processing(self):
150 | for (x, y) in self.loader:
151 | pers_info = self.cc(x)
152 |
153 | assert pers_info is not None
154 | assert len(pers_info) == self.batch_size
155 |
156 | pers_info_dense = make_tensor(pers_info)
157 |
158 | assert pers_info_dense is not None
159 |
160 | def test_batch_iter(self):
161 | for (x, y) in self.loader:
162 | pers_info = self.cc(x)
163 |
164 | assert pers_info is not None
165 | assert len(pers_info) == self.batch_size
166 |
167 | assert sum(1 for x in batch_iter(pers_info)) == self.batch_size
168 |
169 | for x in batch_iter(pers_info):
170 | for y in x:
171 | assert isinstance(y, PersistenceInformation)
172 |
173 | for x in batch_iter(pers_info, dim=0):
174 |
175 | # Make sure that we have something to iterate over.
176 | assert sum(1 for y in x) != 0
177 |
178 | for y in x:
179 | assert isinstance(y, PersistenceInformation)
180 | assert y.dimension == 0
181 |
--------------------------------------------------------------------------------
/tests/test_layers.py:
--------------------------------------------------------------------------------
1 | from torch_topological.datasets import SphereVsTorus
2 |
3 | from torch_topological.nn.data import make_tensor
4 |
5 | from torch_topological.nn import VietorisRipsComplex
6 | from torch_topological.nn.layers import StructureElementLayer
7 |
8 | from torch.utils.data import DataLoader
9 |
10 | batch_size = 32
11 |
12 |
13 | class TestStructureElementLayer:
14 | data_set = SphereVsTorus(n_point_clouds=2 * batch_size)
15 | loader = DataLoader(
16 | data_set,
17 | batch_size=batch_size,
18 | shuffle=True,
19 | drop_last=False,
20 | )
21 |
22 | vr = VietorisRipsComplex(dim=1)
23 |
24 | layer = StructureElementLayer(10)
25 |
26 | def test_processing(self):
27 | for (x, y) in self.loader:
28 | pers_info = make_tensor(self.vr(x))
29 | output = self.layer(pers_info)
30 |
31 | assert pers_info is not None
32 | assert output is not None
33 |
--------------------------------------------------------------------------------
/tests/test_multi_scale_kernel.py:
--------------------------------------------------------------------------------
1 | from torch_topological.data import sample_from_unit_cube
2 |
3 | from torch_topological.nn import MultiScaleKernel
4 | from torch_topological.nn import VietorisRipsComplex
5 |
6 |
7 | class TestMultiScaleKernel:
8 | vr = VietorisRipsComplex(dim=1)
9 | kernel = MultiScaleKernel(1.0)
10 | X = sample_from_unit_cube(100)
11 | Y = sample_from_unit_cube(100)
12 |
13 | def test_pseudo_metric(self):
14 | pers_info_X, pers_info_Y = self.vr([self.X, self.Y])
15 | k_XX = self.kernel(pers_info_X, pers_info_X)
16 | k_YY = self.kernel(pers_info_Y, pers_info_Y)
17 | k_XY = self.kernel(pers_info_X, pers_info_Y)
18 |
19 | assert k_XY > 0
20 | assert k_XX > 0
21 | assert k_YY > 0
22 |
23 | assert k_XX + k_YY - 2 * k_XY >= 0
24 |
--------------------------------------------------------------------------------
/tests/test_ot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 |
5 | from torch_topological.nn import PersistenceInformation
6 | from torch_topological.nn import WassersteinDistance
7 |
8 |
9 | def wrap(diagram):
10 | diagram = torch.as_tensor(diagram, dtype=torch.float)
11 | return [
12 | PersistenceInformation([], diagram)
13 | ]
14 |
15 |
16 | class TestWassersteinDistance:
17 |
18 | def test_simple(self):
19 | X = [
20 | (3, 5),
21 | (1, 2),
22 | (3, 4)
23 | ]
24 |
25 | Y = [
26 | (0, 0),
27 | (3, 4),
28 | (5, 3),
29 | (7, 4)
30 | ]
31 |
32 | X = wrap(X)
33 | Y = wrap(Y)
34 |
35 | dist = WassersteinDistance()(X, Y)
36 | assert dist > 0.0
37 |
38 | def test_random(self):
39 | n_points = 10
40 | n_instances = 10
41 |
42 | for i in range(n_instances):
43 | X = np.random.default_rng().normal(size=(n_points, 2))
44 | Y = np.random.default_rng().normal(size=(n_points, 2))
45 |
46 | X = wrap(X)
47 | Y = wrap(Y)
48 |
49 | dist = WassersteinDistance()(X, Y)
50 | assert dist > 0.0
51 |
52 | def test_almost_zero(self):
53 | n_points = 100
54 |
55 | X = np.random.default_rng().uniform(-1e-16, 1e-11, size=(n_points, 2))
56 | Y = np.random.default_rng().uniform(-1e-16, 1e-11, size=(n_points, 2))
57 |
58 | X = wrap(X)
59 | Y = wrap(Y)
60 |
61 | dist = WassersteinDistance()(X, Y)
62 | assert dist > 0.0
63 |
--------------------------------------------------------------------------------
/tests/test_pytorch_topological.py:
--------------------------------------------------------------------------------
1 | from torch_topological import __version__
2 |
3 |
4 | def test_version():
5 | assert __version__ == '0.1.0'
6 |
--------------------------------------------------------------------------------
/tests/test_vietoris_rips_complex.py:
--------------------------------------------------------------------------------
1 | from torch_topological.datasets import SphereVsTorus
2 |
3 | from torch_topological.nn.data import batch_iter
4 | from torch_topological.nn.data import make_tensor
5 | from torch_topological.nn.data import PersistenceInformation
6 |
7 | from torch_topological.nn import VietorisRipsComplex
8 | from torch_topological.nn import WassersteinDistance
9 |
10 | from torch import cdist
11 | from torch import isinf
12 | from torch.utils.data import DataLoader
13 |
14 | import torch
15 | import pytest
16 |
17 | import numpy as np
18 |
19 | batch_size = 64
20 |
21 |
22 | class TestVietorisRipsComplex:
23 | data_set = SphereVsTorus(n_point_clouds=3 * batch_size)
24 | loader = DataLoader(
25 | data_set,
26 | batch_size=batch_size,
27 | shuffle=True,
28 | drop_last=False,
29 | )
30 |
31 | vr = VietorisRipsComplex(dim=1, p=1)
32 |
33 | def test_simple(self):
34 | for (x, y) in self.loader:
35 | pers_info = self.vr(x)
36 |
37 | assert pers_info is not None
38 | assert len(pers_info) == batch_size
39 |
40 | def test_predefined_distances(self):
41 | for (x, y) in self.loader:
42 |
43 | distances = cdist(x, x, p=1)
44 | pers_info1 = self.vr(distances, treat_as_distances=True)
45 | pers_info2 = self.vr(x, treat_as_distances=False)
46 |
47 | assert pers_info1 is not None
48 | assert pers_info2 is not None
49 | assert len(pers_info1) == batch_size
50 | assert len(pers_info1) == len(pers_info2)
51 |
52 | # Check that we are getting the same persistence diagrams.
53 | for pi1, pi2 in zip(pers_info1, pers_info2):
54 | dist = WassersteinDistance()(pi1, pi2)
55 | assert dist == pytest.approx(0.0)
56 |
57 |
58 | class TestVietorisRipsComplexThreshold:
59 | data_set = SphereVsTorus(n_point_clouds=3 * batch_size)
60 | loader = DataLoader(
61 | data_set,
62 | batch_size=batch_size,
63 | shuffle=True,
64 | drop_last=False,
65 | )
66 |
67 | vr = VietorisRipsComplex(
68 | dim=1,
69 | p=1,
70 | threshold=0.1,
71 | keep_infinite_features=True
72 | )
73 |
74 | def test_threshold(self):
75 | for (x, y) in self.loader:
76 | pers_info = self.vr(x)
77 |
78 | assert pers_info is not None
79 | assert len(pers_info) == batch_size
80 |
81 | assert(torch.any(isinf(pers_info[0][0].diagram.flatten().sum())))
82 |
83 |
84 | class TestVietorisRipsComplexBatchHandling:
85 | data_set = SphereVsTorus(n_point_clouds=3 * batch_size)
86 | loader = DataLoader(
87 | data_set,
88 | batch_size=batch_size,
89 | shuffle=True,
90 | drop_last=False,
91 | )
92 |
93 | vr = VietorisRipsComplex(dim=1)
94 |
95 | def test_processing(self):
96 | for (x, y) in self.loader:
97 | pers_info = self.vr(x)
98 |
99 | assert pers_info is not None
100 | assert len(pers_info) == batch_size
101 |
102 | pers_info_dense = make_tensor(pers_info)
103 |
104 | assert pers_info_dense is not None
105 |
106 | def test_ragged_processing(self):
107 | rng = np.random.default_rng()
108 |
109 | data = [
110 | np.random.default_rng().uniform(size=(rng.integers(32, 64), 8))
111 | for _ in range(batch_size)
112 | ]
113 |
114 | pers_info = self.vr(data)
115 |
116 | assert pers_info is not None
117 | assert len(pers_info) == batch_size
118 |
119 | def test_batch_iter(self):
120 | for (x, y) in self.loader:
121 | pers_info = self.vr(x)
122 |
123 | assert pers_info is not None
124 | assert len(pers_info) == batch_size
125 |
126 | # This is just to confirm that we can properly iterate over
127 | # this batch. Here, `batch_iter` is a little bit like `NoP`,
128 | # but in general, more complicated nested structures may be
129 | # present.
130 | assert sum(1 for x in batch_iter(pers_info)) == batch_size
131 |
132 | for x in batch_iter(pers_info):
133 | for y in x:
134 | assert isinstance(y, PersistenceInformation)
135 |
136 | for x in batch_iter(pers_info, dim=0):
137 |
138 | # Make sure that we have something to iterate over.
139 | assert sum(1 for y in x) != 0
140 |
141 | for y in x:
142 | assert isinstance(y, PersistenceInformation)
143 | assert y.dimension == 0
144 |
--------------------------------------------------------------------------------
/torch_topological.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
149 |
--------------------------------------------------------------------------------
/torch_topological/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.1.0'
2 |
--------------------------------------------------------------------------------
/torch_topological/data/__init__.py:
--------------------------------------------------------------------------------
1 | """Module for various data operations and data set creation strategies."""
2 |
3 | from .shapes import sample_from_annulus
4 | from .shapes import sample_from_disk
5 | from .shapes import sample_from_sphere
6 | from .shapes import sample_from_torus
7 | from .shapes import sample_from_unit_cube
8 |
9 | __all__ = [
10 | 'sample_from_annulus',
11 | 'sample_from_disk',
12 | 'sample_from_sphere',
13 | 'sample_from_torus',
14 | 'sample_from_unit_cube',
15 | ]
16 |
--------------------------------------------------------------------------------
/torch_topological/data/shapes.py:
--------------------------------------------------------------------------------
1 | """Contains sampling routines for various simple geometric objects."""
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from .utils import embed
7 |
8 |
9 | def sample_from_disk(n=100, r=0.9, R=1.0, seed=None):
10 | """Sample points from disk.
11 |
12 | Parameters
13 | ----------
14 | n : int
15 | Number of points to sample.
16 |
17 | r: float
18 | Minimum radius, i.e. the radius of the inner circle of a perfect
19 | sampling.
20 |
21 | R : float
22 | Maximum radius, i.e. the radius of the outer circle of a perfect
23 | sampling.
24 |
25 | seed : int, instance of `np.random.Generator`, or `None`
26 | Seed for the random number generator, or an instance of such
27 | a generator. If set to `None`, the default random number
28 | generator will be used.
29 |
30 | Returns
31 | -------
32 | torch.tensor of shape `(n, 2)`
33 | Tensor containing the sampled coordinates.
34 | """
35 | assert r <= R, RuntimeError('r > R')
36 |
37 | rng = np.random.default_rng(seed)
38 |
39 | length = rng.uniform(r, R, size=n)
40 | angle = np.pi * rng.uniform(0, 2, size=n)
41 |
42 | x = np.sqrt(length) * np.cos(angle)
43 | y = np.sqrt(length) * np.sin(angle)
44 |
45 | X = np.vstack((x, y)).T
46 | return torch.as_tensor(X)
47 |
48 |
49 | def sample_from_unit_cube(n, d=3, seed=None):
50 | """Sample points uniformly from unit cube in `d` dimensions.
51 |
52 | Parameters
53 | ----------
54 | n : int
55 | Number of points to sample
56 |
57 | d : int
58 | Number of dimensions.
59 |
60 | seed : int, instance of `np.random.Generator`, or `None`
61 | Seed for the random number generator, or an instance of such
62 | a generator. If set to `None`, the default random number
63 | generator will be used.
64 |
65 | Returns
66 | -------
67 | torch.tensor of shape `(n, d)`
68 | Tensor containing the sampled coordinates.
69 | """
70 | rng = np.random.default_rng(seed)
71 | X = rng.uniform(size=(n, d))
72 |
73 | return torch.as_tensor(X)
74 |
75 |
76 | def sample_from_sphere(n=100, d=2, r=1, noise=None, ambient=None, seed=None):
77 | """Sample `n` data points from a `d`-sphere in `d + 1` dimensions.
78 |
79 | Parameters
80 | -----------
81 | n : int
82 | Number of data points in shape.
83 |
84 | d : int
85 | Dimension of the sphere.
86 |
87 | r : float
88 | Radius of sphere.
89 |
90 | noise : float or None
91 | Optional noise factor. If set, data coordinates will be
92 | perturbed by a standard normal distribution, scaled by
93 | `noise`.
94 |
95 | ambient : int or None
96 | Embed the sphere into a space with ambient dimension equal to
97 | `ambient`. The sphere is randomly rotated into this
98 | high-dimensional space.
99 |
100 | seed : int, instance of `np.random.Generator`, or `None`
101 | Seed for the random number generator, or an instance of such
102 | a generator. If set to `None`, the default random number
103 | generator will be used.
104 |
105 | Returns
106 | -------
107 | torch.tensor
108 | Tensor of sampled coordinates. If `ambient` is set, array will be
109 | of shape `(n, ambient)`. Else, array will be of shape `(n, d + 1)`.
110 |
111 | Notes
112 | -----
113 | This function was originally authored by Nathaniel Saul as part of
114 | the `tadasets` package. [tadasets]_
115 |
116 | References
117 | ----------
118 | .. [tadasets] https://github.com/scikit-tda/tadasets
119 |
120 | """
121 | rng = np.random.default_rng(seed)
122 | data = rng.standard_normal((n, d+1))
123 |
124 | # Normalize points to the sphere
125 | data = r * data / np.sqrt(np.sum(data**2, 1)[:, None])
126 |
127 | if noise:
128 | data += noise * rng.standard_normal(data.shape)
129 |
130 | if ambient is not None:
131 | assert ambient > d
132 | data = embed(data, ambient)
133 |
134 | return torch.as_tensor(data)
135 |
136 |
137 | def sample_from_torus(n, d=3, r=1.0, R=2.0, seed=None):
138 | """Sample points uniformly from torus and embed it in `d` dimensions.
139 |
140 | Parameters
141 | ----------
142 | n : int
143 | Number of points to sample
144 |
145 | d : int
146 | Number of dimensions.
147 |
148 | r : float
149 | Radius of the 'tube' of the torus.
150 |
151 | R : float
152 | Radius of the torus, i.e. the distance from the centre of the
153 | 'tube' to the centre of the torus.
154 |
155 | seed : int, instance of `np.random.Generator`, or `None`
156 | Seed for the random number generator, or an instance of such
157 | a generator. If set to `None`, the default random number
158 | generator will be used.
159 |
160 | Returns
161 | -------
162 | torch.tensor of shape `(n, d)`
163 | Tensor of sampled coordinates.
164 | """
165 | if r > R:
166 | raise RuntimeError('Radius of the tube must be less than ' +
167 | 'or equal to radius of the torus')
168 |
169 | rng = np.random.default_rng(seed)
170 | angles = []
171 |
172 | while len(angles) < n:
173 | x = rng.uniform(0, 2 * np.pi)
174 | y = rng.uniform(0, 1 / np.pi)
175 |
176 | f = (1.0 + (r/R) * np.cos(x)) / (2 * np.pi)
177 |
178 | if y < f:
179 | psi = rng.uniform(0, 2 * np.pi)
180 | angles.append((x, psi))
181 |
182 | X = []
183 |
184 | for theta, psi in angles:
185 | a = R + r * np.cos(theta)
186 | x = a * np.cos(psi)
187 | y = a * np.sin(psi)
188 | z = r * np.sin(theta)
189 |
190 | X.append((x, y, z))
191 |
192 | X = np.asarray(X)
193 | return torch.as_tensor(X)
194 |
195 |
196 | def sample_from_annulus(n, r, R, seed=None):
197 | """Sample points from a 2D annulus.
198 |
199 | This function samples `N` points from an annulus with inner radius `r`
200 | and outer radius `R`.
201 |
202 | Parameters
203 | ----------
204 | n : int
205 | Number of points to sample
206 |
207 | r : float
208 | Inner radius of annulus
209 |
210 | R : float
211 | Outer radius of annulus
212 |
213 | seed : int, instance of `np.random.Generator`, or `None`
214 | Seed for the random number generator, or an instance of such
215 | a generator. If set to `None`, the default random number
216 | generator will be used.
217 |
218 | Returns
219 | -------
220 | torch.tensor of shape `(n, 2)`
221 | Tensor containing sampled coordinates.
222 | """
223 | if r >= R:
224 | raise RuntimeError(
225 | 'Inner radius must be less than or equal to outer radius'
226 | )
227 |
228 | rng = np.random.default_rng(seed)
229 | thetas = rng.uniform(0, 2 * np.pi, n)
230 |
231 | # Need to sample based on squared radii to account for density
232 | # differences.
233 | radii = np.sqrt(rng.uniform(r ** 2, R ** 2, n))
234 |
235 | X = np.column_stack((radii * np.cos(thetas), radii * np.sin(thetas)))
236 | return torch.as_tensor(X)
237 |
--------------------------------------------------------------------------------
/torch_topological/data/utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions for data set generation."""
2 |
3 | import numpy as np
4 |
5 |
6 | def embed(data, ambient=50):
7 | """Embed `data` in `ambient` dimensions.
8 |
9 | Parameters
10 | ----------
11 | data : array_like
12 | Input data set; needs to have shape `(n, d)`, i.e. samples are
13 | in the rows, dimensions are in the columns.
14 |
15 | ambient : int
16 | Dimension of embedding space. Must be greater than
17 | dimensionality of data.
18 |
19 | Returns
20 | -------
21 | array_like
22 | Input array of shape `(n, D)`, with `D = ambient`.
23 |
24 | Notes
25 | -----
26 | This function was originally authored by Nathaniel Saul as part of
27 | the `tadasets` package. [tadasets]_
28 |
29 | References
30 | ----------
31 | .. [tadasets] https://github.com/scikit-tda/tadasets
32 | """
33 | n, d = data.shape
34 | assert ambient > d
35 |
36 | base = np.zeros((n, ambient))
37 | base[:, :d] = data
38 |
39 | # construct a rotation matrix of dimension `ambient`.
40 | random_rotation = np.random.random((ambient, ambient))
41 | q, r = np.linalg.qr(random_rotation)
42 |
43 | base = np.dot(base, q)
44 | return base
45 |
--------------------------------------------------------------------------------
/torch_topological/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | """Data sets with special topological characteristics."""
2 |
3 | from .shapes import SphereVsTorus
4 |
5 | from .spheres import Spheres
6 |
7 | __all__ = [
8 | 'Spheres',
9 | 'SphereVsTorus',
10 | ]
11 |
--------------------------------------------------------------------------------
/torch_topological/datasets/shapes.py:
--------------------------------------------------------------------------------
1 | """Data sets based on simple topological shapes."""
2 |
3 | from torch_topological.data import sample_from_sphere
4 | from torch_topological.data import sample_from_torus
5 |
6 | from torch.utils.data import Dataset
7 |
8 | import torch
9 |
10 |
11 | class SphereVsTorus(Dataset):
12 | """Data set containing point cloud samples from spheres and tori."""
13 |
14 | def __init__(
15 | self,
16 | n_point_clouds=500,
17 | n_samples=100,
18 | shuffle=True
19 | ):
20 | """Create new instance of the data set.
21 |
22 | Parameters
23 | ----------
24 | n_point_clouds : int
25 | Number of point clouds to generate. Each point cloud will
26 | consist of `n_samples` samples.
27 |
28 | n_samples : int
29 | Number of samples to use for each of the `n_point_clouds`
30 | point clouds contained in the data set.
31 |
32 | shuffle : bool
33 | If set, shuffles point clouds. Else, point clouds will be
34 | stored in the order of their creation.
35 | """
36 | self.n_point_clouds = n_point_clouds
37 | self.n_samples = n_samples
38 |
39 | n_spheres = n_point_clouds // 2
40 | n_tori = n_point_clouds - n_spheres
41 |
42 | spheres = torch.stack([
43 | sample_from_sphere(self.n_samples) for i in range(n_spheres)
44 | ])
45 |
46 | tori = torch.stack([
47 | sample_from_torus(self.n_samples) for i in range(n_tori)
48 | ])
49 |
50 | labels = torch.as_tensor(
51 | [0] * n_spheres + [1] * n_tori,
52 | dtype=torch.long
53 | )
54 |
55 | self.data = torch.vstack((spheres, tori))
56 | self.labels = labels
57 |
58 | if shuffle:
59 | perm = torch.randperm(self.n_point_clouds)
60 |
61 | self.data = self.data[perm]
62 | self.labels = self.labels[perm]
63 |
64 | def __getitem__(self, index):
65 | """Get point cloud at `index`.
66 |
67 | Accesses the point cloud stored at `index` and returns it as
68 | well as its corresponding label.
69 |
70 | Parameters
71 | ----------
72 | index : int
73 | Index of samples to access.
74 |
75 | Returns
76 | -------
77 | Tuple of torch.tensor, torch.tensor
78 | Point cloud at index `index` and its label.
79 | """
80 | return self.data[index], self.labels[index]
81 |
82 | def __len__(self):
83 | """Get number of point clouds stored in data set.
84 |
85 | Returns
86 | -------
87 | int
88 | Number of point clouds stored in this instance of the class.
89 | """
90 | return len(self.data)
91 |
--------------------------------------------------------------------------------
/torch_topological/datasets/spheres.py:
--------------------------------------------------------------------------------
1 | """Create `SPHERES` data set."""
2 |
3 | import numpy as np
4 |
5 | import torch
6 |
7 | from torch.utils.data import Dataset
8 | from torch.utils.data import random_split
9 |
10 | from torch_topological.data import sample_from_sphere
11 |
12 |
13 | def create_sphere_dataset(n_samples=500, n_spheres=11, d=100, r=5, seed=None):
14 | """Create data set of high-dimensional spheres.
15 |
16 | Create `SPHERES` data set described in Moor et al. [Moor20a]_. The
17 | data sets consists of `n` spheres, enclosed by a single sphere. It
18 | is a perfect example of simple manifolds, being arranged in simple
19 | pattern, that is nevertheless challenging to embed by algorithms.
20 |
21 | Parameters
22 | ----------
23 | n_samples : int
24 | Number of points to sample per sphere.
25 |
26 | n_spheres : int
27 | Total number of spheres to create. The algorithm will always
28 | create the *last* sphere to enclose the previous ones. Hence,
29 | if `n_spheres = 3`, two spheres will be enclosed by a larger
30 | one.
31 |
32 | d : int
33 | Dimension of spheres to sample from. A `d`-sphere will be
34 | embedded in `d+1` dimensions.
35 |
36 | r : float
37 | Radius of smaller spheres. The radius of the larger enclosing
38 | sphere will be `5 * r`.
39 |
40 | seed : int, instance of `np.random.Generator`, or `None`
41 | Seed for the random number generator, or an instance of such
42 | a generator. If set to `None`, the default random number
43 | generator will be used.
44 |
45 | Returns
46 | -------
47 | Tuple of `np.array`, `np.array`
48 | Array containing the coordinates of the spheres. The second
49 | array contains the respective labels, ranging from `0` to
50 | `n_spheres - 1`. This array can be used for visualisation
51 | purposes.
52 |
53 | Notes
54 | -----
55 | The original version of this code was authored by Michael Moor.
56 |
57 | References
58 | ----------
59 | .. [Moor20a] M. Moor et al., "Topological Autoencoders",
60 | *Proceedings of the 37th International Conference on Machine
61 | Learning*, PMLR 119, pp. 7045--7054, 2020.
62 | """
63 | rng = np.random.default_rng(seed)
64 |
65 | variance = 10 / np.sqrt(d)
66 | shift_matrix = rng.normal(0, variance, [n_spheres, d+1])
67 |
68 | spheres = []
69 | n_datapoints = 0
70 | for i in np.arange(n_spheres - 1):
71 | sphere = sample_from_sphere(n=n_samples, d=d, r=r)
72 | spheres.append(sphere + shift_matrix[i, :])
73 | n_datapoints += n_samples
74 |
75 | # Build additional large surrounding sphere:
76 | n_samples_big = 10 * n_samples
77 | big = sample_from_sphere(n=n_samples_big, d=d, r=r*5)
78 | spheres.append(big)
79 | n_datapoints += n_samples_big
80 |
81 | X = np.concatenate(spheres, axis=0)
82 | y = np.zeros(n_datapoints)
83 |
84 | label_index = 0
85 |
86 | for index, data in enumerate(spheres):
87 | n_sphere_samples = data.shape[0]
88 | y[label_index:label_index + n_sphere_samples] = index
89 | label_index += n_sphere_samples
90 |
91 | return X, y
92 |
93 |
94 | class Spheres(Dataset):
95 | """Data set containing multiple spheres, enclosed by a larger one.
96 |
97 | This is a `Dataset` instance of the `SPHERES` data set described by
98 | Moor et al. [Moor20a]_. The class provides a convenient interface
99 | for making use of that data set in machine learning tasks.
100 | """
101 |
102 | def __init__(
103 | self,
104 | train=True,
105 | test_fraction=0.1,
106 | n_samples=100,
107 | n_spheres=11,
108 | r=5,
109 | ):
110 | """Create new instance of `SPHERES` data set.
111 |
112 | This class wraps the `SPHERES` data set for subsequent use in
113 | machine learning tasks. See :func:`create_sphere_dataset` for
114 | more details on the supported parameters.
115 |
116 | Parameters
117 | ----------
118 | train : bool
119 | If set, create and store training portion of data set.
120 |
121 | test_fraction : float
122 | Fraction of generated samples to be used as the test portion
123 | of the data set.
124 | """
125 | X, y = create_sphere_dataset(
126 | n_samples=n_samples,
127 | n_spheres=n_spheres,
128 | r=r)
129 |
130 | X = torch.as_tensor(X, dtype=torch.float)
131 |
132 | test_size = int(test_fraction * len(X))
133 | train_size = len(X) - test_size
134 |
135 | indices = torch.as_tensor(np.arange(len(X)), dtype=torch.long)
136 |
137 | train_indices, test_indices = random_split(
138 | indices, [train_size, test_size]
139 | )
140 |
141 | X_train, X_test = X[train_indices], X[test_indices]
142 | y_train, y_test = y[train_indices], y[test_indices]
143 |
144 | self.data = X_train if train else X_test
145 | self.labels = y_train if train else y_test
146 |
147 | self.dimension = X.shape[1]
148 |
149 | def __getitem__(self, index):
150 | """Return data point and label of a specific point.
151 |
152 | Parameters
153 | ----------
154 | index : int
155 | Index of point in data set. Invalid indices will raise an
156 | exception.
157 |
158 | Returns
159 | -------
160 | tuple of `torch.tensor` and `torch.tensor`
161 | The point at the specific index and its label, indicating
162 | the sphere it belongs to. See :func:`create_sphere_dataset`
163 | for more details on the specifics of label assignment.
164 | """
165 | return self.data[index], self.labels[index]
166 |
167 | def __len__(self):
168 | """Return number of points in data set.
169 |
170 | Returns
171 | -------
172 | int
173 | Number of samples in data set.
174 | """
175 | return len(self.data)
176 |
--------------------------------------------------------------------------------
/torch_topological/examples/alpha_complex.py:
--------------------------------------------------------------------------------
1 | """Example demonstrating the computation of alpha complexes.
2 |
3 | This simple example demonstrates how to use alpha complexes to change
4 | the appearance of a point cloud, following the `TopologyLayer
5 | `_ package.
6 |
7 | This example is still a **work in progress**.
8 | """
9 |
10 | from torch_topological.nn import AlphaComplex
11 | from torch_topological.nn import SummaryStatisticLoss
12 |
13 | from torch_topological.utils import SelectByDimension
14 |
15 | import numpy as np
16 | import matplotlib.pyplot as plt
17 |
18 | import torch
19 |
20 | if __name__ == '__main__':
21 | np.random.seed(42)
22 | data = np.random.rand(100, 2)
23 |
24 | alpha_complex = AlphaComplex()
25 |
26 | loss_fn = SummaryStatisticLoss(
27 | summary_statistic='polynomial_function',
28 | p=2,
29 | q=0
30 | )
31 |
32 | X = torch.nn.Parameter(torch.as_tensor(data), requires_grad=True)
33 | opt = torch.optim.Adam([X], lr=1e-2)
34 |
35 | for i in range(100):
36 | # We are only interested in working with persistence diagrams of
37 | # dimension 1.
38 | selector = SelectByDimension(1)
39 |
40 | # Let's think step by step; apparently, AIs like that! So let's
41 | # first get the persistence information of our complex. We pass
42 | # it through the selector to remove diagrams we do not need.
43 | pers_info = alpha_complex(X)
44 | pers_info = selector(pers_info)
45 |
46 | # Evaluate the loss; notice that we want to *maximise* it in
47 | # order to improve the holes in the data.
48 | loss = -loss_fn(pers_info)
49 |
50 | opt.zero_grad()
51 | loss.backward()
52 | opt.step()
53 |
54 | X = X.detach().numpy()
55 |
56 | plt.scatter(X[:, 0], X[:, 1])
57 | plt.show()
58 |
--------------------------------------------------------------------------------
/torch_topological/examples/autoencoders.py:
--------------------------------------------------------------------------------
1 | """Demo for topology-regularised autoencoders.
2 |
3 | This example demonstrates how to use `pytorch-topological` to create an
4 | additional differentiable loss term that makes autoencoders aware of
5 | topological features. See [Moor20a]_ for more information.
6 | """
7 |
8 | import torch
9 | import torch.optim as optim
10 |
11 | import matplotlib.pyplot as plt
12 |
13 | from tqdm import tqdm
14 |
15 | from torch.utils.data import DataLoader
16 |
17 | from torch_topological.datasets import Spheres
18 |
19 | from torch_topological.nn import SignatureLoss
20 | from torch_topological.nn import VietorisRipsComplex
21 |
22 |
23 | class LinearAutoencoder(torch.nn.Module):
24 | """Simple linear autoencoder class.
25 |
26 | This module performs simple embeddings based on an MSE loss. This is
27 | similar to ordinary principal component analysis. Notice that the
28 | class is only meant to provide a simple example that can be run
29 | easily even without the availability of a GPU. In practice, there
30 | are many more architectures with improved expressive power
31 | available.
32 | """
33 |
34 | def __init__(self, input_dim, latent_dim=2):
35 | """Create new autoencoder with pre-defined latent dimension."""
36 | super().__init__()
37 |
38 | self.input_dim = input_dim
39 | self.latent_dim = latent_dim
40 |
41 | self.encoder = torch.nn.Sequential(
42 | torch.nn.Linear(self.input_dim, self.latent_dim)
43 | )
44 |
45 | self.decoder = torch.nn.Sequential(
46 | torch.nn.Linear(self.latent_dim, self.input_dim)
47 | )
48 |
49 | self.loss_fn = torch.nn.MSELoss()
50 |
51 | def encode(self, x):
52 | """Embed data in latent space."""
53 | return self.encoder(x)
54 |
55 | def decode(self, z):
56 | """Decode data from latent space."""
57 | return self.decoder(z)
58 |
59 | def forward(self, x):
60 | """Embeds and reconstructs data, returning a loss."""
61 | z = self.encode(x)
62 | x_hat = self.decode(z)
63 |
64 | # The loss can of course be changed. If this is your first time
65 | # working with autoencoders, a good exercise would be to 'grok'
66 | # the meaning of different losses.
67 | reconstruction_error = self.loss_fn(x, x_hat)
68 | return reconstruction_error
69 |
70 |
71 | class TopologicalAutoencoder(torch.nn.Module):
72 | """Wrapper for a topologically-regularised autoencoder.
73 |
74 | This class uses another autoencoder model and imbues it with an
75 | additional topology-based loss term.
76 | """
77 | def __init__(self, model, lam=1.0):
78 | super().__init__()
79 |
80 | self.lam = lam
81 | self.model = model
82 | self.loss = SignatureLoss(p=2)
83 |
84 | # TODO: Make dimensionality configurable
85 | self.vr = VietorisRipsComplex(dim=0)
86 |
87 | def forward(self, x):
88 | z = self.model.encode(x)
89 |
90 | pi_x = self.vr(x)
91 | pi_z = self.vr(z)
92 |
93 | geom_loss = self.model(x)
94 | topo_loss = self.loss([x, pi_x], [z, pi_z])
95 |
96 | loss = geom_loss + self.lam * topo_loss
97 | return loss
98 |
99 |
100 | if __name__ == '__main__':
101 | # We first have to create a data set. This follows the original
102 | # publication by Moor et al. by introducing a simple 'manifold'
103 | # data set consisting of multiple spheres.
104 | n_spheres = 11
105 | data_set = Spheres(n_spheres=n_spheres)
106 |
107 | train_loader = DataLoader(
108 | data_set,
109 | batch_size=32,
110 | shuffle=True,
111 | drop_last=True
112 | )
113 |
114 | # Let's set up the two models that we are training. Note that in
115 | # a real application, you would have a more complicated training
116 | # setup, potentially with early stopping etc. This training loop
117 | # is merely to be seen as a proof of concept.
118 | model = LinearAutoencoder(input_dim=data_set.dimension)
119 | topo_model = TopologicalAutoencoder(model, lam=10)
120 |
121 | optimizer = optim.Adam(topo_model.parameters(), lr=1e-3)
122 |
123 | n_epochs = 5
124 |
125 | progress = tqdm(range(n_epochs))
126 |
127 | for i in progress:
128 | topo_model.train()
129 |
130 | for batch, (x, y) in enumerate(train_loader):
131 | loss = topo_model(x)
132 |
133 | optimizer.zero_grad()
134 | loss.backward()
135 | optimizer.step()
136 |
137 | progress.set_postfix(loss=loss.item())
138 |
139 | # Evaluate the autoencoder on a new instance of the data set.
140 | data_set = Spheres(
141 | train=False,
142 | n_samples=2000,
143 | n_spheres=n_spheres,
144 | )
145 |
146 | test_loader = DataLoader(
147 | data_set,
148 | shuffle=False,
149 | batch_size=len(data_set)
150 | )
151 |
152 | X, y = next(iter(test_loader))
153 | Z = topo_model.model.encode(X).detach().numpy()
154 |
155 | plt.scatter(
156 | Z[:, 0], Z[:, 1],
157 | c=y,
158 | cmap='Set1',
159 | marker='o',
160 | alpha=0.9,
161 | )
162 | plt.show()
163 |
--------------------------------------------------------------------------------
/torch_topological/examples/benchmarking.py:
--------------------------------------------------------------------------------
1 | """Debug script for benchmarking some calculations."""
2 |
3 | import time
4 | import torch
5 | import sys
6 |
7 | from torch_topological.nn import WassersteinDistance
8 | from torch_topological.nn import VietorisRipsComplex
9 |
10 |
11 | def run_test(X, vr, name, dist=False):
12 | W1 = WassersteinDistance()
13 |
14 | pre = time.perf_counter()
15 |
16 | X_pi = vr(X, treat_as_distances=dist)
17 | Y_pi = vr(X, treat_as_distances=dist)
18 |
19 | dists = torch.stack([W1(x_pi, y_pi) for x_pi, y_pi in zip(X_pi, Y_pi)])
20 | dist = dists.mean()
21 |
22 | cur = time.perf_counter()
23 | print(f"{name}: {cur - pre:.4f}s")
24 |
25 |
26 | if __name__ == "__main__":
27 | X = torch.load(sys.argv[1])
28 |
29 | print("Calculating everything ourselves")
30 |
31 | run_test(X, VietorisRipsComplex(dim=0), "raw")
32 | run_test(X, VietorisRipsComplex(dim=0, threshold=1.0), "thresholded")
33 |
34 | print("\nPre-defined distances")
35 |
36 | D = torch.cdist(X, X)
37 |
38 | run_test(D, VietorisRipsComplex(dim=0), "raw", dist=True)
39 | run_test(
40 | D, VietorisRipsComplex(dim=0, threshold=1.0), "thresholded", dist=True
41 | )
42 |
--------------------------------------------------------------------------------
/torch_topological/examples/classification.py:
--------------------------------------------------------------------------------
1 | """Example of classifying data using topological layers."""
2 |
3 | from torch_topological.datasets import SphereVsTorus
4 |
5 | from torch_topological.nn.data import make_tensor
6 |
7 | from torch_topological.nn import VietorisRipsComplex
8 | from torch_topological.nn.layers import StructureElementLayer
9 |
10 | from torch.utils.data import DataLoader
11 |
12 | from tqdm import tqdm
13 |
14 | import torch
15 |
16 |
17 | class TopologicalModel(torch.nn.Module):
18 | def __init__(self, n_elements, latent_dim=64, output_dim=2):
19 | super().__init__()
20 |
21 | self.n_elements = n_elements
22 | self.latent_dim = latent_dim
23 |
24 | self.model = torch.nn.Sequential(
25 | StructureElementLayer(self.n_elements),
26 | torch.nn.Linear(self.n_elements, self.latent_dim),
27 | torch.nn.ReLU(),
28 | torch.nn.Linear(self.latent_dim, output_dim),
29 | )
30 |
31 | self.vr = VietorisRipsComplex(dim=0)
32 |
33 | def forward(self, x):
34 | pers_info = self.vr(x)
35 | pers_info = make_tensor(pers_info)
36 |
37 | return self.model(pers_info)
38 |
39 |
40 | if __name__ == "__main__":
41 | batch_size = 32
42 | n_epochs = 50
43 | n_elements = 10
44 |
45 | data_set = SphereVsTorus(n_point_clouds=2 * batch_size)
46 |
47 | loader = DataLoader(
48 | data_set,
49 | batch_size=batch_size,
50 | shuffle=True,
51 | drop_last=False,
52 | )
53 |
54 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
55 |
56 | model = TopologicalModel(n_elements).to(device)
57 | loss_fn = torch.nn.CrossEntropyLoss()
58 | opt = torch.optim.Adam(model.parameters(), lr=1e-4)
59 |
60 | progress = tqdm(range(n_epochs))
61 |
62 | for epoch in progress:
63 | for batch, (x, y) in enumerate(loader):
64 | x = x.to(device)
65 | y = y.to(device)
66 |
67 | output = model(x)
68 | loss = loss_fn(output, y)
69 |
70 | opt.zero_grad()
71 | loss.backward()
72 | opt.step()
73 |
74 | pred = torch.argmax(output, dim=1)
75 | acc = (pred == y).sum() / len(y)
76 |
77 | progress.set_postfix(loss=f"{loss.item():.08f}", acc=f"{acc:.02f}")
78 |
--------------------------------------------------------------------------------
/torch_topological/examples/cubical_complex.py:
--------------------------------------------------------------------------------
1 | """Demo for calculating cubical complexes.
2 |
3 | This example demonstrates how to perform topological operations on
4 | a structured array, such as a grey-scale image.
5 | """
6 |
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 |
10 | from torch_topological.nn import CubicalComplex
11 | from torch_topological.nn import WassersteinDistance
12 |
13 | from tqdm import tqdm
14 |
15 | import torch
16 |
17 |
18 | def sample_circles(n_cells, n_samples=1000):
19 | """Sample two nested circles and bin them.
20 |
21 | Parameters
22 | ----------
23 | n_cells : int
24 | Number of cells for the 2D histogram, i.e. the 'resolution' of
25 | the histogram.
26 |
27 | n_samples : int
28 | Number of samples to use for creating the nested circles
29 | coordinates.
30 |
31 | Returns
32 | -------
33 | np.ndarray of shape ``(n_cells, n_cells)``
34 | Structured array containing intensity values for the data set.
35 | """
36 | from sklearn.datasets import make_circles
37 | X = make_circles(n_samples, shuffle=True, noise=0.01)[0]
38 |
39 | heatmap, *_ = np.histogram2d(X[:, 0], X[:, 1], bins=n_cells)
40 | heatmap -= heatmap.mean()
41 | heatmap /= heatmap.max()
42 |
43 | return heatmap
44 |
45 |
46 | if __name__ == '__main__':
47 |
48 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49 | print('Device', device)
50 |
51 | np.random.seed(23)
52 |
53 | Y = sample_circles(50)
54 | Y = torch.as_tensor(Y, dtype=torch.float)
55 | X = torch.as_tensor(
56 | Y + np.random.normal(scale=0.20, size=Y.shape),
57 | dtype=torch.float,
58 | device=device,
59 | )
60 | Y = Y.to(device)
61 | X = torch.nn.Parameter(X, requires_grad=True).to(device)
62 |
63 | source = X.clone()
64 |
65 | optimizer = torch.optim.Adam([X], lr=1e-3)
66 | loss_fn = WassersteinDistance(q=2)
67 |
68 | cubical_complex = CubicalComplex()
69 |
70 | persistence_information_target = cubical_complex(Y)
71 | persistence_information_target = persistence_information_target[0]
72 |
73 | n_iter = 500
74 | progress = tqdm(range(n_iter))
75 |
76 | for i in progress:
77 | persistence_information = cubical_complex(X)
78 | persistence_information = persistence_information[0]
79 |
80 | optimizer.zero_grad()
81 |
82 | loss = loss_fn(
83 | persistence_information,
84 | persistence_information_target
85 | )
86 |
87 | loss.backward()
88 | optimizer.step()
89 |
90 | progress.set_postfix(loss=loss.item())
91 |
92 | source = source.detach().numpy()
93 | target = Y.cpu().detach().numpy()
94 | result = X.cpu().detach().numpy()
95 |
96 | fig, ax = plt.subplots(ncols=3)
97 |
98 | ax[0].imshow(source)
99 | ax[0].set_title('Source')
100 |
101 | ax[1].imshow(target)
102 | ax[1].set_title('Target')
103 |
104 | ax[2].imshow(result)
105 | ax[2].set_title('Result')
106 |
107 | plt.show()
108 |
--------------------------------------------------------------------------------
/torch_topological/examples/distances.py:
--------------------------------------------------------------------------------
1 | """Demo for distance minimisations of a point cloud.
2 |
3 | Note
4 | ----
5 | This demonstration is a work in progress. It is not fully documented and
6 | tested yet.
7 | """
8 |
9 | from tqdm import tqdm
10 |
11 | from torch_topological.data import sample_from_disk
12 |
13 | from torch_topological.nn import VietorisRipsComplex
14 | from torch_topological.nn import WassersteinDistance
15 |
16 | import torch
17 | import torch.optim as optim
18 |
19 |
20 | import matplotlib.pyplot as plt
21 |
22 |
23 | if __name__ == '__main__':
24 | n = 100
25 |
26 | X = sample_from_disk(r=0.5, R=0.6, n=n)
27 | Y = sample_from_disk(r=0.9, R=1.0, n=n)
28 |
29 | X = torch.nn.Parameter(torch.as_tensor(X), requires_grad=True)
30 |
31 | vr = VietorisRipsComplex(dim=1)
32 |
33 | pi_target = vr(Y)
34 | loss_fn = WassersteinDistance(q=2)
35 |
36 | opt = optim.SGD([X], lr=0.1)
37 |
38 | n_iterations = 500
39 | progress = tqdm(range(n_iterations))
40 |
41 | for i in progress:
42 |
43 | opt.zero_grad()
44 |
45 | pi_source = vr(X)
46 | loss = loss_fn(pi_source, pi_target)
47 |
48 | loss.backward()
49 | opt.step()
50 |
51 | progress.set_postfix(loss=loss.item())
52 |
53 | X = X.detach().numpy()
54 |
55 | plt.scatter(X[:, 0], X[:, 1], label='Source')
56 | plt.scatter(Y[:, 0], Y[:, 1], label='Target')
57 |
58 | plt.legend()
59 | plt.show()
60 |
--------------------------------------------------------------------------------
/torch_topological/examples/gan.py:
--------------------------------------------------------------------------------
1 | """Example of topology-based GANs.
2 |
3 | This is a work in progress, demonstrating how to add a simple
4 | adversarial loss into a GAN architecture.
5 | """
6 |
7 | import torch
8 | import torchvision
9 |
10 | from torch_topological.nn import CubicalComplex
11 | from torch_topological.nn import WassersteinDistance
12 |
13 | from tqdm import tqdm
14 |
15 | import numpy as np
16 | import matplotlib.pyplot as plt
17 |
18 |
19 | class Generator(torch.nn.Module):
20 | """Simple generator module."""
21 |
22 | def __init__(self, latent_dim, shape):
23 | super().__init__()
24 |
25 | self.latent_dim = latent_dim
26 | self.output_dim = np.prod(shape)
27 |
28 | self.shape = shape
29 |
30 | def _make_layer(input_dim, output_dim):
31 | layers = [
32 | torch.nn.Linear(input_dim, output_dim),
33 | torch.nn.BatchNorm1d(output_dim, 0.8),
34 | torch.nn.LeakyReLU(0.2),
35 | ]
36 | return layers
37 |
38 | self.model = torch.nn.Sequential(
39 | *_make_layer(self.latent_dim, 32),
40 | *_make_layer(32, 64),
41 | *_make_layer(64, 128),
42 | torch.nn.Linear(128, self.output_dim),
43 | torch.nn.Sigmoid()
44 | )
45 |
46 | def forward(self, z):
47 | point_cloud = self.model(z)
48 | point_cloud = point_cloud.view(point_cloud.size(0), *self.shape)
49 |
50 | return point_cloud
51 |
52 |
53 | class Discriminator(torch.nn.Module):
54 | """Simple discriminator module."""
55 |
56 | def __init__(self, shape):
57 | super().__init__()
58 |
59 | input_dim = np.prod(shape)
60 |
61 | # Inspired by the original GAN. THERE CAN ONLY BE ONE!
62 | self.model = torch.nn.Sequential(
63 | torch.nn.Linear(input_dim, 256),
64 | torch.nn.LeakyReLU(0.2),
65 | torch.nn.Linear(256, 128),
66 | torch.nn.LeakyReLU(0.2),
67 | torch.nn.Linear(128, 1),
68 | torch.nn.Sigmoid(),
69 | )
70 |
71 | def forward(self, x):
72 | # Flatten point cloud
73 | x = x.view(x.size(0), -1)
74 | return self.model(x)
75 |
76 |
77 | class TopologicalAdversarialLoss(torch.nn.Module):
78 | """Loss term incorporating topological features."""
79 |
80 | def __init__(self):
81 | super().__init__()
82 |
83 | self.cubical = CubicalComplex()
84 | self.loss = WassersteinDistance(q=2)
85 |
86 | def forward(self, real, synthetic):
87 | """Calculate loss between real and synthetic images."""
88 |
89 | loss = 0.0
90 |
91 | # This could potentially be solved by stacking as well. Note
92 | # that the interface of `torch_topological` permits slightly
93 | # similar constructions.
94 | for x, y in zip(real, synthetic):
95 | # Remove single-channel information. For multiple channels,
96 | # this will have to be adapted.
97 | x = x.squeeze()
98 | y = y.squeeze()
99 |
100 | pi_real = self.cubical(x)[0]
101 | pi_synthetic = self.cubical(y)[0]
102 |
103 | loss += self.loss(pi_real, pi_synthetic)
104 | self.loss(pi_real, pi_synthetic)
105 |
106 | return loss
107 |
108 |
109 | def show(imgs):
110 | if not isinstance(imgs, list):
111 | imgs = [imgs]
112 | fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
113 | for i, img in enumerate(imgs):
114 | img = img.detach()
115 | img = torchvision.transforms.functional.to_pil_image(img)
116 | axs[0, i].imshow(np.asarray(img))
117 | axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
118 |
119 |
120 | if __name__ == "__main__":
121 | n_epochs = 5
122 |
123 | img_size = 16
124 | shape = (1, img_size, img_size)
125 |
126 | batch_size = 32
127 | latent_dim = 200
128 |
129 | data_loader = torch.utils.data.DataLoader(
130 | torchvision.datasets.MNIST(
131 | "./data/MNIST",
132 | train=False,
133 | download=True,
134 | transform=torchvision.transforms.Compose(
135 | [
136 | torchvision.transforms.Resize(img_size),
137 | torchvision.transforms.ToTensor(),
138 | torchvision.transforms.Normalize([0.5], [0.5]),
139 | ]
140 | ),
141 | ),
142 | batch_size=batch_size,
143 | shuffle=True,
144 | )
145 |
146 | generator = Generator(shape=shape, latent_dim=latent_dim)
147 | discriminator = Discriminator(shape=shape)
148 | adversarial_loss = torch.nn.BCELoss()
149 | topo_loss = TopologicalAdversarialLoss()
150 |
151 | opt_g = torch.optim.Adam(generator.parameters(), lr=1e-4)
152 | opt_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
153 |
154 | for epoch in range(n_epochs):
155 | for batch, (imgs, _) in tqdm(enumerate(data_loader), desc="Batch"):
156 | z = torch.autograd.Variable(
157 | torch.Tensor(
158 | np.random.normal(0, 1, (imgs.shape[0], latent_dim))
159 | )
160 | )
161 |
162 | real_labels = torch.as_tensor([1.0] * len(imgs)).view(-1, 1)
163 | fake_labels = torch.as_tensor([0.0] * len(imgs)).view(-1, 1)
164 |
165 | opt_g.zero_grad()
166 |
167 | imgs_synthetic = generator(z)
168 |
169 | generator_loss = adversarial_loss(
170 | discriminator(imgs_synthetic), real_labels
171 | ) + 0.01 * topo_loss(imgs, imgs_synthetic)
172 |
173 | generator_loss.backward()
174 |
175 | opt_g.step()
176 |
177 | opt_d.zero_grad()
178 |
179 | real_loss = adversarial_loss(discriminator(imgs), real_labels)
180 | fake_loss = adversarial_loss(
181 | discriminator(imgs_synthetic).detach(), fake_labels
182 | )
183 |
184 | discriminator_loss = 0.5 * (real_loss + fake_loss)
185 | discriminator_loss.backward()
186 |
187 | opt_d.step()
188 |
189 | output = imgs_synthetic.detach()
190 | grid = torchvision.utils.make_grid(output)
191 |
192 | show(grid)
193 |
194 | plt.show()
195 |
--------------------------------------------------------------------------------
/torch_topological/examples/image_smoothing.py:
--------------------------------------------------------------------------------
1 | """Demo for image smoothing based on topology.
2 |
3 | This is a work in progress, which at the moment merely showcases
4 | a simple topology-based loss applied to an image. There's *much*
5 | more to be done here.
6 | """
7 |
8 | import numpy as np
9 | import matplotlib.pyplot as plt
10 |
11 | from torch_topological.nn import CubicalComplex
12 | from torch_topological.nn import SummaryStatisticLoss
13 |
14 | from sklearn.datasets import make_circles
15 |
16 | import torch
17 |
18 |
19 | def _make_data(n_cells, n_samples=1000):
20 | X = make_circles(n_samples, shuffle=True, noise=0.05)[0]
21 |
22 | heatmap, *_ = np.histogram2d(X[:, 0], X[:, 1], bins=n_cells)
23 | heatmap -= heatmap.mean()
24 | heatmap /= heatmap.max()
25 |
26 | return heatmap
27 |
28 |
29 | class TopologicalSimplification(torch.nn.Module):
30 | def __init__(self, theta):
31 | super().__init__()
32 |
33 | self.theta = theta
34 |
35 | def forward(self, x):
36 | persistence_information = cubical(x)
37 | persistence_information = persistence_information[0]
38 |
39 | gens, pd = (
40 | persistence_information.pairing,
41 | persistence_information.diagram,
42 | )
43 |
44 | persistence = (pd[:, 1] - pd[:, 0]).abs()
45 | indices = persistence <= self.theta
46 |
47 | gens = gens[indices]
48 |
49 | indices = torch.vstack((gens[:, 0:2], gens[:, 2:]))
50 |
51 | indices = np.ravel_multi_index((indices[:, 0], indices[:, 1]), x.shape)
52 |
53 | x.ravel()[indices] = 0.0
54 |
55 | persistence_information = cubical(x)
56 | persistence_information = [persistence_information[0]]
57 |
58 | return x, persistence_information
59 |
60 |
61 | if __name__ == "__main__":
62 |
63 | np.random.seed(23)
64 |
65 | Y = _make_data(50)
66 | Y = torch.as_tensor(Y, dtype=torch.float)
67 | X = torch.as_tensor(
68 | Y + np.random.normal(scale=0.05, size=Y.shape), dtype=torch.float
69 | )
70 |
71 | theta = torch.nn.Parameter(
72 | torch.as_tensor(1.0),
73 | requires_grad=True,
74 | )
75 |
76 | topological_simplification = TopologicalSimplification(theta)
77 |
78 | optimizer = torch.optim.Adam([theta], lr=1e-2)
79 | loss_fn = SummaryStatisticLoss("total_persistence", p=1)
80 |
81 | cubical = CubicalComplex()
82 |
83 | persistence_information_target = cubical(Y)
84 | persistence_information_target = [persistence_information_target[0]]
85 |
86 | for i in range(500):
87 | X, persistence_information = topological_simplification(X)
88 |
89 | optimizer.zero_grad()
90 |
91 | loss = loss_fn(persistence_information, persistence_information_target)
92 |
93 | print(loss.item(), theta.item())
94 |
95 | theta.backward()
96 | optimizer.step()
97 |
98 | X = X.detach().numpy()
99 |
100 | plt.imshow(X)
101 | plt.show()
102 |
--------------------------------------------------------------------------------
/torch_topological/examples/summary_statistics.py:
--------------------------------------------------------------------------------
1 | """Demo for summary statistics minimisation of a point cloud.
2 |
3 | This example demonstrates how to use various topological summary
4 | statistics in order to change the shape of an input point cloud.
5 | The script can either demonstrate how to adjust the shape of two
6 | point clouds, i.e. using a summary statistic as a loss function,
7 | or how to change the shape of a *single* point cloud. By default
8 | two point clouds will be used.
9 | """
10 |
11 | import argparse
12 |
13 | import matplotlib.pyplot as plt
14 |
15 | from torch_topological.data import sample_from_disk
16 | from torch_topological.data import sample_from_unit_cube
17 |
18 | from torch_topological.nn import SummaryStatisticLoss
19 | from torch_topological.nn import VietorisRipsComplex
20 |
21 | from tqdm import tqdm
22 |
23 | import torch
24 | import torch.optim as optim
25 |
26 |
27 | def create_data_set(args):
28 | """Create data set based on user-provided arguments."""
29 | n = args.n_samples
30 | if args.single:
31 | X = sample_from_unit_cube(n=n, d=2)
32 | Y = X.clone()
33 | else:
34 | X = sample_from_disk(n=n, r=0.5, R=0.6)
35 | Y = sample_from_disk(n=n, r=0.9, R=1.0)
36 |
37 | # Make source point cloud adjustable by treating it as a parameter.
38 | # This enables topological loss functions to influence the shape of
39 | # `X`.
40 | X = torch.nn.Parameter(torch.as_tensor(X), requires_grad=True)
41 | return X, Y
42 |
43 |
44 | def main(args):
45 | """Run example."""
46 | n_iterations = args.n_iterations
47 | statistic = args.statistic
48 | p = args.p
49 | q = args.q
50 |
51 | X, Y = create_data_set(args)
52 | vr = VietorisRipsComplex(dim=2)
53 |
54 | if not args.single:
55 | pi_target = vr(Y)
56 |
57 | loss_fn = SummaryStatisticLoss(
58 | summary_statistic=statistic,
59 | p=p,
60 | q=q
61 | )
62 |
63 | opt = optim.SGD([X], lr=0.05)
64 | progress = tqdm(range(n_iterations))
65 |
66 | for i in progress:
67 | pi_source = vr(X)
68 |
69 | if not args.single:
70 | loss = loss_fn(pi_source, pi_target)
71 | else:
72 | loss = loss_fn(pi_source)
73 |
74 | opt.zero_grad()
75 | loss.backward()
76 | opt.step()
77 |
78 | progress.set_postfix(loss=f'{loss.item():.08f}')
79 |
80 | X = X.detach().numpy()
81 |
82 | if args.single:
83 | plt.scatter(X[:, 0], X[:, 1], label='Result')
84 | plt.scatter(Y[:, 0], Y[:, 1], label='Initial')
85 | else:
86 | plt.scatter(X[:, 0], X[:, 1], label='Source')
87 | plt.scatter(Y[:, 0], Y[:, 1], label='Target')
88 |
89 | plt.legend()
90 | plt.show()
91 |
92 |
93 | if __name__ == '__main__':
94 | parser = argparse.ArgumentParser()
95 |
96 | parser.add_argument(
97 | '-i', '--n-iterations',
98 | default=250,
99 | type=int,
100 | help='Number of iterations'
101 | )
102 |
103 | parser.add_argument(
104 | '-n', '--n-samples',
105 | default=100,
106 | type=int,
107 | help='Number of samples in point clouds'
108 | )
109 |
110 | parser.add_argument(
111 | '-s', '--statistic',
112 | choices=[
113 | 'persistent_entropy',
114 | 'polynomial_function',
115 | 'total_persistence',
116 | ],
117 | default='polynomial_function',
118 | help='Name of summary statistic to use for the loss'
119 | )
120 |
121 | parser.add_argument(
122 | '-S', '--single',
123 | action='store_true',
124 | help='If set, uses only a single point cloud'
125 | )
126 |
127 | parser.add_argument(
128 | '-p',
129 | type=float,
130 | default=2.0,
131 | help='Outer exponent for summary statistic loss calculation'
132 | )
133 |
134 | parser.add_argument(
135 | '-q',
136 | type=float,
137 | default=2.0,
138 | help='Inner exponent for summary statistic loss calculation. Will '
139 | 'only be used for certain summary statistics.'
140 | )
141 |
142 | args = parser.parse_args()
143 | main(args)
144 |
--------------------------------------------------------------------------------
/torch_topological/examples/weighted_euler_characteristic_transform.py:
--------------------------------------------------------------------------------
1 | """Demo for using the Weighted Euler Characteristic transform
2 | in an optimiblzation routine.
3 |
4 | This example demonstrates how the WECT can be used to optimize
5 | a neural networks predictions to match the topological signature
6 | of a target.
7 | """
8 | from torch import nn
9 | import torch
10 | from torch_topological.nn import EulerDistance, WeightedEulerCurve
11 | import torch.optim as optim
12 |
13 |
14 | class NN(nn.Module):
15 | def __init__(self, inp_dim, hidden_dim, out_dim):
16 | super(NN, self).__init__()
17 | self.fc1 = torch.nn.Linear(inp_dim, hidden_dim)
18 | self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
19 | self.fc3 = torch.nn.Linear(hidden_dim, out_dim)
20 | self.out_dim = out_dim
21 |
22 | def forward(self, x_):
23 | x = x_.clone()
24 | x = torch.nn.functional.relu(self.fc1(x))
25 | x = torch.nn.functional.relu(self.fc2(x))
26 | x = self.fc3(x)
27 | x = torch.nn.functional.sigmoid(x)
28 | out = int(self.out_dim ** (1 / 3))
29 | return x.reshape([out, out, out])
30 |
31 |
32 | if __name__ == "__main__":
33 | torch.manual_seed(4)
34 | z = 3
35 | arr = torch.ones([z, z, z], requires_grad=False)
36 | model = NN(z * z * z, 100, z * z * z)
37 | arr2 = torch.rand([z, z, z], requires_grad=False)
38 | arr2[arr2 > 0.5] = 1
39 | arr2[arr2 <= 0.5] = 0
40 | ec = WeightedEulerCurve(prod=True)
41 | dist = EulerDistance()
42 | optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
43 | ans = 100
44 | while ans > 0.05:
45 | optimizer.zero_grad()
46 | ans = dist(ec(model(arr.flatten())), ec(arr2))
47 | ans.backward()
48 | optimizer.step()
49 | with torch.no_grad():
50 | print(
51 | "L2 distance:",
52 | dist(model(arr.flatten()), arr2),
53 | " Euler Distance:",
54 | ans,
55 | )
56 |
--------------------------------------------------------------------------------
/torch_topological/nn/__init__.py:
--------------------------------------------------------------------------------
1 | """Layers and loss terms for persistence-based optimisation."""
2 |
3 | from .data import PersistenceInformation
4 |
5 | from .distances import WassersteinDistance
6 | from .sliced_wasserstein_distance import SlicedWassersteinDistance
7 | from .sliced_wasserstein_kernel import SlicedWassersteinKernel
8 | from .multi_scale_kernel import MultiScaleKernel
9 |
10 | from .loss import SignatureLoss
11 | from .loss import SummaryStatisticLoss
12 |
13 | from .alpha_complex import AlphaComplex
14 | from .cubical_complex import CubicalComplex
15 | from .vietoris_rips_complex import VietorisRipsComplex
16 |
17 | from .weighted_euler_characteristic_transform import WeightedEulerCurve
18 | from .weighted_euler_characteristic_transform import EulerDistance
19 |
20 |
21 | __all__ = [
22 | "AlphaComplex",
23 | "CubicalComplex",
24 | "EulerDistance",
25 | "MultiScaleKernel",
26 | "PersistenceInformation",
27 | "SignatureLoss",
28 | "SlicedWassersteinDistance",
29 | "SlicedWassersteinKernel",
30 | "SummaryStatisticLoss",
31 | "VietorisRipsComplex",
32 | "WassersteinDistance",
33 | "WeightedEulerCurve"
34 | ]
35 |
--------------------------------------------------------------------------------
/torch_topological/nn/alpha_complex.py:
--------------------------------------------------------------------------------
1 | """Alpha complex calculation module(s)."""
2 |
3 | from torch import nn
4 |
5 | from torch_topological.nn import PersistenceInformation
6 | from torch_topological.nn.data import batch_handler
7 |
8 | import gudhi
9 | import itertools
10 | import torch
11 |
12 |
13 | class AlphaComplex(nn.Module):
14 | """Calculate persistence diagrams of an alpha complex.
15 |
16 | This module calculates persistence diagrams of an alpha complex,
17 | i.e. a subcomplex of the Delaunay triangulation, which is sparse
18 | and thus often substantially smaller than other complex.
19 |
20 | It was first described in [Edelsbrunner94]_ and is particularly
21 | useful when analysing low-dimensional data.
22 |
23 | Notes
24 | -----
25 | At the moment, this alpha complex implementation, following other
26 | implementations, provides *distance-based filtrations* only. This
27 | means that the resulting persistence diagrams do *not* correspond
28 | to the circumradius of a simplex.
29 |
30 | In addition, this implementation is **work in progress**. Some of
31 | the core features, such as handling of infinite features, are not
32 | available at the moment.
33 |
34 | References
35 | ----------
36 | .. [Edelsbrunner94] H. Edelsbrunner and E.P. Mücke,
37 | "Three-dimensional alpha shapes", *ACM Transactions on Graphics*,
38 | Volume 13, Number 1, pp. 43--72, 1994.
39 | """
40 |
41 | def __init__(self, p=2):
42 | """Initialise new alpha complex calculation module.
43 |
44 | Parameters
45 | ----------
46 | p : float
47 | Exponent for the `p`-norm calculation of distances.
48 |
49 | Notes
50 | -----
51 | This module currently only supports Minkowski norms. It does not
52 | yet support other metrics.
53 | """
54 | super().__init__()
55 |
56 | self.p = p
57 |
58 | def forward(self, x):
59 | """Implement forward pass for persistence diagram calculation.
60 |
61 | The forward pass entails calculating persistent homology on
62 | a point cloud and returning a set of persistence diagrams.
63 |
64 | Parameters
65 | ----------
66 | x : array_like
67 | Input point cloud(s). `x` can either be a 2D array of shape
68 | `(n, d)`, which is treated as a single point cloud, or a 3D
69 | array/tensor of the form `(b, n, d)`, with `b` representing
70 | the batch size. Alternatively, you may also specify a list,
71 | possibly containing point clouds of non-uniform sizes.
72 |
73 | Returns
74 | -------
75 | list of :class:`PersistenceInformation`
76 | List of :class:`PersistenceInformation`, containing both the
77 | persistence diagrams and the generators, i.e. the
78 | *pairings*, of a certain dimension of topological features.
79 | If `x` is a 3D array, returns a list of lists, in which the
80 | first dimension denotes the batch and the second dimension
81 | refers to the individual instances of
82 | :class:`PersistenceInformation` elements.
83 |
84 | Generators will be represented in the persistence pairing
85 | based on proper creator--destroyer pairs of simplices. In
86 | dimension `k`, for instance, every generator is stored as
87 | a `k`-simplex followed by a `k+1` simplex.
88 | """
89 | return batch_handler(x, self._forward)
90 |
91 | def _forward(self, x):
92 | alpha_complex = gudhi.AlphaComplex(
93 | x.cpu().detach(),
94 | precision='fast',
95 | )
96 |
97 | st = alpha_complex.create_simplex_tree()
98 | st.persistence()
99 | persistence_pairs = st.persistence_pairs()
100 |
101 | max_dim = x.shape[-1]
102 | dist = torch.cdist(x.contiguous(), x.contiguous(), p=self.p)
103 |
104 | return [
105 | self._extract_generators_and_diagrams(
106 | dist,
107 | persistence_pairs,
108 | dim
109 | )
110 | for dim in range(0, max_dim)
111 | ]
112 |
113 | def _extract_generators_and_diagrams(self, dist, persistence_pairs, dim):
114 | pairs = [
115 | torch.cat((torch.as_tensor(p[0]), torch.as_tensor(p[1])), 0)
116 | for p in persistence_pairs if len(p[0]) == (dim + 1)
117 | ]
118 |
119 | # TODO: Ignore infinite features for now. Will require different
120 | # handling in the future.
121 | pairs = [p for p in pairs if len(p) == 2 * dim + 3]
122 |
123 | if not pairs:
124 | return PersistenceInformation(
125 | pairing=[],
126 | diagram=[],
127 | dimension=dim
128 | )
129 |
130 | # Create tensor of shape `(n, 2 * dim + 3)`, with `n` being the
131 | # number of finite persistence pairs.
132 | pairs = torch.stack(pairs)
133 |
134 | # We have to branch here because the creation of
135 | # zero-dimensional persistence diagrams is easy,
136 | # whereas higher-dimensional diagrams require an
137 | # involved lookup strategy.
138 | if dim == 0:
139 | creators = torch.zeros_like(
140 | torch.as_tensor(pairs)[:, 0],
141 | device=dist.device
142 | )
143 |
144 | # Iterate over the flag complex in order to get (a) the distance
145 | # of the creator simplex, and (b) the distance of the destroyer.
146 | # We *cannot* look this information up in the filtration itself,
147 | # because we have to use gradient-imbued information such as the
148 | # set of pairwise distances.
149 | else:
150 | creators = torch.stack(
151 | [
152 | self._get_filtration_weight(creator, dist)
153 | for creator in pairs[:, :dim+1]
154 | ]
155 | )
156 |
157 | # For the destroyers, we can always rely on the same
158 | # construction, regardless of dimensionality.
159 | destroyers = torch.stack(
160 | [
161 | self._get_filtration_weight(destroyer, dist)
162 | for destroyer in pairs[:, dim+1:]
163 | ]
164 | )
165 |
166 | # Create the persistence diagram from creator and destroyer
167 | # information. This step is the same for all dimensions.
168 | persistence_diagram = torch.stack(
169 | (creators, destroyers), 1
170 | )
171 |
172 | return PersistenceInformation(
173 | pairing=pairs,
174 | diagram=persistence_diagram,
175 | dimension=dim
176 | )
177 |
178 | def _get_filtration_weight(self, simplex, dist):
179 | """Auxiliary function for querying simplex weights.
180 |
181 | This function returns the filtration weight of an arbitrary
182 | simplex under a distance-based filtration, i.e. the maximum
183 | weight of its cofaces. This function is crucial for getting
184 | persistence diagrams that are differentiable.
185 |
186 | Parameters
187 | ----------
188 | simplex : torch.tensor
189 | Simplex to query; must be a sequence of numbers, i.e. of
190 | shape `(n, 1)` or of shape `(n, )`.
191 |
192 | dist : torch.tensor
193 | Matrix of pairwise distances between points.
194 |
195 | Returns
196 | -------
197 | torch.tensor
198 | Scalar tensor containing the filtration weight.
199 | """
200 | weights = torch.stack([
201 | dist[edge] for edge in itertools.combinations(simplex, 2)
202 | ])
203 |
204 | # TODO: Might have to be adjusted depending on filtration
205 | # ordering?
206 | return torch.max(weights)
207 |
--------------------------------------------------------------------------------
/torch_topological/nn/cubical_complex.py:
--------------------------------------------------------------------------------
1 | """Cubical complex calculation module."""
2 |
3 | from torch import nn
4 |
5 | from torch_topological.nn import PersistenceInformation
6 |
7 | import gudhi
8 | import torch
9 |
10 | import numpy as np
11 |
12 |
13 | class CubicalComplex(nn.Module):
14 | """Calculate cubical complex persistence diagrams.
15 |
16 | This module calculates 'differentiable' persistence diagrams for
17 | structured data, such as images. This is achieved by calculating
18 | a *cubical complex*.
19 |
20 | Cubical complexes are the natural choice for calculating topological
21 | features of highly-structured inputs. See [Rieck20a]_ for an example
22 | of how to apply such topological features in practice.
23 |
24 | References
25 | ----------
26 | .. [Rieck20a] B. Rieck et al., "Uncovering the Topology of
27 | Time-Varying fMRI Data Using Cubical Complex", *Advances in
28 | Neural Information Processing Systems 33*, pp. 6900--6912, 2020.
29 | """
30 |
31 | def __init__(self, superlevel=False, dim=None):
32 | """Initialise new module.
33 |
34 | Parameters
35 | ----------
36 | superlevel : bool
37 | Indicates whether to calculate topological features based on
38 | superlevel sets. By default, *sublevel set filtrations* are
39 | used.
40 |
41 | dim : int or `None`
42 | If set, describes dimension of input data. This is meant to
43 | be the dimension of an individual image **without** channel
44 | information, if any. The value of `dim` will change the way
45 | an input tensor is being handled: additional dimensions, if
46 | present, will be treated as batches or channels. If not set
47 | to an integer value, :func:`forward` will just *guess* what
48 | to do with an input (which should work in most cases).
49 |
50 | For example, when dealing with volume data, i.e. 3D tensors,
51 | set `dim=3` when instantiating the class. This will permit a
52 | seamless user experience with *both* batched and non-batched
53 | input data sets.
54 | """
55 | super().__init__()
56 |
57 | # TODO: This is handled somewhat inelegantly below. Might be
58 | # smarter to update.
59 | self.superlevel = superlevel
60 | self.dim = dim
61 |
62 | def forward(self, x):
63 | """Implement forward pass for persistence diagram calculation.
64 |
65 | The forward pass entails calculating persistent homology on a
66 | cubical complex and returning a set of persistence diagrams.
67 | The way the input will be interpreted depends on the presence
68 | of the `dim` attribute of this class. If `dim` is set, the
69 | *last* `dim` dimensions of an input tensor will be considered to
70 | contain the image data. If `dim` is not set, image dimensions
71 | will be guessed as follows:
72 |
73 | 1. Tensor of dimension 2: a single image
74 | 2. Tensor of dimension 3: a single 2D image with channels
75 | 3. Tensor of dimension 4: a batch of 2D images with channels
76 |
77 | This is a conservative way of handling the data, ensuring that
78 | by default, 2D tensors with channel information and a potential
79 | batch information can be handled, since this is the default for
80 | many applications.
81 |
82 | To ensure that the class can handle e.g. 3D volume data, it is
83 | sufficient to set `dim = 3` when initialising the class. Refer
84 | to the examples and parameters sections for more details.
85 |
86 | Parameters
87 | ----------
88 | x : array_like
89 | Input image(s). If `dim` has not been set, will *guess* how
90 | to handle the input as follows: `x` can either be a 2D array
91 | of shape `(H, W)`, which is treated as a single image, or
92 | a 3D array/tensor of the form `(C, H, W)`, with `C`
93 | representing the number of channels, or a 4D array/tensor of
94 | the form `(B, C, H, W)`, with `B` being the batch size. If
95 | `dim` has been set, the same handling strategy applies, but
96 | the *last* `dim` dimensions of the tensor are being used for
97 | the cubical complex calculation. All subsequent dimensions
98 | will be assumed to represent batches or channels (in this
99 | order). Hence, if `dim` is set, the tensor must at most have
100 | `dim + 2` dimensions.
101 |
102 | Returns
103 | -------
104 | list of :class:`PersistenceInformation`
105 | List of :class:`PersistenceInformation`, containing both the
106 | persistence diagrams and the generators, i.e. the
107 | *pairings*, of a certain dimension of topological features.
108 | If `x` is a 3D array, returns a list of lists, in which the
109 | first dimension denotes the batch and the second dimension
110 | refers to the individual instances of
111 | :class:`PersistenceInformation` elements. Similar for
112 | higher-order tensors.
113 |
114 | Examples
115 | --------
116 | # Handling 3D tensors (volumes), either in batches or presented
117 | # individually to the function.
118 | >> cubical_complex = CubicalComplex(dim=3)
119 | >> cubical_complex(x)
120 | """
121 | # Dimension was provided; this makes calculating the *effective*
122 | # dimension of the tensor much easier: take everything but the
123 | # last `self.dim` dimensions.
124 | if self.dim is not None:
125 | shape = x.shape[:-self.dim]
126 | dims = len(shape)
127 |
128 | # No dimension was provided; just use the shape provided by the
129 | # client.
130 | else:
131 | dims = len(x.shape) - 2
132 |
133 | # No additional dimensions present: a single image
134 | if dims == 0:
135 | return self._forward(x)
136 |
137 | # Handle image with channels, such as a tensor of the form `(C, H, W)`
138 | elif dims == 1:
139 | return [
140 | self._forward(x_) for x_ in x
141 | ]
142 |
143 | # Handle image with channels and batch index, such as a tensor of
144 | # the form `(B, C, H, W)`.
145 | elif dims == 2:
146 | return [
147 | [self._forward(x__) for x__ in x_] for x_ in x
148 | ]
149 |
150 | def _forward(self, x):
151 | """Handle a single-channel image.
152 |
153 | This internal function handles the calculation of topological
154 | features for a single-channel image, i.e. an `array_like`.
155 |
156 | Parameters
157 | ----------
158 | x : array_like of shape `(d_1, d_2, ..., d_d)`
159 | Single-channel input image of arbitrary dimensions. Batch
160 | dimensions and channel dimensions have to to be handled by
161 | the calling function explicitly. This function interprets
162 | its input as a high-dimensional image.
163 |
164 | Returns
165 | -------
166 | list of class:`PersistenceInformation`
167 | List of persistence information data structures, containing
168 | the persistence diagram and the persistence pairing of some
169 | dimension in the input data set.
170 | """
171 | if self.superlevel:
172 | x = -x
173 |
174 | cubical_complex = gudhi.CubicalComplex(
175 | dimensions=x.shape,
176 | top_dimensional_cells=x.flatten()
177 | )
178 |
179 | # We need the persistence pairs first, even though we are *not*
180 | # using them directly here.
181 | cubical_complex.persistence()
182 | cofaces = cubical_complex.cofaces_of_persistence_pairs()
183 |
184 | max_dim = len(x.shape)
185 |
186 | # TODO: Make this configurable; is it possible that users only
187 | # want to return a *part* of the data?
188 | persistence_information = [
189 | self._extract_generators_and_diagrams(
190 | x,
191 | cofaces,
192 | dim
193 | ) for dim in range(0, max_dim)
194 | ]
195 |
196 | return persistence_information
197 |
198 | def _extract_generators_and_diagrams(self, x, cofaces, dim):
199 | pairs = torch.empty((0, 2), dtype=torch.long)
200 |
201 | try:
202 | regular_pairs = torch.as_tensor(
203 | cofaces[0][dim], dtype=torch.long
204 | )
205 | pairs = torch.cat(
206 | (pairs, regular_pairs)
207 | )
208 | except IndexError:
209 | pass
210 |
211 | try:
212 | infinite_pairs = torch.as_tensor(
213 | cofaces[1][dim], dtype=torch.long
214 | )
215 | except IndexError:
216 | infinite_pairs = None
217 |
218 | if infinite_pairs is not None:
219 | # 'Pair off' all the indices
220 | max_index = torch.argmax(x)
221 | fake_destroyers = torch.empty_like(infinite_pairs).fill_(max_index)
222 |
223 | infinite_pairs = torch.stack(
224 | (infinite_pairs, fake_destroyers), 1
225 | )
226 |
227 | pairs = torch.cat(
228 | (pairs, infinite_pairs)
229 | )
230 |
231 | return self._create_tensors_from_pairs(x, pairs, dim)
232 |
233 | # Internal utility function to handle the 'heavy lifting:'
234 | # creates tensors from sets of persistence pairs.
235 | def _create_tensors_from_pairs(self, x, pairs, dim):
236 |
237 | xs = x.shape
238 |
239 | # Notice that `creators` and `destroyers` refer to pixel
240 | # coordinates in the image.
241 | creators = torch.as_tensor(
242 | np.column_stack(
243 | np.unravel_index(pairs[:, 0], xs)
244 | ),
245 | dtype=torch.long
246 | )
247 | destroyers = torch.as_tensor(
248 | np.column_stack(
249 | np.unravel_index(pairs[:, 1], xs)
250 | ),
251 | dtype=torch.long
252 | )
253 | gens = torch.as_tensor(torch.hstack((creators, destroyers)))
254 |
255 | # TODO: Most efficient way to generate diagram again?
256 | persistence_diagram = torch.stack((
257 | x.ravel()[pairs[:, 0]],
258 | x.ravel()[pairs[:, 1]]
259 | ), 1)
260 |
261 | return PersistenceInformation(
262 | pairing=gens,
263 | diagram=persistence_diagram,
264 | dimension=dim
265 | )
266 |
--------------------------------------------------------------------------------
/torch_topological/nn/data.py:
--------------------------------------------------------------------------------
1 | """Data structures for persistent homology calculations."""
2 |
3 | from collections import namedtuple
4 |
5 | from itertools import chain
6 |
7 | from operator import itemgetter
8 |
9 | from torch_topological.utils import nesting_level
10 |
11 | import torch
12 |
13 |
14 | class PersistenceInformation(
15 | namedtuple(
16 | "PersistenceInformation",
17 | [
18 | "pairing",
19 | "diagram",
20 | "dimension",
21 | ],
22 | # Ensures that there is always a dimension specified, albeit an
23 | # 'incorrect' one.
24 | defaults=[None],
25 | )
26 | ):
27 | """Persistence information data structure.
28 |
29 | This is a light-weight data structure for carrying information about
30 | the calculation of persistent homology. It consists of the following
31 | components:
32 |
33 | - A *persistence pairing*
34 | - A *persistence diagram*
35 | - An (optional) *dimension*
36 |
37 | Due to its lightweight nature, no validity checks are performed, but
38 | all calculation modules should return a sequence of instances of the
39 | :class:`PersistenceInformation` class.
40 |
41 | Since this data class is shared with modules that are capable of
42 | calculating persistent homology, the exact form of the persistence
43 | pairing might change. Please refer to the respective classes for
44 | more documentation.
45 | """
46 |
47 | __slots__ = ()
48 |
49 | # Disable iterating over the class since it collates heterogeneous
50 | # information and should rather be treated as a building block.
51 | __iter__ = None
52 |
53 |
54 | def make_tensor(x):
55 | """Create dense tensor representation from sparse inputs.
56 |
57 | This function turns sparse inputs of :class:`PersistenceInformation`
58 | objects into 'dense' tensor representations, thus providing a useful
59 | integration into differentiable layers.
60 |
61 | The dimension of the resulting tensor depends on maximum number of
62 | topological features, summed over all dimensions in the data. This
63 | is similar to the format in `giotto-ph`.
64 |
65 | Parameters
66 | ----------
67 | x : list of (list of ...) :class:`PersistenceInformation`
68 | Input, consisting of a (potentially) nested list of
69 | :class:`PersistenceInformation` objects as obtained
70 | from a persistent homology calculation module, such
71 | as :class:`VietorisRipsComplex`.
72 |
73 | Returns
74 | -------
75 | torch.tensor
76 | Dense tensor representation of `x`. The output is best
77 | understood by considering some examples: given a batch
78 | obtained from :class:`VietorisRipsComplex`, our tensor
79 | will have shape `(B, N, 3)`. `B` is the batch size and
80 | `N` is the sum of maximum lengths of diagrams relative
81 | to this batch. Each entry will consist of a creator, a
82 | destroyer, and a dimension. Dummy entries, used to pad
83 | the batch, can be detected as `torch.nan`.
84 | """
85 | level = nesting_level(x)
86 |
87 | # Internal utility function for calculating the length of the output
88 | # tensor. This is required to ensure that all inputs can be *merged*
89 | # into a single output tensor.
90 | def _calculate_length(x, level):
91 |
92 | # Simple base case; should never occur in practice but let's be
93 | # consistent here.
94 | if len(x) == 0:
95 | return 0
96 |
97 | # Each `chain.from_iterable()` removes an additional layer of
98 | # nesting. We only have to start from level 2; we get level 1
99 | # for free because we can always iterate over a list.
100 | for i in range(2, level + 1):
101 | x = chain.from_iterable(x)
102 |
103 | # Collect information that we need to create the full tensor. An
104 | # entry of the resulting list contains the length of the diagram
105 | # and the dimension, making it possible to derive padding values
106 | # for all entries.
107 | M = list(map(lambda a: (len(a.diagram), a.dimension), x))
108 |
109 | # Get maximum dimension
110 | dim = max(M, key=itemgetter(1))[1]
111 |
112 | # Get *sum* of maximum number of entries for each dimension.
113 | # This is calculated over all batches.
114 | N = sum(
115 | [
116 | max([L for L in M if L[1] == d], key=itemgetter(0))[0]
117 | for d in range(dim + 1)
118 | ]
119 | )
120 |
121 | return N
122 |
123 | # Auxiliary function for padding tensors with `torch.nan` to
124 | # a specific dimension. Will always return a `list`; we turn
125 | # it into a tensor depending on the call level.
126 | def _pad_tensors(tensors, N, value=torch.nan):
127 | return list(
128 | map(
129 | lambda t: torch.nn.functional.pad(
130 | t, (0, 0, N - len(t), 0), mode="constant", value=value
131 | ),
132 | tensors,
133 | )
134 | )
135 |
136 | N = _calculate_length(x, level)
137 |
138 | # List of lists: the first axis is treated as the batch axis, while
139 | # the second axis is treated as the dimension of diagrams or pairs.
140 | # This also handles ordinary lists, which will result in a batch of
141 | # size 1.
142 | if level <= 2:
143 | tensors = [
144 | make_tensor_from_persistence_information(pers_infos)
145 | for pers_infos in x
146 | ]
147 |
148 | # Pad all tensors to length N in the first dimension, then turn
149 | # them into a batch.
150 | result = torch.stack(_pad_tensors(tensors, N))
151 | return result
152 |
153 | # List of lists of lists: this indicates image-based data, where we
154 | # also have a set of tensors for each channel. The internal layout,
155 | # i.e. our input, has the following structure:
156 | #
157 | # B x C x D
158 | #
159 | # Each variable being the length of the respective list. We want an
160 | # output of the following shape:
161 | #
162 | # B x C x N x 3
163 | #
164 | # Here, `N` is the maximum length of an individual persistence
165 | # information object.
166 | else:
167 | tensors = [
168 | [
169 | make_tensor_from_persistence_information(pers_infos)
170 | for pers_infos in batch
171 | ]
172 | for batch in x
173 | ]
174 |
175 | # Pad all tensors to length N in the first dimension, then turn
176 | # them into a batch. We first stack over channels (inner), then
177 | # over the batch (outer).
178 | result = torch.stack(
179 | [
180 | torch.stack(_pad_tensors(batch_tensors, N))
181 | for batch_tensors in tensors
182 | ]
183 | )
184 |
185 | return result
186 |
187 |
188 | def make_tensor_from_persistence_information(
189 | pers_info, extract_generators=False
190 | ):
191 | """Convert (sequence) of persistence information entries to tensor.
192 |
193 | This function converts instance(s) of :class:`PersistenceInformation`
194 | objects into a single tensor. No padding will be performed. A client
195 | may specify what type of information to extract from the object. For
196 | instance, by default, the function will extract persistence diagrams
197 | but this behaviour can be changed by setting `extract_generators` to
198 | `true`.
199 |
200 | Parameters
201 | ----------
202 | pers_info : :class:`PersistenceInformation` or iterable thereof
203 | Input persistence information object(s). The function is able to
204 | handle both single objects and sequences. This has no bearing on
205 | the length of the returned tensor.
206 |
207 | extract_generators : bool
208 | If set, extracts generators instead of persistence diagram from
209 | `pers_info`.
210 |
211 | Returns
212 | -------
213 | torch.tensor
214 | Tensor of shape `(n, 3)`, where `n` is the sum of all features,
215 | over all dimensions in the input `pers_info`. Each triple shall
216 | be of the form `(creation, destruction, dim)` for a persistence
217 | diagram. If the client requested generators to be returned, the
218 | first two entries of the triple refer to *indices* with respect
219 | to the input data set. Depending on the algorithm employed, the
220 | meaning of these indices can change. Please refer to the module
221 | used to calculate persistent homology for more details.
222 | """
223 | # Looks a little bit cumbersome, but since `namedtuple` is iterable
224 | # as well, we need to ensure that we are actually dealing with more
225 | # than one instance here.
226 | if len(pers_info) > 1 and not isinstance(
227 | pers_info[0], PersistenceInformation
228 | ):
229 | pers_info = [pers_info]
230 |
231 | # TODO: This might not always work since the size of generators
232 | # changes in different dimensions.
233 | if extract_generators:
234 | pairs = torch.cat(
235 | [torch.as_tensor(x.pairing, dtype=torch.float) for x in pers_info],
236 | ).long()
237 | else:
238 | pairs = torch.cat(
239 | [torch.as_tensor(x.diagram, dtype=torch.float) for x in pers_info],
240 | ).float()
241 |
242 | dimensions = torch.cat(
243 | [
244 | torch.as_tensor(
245 | [x.dimension] * len(x.diagram),
246 | dtype=torch.long,
247 | device=x.diagram.device,
248 | )
249 | for x in pers_info
250 | ]
251 | )
252 |
253 | result = torch.column_stack((pairs, dimensions))
254 | return result
255 |
256 |
257 | def batch_handler(x, handler_fn, **kwargs):
258 | """Light-weight batch handling function.
259 |
260 | The purpose of this function is to simplify the handling of batches
261 | of input data, in particular for modules that deal with point cloud
262 | data. The handler essentially checks whether a 2D array (matrix) or
263 | a 3D array (tensor) was provided, and calls a handler function. The
264 | idea of the handler function is to handle an individual 2D array.
265 |
266 | Parameters
267 | ----------
268 | x : array_like
269 | Input point cloud(s). Can be either 2D array, indicating
270 | a single point cloud, or a 3D array, or even a *list* of
271 | point clouds (of potentially different cardinalities).
272 |
273 | handler_fn : callable
274 | Function to call for handling a 2D array.
275 |
276 | **kwargs
277 | Additional arguments to provide to `handler_fn`.
278 |
279 | Returns
280 | -------
281 | list or individual value
282 | Depending on whether `x` needs to be unwrapped, this function
283 | returns either a single value or a list of values, resulting
284 | from calling `handler_fn` on individual parts of `x`.
285 | """
286 | # Check whether individual batches need to be handled (3D array)
287 | # or not (2D array). We default to this type of processing for a
288 | # list as well.
289 | if isinstance(x, list) or len(x.shape) == 3:
290 | # TODO: This type of batch handling is rather ugly and
291 | # inefficient but at the same time, it is the easiest
292 | # workaround for now, permitting even 'ragged' inputs of
293 | # different lengths.
294 | return [handler_fn(torch.as_tensor(x_), **kwargs) for x_ in x]
295 | else:
296 | return handler_fn(torch.as_tensor(x), **kwargs)
297 |
298 |
299 | def batch_iter(x, dim=None):
300 | """Iterate over batches from input data.
301 |
302 | This utility function simplifies working with 'sparse' data sets
303 | consisting of :class:`PersistenceInformation` instances. It will
304 | present inputs in the order in which they appear in a batch such
305 | that instances belonging to the same data set are kept together.
306 |
307 | Parameters
308 | ----------
309 | x : recursively-nested list of :class:`PersistenceInformation`
310 | Input in sparse form, i.e. a nested structure containing
311 | persistence information about a data set.
312 |
313 | dim : int or `None`
314 | If set, only iterates over persistence information instances of
315 | the specified dimension. Else, will iterate over all instances.
316 |
317 | Returns
318 | -------
319 | A generator (iterable) that will either yield direct instances of
320 | :class:`PersistenceInformation` objects or further iterators into
321 | them. This ensures that it is possible to always iterate over the
322 | individual batches, without having to know internal details about
323 | the structure of `x`.
324 | """
325 | level = nesting_level(x)
326 |
327 | # Nothing to do for non-nested data structures, i.e. a single batch
328 | # that has been squeezed (for instance). Wrapping the input enables
329 | # us to treat it like a regular input again.
330 | if level == 1:
331 | x = [x]
332 |
333 | if level <= 2:
334 |
335 | def handler(x):
336 | return x
337 |
338 | # Remove the first dimension but also the subsequent one so that all
339 | # only iterables containing persistence information about a specific
340 | # data set are being returned.
341 | #
342 | # TODO: Generalise recursively? Do we want to support that?
343 | else:
344 |
345 | def handler(x):
346 | return chain.from_iterable(x)
347 |
348 | if dim is not None:
349 | for x_ in x:
350 | yield list(filter(lambda x: x.dimension == dim, handler(x_)))
351 | else:
352 | for x_ in x:
353 | yield handler(x_)
354 |
--------------------------------------------------------------------------------
/torch_topological/nn/distances.py:
--------------------------------------------------------------------------------
1 | """Distance calculation modules between topological descriptors."""
2 |
3 | import ot
4 | import torch
5 |
6 | from torch_topological.utils import wrap_if_not_iterable
7 |
8 |
9 | class WassersteinDistance(torch.nn.Module):
10 | """Implement Wasserstein distance between persistence diagrams.
11 |
12 | This module calculates the Wasserstein between two persistence
13 | diagrams. The Wasserstein distance is arguably the most common
14 | metric that is applied when dealing with such diagrams. Notice
15 | that calculating the metric involves solving optimal transport
16 | problems, which are known to suffer from scalability problems.
17 | When dealing with large persistence diagrams, other losses may
18 | be more appropriate.
19 | """
20 |
21 | def __init__(self, p=torch.inf, q=1):
22 | """Create new Wasserstein distance calculation module.
23 |
24 | Parameters
25 | ----------
26 | p : float or `inf`
27 | Specifies the exponent of the norm to calculate. By default,
28 | `p = torch.inf`, corresponding to the *maximum norm*.
29 |
30 | q: float
31 | Specifies the order of Wasserstein metric to calculate. This
32 | raises all internal matching costs to the power of `q`, hence
33 | subsequently returning the `q`-th root of the total cost.
34 | """
35 | super().__init__()
36 |
37 | self.p = p
38 | self.q = q
39 |
40 | def _project_to_diagonal(self, diagram):
41 | x = diagram[:, 0]
42 | y = diagram[:, 1]
43 |
44 | # TODO: Is this the closest point in all p-norms?
45 | return 0.5 * torch.stack(((x + y), (x + y)), 1)
46 |
47 | def _distance_to_diagonal(self, diagram):
48 | return torch.linalg.vector_norm(
49 | diagram - self._project_to_diagonal(diagram),
50 | self.p,
51 | dim=1
52 | )
53 |
54 | def _make_distance_matrix(self, D1, D2):
55 | dist_D11 = self._distance_to_diagonal(D1)
56 | dist_D22 = self._distance_to_diagonal(D2)
57 |
58 | # n x m matrix containing the distances between 'regular'
59 | # persistence pairs of both persistence diagrams.
60 | dist = torch.cdist(D1, D2, p=torch.inf)
61 |
62 | # Extend the matrix with a column of distances of samples in D1
63 | # to their respective projection on the diagonal.
64 | upper_blocks = torch.hstack((dist, dist_D11[:, None]))
65 |
66 | # Create a lower row of distances of samples in D2 to their
67 | # respective projection on the diagonal. The ordering needs
68 | # to follow the ordering of samples in D2. Note how one `0`
69 | # needs to be added to the row in order to balance it. The
70 | # entry intuitively describes the cost between *projected*
71 | # points, so it has to be zero.
72 | lower_blocks = torch.cat(
73 | (dist_D22, torch.tensor(0, device=dist_D22.device).unsqueeze(0))
74 | )
75 |
76 | # Full (n + 1 ) x (m + 1) matrix containing *all* distances. By
77 | # construction, M[[i, n] contains distances to projected points
78 | # in D1, whereas M[m, j] does the same for points in D2. Only a
79 | # cell M[i, j] with 0 <= i < n and 0 <= j < m contains a proper
80 | # distance.
81 | M = torch.vstack((upper_blocks, lower_blocks))
82 | M = M.pow(self.q)
83 |
84 | return M
85 |
86 | def forward(self, X, Y):
87 | """Calculate Wasserstein metric based on input tensors.
88 |
89 | Parameters
90 | ----------
91 | X : list or instance of :class:`PersistenceInformation`
92 | Topological features of the first space. Supposed to contain
93 | persistence diagrams and persistence pairings.
94 |
95 | Y : list or instance of :class:`PersistenceInformation`
96 | Topological features of the second space. Supposed to
97 | contain persistence diagrams and persistence pairings.
98 |
99 | Returns
100 | -------
101 | torch.tensor
102 | A single scalar tensor containing the distance between the
103 | persistence diagram(s) contained in `X` and `Y`.
104 | """
105 | total_cost = 0.0
106 |
107 | X = wrap_if_not_iterable(X)
108 | Y = wrap_if_not_iterable(Y)
109 |
110 | for pers_info in zip(X, Y):
111 | D1 = pers_info[0].diagram
112 | D2 = pers_info[1].diagram
113 |
114 | n = len(D1)
115 | m = len(D2)
116 |
117 | dist = self._make_distance_matrix(D1, D2)
118 |
119 | # Create weight vectors. Since the last entries of entries
120 | # describe the m points coming from D2, we have to set the
121 | # last entry accordingly.
122 |
123 | a = torch.ones(n + 1, device=dist.device)
124 | b = torch.ones(m + 1, device=dist.device)
125 |
126 | a[-1] = m
127 | b[-1] = n
128 |
129 | # TODO: Make settings configurable?
130 | total_cost += ot.emd2(a, b, dist)
131 |
132 | return total_cost.pow(1.0 / self.q)
133 |
--------------------------------------------------------------------------------
/torch_topological/nn/graphs.py:
--------------------------------------------------------------------------------
1 | """Layers for topological data analysis based on graphs.
2 |
3 | This is a work-in-progress module. The goal is to provide
4 | a functionality similar to what is described in TOGL; see
5 | https://github.com/BorgwardtLab/TOGL.
6 |
7 | At the moment, the following functionality is present:
8 | - simple deep set layer
9 | - simple TOGL implementation with deep set functions
10 | - basic GCN with TOGL
11 |
12 | The following aspects are currently ignored:
13 | - handling higher-order information properly
14 | - expanding simplicial complexes
15 | - making use of the dimension of features
16 |
17 | On the other hand, the current implementation is simplified in the sense
18 | of showing that *lower star filtrations* (and their corresponding
19 | generators) may be useful.
20 | """
21 |
22 | import torch
23 | import torch.nn as nn
24 |
25 | import gudhi as gd
26 |
27 | from torch_geometric.data import Data
28 |
29 | from torch_geometric.loader import DataLoader
30 |
31 | from torch_geometric.utils import erdos_renyi_graph
32 |
33 | from torch_geometric.nn import GCNConv
34 | from torch_geometric.nn import global_mean_pool
35 |
36 | from torch_scatter import scatter
37 |
38 | from torch_topological.utils import pairwise
39 |
40 |
41 | class DeepSetLayer(nn.Module):
42 | """Simple equivariant deep set layer."""
43 |
44 | def __init__(self, dim_in, dim_out, aggregation_fn):
45 | """Create new deep set layer.
46 |
47 | Parameters
48 | ----------
49 | dim_in : int
50 | Input dimension
51 |
52 | dim_out : int
53 | Output dimension
54 |
55 | aggregation_fn : str
56 | Aggregation to use for the reduction step. Must be valid for
57 | the ``torch_scatter.scatter()`` function, i.e. one of "sum",
58 | "mul", "mean", "min" or "max".
59 | """
60 | super().__init__()
61 |
62 | self.Gamma = nn.Linear(dim_in, dim_out)
63 | self.Lambda = nn.Linear(dim_in, dim_out, bias=False)
64 |
65 | self.aggregation_fn = aggregation_fn
66 |
67 | def forward(self, x, batch):
68 | """Implement forward pass through layer."""
69 | xm = scatter(x, batch, dim=0, reduce=self.aggregation_fn)
70 | xm = self.Lambda(xm)
71 |
72 | x = self.Gamma(x)
73 | x = x - xm[batch, :]
74 | return x
75 |
76 |
77 | class TOGL(nn.Module):
78 | """Implementation of TOGL, a topological graph layer.
79 |
80 | Some caveats: this implementation only focuses on a set function
81 | aggregation of topological features. At the moment, it is not as
82 | powerful and feature-complete as the original implementation.
83 | """
84 |
85 | def __init__(
86 | self,
87 | n_features,
88 | n_filtrations,
89 | hidden_dim,
90 | out_dim,
91 | aggregation_fn,
92 | ):
93 | super().__init__()
94 |
95 | self.n_filtrations = n_filtrations
96 |
97 | self.filtrations = nn.Sequential(
98 | nn.Linear(n_features, hidden_dim),
99 | nn.ReLU(),
100 | nn.Linear(hidden_dim, n_filtrations),
101 | )
102 |
103 | self.set_fn = nn.ModuleList(
104 | [
105 | nn.Linear(n_filtrations * 2, out_dim),
106 | nn.ReLU(),
107 | DeepSetLayer(out_dim, out_dim, aggregation_fn),
108 | nn.ReLU(),
109 | DeepSetLayer(
110 | out_dim,
111 | n_features,
112 | aggregation_fn,
113 | ),
114 | ]
115 | )
116 |
117 | self.batch_norm = nn.BatchNorm1d(n_features)
118 |
119 | def compute_persistent_homology(
120 | self,
121 | x,
122 | edge_index,
123 | vertex_slices,
124 | edge_slices,
125 | batch,
126 | n_nodes,
127 | return_filtration=False,
128 | ):
129 | """Return persistence pairs (i.e. generators)."""
130 | # Apply filtrations to node attributes. For the edge values, we
131 | # use a sublevel set filtration.
132 | #
133 | # TODO: Support different ways of filtering?
134 | filtered_v = self.filtrations(x)
135 | filtered_e, _ = torch.max(
136 | torch.stack(
137 | (filtered_v[edge_index[0]], filtered_v[edge_index[1]])
138 | ),
139 | axis=0,
140 | )
141 |
142 | filtered_v = filtered_v.transpose(1, 0).cpu().contiguous()
143 | filtered_e = filtered_e.transpose(1, 0).cpu().contiguous()
144 | edge_index = edge_index.cpu().transpose(1, 0).contiguous()
145 |
146 | # TODO: Do we have to enforce contiguous indices here?
147 | vertex_index = torch.arange(end=n_nodes, dtype=torch.int)
148 |
149 | # Fill all persistence information at the same time.
150 | persistence_diagrams = torch.empty(
151 | (self.n_filtrations, n_nodes, 2),
152 | dtype=torch.float,
153 | )
154 |
155 | for filt_index in range(self.n_filtrations):
156 | for (vi, vj), (ei, ej) in zip(
157 | pairwise(vertex_slices), pairwise(edge_slices)
158 | ):
159 | vertices = vertex_index[vi:vj]
160 | edges = edge_index[ei:ej]
161 |
162 | offset = vi
163 |
164 | f_vertices = filtered_v[filt_index][vi:vj]
165 | f_edges = filtered_e[filt_index][ei:ej]
166 |
167 | persistence_diagram = self._compute_persistent_homology(
168 | vertices, f_vertices, edges, f_edges, offset
169 | )
170 |
171 | persistence_diagrams[filt_index, vi:vj] = persistence_diagram
172 |
173 | # Make sure that the tensor is living on the proper device here;
174 | # all subsequent operations can happen either on the CPU *or* on
175 | # the GPU.
176 | persistence_diagrams = persistence_diagrams.to(x.device)
177 | return persistence_diagrams
178 |
179 | # Helper function for doing the actual calculation of topological
180 | # features of a graph.
181 | def _compute_persistent_homology(
182 | self, vertices, f_vertices, edges, f_edges, offset
183 | ):
184 | assert len(vertices) == len(f_vertices)
185 | assert len(edges) == len(f_edges)
186 |
187 | st = gd.SimplexTree()
188 |
189 | for v, f in zip(vertices, f_vertices):
190 | st.insert([v], filtration=f)
191 |
192 | for (u, v), f in zip(edges, f_edges):
193 | st.insert([u, v], filtration=f)
194 |
195 | st.make_filtration_non_decreasing()
196 | st.expansion(2)
197 | st.persistence()
198 |
199 | # The generators are split into "regular" and "essential"
200 | # vertices, sorted by dimension.
201 | generators = st.lower_star_persistence_generators()
202 | generators_regular, generators_essential = generators
203 |
204 | # TODO: Let's think about how to leverage *all* generators in
205 | # *all* dimensions.
206 | generators_regular = torch.as_tensor(generators_regular[0])
207 | generators_regular = generators_regular - offset
208 | generators_regular = generators_regular.sort(dim=0, stable=True)[0]
209 |
210 | # By default, every vertex is paired with itself, so we just
211 | # duplicate the information here.
212 | persistence_diagram = torch.stack((f_vertices, f_vertices), dim=1)
213 |
214 | # Map generators back to filtration values, thus adding non-trivial
215 | # tuples to the persistence diagram while preserving gradients.
216 | if len(generators_regular) > 0:
217 | persistence_diagram[generators_regular[:, 0], 1] = f_vertices[
218 | generators_regular[:, 1]
219 | ]
220 |
221 | return persistence_diagram
222 |
223 | def forward(self, x, data):
224 | """Implement forward pass through data."""
225 | # TODO: Is this the best signature? `data` is following directly
226 | # the convention of `PyG`.
227 | #
228 | # x : current node attributes of layer; we should not use the
229 | # original attributes here because they are not informed by a
230 | # previous layer.
231 | #
232 | # data : edge slice information etc.
233 |
234 | edge_index = data.edge_index
235 |
236 | vertex_slices = torch.Tensor(data._slice_dict["x"]).long()
237 | edge_slices = torch.Tensor(data._slice_dict["edge_index"]).long()
238 | batch = data.batch
239 |
240 | persistence_pairs = self.compute_persistent_homology(
241 | x,
242 | edge_index,
243 | vertex_slices,
244 | edge_slices,
245 | batch,
246 | n_nodes=data.num_nodes,
247 | )
248 |
249 | x0 = persistence_pairs.permute(1, 0, 2).reshape(
250 | persistence_pairs.shape[1], -1
251 | )
252 |
253 | for layer in self.set_fn:
254 | # Preserve batch information for our set function layer
255 | # instead of treating all inputs the same.
256 | if isinstance(layer, DeepSetLayer):
257 | x0 = layer(x0, batch)
258 | else:
259 | x0 = layer(x0)
260 |
261 | # TODO: Residual step; could be made optional. Plus, the optimal
262 | # order of operations is not clear.
263 | x = x + self.batch_norm(nn.functional.relu(x0))
264 | return x
265 |
266 |
267 | class TopoGCN(torch.nn.Module):
268 | def __init__(self):
269 | super().__init__()
270 |
271 | self.layers = nn.ModuleList([GCNConv(1, 8), GCNConv(8, 2)])
272 |
273 | self.pooling_fn = global_mean_pool
274 | self.togl = TOGL(8, 16, 32, 16, "mean")
275 |
276 | def forward(self, data):
277 | x, edge_index = data.x, data.edge_index
278 |
279 | for layer in self.layers[:1]:
280 | x = layer(x, edge_index)
281 |
282 | x = self.togl(x, data)
283 |
284 | for layer in self.layers[1:]:
285 | x = layer(x, edge_index)
286 |
287 | x = self.pooling_fn(x, data.batch)
288 | return x
289 |
290 |
291 | B = 64
292 | N = 100
293 | p = 0.2
294 |
295 | if torch.cuda.is_available():
296 | dev = "cuda:0"
297 | else:
298 | dev = "cpu"
299 |
300 | print("Selected device:", dev)
301 |
302 | data_list = [
303 | Data(x=torch.rand(N, 1), edge_index=erdos_renyi_graph(N, p), num_nodes=N)
304 | for i in range(B)
305 | ]
306 |
307 | loader = DataLoader(data_list, batch_size=8)
308 |
309 | model = TopoGCN().to(dev)
310 |
311 | for index, batch in enumerate(loader):
312 | print(batch)
313 | batch = batch.to(dev)
314 |
315 | vertex_slices = torch.Tensor(batch._slice_dict["x"]).long()
316 | edge_slices = torch.Tensor(batch._slice_dict["edge_index"]).long()
317 |
318 | model(batch)
319 |
--------------------------------------------------------------------------------
/torch_topological/nn/layers.py:
--------------------------------------------------------------------------------
1 | """Layers for processing persistence diagrams."""
2 |
3 | import torch
4 |
5 |
6 | class StructureElementLayer(torch.nn.Module):
7 | def __init__(
8 | self,
9 | n_elements
10 | ):
11 | super().__init__()
12 |
13 | self.n_elements = n_elements
14 | self.dim = 2 # TODO: Make configurable
15 |
16 | size = (self.n_elements, self.dim)
17 |
18 | self.centres = torch.nn.Parameter(
19 | torch.rand(*size)
20 | )
21 |
22 | self.sharpness = torch.nn.Parameter(
23 | torch.ones(*size) * 3
24 | )
25 |
26 | def forward(self, x):
27 | batch = torch.cat([x] * self.n_elements, 1)
28 |
29 | B, N, D = x.shape
30 |
31 | # This is a 'butchered' variant of the much nicer `SLayerExponential`
32 | # class by C. Hofer and R. Kwitt.
33 | #
34 | # https://c-hofer.github.io/torchph/_modules/torchph/nn/slayer.html#SLayerExponential
35 |
36 | centres = torch.cat([self.centres] * N, 1)
37 | centres = centres.view(-1, self.dim)
38 | centres = torch.stack([centres] * B, 0)
39 | centres = torch.cat((centres, 2 * batch[..., -1].unsqueeze(-1)), 2)
40 |
41 | sharpness = torch.pow(self.sharpness, 2)
42 | sharpness = torch.cat([sharpness] * N, 1)
43 | sharpness = sharpness.view(-1, self.dim)
44 | sharpness = torch.stack([sharpness] * B, 0)
45 | sharpness = torch.cat(
46 | (
47 | sharpness,
48 | torch.ones_like(batch[..., -1].unsqueeze(-1))
49 | ),
50 | 2
51 | )
52 |
53 | x = centres - batch
54 | x = x.pow(2)
55 | x = torch.mul(x, sharpness)
56 | x = torch.nansum(x, 2)
57 | x = torch.exp(-x)
58 | x = x.view(B, self.n_elements, -1)
59 | x = torch.sum(x, 2)
60 | x = x.squeeze()
61 |
62 | return x
63 |
--------------------------------------------------------------------------------
/torch_topological/nn/loss.py:
--------------------------------------------------------------------------------
1 | """Loss terms for various optimisation objectives."""
2 |
3 | import torch
4 |
5 | from torch_topological.utils import is_iterable
6 |
7 |
8 | class SummaryStatisticLoss(torch.nn.Module):
9 | r"""Implement loss based on summary statistic.
10 |
11 | This is a generic loss function based on topological summary
12 | statistics. It implements a loss of the following form:
13 |
14 | .. math:: \|s(X) - s(Y)\|^p
15 |
16 | In the preceding equation, `s` refers to a function that results in
17 | a scalar-valued summary of a persistence diagram.
18 | """
19 |
20 | def __init__(self, summary_statistic="total_persistence", **kwargs):
21 | """Create new loss function based on summary statistic.
22 |
23 | Parameters
24 | ----------
25 | summary_statistic : str
26 | Indicates which summary statistic function to use. Must be
27 | a summary statistics function that exists in the utilities
28 | module, i.e. :mod:`torch_topological.utils`.
29 |
30 | At present, the following choices are valid:
31 |
32 | - `torch_topological.utils.persistent_entropy`
33 | - `torch_topological.utils.polynomial_function`
34 | - `torch_topological.utils.total_persistence`
35 | - `torch_topological.utils.p_norm`
36 |
37 | **kwargs
38 | Optional keyword arguments, to be passed to the
39 | summary statistic function.
40 | """
41 | super().__init__()
42 |
43 | self.p = kwargs.get("p", 1.0)
44 | self.kwargs = kwargs
45 |
46 | import torch_topological.utils.summary_statistics as stat
47 |
48 | self.stat_fn = getattr(stat, summary_statistic, None)
49 |
50 | def forward(self, X, Y=None):
51 | r"""Calculate loss based on input tensor(s).
52 |
53 | Parameters
54 | ----------
55 | X : list of :class:`PersistenceInformation`
56 | Source information. Supposed to contain persistence diagrams
57 | and persistence pairings.
58 |
59 | Y : list of :class:`PersistenceInformation` or `None`
60 | Optional target information. If set, evaluates a difference
61 | in loss functions as shown in the introduction. If `None`,
62 | a simpler variant of the loss will be evaluated.
63 |
64 | Returns
65 | -------
66 | torch.tensor
67 | Loss based on the summary statistic selected by the client.
68 | Given a statistic :math:`s`, the function returns the
69 | following expression:
70 |
71 | .. math:: \|s(X) - s(Y)\|^p
72 |
73 | In case no target tensor `Y` has been provided, the latter part
74 | of the expression amounts to `0`.
75 | """
76 | stat_src = self._evaluate_stat_fn(X)
77 |
78 | if Y is not None:
79 | stat_target = self._evaluate_stat_fn(Y)
80 | return (stat_target - stat_src).abs().pow(self.p)
81 | else:
82 | return stat_src.abs().pow(self.p)
83 |
84 | def _evaluate_stat_fn(self, X):
85 | """Evaluate statistic function for a given tensor."""
86 | return torch.sum(
87 | torch.stack(
88 | [
89 | self.stat_fn(pers_info.diagram, **self.kwargs)
90 | for pers_info in X
91 | ]
92 | )
93 | )
94 |
95 |
96 | class SignatureLoss(torch.nn.Module):
97 | """Implement topological signature loss.
98 |
99 | This module implements the topological signature loss first
100 | described in [Moor20a]_. In contrast to the original code provided
101 | by the authors, this module also provides extensions to
102 | higher-dimensional generators if desired.
103 |
104 | The module can be used in conjunction with any set of generators and
105 | persistence diagrams, i.e. with any set of persistence pairings and
106 | persistence diagrams. At the moment, it is restricted to calculating
107 | a Minkowski distances for the loss calculation.
108 |
109 | References
110 | ----------
111 | .. [Moor20a] M. Moor et al., "Topological Autoencoders",
112 | *Proceedings of the 37th International Conference on Machine
113 | Learning*, PMLR 119, pp. 7045--7054, 2020.
114 | """
115 |
116 | def __init__(self, p=2, normalise=True, dimensions=0):
117 | """Create new loss instance.
118 |
119 | Parameters
120 | ----------
121 | p : float
122 | Exponent for the `p`-norm calculation of distances.
123 |
124 | normalise : bool
125 | If set, normalises distances for each point cloud. This can
126 | be useful when working with batches.
127 |
128 | dimensions : int or tuple of int
129 | Dimensions to use in the signature calculation. Following
130 | [Moor20a]_, this is set by default to `0`.
131 | """
132 | super().__init__()
133 |
134 | self.p = p
135 | self.normalise = normalise
136 | self.dimensions = dimensions
137 |
138 | # Ensure that we can iterate over the dimensions later on, as
139 | # this simplifies the code.
140 | if not is_iterable(self.dimensions):
141 | self.dimensions = [self.dimensions]
142 |
143 | def forward(self, X, Y):
144 | """Calculate the signature loss between two point clouds.
145 |
146 | This loss function uses the persistent homology from each point
147 | cloud in order to retrieve the topologically relevant distances
148 | from a distance matrix calculated from the point clouds. For
149 | more information, see [Moor20a]_.
150 |
151 | Parameters
152 | ----------
153 | X: Tuple[torch.tensor, PersistenceInformation]
154 | A tuple consisting of the point cloud and the persistence
155 | information of the point cloud. The persistent information
156 | is calculated by performing persistent homology calculation
157 | to retrieve a list of topologically relevant edges.
158 |
159 | Y: Tuple[torch.tensor, PersistenceInformation]
160 | A tuple consisting of the point cloud and the persistence
161 | information of the point cloud. The persistent information
162 | is calculated by performing persistent homology calculation
163 | to retrieve a list of topologically relevant edges.
164 |
165 | Returns
166 | -------
167 | torch.tensor
168 | A scalar representing the topological loss term for the two
169 | data sets.
170 | """
171 | X_point_cloud, X_persistence_info = X
172 | Y_point_cloud, Y_persistence_info = Y
173 |
174 | # Calculate the pairwise distance matrix between points in the
175 | # point cloud. Distances are calculated using the p-norm.
176 | X_pairwise_dist = torch.cdist(X_point_cloud, X_point_cloud, self.p)
177 | Y_pairwise_dist = torch.cdist(Y_point_cloud, Y_point_cloud, self.p)
178 |
179 | if self.normalise:
180 | X_pairwise_dist = X_pairwise_dist / X_pairwise_dist.max()
181 | Y_pairwise_dist = Y_pairwise_dist / Y_pairwise_dist.max()
182 |
183 | # Using the topologically relevant edges from point cloud X,
184 | # retrieve the corresponding distances from the pairwise
185 | # distance matrix of X.
186 | X_sig_X = [
187 | self._select_distances(
188 | X_pairwise_dist, X_persistence_info[dim].pairing
189 | )
190 | for dim in self.dimensions
191 | ]
192 |
193 | # Using the topologically relevant edges from point cloud Y,
194 | # retrieve the corresponding distances from the pairwise
195 | # distance matrix of X.
196 | X_sig_Y = [
197 | self._select_distances(
198 | X_pairwise_dist, Y_persistence_info[dim].pairing
199 | )
200 | for dim in self.dimensions
201 | ]
202 |
203 | # Using the topologically relevant edges from point cloud X,
204 | # retrieve the corresponding distances from the pairwise
205 | # distance matrix of Y.
206 | Y_sig_X = [
207 | self._select_distances(
208 | Y_pairwise_dist, X_persistence_info[dim].pairing
209 | )
210 | for dim in self.dimensions
211 | ]
212 |
213 | # Using the topologically relevant edges from point cloud Y,
214 | # retrieve the corresponding distances from the pairwise
215 | # distance matrix of Y.
216 | Y_sig_Y = [
217 | self._select_distances(
218 | Y_pairwise_dist, Y_persistence_info[dim].pairing
219 | )
220 | for dim in self.dimensions
221 | ]
222 |
223 | XY_dist = self._partial_distance(X_sig_X, Y_sig_X)
224 | YX_dist = self._partial_distance(Y_sig_Y, X_sig_Y)
225 |
226 | return torch.stack(XY_dist).sum() + torch.stack(YX_dist).sum()
227 |
228 | def _select_distances(self, pairwise_distance_matrix, generators):
229 | """Select topologically relevant edges from a pairwise distance matrix.
230 |
231 | Parameters
232 | ----------
233 | pairwise_distance_matrix: torch.tensor
234 | NxN pairwise distance matrix of a point cloud.
235 |
236 | generators: np.ndarray
237 | A 2D array consisting of indices corresponding to edges that
238 | correspond to the birth/destruction of some topological
239 | feature during persistent homology calculation. If the
240 | generator corresponds to topological features in
241 | 0-dimension, i.e. connected components, we only consider the
242 | edges that destroy connected components (we do not consider
243 | vertices). If the generator corresponds to topological
244 | features in > 0 dimensions, e.g holes or voids, we consider
245 | edges that create/destroy such topological features.
246 |
247 | Returns
248 | -------
249 | torch.tensor
250 | A vector that contains all of the topologically relevant
251 | distances.
252 | """
253 | # Dimension 0: only a mapping of vertices--edges is present, and
254 | # we must *only* access the edges.
255 | if generators.shape[1] == 3:
256 | selected_distances = pairwise_distance_matrix[
257 | generators[:, 1], generators[:, 2]
258 | ]
259 |
260 | # Dimension > 0: we can access all distances
261 | else:
262 | creator_distances = pairwise_distance_matrix[
263 | generators[:, 0], generators[:, 1]
264 | ]
265 | destroyer_distances = pairwise_distance_matrix[
266 | generators[:, 2], generators[:, 3]
267 | ]
268 |
269 | # Need to use `torch.abs` here because of the way the
270 | # signature lookup works. We are *not* guaranteed to
271 | # get 'valid' persistence values when using a pairing
272 | # from space X to access distances from space Y, for
273 | # instance, hence some of values could be *negative*.
274 | selected_distances = torch.abs(
275 | destroyer_distances - creator_distances
276 | )
277 |
278 | return selected_distances
279 |
280 | def _partial_distance(self, A, B):
281 | """
282 | Calculate partial distances between pairings.
283 |
284 | The purpose of this function is to calculate a partial distance
285 | for the loss, depending on distances selected from the pairing.
286 | """
287 | dist = [
288 | 0.5 * torch.linalg.vector_norm(a - b, ord=self.p)
289 | for a, b in zip(A, B)
290 | ]
291 |
292 | return dist
293 |
--------------------------------------------------------------------------------
/torch_topological/nn/multi_scale_kernel.py:
--------------------------------------------------------------------------------
1 | """Contains multi-scale kernel (scale space) kernel module."""
2 |
3 | import torch
4 |
5 | from torch_topological.utils import wrap_if_not_iterable
6 |
7 |
8 | class MultiScaleKernel(torch.nn.Module):
9 | # TODO: more detailed description
10 | r"""Implement the multi-scale kernel between two persistence diagrams.
11 |
12 | This class implements the multi-scale kernel between two persistence
13 | diagrams (also known as the scale space kernel) as defined by
14 | Reininghaus et al. [Reininghaus15a]_ as
15 |
16 | .. math::
17 | k_\sigma(F,G) = \frac{1}{8 \pi \sigma}
18 | \sum_{\substack{p \in F\\q \in G}} exp{-\frac{\|p-q\|^2}{8\sigma}}
19 | - exp{-\frac{\|p-\overline{q}\|^2}{8\sigma}}
20 |
21 | where :math:`z=(z_1, z_2)` and :math:`\overline{z}=(z_2, z_1)`
22 |
23 | References
24 | ----------
25 | .. [Reininghaus15a] J. Reininghaus, U. Bauer and R. Kwitt, "A Stable
26 | Multi-Scale Kernel for Topological Machine Learning", *Proceedings
27 | of the IEEE Conference on Computer Vision and Pattern Recognition*,
28 | pp. 4741--4748, 2015.
29 | """
30 |
31 | def __init__(self, sigma):
32 | """Create new instance of the kernel.
33 |
34 | Parameters
35 | ----------
36 | sigma : float
37 | scale parameter of the kernel
38 | """
39 | super().__init__()
40 |
41 | self.sigma = sigma
42 |
43 | @staticmethod
44 | def _check_upper(d):
45 | # Check if all points in the diagram are above the diagonal.
46 | # All points below the diagonal are 'swapped'.
47 | is_upper = d[:, 0] < d[:, 1]
48 | if not torch.all(is_upper):
49 | d[~is_upper, 0] = d[~is_upper, 1]
50 | d[~is_upper, 1] = d[~is_upper, 0]
51 | return d
52 |
53 | @staticmethod
54 | def _mirror(x):
55 | # Mirror one or multiple points of a persistence
56 | # diagram at the diagonal
57 | if len(x.shape) > 1:
58 | return x[:, [1, 0]]
59 | # only a single point in the diagram
60 | return x[[1, 0]]
61 |
62 | @staticmethod
63 | def _dist(x, y, p):
64 | # Compute the point-wise lp-distance between two
65 | # persistence diagrams
66 | dist = torch.cdist(x, y, p=p)
67 | return dist.pow(2)
68 |
69 | def forward(self, X, Y, p=2.):
70 | """Calculate the kernel value between two persistence diagrams.
71 |
72 | The kernel value is computed for each dimension of the persistence
73 | diagram individually, according to Equation 10 from Reininghaus et al.
74 | The final kernel value is computed as the sum of kernel values over
75 | all dimensions.
76 |
77 | Parameters
78 | ----------
79 | X : list or instance of :class:`PersistenceInformation`
80 | Topological features of the first space. Supposed to
81 | contain persistence diagrams and persistence pairings.
82 |
83 | Y : list or instance of :class:`PersistenceInformation`
84 | Topological features of the second space. Supposed to
85 | contain persistence diagrams and persistence pairings.
86 |
87 | p : float or inf, default 2.
88 | Specify which p-norm to use for distance calculation.
89 | For infinity/maximum norm pass p=float('inf').
90 | Please note that using norms other than the 2-norm
91 | (Euclidean norm) are not guaranteed to give positive
92 | definite results.
93 |
94 | Returns
95 | -------
96 | torch.tensor
97 | A single scalar tensor containing the kernel value between the
98 | persistence diagram(s) contained in `X` and `Y`.
99 |
100 | Examples
101 | --------
102 | >>> from torch_topological.data.shapes import sample_from_disk
103 | >>> from torch_topological.nn import VietorisRipsComplex
104 | >>> # sample randomly from two disks
105 | >>> x = sample_from_disk(r=0.5, R=0.6, n=100)
106 | >>> y = sample_from_disk(r=0.9, R=1.0, n=100)
107 | >>> # compute vietoris rips filtration for both point clouds
108 | >>> vr = VietorisRipsComplex(dim=1)
109 | >>> vr_x = vr(x)
110 | >>> vr_y = vr(y)
111 | >>> # compute kernel value between persistence
112 | >>> # diagrams with sigma set to 1
113 | >>> msk = MultiScaleKernel(1.)
114 | >>> msk_value = msk(vr_x, vr_y)
115 | """
116 | X_ = wrap_if_not_iterable(X)
117 | Y_ = wrap_if_not_iterable(Y)
118 |
119 | k_sigma = 0.0
120 |
121 | for pers_info in zip(X_, Y_):
122 | # ensure that all points in the diagram are
123 | # above the diagonal
124 | D1 = self._check_upper(pers_info[0].diagram)
125 | D2 = self._check_upper(pers_info[1].diagram)
126 |
127 | # compute the pairwise distances between the
128 | # two diagrams
129 | nom = self._dist(D1, D2, p)
130 | # distance between diagram 1 and mirrored
131 | # diagram 2
132 | denom = self._dist(D1, self._mirror(D2), p)
133 |
134 | M = torch.exp(-nom / (8 * self.sigma))
135 | M -= torch.exp(-denom / (8 * self.sigma))
136 |
137 | # sum over all points
138 | k_sigma += M.sum() / (8. * self.sigma * torch.pi)
139 |
140 | return k_sigma
141 |
--------------------------------------------------------------------------------
/torch_topological/nn/sliced_wasserstein_distance.py:
--------------------------------------------------------------------------------
1 | """Sliced Wasserstein distance implementation."""
2 |
3 |
4 | import torch
5 |
6 | import numpy as np
7 |
8 | from torch_topological.utils import wrap_if_not_iterable
9 |
10 |
11 | class SlicedWassersteinDistance(torch.nn.Module):
12 | """Calculate sliced Wasserstein distance between persistence diagrams.
13 |
14 | This is an implementation of the sliced Wasserstein distance between
15 | persistence diagrams, following [Carriere17a]_.
16 |
17 | This module calculates the sliced Wasserstein distance between two
18 | persistence diagrams. It is an efficient variant of the Wasserstein
19 | distance, and it is commonly used in the Sliced Wasserstein Kernel.
20 | It computes the expected value of the Wasserstein distance when the
21 | persistence diagram is projected on a random line passing through
22 | the origin.
23 | """
24 |
25 | def __init__(self, num_directions=10):
26 | """Create new sliced Wasserstein distance calculation module.
27 |
28 | Parameters
29 | ----------
30 | num_directions : int
31 | Specifies the number of random directions to be sampled for
32 | computation of the sliced Wasserstein distance.
33 | """
34 | super().__init__()
35 |
36 | # Generates num_directions number of lines with slopes randomly sampled
37 | # between -pi/2 and pi/2.
38 | self.num_directions = num_directions
39 | thetas = torch.linspace(-np.pi/2, np.pi/2, steps=self.num_directions+1)
40 | thetas = thetas[:-1]
41 | self.lines = torch.vstack([torch.tensor([torch.cos(i), torch.sin(i)],
42 | dtype=torch.float32) for i in thetas])
43 |
44 | def _emd1d(self, X, Y):
45 | # Compute Wasserstein Distance between two 1d-distributions.
46 | X, ind = torch.sort(X, dim=0)
47 | Y, ind = torch.sort(Y, dim=0)
48 | return torch.sum(torch.abs(torch.sub(X, Y)))
49 |
50 | def _project_diagram(self, D1, L):
51 | # Project persistence diagram D1 onto a given line L.
52 | return torch.stack([torch.dot(x, L)/torch.dot(L, L) for x in D1])
53 |
54 | def forward(self, X, Y):
55 | """Calculate sliced Wasserstein metric based on input tensors.
56 |
57 | Parameters
58 | ----------
59 | X : list or instance of :class:`PersistenceInformation`
60 | Topological features of the first space. Supposed to contain
61 | persistence diagrams and persistence pairings.
62 |
63 | Y : list or instance of :class:`PersistenceInformation`
64 | Topological features of the second space. Supposed to
65 | contain persistence diagrams and persistence pairings.
66 |
67 | Returns
68 | -------
69 | torch.tensor
70 | A single scalar tensor containing the sliced Wasserstein distance
71 | between the persistence diagram(s) contained in `X` and `Y`.
72 | """
73 | total_cost = 0.0
74 |
75 | X = wrap_if_not_iterable(X)
76 | Y = wrap_if_not_iterable(Y)
77 |
78 | for pers_info in zip(X, Y):
79 | D1 = pers_info[0].diagram.float()
80 | D2 = pers_info[1].diagram.float()
81 |
82 | # Auxiliary array to project onto diagonal.
83 | diag = torch.tensor([0.5, 0.5], dtype=torch.float32)
84 |
85 | # Project both the diagrams onto the diagonals.
86 | D1_diag = torch.sum(D1, dim=1, keepdim=True) * diag
87 | D2_diag = torch.sum(D2, dim=1, keepdim=True) * diag
88 |
89 | cost = 0.0
90 |
91 | for line in self.lines:
92 | proj_d1 = self._project_diagram(D1, line)
93 | proj_d2 = self._project_diagram(D2, line)
94 |
95 | proj_diag_d1 = self._project_diagram(D1_diag, line)
96 | proj_diag_d2 = self._project_diagram(D2_diag, line)
97 |
98 | cost += self._emd1d(torch.cat([proj_d1, proj_diag_d2]),
99 | torch.cat([proj_d2, proj_diag_d1]))
100 |
101 | cost /= self.num_directions
102 |
103 | total_cost += cost
104 |
105 | return total_cost
106 |
--------------------------------------------------------------------------------
/torch_topological/nn/sliced_wasserstein_kernel.py:
--------------------------------------------------------------------------------
1 | """Sliced Wasserstein kernel implementation."""
2 |
3 |
4 | import torch
5 |
6 | from torch_topological.nn import SlicedWassersteinDistance
7 |
8 | from torch_topological.utils import wrap_if_not_iterable
9 |
10 |
11 | class SlicedWassersteinKernel(torch.nn.Module):
12 | """Calculate sliced Wasserstein kernel between persistence diagrams.
13 |
14 | This is an implementation of the sliced Wasserstein kernel between
15 | persistence diagrams, following [Carriere17a]_.
16 |
17 | References
18 | ----------
19 | .. [Carriere17a] M. Carrière et al., "Sliced Wasserstein Kernel for
20 | Persistence Diagrams", *Proceedings of the 34th International
21 | Conference on Machine Learning*, PMLR 70, pp. 664--673, 2017.
22 | """
23 |
24 | def __init__(self, num_directions=10, sigma=1.0):
25 | """Create new sliced Wasserstein kernel module.
26 |
27 | Parameters
28 | ----------
29 | num_directions : int
30 | Specifies the number of random directions to be sampled for
31 | computation of the sliced Wasserstein distance.
32 |
33 | sigma : int
34 | Variance term of the sliced Wasserstein kernel expression.
35 | """
36 | super().__init__()
37 |
38 | self.num_directions = num_directions
39 | self.sigma = sigma
40 |
41 | def forward(self, X, Y):
42 | """Calculate sliced Wasserstein kernel based on input tensors.
43 |
44 | Parameters
45 | ----------
46 | X : list or instance of :class:`PersistenceInformation`
47 | Topological features of the first space. Supposed to contain
48 | persistence diagrams and persistence pairings.
49 |
50 | Y : list or instance of :class:`PersistenceInformation`
51 | Topological features of the second space. Supposed to
52 | contain persistence diagrams and persistence pairings.
53 |
54 | Returns
55 | -------
56 | torch.tensor
57 | A single scalar tensor containing the sliced Wasserstein kernel
58 | between the persistence diagram(s) contained in `X` and `Y`.
59 | """
60 | total_cost = 0.0
61 |
62 | X = wrap_if_not_iterable(X)
63 | Y = wrap_if_not_iterable(Y)
64 |
65 | swd = SlicedWassersteinDistance(num_directions=self.num_directions)
66 |
67 | for pers_info in zip(X, Y):
68 | D1 = pers_info[0]
69 | D2 = pers_info[1]
70 |
71 | total_cost += torch.exp(-swd(D1, D2)) / self.sigma
72 |
73 | return total_cost
74 |
--------------------------------------------------------------------------------
/torch_topological/nn/vietoris_rips_complex.py:
--------------------------------------------------------------------------------
1 | """Vietoris--Rips complex calculation module(s)."""
2 |
3 | from itertools import starmap
4 |
5 | from gph import ripser_parallel
6 | from torch import nn
7 |
8 | from torch_topological.nn import PersistenceInformation
9 | from torch_topological.nn.data import batch_handler
10 |
11 | import numpy
12 | import torch
13 |
14 |
15 | class VietorisRipsComplex(nn.Module):
16 | """Calculate Vietoris--Rips complex of a data set.
17 |
18 | This module calculates 'differentiable' persistence diagrams for
19 | point clouds. The underlying topological approximations are done
20 | by calculating a Vietoris--Rips complex of the data.
21 | """
22 |
23 | def __init__(
24 | self,
25 | dim=1,
26 | p=2,
27 | threshold=numpy.inf,
28 | keep_infinite_features=False,
29 | **kwargs
30 | ):
31 | """Initialise new module.
32 |
33 | Parameters
34 | ----------
35 | dim : int
36 | Calculates persistent homology up to (and including) the
37 | prescribed dimension.
38 |
39 | p : float
40 | Exponent indicating which Minkowski `p`-norm to use for the
41 | calculation of pairwise distances between points. Note that
42 | if `treat_as_distances` is supplied to :func:`forward`, the
43 | parameter is ignored and will have no effect. The rationale
44 | is to permit clients to use a pre-computed distance matrix,
45 | while always falling back to Minkowski norms.
46 |
47 | threshold : float
48 | If set to a finite number, only calculates topological
49 | features up to the specified distance threshold. Thus,
50 | any persistence pairings may contain infinite features
51 | as well.
52 |
53 | keep_infinite_features : bool
54 | If set, keeps infinite features. This flag is disabled by
55 | default. The rationale for this is that infinite features
56 | require more deliberate handling and, in case `threshold`
57 | is not changed, only a *single* infinite feature will not
58 | be considered in subsequent calculations.
59 |
60 | **kwargs
61 | Additional arguments to be provided to ``ripser``, i.e. the
62 | backend for calculating persistent homology. The `n_threads`
63 | parameter, which controls parallelisation, is probably the
64 | most relevant parameter to be adjusted.
65 | Please refer to the `the gitto-ph documentation
66 | `_
67 | for more details on admissible parameters.
68 |
69 | Notes
70 | -----
71 | This module currently only supports Minkowski norms. It does not
72 | yet support other metrics internally. To use custom metrics, you
73 | need to set `treat_as_distances` in the :func:`forward` function
74 | instead.
75 | """
76 | super().__init__()
77 |
78 | self.dim = dim
79 | self.p = p
80 | self.threshold = threshold
81 | self.keep_infinite_features = keep_infinite_features
82 |
83 | # Ensures that the same parameters are used whenever calling
84 | # `ripser`.
85 | self.ripser_params = {
86 | 'return_generators': True,
87 | 'maxdim': self.dim,
88 | 'thresh': self.threshold
89 | }
90 |
91 | self.ripser_params.update(kwargs)
92 |
93 | def forward(self, x, treat_as_distances=False):
94 | """Implement forward pass for persistence diagram calculation.
95 |
96 | The forward pass entails calculating persistent homology on
97 | a point cloud and returning a set of persistence diagrams.
98 |
99 | Parameters
100 | ----------
101 | x : array_like
102 | Input point cloud(s). `x` can either be a 2D array of shape
103 | `(n, d)`, which is treated as a single point cloud, or a 3D
104 | array/tensor of the form `(b, n, d)`, with `b` representing
105 | the batch size. Alternatively, you may also specify a list,
106 | possibly containing point clouds of non-uniform sizes.
107 |
108 | treat_as_distances : bool
109 | If set, treats `x` as containing pre-computed distances
110 | between points. The semantics of how `x` is handled are
111 | not changed; the only difference is that when `x` has a
112 | shape of `(n, d)`, the values of `n` and `d` need to be
113 | the same.
114 |
115 | Returns
116 | -------
117 | list of :class:`PersistenceInformation`
118 | List of :class:`PersistenceInformation`, containing both the
119 | persistence diagrams and the generators, i.e. the
120 | *pairings*, of a certain dimension of topological features.
121 | If `x` is a 3D array, returns a list of lists, in which the
122 | first dimension denotes the batch and the second dimension
123 | refers to the individual instances of
124 | :class:`PersistenceInformation` elements.
125 |
126 | Generators will be represented in the persistence pairing
127 | based on vertex--edge pairs (dimension 0) or edge--edge
128 | pairs. Thus, the persistence pairing in dimension zero will
129 | have three components, corresponding to a vertex and an
130 | edge, respectively, while the persistence pairing for higher
131 | dimensions will have four components.
132 | """
133 | return batch_handler(
134 | x,
135 | self._forward,
136 | treat_as_distances=treat_as_distances
137 | )
138 |
139 | def _forward(self, x, treat_as_distances=False):
140 | """Handle a *single* point cloud.
141 |
142 | This internal function handles the calculation of topological
143 | features for a single point cloud, i.e. an `array_like` of 2D
144 | shape.
145 |
146 | Parameters
147 | ----------
148 | x : array_like of shape `(n, d)`
149 | Single input point cloud.
150 |
151 | treat_as_distances : bool
152 | Flag indicating whether `x` should be treated as a distance
153 | matrix. See :func:`forward` for more information.
154 |
155 | Returns
156 | -------
157 | list of class:`PersistenceInformation`
158 | List of persistence information data structures, containing
159 | the persistence diagram and the persistence pairing of some
160 | dimension in the input data set.
161 | """
162 | if treat_as_distances:
163 | distances = x
164 | else:
165 | distances = torch.cdist(x, x, p=self.p)
166 |
167 | generators = ripser_parallel(
168 | distances.cpu().detach().numpy(),
169 | metric='precomputed',
170 | **self.ripser_params
171 | )['gens']
172 |
173 | # We always have 0D information.
174 | persistence_information = \
175 | self._extract_generators_and_diagrams(
176 | distances,
177 | generators,
178 | dim0=True,
179 | )
180 |
181 | if self.keep_infinite_features:
182 | persistence_information_inf = \
183 | self._extract_generators_and_diagrams(
184 | distances,
185 | generators,
186 | finite=False,
187 | dim0=True,
188 | )
189 |
190 | # Check whether we have any higher-dimensional information that
191 | # we should return.
192 | if self.dim >= 1:
193 | persistence_information.extend(
194 | self._extract_generators_and_diagrams(
195 | distances,
196 | generators,
197 | dim0=False,
198 | )
199 | )
200 |
201 | if self.keep_infinite_features:
202 | persistence_information_inf.extend(
203 | self._extract_generators_and_diagrams(
204 | distances,
205 | generators,
206 | finite=False,
207 | dim0=False,
208 | )
209 | )
210 |
211 | # Concatenation is only necessary if we want to keep infinite
212 | # features.
213 | if self.keep_infinite_features:
214 | persistence_information = self._concatenate_features(
215 | persistence_information, persistence_information_inf
216 | )
217 |
218 | return persistence_information
219 |
220 | def _extract_generators_and_diagrams(
221 | self,
222 | dist,
223 | gens,
224 | finite=True,
225 | dim0=False
226 | ):
227 | """Extract generators and persistence diagrams from raw data.
228 |
229 | This convenience function translates between the output of
230 | `ripser_parallel` and the required output of this function.
231 | """
232 | index = 1 if not dim0 else 0
233 |
234 | # Perform index shift to find infinite features in the tensor.
235 | if not finite:
236 | index += 2
237 |
238 | gens = gens[index]
239 |
240 | if dim0:
241 | if finite:
242 | # In a Vietoris--Rips complex, all vertices are created at
243 | # time zero.
244 | creators = torch.zeros_like(
245 | torch.as_tensor(gens)[:, 0],
246 | device=dist.device
247 | )
248 |
249 | destroyers = dist[gens[:, 1], gens[:, 2]]
250 | else:
251 | creators = torch.zeros_like(
252 | torch.as_tensor(gens)[:],
253 | device=dist.device
254 | )
255 |
256 | destroyers = torch.full_like(
257 | torch.as_tensor(gens)[:],
258 | torch.inf,
259 | dtype=torch.float,
260 | device=dist.device
261 | )
262 |
263 | inf_pairs = numpy.full(
264 | shape=(gens.shape[0], 2), fill_value=-1
265 | )
266 | gens = numpy.column_stack((gens, inf_pairs))
267 |
268 | persistence_diagram = torch.stack(
269 | (creators, destroyers), 1
270 | )
271 |
272 | return [PersistenceInformation(gens, persistence_diagram, 0)]
273 | else:
274 | result = []
275 |
276 | for index, gens_ in enumerate(gens):
277 | # Dimension zero is handled differently, so we need to
278 | # use an offset here. Note that this is not used as an
279 | # index into the `gens` array any more.
280 | dimension = index + 1
281 |
282 | if finite:
283 | creators = dist[gens_[:, 0], gens_[:, 1]]
284 | destroyers = dist[gens_[:, 2], gens_[:, 3]]
285 |
286 | persistence_diagram = torch.stack(
287 | (creators, destroyers), 1
288 | )
289 | else:
290 | creators = dist[gens_[:, 0], gens_[:, 1]]
291 |
292 | destroyers = torch.full_like(
293 | torch.as_tensor(gens_)[:, 0],
294 | torch.inf,
295 | dtype=torch.float,
296 | device=dist.device
297 | )
298 |
299 | # Create special infinite pairs; we pretend that we
300 | # are concatenating with unknown edges here.
301 | inf_pairs = numpy.full(
302 | shape=(gens_.shape[0], 2), fill_value=-1
303 | )
304 | gens_ = numpy.column_stack((gens_, inf_pairs))
305 |
306 | persistence_diagram = torch.stack(
307 | (creators, destroyers), 1
308 | )
309 |
310 | result.append(
311 | PersistenceInformation(
312 | gens_,
313 | persistence_diagram,
314 | dimension)
315 | )
316 |
317 | return result
318 |
319 | def _concatenate_features(self, pers_info_finite, pers_info_infinite):
320 | """Concatenate finite and infinite features."""
321 | def _apply(fin, inf):
322 | assert fin.dimension == inf.dimension
323 |
324 | diagram = torch.concat((fin.diagram, inf.diagram))
325 | pairing = numpy.concatenate((fin.pairing, inf.pairing), axis=0)
326 | dimension = fin.dimension
327 |
328 | return PersistenceInformation(
329 | pairing=pairing,
330 | diagram=diagram,
331 | dimension=dimension
332 | )
333 |
334 | return list(starmap(_apply, zip(pers_info_finite, pers_info_infinite)))
335 |
--------------------------------------------------------------------------------
/torch_topological/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Utilities module."""
2 |
3 | from .general import is_iterable
4 | from .general import nesting_level
5 | from .general import pairwise
6 | from .general import wrap_if_not_iterable
7 |
8 | from .filters import SelectByDimension
9 |
10 | from .summary_statistics import total_persistence
11 | from .summary_statistics import persistent_entropy
12 | from .summary_statistics import polynomial_function
13 |
14 | __all__ = [
15 | "is_iterable",
16 | "nesting_level",
17 | "pairwise",
18 | "persistent_entropy",
19 | "polynomial_function",
20 | "total_persistence",
21 | "wrap_if_not_iterable",
22 | "SelectByDimension",
23 | ]
24 |
--------------------------------------------------------------------------------
/torch_topological/utils/filters.py:
--------------------------------------------------------------------------------
1 | """Filter functions for algorithmic outputs."""
2 |
3 | import torch
4 |
5 |
6 | class SelectByDimension(torch.nn.Module):
7 | """Select persistence diagrams by dimension.
8 |
9 | This is a simple selector that enables filtering the outputs of
10 | persistent homology calculation algorithms by dimension: often,
11 | one is really only interested in a *specific* dimension but the
12 | corresponding algorithm yields more diagrams. To this end, this
13 | module can be applied as a lightweight filter.
14 | """
15 |
16 | def __init__(self, min_dim, max_dim=None):
17 | """Prepare filter for subsequent usage.
18 |
19 | Provides the filter with the required parameters. A minimum
20 | dimension must be provided. There is also the option to use
21 | a maximum dimension, thus permitting filtering by ranges.
22 |
23 | Parameters
24 | ----------
25 | min_dim : int
26 | Minimum dimension to allow through the filter. If this is
27 | the sole provided parameter, only diagrams satisfying the
28 | dimension requirement will be selected.
29 |
30 | max_dim : int
31 | Optional upper dimension. If set, the selection returns a
32 | diagram whose dimension is within `[min_dim, max_dim]`.
33 | """
34 | super().__init__()
35 |
36 | self.min_dim = min_dim
37 | self.max_dim = max_dim
38 |
39 | def forward(self, X):
40 | """Apply selection parameters to input.
41 |
42 | Iterate over input and select diagrams according to the
43 | pre-defined parameters.
44 |
45 | Parameters
46 | ----------
47 | X : iterable of `PersistenceInformation`
48 | An iterable containing `PersistenceInformation` objects at
49 | its lowest level.
50 |
51 | Returns
52 | -------
53 | iterable
54 | Input, i.e. `X`, but with all non-matching persistence
55 | diagrams removed.
56 | """
57 | return [
58 | pers_info for pers_info in X if self._is_valid(pers_info)
59 | ]
60 |
61 | def _is_valid(self, pers_info):
62 | if self.max_dim is not None:
63 | return self.min_dim <= pers_info.dimension <= self.max_dim
64 | else:
65 | return self.min_dim == pers_info.dimension
66 |
--------------------------------------------------------------------------------
/torch_topological/utils/general.py:
--------------------------------------------------------------------------------
1 | """General utility functions.
2 |
3 | This module contains generic utility functions that are used throughout
4 | the code base. If a function is 'relatively small' and somewhat generic
5 | it should be put here.
6 | """
7 |
8 | import itertools
9 |
10 |
11 | def pairwise(iterable):
12 | """Return iterator to iterate over consecutive pairs.
13 |
14 | Parameters
15 | ----------
16 | iterable : iterable
17 |
18 | Returns
19 | -------
20 | iterator
21 | An iterator to iterate over consecutive pairs of the input data.
22 |
23 | Notes
24 | -----
25 | A similar function appears in more recent versions of the
26 | ``collections`` module. For compatibility with old Python
27 | versions, we provide our own implementation.
28 |
29 | Examples
30 | --------
31 | >>> [x for x in pairwise(["ABC"])]
32 | ["AB", "BC"]
33 | """
34 | a, b = itertools.tee(iterable)
35 | next(b, None)
36 | return zip(a, b)
37 |
38 |
39 | def is_iterable(x):
40 | """Check whether variable is iterable.
41 |
42 | Parameters
43 | ----------
44 | x : any
45 | Input object.
46 |
47 | Returns
48 | -------
49 | bool
50 | `True` if `x` is iterable.
51 | """
52 | result = True
53 |
54 | # This is the most generic way; it also permits objects that only
55 | # implement the `__getitem__` interface.
56 | try:
57 | iter(x)
58 | except TypeError:
59 | result = False
60 |
61 | return result
62 |
63 |
64 | def wrap_if_not_iterable(x):
65 | """Wrap variable in case it cannot be iterated over.
66 |
67 | This function provides a convenience wrapper for variables that need
68 | to be iterated over. If the variable is already an `iterable`, there
69 | is nothing to be done and it will be returned as-is. Otherwise, will
70 | 'wrap' the variable to be the single item of a list.
71 |
72 | The primary purpose of this function is to make it easier for users
73 | to interact with certain classes: essentially, one does not have to
74 | think any more about single inputs vs. `iterable` inputs.
75 |
76 | Parameters
77 | ----------
78 | x : any
79 | Input object.
80 |
81 | Returns
82 | -------
83 | list or type of x
84 | If `x` can be iterated over, `x` will be returned as-is. Else,
85 | will return `[x]`, i.e. a list containing `x`.
86 |
87 | Examples
88 | --------
89 | >>> wrap_if_not_iterable(1.0)
90 | [1.0]
91 | >>> wrap_if_not_iterable('Hello, World!')
92 | 'Hello, World!'
93 | """
94 | if is_iterable(x):
95 | return x
96 | else:
97 | return [x]
98 |
99 |
100 | def nesting_level(x):
101 | """Calculate nesting level of a list of objects.
102 |
103 | To convert between sparse and dense representations of topological
104 | features, we need to determine the nesting level of an input list.
105 | The nesting level is defined as the maximum number of times we can
106 | recurse into the object while still obtaining lists.
107 |
108 | Parameters
109 | ----------
110 | x : list
111 | Input list of objects.
112 |
113 | Returns
114 | -------
115 | int
116 | Nesting level of `x`. If `x` has no well-defined nesting level,
117 | for example because `x` is not a list of something, will return
118 | `0`.
119 |
120 | Notes
121 | -----
122 | This function is implemented recursively. It is therefore a bad idea
123 | to apply it to objects with an extremely high nesting level.
124 |
125 | Examples
126 | --------
127 | >>> nesting_level([1, 2, 3])
128 | 1
129 |
130 | >>> nesting_level([[1, 2], [3, 4]])
131 | 2
132 | """
133 | # This is really only supposed to work with lists. Anything fancier,
134 | # for example a `torch.tensor`, can already be used as a dense data
135 | # structure.
136 | if not isinstance(x, list):
137 | return 0
138 |
139 | # Empty lists have a nesting level of 1.
140 | if len(x) == 0:
141 | return 1
142 | else:
143 | return max(nesting_level(y) for y in x) + 1
144 |
--------------------------------------------------------------------------------
/torch_topological/utils/summary_statistics.py:
--------------------------------------------------------------------------------
1 | """Summary statistics for persistence diagrams."""
2 |
3 | import torch
4 |
5 |
6 | def persistent_entropy(D, **kwargs):
7 | """Calculate persistent entropy of a persistence diagram.
8 |
9 | Parameters
10 | ----------
11 | D : `torch.tensor`
12 | Persistence diagram, assumed to be in shape `(n, 2)`, where each
13 | entry corresponds to a tuple of the form :math:`(x, y)`, with
14 | :math:`x` denoting the creation of a topological feature and
15 | :math:`y` denoting its destruction.
16 |
17 | Returns
18 | -------
19 | Persistent entropy of `D`.
20 | """
21 | persistence = torch.diff(D)
22 | persistence = persistence[torch.isfinite(persistence)].abs()
23 |
24 | P = persistence.sum()
25 | probabilities = persistence / P
26 |
27 | # Ensures that a probability of zero will just result in
28 | # a logarithm of zero as well. This is required whenever
29 | # one deals with entropy calculations.
30 | indices = probabilities > 0
31 | log_prob = torch.zeros_like(probabilities)
32 | log_prob[indices] = torch.log2(probabilities[indices])
33 |
34 | return torch.sum(-probabilities * log_prob)
35 |
36 |
37 | def polynomial_function(D, p, q, **kwargs):
38 | r"""Parametrise polynomial function over persistence diagrams.
39 |
40 | This function follows an approach by Adcock et al. [Adcock16a]_ and
41 | parametrises a polynomial function over a persistence diagram.
42 |
43 | Parameters
44 | ----------
45 | D : `torch.tensor`
46 | Persistence diagram, assumed to be in shape `(n, 2)`, where each
47 | entry corresponds to a tuple of the form :math:`(x, y)`, with
48 | :math:`x` denoting the creation of a topological feature and
49 | :math:`y` denoting its destruction.
50 |
51 | p : float
52 | Exponent for persistence differences in the diagram.
53 |
54 | q : float
55 | Exponent for mean persistence in the diagram.
56 |
57 | Returns
58 | -------
59 | Sum of the form :math:`\sigma L^p * \mu^q`, with :math:`L` denoting
60 | an individual persistence value, and :math:`\mu` denoting its
61 | average persistence.
62 |
63 | References
64 | ----------
65 | .. [Adcock16a] A. Adcock et al., "The Ring of Algebraic Functions on
66 | Persistence Bar Codes", *Homology, Homotopy and Applications*,
67 | Volume 18, Issue 1, pp. 381--402, 2016.
68 | """
69 | lengths = torch.diff(D)
70 | means = torch.sum(D, dim=-1, keepdim=True) / 2
71 |
72 | # Filter out non-finite values; the same mask works here because the
73 | # mean is non-finite if and only if the persistence is.
74 | mask = torch.isfinite(lengths)
75 | lengths = lengths[mask]
76 | means = means[mask]
77 |
78 | return torch.sum(torch.mul(lengths.pow(p), means.pow(q)))
79 |
80 |
81 | def total_persistence(D, p=2, **kwargs):
82 | """Calculate total persistence of a persistence diagram.
83 |
84 | This function will calculate the total persistence of a persistence
85 | diagram. Infinite values will be ignored.
86 |
87 | Parameters
88 | ----------
89 | D : `torch.tensor`
90 | Persistence diagram, assumed to be in shape `(n, 2)`, where each
91 | entry corresponds to a tuple of the form :math:`(x, y)`, with
92 | :math:`x` denoting the creation of a topological feature and
93 | :math:`y` denoting its destruction.
94 |
95 | p : float
96 | Weight parameter for the total persistence calculation.
97 |
98 | Returns
99 | -------
100 | float
101 | Total persistence of `D`.
102 | """
103 | persistence = torch.diff(D)
104 | persistence = persistence[torch.isfinite(persistence)]
105 |
106 | return persistence.abs().pow(p).sum()
107 |
108 |
109 | def p_norm(D, p=2, **kwargs):
110 | """Calculate :math:`p`-norm of a persistence diagram.
111 |
112 | This function will calculate the :math:`p`-norm of a persistence
113 | diagram. Infinite value will be ignored.
114 |
115 | Parameters
116 | ----------
117 | D : `torch.tensor`
118 | Persistence diagram, assumed to be in shape `(n, 2)`, where each
119 | entry corresponds to a tuple of the form :math:`(x, y)`, with
120 | :math:`x` denoting the creation of a topological feature and
121 | :math:`y` denoting its destruction.
122 |
123 | p : float
124 | Weight parameter for the norm calculation. It must be valid to
125 | raise a term to the :math:`p`th power, but the function has no
126 | additional checks for this.
127 |
128 | Returns
129 | -------
130 | float
131 | :math:`p`-norm of `D`.
132 | """
133 | return torch.pow(total_persistence(D), 1.0 / p)
134 |
--------------------------------------------------------------------------------