├── docs ├── .nojekyll ├── objects.inv ├── ._index.html ├── _static │ ├── file.png │ ├── plus.png │ ├── minus.png │ ├── fonts │ │ ├── Lato-Bold.ttf │ │ ├── Inconsolata.ttf │ │ ├── Lato-Regular.ttf │ │ ├── Lato │ │ │ ├── lato-bold.eot │ │ │ ├── lato-bold.ttf │ │ │ ├── lato-bold.woff │ │ │ ├── lato-bold.woff2 │ │ │ ├── lato-italic.eot │ │ │ ├── lato-italic.ttf │ │ │ ├── lato-italic.woff │ │ │ ├── lato-italic.woff2 │ │ │ ├── lato-regular.eot │ │ │ ├── lato-regular.ttf │ │ │ ├── lato-regular.woff │ │ │ ├── lato-bolditalic.eot │ │ │ ├── lato-bolditalic.ttf │ │ │ ├── lato-regular.woff2 │ │ │ ├── lato-bolditalic.woff │ │ │ └── lato-bolditalic.woff2 │ │ ├── RobotoSlab-Bold.ttf │ │ ├── Inconsolata-Bold.ttf │ │ ├── RobotoSlab-Regular.ttf │ │ ├── Inconsolata-Regular.ttf │ │ ├── fontawesome-webfont.eot │ │ ├── fontawesome-webfont.ttf │ │ ├── fontawesome-webfont.woff │ │ ├── fontawesome-webfont.woff2 │ │ └── RobotoSlab │ │ │ ├── roboto-slab-v7-bold.eot │ │ │ ├── roboto-slab-v7-bold.ttf │ │ │ ├── roboto-slab-v7-bold.woff │ │ │ ├── roboto-slab-v7-bold.woff2 │ │ │ ├── roboto-slab-v7-regular.eot │ │ │ ├── roboto-slab-v7-regular.ttf │ │ │ ├── roboto-slab-v7-regular.woff │ │ │ └── roboto-slab-v7-regular.woff2 │ ├── documentation_options.js │ ├── css │ │ └── badge_only.css │ ├── js │ │ └── theme.js │ ├── pygments.css │ └── doctools.js ├── _sources │ ├── nn.rst.txt │ ├── pershom.rst.txt │ ├── install │ │ └── index.rst.txt │ ├── index.rst.txt │ └── tutorials │ │ └── SLayer.ipynb.txt ├── _images │ ├── tutorials_InputOptim_4_0.png │ ├── tutorials_InputOptim_6_1.png │ ├── tutorials_InputOptim_8_1.png │ ├── tutorials_ToyDiffVR_10_0.png │ ├── tutorials_ToyDiffVR_15_0.png │ ├── tutorials_ToyDiffVR_6_0.png │ ├── tutorials_InputOptim_10_1.png │ └── tutorials_ComparisonSOTA_7_0.png ├── .buildinfo ├── _modules │ └── index.html ├── search.html ├── py-modindex.html ├── tutorials │ └── SLayer.ipynb ├── searchindex.js └── install │ └── index.html ├── docs_src ├── .nojekyll ├── source │ ├── tutorials │ │ ├── .gitignore │ │ ├── shared_code.py │ │ └── SLayer.ipynb │ ├── nn.rst │ ├── _static │ │ └── .gitignore │ ├── pershom.rst │ ├── install │ │ └── index.rst │ ├── index.rst │ └── conf.py └── Makefile ├── tests ├── .gitignore ├── __init__.py └── pershom │ ├── __init__.py │ ├── test_pershom_backend_data │ ├── random_simplicial_complexes │ │ ├── random_sp__args__100_200_300.pickle │ │ ├── random_sp__args__100_200_300_400.pickle │ │ └── random_sp__args__100_100_100_100_100.pickle │ └── random_point_clouds │ │ ├── 09.txt │ │ ├── 06.txt │ │ ├── 00.txt │ │ ├── 03.txt │ │ ├── 04.txt │ │ ├── 07.txt │ │ ├── 01.txt │ │ ├── 02.txt │ │ ├── 05.txt │ │ └── 08.txt │ └── test_pershom_backend.py ├── torchph ├── __init__.py ├── pershom │ ├── __init__.py │ ├── pershom_cpp_src │ │ ├── tensor_utils.cuh │ │ ├── cuda_checks.cuh │ │ ├── makefile │ │ ├── vertex_filtration_comp_cuda.h │ │ ├── calc_pers_cuda.cuh │ │ ├── param_checks_cuda.cuh │ │ ├── tensor_utils.cu │ │ ├── vr_comp_cuda.cuh │ │ ├── py_bindings.cpp │ │ └── vertex_filtration_comp_cuda.cpp │ └── pershom_backend.py └── nn │ └── __init__.py ├── pershom_dev ├── profiling_pershom │ ├── data │ │ ├── max_dimension.txt │ │ ├── boundary_array_size.txt │ │ └── buggy_case_1 │ │ │ ├── max_dimension.txt │ │ │ ├── boundary_array_size.txt │ │ │ ├── column_dimension.txt │ │ │ └── boundary_array.txt │ ├── .gitignore │ ├── profile.sh │ ├── makefile │ └── profile_case.cpp ├── bug_1_sp.pickle ├── high_level_profile.cProfile ├── setup.py ├── calc_pers.py ├── refactor_vc_comp.py ├── generate_test_data.py ├── generate_profile_data.py ├── bug_1.py ├── vc_comp_script.py ├── simplicial_complex.py ├── pershom_script.py └── cpu_sorted_boundary_array_implementation.py ├── requirements.txt ├── run_test.sh ├── .gitignore ├── update_docs.sh ├── setup.py ├── technical_readme.md ├── LICENSE └── README.md /docs/.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs_src/.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchph/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/pershom/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pershom_dev/profiling_pershom/data/max_dimension.txt: -------------------------------------------------------------------------------- 1 | 2 -------------------------------------------------------------------------------- /docs_src/source/tutorials/.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | -------------------------------------------------------------------------------- /torchph/pershom/__init__.py: -------------------------------------------------------------------------------- 1 | from .pershom_backend import * 2 | -------------------------------------------------------------------------------- /pershom_dev/profiling_pershom/data/boundary_array_size.txt: -------------------------------------------------------------------------------- 1 | 17805 6 -------------------------------------------------------------------------------- /pershom_dev/profiling_pershom/data/buggy_case_1/max_dimension.txt: -------------------------------------------------------------------------------- 1 | 2 -------------------------------------------------------------------------------- /pershom_dev/profiling_pershom/.gitignore: -------------------------------------------------------------------------------- 1 | *.google-pprof 2 | profile 3 | -------------------------------------------------------------------------------- /pershom_dev/profiling_pershom/data/buggy_case_1/boundary_array_size.txt: -------------------------------------------------------------------------------- 1 | 93 6 -------------------------------------------------------------------------------- /docs/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/objects.inv -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | docs: 2 | pip install sphinx 3 | pip install sphinx_rtd_theme -------------------------------------------------------------------------------- /docs/._index.html: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/._index.html -------------------------------------------------------------------------------- /docs/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/file.png -------------------------------------------------------------------------------- /docs/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/plus.png -------------------------------------------------------------------------------- /docs/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/minus.png -------------------------------------------------------------------------------- /docs/_sources/nn.rst.txt: -------------------------------------------------------------------------------- 1 | ``nn`` 2 | ====== 3 | 4 | .. automodule:: torchph.nn.slayer 5 | :members: -------------------------------------------------------------------------------- /docs_src/source/nn.rst: -------------------------------------------------------------------------------- 1 | ``nn`` 2 | ====== 3 | 4 | .. automodule:: torchph.nn.slayer 5 | :members: -------------------------------------------------------------------------------- /run_test.sh: -------------------------------------------------------------------------------- 1 | find tests/ | grep -E "(__pycache__|\.pyc|\.pyo$)" | xargs rm -rf 2 | pytest tests/ 3 | -------------------------------------------------------------------------------- /torchph/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .slayer import SLayerExponential, SLayerRational, SLayerRationalHat 2 | -------------------------------------------------------------------------------- /pershom_dev/bug_1_sp.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/pershom_dev/bug_1_sp.pickle -------------------------------------------------------------------------------- /docs/_static/fonts/Lato-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato-Bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Inconsolata.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Inconsolata.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato-Regular.ttf -------------------------------------------------------------------------------- /docs_src/source/_static/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato/lato-bold.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato/lato-bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato/lato-bold.woff -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/RobotoSlab-Bold.ttf -------------------------------------------------------------------------------- /docs/_images/tutorials_InputOptim_4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_images/tutorials_InputOptim_4_0.png -------------------------------------------------------------------------------- /docs/_images/tutorials_InputOptim_6_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_images/tutorials_InputOptim_6_1.png -------------------------------------------------------------------------------- /docs/_images/tutorials_InputOptim_8_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_images/tutorials_InputOptim_8_1.png -------------------------------------------------------------------------------- /docs/_images/tutorials_ToyDiffVR_10_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_images/tutorials_ToyDiffVR_10_0.png -------------------------------------------------------------------------------- /docs/_images/tutorials_ToyDiffVR_15_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_images/tutorials_ToyDiffVR_15_0.png -------------------------------------------------------------------------------- /docs/_images/tutorials_ToyDiffVR_6_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_images/tutorials_ToyDiffVR_6_0.png -------------------------------------------------------------------------------- /docs/_static/fonts/Inconsolata-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Inconsolata-Bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato/lato-italic.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato/lato-italic.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato/lato-italic.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato/lato-italic.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato/lato-regular.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato/lato-regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato/lato-regular.woff -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/RobotoSlab-Regular.ttf -------------------------------------------------------------------------------- /docs_src/source/pershom.rst: -------------------------------------------------------------------------------- 1 | ``pershom`` 2 | =========== 3 | 4 | .. automodule:: torchph.pershom.pershom_backend 5 | :members: 6 | -------------------------------------------------------------------------------- /pershom_dev/high_level_profile.cProfile: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/pershom_dev/high_level_profile.cProfile -------------------------------------------------------------------------------- /docs/_images/tutorials_InputOptim_10_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_images/tutorials_InputOptim_10_1.png -------------------------------------------------------------------------------- /docs/_sources/pershom.rst.txt: -------------------------------------------------------------------------------- 1 | ``pershom`` 2 | =========== 3 | 4 | .. automodule:: torchph.pershom.pershom_backend 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/_static/fonts/Inconsolata-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Inconsolata-Regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato/lato-bolditalic.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato/lato-bolditalic.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato/lato-regular.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/_images/tutorials_ComparisonSOTA_7_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_images/tutorials_ComparisonSOTA_7_0.png -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato/lato-bolditalic.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/Lato/lato-bolditalic.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | develop.py 2 | develop_2.py 3 | .idea 4 | .vscode 5 | __pycache__ 6 | *.pyc 7 | .cache 8 | .pytest_cache 9 | pershom_dev/extensions_sandbox 10 | -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2 -------------------------------------------------------------------------------- /update_docs.sh: -------------------------------------------------------------------------------- 1 | rm -r docs/* 2 | 3 | cd docs_src 4 | 5 | sphinx-build -b html source ../docs 6 | 7 | cd ../docs 8 | 9 | rm -r .doctrees 10 | 11 | touch .nojekyll -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup( 5 | name="torchph", 6 | version="0.0.0", 7 | packages=setuptools.find_packages(exclude=('tests*',)) 8 | ) 9 | -------------------------------------------------------------------------------- /torchph/pershom/pershom_cpp_src/tensor_utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | 4 | #include 5 | 6 | 7 | using namespace torch; 8 | 9 | 10 | namespace TensorUtils{ 11 | void fill_range_cuda_(Tensor t); 12 | } -------------------------------------------------------------------------------- /pershom_dev/profiling_pershom/data/buggy_case_1/column_dimension.txt: -------------------------------------------------------------------------------- 1 | 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 -------------------------------------------------------------------------------- /tests/pershom/test_pershom_backend_data/random_simplicial_complexes/random_sp__args__100_200_300.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/tests/pershom/test_pershom_backend_data/random_simplicial_complexes/random_sp__args__100_200_300.pickle -------------------------------------------------------------------------------- /docs/.buildinfo: -------------------------------------------------------------------------------- 1 | # Sphinx build info version 1 2 | # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. 3 | config: 54af9f4344c5084308a0935f67140759 4 | tags: 645f666f9bcd5a90fca523b33c5a78b7 5 | -------------------------------------------------------------------------------- /tests/pershom/test_pershom_backend_data/random_simplicial_complexes/random_sp__args__100_200_300_400.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/tests/pershom/test_pershom_backend_data/random_simplicial_complexes/random_sp__args__100_200_300_400.pickle -------------------------------------------------------------------------------- /tests/pershom/test_pershom_backend_data/random_simplicial_complexes/random_sp__args__100_100_100_100_100.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/torchph/HEAD/tests/pershom/test_pershom_backend_data/random_simplicial_complexes/random_sp__args__100_100_100_100_100.pickle -------------------------------------------------------------------------------- /technical_readme.md: -------------------------------------------------------------------------------- 1 | # Update docs 2 | 3 | ```bash 4 | git clone https://github.com/c-hofer/torchph.git 5 | git pull 6 | pip install sphinx 7 | pip install nbsphinx 8 | pip install sphinx_rtd_theme 9 | cd torchph 10 | bash update_docs.sh 11 | ``` 12 | 13 | Then open `docs/index.html`. 14 | -------------------------------------------------------------------------------- /pershom_dev/profiling_pershom/profile.sh: -------------------------------------------------------------------------------- 1 | export LD_LIBRARY_PATH="/usr/local/lib:/scratch2/chofer/opt/anaconda3/envs/pyt_gh/lib/:/scratch2/chofer/opt/anaconda3/envs/pyt_gh/lib/python3.6/site-packages/torch/lib" 2 | 3 | rm -f profile 4 | 5 | echo "Compiling ..." 6 | 7 | make 8 | 9 | make clean 10 | 11 | 12 | echo "Compiled! Now executing ... " 13 | ./profile 14 | 15 | 16 | -------------------------------------------------------------------------------- /docs/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: '0.0.0', 4 | LANGUAGE: 'None', 5 | COLLAPSE_INDEX: false, 6 | BUILDER: 'html', 7 | FILE_SUFFIX: '.html', 8 | HAS_SOURCE: true, 9 | SOURCELINK_SUFFIX: '.txt', 10 | NAVIGATION_WITH_KEYS: false 11 | }; -------------------------------------------------------------------------------- /pershom_dev/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='pershom_backend_C', 6 | ext_modules=[ 7 | CUDAExtension('pershom_backend_C_cuda', [ 8 | 'pershom_cpp_src/pershom_cuda.cu', 9 | 'pershom_cpp_src/pershom.cpp', 10 | ]) 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /torchph/pershom/pershom_cpp_src/cuda_checks.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | 4 | #include 5 | 6 | 7 | #define cudaCheckError() { \ 8 | cudaError_t e=cudaGetLastError(); \ 9 | if(e!=cudaSuccess) { \ 10 | printf("Cuda failure %s:%d: '%s'\n",__FILE__,__LINE__,cudaGetErrorString(e)); \ 11 | exit(0); \ 12 | } \ 13 | } 14 | -------------------------------------------------------------------------------- /pershom_dev/calc_pers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchph.pershom.pershom_backend as pershom_backend 3 | 4 | device = torch.device('cuda') 5 | dtype = torch.int64 6 | 7 | ba = torch.empty([0, 2], device=device, dtype=dtype) 8 | ba_row_i_to_bm_col_i = ba 9 | simplex_dimension = torch.zeros(10, device=device, dtype=dtype) 10 | print(simplex_dimension) 11 | 12 | # ret = pershom_backend.__C.CalcPersCuda__calculate_persistence(ba, ba_row_i_to_bm_col_i, simplex_dimension, 3, -1) 13 | ret = pershom_backend.__C.CalcPersCuda__my_test_f(ba); 14 | 15 | print(ret) -------------------------------------------------------------------------------- /docs_src/source/tutorials/shared_code.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | 5 | def check_torchph_availability(): 6 | try: 7 | import torchph 8 | 9 | except ImportError: 10 | sys.path.append(os.path.dirname(os.getcwd())) 11 | 12 | try: 13 | import torchph 14 | 15 | except ImportError as ex: 16 | raise ImportError( 17 | """ 18 | Could not import torchph. Running your python \ 19 | interpreter in the 'tutorials' sub folder could resolve \ 20 | this issue. 21 | """ 22 | ) from ex 23 | -------------------------------------------------------------------------------- /docs_src/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /torchph/pershom/pershom_cpp_src/makefile: -------------------------------------------------------------------------------- 1 | CC=g++ 2 | 3 | LIB=-L/scratch2/chofer/opt/anaconda3/envs/pyt_gh/lib/python3.6/site-packages/torch/lib #-lATen 4 | INC=-I/scratch2/chofer/opt/anaconda3/envs/pyt_gh/lib/python3.6/site-packages/torch/lib/include \ 5 | -I /scratch2/chofer/opt/anaconda3/envs/pyt_gh/include/python3.6m 6 | FLAGS= -Wall -std=c++11 7 | 8 | 9 | lib: pershom.o pershom_cuda.o 10 | $(CC) -shared $(FLAGS) -o libpershom.so pershom.o pershom_cuda.o $(INC) $(LIB) 11 | 12 | pershom.o: pershom.cpp 13 | $(CC) -fPIC $(FLAGS) -c pershom.cpp $(INC) $(LIB) 14 | 15 | pershom_cuda.o: pershom_cuda.cu 16 | nvcc -shared -Xcompiler -fPIC -std=c++11 -c pershom_cuda.cu $(INC) 17 | 18 | clean: 19 | rm -r *.o 20 | 21 | -------------------------------------------------------------------------------- /torchph/pershom/pershom_cpp_src/vertex_filtration_comp_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | 4 | #include 5 | 6 | 7 | using namespace torch; 8 | 9 | 10 | namespace VertFiltCompCuda 11 | { 12 | std::vector vert_filt_comp_calculate_persistence_args( 13 | const Tensor & vertex_filtration, 14 | const std::vector & boundary_info); 15 | 16 | std::vector> vert_filt_persistence_single( 17 | const Tensor & vertex_filtration, 18 | const std::vector & boundary_info); 19 | 20 | std::vector>> vert_filt_persistence_batch( 21 | const std::vector>> & batch 22 | ); 23 | } -------------------------------------------------------------------------------- /pershom_dev/refactor_vc_comp.py: -------------------------------------------------------------------------------- 1 | import torchph.pershom.pershom_backend as pershom_backend 2 | import torch 3 | import time 4 | from scipy.special import binom 5 | from itertools import combinations 6 | 7 | 8 | from collections import Counter 9 | 10 | 11 | point_cloud = [(0, 0), (1, 0), (0, 0.5), (1, 1.5)] 12 | point_cloud = torch.tensor(point_cloud, device='cuda', dtype=torch.float, requires_grad=True) 13 | 14 | def l1_norm(x, y): 15 | return float((x-y).abs().sum()) 16 | 17 | testee = pershom_backend.__C.VRCompCuda__PointCloud2VR_factory("l1") 18 | 19 | args = testee(point_cloud, 2, 2) 20 | print(args[3]) 21 | args[3].sum().backward() 22 | print(point_cloud.grad) 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | # print(c(torch.rand((10, 3), device='cuda'), 2, 0)) 40 | # print(c.filtration_values_by_dim) -------------------------------------------------------------------------------- /torchph/pershom/pershom_cpp_src/calc_pers_cuda.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | 4 | #include 5 | 6 | 7 | using namespace torch; 8 | 9 | namespace CalcPersCuda 10 | { 11 | 12 | Tensor find_merge_pairings( 13 | const Tensor & pivots, 14 | int64_t max_pairs = -1); 15 | 16 | void merge_columns( 17 | const Tensor & compr_desc_sort_ba, 18 | const Tensor & merge_pairs); 19 | 20 | std::vector> read_barcodes( 21 | const Tensor & pivots, 22 | const Tensor & simplex_dimension, 23 | int64_t max_dim_to_read_of_reduced_ba); 24 | 25 | std::vector> calculate_persistence( 26 | const Tensor & compr_desc_sort_ba, 27 | const Tensor & ba_row_i_to_bm_col_i, 28 | const Tensor & simplex_dimension, 29 | int64_t max_dim_to_read_of_reduced_ba, 30 | int64_t max_pairs); 31 | 32 | } // namespace CalcPersCuda -------------------------------------------------------------------------------- /torchph/pershom/pershom_cpp_src/param_checks_cuda.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | 4 | #include 5 | 6 | 7 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_TENSOR_CUDA_CONTIGUOUS(x) \ 10 | CHECK_CUDA(x); \ 11 | CHECK_CONTIGUOUS(x) 12 | #define CHECK_TENSOR_INT64(x) AT_ASSERTM(x.type().scalarType() == ScalarType::Long, "expected " #x "to be of scalar type int64") 13 | 14 | #define CHECK_SMALLER_EQ(x, y) AT_ASSERTM(x <= y, "expected " #x "<=" #y) 15 | #define CHECK_EQUAL(x, y) AT_ASSERTM(x == y, "expected " #x "==" #y) 16 | #define CHECK_GREATER_EQ(x, y) AT_ASSERTM(x >= y, "expected " #x ">=" #y) 17 | 18 | #define CHECK_SAME_DEVICE(x, y) AT_ASSERTM(x.device() == y.device(), #x, #y "are not on same device") 19 | 20 | 21 | #define PRINT(x) std::cout << x << std::endl -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Christoph Hofer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pershom_dev/profiling_pershom/makefile: -------------------------------------------------------------------------------- 1 | CC=g++ 2 | 3 | LIB=-L/usr/local/lib\ 4 | -L/usr/local/cuda-9.2/targets/x86_64-linux/lib \ 5 | -L/scratch2/chofer/opt/anaconda3/envs/pyt_gh/lib \ 6 | -L/scratch2/chofer/opt/anaconda3/envs/pyt_gh/lib/python3.6/site-packages/torch/lib \ 7 | -ldl -lprofiler\ 8 | -lcudart \ 9 | -liomp5 -lmklml_intel -lpython3.6m \ 10 | -lcaffe2_gpu -lcaffe2 -lshm -lnccl \ 11 | -L/scratch2/chofer/opt/anaconda3/envs/pyt_gh/lib/python3.6/site-packages/torch\ 12 | 13 | 14 | INC=-I/scratch2/chofer/opt/anaconda3/envs/pyt_gh/lib/python3.6/site-packages/torch/lib/include\ 15 | -I/scratch2/chofer/opt/anaconda3/envs/pyt_gh/include/python3.6m\ 16 | -I/usr/local/cuda-9.2/include 17 | FLAGS= -Wall -std=c++11 -DPROFILE 18 | 19 | 20 | profile: profile_case.o pershom.o pershom_cuda.o 21 | $(CC) $(FLAGS) -o profile profile_case.o pershom.o pershom_cuda.o $(INC) $(LIB) 22 | 23 | profile_case.o : 24 | $(CC) $(FLAGS) -c profile_case.cpp $(INC) $(LIB) 25 | 26 | pershom.o: ../../torchph/pershom/pershom_cpp_src/pershom.cpp 27 | $(CC) $(FLAGS) -c ../../torchph/pershom/pershom_cpp_src/pershom.cpp $(INC) $(LIB) 28 | 29 | pershom_cuda.o: ../../torchph/pershom/pershom_cpp_src/pershom_cuda.cu 30 | nvcc -std=c++11 -c ../../torchph/pershom/pershom_cpp_src/pershom_cuda.cu $(INC) $(LIB) 31 | 32 | clean: 33 | rm -r *.o -------------------------------------------------------------------------------- /torchph/pershom/pershom_cpp_src/tensor_utils.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_checks.cuh" 5 | #include "param_checks_cuda.cuh" 6 | 7 | 8 | using namespace at; 9 | 10 | 11 | namespace TensorUtils{ 12 | 13 | namespace { 14 | 15 | template 16 | __global__ void fill_range_kernel(scalar_t* out, int64_t out_numel){ 17 | auto index = blockIdx.x * blockDim.x + threadIdx.x; 18 | 19 | if (index < out_numel){ 20 | out[index] = index; 21 | } 22 | } 23 | } 24 | 25 | void fill_range_cuda_(Tensor t) 26 | { 27 | CHECK_CUDA(t); 28 | at::OptionalDeviceGuard guard(device_of(t)); 29 | 30 | const int threads_per_block = 256; 31 | const int blocks = t.numel()/threads_per_block + 1; 32 | 33 | auto scalar_type = t.scalar_type(); 34 | switch(scalar_type) 35 | { 36 | case ScalarType::Int: 37 | fill_range_kernel<<>>(t.data_ptr(), t.numel()); 38 | break; 39 | 40 | case ScalarType::Long: 41 | fill_range_kernel<<>>(t.data_ptr(), t.numel()); 42 | break; 43 | 44 | default: 45 | throw std::invalid_argument("Unrecognized Type!"); 46 | } 47 | 48 | cudaCheckError(); 49 | } 50 | 51 | } // namespace TensorUtils -------------------------------------------------------------------------------- /docs/_sources/install/index.rst.txt: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | The following setup was tested with the following system configuration: 5 | 6 | * Ubuntu 18.04.2 LTS 7 | * CUDA 10.1 (driver version 418.87.00) 8 | * Anaconda (Python 3.7.6) 9 | * PyTorch 1.4 10 | 11 | In the following, we assume that we work in ``/tmp`` (obviously, you have to 12 | change this to reflect your choice and using ``/tmp`` is, of course, not 13 | the best choice :). 14 | 15 | First, get the Anaconda installer and install Anaconda (in ``/tmp/anaconda3``) 16 | using 17 | 18 | .. code-block:: bash 19 | 20 | cd /tmp/ 21 | wget https://repo.anaconda.com/archive/Anaconda3-2019.10-Linux-x86_64.sh 22 | bash Anaconda3-2019.10-Linux-x86_64.sh 23 | # specify /tmp/anaconda3 as your installation path 24 | source /tmp/anaconda3/bin/activate 25 | 26 | Second, we install PyTorch (v1.4) using 27 | 28 | .. code-block:: bash 29 | 30 | conda install pytorch torchvision cudatoolkit=10.1 -c pytorch 31 | 32 | 33 | Third, we clone the ``torchph`` repository from GitHub and make 34 | it available within Anaconda. 35 | 36 | .. code-block:: bash 37 | 38 | cd /tmp/ 39 | git clone https://github.com/c-hofer/torchph.git 40 | conda develop /tmp/torchph 41 | 42 | A quick check if everything works can be done with 43 | 44 | .. code-block:: python 45 | 46 | >>> import torchph 47 | 48 | .. note:: 49 | 50 | At the moment, we only have GPU support available. CPU support 51 | is not planned yet, as many other packages exist which support 52 | PH computation on the CPU. 53 | 54 | -------------------------------------------------------------------------------- /docs_src/source/install/index.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | The following setup was tested with the following system configuration: 5 | 6 | * Ubuntu 18.04.2 LTS 7 | * CUDA 10.1 (driver version 418.87.00) 8 | * Anaconda (Python 3.7.6) 9 | * PyTorch 1.4 10 | 11 | In the following, we assume that we work in ``/tmp`` (obviously, you have to 12 | change this to reflect your choice and using ``/tmp`` is, of course, not 13 | the best choice :). 14 | 15 | First, get the Anaconda installer and install Anaconda (in ``/tmp/anaconda3``) 16 | using 17 | 18 | .. code-block:: bash 19 | 20 | cd /tmp/ 21 | wget https://repo.anaconda.com/archive/Anaconda3-2019.10-Linux-x86_64.sh 22 | bash Anaconda3-2019.10-Linux-x86_64.sh 23 | # specify /tmp/anaconda3 as your installation path 24 | source /tmp/anaconda3/bin/activate 25 | 26 | Second, we install PyTorch (v1.4) using 27 | 28 | .. code-block:: bash 29 | 30 | conda install pytorch torchvision cudatoolkit=10.1 -c pytorch 31 | 32 | 33 | Third, we clone the ``torchph`` repository from GitHub and make 34 | it available within Anaconda. 35 | 36 | .. code-block:: bash 37 | 38 | cd /tmp/ 39 | git clone https://github.com/c-hofer/torchph.git 40 | conda develop /tmp/torchph 41 | 42 | A quick check if everything works can be done with 43 | 44 | .. code-block:: python 45 | 46 | >>> import torchph 47 | 48 | .. note:: 49 | 50 | At the moment, we only have GPU support available. CPU support 51 | is not planned yet, as many other packages exist which support 52 | PH computation on the CPU. 53 | 54 | -------------------------------------------------------------------------------- /pershom_dev/profiling_pershom/data/buggy_case_1/boundary_array.txt: -------------------------------------------------------------------------------- 1 | -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 7 5 -1 -1 -1 -1 8 5 -1 -1 -1 -1 4 1 -1 -1 -1 -1 9 6 -1 -1 -1 -1 6 3 -1 -1 -1 -1 8 7 -1 -1 -1 -1 3 2 -1 -1 -1 -1 2 0 -1 -1 -1 -1 5 3 -1 -1 -1 -1 8 1 -1 -1 -1 -1 9 4 -1 -1 -1 -1 4 0 -1 -1 -1 -1 7 2 -1 -1 -1 -1 7 3 -1 -1 -1 -1 2 1 -1 -1 -1 -1 6 5 -1 -1 -1 -1 9 8 -1 -1 -1 -1 5 2 -1 -1 -1 -1 7 1 -1 -1 -1 -1 8 4 -1 -1 -1 -1 6 0 -1 -1 -1 -1 6 4 -1 -1 -1 -1 8 2 -1 -1 -1 -1 7 6 -1 -1 -1 -1 6 1 -1 -1 -1 -1 5 4 -1 -1 -1 -1 4 2 -1 -1 -1 -1 5 1 -1 -1 -1 -1 7 0 -1 -1 -1 -1 3 0 -1 -1 -1 -1 9 5 -1 -1 -1 -1 1 0 -1 -1 -1 -1 9 2 -1 -1 -1 -1 7 4 -1 -1 -1 -1 8 6 -1 -1 -1 -1 9 7 -1 -1 -1 -1 4 3 -1 -1 -1 -1 8 0 -1 -1 -1 -1 3 1 -1 -1 -1 -1 9 1 -1 -1 -1 -1 5 0 -1 -1 -1 -1 9 3 -1 -1 -1 -1 6 2 -1 -1 -1 -1 44 31 29 -1 -1 -1 49 34 13 -1 -1 -1 48 24 16 -1 -1 -1 31 20 13 -1 -1 -1 41 24 17 -1 -1 -1 37 27 24 -1 -1 -1 47 38 15 -1 -1 -1 36 32 29 -1 -1 -1 34 33 28 -1 -1 -1 48 37 18 -1 -1 -1 41 21 12 -1 -1 -1 51 45 23 -1 -1 -1 27 22 10 -1 -1 -1 50 35 21 -1 -1 -1 43 35 10 -1 -1 -1 41 38 28 -1 -1 -1 45 26 15 -1 -1 -1 47 41 19 -1 -1 -1 36 21 17 -1 -1 -1 34 31 12 -1 -1 -1 15 11 10 -1 -1 -1 52 33 22 -1 -1 -1 39 30 14 -1 -1 -1 51 14 13 -1 -1 -1 38 22 17 -1 -1 -1 40 25 13 -1 -1 -1 52 42 13 -1 -1 -1 48 28 23 -1 -1 -1 42 40 27 -1 -1 -1 52 34 24 -1 -1 -1 47 29 21 -1 -1 -1 46 36 16 -1 -1 -1 52 27 25 -1 -1 -1 32 24 19 -1 -1 -1 36 35 27 -1 -1 -1 49 45 28 -1 -1 -1 39 17 16 -1 -1 -1 48 34 14 -1 -1 -1 49 40 37 -1 -1 -1 51 49 48 -1 -1 -1 -------------------------------------------------------------------------------- /tests/pershom/test_pershom_backend_data/random_point_clouds/09.txt: -------------------------------------------------------------------------------- 1 | -6.354418992996215820e-01 2.851496124267578125e+01 8.732179641723632812e+00 2 | -3.880949974060058594e+00 1.137017607688903809e+00 2.816700363159179688e+01 3 | -1.733258628845214844e+01 1.361482429504394531e+01 3.732596397399902344e+00 4 | 6.984520435333251953e+00 2.385647964477539062e+01 2.254618263244628906e+01 5 | 1.265965270996093750e+01 -5.491594314575195312e+00 3.671529293060302734e+00 6 | 5.588706493377685547e+00 1.840843200683593750e+01 1.881173324584960938e+01 7 | 1.173019123077392578e+01 3.493070602416992188e+00 -9.680926322937011719e+00 8 | -1.931317806243896484e+00 1.681844711303710938e+00 1.082907009124755859e+01 9 | 1.125778865814208984e+01 -1.035340785980224609e+01 2.278220891952514648e+00 10 | 6.730436325073242188e+00 1.772405952215194702e-02 -3.428169488906860352e+00 11 | 1.663151931762695312e+01 -6.609946727752685547e+00 2.343542480468750000e+01 12 | 5.656902313232421875e+00 -5.299721240997314453e+00 -2.454715728759765625e+00 13 | -7.677423954010009766e+00 2.668741607666015625e+01 4.579981803894042969e+00 14 | 7.027086257934570312e+00 -6.528122425079345703e+00 9.524543762207031250e+00 15 | -6.829839706420898438e+00 2.094888114929199219e+01 -4.094919204711914062e+00 16 | -4.066775321960449219e+00 4.786375045776367188e+00 1.059037780761718750e+01 17 | 6.120678901672363281e+00 5.419001579284667969e+00 -4.117619514465332031e+00 18 | 2.133938598632812500e+01 1.223642826080322266e+01 2.026872634887695312e+01 19 | -1.755358695983886719e+00 1.888538742065429688e+01 3.669851303100585938e+00 20 | 4.925955772399902344e+00 -9.080457687377929688e+00 1.062245488166809082e+00 21 | -------------------------------------------------------------------------------- /tests/pershom/test_pershom_backend_data/random_point_clouds/06.txt: -------------------------------------------------------------------------------- 1 | 1.770122909545898438e+01 6.914246559143066406e+00 1.933807182312011719e+01 2 | 1.123076725006103516e+01 -2.333509063720703125e+01 -8.656030654907226562e+00 3 | 1.017960071563720703e+01 -3.776865959167480469e+00 -7.845626354217529297e+00 4 | 1.014298915863037109e+01 1.187571525573730469e+01 2.992667198181152344e+00 5 | -2.338955402374267578e-01 5.095698356628417969e+00 -2.353112459182739258e+00 6 | 1.131519889831542969e+01 -1.137460708618164062e+01 6.904096603393554688e-01 7 | 2.586653232574462891e+00 7.005810737609863281e-01 2.440843963623046875e+01 8 | -6.545204639434814453e+00 9.657912254333496094e+00 -1.133399963378906250e+01 9 | 1.820130920410156250e+01 -1.324907970428466797e+01 -1.625947189331054688e+01 10 | 1.159400463104248047e+00 1.088666534423828125e+01 1.807940483093261719e+01 11 | 4.778804779052734375e+00 1.657044601440429688e+01 4.726283073425292969e+00 12 | 2.271530151367187500e+01 1.778459548950195312e+01 6.959418296813964844e+00 13 | 6.825378417968750000e+00 8.128238677978515625e+00 1.170329666137695312e+01 14 | -3.883583307266235352e+00 -2.231578063964843750e+01 -1.070453643798828125e+01 15 | -3.620388031005859375e+00 -1.458454608917236328e+00 1.515957069396972656e+01 16 | -2.512928390502929688e+01 -5.790386199951171875e-01 -1.131245613098144531e+00 17 | 1.880904960632324219e+01 -2.659870147705078125e+01 -5.910870552062988281e+00 18 | -2.066169261932373047e+00 2.650362730026245117e+00 -1.925411701202392578e+00 19 | -6.992560386657714844e+00 -4.599108219146728516e+00 -4.340693950653076172e-01 20 | 2.972890615463256836e+00 3.573164367675781250e+01 1.819517517089843750e+01 21 | -------------------------------------------------------------------------------- /tests/pershom/test_pershom_backend_data/random_point_clouds/00.txt: -------------------------------------------------------------------------------- 1 | 1.753428840637207031e+01 -4.500185012817382812e+00 4.682235717773437500e-01 2 | 6.025600433349609375e-02 -1.957279205322265625e+01 -2.991626262664794922e+00 3 | 1.968634033203125000e+01 7.508693695068359375e+00 6.393542289733886719e-01 4 | -1.193013668060302734e+00 -2.650102996826171875e+01 1.935586214065551758e+00 5 | 1.506752204895019531e+01 -2.003410911560058594e+01 -8.993811607360839844e-01 6 | 5.611409187316894531e+00 4.044103622436523438e-02 2.952403068542480469e+00 7 | 1.224860382080078125e+01 -6.847229480743408203e+00 1.990456199645996094e+01 8 | 2.904957962036132812e+01 2.324859428405761719e+01 1.325541591644287109e+01 9 | 6.919024944305419922e+00 -8.943810462951660156e-01 6.396169662475585938e+00 10 | -6.702182769775390625e+00 -5.443521499633789062e+00 9.209456443786621094e-01 11 | -2.448676490783691406e+01 -1.512960433959960938e+01 -2.161047697067260742e+00 12 | -1.359420871734619141e+01 -3.092990636825561523e+00 6.911022663116455078e-01 13 | 2.382064104080200195e+00 7.481824874877929688e+00 1.292202281951904297e+01 14 | -1.910583019256591797e+00 3.095754623413085938e+01 1.900699234008789062e+01 15 | 1.375919151306152344e+01 1.100754547119140625e+01 -6.891832828521728516e+00 16 | -8.518700599670410156e+00 1.263058280944824219e+01 1.380192089080810547e+01 17 | 1.574968528747558594e+01 -5.353053092956542969e+00 -6.263266086578369141e+00 18 | -6.853740692138671875e+00 -2.236871910095214844e+01 -4.613399505615234375e-03 19 | -1.169267082214355469e+01 -1.769064521789550781e+01 1.068235301971435547e+01 20 | -4.828259944915771484e-01 -3.337502479553222656e+00 2.380093383789062500e+01 21 | -------------------------------------------------------------------------------- /tests/pershom/test_pershom_backend_data/random_point_clouds/03.txt: -------------------------------------------------------------------------------- 1 | 3.343302726745605469e+00 -1.746718597412109375e+01 -1.829320335388183594e+01 2 | -1.935898971557617188e+01 -1.322010898590087891e+01 4.535314083099365234e+00 3 | 2.005539321899414062e+01 -8.997367858886718750e+00 1.660152435302734375e+00 4 | -9.446917533874511719e+00 1.447911262512207031e+01 -1.560514259338378906e+01 5 | -3.147289276123046875e+00 -9.999388694763183594e+00 -1.405495071411132812e+01 6 | 1.553272438049316406e+01 6.232530593872070312e+00 9.322841644287109375e+00 7 | 2.402659606933593750e+01 1.838321924209594727e+00 2.661139011383056641e+00 8 | -1.343440723419189453e+01 -4.685816764831542969e-01 -1.514006042480468750e+01 9 | 9.744992256164550781e+00 -6.428390026092529297e+00 -2.659524202346801758e+00 10 | 1.409182929992675781e+01 8.406430482864379883e-01 8.414267539978027344e+00 11 | 1.277910327911376953e+01 1.410011100769042969e+01 -8.149508476257324219e+00 12 | -4.342250347137451172e+00 9.465246200561523438e+00 3.384090900421142578e+00 13 | 1.977715492248535156e-01 1.909025192260742188e+01 -1.221620368957519531e+01 14 | 1.142005252838134766e+01 -2.150404167175292969e+01 -1.185356497764587402e+00 15 | -6.336628913879394531e+00 -1.719444465637207031e+01 -4.062208175659179688e+00 16 | -7.335078239440917969e+00 -2.014678955078125000e+00 1.270345783233642578e+01 17 | 1.421373081207275391e+01 -1.173992156982421875e+01 1.328297901153564453e+01 18 | 1.609474945068359375e+01 1.255855464935302734e+01 1.855561256408691406e+00 19 | -1.242047691345214844e+01 -1.334993934631347656e+01 -1.488025379180908203e+01 20 | -5.690674781799316406e-01 -1.352301311492919922e+01 3.821725845336914062e+00 21 | -------------------------------------------------------------------------------- /tests/pershom/test_pershom_backend_data/random_point_clouds/04.txt: -------------------------------------------------------------------------------- 1 | 1.245974445343017578e+01 -5.715072154998779297e+00 -8.644862174987792969e-01 2 | -1.276222038269042969e+01 1.902816772460937500e+00 4.716373085975646973e-01 3 | 1.029317665100097656e+01 2.834968566894531250e+00 1.496321010589599609e+01 4 | -1.174104118347167969e+01 1.695847702026367188e+01 -1.463916492462158203e+01 5 | 5.676942825317382812e+00 -5.096730709075927734e+00 2.632406234741210938e+00 6 | 1.092394828796386719e+00 -1.501656341552734375e+01 -1.031024456024169922e+01 7 | -6.467421054840087891e+00 -1.160396194458007812e+01 1.470543861389160156e+01 8 | -1.762504577636718750e+01 -1.413936710357666016e+01 3.142774581909179688e+01 9 | 1.672815895080566406e+01 1.947975158691406250e+01 -1.495285701751708984e+01 10 | 2.177924537658691406e+01 2.762977218627929688e+01 -2.303795242309570312e+01 11 | -3.762039661407470703e+00 -8.857637405395507812e+00 -3.668912410736083984e+00 12 | 1.469023513793945312e+01 -1.349047470092773438e+01 -1.360225081443786621e+00 13 | -1.613547897338867188e+01 -8.901642799377441406e+00 2.956092834472656250e+00 14 | -9.738894462585449219e+00 -1.355378913879394531e+01 1.227253723144531250e+01 15 | -1.169795989990234375e+01 1.207351875305175781e+01 -9.561182975769042969e+00 16 | 8.361349105834960938e-01 3.609262943267822266e+00 1.438896942138671875e+01 17 | 1.665669631958007812e+01 -1.130194568634033203e+01 1.508170223236083984e+01 18 | -2.297048187255859375e+01 -2.116431713104248047e+00 6.556929111480712891e+00 19 | -1.091447257995605469e+01 9.879340171813964844e+00 1.361890888214111328e+01 20 | 7.588554382324218750e+00 -2.157068061828613281e+01 4.903929710388183594e+00 21 | -------------------------------------------------------------------------------- /tests/pershom/test_pershom_backend_data/random_point_clouds/07.txt: -------------------------------------------------------------------------------- 1 | 6.636974811553955078e+00 -4.373825073242187500e+00 -2.882398986816406250e+01 2 | 1.689126014709472656e+00 -7.226819515228271484e+00 -5.111994743347167969e+00 3 | -1.397345829010009766e+01 -9.936675071716308594e+00 9.829759597778320312e+00 4 | 1.249024009704589844e+01 1.495247459411621094e+01 -1.787715339660644531e+01 5 | 9.347402572631835938e+00 -5.516538619995117188e+00 6.091839790344238281e+00 6 | 1.465076923370361328e+00 -2.554033088684082031e+01 -2.093047523498535156e+01 7 | 6.774293899536132812e+00 6.370829105377197266e+00 -7.091963291168212891e+00 8 | -3.535380363464355469e+00 1.219524621963500977e+00 -9.942526817321777344e+00 9 | 2.098946762084960938e+01 1.224189996719360352e+00 -1.022173881530761719e+01 10 | -1.153464508056640625e+01 -5.519344329833984375e+00 1.138543701171875000e+01 11 | -2.234888553619384766e+00 -1.513765144348144531e+01 1.711233329772949219e+01 12 | 2.048178482055664062e+01 1.531910514831542969e+01 1.297569274902343750e+00 13 | 2.220903873443603516e+00 2.876883029937744141e+00 1.168313980102539062e+01 14 | 1.488710880279541016e+00 1.013339710235595703e+01 -1.085923004150390625e+01 15 | -6.793376922607421875e+00 5.561182975769042969e+00 2.823381614685058594e+01 16 | -6.240493774414062500e+00 -1.261471748352050781e+00 -2.007337188720703125e+01 17 | 1.941026306152343750e+01 1.594352245330810547e+01 -1.644583642482757568e-01 18 | 2.034305572509765625e+00 9.462517738342285156e+00 -2.156559753417968750e+01 19 | -1.235863685607910156e+01 7.686964988708496094e+00 -2.778312110900878906e+01 20 | 1.254753303527832031e+01 2.037656784057617188e+01 6.217426300048828125e+00 21 | -------------------------------------------------------------------------------- /tests/pershom/test_pershom_backend_data/random_point_clouds/01.txt: -------------------------------------------------------------------------------- 1 | 2.294879436492919922e+00 2.029940366744995117e+00 1.008633804321289062e+01 2 | 5.282232284545898438e+00 -1.392682456970214844e+01 -1.019494342803955078e+01 3 | 1.878868865966796875e+01 -1.135763072967529297e+01 1.084267997741699219e+01 4 | -1.248034572601318359e+01 1.330796527862548828e+01 8.507024765014648438e+00 5 | -3.410379791259765625e+01 -1.945067596435546875e+01 -7.762837409973144531e-01 6 | 9.469700813293457031e+00 -2.092086982727050781e+01 1.532528686523437500e+01 7 | -4.655651092529296875e+00 1.023709297180175781e+01 -7.881093502044677734e+00 8 | -4.548531055450439453e+00 -1.529055976867675781e+01 -1.893261718750000000e+01 9 | 2.177994728088378906e+01 -6.495486259460449219e+00 -2.546433210372924805e+00 10 | -1.506688117980957031e+01 -2.999695301055908203e+00 6.474236011505126953e+00 11 | -2.249460601806640625e+01 1.785397911071777344e+01 -6.464286327362060547e+00 12 | -2.153913879394531250e+01 -3.957053661346435547e+00 -1.881030654907226562e+01 13 | -1.292291450500488281e+01 3.578302145004272461e+00 -1.580009460449218750e-01 14 | 1.953334045410156250e+01 -9.488323211669921875e+00 -1.056732368469238281e+01 15 | -1.952380752563476562e+01 1.605683326721191406e+01 3.020879983901977539e+00 16 | 1.771048545837402344e+01 -1.233715248107910156e+01 -1.554603576660156250e+00 17 | 2.357878303527832031e+01 -2.250756263732910156e+00 9.804723739624023438e+00 18 | 9.919981956481933594e-01 2.368477249145507812e+01 1.610438537597656250e+01 19 | -5.584335327148437500e-02 -8.989120483398437500e+00 -1.685858345031738281e+01 20 | 8.082454681396484375e+00 -7.023594379425048828e-01 -6.416393280029296875e+00 21 | -------------------------------------------------------------------------------- /tests/pershom/test_pershom_backend_data/random_point_clouds/02.txt: -------------------------------------------------------------------------------- 1 | -1.569143962860107422e+01 8.328658103942871094e+00 7.621088504791259766e+00 2 | 1.521720886230468750e+00 -2.383153438568115234e+00 -1.007918834686279297e+00 3 | -1.053955650329589844e+01 5.994036674499511719e+00 -1.497441959381103516e+01 4 | -1.045053958892822266e+01 -9.795961380004882812e+00 -8.527483940124511719e+00 5 | 7.234384536743164062e+00 1.309517288208007812e+01 6.125024318695068359e+00 6 | 1.079290676116943359e+01 1.027314472198486328e+01 1.179671287536621094e+01 7 | 1.356038188934326172e+01 -2.782874298095703125e+01 -7.038252353668212891e+00 8 | 6.458787918090820312e+00 -9.498422622680664062e+00 2.343321323394775391e+00 9 | 1.132456493377685547e+01 -8.603633880615234375e+00 8.479830741882324219e+00 10 | 4.921120166778564453e+00 2.547783660888671875e+01 7.478994369506835938e+00 11 | 3.025704860687255859e+00 -4.886765480041503906e+00 -1.085865020751953125e+01 12 | -1.723771286010742188e+01 8.312644958496093750e+00 -1.234071063995361328e+01 13 | -3.242609977722167969e+00 -1.707215690612792969e+01 -1.345014381408691406e+01 14 | -2.772304058074951172e+00 8.703633308410644531e+00 6.666395187377929688e+00 15 | 1.592916297912597656e+01 -1.324088191986083984e+01 2.791919708251953125e+01 16 | 2.864090204238891602e-01 -5.610611915588378906e+00 -2.340017318725585938e+01 17 | -9.287577629089355469e+00 1.476047515869140625e+00 -1.529828834533691406e+01 18 | -1.121659517288208008e+00 -3.132452011108398438e+00 1.739150428771972656e+01 19 | -1.123724555969238281e+01 -8.771486282348632812e+00 -3.216977310180664062e+01 20 | 4.046554088592529297e+00 -1.887501335144042969e+01 -2.078896331787109375e+01 21 | -------------------------------------------------------------------------------- /tests/pershom/test_pershom_backend_data/random_point_clouds/05.txt: -------------------------------------------------------------------------------- 1 | 4.200152873992919922e+00 -9.100952148437500000e+00 1.480857086181640625e+01 2 | 1.230931282043457031e+01 4.111227035522460938e+00 -1.112330245971679688e+01 3 | 9.193184375762939453e-01 -1.163776111602783203e+01 -9.874886512756347656e+00 4 | 2.333695030212402344e+01 -7.821397781372070312e+00 -1.233045101165771484e+01 5 | 1.586226558685302734e+01 7.421721458435058594e+00 -3.939793705940246582e-01 6 | 5.050350189208984375e+00 -1.023896598815917969e+01 -3.268229961395263672e+00 7 | 5.758150100708007812e+00 2.761218070983886719e+01 -9.432003021240234375e+00 8 | -1.134756088256835938e+00 -1.546015167236328125e+01 2.770107746124267578e+00 9 | -3.572121143341064453e+00 -6.340959072113037109e+00 -1.236189603805541992e+00 10 | 2.330923080444335938e+01 -9.795475006103515625e-02 3.458006858825683594e+00 11 | 2.055246162414550781e+01 8.026174902915954590e-01 -1.684565925598144531e+01 12 | 1.073609542846679688e+01 4.474770069122314453e+00 1.488601112365722656e+01 13 | -1.541873359680175781e+01 5.138186931610107422e+00 1.527934265136718750e+01 14 | -1.824094009399414062e+01 8.860630989074707031e+00 9.692591667175292969e+00 15 | -1.152042150497436523e+00 -5.440307617187500000e+00 -1.975798130035400391e+00 16 | 3.797050952911376953e+00 -1.583844184875488281e+00 -2.027872276306152344e+01 17 | 1.094208002090454102e+00 -2.214362621307373047e+00 -6.184287071228027344e+00 18 | -1.804293441772460938e+01 -2.615973854064941406e+01 -7.414704322814941406e+00 19 | 1.446764755249023438e+01 1.513642311096191406e+01 -4.236664295196533203e+00 20 | -5.677638530731201172e+00 -1.001023387908935547e+01 8.501178741455078125e+00 21 | -------------------------------------------------------------------------------- /tests/pershom/test_pershom_backend_data/random_point_clouds/08.txt: -------------------------------------------------------------------------------- 1 | 4.964171648025512695e-01 -4.169604301452636719e+00 8.952707290649414062e+00 2 | -3.376441240310668945e+00 -1.501020622253417969e+01 -1.941167831420898438e+01 3 | 9.189126968383789062e+00 9.829032897949218750e+00 2.326453447341918945e+00 4 | 2.415312194824218750e+01 2.093213462829589844e+01 1.829629135131835938e+01 5 | -1.595801115036010742e+00 1.649509811401367188e+01 -1.712705612182617188e+01 6 | -1.094874191284179688e+01 -1.049565792083740234e+01 1.961479187011718750e+00 7 | -5.696784019470214844e+00 1.154114341735839844e+01 -1.104621124267578125e+01 8 | 9.585831642150878906e+00 1.754752922058105469e+01 1.760782051086425781e+01 9 | -8.444528579711914062e-01 -1.675706863403320312e+01 -1.112575626373291016e+01 10 | -2.848413848876953125e+01 1.019102478027343750e+01 1.564753532409667969e+01 11 | -6.987383365631103516e-01 -2.057890319824218750e+01 -1.097996520996093750e+01 12 | -8.175359725952148438e+00 -2.336300849914550781e+00 -1.292065715789794922e+01 13 | 9.478190422058105469e+00 -2.607798576354980469e+00 1.378364372253417969e+01 14 | 6.538515567779541016e+00 -1.332163333892822266e+00 -1.585797786712646484e+01 15 | -8.233261108398437500e-01 5.137370109558105469e+00 -1.182361221313476562e+01 16 | 1.425596237182617188e+00 7.952968597412109375e+00 -3.700152397155761719e+00 17 | 9.599940776824951172e-01 -6.959006786346435547e+00 -1.842673492431640625e+01 18 | -1.717554473876953125e+01 -6.834000110626220703e+00 -2.126903152465820312e+01 19 | 7.328005790710449219e+00 2.408910560607910156e+01 3.372626066207885742e+00 20 | 4.278636932373046875e+00 -1.013033866882324219e+01 -1.626550674438476562e+01 21 | -------------------------------------------------------------------------------- /pershom_dev/generate_test_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as pth 3 | import pickle 4 | from simplicial_complex import * 5 | from cpu_sorted_boundary_array_implementation import SortedListBoundaryMatrix 6 | import torch 7 | from pershombox import toplex_persistence_diagrams 8 | from collections import Counter 9 | 10 | 11 | def generate(): 12 | 13 | output_path = "/tmp/random_simplicial_complexes" 14 | try: 15 | os.mkdir(output_path) 16 | except: 17 | pass 18 | 19 | args = [(100, 200, 300), 20 | (100, 200, 300, 400), 21 | (100, 100, 100, 100, 100)] 22 | 23 | 24 | for arg in args: 25 | c = random_simplicial_complex(*arg) 26 | file_name_stub = os.path.join(output_path, "random_sp__args__" + "_".join(str(x) for x in arg)) 27 | 28 | bm, col_dim = descending_sorted_boundary_array_from_filtrated_sp(c) 29 | bm, col_dim = bm.to('cuda'), col_dim.to('cuda') 30 | 31 | ind_not_reduced = torch.tensor(list(range(col_dim.size(0)))).to('cuda') 32 | ind_not_reduced = ind_not_reduced.masked_select(bm[:, 0] >= 0).long() 33 | bm = bm.index_select(0, ind_not_reduced) 34 | 35 | barcodes_true = toplex_persistence_diagrams(c, list(range(len(c)))) 36 | dgm_true = [Counter(((float(b), float(d)) for b, d in dgm )) for dgm in barcodes_true] 37 | 38 | with open(file_name_stub + ".pickle", 'wb') as f: 39 | pickle.dump({'calculate_persistence_args': (bm, ind_not_reduced, col_dim, max(col_dim)), 40 | 'expected_result': dgm_true}, 41 | f) 42 | 43 | 44 | 45 | if __name__ == "__main__": 46 | generate() 47 | -------------------------------------------------------------------------------- /pershom_dev/generate_profile_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as pth 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from simplicial_complex import * 8 | from cpu_sorted_boundary_array_implementation import SortedListBoundaryMatrix 9 | 10 | def generate_profile_input_dump(): 11 | c = random_simplicial_complex(1000, 2000, 4000) 12 | max_dimension = 2 13 | 14 | bm, simplex_dim = descending_sorted_boundary_array_from_filtrated_sp(c) 15 | np.savetxt('profiling_pershom/data/boundary_array.np_txt', bm) 16 | np.savetxt('profiling_pershom/data/boundary_array.np_txt', simplex_dim) 17 | 18 | ind_not_reduced = torch.tensor(list(range(simplex_dim.size(0)))) 19 | ind_not_reduced = ind_not_reduced.masked_select(bm[:, 0] >= 0) 20 | 21 | bm = bm.index_select(0, ind_not_reduced) 22 | 23 | with open('profiling_pershom/data/boundary_array.txt', 'w') as f: 24 | for row in bm: 25 | for v in row: 26 | f.write(str(int(v))) 27 | f.write(" ") 28 | 29 | with open ('profiling_pershom/data/boundary_array_size.txt', 'w') as f: 30 | for v in bm.size(): 31 | f.write(str(int(v))) 32 | f.write(" ") 33 | 34 | with open ('profiling_pershom/data/ind_not_reduced.txt', 'w') as f: 35 | for v in ind_not_reduced: 36 | f.write(str(int(v))) 37 | f.write(" ") 38 | 39 | with open ('profiling_pershom/data/column_dimension.txt', 'w') as f: 40 | for v in simplex_dim: 41 | f.write(str(int(v))) 42 | f.write(" ") 43 | 44 | with open ('profiling_pershom/data/max_dimension.txt', 'w') as f: 45 | f.write(str(int(max_dimension))) 46 | 47 | 48 | 49 | if __name__ == "__main__": 50 | generate_profile_input_dump() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchph 2 | 3 | This repository contains [PyTorch](http://pytorch.org) extensions to **compute** 4 | persistent homology and to **differentiate** through the persistent homology computation. 5 | The packaging structure is similar to PyTorch's structure to facilitate usage for people 6 | familiar with PyTorch. 7 | 8 | ## Documentation 9 | 10 | [Read the docs!](https://c-hofer.github.io/torchph/) 11 | 12 | The folder *tutorials* (within `docs`) contains some (more or less) minimalistic examples in form of Jupyter notebooks 13 | to demonstrate how to use the `PyTorch` extensions. 14 | 15 | ## Associated publications 16 | 17 | If you use any of these extensions, please cite the following works (depending on which functionality you use, obviously :) 18 | 19 | ```bash 20 | @inproceedings{Hofer17a, 21 | author = {C.~Hofer, R.~Kwitt, M.~Niethammer and A.~Uhl}, 22 | title = {Deep Learning with Topological Signatures}, 23 | booktitle = {NIPS}, 24 | year = {2017}} 25 | 26 | @inproceedings{Hofer19a, 27 | author = {C.~Hofer, R.~Kwitt, M.~Dixit and M.~Niethammer}, 28 | title = {Connectivity-Optimized Representation Learning via Persistent Homology}, 29 | booktitle = {ICML}, 30 | year = {2019}} 31 | 32 | @article{Hofer19b, 33 | author = {C.~Hofer, R.~Kwitt, and M.~Niethammer}, 34 | title = {Learning Representations of Persistence Barcodes}, 35 | booktitle = {JMLR}, 36 | year = {2019}} 37 | 38 | @inproceedings{Hofer20a}, 39 | author = {C.~Hofer, F.~Graf, R.~Kwitt, B.~Rieck and M.~Niethammer}, 40 | title = {Graph Filtration Learning}, 41 | booktitle = {arXiv}, 42 | year = {2020}} 43 | 44 | @inproceedings{Hofer20a, 45 | author = {C.~Hofer, F.~Graf, M.~Niethammer and R.~Kwitt}, 46 | title = {Topologically Densified Distributions}, 47 | booktitle = {arXiv}, 48 | year = {2020}} 49 | ``` 50 | -------------------------------------------------------------------------------- /pershom_dev/bug_1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as pth 3 | import pickle 4 | from simplicial_complex import * 5 | from cpu_sorted_boundary_array_implementation import SortedListBoundaryMatrix 6 | import torch 7 | from time import time 8 | from pershombox import toplex_persistence_diagrams 9 | import torchph.pershom.pershom_backend as pershom_backend 10 | import yep 11 | from torchph.pershom.calculate_persistence import calculate_persistence 12 | import cProfile 13 | from collections import Counter 14 | 15 | os.environ['CUDA_LAUNCH_BLOCKING'] = str(1) 16 | 17 | 18 | def test(): 19 | c = None 20 | 21 | with open('bug_1_sp.pickle', 'br') as f: 22 | c = pickle.load(f) 23 | print('|C| = ', len(c)) 24 | max_red_by_iteration = 10000 25 | 26 | # cpu_impl = SortedListBoundaryMatrix(c) 27 | # cpu_impl.max_pairs = max_red_by_iteration 28 | bm, col_dim = descending_sorted_boundary_array_from_filtrated_sp(c) 29 | bm, col_dim = bm.to('cuda'), col_dim.to('cuda') 30 | 31 | 32 | # barcodes_true = toplex_persistence_diagrams(c, list(range(len(c)))) 33 | # dgm_true = [Counter(((float(b), float(d)) for b, d in dgm )) for dgm in barcodes_true] 34 | 35 | 36 | def my_output_to_dgms(input): 37 | ret = [] 38 | b, b_e = input 39 | 40 | for dim, (b_dim, b_dim_e) in enumerate(zip(b, b_e)): 41 | b_dim, b_dim_e = b_dim.float(), b_dim_e.float() 42 | 43 | tmp = torch.empty_like(b_dim_e) 44 | tmp.fill_(float('inf')) 45 | b_dim_e = torch.cat([b_dim_e, tmp], dim=1) 46 | 47 | 48 | dgm = torch.cat([b_dim, b_dim_e], dim=0) 49 | dgm = dgm.tolist() 50 | dgm = Counter(((float(b), float(d)) for b, d in dgm )) 51 | 52 | ret.append(dgm) 53 | 54 | return ret 55 | 56 | output = calculate_persistence(bm, col_dim, max(col_dim), max_red_by_iteration) 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | test() 65 | -------------------------------------------------------------------------------- /pershom_dev/vc_comp_script.py: -------------------------------------------------------------------------------- 1 | import torchph.pershom.pershom_backend as pershom_backend 2 | import torch 3 | import time 4 | from scipy.special import binom 5 | 6 | 7 | #torch.tensor([[-0.6690, 1.5059], [ 0.4220, 1.2434], [-0.3436, -0.0053], [-0.1569, 0.0627]], device='cuda', requires_grad=True).float() 8 | 9 | 10 | point_cloud = torch.randn(5,3, device='cuda', requires_grad=True).float() 11 | # point_cloud = torch.tensor([(0, 0), (1, 0), (0, 0.5), (1, 1.5)], device='cuda', requires_grad=True) 12 | 13 | # loss = x.sum() 14 | # loss.backward() 15 | # print(point_cloud.grad) 16 | 17 | 18 | print(point_cloud) 19 | # pershom_backend.__C.CalcPersCuda__my_test_f(point_cloud).sum().backward(); 20 | 21 | # time_start = time.time() 22 | try: 23 | r = pershom_backend.__C.VRCompCuda__vr_l1_persistence(point_cloud, 0, 0) 24 | 25 | except Exception as ex: 26 | print("=== Error ===") 27 | print(ex) 28 | exit() 29 | 30 | non_essentials = r[0] 31 | essentials = r[1] 32 | 33 | print(len(non_essentials), len(essentials)) 34 | 35 | print("=== non-essentials ===") 36 | for x in non_essentials: print(x) 37 | 38 | print("=== essentials ===") 39 | for x in essentials: print(x) 40 | 41 | print("=== grad ===") 42 | loss = non_essentials[0].sum() 43 | loss.backward() 44 | print(point_cloud.grad) 45 | 46 | # ba = r[0] 47 | # simp_dim = r[2] 48 | # filt_val = r[3] 49 | 50 | # dim = 3 51 | # for simp_id, boundaries in enumerate(ba): 52 | # simp_filt_val = filt_val[simp_id + point_cloud.size(0)] 53 | 54 | # for boundary in boundaries.tolist(): 55 | # if boundary == -1: continue 56 | # cond = True 57 | # cond = cond and (simp_dim[simp_id + point_cloud.size(0)] - 1 == simp_dim[boundary]) 58 | # cond = cond and filt_val[boundary] <= simp_filt_val 59 | 60 | # if not cond: 61 | # print("{}, {}".format(simp_id, boundary)) 62 | # print(ba[simp_id]) 63 | # print(simp_dim[simp_id + point_cloud.size(0)], simp_dim[boundary]) 64 | 65 | # raise Exception() 66 | 67 | 68 | 69 | 70 | 71 | # print(print(r[2][int(5+binom(5,2)):int(5+binom(5,2))+1000])) 72 | 73 | 74 | # loss = t.sum() 75 | # loss.backward() 76 | # 77 | # for i, r in enumerate(ba): 78 | # x = point_cloud[r[0]-point_cloud.size(0)] 79 | # y = point_cloud[r[1]-point_cloud.size(0)] 80 | 81 | # assert torch.equal((x-y).abs().sum(), t[i]) 82 | 83 | 84 | # for row_i in range(t.size(0)): 85 | # for col_i in range(t.size(1)): 86 | # assert int(t[row_i, col_i]) == binom(col_i, row_i+1) 87 | # print(time.time() - time_start) 88 | -------------------------------------------------------------------------------- /pershom_dev/simplicial_complex.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from collections import defaultdict 4 | from itertools import combinations 5 | from scipy.special import binom 6 | 7 | 8 | def boundary_operator(simplex): 9 | s = tuple(simplex) 10 | 11 | if len(simplex) == 1: 12 | return () 13 | 14 | else: 15 | return [s[:i] + s[(i + 1):] for i in range(len(s))] 16 | 17 | 18 | def random_simplicial_complex(*args): 19 | simplices_by_dim = defaultdict(set) 20 | 21 | vertex_ids = np.array((range(args[0]))) 22 | 23 | vertices = [(i,) for i in vertex_ids] 24 | simplices_by_dim[0] = set(vertices) 25 | 26 | for dim_i, n_simplices in enumerate(args[1:], start=1): 27 | 28 | n_simplices = min(n_simplices, int(binom(len(vertices), dim_i+1))) 29 | while len(simplices_by_dim[dim_i]) != n_simplices: 30 | chosen_vertices = np.random.choice(vertex_ids, replace=False, size=dim_i+1) 31 | simplex = tuple(sorted(chosen_vertices)) 32 | 33 | if simplex not in simplices_by_dim[dim_i]: 34 | simplices_by_dim[dim_i].add(simplex) 35 | 36 | for dim_i in sorted(simplices_by_dim, reverse=True): 37 | if dim_i == 0 : 38 | break 39 | 40 | for s in simplices_by_dim[dim_i]: 41 | for boundary_s in boundary_operator(s): 42 | if boundary_s not in simplices_by_dim[dim_i-1]: 43 | simplices_by_dim[dim_i-1].add(boundary_s) 44 | 45 | sp = [] 46 | for dim_i in range(len(args)): 47 | sp += list(simplices_by_dim[dim_i]) 48 | 49 | return sp 50 | 51 | 52 | def descending_sorted_boundary_array_from_filtrated_sp(filtrated_sp, 53 | dtype=torch.int32, 54 | resize_factor=2): 55 | simplex_to_ordering_position = {s: i for i, s in enumerate(filtrated_sp)} 56 | 57 | max_boundary_size = max(len(s) for s in filtrated_sp) 58 | n_cols = len(filtrated_sp) 59 | n_rows = resize_factor*max_boundary_size 60 | 61 | bm = torch.empty(size=(n_rows, n_cols), 62 | dtype=dtype) 63 | bm.fill_(-1) 64 | 65 | col_to_dim = torch.empty(size=(n_cols,), 66 | dtype=dtype) 67 | 68 | for col_i, s in enumerate(filtrated_sp): 69 | boundary = boundary_operator(s) 70 | orderings_of_boundaries = sorted((simplex_to_ordering_position[b] for b in boundary), 71 | reverse=True) 72 | 73 | col_to_dim[col_i] = len(s) - 1 74 | 75 | for row_i, entry in enumerate(orderings_of_boundaries): 76 | bm[row_i, col_i] = entry 77 | 78 | # boundary array is delivered in column first order for efficency when merging 79 | bm = bm.transpose_(0, 1) 80 | 81 | 82 | return bm, col_to_dim 83 | -------------------------------------------------------------------------------- /torchph/pershom/pershom_cpp_src/vr_comp_cuda.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | 4 | #include 5 | 6 | 7 | using namespace torch; 8 | 9 | 10 | namespace VRCompCuda { 11 | std::vector> vr_persistence( 12 | const Tensor& distance_matrix, 13 | int64_t max_dimension, 14 | double max_ball_diameter); 15 | 16 | std::vector> vr_persistence_l1( 17 | const Tensor& point_cloud, 18 | int64_t max_dimension, 19 | double max_ball_diameter); 20 | 21 | void write_combinations_table_to_tensor( 22 | const Tensor& out, 23 | const int64_t out_row_offset, 24 | const int64_t additive_constant, 25 | const int64_t max_n, 26 | const int64_t r); 27 | 28 | Tensor co_faces_from_combinations( 29 | const Tensor & combinations, 30 | const Tensor & faces); 31 | 32 | Tensor l1_norm_distance_matrix(const Tensor & points); 33 | 34 | Tensor l2_norm_distance_matrix(const Tensor & points); 35 | 36 | class VietorisRipsArgsGenerator 37 | { 38 | public: 39 | 40 | at::TensorOptions tensopt_real; 41 | at::TensorOptions tensopt_int; 42 | 43 | Tensor distance_matrix; 44 | int64_t max_dimension; 45 | double max_ball_diameter; 46 | 47 | std::vector boundary_info_non_vertices; 48 | std::vector filtration_values_by_dim; 49 | std::vector n_simplices_by_dim; 50 | 51 | Tensor simplex_dimension_vector; 52 | Tensor filtration_values_vector_without_vertices; 53 | Tensor filtration_add_eps_hack_values; 54 | 55 | Tensor sort_indices_without_vertices; 56 | Tensor sort_indices_without_vertices_inverse; 57 | Tensor sorted_filtration_values_vector; 58 | 59 | Tensor boundary_array; 60 | 61 | Tensor ba_row_i_to_bm_col_i_vector; 62 | 63 | std::vector operator()( 64 | const Tensor & distance_matrix, 65 | int64_t max_dimension, 66 | double max_ball_diameter); 67 | 68 | void init_state( 69 | const Tensor & distance_matrix, 70 | int64_t max_dimension, 71 | double max_ball_diameter 72 | ); 73 | 74 | void make_boundary_info_edges(); 75 | void make_boundary_info_non_edges(); 76 | void make_simplex_ids_compatible_within_dimensions(); 77 | void make_simplex_dimension_vector(); 78 | void make_filtration_values_vector_without_vertices(); 79 | void do_filtration_add_eps_hack(); 80 | void make_sorting_infrastructure(); 81 | void undo_filtration_add_eps_hack(); 82 | void make_sorted_filtration_values_vector(); 83 | void make_boundary_array_rows_unsorted(); 84 | void apply_sorting_to_rows(); 85 | void make_ba_row_i_to_bm_col_i_vector(); 86 | }; 87 | 88 | 89 | std::vector> calculate_persistence_output_to_barcode_tensors( 90 | const std::vector>& calculate_persistence_output, 91 | const Tensor & filtration_values); 92 | } 93 | 94 | -------------------------------------------------------------------------------- /docs/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:before,.clearfix:after{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-weight:normal;font-style:normal;src:url("../fonts/fontawesome-webfont.eot");src:url("../fonts/fontawesome-webfont.eot?#iefix") format("embedded-opentype"),url("../fonts/fontawesome-webfont.woff") format("woff"),url("../fonts/fontawesome-webfont.ttf") format("truetype"),url("../fonts/fontawesome-webfont.svg#FontAwesome") format("svg")}.fa:before{display:inline-block;font-family:FontAwesome;font-style:normal;font-weight:normal;line-height:1;text-decoration:inherit}a .fa{display:inline-block;text-decoration:inherit}li .fa{display:inline-block}li .fa-large:before,li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-0.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before,ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before{content:""}.icon-book:before{content:""}.fa-caret-down:before{content:""}.icon-caret-down:before{content:""}.fa-caret-up:before{content:""}.icon-caret-up:before{content:""}.fa-caret-left:before{content:""}.icon-caret-left:before{content:""}.fa-caret-right:before{content:""}.icon-caret-right:before{content:""}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;z-index:400}.rst-versions a{color:#2980B9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27AE60;*zoom:1}.rst-versions .rst-current-version:before,.rst-versions .rst-current-version:after{display:table;content:""}.rst-versions .rst-current-version:after{clear:both}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book{float:left}.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#E74C3C;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#F1C40F;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:gray;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:solid 1px #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .icon-book{float:none}.rst-versions.rst-badge .fa-book{float:none}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book{float:left}.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge .rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width: 768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} 2 | -------------------------------------------------------------------------------- /pershom_dev/pershom_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as pth 3 | import pickle 4 | from simplicial_complex import * 5 | from cpu_sorted_boundary_array_implementation import SortedListBoundaryMatrix 6 | import torch 7 | from time import time 8 | # from pershombox import toplex_persistence_diagrams 9 | import torchph.pershom.pershom_backend as pershom_backend 10 | import yep 11 | from torchph.pershom.calculate_persistence import calculate_persistence 12 | import cProfile 13 | from collections import Counter 14 | 15 | # os.environ['CUDA_LAUNCH_BLOCKING'] = str(1) 16 | 17 | 18 | def test(): 19 | c = None 20 | 21 | use_cache = False 22 | 23 | if use_cache: 24 | random_simplicial_complex_path = './random_simplicial_complex.pickle' 25 | if pth.exists(random_simplicial_complex_path): 26 | with open(random_simplicial_complex_path, 'br') as f: 27 | c = pickle.load(f) 28 | else: 29 | c = random_simplicial_complex(100, 100, 100, 100, 100, 100) 30 | with open(random_simplicial_complex_path, 'bw') as f: 31 | pickle.dump(c, f) 32 | else: 33 | c = random_simplicial_complex(100, 100, 100, 100, 100) 34 | 35 | print('|C| = ', len(c)) 36 | max_red_by_iteration = -1 37 | 38 | # cpu_impl = SortedListBoundaryMatrix(c) 39 | # cpu_impl.max_pairs = max_red_by_iteration 40 | bm, col_dim = descending_sorted_boundary_array_from_filtrated_sp(c) 41 | 42 | print(bm[-1]) 43 | 44 | bm, col_dim = bm.to('cuda'), col_dim.to('cuda') 45 | 46 | 47 | barcodes_true = toplex_persistence_diagrams(c, list(range(len(c)))) 48 | dgm_true = [Counter(((float(b), float(d)) for b, d in dgm )) for dgm in barcodes_true] 49 | 50 | 51 | def my_output_to_dgms(input): 52 | ret = [] 53 | b, b_e = input 54 | 55 | for dim, (b_dim, b_dim_e) in enumerate(zip(b, b_e)): 56 | b_dim, b_dim_e = b_dim.float(), b_dim_e.float() 57 | 58 | tmp = torch.empty_like(b_dim_e) 59 | tmp.fill_(float('inf')) 60 | b_dim_e = torch.cat([b_dim_e, tmp], dim=1) 61 | 62 | 63 | dgm = torch.cat([b_dim, b_dim_e], dim=0) 64 | dgm = dgm.tolist() 65 | dgm = Counter(((float(b), float(d)) for b, d in dgm )) 66 | 67 | ret.append(dgm) 68 | 69 | return ret 70 | 71 | 72 | # pr = cProfile.Profile() 73 | # pr.enable() 74 | 75 | ind_not_reduced = torch.tensor(list(range(col_dim.size(0)))).to('cuda').detach() 76 | ind_not_reduced = ind_not_reduced.masked_select(bm[:, 0] >= 0).long().detach() 77 | bm = bm.index_select(0, ind_not_reduced).detach() 78 | 79 | yep.start('profiling_pershom/profile.google-pprof') 80 | 81 | for i in range(10): 82 | time_start = time() 83 | output = pershom_backend.calculate_persistence(bm.clone(), ind_not_reduced.clone(), col_dim.clone(), max(col_dim)) 84 | print(time() - time_start) 85 | yep.stop() 86 | 87 | # pr.disable() 88 | # pr.dump_stats('high_level_profile.cProfile') 89 | 90 | print([[len(x) for x in y] for y in output ]) 91 | 92 | dgm_test = my_output_to_dgms(output) 93 | 94 | print('dgm_true lengths:', [len(dgm) for dgm in dgm_true]) 95 | print('dgm_test lengths:', [len(dgm) for dgm in dgm_test]) 96 | 97 | for dgm_test, dgm_true in zip(dgm_test, dgm_true): 98 | assert(dgm_test == dgm_true) 99 | 100 | 101 | def vr_l1_persistence_performance_test(): 102 | pc = torch.randn(20,3) 103 | pc = torch.tensor(pc, device='cuda', dtype=torch.float) 104 | 105 | max_dimension = 2 106 | max_ball_radius = 0 107 | 108 | yep.start('profiling_pershom/profile.google-pprof') 109 | time_start = time() 110 | res = pershom_backend.__C.VRCompCuda__vr_persistence(pc, max_dimension, max_ball_radius, 'l1') 111 | print(time() - time_start) 112 | yep.stop() 113 | print(res[0][0]) 114 | 115 | 116 | 117 | 118 | vr_l1_persistence_performance_test() 119 | -------------------------------------------------------------------------------- /pershom_dev/cpu_sorted_boundary_array_implementation.py: -------------------------------------------------------------------------------- 1 | from collections import Counter, defaultdict 2 | import numpy as np 3 | from simplicial_complex import boundary_operator 4 | 5 | class SortedListBoundaryMatrix: 6 | 7 | def __init__(self, filtrated_sc): 8 | self._data = None 9 | self._simplex_dims = None 10 | self.stats = defaultdict(list) 11 | self._init_from_filtrated_sp(filtrated_sc) 12 | self.max_pairs = 100 13 | 14 | def _init_from_filtrated_sp(self, filtrated_sp): 15 | simplex_to_ordering_position = {s: i for i, s in enumerate(filtrated_sp)} 16 | 17 | self._data = [[]]*len(filtrated_sp) 18 | self._simplex_dims = {} 19 | 20 | for col_i, s in enumerate(filtrated_sp): 21 | boundary = boundary_operator(s) 22 | orderings_of_boundaries = sorted((simplex_to_ordering_position[b] for b in boundary)) 23 | self._data[col_i] = orderings_of_boundaries 24 | self._simplex_dims[col_i] = len(s) - 1 25 | 26 | 27 | 28 | def add_column_i_to_j(self, i, j): 29 | col_i = self._data[i] 30 | col_j = self._data[j] 31 | 32 | self._data[j] = sorted((k for k, v in Counter(col_i + col_j).items() if v == 1)) 33 | 34 | def get_pivot(self): 35 | return {col_i: col_i_entries[-1] for col_i, col_i_entries in enumerate(self._data) 36 | if len(col_i_entries) > 0} 37 | 38 | def pairs_for_reduction(self): 39 | tmp = defaultdict(list) 40 | for column_i, pivot_i in self.get_pivot().items(): 41 | tmp[pivot_i].append(column_i) 42 | 43 | ret = [] 44 | for k, v in tmp.items(): 45 | if len(v) > 1: 46 | for j in v[1:]: 47 | ret.append((v[0], j)) 48 | 49 | if self.max_pairs is not None: 50 | ret = ret[:self.max_pairs] 51 | return ret 52 | 53 | def reduction_step(self): 54 | pairs = self.pairs_for_reduction() 55 | 56 | if len(pairs) == 0: 57 | raise StopIteration() 58 | 59 | self.stats['n_pairs'].append(len(pairs)) 60 | 61 | for i, j in pairs: 62 | self.add_column_i_to_j(i, j) 63 | 64 | self.stats['longest_column'].append(max(len(c) for c in self._data)) 65 | 66 | def reduce(self): 67 | iterations = 0 68 | try: 69 | while True: 70 | self.reduction_step() 71 | iterations += 1 72 | 73 | except StopIteration: 74 | print('Reached end of reduction after ', iterations, ' iterations') 75 | 76 | def row_i_contains_lowest_one(self): 77 | pass 78 | 79 | def read_barcodes(self): 80 | assert len(self.pairs_for_reduction()) == 0, 'Matrix is not reduced' 81 | 82 | barcodes = defaultdict(list) 83 | pivot = self.get_pivot() 84 | 85 | for death, birth in pivot.items(): 86 | assert birth < death 87 | barcodes[self._simplex_dims[birth]].append((birth, death)) 88 | 89 | rows_with_lowest_one = set() 90 | for k, v in pivot.items(): 91 | rows_with_lowest_one.add(v) 92 | 93 | 94 | for i in range(len(self)): 95 | if (i not in pivot) and i not in rows_with_lowest_one: 96 | barcodes[self._simplex_dims[i]].append((i, float('inf'))) 97 | 98 | 99 | return barcodes 100 | 101 | 102 | def __repr__(self): 103 | n = len(self._data) 104 | matrix = np.zeros((n, n)) 105 | 106 | for i, col_i in enumerate(self._data): 107 | for j in col_i: 108 | matrix[j, i] = 1 109 | 110 | 111 | return str(matrix) 112 | 113 | def __len__(self): 114 | return len(self._data) 115 | -------------------------------------------------------------------------------- /docs/_static/js/theme.js: -------------------------------------------------------------------------------- 1 | /* sphinx_rtd_theme version 0.4.3 | MIT license */ 2 | /* Built 20190212 16:02 */ 3 | require=function r(s,a,l){function c(e,n){if(!a[e]){if(!s[e]){var i="function"==typeof require&&require;if(!n&&i)return i(e,!0);if(u)return u(e,!0);var t=new Error("Cannot find module '"+e+"'");throw t.code="MODULE_NOT_FOUND",t}var o=a[e]={exports:{}};s[e][0].call(o.exports,function(n){return c(s[e][1][n]||n)},o,o.exports,r,s,a,l)}return a[e].exports}for(var u="function"==typeof require&&require,n=0;n"),i("table.docutils.footnote").wrap("
"),i("table.docutils.citation").wrap("
"),i(".wy-menu-vertical ul").not(".simple").siblings("a").each(function(){var e=i(this);expand=i(''),expand.on("click",function(n){return t.toggleCurrent(e),n.stopPropagation(),!1}),e.prepend(expand)})},reset:function(){var n=encodeURI(window.location.hash)||"#";try{var e=$(".wy-menu-vertical"),i=e.find('[href="'+n+'"]');if(0===i.length){var t=$('.document [id="'+n.substring(1)+'"]').closest("div.section");0===(i=e.find('[href="#'+t.attr("id")+'"]')).length&&(i=e.find('[href="#"]'))}0this.docHeight||(this.navBar.scrollTop(i),this.winPosition=n)},onResize:function(){this.winResize=!1,this.winHeight=this.win.height(),this.docHeight=$(document).height()},hashChange:function(){this.linkScroll=!0,this.win.one("hashchange",function(){this.linkScroll=!1})},toggleCurrent:function(n){var e=n.closest("li");e.siblings("li.current").removeClass("current"),e.siblings().find("li.current").removeClass("current"),e.find("> ul li.current").removeClass("current"),e.toggleClass("current")}},"undefined"!=typeof window&&(window.SphinxRtdTheme={Navigation:e.exports.ThemeNav,StickyNav:e.exports.ThemeNav}),function(){for(var r=0,n=["ms","moz","webkit","o"],e=0;e` to install ``torchph``. 61 | 62 | Functionality 63 | ============= 64 | * Vietoris-Rips (VR) persistent homology (from point clouds and distance matrices) 65 | * Vertex-based filtrations (e.g., usable for graphs) 66 | * Learnable vectorizations of persistence barcodes 67 | 68 | All of this functionality is available for **GPU** computations and can easily be 69 | used within the PyTorch environment. 70 | 71 | The following **simple example** is a teaser showing how to compute 0-dim. persistent 72 | homology of a (1) Vietoris-Rips filtration which uses the Manhatten distance between 73 | samples and (2) doing the same using a pre-computed distance matrix. 74 | 75 | .. code-block:: python 76 | 77 | device = "cuda:0" 78 | 79 | # import numpy 80 | import numpy as np 81 | 82 | # import VR persistence computation functionality 83 | from torchph.pershom import vr_persistence_l1, vr_persistence 84 | 85 | # import scipy methods to compute pairwise distance matrices 86 | from scipy.spatial.distance import pdist 87 | from scipy.spatial.distance import squareform 88 | 89 | # create 10-dim. point cloud with 100 samples 90 | x = np.random.randn(100, 10) 91 | 92 | # compute VR persistent homology (using l1 metric) 93 | X = torch.Tensor(x).to(device) 94 | l_a, _ = vr_persistence_l1(X.contiguous(),0, 0); 95 | 96 | # compute the same using a pre-computed distance matrix 97 | D = torch.tensor( 98 | squareform( 99 | pdist(x, metric='cityblock') 100 | ) 101 | ).to("cuda:0") 102 | l_b, _ = vr_persistence(D, 0, 0) 103 | print("Diff: ", 104 | (l_a[0].float()-l_b[0].float()).abs().sum().item()) 105 | 106 | .. toctree:: 107 | :caption: Modules 108 | :maxdepth: 1 109 | :hidden: 110 | :glob: 111 | 112 | nn 113 | pershom 114 | 115 | .. toctree:: 116 | :caption: Notebooks 117 | :maxdepth: 0 118 | :hidden: 119 | :glob: 120 | 121 | tutorials/SLayer.ipynb 122 | tutorials/ToyDiffVR.ipynb 123 | tutorials/ComparisonSOTA.ipynb 124 | tutorials/InputOptim.ipynb 125 | 126 | Indices and tables 127 | ================== 128 | 129 | * :ref:`genindex` 130 | * :ref:`modindex` 131 | * :ref:`search` 132 | -------------------------------------------------------------------------------- /docs/_static/pygments.css: -------------------------------------------------------------------------------- 1 | .highlight .hll { background-color: #ffffcc } 2 | .highlight { background: #f8f8f8; } 3 | .highlight .c { color: #408080; font-style: italic } /* Comment */ 4 | .highlight .err { border: 1px solid #FF0000 } /* Error */ 5 | .highlight .k { color: #008000; font-weight: bold } /* Keyword */ 6 | .highlight .o { color: #666666 } /* Operator */ 7 | .highlight .ch { color: #408080; font-style: italic } /* Comment.Hashbang */ 8 | .highlight .cm { color: #408080; font-style: italic } /* Comment.Multiline */ 9 | .highlight .cp { color: #BC7A00 } /* Comment.Preproc */ 10 | .highlight .cpf { color: #408080; font-style: italic } /* Comment.PreprocFile */ 11 | .highlight .c1 { color: #408080; font-style: italic } /* Comment.Single */ 12 | .highlight .cs { color: #408080; font-style: italic } /* Comment.Special */ 13 | .highlight .gd { color: #A00000 } /* Generic.Deleted */ 14 | .highlight .ge { font-style: italic } /* Generic.Emph */ 15 | .highlight .gr { color: #FF0000 } /* Generic.Error */ 16 | .highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */ 17 | .highlight .gi { color: #00A000 } /* Generic.Inserted */ 18 | .highlight .go { color: #888888 } /* Generic.Output */ 19 | .highlight .gp { color: #000080; font-weight: bold } /* Generic.Prompt */ 20 | .highlight .gs { font-weight: bold } /* Generic.Strong */ 21 | .highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ 22 | .highlight .gt { color: #0044DD } /* Generic.Traceback */ 23 | .highlight .kc { color: #008000; font-weight: bold } /* Keyword.Constant */ 24 | .highlight .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */ 25 | .highlight .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */ 26 | .highlight .kp { color: #008000 } /* Keyword.Pseudo */ 27 | .highlight .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */ 28 | .highlight .kt { color: #B00040 } /* Keyword.Type */ 29 | .highlight .m { color: #666666 } /* Literal.Number */ 30 | .highlight .s { color: #BA2121 } /* Literal.String */ 31 | .highlight .na { color: #7D9029 } /* Name.Attribute */ 32 | .highlight .nb { color: #008000 } /* Name.Builtin */ 33 | .highlight .nc { color: #0000FF; font-weight: bold } /* Name.Class */ 34 | .highlight .no { color: #880000 } /* Name.Constant */ 35 | .highlight .nd { color: #AA22FF } /* Name.Decorator */ 36 | .highlight .ni { color: #999999; font-weight: bold } /* Name.Entity */ 37 | .highlight .ne { color: #D2413A; font-weight: bold } /* Name.Exception */ 38 | .highlight .nf { color: #0000FF } /* Name.Function */ 39 | .highlight .nl { color: #A0A000 } /* Name.Label */ 40 | .highlight .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */ 41 | .highlight .nt { color: #008000; font-weight: bold } /* Name.Tag */ 42 | .highlight .nv { color: #19177C } /* Name.Variable */ 43 | .highlight .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */ 44 | .highlight .w { color: #bbbbbb } /* Text.Whitespace */ 45 | .highlight .mb { color: #666666 } /* Literal.Number.Bin */ 46 | .highlight .mf { color: #666666 } /* Literal.Number.Float */ 47 | .highlight .mh { color: #666666 } /* Literal.Number.Hex */ 48 | .highlight .mi { color: #666666 } /* Literal.Number.Integer */ 49 | .highlight .mo { color: #666666 } /* Literal.Number.Oct */ 50 | .highlight .sa { color: #BA2121 } /* Literal.String.Affix */ 51 | .highlight .sb { color: #BA2121 } /* Literal.String.Backtick */ 52 | .highlight .sc { color: #BA2121 } /* Literal.String.Char */ 53 | .highlight .dl { color: #BA2121 } /* Literal.String.Delimiter */ 54 | .highlight .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */ 55 | .highlight .s2 { color: #BA2121 } /* Literal.String.Double */ 56 | .highlight .se { color: #BB6622; font-weight: bold } /* Literal.String.Escape */ 57 | .highlight .sh { color: #BA2121 } /* Literal.String.Heredoc */ 58 | .highlight .si { color: #BB6688; font-weight: bold } /* Literal.String.Interpol */ 59 | .highlight .sx { color: #008000 } /* Literal.String.Other */ 60 | .highlight .sr { color: #BB6688 } /* Literal.String.Regex */ 61 | .highlight .s1 { color: #BA2121 } /* Literal.String.Single */ 62 | .highlight .ss { color: #19177C } /* Literal.String.Symbol */ 63 | .highlight .bp { color: #008000 } /* Name.Builtin.Pseudo */ 64 | .highlight .fm { color: #0000FF } /* Name.Function.Magic */ 65 | .highlight .vc { color: #19177C } /* Name.Variable.Class */ 66 | .highlight .vg { color: #19177C } /* Name.Variable.Global */ 67 | .highlight .vi { color: #19177C } /* Name.Variable.Instance */ 68 | .highlight .vm { color: #19177C } /* Name.Variable.Magic */ 69 | .highlight .il { color: #666666 } /* Literal.Number.Integer.Long */ -------------------------------------------------------------------------------- /docs_src/source/index.rst: -------------------------------------------------------------------------------- 1 | .. torchph documentation master file, created by 2 | sphinx-quickstart on Mon Feb 4 13:39:08 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | PyTorch extensions for persistent homology 7 | ========================================== 8 | 9 | This package contains the backend methods (to be used within the PyTorch environment) 10 | for multiple works using persistent homology in machine learning problems. In particular, 11 | the following publications are most relevant:: 12 | 13 | @inproceedings{Hofer17a, 14 | author = {C.~Hofer, R.~Kwitt, M.~Niethammer and A.~Uhl}, 15 | title = {Deep Learning with Topological Signatures}, 16 | booktitle = {NIPS}, 17 | year = {2017}} 18 | 19 | @inproceedings{Hofer19a, 20 | author = {C.~Hofer, R.~Kwitt, M.~Dixit and M.~Niethammer}, 21 | title = {Connectivity-Optimized Representation Learning via Persistent Homology}, 22 | booktitle = {ICML}, 23 | year = {2019}} 24 | 25 | @article{Hofer19b, 26 | author = {C.~Hofer, R.~Kwitt, and M.~Niethammer}, 27 | title = {Learning Representations of Persistence Barcodes}, 28 | booktitle = {JMLR}, 29 | year = {2019}} 30 | 31 | @article{Hofer20a, 32 | author = {C.~Hofer, F.~Graf, B.~Rieck, R.~Kwitt and M.~Niethammer}, 33 | title = {Graph Filtration Learning}, 34 | journal = {arXiv}, 35 | note = {\url{https://arxiv.org/abs/1905.10996}}, 36 | year = {2019}} 37 | 38 | @article{Hofer20b, 39 | author = {C.~Hofer, F.~Graf, M.~Niethammer and R.~Kwitt}, 40 | title = {Topologically Densified Distributions}, 41 | journal = {arXiv}, 42 | note = {\url{https://arxiv.org/abs/2002.04805}}, 43 | year = {2020}} 44 | 45 | .. note:: 46 | Note that not all of the available functionality is exposed in the 47 | documentation yet. 48 | 49 | Get started 50 | =========== 51 | 52 | .. toctree:: 53 | :maxdepth: 1 54 | :caption: Get Started 55 | :hidden: 56 | :glob: 57 | 58 | install/index 59 | 60 | Follow the :doc:`instructions` to install ``torchph``. 61 | 62 | Functionality 63 | ============= 64 | * Vietoris-Rips (VR) persistent homology (from point clouds and distance matrices) 65 | * Vertex-based filtrations (e.g., usable for graphs) 66 | * Learnable vectorizations of persistence barcodes 67 | 68 | All of this functionality is available for **GPU** computations and can easily be 69 | used within the PyTorch environment. 70 | 71 | The following **simple example** is a teaser showing how to compute 0-dim. persistent 72 | homology of a (1) Vietoris-Rips filtration which uses the Manhatten distance between 73 | samples and (2) doing the same using a pre-computed distance matrix. 74 | 75 | .. code-block:: python 76 | 77 | device = "cuda:0" 78 | 79 | # import numpy 80 | import numpy as np 81 | 82 | # import VR persistence computation functionality 83 | from torchph.pershom import vr_persistence_l1, vr_persistence 84 | 85 | # import scipy methods to compute pairwise distance matrices 86 | from scipy.spatial.distance import pdist 87 | from scipy.spatial.distance import squareform 88 | 89 | # create 10-dim. point cloud with 100 samples 90 | x = np.random.randn(100, 10) 91 | 92 | # compute VR persistent homology (using l1 metric) 93 | X = torch.Tensor(x).to(device) 94 | l_a, _ = vr_persistence_l1(X.contiguous(),0, 0); 95 | 96 | # compute the same using a pre-computed distance matrix 97 | D = torch.tensor( 98 | squareform( 99 | pdist(x, metric='cityblock') 100 | ) 101 | ).to("cuda:0") 102 | l_b, _ = vr_persistence(D, 0, 0) 103 | print("Diff: ", 104 | (l_a[0].float()-l_b[0].float()).abs().sum().item()) 105 | 106 | .. toctree:: 107 | :caption: Modules 108 | :maxdepth: 1 109 | :hidden: 110 | :glob: 111 | 112 | nn 113 | pershom 114 | 115 | .. toctree:: 116 | :caption: Notebooks 117 | :maxdepth: 0 118 | :hidden: 119 | :glob: 120 | 121 | tutorials/SLayer.ipynb 122 | tutorials/ToyDiffVR.ipynb 123 | tutorials/ComparisonSOTA.ipynb 124 | tutorials/InputOptim.ipynb 125 | 126 | Indices and tables 127 | ================== 128 | 129 | * :ref:`genindex` 130 | * :ref:`modindex` 131 | * :ref:`search` 132 | -------------------------------------------------------------------------------- /torchph/pershom/pershom_cpp_src/py_bindings.cpp: -------------------------------------------------------------------------------- 1 | #ifndef PROFILE //TODO remove this? 2 | 3 | #include 4 | 5 | #include "calc_pers_cuda.cuh" 6 | #include "vr_comp_cuda.cuh" 7 | #include "vertex_filtration_comp_cuda.h" 8 | 9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 10 | { 11 | m.def("CalcPersCuda__find_merge_pairings", &CalcPersCuda::find_merge_pairings, "find_merge_pairings (CUDA)"); 12 | m.def("CalcPersCuda__merge_columns", &CalcPersCuda::merge_columns, "merge_columns (CUDA)"); 13 | m.def("CalcPersCuda__read_barcodes", &CalcPersCuda::read_barcodes, "read_barcodes (CUDA)"); 14 | m.def("CalcPersCuda__calculate_persistence", &CalcPersCuda::calculate_persistence, "calculate_persistence (CUDA)"); 15 | 16 | m.def("VRCompCuda__vr_persistence", &VRCompCuda::vr_persistence, ""); 17 | m.def("VRCompCuda__vr_persistence_l1", &VRCompCuda::vr_persistence_l1, ""); 18 | m.def("VRCompCuda__write_combinations_table_to_tensor", &VRCompCuda::write_combinations_table_to_tensor, ""), 19 | m.def("VRCompCuda__co_faces_from_combinations", &VRCompCuda::co_faces_from_combinations, ""); 20 | 21 | m.def("VRCompCuda__l1_norm_distance_matrix", &VRCompCuda::l1_norm_distance_matrix, ""); 22 | m.def("VRCompCuda__l2_norm_distance_matrix", &VRCompCuda::l2_norm_distance_matrix, ""); 23 | 24 | pybind11::class_(m, "VRCompCuda__VietorisRipsArgsGenerator") 25 | .def(pybind11::init<>()) 26 | .def_readwrite("boundary_info_non_vertices", &VRCompCuda::VietorisRipsArgsGenerator::boundary_info_non_vertices) 27 | .def_readwrite("filtration_values_by_dim", &VRCompCuda::VietorisRipsArgsGenerator::filtration_values_by_dim) 28 | .def_readwrite("n_simplices_by_dim", &VRCompCuda::VietorisRipsArgsGenerator::n_simplices_by_dim) 29 | 30 | .def_readwrite("simplex_dimension_vector", &VRCompCuda::VietorisRipsArgsGenerator::simplex_dimension_vector) 31 | .def_readwrite("filtration_values_vector_without_vertices", &VRCompCuda::VietorisRipsArgsGenerator::filtration_values_vector_without_vertices) 32 | .def_readwrite("filtration_add_eps_hack_values", &VRCompCuda::VietorisRipsArgsGenerator::filtration_add_eps_hack_values) 33 | 34 | .def_readwrite("sort_indices_without_vertices", &VRCompCuda::VietorisRipsArgsGenerator::sort_indices_without_vertices) 35 | .def_readwrite("sort_indices_without_vertices_inverse", &VRCompCuda::VietorisRipsArgsGenerator::sort_indices_without_vertices_inverse) 36 | 37 | .def_readwrite("sorted_filtration_values_vector", &VRCompCuda::VietorisRipsArgsGenerator::sorted_filtration_values_vector) 38 | 39 | .def_readwrite("boundary_array", &VRCompCuda::VietorisRipsArgsGenerator::boundary_array) 40 | .def_readwrite("ba_row_i_to_bm_col_i_vector", &VRCompCuda::VietorisRipsArgsGenerator::ba_row_i_to_bm_col_i_vector) 41 | 42 | .def("__call__", &VRCompCuda::VietorisRipsArgsGenerator::operator()) 43 | 44 | .def("init_state", &VRCompCuda::VietorisRipsArgsGenerator::init_state, "") 45 | .def("make_boundary_info_edges", &VRCompCuda::VietorisRipsArgsGenerator::make_boundary_info_edges, "") 46 | .def("make_boundary_info_non_edges", &VRCompCuda::VietorisRipsArgsGenerator::make_boundary_info_non_edges, "") 47 | .def("make_simplex_ids_compatible_within_dimensions", &VRCompCuda::VietorisRipsArgsGenerator::make_simplex_ids_compatible_within_dimensions, "") 48 | .def("make_simplex_dimension_vector", &VRCompCuda::VietorisRipsArgsGenerator::make_simplex_dimension_vector, "") 49 | .def("make_filtration_values_vector_without_vertices", &VRCompCuda::VietorisRipsArgsGenerator::make_filtration_values_vector_without_vertices, "") 50 | .def("do_filtration_add_eps_hack", &VRCompCuda::VietorisRipsArgsGenerator::do_filtration_add_eps_hack, "") 51 | .def("make_sorting_infrastructure", &VRCompCuda::VietorisRipsArgsGenerator::make_sorting_infrastructure, "") 52 | .def("undo_filtration_add_eps_hack", &VRCompCuda::VietorisRipsArgsGenerator::undo_filtration_add_eps_hack, "") 53 | .def("make_sorted_filtration_values_vector", &VRCompCuda::VietorisRipsArgsGenerator::make_sorted_filtration_values_vector, "") 54 | .def("make_boundary_array_rows_unsorted", &VRCompCuda::VietorisRipsArgsGenerator::make_boundary_array_rows_unsorted, "") 55 | .def("apply_sorting_to_rows", &VRCompCuda::VietorisRipsArgsGenerator::apply_sorting_to_rows, "") 56 | .def("make_ba_row_i_to_bm_col_i_vector", &VRCompCuda::VietorisRipsArgsGenerator::make_ba_row_i_to_bm_col_i_vector, "") 57 | // .def("", &VRCompCuda::VietorisRipsArgsGenerator::, "") 58 | ; 59 | 60 | m.def("VertFiltCompCuda__vert_filt_comp_calculate_persistence_args", &VertFiltCompCuda::vert_filt_comp_calculate_persistence_args, "compute args for calculate_persistence from simplicial complex definition(CUDA)"); 61 | m.def("VertFiltCompCuda__vert_filt_persistence_single", &VertFiltCompCuda::vert_filt_persistence_single, ""); 62 | m.def("VertFiltCompCuda__vert_filt_persistence_batch", &VertFiltCompCuda::vert_filt_persistence_batch, ""); 63 | 64 | } 65 | 66 | #endif -------------------------------------------------------------------------------- /pershom_dev/profiling_pershom/profile_case.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | // #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | using namespace at; 15 | 16 | 17 | std::vector> calculate_persistence( 18 | Tensor descending_sorted_boundary_array, 19 | Tensor ind_not_reduced, 20 | Tensor column_dimension, 21 | int max_pairs, 22 | int max_dimension=-1); 23 | 24 | Tensor find_merge_pairings(Tensor pivots, int max_pairs); 25 | 26 | std::tuple read_profiling_data(){ 27 | 28 | std::tuple ret; 29 | std::ifstream inputFile("data/boundary_array.txt"); 30 | 31 | std::vector boundary_array_data; 32 | 33 | if (inputFile.good()) { 34 | 35 | int32_t current_number = 0; 36 | while (inputFile >> current_number){ 37 | boundary_array_data.push_back(current_number); 38 | } 39 | 40 | inputFile.close(); 41 | } 42 | 43 | 44 | inputFile = std::ifstream("data/boundary_array_size.txt"); 45 | std::vector boundary_array_size; 46 | 47 | if (inputFile.good()) { 48 | 49 | int current_number = 0; 50 | while (inputFile >> current_number){ 51 | boundary_array_size.push_back(current_number); 52 | } 53 | 54 | inputFile.close(); 55 | } 56 | 57 | int32_t *d_boundary_array_data; 58 | 59 | // Sends data to device 60 | auto size = boundary_array_size[0]*boundary_array_size[1]*sizeof(int32_t); 61 | cudaMalloc((void**) &d_boundary_array_data, size); 62 | cudaMemcpy(d_boundary_array_data, &boundary_array_data[0], size, cudaMemcpyHostToDevice); 63 | 64 | auto boundary_array = CUDA(kInt).tensorFromBlob(d_boundary_array_data, {boundary_array_size[0], boundary_array_size[1]}); 65 | boundary_array = boundary_array.clone(); 66 | cudaFree(d_boundary_array_data); 67 | // auto boundary_array = CPU(kInt).tensorFromBlob(&boundary_array_data[0], {boundary_array_size[0], boundary_array_size[1]}); 68 | 69 | 70 | inputFile = std::ifstream("data/ind_not_reduced.txt"); 71 | std::vector ind_not_reduced_data; 72 | 73 | if (inputFile.good()) { 74 | 75 | int64_t current_number = 0; 76 | while (inputFile >> current_number){ 77 | ind_not_reduced_data.push_back(current_number); 78 | } 79 | 80 | inputFile.close(); 81 | } 82 | 83 | int64_t *d_ind_not_reduced_data; 84 | size = ind_not_reduced_data.size()*sizeof(int64_t); 85 | cudaMalloc((void**) &d_ind_not_reduced_data, size); 86 | cudaMemcpy(d_ind_not_reduced_data, &ind_not_reduced_data[0], size, cudaMemcpyHostToDevice); 87 | 88 | auto ind_not_reduced = CUDA(kLong).tensorFromBlob(d_ind_not_reduced_data, {ind_not_reduced_data.size()}); 89 | ind_not_reduced = ind_not_reduced.clone(); 90 | cudaFree(d_ind_not_reduced_data); 91 | 92 | 93 | inputFile = std::ifstream("data/column_dimension.txt"); 94 | std::vector column_dimension_data; 95 | 96 | if (inputFile.good()) { 97 | 98 | int current_number = 0; 99 | while (inputFile >> current_number){ 100 | column_dimension_data.push_back(current_number); 101 | } 102 | 103 | inputFile.close(); 104 | } 105 | 106 | int32_t *d_column_dimension_data; 107 | size = column_dimension_data.size()*sizeof(int32_t); 108 | cudaMalloc((void**) &d_column_dimension_data, size); 109 | cudaMemcpy(d_column_dimension_data, &column_dimension_data[0], size, cudaMemcpyHostToDevice); 110 | 111 | auto column_dimension = CUDA(kInt).tensorFromBlob(d_column_dimension_data, {column_dimension_data.size()}); 112 | column_dimension = column_dimension.clone(); 113 | cudaFree(d_column_dimension_data); 114 | // auto column_dimension = CPU(kInt).tensorFromBlob(&column_dimension_data[0], {column_dimension_data.size()}); 115 | 116 | 117 | inputFile = std::ifstream("data/max_dimension.txt"); 118 | int max_dimension; 119 | 120 | if (inputFile.good()) { 121 | 122 | inputFile >> max_dimension; 123 | 124 | inputFile.close(); 125 | } 126 | 127 | return std::make_tuple(boundary_array, ind_not_reduced, column_dimension, max_dimension); 128 | 129 | } 130 | 131 | void sorting(Tensor pivots){ 132 | pivots.sort(0); 133 | } 134 | 135 | int main() 136 | { 137 | dlopen("libcaffe2_gpu.so", RTLD_NOW); 138 | 139 | auto bm_ind_dim_maxdim = read_profiling_data(); 140 | 141 | auto bm = std::get<0>(bm_ind_dim_maxdim); 142 | auto ind_not_reduced = std::get<1>(bm_ind_dim_maxdim); 143 | auto col_dim = std::get<2>(bm_ind_dim_maxdim); 144 | auto max_dim = std::get<3>(bm_ind_dim_maxdim); 145 | 146 | // auto pivots = bm.slice(1, 0, 1).contiguous(); 147 | 148 | ProfilerStart("profile.google-pprof"); 149 | for (int i=0; i<10; i++){ 150 | // auto pairs = find_merge_pairings(pivots, 100000); 151 | auto ret = calculate_persistence(bm.clone(), ind_not_reduced.clone(), col_dim.clone(), max_dim); 152 | } 153 | ProfilerStop(); 154 | 155 | // pivots = pivots.toBackend(Backend::CPU); 156 | // std::cout << Scalar(pivots[0][0]).to() << std::endl; 157 | 158 | return 0; 159 | } -------------------------------------------------------------------------------- /tests/pershom/test_pershom_backend.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import glob 4 | import pickle 5 | import torchph.pershom.pershom_backend as pershom_backend 6 | from collections import Counter 7 | 8 | 9 | class Test_find_merge_pairings: 10 | @pytest.mark.parametrize("device, dtype", [ 11 | (torch.device('cuda'), torch.int64) 12 | ]) 13 | def test_return_value_dtype(self, device, dtype): 14 | pivots = torch.tensor([1, 1], device=device, dtype=dtype) 15 | 16 | result = pershom_backend.find_merge_pairings(pivots) 17 | 18 | assert result.dtype == torch.int64 19 | 20 | @pytest.mark.parametrize("device, dtype", [ 21 | (torch.device('cuda'), torch.int64) 22 | ]) 23 | def test_parameter_max_pairs(self, device, dtype): 24 | pivots = torch.tensor([1]*1000, device=device, dtype=dtype).unsqueeze(1) 25 | 26 | result = pershom_backend.find_merge_pairings(pivots, max_pairs=100) 27 | 28 | assert result.size(0) == 100 29 | 30 | @pytest.mark.parametrize("device, dtype", [ 31 | (torch.device('cuda'), torch.int64) 32 | ]) 33 | def test_no_merge_pairs(self, device, dtype): 34 | pivots = torch.tensor(list(range(100)), device=device, dtype=dtype).unsqueeze(1) 35 | assert pershom_backend.find_merge_pairings(pivots).numel() == 0 36 | 37 | @pytest.mark.parametrize("device, dtype", [ 38 | (torch.device('cuda'), torch.int64) 39 | ]) 40 | def test_result_1(self, device, dtype): 41 | pivots = [6, 3, 3, 3, 5, 6, 6, 0, 5, 5] 42 | pivots = torch.tensor(pivots, device=device, dtype=dtype).unsqueeze(1) 43 | 44 | result = pershom_backend.find_merge_pairings(pivots) 45 | 46 | assert result.dtype == torch.int64 47 | 48 | expected_result = set([(0, 5), (0, 6), 49 | (1, 2), (1, 3), 50 | (4, 8), (4, 9)]) 51 | 52 | result = set(tuple(x) for x in result.tolist()) 53 | 54 | assert result == (expected_result) 55 | 56 | @pytest.mark.parametrize("device, dtype", [ 57 | (torch.device('cuda'), torch.int64) 58 | ]) 59 | def test_result_2(self, device, dtype): 60 | pivots = sum([100*[i] for i in range(100)], []) 61 | pivots = torch.tensor(pivots, device=device, dtype=dtype).unsqueeze(1) 62 | 63 | expected_result = torch.tensor([(int(i/100) * 100, i) for i in range(100 * 100) if i % 100 != 0]) 64 | expected_result = expected_result.long().to('cuda') 65 | 66 | result = pershom_backend.find_merge_pairings(pivots) 67 | 68 | assert expected_result.equal(result) 69 | 70 | 71 | class Test_calculate_persistence: 72 | 73 | @staticmethod 74 | def calculate_persistence_output_to_barcode_list(input): 75 | ret = [] 76 | b, b_e = input 77 | 78 | for dim, (b_dim, b_dim_e) in enumerate(zip(b, b_e)): 79 | b_dim, b_dim_e = b_dim.float(), b_dim_e.float() 80 | 81 | tmp = torch.empty_like(b_dim_e) 82 | tmp.fill_(float('inf')) 83 | b_dim_e = torch.cat([b_dim_e, tmp], dim=1) 84 | 85 | dgm = torch.cat([b_dim, b_dim_e], dim=0) 86 | dgm = dgm.tolist() 87 | dgm = Counter(((float(b), float(d)) for b, d in dgm)) 88 | 89 | ret.append(dgm) 90 | 91 | return ret 92 | 93 | @pytest.mark.parametrize("device, dtype", [ 94 | (torch.device('cuda'), torch.int64) 95 | ]) 96 | def test_empty_input(self, device, dtype): 97 | ba = torch.empty([0], device=device, dtype=dtype) 98 | ba_row_i_to_bm_col_i = ba 99 | simplex_dimension = torch.zeros(10, device=device, dtype=dtype) 100 | 101 | not_ess, ess = pershom_backend.calculate_persistence( 102 | ba, 103 | ba_row_i_to_bm_col_i, 104 | simplex_dimension, 105 | 2, 106 | -1) 107 | 108 | for pairings in not_ess: 109 | assert len(pairings) == 0 110 | 111 | assert len(ess[0]) == 10 112 | 113 | for birth_i in ess[1:]: 114 | assert len(birth_i) == 0 115 | 116 | def test_simple_1(self): 117 | device = torch.device('cuda') 118 | dtype = torch.int64 119 | 120 | ba = torch.empty((1, 4)) 121 | ba.fill_(-1) 122 | ba[0, 0:2] = torch.tensor([1, 0]) 123 | ind_not_reduced = torch.tensor([2]) 124 | 125 | row_dim = torch.tensor([0, 0, 1]) 126 | 127 | max_dim_to_read_of_reduced_ba = 1 128 | 129 | ba = ba.to(device).type(dtype) 130 | ind_not_reduced = ind_not_reduced.to(device).long() 131 | row_dim = row_dim.to(device).type(dtype) 132 | 133 | out = pershom_backend.calculate_persistence( 134 | ba, 135 | ind_not_reduced, 136 | row_dim, 137 | max_dim_to_read_of_reduced_ba, 138 | 100) 139 | 140 | barcodes = Test_calculate_persistence.calculate_persistence_output_to_barcode_list(out) 141 | 142 | assert barcodes[0] == Counter([(1.0, 2.0), (0.0, float('inf'))]) 143 | 144 | def test_random_simplicial_complexes(self): 145 | device = torch.device('cuda') 146 | dtype = torch.int64 147 | 148 | for sp_path in glob.glob('test_pershom_backend_data/random_simplicial_complexes/*'): 149 | 150 | with open(sp_path, 'br') as f: 151 | data = pickle.load(f) 152 | 153 | assert len(data) == 2 154 | 155 | ba, ind_not_reduced, row_dim, max_dim_to_read_of_reduced_ba = data['calculate_persistence_args'] 156 | expected_result = data['expected_result'] 157 | 158 | ba, ind_not_reduced, row_dim = ba.to(device).type(dtype), ind_not_reduced.to(device).long(), row_dim.to(device).type(dtype) 159 | 160 | result = pershom_backend.calculate_persistence( 161 | ba, 162 | ind_not_reduced, 163 | row_dim, 164 | max_dim_to_read_of_reduced_ba, 165 | 10000) 166 | result = Test_calculate_persistence.calculate_persistence_output_to_barcode_list(result) 167 | 168 | for dgm, dgm_exp in zip(result, expected_result): 169 | assert dgm == dgm_exp 170 | -------------------------------------------------------------------------------- /docs/_modules/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | Overview: module code — torchph 0.0.0 documentation 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 |
47 | 48 | 106 | 107 |
108 | 109 | 110 | 116 | 117 | 118 |
119 | 120 |
121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 |
139 | 140 |
    141 | 142 |
  • Docs »
  • 143 | 144 |
  • Overview: module code
  • 145 | 146 | 147 |
  • 148 | 149 |
  • 150 | 151 |
152 | 153 | 154 |
155 |
156 |
157 |
158 | 159 |

All modules for which code is available

160 | 163 | 164 |
165 | 166 |
167 |
168 | 169 | 170 |
171 | 172 |
173 |

174 | © Copyright 2020, Christoph D. Hofer, Roland Kwitt 175 | 176 |

177 |
178 | Built with Sphinx using a theme provided by Read the Docs. 179 | 180 |
181 | 182 |
183 |
184 | 185 |
186 | 187 |
188 | 189 | 190 | 191 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | -------------------------------------------------------------------------------- /docs/search.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | Search — torchph 0.0.0 documentation 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 |
48 | 49 | 107 | 108 |
109 | 110 | 111 | 117 | 118 | 119 |
120 | 121 |
122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 |
140 | 141 |
    142 | 143 |
  • Docs »
  • 144 | 145 |
  • Search
  • 146 | 147 | 148 |
  • 149 | 150 | 151 | 152 |
  • 153 | 154 |
155 | 156 | 157 |
158 |
159 |
160 |
161 | 162 | 170 | 171 | 172 |
173 | 174 |
175 | 176 |
177 | 178 |
179 |
180 | 181 | 182 |
183 | 184 |
185 |

186 | © Copyright 2020, Christoph D. Hofer, Roland Kwitt 187 | 188 |

189 |
190 | Built with Sphinx using a theme provided by Read the Docs. 191 | 192 |
193 | 194 |
195 |
196 | 197 |
198 | 199 |
200 | 201 | 202 | 203 | 208 | 209 | 210 | 211 | 212 | 213 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | -------------------------------------------------------------------------------- /docs_src/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | # import os 16 | # import sys 17 | # sys.path.insert(0, os.path.abspath('.')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'torchph' 23 | copyright = '2020, Christoph D. Hofer, Roland Kwitt' 24 | author = 'Christoph D. Hofer, Roland Kwitt' 25 | 26 | # The short X.Y version 27 | version = '' 28 | # The full version, including alpha/beta/rc tags 29 | release = '0.0.0' 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # If your documentation needs a minimal Sphinx version, state it here. 35 | # 36 | # needs_sphinx = '1.0' 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = [ 42 | 'nbsphinx', 43 | 'sphinx.ext.autosummary', 44 | 'sphinx.ext.coverage', 45 | 'sphinx.ext.napoleon', 46 | 'sphinx.ext.autodoc', 47 | 'sphinx.ext.intersphinx', 48 | 'sphinx.ext.mathjax', 49 | 'sphinx.ext.viewcode', 50 | 'sphinx.ext.githubpages', 51 | ] 52 | 53 | # Add any paths that contain templates here, relative to this directory. 54 | templates_path = ['_templates'] 55 | 56 | # The suffix(es) of source filenames. 57 | # You can specify multiple suffix as a list of string: 58 | # 59 | # source_suffix = ['.rst', '.md'] 60 | source_suffix = '.rst' 61 | 62 | # The master toctree document. 63 | master_doc = 'index' 64 | 65 | # The language for content autogenerated by Sphinx. Refer to documentation 66 | # for a list of supported languages. 67 | # 68 | # This is also used if you do content translation via gettext catalogs. 69 | # Usually you set "language" from the command line for these cases. 70 | language = None 71 | 72 | # List of patterns, relative to source directory, that match files and 73 | # directories to ignore when looking for source files. 74 | # This pattern also affects html_static_path and html_extra_path. 75 | exclude_patterns = ['_build', '**.ipynb_checkpoints'] 76 | 77 | 78 | # The name of the Pygments (syntax highlighting) style to use. 79 | pygments_style = None 80 | 81 | # -- Options for HTML output ------------------------------------------------- 82 | 83 | # The theme to use for HTML and HTML Help pages. See the documentation for 84 | # a list of builtin themes. 85 | # 86 | html_theme = 'sphinx_rtd_theme' # 'alabaster' 87 | 88 | # Theme options are theme-specific and customize the look and feel of a theme 89 | # further. For a list of options available for each theme, see the 90 | # documentation. 91 | # 92 | # html_theme_options = {} 93 | 94 | # Add any paths that contain custom static files (such as style sheets) here, 95 | # relative to this directory. They are copied after the builtin static files, 96 | # so a file named "default.css" will overwrite the builtin "default.css". 97 | html_static_path = ['_static'] 98 | 99 | # Custom sidebar templates, must be a dictionary that maps document names 100 | # to template names. 101 | # 102 | # The default sidebars (for documents that don't match any pattern) are 103 | # defined by theme itself. Builtin themes are using these templates by 104 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 105 | # 'searchbox.html']``. 106 | # 107 | # html_sidebars = {} 108 | 109 | 110 | # -- Options for HTMLHelp output --------------------------------------------- 111 | 112 | # Output file base name for HTML help builder. 113 | htmlhelp_basename = 'torchphdoc' 114 | 115 | 116 | # -- Options for LaTeX output ------------------------------------------------ 117 | 118 | latex_elements = { 119 | # The paper size ('letterpaper' or 'a4paper'). 120 | # 121 | # 'papersize': 'letterpaper', 122 | 123 | # The font size ('10pt', '11pt' or '12pt'). 124 | # 125 | # 'pointsize': '10pt', 126 | 127 | # Additional stuff for the LaTeX preamble. 128 | # 129 | # 'preamble': '', 130 | 131 | # Latex figure (float) alignment 132 | # 133 | # 'figure_align': 'htbp', 134 | } 135 | 136 | # Grouping the document tree into LaTeX files. List of tuples 137 | # (source start file, target name, title, 138 | # author, documentclass [howto, manual, or own class]). 139 | latex_documents = [ 140 | (master_doc, 'torchph.tex', 'torchph Documentation', 141 | 'Christoph D. Hofer', 'manual'), 142 | ] 143 | 144 | 145 | # -- Options for manual page output ------------------------------------------ 146 | 147 | # One entry per manual page. List of tuples 148 | # (source start file, name, description, authors, manual section). 149 | man_pages = [ 150 | (master_doc, 'torchph', 'torchph Documentation', 151 | [author], 1) 152 | ] 153 | 154 | 155 | # -- Options for Texinfo output ---------------------------------------------- 156 | 157 | # Grouping the document tree into Texinfo files. List of tuples 158 | # (source start file, target name, title, author, 159 | # dir menu entry, description, category) 160 | texinfo_documents = [ 161 | (master_doc, 'torchph', 'torchph Documentation', 162 | author, 'torchph', 'One line description of project.', 163 | 'Miscellaneous'), 164 | ] 165 | 166 | 167 | # -- Options for Epub output ------------------------------------------------- 168 | 169 | # Bibliographic Dublin Core info. 170 | epub_title = project 171 | 172 | # The unique identifier of the text. This can be a ISBN number 173 | # or the project homepage. 174 | # 175 | # epub_identifier = '' 176 | 177 | # A unique identification for the text. 178 | # 179 | # epub_uid = '' 180 | 181 | # A list of files that should not be packed into the epub file. 182 | epub_exclude_files = ['search.html'] 183 | 184 | 185 | # -- Extension configuration ------------------------------------------------- 186 | 187 | # -- Options for intersphinx extension --------------------------------------- 188 | 189 | # Example configuration for intersphinx: refer to the Python standard library. 190 | intersphinx_mapping = {'https://docs.python.org/': None} 191 | 192 | 193 | # -- Customization chofer ---------------------------------------------------- 194 | 195 | def skip(app, what, name, obj, would_skip, options): 196 | if name == "__init__": 197 | return False 198 | return would_skip 199 | 200 | def setup(app): 201 | app.connect("autodoc-skip-member", skip) 202 | -------------------------------------------------------------------------------- /docs/py-modindex.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | Python Module Index — torchph 0.0.0 documentation 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 |
50 | 51 | 109 | 110 |
111 | 112 | 113 | 119 | 120 | 121 |
122 | 123 |
124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 |
142 | 143 |
    144 | 145 |
  • Docs »
  • 146 | 147 |
  • Python Module Index
  • 148 | 149 | 150 |
  • 151 | 152 |
  • 153 | 154 |
155 | 156 | 157 |
158 |
159 |
160 |
161 | 162 | 163 |

Python Module Index

164 | 165 |
166 | t 167 |
168 | 169 | 170 | 171 | 173 | 174 | 176 | 179 | 180 | 181 | 184 | 185 | 186 | 189 |
 
172 | t
177 | torchph 178 |
    182 | torchph.nn.slayer 183 |
    187 | torchph.pershom.pershom_backend 188 |
190 | 191 | 192 |
193 | 194 |
195 |
196 | 197 | 198 |
199 | 200 |
201 |

202 | © Copyright 2020, Christoph D. Hofer, Roland Kwitt 203 | 204 |

205 |
206 | Built with Sphinx using a theme provided by Read the Docs. 207 | 208 |
209 | 210 |
211 |
212 | 213 |
214 | 215 |
216 | 217 | 218 | 219 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | -------------------------------------------------------------------------------- /torchph/pershom/pershom_backend.py: -------------------------------------------------------------------------------- 1 | r""" 2 | This module exposes the C++/CUDA backend functionality for Python. 3 | 4 | Terminology 5 | ----------- 6 | 7 | Descending sorted boundary array: 8 | Boundary array which encodes the boundary matrix (BM) for a given 9 | filtration in column first order. 10 | Let BA be the descending_sorted_boundary of BM, then 11 | ``BA[i, :]`` is the i-th column of BM. 12 | Content encoded as decreasingly sorted list, embedded into the array 13 | with -1 padding from the right. 14 | 15 | Example : 16 | ``BA[3, :] = [2, 0, -1, -1]`` 17 | then :math:`\partial(v_3) = v_0 + v_2` 18 | 19 | ``BA[6, :] = [5, 4, 3, -1]`` 20 | then :math:`\partial(v_6) = v_3 + v_4 + v_5` 21 | 22 | 23 | Compressed descending sorted boundary array: 24 | Same as *descending sorted boundary array* but rows consisting only of -1 25 | are omitted. 26 | This is sometimes used for efficiency purposes and is usually accompanied 27 | by a vector, ``v``, telling which row of the reduced BA corresponds to 28 | which row of the uncompressed BA, i.e., ``v[3] = 5`` means that the 3rd 29 | row of the reduced BA corresponds to the 5th row in the uncompressed 30 | version. 31 | 32 | Birth/Death-time: 33 | Index of the coresponding birth/death event in the filtration. 34 | This is always an *integer*. 35 | 36 | Birth/Death-value: 37 | If a filtration is induced by a real-valued function, this corresponds 38 | to the value of this function corresponding to the birth/death event. 39 | This is always *real*-valued. 40 | 41 | Limitations 42 | ----------- 43 | 44 | Currently all ``cuda`` backend functionality **only** supports ``int64_t`` and 45 | ``float32_t`` typing. 46 | 47 | """ 48 | import warnings 49 | import os.path as pth 50 | from typing import List 51 | from torch import Tensor 52 | from glob import glob 53 | 54 | 55 | from torch.utils.cpp_extension import load 56 | 57 | 58 | __module_file_dir = pth.dirname(pth.realpath(__file__)) 59 | __cpp_src_dir = pth.join(__module_file_dir, 'pershom_cpp_src') 60 | src_files = [] 61 | 62 | for extension in ['*.cpp', '*.cu']: 63 | src_files += glob(pth.join(__cpp_src_dir, extension)) 64 | 65 | # jit compiling the c++ extension 66 | 67 | _failed_compilation_msg = \ 68 | """ 69 | Failed jit compilation in {}. 70 | Error was `{}`. 71 | The error will be re-raised calling any function in this module. 72 | """ 73 | 74 | __C = None 75 | try: 76 | __C = load( 77 | 'pershom_cuda_ext', 78 | src_files, 79 | verbose=False) 80 | 81 | except Exception as ex: 82 | warnings.warn(_failed_compilation_msg.format(__file__, ex)) 83 | 84 | class ErrorThrower(object): 85 | ex = ex 86 | 87 | def __getattr__(self, name): 88 | raise self.ex 89 | 90 | __C = ErrorThrower() 91 | 92 | 93 | def find_merge_pairings( 94 | pivots: Tensor, 95 | max_pairs: int = -1 96 | ) -> Tensor: 97 | """Finds the pairs which have to be merged in the current iteration of the 98 | matrix reduction. 99 | 100 | Args: 101 | pivots: 102 | The pivots of a descending sorted boundary array. 103 | Expected size is ``Nx1``, where N is the number of columns of the 104 | underlying descending sorted boundary array. 105 | 106 | max_pairs: 107 | The output is at most a ``max_pairs x 2`` Tensor. If set to 108 | default all possible merge pairs are returned. 109 | 110 | Returns: 111 | The merge pairs, ``p``, for the current iteration of the reduction. 112 | ``p[i]`` is a merge pair. 113 | In boundary matrix notation this would mean column ``p[i][0]`` has to 114 | be merged into column ``p[i][1]``. 115 | """ 116 | return __C.CalcPersCuda__find_merge_pairings(pivots, max_pairs) 117 | 118 | 119 | def merge_columns_( 120 | compr_desc_sort_ba: Tensor, 121 | merge_pairs: Tensor 122 | ) -> None: 123 | r"""Executes the given merging operations inplace on the descending 124 | sorted boundary array. 125 | 126 | Args: 127 | compr_desc_sort_ba: 128 | see module description top. 129 | 130 | merge_pairs: 131 | output of a ``find_merge_pairings`` call. 132 | 133 | Returns: 134 | None 135 | """ 136 | __C.CalcPersCuda__merge_columns_(compr_desc_sort_ba, merge_pairs) 137 | 138 | 139 | def read_barcodes( 140 | pivots: Tensor, 141 | simplex_dimension: Tensor, 142 | max_dim_to_read_of_reduced_ba: int 143 | ) -> List[List[Tensor]]: 144 | """Reads the barcodes using the pivot of a reduced boundary array. 145 | 146 | Arguments: 147 | pivots: 148 | pivots is the first column of a compr_desc_sort_ba 149 | 150 | simplex_dimension: 151 | Vector whose i-th entry is the dimension if the i-th simplex in 152 | the given filtration. 153 | 154 | max_dim_to_read_of_reduced_ba: 155 | features up to max_dim_to_read_of_reduced_ba are read from the 156 | reduced boundary array 157 | 158 | Returns: 159 | List of birth/death times. 160 | 161 | ``ret[0][n]`` are non essential birth/death-times of dimension ``n``. 162 | 163 | ``ret[1][n]`` are birth-times of essential classes of dimension ``n``. 164 | """ 165 | return __C.CalcPersCuda__read_barcodes( 166 | pivots, 167 | simplex_dimension, 168 | max_dim_to_read_of_reduced_ba) 169 | 170 | 171 | def calculate_persistence( 172 | compr_desc_sort_ba: Tensor, 173 | ba_row_i_to_bm_col_i: Tensor, 174 | simplex_dimension: Tensor, 175 | max_dim_to_read_of_reduced_ba: int, 176 | max_pairs: int = -1 177 | ) -> List[List[Tensor]]: 178 | """Returns the barcodes of the given encoded boundary array. 179 | 180 | Arguments: 181 | compr_desc_sort_ba: 182 | A `compressed descending sorted boundary array`, 183 | see readme section top. 184 | 185 | ba_row_i_to_bm_col_i: 186 | Vector whose i-th entry is the column index of the boundary 187 | matrix the i-th row in ``compr_desc_sort_ba corresponds`` to. 188 | 189 | simplex_dimension: 190 | Vector whose i-th entry is the dimension if the i-th simplex in 191 | the given filtration 192 | 193 | max_pairs: see ``find_merge_pairings``. 194 | 195 | max_dim_to_read_of_reduced_ba: 196 | features up to max_dim_to_read_of_reduced_ba are read from the 197 | reduced boundary array. 198 | 199 | Returns: 200 | List of birth/death times. 201 | 202 | ``ret[0][n]`` are non essential birth/death-times of dimension ``n``. 203 | 204 | ``ret[1][n]`` are birth-times of essential classes of dimension ``n``. 205 | """ 206 | return __C.CalcPersCuda__calculate_persistence( 207 | compr_desc_sort_ba, 208 | ba_row_i_to_bm_col_i, 209 | simplex_dimension, 210 | max_dim_to_read_of_reduced_ba, 211 | max_pairs) 212 | 213 | 214 | def vr_persistence_l1( 215 | point_cloud: Tensor, 216 | max_dimension: int, 217 | max_ball_diameter: float = 0.0 218 | ) -> List[List[Tensor]]: 219 | """Returns the barcodes of the Vietoris-Rips complex of a given point cloud 220 | w.r.t. the l1 (manhatten) distance. 221 | 222 | Args: 223 | point_cloud: 224 | Point cloud from which the Vietoris-Rips complex is generated. 225 | 226 | max_dimension: 227 | The dimension of the used Vietoris-Rips complex. 228 | 229 | max_ball_diameter: 230 | If not 0, edges whose two defining vertices' distance is greater 231 | than ``max_ball_diameter`` are ignored. 232 | 233 | Returns: 234 | List of birth/death times. 235 | 236 | ``ret[0][n]`` are non essential birth/death-*values* of dimension ``n``. 237 | 238 | ``ret[1][n]`` are birth-*values* of essential classes of 239 | dimension ``n``. 240 | """ 241 | return __C.VRCompCuda__vr_persistence_l1( 242 | point_cloud, 243 | max_dimension, 244 | max_ball_diameter) 245 | 246 | 247 | def vr_persistence( 248 | distance_matrix: Tensor, 249 | max_dimension: int, 250 | max_ball_diameter: float = 0.0 251 | ) -> List[List[Tensor]]: 252 | """Returns the barcodes of the Vietoris-Rips complex of a given distance 253 | matrix. 254 | 255 | **Note**: ``distance_matrix`` is assumed to be a square matrix. 256 | Practically, symmetry is *not* required and the upper diagonal part is 257 | *ignored*. For the computation, just the *lower* diagonal part is used. 258 | 259 | Args: 260 | distance_matrix: 261 | Distance matrix the Vietoris-Rips complex is based on. 262 | 263 | max_dimension: 264 | The dimension of the used Vietoris-Rips complex. 265 | 266 | max_ball_diameter: 267 | If not 0, edges whose two defining vertices' distance is greater 268 | than ``max_ball_diameter`` are ignored. 269 | 270 | Returns: 271 | List of birth/death times. 272 | 273 | ``ret[0][n]`` are non essential birth/death-*values* of dimension ``n``. 274 | 275 | ``ret[1][n]`` are birth-*values* of essential classes of 276 | dimension ``n``. 277 | """ 278 | return __C.VRCompCuda__vr_persistence( 279 | distance_matrix, 280 | max_dimension, 281 | max_ball_diameter) 282 | -------------------------------------------------------------------------------- /docs/tutorials/SLayer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Differentiable barcode vectorization\n", 8 | "\n", 9 | "This tutorial gives you a brief insight in the functionalities offered by the `torchph.nn.SLayerExponential` \n", 10 | "module. It assumes familarity with standard `PyTorch` functionality. \n", 11 | "\n", 12 | "Also, `torchph.nn.SLayerExponential` is just one *structure element* and others are available as well (see documentation)." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "from shared_code import check_torchph_availability\n", 22 | "check_torchph_availability()" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "from torchph.nn import SLayerExponential\n", 32 | "\n", 33 | "# create an instance with 3 structure elements over \\R^2\n", 34 | "sl = SLayerExponential(3, 2)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "`nn.SLayerExponential` is a `torch.nn.Module` ... " 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 4, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "data": { 51 | "text/plain": [ 52 | "True" 53 | ] 54 | }, 55 | "execution_count": 4, 56 | "metadata": {}, 57 | "output_type": "execute_result" 58 | } 59 | ], 60 | "source": [ 61 | "import torch\n", 62 | "isinstance(sl, torch.nn.Module)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "... now we can do all the beautiful stuff which is inherited from `torch.nn.Module`, e.g.," 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 5, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "name": "stdout", 79 | "output_type": "stream", 80 | "text": [ 81 | "Parameter containing:\n", 82 | "tensor([[0.6355, 0.3604],\n", 83 | " [0.3162, 0.9167],\n", 84 | " [0.4922, 0.9822]], requires_grad=True)\n", 85 | "Parameter containing:\n", 86 | "tensor([[3., 3.],\n", 87 | " [3., 3.],\n", 88 | " [3., 3.]], requires_grad=True)\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "for p in sl.parameters():\n", 94 | " print(p)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "The module has **two** parameters: \n", 102 | "1. `centers` : controls the centers of the structure elements. \n", 103 | "2. `sharpness`: controls how tight the used Gaussians are. The higher the value, the tighter. \n", 104 | "\n", 105 | "Both can be initialized using the `centers_init` and `sharpness_init` keyword arguments, respectively." 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "Parameter containing:\n", 118 | "tensor([[0.0000, 0.0000],\n", 119 | " [0.5000, 0.5000],\n", 120 | " [1.0000, 1.0000]], requires_grad=True)\n", 121 | "Parameter containing:\n", 122 | "tensor([[1., 1.],\n", 123 | " [2., 2.],\n", 124 | " [3., 3.]], requires_grad=True)\n" 125 | ] 126 | } 127 | ], 128 | "source": [ 129 | "# here is an initialization example\n", 130 | "centers_init = torch.Tensor(\n", 131 | " [\n", 132 | " [0,0], \n", 133 | " [0.5, 0.5], \n", 134 | " [1,1]\n", 135 | " ]\n", 136 | ")\n", 137 | "\n", 138 | "sharpness_init = torch.Tensor(\n", 139 | " [\n", 140 | " [1,1], \n", 141 | " [2,2], \n", 142 | " [3,3]\n", 143 | " ]\n", 144 | ")\n", 145 | "\n", 146 | "sl = SLayerExponential(3, 2, \n", 147 | " centers_init=centers_init, \n", 148 | " sharpness_init=sharpness_init)\n", 149 | "\n", 150 | "print(sl.centers)\n", 151 | "print(sl.sharpness)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "The simplest input form for `nn.SLayerExponential` is a `list` of `torch.Tensor` objects which are treated as a *batch*. " 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 7, 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "torch.Size([4, 3])\n" 171 | ] 172 | } 173 | ], 174 | "source": [ 175 | "# As an example, we create a batch of multisets\n", 176 | "mset_1 = [[0, 0]]\n", 177 | "mset_2 = [[0, 0], [0, 0]]\n", 178 | "mset_3 = [[1, 1], [0, 0]]\n", 179 | "mset_4 = [[0, 0], [1, 1]]\n", 180 | "batch = [mset_1, mset_2, mset_3, mset_4]\n", 181 | "batch = [torch.Tensor(x) for x in batch]\n", 182 | "output = sl(batch)\n", 183 | "print(output.size())" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "As we see the output dimensionality is `(4, 3)` since\n", 191 | "we have a batch of size `4` and `3` structure elements. \n", 192 | "\n", 193 | "In other words, \n", 194 | "`output[i, j] =` \"evaluation of structure element j on mset_i\"\n", 195 | "\n", 196 | "Lets take a look ... " 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 8, 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "name": "stdout", 206 | "output_type": "stream", 207 | "text": [ 208 | "tensor([[1.0000e+00, 1.3534e-01, 1.5230e-08],\n", 209 | " [2.0000e+00, 2.7067e-01, 3.0460e-08],\n", 210 | " [1.1353e+00, 2.7067e-01, 1.0000e+00],\n", 211 | " [1.1353e+00, 2.7067e-01, 1.0000e+00]], grad_fn=)\n" 212 | ] 213 | } 214 | ], 215 | "source": [ 216 | "print(output)" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": {}, 222 | "source": [ 223 | "**We observe the following:**\n", 224 | "\n", 225 | "1. The j-th stucture element approximates the multiplicity function of the given input at point `sl.centers[j]`. E.g., the output of mset_1, `output[0, :]`, is approx. `(1, 0, 0)`. \n", 226 | "2. `sl.sharpness[j]` controls the amount of contribution of points not exactly on `sl.centers[j]` with respect to their distance to `sl.centers[j]`. \n", 227 | "3. The input is interpreted as set, i.e., it is permutation invariant, as mset_3 and mset_4 do not defer as multiset and also `output[2,:] == output[3, :]`. \n", 228 | "\n", 229 | "Maybe this becomes more clear if we increase the sharpness of our structure elements a \"little\" ..." 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 9, 235 | "metadata": {}, 236 | "outputs": [ 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "tensor([[1., 0., 0.],\n", 242 | " [2., 0., 0.],\n", 243 | " [1., 0., 1.],\n", 244 | " [1., 0., 1.]], grad_fn=)\n" 245 | ] 246 | } 247 | ], 248 | "source": [ 249 | "sl = SLayerExponential(3, 2, \n", 250 | " centers_init=centers_init, \n", 251 | " sharpness_init=10*sharpness_init)\n", 252 | "print(sl(batch))" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "Below is a small toy model to illustrate the applicatation of `SLayerExponential`:" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 10, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "class ToyModel(torch.nn.Module):\n", 269 | " def __init__(self):\n", 270 | " super().__init__() \n", 271 | " self.slayer = SLayerExponential(50, 2)\n", 272 | " self.linear = torch.nn.Linear(50, 10)\n", 273 | " \n", 274 | " def forward(self, inp):\n", 275 | " x = self.slayer(inp)\n", 276 | " x = self.linear(x)\n", 277 | " return x " 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 11, 283 | "metadata": {}, 284 | "outputs": [ 285 | { 286 | "name": "stdout", 287 | "output_type": "stream", 288 | "text": [ 289 | "torch.Size([3, 10])\n" 290 | ] 291 | } 292 | ], 293 | "source": [ 294 | "model = ToyModel()\n", 295 | "inp = [torch.rand(10,2), torch.rand(20,2), torch.rand(30,2)]\n", 296 | "out = model(inp)\n", 297 | "print(out.size())" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": {}, 303 | "source": [ 304 | "More information about alternative structure elements, i.e., \n", 305 | "\n", 306 | "- `torchph.nn.SLayerRational` \n", 307 | "- `torchph.nn.SLayerRationalHat`\n", 308 | "\n", 309 | "see documentation." 310 | ] 311 | } 312 | ], 313 | "metadata": { 314 | "kernelspec": { 315 | "display_name": "Python 3", 316 | "language": "python", 317 | "name": "python3" 318 | }, 319 | "language_info": { 320 | "codemirror_mode": { 321 | "name": "ipython", 322 | "version": 3 323 | }, 324 | "file_extension": ".py", 325 | "mimetype": "text/x-python", 326 | "name": "python", 327 | "nbconvert_exporter": "python", 328 | "pygments_lexer": "ipython3", 329 | "version": "3.7.6" 330 | } 331 | }, 332 | "nbformat": 4, 333 | "nbformat_minor": 2 334 | } 335 | -------------------------------------------------------------------------------- /docs_src/source/tutorials/SLayer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Differentiable barcode vectorization\n", 8 | "\n", 9 | "This tutorial gives you a brief insight in the functionalities offered by the `torchph.nn.SLayerExponential` \n", 10 | "module. It assumes familarity with standard `PyTorch` functionality. \n", 11 | "\n", 12 | "Also, `torchph.nn.SLayerExponential` is just one *structure element* and others are available as well (see documentation)." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "from shared_code import check_torchph_availability\n", 22 | "check_torchph_availability()" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "from torchph.nn import SLayerExponential\n", 32 | "\n", 33 | "# create an instance with 3 structure elements over \\R^2\n", 34 | "sl = SLayerExponential(3, 2)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "`nn.SLayerExponential` is a `torch.nn.Module` ... " 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 4, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "data": { 51 | "text/plain": [ 52 | "True" 53 | ] 54 | }, 55 | "execution_count": 4, 56 | "metadata": {}, 57 | "output_type": "execute_result" 58 | } 59 | ], 60 | "source": [ 61 | "import torch\n", 62 | "isinstance(sl, torch.nn.Module)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "... now we can do all the beautiful stuff which is inherited from `torch.nn.Module`, e.g.," 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 5, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "name": "stdout", 79 | "output_type": "stream", 80 | "text": [ 81 | "Parameter containing:\n", 82 | "tensor([[0.6355, 0.3604],\n", 83 | " [0.3162, 0.9167],\n", 84 | " [0.4922, 0.9822]], requires_grad=True)\n", 85 | "Parameter containing:\n", 86 | "tensor([[3., 3.],\n", 87 | " [3., 3.],\n", 88 | " [3., 3.]], requires_grad=True)\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "for p in sl.parameters():\n", 94 | " print(p)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "The module has **two** parameters: \n", 102 | "1. `centers` : controls the centers of the structure elements. \n", 103 | "2. `sharpness`: controls how tight the used Gaussians are. The higher the value, the tighter. \n", 104 | "\n", 105 | "Both can be initialized using the `centers_init` and `sharpness_init` keyword arguments, respectively." 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "Parameter containing:\n", 118 | "tensor([[0.0000, 0.0000],\n", 119 | " [0.5000, 0.5000],\n", 120 | " [1.0000, 1.0000]], requires_grad=True)\n", 121 | "Parameter containing:\n", 122 | "tensor([[1., 1.],\n", 123 | " [2., 2.],\n", 124 | " [3., 3.]], requires_grad=True)\n" 125 | ] 126 | } 127 | ], 128 | "source": [ 129 | "# here is an initialization example\n", 130 | "centers_init = torch.Tensor(\n", 131 | " [\n", 132 | " [0,0], \n", 133 | " [0.5, 0.5], \n", 134 | " [1,1]\n", 135 | " ]\n", 136 | ")\n", 137 | "\n", 138 | "sharpness_init = torch.Tensor(\n", 139 | " [\n", 140 | " [1,1], \n", 141 | " [2,2], \n", 142 | " [3,3]\n", 143 | " ]\n", 144 | ")\n", 145 | "\n", 146 | "sl = SLayerExponential(3, 2, \n", 147 | " centers_init=centers_init, \n", 148 | " sharpness_init=sharpness_init)\n", 149 | "\n", 150 | "print(sl.centers)\n", 151 | "print(sl.sharpness)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "The simplest input form for `nn.SLayerExponential` is a `list` of `torch.Tensor` objects which are treated as a *batch*. " 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 7, 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "torch.Size([4, 3])\n" 171 | ] 172 | } 173 | ], 174 | "source": [ 175 | "# As an example, we create a batch of multisets\n", 176 | "mset_1 = [[0, 0]]\n", 177 | "mset_2 = [[0, 0], [0, 0]]\n", 178 | "mset_3 = [[1, 1], [0, 0]]\n", 179 | "mset_4 = [[0, 0], [1, 1]]\n", 180 | "batch = [mset_1, mset_2, mset_3, mset_4]\n", 181 | "batch = [torch.Tensor(x) for x in batch]\n", 182 | "output = sl(batch)\n", 183 | "print(output.size())" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "As we see the output dimensionality is `(4, 3)` since\n", 191 | "we have a batch of size `4` and `3` structure elements. \n", 192 | "\n", 193 | "In other words, \n", 194 | "`output[i, j] =` \"evaluation of structure element j on mset_i\"\n", 195 | "\n", 196 | "Lets take a look ... " 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 8, 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "name": "stdout", 206 | "output_type": "stream", 207 | "text": [ 208 | "tensor([[1.0000e+00, 1.3534e-01, 1.5230e-08],\n", 209 | " [2.0000e+00, 2.7067e-01, 3.0460e-08],\n", 210 | " [1.1353e+00, 2.7067e-01, 1.0000e+00],\n", 211 | " [1.1353e+00, 2.7067e-01, 1.0000e+00]], grad_fn=)\n" 212 | ] 213 | } 214 | ], 215 | "source": [ 216 | "print(output)" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": {}, 222 | "source": [ 223 | "**We observe the following:**\n", 224 | "\n", 225 | "1. The j-th stucture element approximates the multiplicity function of the given input at point `sl.centers[j]`. E.g., the output of mset_1, `output[0, :]`, is approx. `(1, 0, 0)`. \n", 226 | "2. `sl.sharpness[j]` controls the amount of contribution of points not exactly on `sl.centers[j]` with respect to their distance to `sl.centers[j]`. \n", 227 | "3. The input is interpreted as set, i.e., it is permutation invariant, as mset_3 and mset_4 do not defer as multiset and also `output[2,:] == output[3, :]`. \n", 228 | "\n", 229 | "Maybe this becomes more clear if we increase the sharpness of our structure elements a \"little\" ..." 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 9, 235 | "metadata": {}, 236 | "outputs": [ 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "tensor([[1., 0., 0.],\n", 242 | " [2., 0., 0.],\n", 243 | " [1., 0., 1.],\n", 244 | " [1., 0., 1.]], grad_fn=)\n" 245 | ] 246 | } 247 | ], 248 | "source": [ 249 | "sl = SLayerExponential(3, 2, \n", 250 | " centers_init=centers_init, \n", 251 | " sharpness_init=10*sharpness_init)\n", 252 | "print(sl(batch))" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "Below is a small toy model to illustrate the applicatation of `SLayerExponential`:" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 10, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "class ToyModel(torch.nn.Module):\n", 269 | " def __init__(self):\n", 270 | " super().__init__() \n", 271 | " self.slayer = SLayerExponential(50, 2)\n", 272 | " self.linear = torch.nn.Linear(50, 10)\n", 273 | " \n", 274 | " def forward(self, inp):\n", 275 | " x = self.slayer(inp)\n", 276 | " x = self.linear(x)\n", 277 | " return x " 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 11, 283 | "metadata": {}, 284 | "outputs": [ 285 | { 286 | "name": "stdout", 287 | "output_type": "stream", 288 | "text": [ 289 | "torch.Size([3, 10])\n" 290 | ] 291 | } 292 | ], 293 | "source": [ 294 | "model = ToyModel()\n", 295 | "inp = [torch.rand(10,2), torch.rand(20,2), torch.rand(30,2)]\n", 296 | "out = model(inp)\n", 297 | "print(out.size())" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": {}, 303 | "source": [ 304 | "More information about alternative structure elements, i.e., \n", 305 | "\n", 306 | "- `torchph.nn.SLayerRational` \n", 307 | "- `torchph.nn.SLayerRationalHat`\n", 308 | "\n", 309 | "see documentation." 310 | ] 311 | } 312 | ], 313 | "metadata": { 314 | "kernelspec": { 315 | "display_name": "Python 3", 316 | "language": "python", 317 | "name": "python3" 318 | }, 319 | "language_info": { 320 | "codemirror_mode": { 321 | "name": "ipython", 322 | "version": 3 323 | }, 324 | "file_extension": ".py", 325 | "mimetype": "text/x-python", 326 | "name": "python", 327 | "nbconvert_exporter": "python", 328 | "pygments_lexer": "ipython3", 329 | "version": "3.7.6" 330 | } 331 | }, 332 | "nbformat": 4, 333 | "nbformat_minor": 2 334 | } 335 | -------------------------------------------------------------------------------- /docs/_sources/tutorials/SLayer.ipynb.txt: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Differentiable barcode vectorization\n", 8 | "\n", 9 | "This tutorial gives you a brief insight in the functionalities offered by the `torchph.nn.SLayerExponential` \n", 10 | "module. It assumes familarity with standard `PyTorch` functionality. \n", 11 | "\n", 12 | "Also, `torchph.nn.SLayerExponential` is just one *structure element* and others are available as well (see documentation)." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "from shared_code import check_torchph_availability\n", 22 | "check_torchph_availability()" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "from torchph.nn import SLayerExponential\n", 32 | "\n", 33 | "# create an instance with 3 structure elements over \\R^2\n", 34 | "sl = SLayerExponential(3, 2)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "`nn.SLayerExponential` is a `torch.nn.Module` ... " 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 4, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "data": { 51 | "text/plain": [ 52 | "True" 53 | ] 54 | }, 55 | "execution_count": 4, 56 | "metadata": {}, 57 | "output_type": "execute_result" 58 | } 59 | ], 60 | "source": [ 61 | "import torch\n", 62 | "isinstance(sl, torch.nn.Module)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "... now we can do all the beautiful stuff which is inherited from `torch.nn.Module`, e.g.," 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 5, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "name": "stdout", 79 | "output_type": "stream", 80 | "text": [ 81 | "Parameter containing:\n", 82 | "tensor([[0.6355, 0.3604],\n", 83 | " [0.3162, 0.9167],\n", 84 | " [0.4922, 0.9822]], requires_grad=True)\n", 85 | "Parameter containing:\n", 86 | "tensor([[3., 3.],\n", 87 | " [3., 3.],\n", 88 | " [3., 3.]], requires_grad=True)\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "for p in sl.parameters():\n", 94 | " print(p)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "The module has **two** parameters: \n", 102 | "1. `centers` : controls the centers of the structure elements. \n", 103 | "2. `sharpness`: controls how tight the used Gaussians are. The higher the value, the tighter. \n", 104 | "\n", 105 | "Both can be initialized using the `centers_init` and `sharpness_init` keyword arguments, respectively." 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "Parameter containing:\n", 118 | "tensor([[0.0000, 0.0000],\n", 119 | " [0.5000, 0.5000],\n", 120 | " [1.0000, 1.0000]], requires_grad=True)\n", 121 | "Parameter containing:\n", 122 | "tensor([[1., 1.],\n", 123 | " [2., 2.],\n", 124 | " [3., 3.]], requires_grad=True)\n" 125 | ] 126 | } 127 | ], 128 | "source": [ 129 | "# here is an initialization example\n", 130 | "centers_init = torch.Tensor(\n", 131 | " [\n", 132 | " [0,0], \n", 133 | " [0.5, 0.5], \n", 134 | " [1,1]\n", 135 | " ]\n", 136 | ")\n", 137 | "\n", 138 | "sharpness_init = torch.Tensor(\n", 139 | " [\n", 140 | " [1,1], \n", 141 | " [2,2], \n", 142 | " [3,3]\n", 143 | " ]\n", 144 | ")\n", 145 | "\n", 146 | "sl = SLayerExponential(3, 2, \n", 147 | " centers_init=centers_init, \n", 148 | " sharpness_init=sharpness_init)\n", 149 | "\n", 150 | "print(sl.centers)\n", 151 | "print(sl.sharpness)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "The simplest input form for `nn.SLayerExponential` is a `list` of `torch.Tensor` objects which are treated as a *batch*. " 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 7, 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "torch.Size([4, 3])\n" 171 | ] 172 | } 173 | ], 174 | "source": [ 175 | "# As an example, we create a batch of multisets\n", 176 | "mset_1 = [[0, 0]]\n", 177 | "mset_2 = [[0, 0], [0, 0]]\n", 178 | "mset_3 = [[1, 1], [0, 0]]\n", 179 | "mset_4 = [[0, 0], [1, 1]]\n", 180 | "batch = [mset_1, mset_2, mset_3, mset_4]\n", 181 | "batch = [torch.Tensor(x) for x in batch]\n", 182 | "output = sl(batch)\n", 183 | "print(output.size())" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "As we see the output dimensionality is `(4, 3)` since\n", 191 | "we have a batch of size `4` and `3` structure elements. \n", 192 | "\n", 193 | "In other words, \n", 194 | "`output[i, j] =` \"evaluation of structure element j on mset_i\"\n", 195 | "\n", 196 | "Lets take a look ... " 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 8, 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "name": "stdout", 206 | "output_type": "stream", 207 | "text": [ 208 | "tensor([[1.0000e+00, 1.3534e-01, 1.5230e-08],\n", 209 | " [2.0000e+00, 2.7067e-01, 3.0460e-08],\n", 210 | " [1.1353e+00, 2.7067e-01, 1.0000e+00],\n", 211 | " [1.1353e+00, 2.7067e-01, 1.0000e+00]], grad_fn=)\n" 212 | ] 213 | } 214 | ], 215 | "source": [ 216 | "print(output)" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": {}, 222 | "source": [ 223 | "**We observe the following:**\n", 224 | "\n", 225 | "1. The j-th stucture element approximates the multiplicity function of the given input at point `sl.centers[j]`. E.g., the output of mset_1, `output[0, :]`, is approx. `(1, 0, 0)`. \n", 226 | "2. `sl.sharpness[j]` controls the amount of contribution of points not exactly on `sl.centers[j]` with respect to their distance to `sl.centers[j]`. \n", 227 | "3. The input is interpreted as set, i.e., it is permutation invariant, as mset_3 and mset_4 do not defer as multiset and also `output[2,:] == output[3, :]`. \n", 228 | "\n", 229 | "Maybe this becomes more clear if we increase the sharpness of our structure elements a \"little\" ..." 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 9, 235 | "metadata": {}, 236 | "outputs": [ 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "tensor([[1., 0., 0.],\n", 242 | " [2., 0., 0.],\n", 243 | " [1., 0., 1.],\n", 244 | " [1., 0., 1.]], grad_fn=)\n" 245 | ] 246 | } 247 | ], 248 | "source": [ 249 | "sl = SLayerExponential(3, 2, \n", 250 | " centers_init=centers_init, \n", 251 | " sharpness_init=10*sharpness_init)\n", 252 | "print(sl(batch))" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "Below is a small toy model to illustrate the applicatation of `SLayerExponential`:" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 10, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "class ToyModel(torch.nn.Module):\n", 269 | " def __init__(self):\n", 270 | " super().__init__() \n", 271 | " self.slayer = SLayerExponential(50, 2)\n", 272 | " self.linear = torch.nn.Linear(50, 10)\n", 273 | " \n", 274 | " def forward(self, inp):\n", 275 | " x = self.slayer(inp)\n", 276 | " x = self.linear(x)\n", 277 | " return x " 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 11, 283 | "metadata": {}, 284 | "outputs": [ 285 | { 286 | "name": "stdout", 287 | "output_type": "stream", 288 | "text": [ 289 | "torch.Size([3, 10])\n" 290 | ] 291 | } 292 | ], 293 | "source": [ 294 | "model = ToyModel()\n", 295 | "inp = [torch.rand(10,2), torch.rand(20,2), torch.rand(30,2)]\n", 296 | "out = model(inp)\n", 297 | "print(out.size())" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": {}, 303 | "source": [ 304 | "More information about alternative structure elements, i.e., \n", 305 | "\n", 306 | "- `torchph.nn.SLayerRational` \n", 307 | "- `torchph.nn.SLayerRationalHat`\n", 308 | "\n", 309 | "see documentation." 310 | ] 311 | } 312 | ], 313 | "metadata": { 314 | "kernelspec": { 315 | "display_name": "Python 3", 316 | "language": "python", 317 | "name": "python3" 318 | }, 319 | "language_info": { 320 | "codemirror_mode": { 321 | "name": "ipython", 322 | "version": 3 323 | }, 324 | "file_extension": ".py", 325 | "mimetype": "text/x-python", 326 | "name": "python", 327 | "nbconvert_exporter": "python", 328 | "pygments_lexer": "ipython3", 329 | "version": "3.7.6" 330 | } 331 | }, 332 | "nbformat": 4, 333 | "nbformat_minor": 2 334 | } 335 | -------------------------------------------------------------------------------- /docs/searchindex.js: -------------------------------------------------------------------------------- 1 | Search.setIndex({docnames:["index","install/index","nn","pershom","tutorials/ComparisonSOTA","tutorials/InputOptim","tutorials/SLayer","tutorials/ToyDiffVR"],envversion:{"sphinx.domains.c":1,"sphinx.domains.changeset":1,"sphinx.domains.citation":1,"sphinx.domains.cpp":1,"sphinx.domains.index":1,"sphinx.domains.javascript":1,"sphinx.domains.math":2,"sphinx.domains.python":1,"sphinx.domains.rst":1,"sphinx.domains.std":1,"sphinx.ext.intersphinx":1,"sphinx.ext.viewcode":1,nbsphinx:2,sphinx:56},filenames:["index.rst","install/index.rst","nn.rst","pershom.rst","tutorials/ComparisonSOTA.ipynb","tutorials/InputOptim.ipynb","tutorials/SLayer.ipynb","tutorials/ToyDiffVR.ipynb"],objects:{"torchph.nn":{slayer:[2,0,0,"-"]},"torchph.nn.slayer":{SLayerExponential:[2,1,1,""],SLayerRational:[2,1,1,""],SLayerRationalHat:[2,1,1,""],prepare_batch:[2,3,1,""]},"torchph.nn.slayer.SLayerExponential":{__init__:[2,2,1,""],forward:[2,2,1,""]},"torchph.nn.slayer.SLayerRational":{__init__:[2,2,1,""],forward:[2,2,1,""]},"torchph.nn.slayer.SLayerRationalHat":{__init__:[2,2,1,""],forward:[2,2,1,""]},"torchph.pershom":{pershom_backend:[3,0,0,"-"]},"torchph.pershom.pershom_backend":{calculate_persistence:[3,3,1,""],find_merge_pairings:[3,3,1,""],merge_columns_:[3,3,1,""],read_barcodes:[3,3,1,""],vr_persistence:[3,3,1,""],vr_persistence_l1:[3,3,1,""]}},objnames:{"0":["py","module","Python module"],"1":["py","class","Python class"],"2":["py","method","Python method"],"3":["py","function","Python function"]},objtypes:{"0":"py:module","1":"py:class","2":"py:method","3":"py:function"},terms:{"0000e":6,"0460e":6,"1353e":6,"3534e":6,"3rd":3,"5230e":6,"5th":3,"7067e":6,"br\u00fcel":5,"case":7,"class":[2,3,6,7],"default":3,"float":[0,2,3],"function":[2,3,6,7],"import":[0,1,2,4,5,6,7],"int":[2,3],"public":0,"return":[2,3,4,6,7],"super":[6,7],"true":[2,5,6,7],"while":2,For:[2,3,7],LTS:1,The:[0,1,2,3,5,6],__getitem__:7,__init__:[2,6,7],__iter__:7,__len__:7,about:6,abov:5,abs:[0,4,5,7],accompani:3,activ:[1,7],actual:7,adam:[5,7],addit:7,advanc:2,after:7,afterward:2,all:[0,2,3,4,5,6,7],alpha:[4,7],alreadi:4,also:6,altern:6,although:2,alwai:3,amount:6,anaconda3:1,anaconda:[1,4],append:[4,7],appli:7,applicat:6,apply_model:7,approx:6,approxim:6,arang:4,archiv:1,argument:6,arrai:[3,4,7],arrang:7,articl:0,arxiv:[0,5],assert:4,assum:[1,3,6],author:0,autoencod:7,autoreload:[5,7],avail:[0,1,4,6],avg:[4,7],avgep:7,axes:7,ba_row_i_to_bm_col_i:3,backend:[0,3],backprop:7,backward:[5,7],barcod:[0,2,3,7],base:[0,3],bash:1,basic:2,batch:[2,6,7],batch_siz:[2,7],bbox:7,bbox_inch:4,beauti:6,becom:6,below:6,best:1,beta:7,between:0,bin:1,birth:3,black:7,blue:4,booktitl:0,both:6,boundari:3,brief:6,calculate_persist:3,call:[2,3],can:[0,1,6,7],care:2,carlsson:5,cat:7,center:[2,6],centers_init:[2,6],chang:[1,7],check:[1,4,7],check_torchph_avail:6,chofer:4,choic:1,cityblock:[0,4],clear:6,click:2,clone:1,closer:7,cloud:[0,3,7],cohomology_persist:4,collect:[4,7],color:[4,7],column:3,com:1,combin:7,complex:3,compr_desc_sort_ba:3,compress:3,comput:[0,1,2,3,4,5,7],conda:1,configur:[1,5,7],connect:[0,4,7],consist:[2,3],construct:2,contain:[0,4,6],content:3,contigu:[0,4],contribut:6,control:6,conveni:7,corespond:3,correspond:3,counter:7,cours:1,covari:7,cpu:[1,4,5,7],creat:[0,6,7],cuda:[0,1,3,4,5,7],cudatoolkit:1,current:3,cycler:4,cython:4,d_l1:4,dagger:7,data:4,dataload:7,dataset:[4,7],dateutil:4,death:[3,4,7],decreasingli:3,deep:0,def:[4,6,7],defaultdict:[4,7],defer:6,defin:[2,3],densifi:0,deriv:7,descend:3,descending_sorted_boundari:3,descript:3,desir:7,detach:5,detail:7,develop:1,devic:[0,4,5,7],dgm:4,dgm_ripser:4,diagon:3,dict:7,diff:0,differ:7,differenti:2,dim:[0,7],dimens:[2,3],dimension:[2,6],dionysu:4,directli:5,distanc:[0,3,4,6,7],distance_matrix:3,distribut:[0,5],dixit:[0,4,7],document:[0,6],doing:0,done:1,driver:1,drop_last:7,dummi:2,dure:5,dwaraknath:5,easili:0,edg:3,effect:7,effici:3,element:[2,6,7],emb:2,embed:3,encod:[3,7],end:7,entri:3,enumer:[4,7],env:4,environ:0,epoch:7,epoch_i:7,epoch_loss:7,epsilon_t:7,essenti:[3,5,7],eta:7,eval:7,evalu:[4,6],event:3,everi:2,everyth:[1,7],exactli:6,exampl:[0,2,3,6,7],execut:[3,4],exist:1,expect:3,experi:[5,7],expon:2,exponent_init:2,expos:[0,3],facecolor:7,factor:4,fals:2,familar:6,featur:3,fig:[5,7],figsiz:[4,7],figur:[4,5,7],fill_between:4,fill_rip:4,filt:4,filtrat:[0,3],find:[3,7],find_merge_pair:3,first:[1,3,7],float32_t:3,follow:[0,1,5,6],fontsiz:[4,7],form:[2,6,7],format:[4,5,7],former:2,forward:[2,6],four:2,freez:7,freeze_expon:2,from:[0,1,2,3,4,5,6,7],gabrielsson:5,gaussian:[4,6,7],gen_circlc:4,gen_random_10d_data:4,gener:[3,4,5],get:[1,4,7],git:1,github:1,give:6,given:[3,6,7],gpu:[0,1,2,4,7],grad:7,grad_fn:6,graf:0,graph:0,greater:3,green:4,grid:[4,7],guiba:5,h_0:5,has:[3,6],hat:[2,7],have:[1,3,6],here:[2,6],higher:6,highest:2,hofer17a:0,hofer19a:0,hofer19b:0,hofer20a:0,hofer20b:0,hofer:[0,1,4,7],homolog:[2,4,5],hook:2,hopcroftkarp:4,horizontalalign:7,how:[0,6,7],howev:7,http:[0,1],icml:[0,4,7],ignor:[2,3],illustr:6,implement:[2,4,7],imposs:7,increas:[5,6],index:[0,3],indic:2,induc:3,inform:6,inherit:6,init_diagram:4,initi:[2,6,7],inlin:4,inp:6,inplac:3,inproceed:0,input:[2,6],insight:6,instal:[0,4],instanc:[2,6],instead:2,instruct:0,int64_t:3,integ:3,interleav:7,interpret:6,interv:7,invari:6,isinst:[6,7],item:[0,4,5,7],iter:[3,5,7],iteration_loss:7,itertool:7,jmlr:0,joblib:4,journal:0,just:[3,6],keyword:6,kiwisolv:4,kwitt:[0,4,7],l_1:5,l_2:5,l_a:0,l_b:0,label:4,later:7,latter:2,layer:[2,5,7],leakyrelu:7,learn:[0,4,5],learnabl:0,left:[5,7],legend:4,len:[4,7],let:[3,6,7],level:2,lib:4,lifetim:[5,7],linear:[6,7],linearsegmentedcolormap:7,linux:1,list:[2,3,4,6,7],littl:6,load_ext:[5,7],local:4,look:6,loss:[5,7],lower:3,machin:[0,5],make:1,make_circl:4,manhattan:4,manhatten:[0,3],mani:1,manner:7,map:7,markers:[5,7],mathbb:[4,5,7],matplotlib:[4,5,7],matric:0,matrix:[0,3,4],max:[2,4,7],max_ball_diamet:3,max_dim_to_read_of_reduced_ba:3,max_dimens:3,max_pair:3,max_pt:2,maxdim:4,mayb:6,mean:[3,4,7],mention:5,merg:3,merge_columns_:3,merge_pair:3,method:[0,2,4],metric:[0,4],min:7,mini:7,minim:[5,7],misc:7,mlp:7,model:6,modifi:5,modul:[0,2,3,4,6],moment:1,more:[6,7],most:[0,3],motiv:5,mset_1:6,mset_2:6,mset_3:6,mset_4:6,mset_i:6,multi:2,multipl:[0,6],multiset:[2,6],mus:7,n_element:2,n_max_point:2,n_sampl:4,n_samples_by_class:7,need:2,nelson:5,never:7,niethamm:[0,4,7],nip:0,no_grad:7,nois:4,non:[3,5],none:[2,3],norm:5,normal10d_runtim:4,not_dummi:2,notat:3,note:[0,3,5,7],notebook:4,now:6,num_work:7,number:[2,3],numpi:[0,4,5,7],nx1:3,object:6,observ:[6,7],obtain:7,obvious:1,occur:2,offer:6,omit:3,one:[2,6,7],onli:[1,3],onlin:[4,7],oper:3,opt:[4,5,7],optim:[0,4],order:[2,3],org:0,other:[1,6],our:[4,6,7],out:6,output:[3,6,7],over:[6,7],overridden:2,packag:[0,1,4],pad:[2,3],page:0,pair:3,pairwis:[0,4,7],paper:[5,7],param_group:7,paramet:[2,3,6,7],part:3,partial:3,particular:[0,5,7],pass:2,path:1,pdf:[4,5],pdist:[0,4],per:[5,7],perform:2,permut:6,pershom:[0,4,5,7],pershom_backend:3,persim:4,persist:[2,4,5],pip:4,pivot:3,plan:1,plot:[4,5,7],plt:[4,5,7],point:[0,2,3,5,6,7],point_cloud:3,point_dim:2,point_dimens:2,pointwise_activation_threshold:2,polici:2,posit:[2,7],possibl:3,post20200210:4,practic:3,pre:0,prepare_batch:2,print:[0,4,5,6,7],problem:0,process:2,propos:2,pts:7,purpos:3,pypars:4,pyplot:[4,5,7],pyt_1:4,python3:4,python:[1,3,4],pytorch:[1,6,7],quick:1,radiu:2,radius_init:2,rand:[2,5,6],randn:[0,4,7],random:[0,4,5],random_split:7,rang:[5,7],rate:7,ration:2,ravel:4,reach:7,read:3,read_barcod:3,readm:3,real:3,recip:2,red:4,reduc:3,reduct:3,reflect:1,regist:2,relev:0,repo:1,repositori:1,represent:[0,4,7],reproduc:7,requir:[3,4],requires_grad:[5,6],respect:6,result:4,ret:3,rieck:0,right:[3,5,7],rip:[0,3,4,5],ripser:4,row:3,run:[2,4,7],runtim:4,same:[0,3,4],sampl:[0,4,5,7],saniti:4,satisfi:4,savefig:4,scale:4,schedul:7,scikit:4,scipi:[0,4],scratch4:4,search:0,second:1,section:3,see:[3,6,7],seed:5,self:[6,7],sequenti:7,set:[2,3,6,7],set_titl:7,setup:1,setuptool:4,share_expon:2,share_sharp:2,shared_cod:6,sharp:[2,6],sharpness_init:[2,6],should:2,show:0,shuffl:7,sigma:[4,7],signatur:0,silent:2,simpl:[0,4,7],simplest:6,simplex:3,simplex_dimens:3,sinc:[2,6],site:4,six:4,size:[2,3,4,6,7],sklearn:4,skraba:5,slayer:[2,6],slayerexponenti:[2,6],slayerr:[2,6],slayerrationalhat:[2,6],slightli:5,small:[4,6],sometim:3,sort:[3,4],sourc:[1,2,3],spatial:[0,4],specifi:1,squar:3,squareform:0,squeezebackward0:6,standard:6,stat:7,std:4,step:[5,7],structur:[2,6],stuctur:6,studi:[4,7],stuff:6,subclass:2,subplot:7,subset:7,sum:[0,4,5,7],sum_:7,support:[1,3],symmetri:3,sys:4,system:1,take:[2,6,7],task:5,teaser:0,tell:3,tensor:[0,2,3,4,5,6,7],tensordataset:7,term:4,test:1,than:[3,7],them:2,theta:7,thi:[0,1,2,3,4,5,6,7],third:1,thr_l1:4,three:7,thresh:4,threshold:4,throughout:7,tight:[4,6],tighter:[6,7],time:[3,4,7],titl:[0,4,5,7],tmp:[1,4],toi:[5,6],tolist:7,top:[3,5],topolog:[0,5],torch:[0,2,3,4,5,6,7],torchph:[0,1,2,3,4,5,6,7],torchvis:1,toy2ddata:7,toy_data:5,toymodel:6,track:7,track_persistence_info:7,train:7,transform:7,transformed_pt:7,treat:6,trial:4,tupl:[2,7],tutori:[2,6],two:[3,6],type:3,ubuntu:1,uhl:0,uncompress:3,underli:3,uniform:5,unit:4,updat:7,upper:3,url:0,usabl:0,use:[5,7],used:[0,2,3,6],uses:[0,5],using:[0,1,3,6],usual:3,util:7,v_0:3,v_2:3,v_3:3,v_4:3,v_5:3,v_6:3,valu:[3,6,7],varepsilon:7,varepsilon_t:7,vector:[0,2,3],version:[1,3,4,5],vertex:0,vertic:3,via:[0,4,7],vietori:[0,3,4,5],vr_persist:[0,3,4],vr_persistence_l1:[0,3,4,5,7],well:6,wget:1,where:[2,3],wherea:[2,5],which:[0,1,2,3,4,6,7],white:7,whose:3,why:7,widehat:7,within:[0,1,2],word:6,work:[0,1],would:3,x86_64:1,x_class:7,x_hat:7,xlabel:[4,7],year:0,yet:[0,1],ylabel:[4,7],you:[1,6],your:1,zero:[2,4,7],zero_grad:[5,7],zip:7},titles:["PyTorch extensions for persistent homology","Installation","nn","pershom","Comparison to SOTA","Input space optimization","Differentiable barcode vectorization","Differentiable Vietoris-Rips persistent homology"],titleterms:{"function":0,barcod:6,comparison:4,data:[5,7],differenti:[6,7],exampl:5,extens:0,get:0,homolog:[0,7],indic:0,input:5,instal:1,learn:7,limit:3,model:7,notat:7,optim:[5,7],pershom:3,persist:[0,7],pytorch:0,rip:7,sota:4,space:5,start:0,tabl:0,task:7,terminolog:3,toi:7,vector:6,vietori:7,visual:7}}) -------------------------------------------------------------------------------- /docs/install/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | Installation — torchph 0.0.0 documentation 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 |
49 | 50 | 108 | 109 |
110 | 111 | 112 | 118 | 119 | 120 |
121 | 122 |
123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 |
141 | 142 |
    143 | 144 |
  • Docs »
  • 145 | 146 |
  • Installation
  • 147 | 148 | 149 |
  • 150 | 151 | 152 | View page source 153 | 154 | 155 |
  • 156 | 157 |
158 | 159 | 160 |
161 |
162 |
163 |
164 | 165 | 166 | 189 |
190 |

Installation

191 |

The following setup was tested with the following system configuration:

192 |
    193 |
  • Ubuntu 18.04.2 LTS

  • 194 |
  • CUDA 10.1 (driver version 418.87.00)

  • 195 |
  • Anaconda (Python 3.7.6)

  • 196 |
  • PyTorch 1.4

  • 197 |
198 |

In the following, we assume that we work in /tmp (obviously, you have to 199 | change this to reflect your choice and using /tmp is, of course, not 200 | the best choice :).

201 |

First, get the Anaconda installer and install Anaconda (in /tmp/anaconda3) 202 | using

203 |
cd /tmp/
204 | wget https://repo.anaconda.com/archive/Anaconda3-2019.10-Linux-x86_64.sh
205 | bash Anaconda3-2019.10-Linux-x86_64.sh
206 | # specify /tmp/anaconda3 as your installation path
207 | source /tmp/anaconda3/bin/activate
208 | 
209 |
210 |

Second, we install PyTorch (v1.4) using

211 |
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
212 | 
213 |
214 |

Third, we clone the torchph repository from GitHub and make 215 | it available within Anaconda.

216 |
cd /tmp/
217 | git clone https://github.com/c-hofer/torchph.git
218 | conda develop /tmp/torchph
219 | 
220 |
221 |

A quick check if everything works can be done with

222 |
>>> import torchph
223 | 
224 |
225 |
226 |

Note

227 |

At the moment, we only have GPU support available. CPU support 228 | is not planned yet, as many other packages exist which support 229 | PH computation on the CPU.

230 |
231 |
232 | 233 | 234 |
235 | 236 |
237 |
238 | 239 | 247 | 248 | 249 |
250 | 251 |
252 |

253 | © Copyright 2020, Christoph D. Hofer, Roland Kwitt 254 | 255 |

256 |
257 | Built with Sphinx using a theme provided by Read the Docs. 258 | 259 |
260 | 261 |
262 |
263 | 264 |
265 | 266 |
267 | 268 | 269 | 270 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | -------------------------------------------------------------------------------- /docs/_static/doctools.js: -------------------------------------------------------------------------------- 1 | /* 2 | * doctools.js 3 | * ~~~~~~~~~~~ 4 | * 5 | * Sphinx JavaScript utilities for all documentation. 6 | * 7 | * :copyright: Copyright 2007-2020 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | 12 | /** 13 | * select a different prefix for underscore 14 | */ 15 | $u = _.noConflict(); 16 | 17 | /** 18 | * make the code below compatible with browsers without 19 | * an installed firebug like debugger 20 | if (!window.console || !console.firebug) { 21 | var names = ["log", "debug", "info", "warn", "error", "assert", "dir", 22 | "dirxml", "group", "groupEnd", "time", "timeEnd", "count", "trace", 23 | "profile", "profileEnd"]; 24 | window.console = {}; 25 | for (var i = 0; i < names.length; ++i) 26 | window.console[names[i]] = function() {}; 27 | } 28 | */ 29 | 30 | /** 31 | * small helper function to urldecode strings 32 | */ 33 | jQuery.urldecode = function(x) { 34 | return decodeURIComponent(x).replace(/\+/g, ' '); 35 | }; 36 | 37 | /** 38 | * small helper function to urlencode strings 39 | */ 40 | jQuery.urlencode = encodeURIComponent; 41 | 42 | /** 43 | * This function returns the parsed url parameters of the 44 | * current request. Multiple values per key are supported, 45 | * it will always return arrays of strings for the value parts. 46 | */ 47 | jQuery.getQueryParameters = function(s) { 48 | if (typeof s === 'undefined') 49 | s = document.location.search; 50 | var parts = s.substr(s.indexOf('?') + 1).split('&'); 51 | var result = {}; 52 | for (var i = 0; i < parts.length; i++) { 53 | var tmp = parts[i].split('=', 2); 54 | var key = jQuery.urldecode(tmp[0]); 55 | var value = jQuery.urldecode(tmp[1]); 56 | if (key in result) 57 | result[key].push(value); 58 | else 59 | result[key] = [value]; 60 | } 61 | return result; 62 | }; 63 | 64 | /** 65 | * highlight a given string on a jquery object by wrapping it in 66 | * span elements with the given class name. 67 | */ 68 | jQuery.fn.highlightText = function(text, className) { 69 | function highlight(node, addItems) { 70 | if (node.nodeType === 3) { 71 | var val = node.nodeValue; 72 | var pos = val.toLowerCase().indexOf(text); 73 | if (pos >= 0 && 74 | !jQuery(node.parentNode).hasClass(className) && 75 | !jQuery(node.parentNode).hasClass("nohighlight")) { 76 | var span; 77 | var isInSVG = jQuery(node).closest("body, svg, foreignObject").is("svg"); 78 | if (isInSVG) { 79 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan"); 80 | } else { 81 | span = document.createElement("span"); 82 | span.className = className; 83 | } 84 | span.appendChild(document.createTextNode(val.substr(pos, text.length))); 85 | node.parentNode.insertBefore(span, node.parentNode.insertBefore( 86 | document.createTextNode(val.substr(pos + text.length)), 87 | node.nextSibling)); 88 | node.nodeValue = val.substr(0, pos); 89 | if (isInSVG) { 90 | var rect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); 91 | var bbox = node.parentElement.getBBox(); 92 | rect.x.baseVal.value = bbox.x; 93 | rect.y.baseVal.value = bbox.y; 94 | rect.width.baseVal.value = bbox.width; 95 | rect.height.baseVal.value = bbox.height; 96 | rect.setAttribute('class', className); 97 | addItems.push({ 98 | "parent": node.parentNode, 99 | "target": rect}); 100 | } 101 | } 102 | } 103 | else if (!jQuery(node).is("button, select, textarea")) { 104 | jQuery.each(node.childNodes, function() { 105 | highlight(this, addItems); 106 | }); 107 | } 108 | } 109 | var addItems = []; 110 | var result = this.each(function() { 111 | highlight(this, addItems); 112 | }); 113 | for (var i = 0; i < addItems.length; ++i) { 114 | jQuery(addItems[i].parent).before(addItems[i].target); 115 | } 116 | return result; 117 | }; 118 | 119 | /* 120 | * backward compatibility for jQuery.browser 121 | * This will be supported until firefox bug is fixed. 122 | */ 123 | if (!jQuery.browser) { 124 | jQuery.uaMatch = function(ua) { 125 | ua = ua.toLowerCase(); 126 | 127 | var match = /(chrome)[ \/]([\w.]+)/.exec(ua) || 128 | /(webkit)[ \/]([\w.]+)/.exec(ua) || 129 | /(opera)(?:.*version|)[ \/]([\w.]+)/.exec(ua) || 130 | /(msie) ([\w.]+)/.exec(ua) || 131 | ua.indexOf("compatible") < 0 && /(mozilla)(?:.*? rv:([\w.]+)|)/.exec(ua) || 132 | []; 133 | 134 | return { 135 | browser: match[ 1 ] || "", 136 | version: match[ 2 ] || "0" 137 | }; 138 | }; 139 | jQuery.browser = {}; 140 | jQuery.browser[jQuery.uaMatch(navigator.userAgent).browser] = true; 141 | } 142 | 143 | /** 144 | * Small JavaScript module for the documentation. 145 | */ 146 | var Documentation = { 147 | 148 | init : function() { 149 | this.fixFirefoxAnchorBug(); 150 | this.highlightSearchWords(); 151 | this.initIndexTable(); 152 | if (DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) { 153 | this.initOnKeyListeners(); 154 | } 155 | }, 156 | 157 | /** 158 | * i18n support 159 | */ 160 | TRANSLATIONS : {}, 161 | PLURAL_EXPR : function(n) { return n === 1 ? 0 : 1; }, 162 | LOCALE : 'unknown', 163 | 164 | // gettext and ngettext don't access this so that the functions 165 | // can safely bound to a different name (_ = Documentation.gettext) 166 | gettext : function(string) { 167 | var translated = Documentation.TRANSLATIONS[string]; 168 | if (typeof translated === 'undefined') 169 | return string; 170 | return (typeof translated === 'string') ? translated : translated[0]; 171 | }, 172 | 173 | ngettext : function(singular, plural, n) { 174 | var translated = Documentation.TRANSLATIONS[singular]; 175 | if (typeof translated === 'undefined') 176 | return (n == 1) ? singular : plural; 177 | return translated[Documentation.PLURALEXPR(n)]; 178 | }, 179 | 180 | addTranslations : function(catalog) { 181 | for (var key in catalog.messages) 182 | this.TRANSLATIONS[key] = catalog.messages[key]; 183 | this.PLURAL_EXPR = new Function('n', 'return +(' + catalog.plural_expr + ')'); 184 | this.LOCALE = catalog.locale; 185 | }, 186 | 187 | /** 188 | * add context elements like header anchor links 189 | */ 190 | addContextElements : function() { 191 | $('div[id] > :header:first').each(function() { 192 | $('\u00B6'). 193 | attr('href', '#' + this.id). 194 | attr('title', _('Permalink to this headline')). 195 | appendTo(this); 196 | }); 197 | $('dt[id]').each(function() { 198 | $('\u00B6'). 199 | attr('href', '#' + this.id). 200 | attr('title', _('Permalink to this definition')). 201 | appendTo(this); 202 | }); 203 | }, 204 | 205 | /** 206 | * workaround a firefox stupidity 207 | * see: https://bugzilla.mozilla.org/show_bug.cgi?id=645075 208 | */ 209 | fixFirefoxAnchorBug : function() { 210 | if (document.location.hash && $.browser.mozilla) 211 | window.setTimeout(function() { 212 | document.location.href += ''; 213 | }, 10); 214 | }, 215 | 216 | /** 217 | * highlight the search words provided in the url in the text 218 | */ 219 | highlightSearchWords : function() { 220 | var params = $.getQueryParameters(); 221 | var terms = (params.highlight) ? params.highlight[0].split(/\s+/) : []; 222 | if (terms.length) { 223 | var body = $('div.body'); 224 | if (!body.length) { 225 | body = $('body'); 226 | } 227 | window.setTimeout(function() { 228 | $.each(terms, function() { 229 | body.highlightText(this.toLowerCase(), 'highlighted'); 230 | }); 231 | }, 10); 232 | $('') 234 | .appendTo($('#searchbox')); 235 | } 236 | }, 237 | 238 | /** 239 | * init the domain index toggle buttons 240 | */ 241 | initIndexTable : function() { 242 | var togglers = $('img.toggler').click(function() { 243 | var src = $(this).attr('src'); 244 | var idnum = $(this).attr('id').substr(7); 245 | $('tr.cg-' + idnum).toggle(); 246 | if (src.substr(-9) === 'minus.png') 247 | $(this).attr('src', src.substr(0, src.length-9) + 'plus.png'); 248 | else 249 | $(this).attr('src', src.substr(0, src.length-8) + 'minus.png'); 250 | }).css('display', ''); 251 | if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) { 252 | togglers.click(); 253 | } 254 | }, 255 | 256 | /** 257 | * helper function to hide the search marks again 258 | */ 259 | hideSearchWords : function() { 260 | $('#searchbox .highlight-link').fadeOut(300); 261 | $('span.highlighted').removeClass('highlighted'); 262 | }, 263 | 264 | /** 265 | * make the url absolute 266 | */ 267 | makeURL : function(relativeURL) { 268 | return DOCUMENTATION_OPTIONS.URL_ROOT + '/' + relativeURL; 269 | }, 270 | 271 | /** 272 | * get the current relative url 273 | */ 274 | getCurrentURL : function() { 275 | var path = document.location.pathname; 276 | var parts = path.split(/\//); 277 | $.each(DOCUMENTATION_OPTIONS.URL_ROOT.split(/\//), function() { 278 | if (this === '..') 279 | parts.pop(); 280 | }); 281 | var url = parts.join('/'); 282 | return path.substring(url.lastIndexOf('/') + 1, path.length - 1); 283 | }, 284 | 285 | initOnKeyListeners: function() { 286 | $(document).keydown(function(event) { 287 | var activeElementType = document.activeElement.tagName; 288 | // don't navigate when in search box or textarea 289 | if (activeElementType !== 'TEXTAREA' && activeElementType !== 'INPUT' && activeElementType !== 'SELECT' 290 | && !event.altKey && !event.ctrlKey && !event.metaKey && !event.shiftKey) { 291 | switch (event.keyCode) { 292 | case 37: // left 293 | var prevHref = $('link[rel="prev"]').prop('href'); 294 | if (prevHref) { 295 | window.location.href = prevHref; 296 | return false; 297 | } 298 | case 39: // right 299 | var nextHref = $('link[rel="next"]').prop('href'); 300 | if (nextHref) { 301 | window.location.href = nextHref; 302 | return false; 303 | } 304 | } 305 | } 306 | }); 307 | } 308 | }; 309 | 310 | // quick alias for translations 311 | _ = Documentation.gettext; 312 | 313 | $(document).ready(function() { 314 | Documentation.init(); 315 | }); 316 | -------------------------------------------------------------------------------- /torchph/pershom/pershom_cpp_src/vertex_filtration_comp_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | // #include 3 | 4 | // #include 5 | // #include 6 | // #include 7 | // #include 8 | // #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include "vertex_filtration_comp_cuda.h" 15 | #include "vr_comp_cuda.cuh" 16 | #include "calc_pers_cuda.cuh" 17 | 18 | 19 | using namespace torch; 20 | 21 | 22 | namespace VertFiltCompCuda { 23 | std::vector vert_filt_comp_calculate_persistence_args( 24 | const Tensor & vertex_filtration, 25 | const std::vector & boundary_info) 26 | { 27 | // initialize helper variables... 28 | auto ret = std::vector(); 29 | 30 | auto tensopt_real = torch::TensorOptions() 31 | .dtype(vertex_filtration.dtype()) 32 | .device(vertex_filtration.device()); 33 | 34 | auto tensopt_int = torch::TensorOptions() 35 | .dtype(torch::kInt64) 36 | .device(vertex_filtration.device()); 37 | 38 | Tensor num_simplices_by_dim; 39 | { 40 | std::vector tmp = {vertex_filtration.size(0)}; 41 | 42 | for (auto const& t: boundary_info){ 43 | tmp.push_back(t.size(0)); 44 | } 45 | 46 | num_simplices_by_dim = torch::tensor(tmp); 47 | } 48 | auto num_simplices_up_to_dim = num_simplices_by_dim.cumsum(0); 49 | auto max_dim = num_simplices_by_dim.size(0)-1; 50 | 51 | // generate simplex dimensionality vector... 52 | auto simplex_dim = torch::zeros({num_simplices_up_to_dim[-1].item()}, tensopt_int); 53 | for (int i = 0; i < max_dim; i++){ 54 | simplex_dim.slice( 55 | 0, 56 | num_simplices_up_to_dim[i].item(), 57 | num_simplices_up_to_dim[i+1].item() 58 | ).fill_(i+1); 59 | 60 | } 61 | 62 | // generate filtration values vector 63 | Tensor filt_val_vec; 64 | { 65 | auto tmp = std::vector(); 66 | tmp.push_back(vertex_filtration); 67 | for (auto const& bi: boundary_info){ 68 | auto v = tmp.back(); 69 | v = v.unsqueeze(0).expand({bi.size(0), -1}); 70 | v = v.gather(1, bi); 71 | 72 | v = std::get<0>(v.max(1)); 73 | 74 | tmp.push_back(v); 75 | } 76 | 77 | filt_val_vec = torch::cat(tmp, 0); 78 | } 79 | 80 | // adapt simplex ids from dimension-wise ids to global ids 81 | std::vector boundary_info_global_id; 82 | { 83 | auto tmp = std::vector(); 84 | tmp.push_back(boundary_info.at(0)); 85 | 86 | for (int i = 1; i < boundary_info.size(); i++){ 87 | tmp.push_back( 88 | boundary_info.at(i) + num_simplices_up_to_dim[i-1] 89 | ); 90 | } 91 | 92 | boundary_info_global_id = tmp; 93 | } 94 | 95 | // do sorting with epsilon hack ... 96 | Tensor sorted_filt_val_vec, perm_inv, perm; 97 | { 98 | auto hack_add = simplex_dim.to(filt_val_vec.dtype()) * 100 * std::numeric_limits::epsilon(); 99 | auto tmp = filt_val_vec + hack_add; 100 | auto sort_res = tmp.sort(0); 101 | sorted_filt_val_vec = std::get<0>(sort_res); 102 | perm = std::get<1>(sort_res); 103 | 104 | sort_res = perm.sort(); 105 | perm_inv = std::get<1>(sort_res); 106 | 107 | sorted_filt_val_vec = sorted_filt_val_vec - hack_add.index_select(0, perm); 108 | } 109 | 110 | // transfer simplex ids to filtration-based ids ... 111 | { 112 | auto tmp = std::vector(); 113 | for (auto const& bi: boundary_info_global_id){ 114 | auto perm_inv_expanded = perm_inv.unsqueeze(0).expand({bi.size(0), -1}); 115 | auto bi_new = perm_inv_expanded.gather(1, bi); 116 | bi_new = std::get<0>(bi_new.sort(1, true)); 117 | tmp.push_back(bi_new); 118 | } 119 | 120 | boundary_info_global_id = tmp; 121 | } 122 | 123 | // create boundary array ... 124 | auto ba = torch::empty( 125 | {num_simplices_up_to_dim[-1].item(), 126 | (max_dim+1)*2}, 127 | tensopt_int 128 | ); 129 | ba.fill_(-1); 130 | 131 | for (int i = 0; i < boundary_info_global_id.size(); i++){ 132 | ba.slice( 133 | 0, 134 | num_simplices_up_to_dim[i].item(), 135 | num_simplices_up_to_dim[i+1].item() 136 | ).slice( 137 | 1, 0, i+2 138 | ) = boundary_info_global_id.at(i); 139 | } 140 | // final sorting ... 141 | 142 | ba = ba.index_select(0, perm); 143 | simplex_dim = simplex_dim.index_select(0, perm); 144 | 145 | // compressing 146 | auto i = simplex_dim.gt(0).nonzero().squeeze(); 147 | auto ba_row_i_to_bm_col_i = torch::arange(0, ba.size(0), tensopt_int); 148 | 149 | ba_row_i_to_bm_col_i = ba_row_i_to_bm_col_i.index_select(0, i); 150 | ba = ba.index_select(0, i); 151 | 152 | ret.push_back(ba); 153 | ret.push_back(ba_row_i_to_bm_col_i); 154 | ret.push_back(simplex_dim); 155 | ret.push_back(sorted_filt_val_vec); 156 | 157 | return ret; 158 | } 159 | 160 | 161 | // //TODO compare to the 'calculate_persistence_output_to_barcode_tensors' 162 | // // function in vr_comp_cuda.cu and refactor 163 | // Tensor read_non_essential_barcode( 164 | // const Tensor & barcode, 165 | // const Tensor & sorted_filtration_values) 166 | // { 167 | // Tensor ret; 168 | // if (barcode.size(0) == 0) 169 | // { 170 | // ret = torch::empty({0, 2}, sorted_filtration_values.options()); 171 | // } 172 | // else 173 | // { 174 | // auto v = sorted_filtration_values 175 | // .unsqueeze(0) 176 | // .expand({barcode.size(0), -1}); 177 | // ret = v.gather(1, barcode); 178 | 179 | // auto i = ret.slice(1, 0, 1).ne(ret.slice(1, 1, 2)); 180 | // i = i.nonzero().squeeze(); 181 | 182 | // if (i.size(0) == 0){ 183 | // ret = torch::empty({0, 2}, sorted_filtration_values.options()); 184 | // } 185 | // else 186 | // { 187 | // ret = ret.index_select(0, i); 188 | // } 189 | // } 190 | 191 | // return ret; 192 | // } 193 | 194 | 195 | // Tensor read_essential_barcode( 196 | // const Tensor & barcode, 197 | // const Tensor & sorted_filtration_values) 198 | // { 199 | // Tensor ret; 200 | // if (barcode.size(0) == 0) 201 | // { 202 | // ret = torch::empty({0, 1}, sorted_filtration_values.options()); 203 | // } 204 | // else 205 | // { 206 | // auto v = sorted_filtration_values 207 | // .unsqueeze(0) 208 | // .expand({barcode.size(0), -1}); 209 | // ret = v.gather(1, barcode); 210 | // } 211 | 212 | // return ret; 213 | // } 214 | 215 | 216 | // std::vector> read_barcode_from_birth_death_times( 217 | // const std::vector>& calculate_persistence_output, 218 | // const Tensor & sorted_filtration_values) 219 | // { 220 | // auto ret_non_ess = std::vector(); 221 | // auto ret_ess = std::vector(); 222 | // auto ret = std::vector>(); 223 | 224 | 225 | // auto non_ess_barcodes = calculate_persistence_output.at(0); 226 | // auto ess_barcodes = calculate_persistence_output.at(1); 227 | 228 | // for (auto const& b: non_ess_barcodes){ 229 | // ret_non_ess.push_back( 230 | // read_non_essential_barcode( 231 | // b, 232 | // sorted_filtration_values 233 | // ) 234 | // ); 235 | // } 236 | 237 | // for (auto const& b: ess_barcodes){ 238 | // ret_ess.push_back( 239 | // read_essential_barcode( 240 | // b, 241 | // sorted_filtration_values 242 | // ) 243 | // ); 244 | // } 245 | 246 | // ret.push_back(ret_non_ess); 247 | // ret.push_back(ret_ess); 248 | 249 | // return ret; 250 | // } 251 | 252 | 253 | std::vector> vert_filt_persistence_single( 254 | const Tensor & vertex_filtration, 255 | const std::vector & boundary_info) 256 | { 257 | auto r = vert_filt_comp_calculate_persistence_args( 258 | vertex_filtration, 259 | boundary_info 260 | ); 261 | 262 | auto ba = r.at(0); 263 | auto ba_row_i_to_bm_col_i = r.at(1); 264 | auto simplex_dim = r.at(2); 265 | auto sorted_filtration_values = r.at(3); 266 | 267 | auto calc_pers_output = CalcPersCuda::calculate_persistence( 268 | ba, 269 | ba_row_i_to_bm_col_i, 270 | simplex_dim, 271 | boundary_info.size(), 272 | -1 273 | ); 274 | // VRCompCuda::calculate_persistence_output_to_barcode_tensors 275 | // read_barcode_from_birth_death_times 276 | return VRCompCuda::calculate_persistence_output_to_barcode_tensors( 277 | calc_pers_output, 278 | sorted_filtration_values 279 | ); 280 | } 281 | 282 | 283 | std::vector>> vert_filt_persistence_batch( 284 | const std::vector>> & batch 285 | ) 286 | { 287 | auto futures = std::vector>>>(); 288 | for (auto & arg: batch){ 289 | 290 | futures.push_back( 291 | std::async( 292 | std::launch::async, 293 | [=]{ 294 | return vert_filt_persistence_single( 295 | std::get<0>(arg), 296 | std::get<1>(arg) 297 | ); 298 | } 299 | ) 300 | ); 301 | } 302 | 303 | auto ret = std::vector>>(); 304 | for (auto & fut: futures){ 305 | ret.push_back( 306 | fut.get() 307 | ); 308 | } 309 | 310 | return ret; 311 | } 312 | } 313 | --------------------------------------------------------------------------------