├── tests ├── __init__.py ├── test_common.py ├── test_lie_group.py ├── test_manifold.py ├── test_pca.py ├── test_backend_tensorflow.py ├── test_template.py ├── test_connection.py ├── test_visualization.py ├── helper.py ├── test_examples.py ├── test_matrices_space.py ├── test_backend_numpy.py ├── test_general_linear_group.py ├── test_stiefel.py ├── test_minkowski_space.py └── test_spd_matrices_space.py ├── geomstats ├── backend │ ├── common.py │ ├── tensorflow_testing.py │ ├── numpy_testing.py │ ├── pytorch_testing.py │ ├── numpy_random.py │ ├── pytorch_random.py │ ├── tensorflow_random.py │ ├── __init__.py │ ├── pytorch_linalg.py │ ├── numpy_linalg.py │ ├── tensorflow_linalg.py │ ├── tensorflow.py │ ├── numpy.py │ └── pytorch.py ├── learning │ ├── __init__.py │ ├── _template.py │ ├── mean_shift.py │ └── pca.py ├── __init__.py ├── __about__.py ├── embedded_manifold.py ├── manifold.py ├── tests.py ├── matrices_space.py ├── euclidean_space.py ├── general_linear_group.py ├── minkowski_space.py ├── lie_group.py ├── spd_matrices_space.py └── connection.py ├── .gitignore ├── docs ├── troubleshooting.rst ├── requirements.doc.txt ├── changelog.rst ├── install.rst ├── index.rst ├── Makefile ├── conf.py ├── tutorials.rst └── api-reference.rst ├── examples ├── imgs │ ├── h2_grid.png │ ├── gradient_descent.gif │ └── gradient_descent.png ├── plot_geodesics_so3.py ├── plot_geodesics_se3.py ├── plot_geodesics_s2.py ├── plot_quantization_s1.py ├── plot_quantization_s2.py ├── plot_square_h2_poincare_disk.py ├── plot_square_h2_poincare_half_plane.py ├── plot_grid_h2.py ├── plot_square_h2_klein_disk.py ├── tangent_pca_so3.py ├── plot_geodesics_h2.py ├── loss_and_gradient_so3.py ├── gradient_descent_s2.py └── loss_and_gradient_se3.py ├── requirements.txt ├── .travis.yml ├── setup.py ├── LICENSE.md ├── README.md └── CONTRIBUTING.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /geomstats/backend/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | pi = np.pi 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | dist 3 | *.mp4 4 | *egg-info 5 | *.egg 6 | *.pyc 7 | *.png 8 | -------------------------------------------------------------------------------- /docs/troubleshooting.rst: -------------------------------------------------------------------------------- 1 | Troubleshooting 2 | =============== 3 | 4 | Trouble shooting. 5 | -------------------------------------------------------------------------------- /examples/imgs/h2_grid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/geomstats/master/examples/imgs/h2_grid.png -------------------------------------------------------------------------------- /examples/imgs/gradient_descent.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/geomstats/master/examples/imgs/gradient_descent.gif -------------------------------------------------------------------------------- /examples/imgs/gradient_descent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/geomstats/master/examples/imgs/gradient_descent.png -------------------------------------------------------------------------------- /docs/requirements.doc.txt: -------------------------------------------------------------------------------- 1 | numpydoc==0.8 2 | sphinx 3 | dask_sphinx_theme>=1.1.0 4 | sphinx-click 5 | toolz 6 | cloudpickle 7 | pandas>=0.19.0 8 | distributed 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | autograd 2 | codecov 3 | coverage 4 | h5py==2.8.0 5 | matplotlib 6 | nose2 7 | numpy>=1.14.1 8 | scikit-learn 9 | scipy 10 | tensorflow>=1.12 11 | torch==0.4.0 12 | -------------------------------------------------------------------------------- /geomstats/backend/tensorflow_testing.py: -------------------------------------------------------------------------------- 1 | """Testing backend.""" 2 | 3 | import numpy as np 4 | 5 | 6 | def assert_allclose(*args, **kwargs): 7 | return np.testing.assert_allclose(*args, **kwargs) 8 | -------------------------------------------------------------------------------- /geomstats/backend/numpy_testing.py: -------------------------------------------------------------------------------- 1 | """Numpy based testing backend.""" 2 | 3 | import numpy as np 4 | 5 | 6 | def assert_allclose(*args, **kwargs): 7 | return np.testing.assert_allclose(*args, **kwargs) 8 | -------------------------------------------------------------------------------- /geomstats/backend/pytorch_testing.py: -------------------------------------------------------------------------------- 1 | """Pytorch based testing backend.""" 2 | 3 | import torch 4 | 5 | 6 | def assert_allclose(*args, **kwargs): 7 | return torch.testing.assert_allclose(*args, **kwargs) 8 | -------------------------------------------------------------------------------- /docs/changelog.rst: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | **0.1.7** 5 | 6 | * Bugfixes 7 | * Default arguments on the API 8 | * Better command line argument parsing 9 | 10 | **0.1.6** 11 | 12 | * Cleaner API. 13 | -------------------------------------------------------------------------------- /geomstats/learning/__init__.py: -------------------------------------------------------------------------------- 1 | from ._template import TemplateEstimator 2 | from ._template import TemplateClassifier 3 | from ._template import TemplateTransformer 4 | 5 | __all__ = ['TemplateEstimator', 'TemplateClassifier', 'TemplateTransformer', 6 | '__version__'] 7 | -------------------------------------------------------------------------------- /docs/install.rst: -------------------------------------------------------------------------------- 1 | Install Geomstats 2 | ================= 3 | 4 | 5 | You can install geomstats with ``pip3`` as follows:: 6 | 7 | pip3 install geomstats 8 | 9 | You should choose your backend by setting the environment variable GEOMSTATS_BACKEND to numpy, tensorflow or pytorch:: 10 | 11 | export GEOMSTATS_BACKEND=numpy 12 | -------------------------------------------------------------------------------- /geomstats/backend/numpy_random.py: -------------------------------------------------------------------------------- 1 | """Numpy based random backend.""" 2 | 3 | import numpy as np 4 | 5 | 6 | def rand(*args, **kwargs): 7 | return np.random.rand(*args, **kwargs) 8 | 9 | 10 | def randint(*args, **kwargs): 11 | return np.random.randint(*args, **kwargs) 12 | 13 | 14 | def seed(*args, **kwargs): 15 | return np.random.seed(*args, **kwargs) 16 | 17 | 18 | def normal(*args, **kwargs): 19 | return np.random.normal(*args, **kwargs) 20 | -------------------------------------------------------------------------------- /geomstats/__init__.py: -------------------------------------------------------------------------------- 1 | from .__about__ import __version__ 2 | 3 | import geomstats.manifold 4 | import geomstats.euclidean_space 5 | import geomstats.hyperbolic_space 6 | import geomstats.hypersphere 7 | import geomstats.invariant_metric 8 | import geomstats.lie_group 9 | import geomstats.minkowski_space 10 | import geomstats.spd_matrices_space 11 | import geomstats.special_euclidean_group 12 | import geomstats.special_orthogonal_group 13 | import geomstats.riemannian_metric 14 | -------------------------------------------------------------------------------- /geomstats/backend/pytorch_random.py: -------------------------------------------------------------------------------- 1 | """Torch based random backend.""" 2 | 3 | import torch 4 | 5 | 6 | def rand(*args, **kwargs): 7 | return torch.rand(*args, **kwargs) 8 | 9 | 10 | def randint(*args, **kwargs): 11 | return torch.randint(*args, **kwargs) 12 | 13 | 14 | def seed(*args, **kwargs): 15 | return torch.manual_seed(*args, **kwargs) 16 | 17 | 18 | def normal(loc=0.0, scale=1.0, size=(1, 1)): 19 | return torch.normal(torch.zeros(size), torch.ones(size)) 20 | -------------------------------------------------------------------------------- /geomstats/__about__.py: -------------------------------------------------------------------------------- 1 | # Remove -dev before releasing 2 | __version__ = '1.15' 3 | 4 | from itertools import chain 5 | 6 | install_requires = [ 7 | 'autograd', 8 | 'h5py==2.8.0', 9 | 'matplotlib', 10 | 'numpy>=1.14.1', 11 | 'scipy', 12 | ] 13 | 14 | extras_require = { 15 | 'test': ['codecov', 'coverage', 'nose2'], 16 | 'tf': ['tensorflow>=1.12'], 17 | 'torch': ['torch==0.4.0'], 18 | } 19 | extras_require['all'] = list(chain(*extras_require.values())) 20 | -------------------------------------------------------------------------------- /tests/test_common.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from sklearn.utils.estimator_checks import check_estimator 4 | 5 | from geomstats.learning._template import TemplateEstimator 6 | from geomstats.learning._template import TemplateClassifier 7 | from geomstats.learning._template import TemplateTransformer 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "Estimator", [TemplateEstimator, TemplateTransformer, TemplateClassifier] 12 | ) 13 | def test_all_estimators(Estimator): 14 | return check_estimator(Estimator) 15 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Geomstats 2 | ========= 3 | 4 | *Geomstats provides code for computations and statistics on manifolds with geometric structures.* 5 | 6 | **Quick install** 7 | 8 | .. code-block:: bash 9 | 10 | pip3 install geomstats 11 | export GEOMSTATS_BACKEND=numpy 12 | 13 | .. toctree:: 14 | :maxdepth: 1 15 | :caption: Getting Started 16 | 17 | install.rst 18 | api-reference.rst 19 | contributing.rst 20 | tutorials.rst 21 | changelog.rst 22 | 23 | .. toctree:: 24 | :maxdepth: 1 25 | :caption: Support 26 | 27 | troubleshooting.rst 28 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | dist: trusty 3 | cache: pip 4 | language: python 5 | python: 6 | - "3.5" 7 | - "3.5-dev" 8 | - "3.6" 9 | - "3.6-dev" 10 | 11 | install: 12 | - pip install --upgrade pip setuptools wheel 13 | - pip install -q -r requirements.txt --only-binary=numpy,scipy 14 | script: 15 | - nose2 --with-coverage --verbose 16 | env: 17 | - GEOMSTATS_BACKEND=numpy 18 | - GEOMSTATS_BACKEND=pytorch 19 | - GEOMSTATS_BACKEND=tensorflow 20 | 21 | after_success: 22 | - bash <(curl -s https://codecov.io/bash) -c -F $GEOMSTATS_BACKEND 23 | -------------------------------------------------------------------------------- /tests/test_lie_group.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for Lie groups. 3 | """ 4 | 5 | import geomstats.tests 6 | 7 | from geomstats.lie_group import LieGroup 8 | 9 | 10 | class TestLieGroupMethods(geomstats.tests.TestCase): 11 | _multiprocess_can_split_ = True 12 | 13 | dimension = 4 14 | group = LieGroup(dimension=dimension) 15 | 16 | def test_dimension(self): 17 | result = self.group.dimension 18 | expected = self.dimension 19 | 20 | self.assertAllClose(result, expected) 21 | 22 | 23 | if __name__ == '__main__': 24 | geomstats.tests.main() 25 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import runpy 3 | from setuptools import setup, find_packages 4 | 5 | base_dir = os.path.dirname(os.path.abspath(__file__)) 6 | about = runpy.run_path(os.path.join(base_dir, 'geomstats', '__about__.py')) 7 | 8 | setup(name='geomstats', 9 | version=about['__version__'], 10 | install_requires=about['install_requires'], 11 | extras_require=about['extras_require'], 12 | description='Geometric statistics on manifolds', 13 | url='http://github.com/geomstats/geomstats', 14 | author='Nina Miolane', 15 | author_email='ninamio78@gmail.com', 16 | license='MIT', 17 | packages=find_packages(), 18 | zip_safe=False) 19 | -------------------------------------------------------------------------------- /geomstats/backend/tensorflow_random.py: -------------------------------------------------------------------------------- 1 | """Tensorflow based random backend.""" 2 | 3 | import tensorflow as tf 4 | 5 | 6 | def randint(low, high=None, size=None): 7 | if size is None: 8 | size = (1,) 9 | maxval = high 10 | minval = low 11 | if high is None: 12 | maxval = low - 1 13 | minval = 0 14 | return tf.random_uniform( 15 | shape=size, 16 | minval=minval, 17 | maxval=maxval, dtype=tf.int32, seed=None, name=None) 18 | 19 | 20 | def rand(*args): 21 | return tf.random_uniform(shape=args) 22 | 23 | 24 | def seed(*args): 25 | return tf.set_random_seed(*args) 26 | 27 | 28 | def normal(loc=0.0, scale=1.0, size=(1, 1)): 29 | return tf.random_normal(mean=loc, stddev=scale, shape=size) 30 | -------------------------------------------------------------------------------- /tests/test_manifold.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for manifolds. 3 | """ 4 | 5 | import geomstats.backend as gs 6 | import geomstats.tests 7 | 8 | from geomstats.manifold import Manifold 9 | 10 | 11 | class TestManifoldMethods(geomstats.tests.TestCase): 12 | _multiprocess_can_split_ = True 13 | 14 | def setUp(self): 15 | self.dimension = 4 16 | self.manifold = Manifold(self.dimension) 17 | 18 | def test_dimension(self): 19 | result = self.manifold.dimension 20 | expected = self.dimension 21 | self.assertAllClose(result, expected) 22 | 23 | def test_belongs(self): 24 | point = gs.array([1., 2., 3.]) 25 | self.assertRaises(NotImplementedError, 26 | lambda: self.manifold.belongs(point)) 27 | 28 | def test_regularize(self): 29 | point = gs.array([1., 2., 3.]) 30 | result = self.manifold.regularize(point) 31 | expected = point 32 | self.assertAllClose(result, expected) 33 | 34 | 35 | if __name__ == '__main__': 36 | geomstats.test.main() 37 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Nina Miolane 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /geomstats/backend/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | _default_backend = 'numpy' 5 | if 'GEOMSTATS_BACKEND' in os.environ: 6 | _backend = os.environ['GEOMSTATS_BACKEND'] 7 | 8 | else: 9 | _backend = _default_backend 10 | 11 | _BACKEND = _backend 12 | 13 | from .common import * # NOQA 14 | 15 | if _BACKEND == 'numpy': 16 | sys.stderr.write('Using numpy backend\n') 17 | from .numpy import * # NOQA 18 | from . import numpy_linalg as linalg 19 | from . import numpy_random as random 20 | from . import numpy_testing as testing 21 | elif _BACKEND == 'pytorch': 22 | sys.stderr.write('Using pytorch backend\n') 23 | from .pytorch import * # NOQA 24 | from . import pytorch_linalg as linalg # NOQA 25 | from . import pytorch_random as random # NOQA 26 | from . import pytorch_testing as testing # NOQA 27 | elif _BACKEND == 'tensorflow': 28 | sys.stderr.write('Using tensorflow backend\n') 29 | from .tensorflow import * # NOQA 30 | from . import tensorflow_linalg as linalg # NOQA 31 | from . import tensorflow_random as random # NOQA 32 | from . import tensorflow_testing as testing # NOQA 33 | 34 | 35 | def backend(): 36 | return _BACKEND 37 | -------------------------------------------------------------------------------- /examples/plot_geodesics_so3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot a geodesic of SO(3) equipped 3 | with its left-invariant canonical METRIC. 4 | """ 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import os 9 | 10 | import geomstats.visualization as visualization 11 | 12 | from geomstats.special_orthogonal_group import SpecialOrthogonalGroup 13 | 14 | SO3_GROUP = SpecialOrthogonalGroup(n=3) 15 | METRIC = SO3_GROUP.bi_invariant_metric 16 | 17 | 18 | def main(): 19 | initial_point = SO3_GROUP.identity 20 | initial_tangent_vec = [0.5, 0.5, 0.8] 21 | geodesic = METRIC.geodesic(initial_point=initial_point, 22 | initial_tangent_vec=initial_tangent_vec) 23 | 24 | n_steps = 10 25 | t = np.linspace(0, 1, n_steps) 26 | 27 | points = geodesic(t) 28 | visualization.plot(points, space='SO3_GROUP') 29 | plt.show() 30 | 31 | 32 | if __name__ == "__main__": 33 | if os.environ['GEOMSTATS_BACKEND'] == 'tensorflow': 34 | print('Examples with visualizations are only implemented ' 35 | 'with numpy backend.\n' 36 | 'To change backend, write: ' 37 | 'export GEOMSTATS_BACKEND = \'numpy\'.') 38 | else: 39 | main() 40 | -------------------------------------------------------------------------------- /examples/plot_geodesics_se3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot a geodesic of SE(3) equipped 3 | with its left-invariant canonical METRIC. 4 | """ 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import os 9 | 10 | import geomstats.visualization as visualization 11 | 12 | from geomstats.special_euclidean_group import SpecialEuclideanGroup 13 | 14 | SE3_GROUP = SpecialEuclideanGroup(n=3) 15 | METRIC = SE3_GROUP.left_canonical_metric 16 | 17 | 18 | def main(): 19 | initial_point = SE3_GROUP.identity 20 | initial_tangent_vec = [1.8, 0.2, 0.3, 3., 3., 1.] 21 | geodesic = METRIC.geodesic(initial_point=initial_point, 22 | initial_tangent_vec=initial_tangent_vec) 23 | 24 | n_steps = 40 25 | t = np.linspace(-3, 3, n_steps) 26 | 27 | points = geodesic(t) 28 | 29 | visualization.plot(points, space='SE3_GROUP') 30 | plt.show() 31 | 32 | 33 | if __name__ == "__main__": 34 | if os.environ['GEOMSTATS_BACKEND'] == 'tensorflow': 35 | print('Examples with visualizations are only implemented ' 36 | 'with numpy backend.\n' 37 | 'To change backend, write: ' 38 | 'export GEOMSTATS_BACKEND = \'numpy\'.') 39 | else: 40 | main() 41 | -------------------------------------------------------------------------------- /tests/test_pca.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from sklearn.utils.testing import assert_allclose 5 | 6 | from geomstats.special_orthogonal_group import SpecialOrthogonalGroup 7 | 8 | from geomstats.learning.pca import TangentPCA 9 | 10 | 11 | SO3_GROUP = SpecialOrthogonalGroup(n=3) 12 | METRIC = SO3_GROUP.bi_invariant_metric 13 | N_SAMPLES = 10 14 | N_COMPONENTS = 2 15 | 16 | 17 | @pytest.fixture 18 | def data(): 19 | data = SO3_GROUP.random_uniform(n_samples=N_SAMPLES) 20 | return data 21 | 22 | 23 | def test_tangent_pca_error(data): 24 | X = data 25 | trans = TangentPCA(n_components=N_COMPONENTS) 26 | trans.fit(X) 27 | with pytest.raises(ValueError, match="Shape of input is different"): 28 | X_diff_size = np.ones((10, X.shape[1] + 1)) 29 | trans.transform(X_diff_size) 30 | 31 | 32 | def test_tangent_pca(data): 33 | X = data 34 | trans = TangentPCA(n_components=N_COMPONENTS) 35 | assert trans.demo_param == 'demo' 36 | 37 | trans.fit(X) 38 | assert trans.n_features_ == X.shape[1] 39 | 40 | X_trans = trans.transform(X) 41 | assert_allclose(X_trans, np.sqrt(X)) 42 | 43 | X_trans = trans.fit_transform(X) 44 | assert_allclose(X_trans, np.sqrt(X)) 45 | -------------------------------------------------------------------------------- /geomstats/embedded_manifold.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manifold embedded in another manifold. 3 | """ 4 | 5 | import math 6 | 7 | from geomstats.manifold import Manifold 8 | 9 | 10 | class EmbeddedManifold(Manifold): 11 | """ 12 | Class for manifolds embedded in another manifold. 13 | """ 14 | 15 | def __init__(self, dimension, embedding_manifold): 16 | assert isinstance(dimension, int) or dimension == math.inf 17 | assert dimension > 0 18 | super(EmbeddedManifold, self).__init__( 19 | dimension=dimension) 20 | self.embedding_manifold = embedding_manifold 21 | 22 | def intrinsic_to_extrinsic_coords(self, point_intrinsic): 23 | raise NotImplementedError( 24 | 'intrinsic_to_extrinsic_coords is not implemented.') 25 | 26 | def extrinsic_to_intrinsic_coords(self, point_extrinsic): 27 | raise NotImplementedError( 28 | 'extrinsic_to_intrinsic_coords is not implemented.') 29 | 30 | def projection(self, point): 31 | raise NotImplementedError( 32 | 'projection is not implemented.') 33 | 34 | def projection_to_tangent_space(self, vector, base_point): 35 | raise NotImplementedError( 36 | 'projection_to_tangent_space is not implemented.') 37 | -------------------------------------------------------------------------------- /examples/plot_geodesics_s2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot a geodesic on the sphere S2 3 | """ 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import os 8 | 9 | import geomstats.visualization as visualization 10 | 11 | from geomstats.hypersphere import Hypersphere 12 | 13 | SPHERE2 = Hypersphere(dimension=2) 14 | METRIC = SPHERE2.metric 15 | 16 | 17 | def main(): 18 | initial_point = [1., 0., 0.] 19 | initial_tangent_vec = SPHERE2.projection_to_tangent_space( 20 | vector=[1., 2., 0.8], 21 | base_point=initial_point) 22 | geodesic = METRIC.geodesic(initial_point=initial_point, 23 | initial_tangent_vec=initial_tangent_vec) 24 | 25 | n_steps = 10 26 | t = np.linspace(0, 1, n_steps) 27 | 28 | points = geodesic(t) 29 | visualization.plot(points, space='S2') 30 | plt.show() 31 | 32 | 33 | if __name__ == "__main__": 34 | if os.environ['GEOMSTATS_BACKEND'] == 'tensorflow': 35 | print('Examples with visualizations are only implemented ' 36 | 'with numpy backend.\n' 37 | 'To change backend, write: ' 38 | 'export GEOMSTATS_BACKEND = \'numpy\'.') 39 | else: 40 | main() 41 | -------------------------------------------------------------------------------- /geomstats/backend/pytorch_linalg.py: -------------------------------------------------------------------------------- 1 | """Pytorch based linear algebra backend.""" 2 | 3 | import numpy as np 4 | import scipy.linalg 5 | import torch 6 | 7 | 8 | def expm(x): 9 | np_expm = np.vectorize( 10 | scipy.linalg.expm, signature='(n,m)->(n,m)')(x) 11 | return torch.from_numpy(np_expm) 12 | 13 | 14 | def inv(*args, **kwargs): 15 | return torch.from_numpy(np.linalg.inv(*args, **kwargs)) 16 | 17 | 18 | def eigvalsh(*args, **kwargs): 19 | return torch.from_numpy(np.linalg.eigvalsh(*args, **kwargs)) 20 | 21 | 22 | def eigh(*args, **kwargs): 23 | eigs = np.linalg.eigh(*args, **kwargs) 24 | return torch.from_numpy(eigs[0]), torch.from_numpy(eigs[1]) 25 | 26 | 27 | def svd(*args, **kwargs): 28 | svds = np.linalg.svd(*args, **kwargs) 29 | return (torch.from_numpy(svds[0]), 30 | torch.from_numpy(svds[1]), 31 | torch.from_numpy(svds[2])) 32 | 33 | 34 | def det(*args, **kwargs): 35 | return torch.from_numpy(np.linalg.det(*args, **kwargs)) 36 | 37 | 38 | def norm(x, ord=2, axis=None, keepdims=False): 39 | if axis is None: 40 | return torch.norm(x, p=ord) 41 | return torch.norm(x, p=ord, dim=axis) 42 | 43 | 44 | def qr(*args, **kwargs): 45 | return torch.from_numpy(np.linalg.qr(*args, **kwargs)) 46 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | project = 'Geomstats' 2 | copyright = '2019, Geomstats, Inc.' 3 | author = 'Geomstats Team' 4 | 5 | version = '0.1' 6 | # The full version, including alpha/beta/rc tags 7 | release = '0.1' 8 | 9 | 10 | extensions = [ 11 | 'sphinx.ext.autodoc', 12 | 'sphinx.ext.doctest', 13 | 'sphinx.ext.coverage', 14 | 'sphinx.ext.mathjax', 15 | 'sphinx.ext.viewcode', 16 | 'sphinx.ext.githubpages', 17 | ] 18 | 19 | templates_path = ['templates'] 20 | 21 | source_suffix = '.rst' 22 | 23 | master_doc = 'index' 24 | 25 | language = None 26 | 27 | exclude_patterns = ['build', 'Thumbs.db', '.DS_Store'] 28 | 29 | pygments_style = None 30 | 31 | html_theme = 'sphinx_rtd_theme' 32 | 33 | html_static_path = ['static'] 34 | 35 | htmlhelp_basename = 'geomstatsdoc' 36 | 37 | latex_elements = { 38 | } 39 | 40 | 41 | latex_documents = [ 42 | (master_doc, 'geomstats.tex', 'geomstats Documentation', 43 | 'Geomstats Team', 'manual'), 44 | ] 45 | 46 | man_pages = [ 47 | (master_doc, 'geomstats', 'geomstats Documentation', 48 | [author], 1) 49 | ] 50 | 51 | texinfo_documents = [ 52 | (master_doc, 'geomstats', 'geomstats Documentation', 53 | author, 'geomstats', 'One line description of project.', 54 | 'Miscellaneous'), 55 | ] 56 | 57 | epub_title = project 58 | epub_exclude_files = ['search.html'] 59 | -------------------------------------------------------------------------------- /geomstats/manifold.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manifold, i.e. a topological space that locally resembles 3 | Euclidean space near each point. 4 | """ 5 | 6 | import math 7 | 8 | 9 | class Manifold(object): 10 | """ 11 | Class for manifolds. 12 | """ 13 | 14 | def __init__(self, dimension): 15 | 16 | assert isinstance(dimension, int) or dimension == math.inf 17 | assert dimension > 0 18 | 19 | self.dimension = dimension 20 | 21 | def belongs(self, point, point_type=None): 22 | """ 23 | Evaluate if a point belongs to the manifold. 24 | 25 | Parameters 26 | ---------- 27 | points : array-like, shape=[n_samples, dimension] 28 | Input points. 29 | 30 | Returns 31 | ------- 32 | belongs : array-like, shape=[n_samples, 1] 33 | """ 34 | raise NotImplementedError('belongs is not implemented.') 35 | 36 | def regularize(self, point, point_type=None): 37 | """ 38 | Regularize a point to the canonical representation 39 | chosen for the manifold. 40 | 41 | Parameters 42 | ---------- 43 | points : array-like, shape=[n_samples, dimension] 44 | Input points. 45 | 46 | Returns 47 | ------- 48 | regularized_point : array-like, shape=[n_samples, dimension] 49 | """ 50 | regularized_point = point 51 | return regularized_point 52 | -------------------------------------------------------------------------------- /geomstats/backend/numpy_linalg.py: -------------------------------------------------------------------------------- 1 | """Numpy based linear algebra backend.""" 2 | 3 | import numpy as np 4 | import scipy.linalg 5 | 6 | 7 | def expm(x): 8 | return np.vectorize( 9 | scipy.linalg.expm, signature='(n,m)->(n,m)')(x) 10 | 11 | 12 | def logm(x): 13 | return np.vectorize( 14 | scipy.linalg.logm, signature='(n,m)->(n,m)')(x) 15 | 16 | 17 | def sqrtm(x): 18 | return np.vectorize( 19 | scipy.linalg.sqrtm, signature='(n,m)->(n,m)')(x) 20 | 21 | 22 | def det(*args, **kwargs): 23 | return np.linalg.det(*args, **kwargs) 24 | 25 | 26 | def norm(*args, **kwargs): 27 | return np.linalg.norm(*args, **kwargs) 28 | 29 | 30 | def inv(*args, **kwargs): 31 | return np.linalg.inv(*args, **kwargs) 32 | 33 | 34 | def matrix_rank(*args, **kwargs): 35 | return np.linalg.matrix_rank(*args, **kwargs) 36 | 37 | 38 | def eigvalsh(*args, **kwargs): 39 | return np.linalg.eigvalsh(*args, **kwargs) 40 | 41 | 42 | def svd(*args, **kwargs): 43 | return np.linalg.svd(*args, **kwargs) 44 | 45 | 46 | def eigh(*args, **kwargs): 47 | return np.linalg.eigh(*args, **kwargs) 48 | 49 | 50 | def eig(*args, **kwargs): 51 | return np.linalg.eig(*args, **kwargs) 52 | 53 | 54 | def exp(*args, **kwargs): 55 | return np.exp(*args, **kwargs) 56 | 57 | 58 | def qr(*args, **kwargs): 59 | return np.vectorize( 60 | np.linalg.qr, 61 | signature='(n,m)->(n,k),(k,m)', 62 | excluded=['mode'])(*args, **kwargs) 63 | -------------------------------------------------------------------------------- /examples/plot_quantization_s1.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot the result of optimal quantization of the uniform distribution 3 | on the circle. 4 | """ 5 | 6 | import matplotlib.pyplot as plt 7 | import os 8 | 9 | import geomstats.visualization as visualization 10 | 11 | from geomstats.hypersphere import Hypersphere 12 | 13 | CIRCLE = Hypersphere(dimension=1) 14 | METRIC = CIRCLE.metric 15 | N_POINTS = 1000 16 | N_CENTERS = 5 17 | N_REPETITIONS = 20 18 | TOLERANCE = 1e-6 19 | 20 | 21 | def main(): 22 | points = CIRCLE.random_uniform(n_samples=N_POINTS, bound=None) 23 | 24 | centers, weights, clusters, n_iterations = METRIC.optimal_quantization( 25 | points=points, n_centers=N_CENTERS, 26 | n_repetitions=N_REPETITIONS, tolerance=TOLERANCE 27 | ) 28 | 29 | plt.figure(0) 30 | visualization.plot(points=centers, space='S1', color='red') 31 | plt.show() 32 | 33 | plt.figure(1) 34 | ax = plt.axes() 35 | circle = visualization.Circle() 36 | circle.draw(ax=ax) 37 | for i in range(N_CENTERS): 38 | circle.draw_points(ax=ax, points=clusters[i]) 39 | plt.show() 40 | 41 | 42 | if __name__ == "__main__": 43 | if os.environ['GEOMSTATS_BACKEND'] == 'tensorflow': 44 | print('Examples with visualizations are only implemented ' 45 | 'with numpy backend.\n' 46 | 'To change backend, write: ' 47 | 'export GEOMSTATS_BACKEND = \'numpy\'.') 48 | else: 49 | main() 50 | -------------------------------------------------------------------------------- /tests/test_backend_tensorflow.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for tensorflow backend. 3 | """ 4 | 5 | import importlib 6 | import os 7 | import tensorflow as tf 8 | 9 | import geomstats.backend as gs 10 | 11 | 12 | class TestBackendTensorFlow(tf.test.TestCase): 13 | _multiprocess_can_split_ = True 14 | 15 | @classmethod 16 | def setUpClass(cls): 17 | cls.initial_backend = os.environ['GEOMSTATS_BACKEND'] 18 | os.environ['GEOMSTATS_BACKEND'] = 'tensorflow' 19 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 20 | importlib.reload(gs) 21 | 22 | @classmethod 23 | def tearDownClass(cls): 24 | os.environ['GEOMSTATS_BACKEND'] = cls.initial_backend 25 | importlib.reload(gs) 26 | 27 | def test_vstack(self): 28 | with self.test_session(): 29 | tensor_1 = tf.convert_to_tensor([[1., 2., 3.], [4., 5., 6.]]) 30 | tensor_2 = tf.convert_to_tensor([[7., 8., 9.]]) 31 | 32 | result = gs.vstack([tensor_1, tensor_2]) 33 | expected = tf.convert_to_tensor([ 34 | [1., 2., 3.], 35 | [4., 5., 6.], 36 | [7., 8., 9.]]) 37 | self.assertAllClose(result, expected) 38 | 39 | def test_tensor_addition(self): 40 | with self.test_session(): 41 | tensor_1 = gs.ones((1, 1)) 42 | tensor_2 = gs.ones((0, 1)) 43 | 44 | result = tensor_1 + tensor_2 45 | 46 | 47 | if __name__ == '__main__': 48 | tf.test.main() 49 | -------------------------------------------------------------------------------- /examples/plot_quantization_s2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot the result of optimal quantization of the von Mises Fisher distribution 3 | on the sphere 4 | """ 5 | 6 | import matplotlib.pyplot as plt 7 | import os 8 | 9 | import geomstats.visualization as visualization 10 | 11 | from geomstats.hypersphere import Hypersphere 12 | 13 | SPHERE2 = Hypersphere(dimension=2) 14 | METRIC = SPHERE2.metric 15 | N_POINTS = 1000 16 | N_CENTERS = 4 17 | N_REPETITIONS = 20 18 | KAPPA = 10 19 | 20 | 21 | def main(): 22 | points = SPHERE2.random_von_mises_fisher(kappa=KAPPA, n_samples=N_POINTS) 23 | 24 | centers, weights, clusters, n_steps = METRIC.optimal_quantization( 25 | points=points, n_centers=N_CENTERS, 26 | n_repetitions=N_REPETITIONS 27 | ) 28 | 29 | plt.figure(0) 30 | ax = plt.subplot(111, projection="3d") 31 | visualization.plot(points=centers, ax=ax, space='S2', c='r') 32 | plt.show() 33 | 34 | plt.figure(1) 35 | ax = plt.subplot(111, projection="3d") 36 | sphere = visualization.Sphere() 37 | sphere.draw(ax=ax) 38 | for i in range(N_CENTERS): 39 | sphere.draw_points(ax=ax, points=clusters[i]) 40 | plt.show() 41 | 42 | 43 | if __name__ == "__main__": 44 | if os.environ['GEOMSTATS_BACKEND'] == 'tensorflow': 45 | print('Examples with visualizations are only implemented ' 46 | 'with numpy backend.\n' 47 | 'To change backend, write: ' 48 | 'export GEOMSTATS_BACKEND = \'numpy\'.') 49 | else: 50 | main() 51 | -------------------------------------------------------------------------------- /docs/tutorials.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | 4 | **Choosing the backend.** 5 | 6 | You need to set the environment variable GEOMSTATS_BACKEND to numpy, tensorflow or pytorch. Only use the numpy backend for examples with visualizations. 7 | 8 | .. code-block:: bash 9 | 10 | export GEOMSTATS_BACKEND=numpy 11 | 12 | **A first python example.** 13 | 14 | This example shows how to compute a geodesic on the Lie group SE(3), which is the group of rotations and translations in 3D. 15 | 16 | .. code-block:: python 17 | 18 | """ 19 | Plot a geodesic of SE(3) equipped 20 | with its left-invariant canonical metric. 21 | """ 22 | 23 | import matplotlib.pyplot as plt 24 | import numpy as np 25 | import os 26 | 27 | import geomstats.visualization as visualization 28 | 29 | from geomstats.special_euclidean_group import SpecialEuclideanGroup 30 | 31 | SE3_GROUP = SpecialEuclideanGroup(n=3) 32 | METRIC = SE3_GROUP.left_canonical_metric 33 | 34 | initial_point = SE3_GROUP.identity 35 | initial_tangent_vec = [1.8, 0.2, 0.3, 3., 3., 1.] 36 | geodesic = METRIC.geodesic(initial_point=initial_point, 37 | initial_tangent_vec=initial_tangent_vec) 38 | 39 | n_steps = 40 40 | t = np.linspace(-3, 3, n_steps) 41 | 42 | points = geodesic(t) 43 | 44 | visualization.plot(points, space='SE3_GROUP') 45 | plt.show() 46 | 47 | **More examples.** 48 | 49 | You can find more examples in the repository "examples" of geomstats. You can run them from the command line as follows. 50 | 51 | .. code-block:: bash 52 | 53 | python3 examples/plot_grid_h2.py 54 | -------------------------------------------------------------------------------- /examples/plot_square_h2_poincare_disk.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot a square on H2 with Poincare Disk visualization. 3 | """ 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import os 8 | 9 | import geomstats.visualization as visualization 10 | 11 | from geomstats.hyperbolic_space import HyperbolicSpace 12 | 13 | H2 = HyperbolicSpace(dimension=2) 14 | METRIC = H2.metric 15 | 16 | SQUARE_SIZE = 50 17 | 18 | 19 | def main(): 20 | top = SQUARE_SIZE / 2.0 21 | bot = - SQUARE_SIZE / 2.0 22 | left = - SQUARE_SIZE / 2.0 23 | right = SQUARE_SIZE / 2.0 24 | corners_int = [(bot, left), (bot, right), (top, right), (top, left)] 25 | corners_ext = H2.intrinsic_to_extrinsic_coords(corners_int) 26 | n_steps = 20 27 | ax = plt.gca() 28 | for i, src in enumerate(corners_ext): 29 | dst_id = (i+1) % len(corners_ext) 30 | dst = corners_ext[dst_id] 31 | tangent_vec = METRIC.log(point=dst, base_point=src) 32 | geodesic = METRIC.geodesic(initial_point=src, 33 | initial_tangent_vec=tangent_vec) 34 | t = np.linspace(0, 1, n_steps) 35 | edge_points = geodesic(t) 36 | 37 | visualization.plot( 38 | edge_points, 39 | ax=ax, 40 | space='H2_poincare_disk', 41 | marker='.', 42 | color='black') 43 | 44 | plt.show() 45 | 46 | 47 | if __name__ == "__main__": 48 | if os.environ['GEOMSTATS_BACKEND'] == 'tensorflow': 49 | print('Examples with visualizations are only implemented ' 50 | 'with numpy backend.\n' 51 | 'To change backend, write: ' 52 | 'export GEOMSTATS_BACKEND = \'numpy\'.') 53 | else: 54 | main() 55 | -------------------------------------------------------------------------------- /examples/plot_square_h2_poincare_half_plane.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot a square on H2 with Poincare Disk visualization. 3 | """ 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import os 8 | 9 | import geomstats.visualization as visualization 10 | 11 | from geomstats.hyperbolic_space import HyperbolicSpace 12 | 13 | H2 = HyperbolicSpace(dimension=2) 14 | METRIC = H2.metric 15 | 16 | SQUARE_SIZE = 50 17 | 18 | 19 | def main(): 20 | top = SQUARE_SIZE / 2.0 21 | bot = - SQUARE_SIZE / 2.0 22 | left = - SQUARE_SIZE / 2.0 23 | right = SQUARE_SIZE / 2.0 24 | corners_int = [(bot, left), (bot, right), (top, right), (top, left)] 25 | corners_ext = H2.intrinsic_to_extrinsic_coords(corners_int) 26 | n_steps = 20 27 | ax = plt.gca() 28 | for i, src in enumerate(corners_ext): 29 | dst_id = (i+1) % len(corners_ext) 30 | dst = corners_ext[dst_id] 31 | tangent_vec = METRIC.log(point=dst, base_point=src) 32 | geodesic = METRIC.geodesic(initial_point=src, 33 | initial_tangent_vec=tangent_vec) 34 | t = np.linspace(0, 1, n_steps) 35 | edge_points = geodesic(t) 36 | 37 | visualization.plot( 38 | edge_points, 39 | ax=ax, 40 | space='H2_poincare_half_plane', 41 | marker='.', 42 | color='black') 43 | 44 | plt.show() 45 | 46 | 47 | if __name__ == "__main__": 48 | if os.environ['GEOMSTATS_BACKEND'] == 'tensorflow': 49 | print('Examples with visualizations are only implemented ' 50 | 'with numpy backend.\n' 51 | 'To change backend, write: ' 52 | 'export GEOMSTATS_BACKEND = \'numpy\'.') 53 | else: 54 | main() 55 | -------------------------------------------------------------------------------- /examples/plot_grid_h2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot a grid on H2 3 | with Poincare Disk visualization. 4 | """ 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import os 9 | 10 | import geomstats.visualization as visualization 11 | 12 | from geomstats.hyperbolic_space import HyperbolicSpace 13 | 14 | H2 = HyperbolicSpace(dimension=2) 15 | METRIC = H2.metric 16 | 17 | 18 | def main(left=-128, 19 | right=128, 20 | bottom=-128, 21 | top=128, 22 | grid_size=32, 23 | n_steps=512): 24 | starts = [] 25 | ends = [] 26 | for p in np.linspace(left, right, grid_size): 27 | starts.append(np.array([top, p])) 28 | ends.append(np.array([bottom, p])) 29 | for p in np.linspace(top, bottom, grid_size): 30 | starts.append(np.array([p, left])) 31 | ends.append(np.array([p, right])) 32 | starts = [H2.intrinsic_to_extrinsic_coords(s) for s in starts] 33 | ends = [H2.intrinsic_to_extrinsic_coords(e) for e in ends] 34 | ax = plt.gca() 35 | for start, end in zip(starts, ends): 36 | geodesic = METRIC.geodesic(initial_point=start, 37 | end_point=end) 38 | 39 | t = np.linspace(0, 1, n_steps) 40 | points_to_plot = geodesic(t) 41 | visualization.plot( 42 | points_to_plot, ax=ax, space='H2_poincare_disk', marker='.', s=1) 43 | plt.show() 44 | 45 | 46 | if __name__ == "__main__": 47 | if os.environ['GEOMSTATS_BACKEND'] == 'tensorflow': 48 | print('Examples with visualizations are only implemented ' 49 | 'with numpy backend.\n' 50 | 'To change backend, write: ' 51 | 'export GEOMSTATS_BACKEND = \'numpy\'.') 52 | else: 53 | main() 54 | -------------------------------------------------------------------------------- /examples/plot_square_h2_klein_disk.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot a square on H2 with Poincare Disk visualization. 3 | """ 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import os 8 | 9 | import geomstats.visualization as visualization 10 | 11 | from geomstats.hyperbolic_space import HyperbolicSpace 12 | 13 | H2 = HyperbolicSpace(dimension=2) 14 | METRIC = H2.metric 15 | 16 | SQUARE_SIZE = 50 17 | 18 | 19 | def main(): 20 | top = SQUARE_SIZE / 2.0 21 | bot = - SQUARE_SIZE / 2.0 22 | left = - SQUARE_SIZE / 2.0 23 | right = SQUARE_SIZE / 2.0 24 | corners_int = [(bot, left), (bot, right), (top, right), (top, left)] 25 | corners_ext = H2.intrinsic_to_extrinsic_coords(corners_int) 26 | n_steps = 20 27 | ax = plt.gca() 28 | for i, src in enumerate(corners_ext): 29 | dst_id = (i+1) % len(corners_ext) 30 | dst = corners_ext[dst_id] 31 | tangent_vec = METRIC.log(point=dst, base_point=src) 32 | geodesic = METRIC.geodesic(initial_point=src, 33 | initial_tangent_vec=tangent_vec) 34 | t = np.linspace(0, 1, n_steps) 35 | edge_points = geodesic(t) 36 | 37 | visualization.plot(edge_points, 38 | ax=ax, 39 | space='H2_klein_disk', 40 | marker='.', 41 | color='black') 42 | plt.show() 43 | 44 | 45 | if __name__ == "__main__": 46 | if os.environ['GEOMSTATS_BACKEND'] == 'tensorflow': 47 | print('Examples with visualizations are only implemented ' 48 | 'with numpy backend.\n' 49 | 'To change backend, write: ' 50 | 'export GEOMSTATS_BACKEND = \'numpy\'.') 51 | else: 52 | main() 53 | -------------------------------------------------------------------------------- /examples/tangent_pca_so3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compute the mean of a data set of 3D rotations. 3 | Performs tangent PCA at the mean. 4 | """ 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | 9 | import geomstats.visualization as visualization 10 | 11 | from geomstats.learning.pca import TangentPCA 12 | from geomstats.special_orthogonal_group import SpecialOrthogonalGroup 13 | 14 | SO3_GROUP = SpecialOrthogonalGroup(n=3) 15 | METRIC = SO3_GROUP.bi_invariant_metric 16 | 17 | N_SAMPLES = 10 18 | N_COMPONENTS = 2 19 | 20 | 21 | def main(): 22 | fig = plt.figure(figsize=(15, 5)) 23 | 24 | data = SO3_GROUP.random_uniform(n_samples=N_SAMPLES) 25 | mean = METRIC.mean(data) 26 | 27 | tpca = TangentPCA(metric=METRIC, n_components=N_COMPONENTS) 28 | tpca = tpca.fit(data, base_point=mean) 29 | tangent_projected_data = tpca.transform(data) 30 | print( 31 | 'Coordinates of the Log of the first 5 data points at the mean, ' 32 | 'projected on the principal components:') 33 | print(tangent_projected_data[:5]) 34 | 35 | ax_var = fig.add_subplot(121) 36 | xticks = np.arange(1, N_COMPONENTS+1, 1) 37 | ax_var.xaxis.set_ticks(xticks) 38 | ax_var.set_title('Explained variance') 39 | ax_var.set_xlabel('Number of Principal Components') 40 | ax_var.set_ylim((0, 1)) 41 | ax_var.plot(xticks, tpca.explained_variance_ratio_) 42 | 43 | ax = fig.add_subplot(122, projection="3d") 44 | plt.setp(ax, xlabel="X", ylabel="Y", zlabel="Z") 45 | 46 | ax.set_title('Data in SO3 (black) and Frechet mean (color)') 47 | visualization.plot(data, ax, space='SO3_GROUP', color='black') 48 | visualization.plot(mean, ax, space='SO3_GROUP', linewidth=3) 49 | ax.set_xlim((-2, 2)) 50 | ax.set_ylim((-2, 2)) 51 | ax.set_zlim((-2, 2)) 52 | plt.show() 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /geomstats/backend/tensorflow_linalg.py: -------------------------------------------------------------------------------- 1 | """Tensorflow based linear algebra backend.""" 2 | 3 | import tensorflow as tf 4 | 5 | from geomstats.backend.tensorflow import to_ndarray 6 | 7 | 8 | def sqrtm(sym_mat): 9 | sym_mat = to_ndarray(sym_mat, to_ndim=3) 10 | 11 | [eigenvalues, vectors] = tf.linalg.eigh(sym_mat) 12 | 13 | sqrt_eigenvalues = tf.sqrt(eigenvalues) 14 | 15 | aux = tf.einsum('ijk,ik->ijk', vectors, sqrt_eigenvalues) 16 | sqrt_mat = tf.einsum('ijk,ilk->ijl', aux, vectors) 17 | 18 | sqrt_mat = to_ndarray(sqrt_mat, to_ndim=3) 19 | return sqrt_mat 20 | 21 | 22 | def expm(x): 23 | return tf.linalg.expm(x) 24 | 25 | 26 | def logm(x): 27 | return tf.linalg.expm(x) 28 | 29 | 30 | def logm(x): 31 | x = tf.cast(x, tf.complex64) 32 | logm = tf.linalg.logm(x) 33 | logm = tf.cast(logm, tf.float32) 34 | return logm 35 | 36 | 37 | def det(x): 38 | return tf.linalg.det(x) 39 | 40 | 41 | def eigh(x): 42 | return tf.linalg.eigh(x) 43 | 44 | 45 | def eig(x): 46 | return tf.linalg.eig(x) 47 | 48 | 49 | def svd(x): 50 | s, u, v_t = tf.svd(x, full_matrices=True) 51 | return u, s, tf.transpose(v_t, perm=(0, 2, 1)) 52 | 53 | 54 | def norm(x, axis=None): 55 | return tf.linalg.norm(x, axis=axis) 56 | 57 | 58 | def inv(x): 59 | return tf.linalg.inv(x) 60 | 61 | 62 | def matrix_rank(x): 63 | return tf.rank(x) 64 | 65 | 66 | def eigvalsh(x): 67 | return tf.linalg.eigvalsh(x) 68 | 69 | 70 | def qr(*args, mode='reduced'): 71 | def qr_aux(x, mode): 72 | if mode == 'complete': 73 | aux = tf.linalg.qr(x, full_matrices=True) 74 | else: 75 | aux = tf.linalg.qr(x) 76 | 77 | return (aux.q, aux.r) 78 | 79 | qr = tf.map_fn( 80 | lambda x: qr_aux(x, mode), 81 | *args, 82 | dtype=(tf.float32, tf.float32)) 83 | 84 | return qr 85 | -------------------------------------------------------------------------------- /tests/test_template.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from sklearn.datasets import load_iris 5 | from sklearn.utils.testing import assert_array_equal 6 | from sklearn.utils.testing import assert_allclose 7 | 8 | from geomstats.learning._template import TemplateEstimator 9 | from geomstats.learning._template import TemplateTransformer 10 | from geomstats.learning._template import TemplateClassifier 11 | 12 | 13 | @pytest.fixture 14 | def data(): 15 | return load_iris(return_X_y=True) 16 | 17 | 18 | def test_template_estimator(data): 19 | est = TemplateEstimator() 20 | assert est.demo_param == 'demo_param' 21 | 22 | est.fit(*data) 23 | assert hasattr(est, 'is_fitted_') 24 | 25 | X = data[0] 26 | y_pred = est.predict(X) 27 | assert_array_equal(y_pred, np.ones(X.shape[0], dtype=np.int64)) 28 | 29 | 30 | def test_template_transformer_error(data): 31 | X, y = data 32 | trans = TemplateTransformer() 33 | trans.fit(X) 34 | with pytest.raises(ValueError, match="Shape of input is different"): 35 | X_diff_size = np.ones((10, X.shape[1] + 1)) 36 | trans.transform(X_diff_size) 37 | 38 | 39 | def test_template_transformer(data): 40 | X, y = data 41 | trans = TemplateTransformer() 42 | assert trans.demo_param == 'demo' 43 | 44 | trans.fit(X) 45 | assert trans.n_features_ == X.shape[1] 46 | 47 | X_trans = trans.transform(X) 48 | assert_allclose(X_trans, np.sqrt(X)) 49 | 50 | X_trans = trans.fit_transform(X) 51 | assert_allclose(X_trans, np.sqrt(X)) 52 | 53 | 54 | def test_template_classifier(data): 55 | X, y = data 56 | clf = TemplateClassifier() 57 | assert clf.demo_param == 'demo' 58 | 59 | clf.fit(X, y) 60 | assert hasattr(clf, 'classes_') 61 | assert hasattr(clf, 'X_') 62 | assert hasattr(clf, 'y_') 63 | 64 | y_pred = clf.predict(X) 65 | assert y_pred.shape == (X.shape[0],) 66 | -------------------------------------------------------------------------------- /tests/test_connection.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for the affine connections. 3 | """ 4 | 5 | import geomstats.backend as gs 6 | import geomstats.tests 7 | 8 | from geomstats.connection import LeviCivitaConnection 9 | from geomstats.euclidean_space import EuclideanMetric 10 | 11 | 12 | class TestConnectionMethods(geomstats.tests.TestCase): 13 | _multiprocess_can_split_ = True 14 | 15 | def setUp(self): 16 | self.dimension = 4 17 | self.metric = EuclideanMetric(dimension=self.dimension) 18 | self.connection = LeviCivitaConnection(self.metric) 19 | 20 | def test_metric_matrix(self): 21 | base_point = gs.array([0., 1., 0., 0.]) 22 | 23 | result = self.connection.metric_matrix(base_point) 24 | expected = gs.array([gs.eye(self.dimension)]) 25 | 26 | with self.session(): 27 | self.assertAllClose(result, expected) 28 | 29 | def test_cometric_matrix(self): 30 | base_point = gs.array([0., 1., 0., 0.]) 31 | 32 | result = self.connection.cometric_matrix(base_point) 33 | expected = gs.array([gs.eye(self.dimension)]) 34 | 35 | with self.session(): 36 | self.assertAllClose(result, expected) 37 | 38 | @geomstats.tests.np_only 39 | def test_metric_derivative(self): 40 | base_point = gs.array([0., 1., 0., 0.]) 41 | 42 | result = self.connection.metric_derivative(base_point) 43 | expected = gs.zeros((1,) + (self.dimension, ) * 3) 44 | 45 | gs.testing.assert_allclose(result, expected) 46 | 47 | @geomstats.tests.np_only 48 | def test_christoffel_symbols(self): 49 | base_point = gs.array([0., 1., 0., 0.]) 50 | 51 | result = self.connection.christoffel_symbols(base_point) 52 | expected = gs.zeros((1,) + (self.dimension, ) * 3) 53 | 54 | gs.testing.assert_allclose(result, expected) 55 | 56 | 57 | if __name__ == '__main__': 58 | geomstats.tests.main() 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Geomstats 2 | [![Build Status](https://travis-ci.org/geomstats/geomstats.svg?branch=master)](https://travis-ci.org/geomstats/geomstats)[![Coverage Status](https://codecov.io/gh/geomstats/geomstats/branch/master/graph/badge.svg?flag=numpy)](https://codecov.io/gh/geomstats/geomstats)[![Coverage Status](https://codecov.io/gh/geomstats/geomstats/branch/master/graph/badge.svg?flag=tensorflow)](https://codecov.io/gh/geomstats/geomstats)[![Coverage Status](https://codecov.io/gh/geomstats/geomstats/branch/master/graph/badge.svg?flag=pytorch)](https://codecov.io/gh/geomstats/geomstats) (Coverages for: numpy, tensorflow, pytorch) 3 | 4 | 5 | Computations and statistics on manifolds with geometric structures. 6 | 7 | - To get started with ```geomstats```, see the [examples directory](https://github.com/geomstats/geomstats/examples). 8 | - For more in-depth applications of ``geomstats``, see the [applications repository](https://github.com/geomstats/applications/). 9 | - The documentation of ```geomstats``` can be found on the [documentation website](https://geomstats.github.io/). 10 | - If you use ``geomstats``, please kindly cite our [paper](https://arxiv.org/abs/1805.08308). 11 | 12 | 13 | 14 | 15 | ## Installation 16 | 17 | OS X & Linux: 18 | 19 | ``` 20 | pip3 install geomstats 21 | ``` 22 | 23 | ## Running tests 24 | 25 | ``` 26 | pip3 install nose2 27 | nose2 28 | ``` 29 | 30 | ## Getting started 31 | 32 | Define your backend by setting the environment variable ```GEOMSTATS_BACKEND``` to ```numpy```, ```tensorflow```, or ```pytorch```: 33 | 34 | ``` 35 | export GEOMSTATS_BACKEND=numpy 36 | ``` 37 | 38 | Then, run example scripts: 39 | 40 | ``` 41 | python3 examples/plot_grid_h2.py 42 | ``` 43 | 44 | ## Contributing 45 | 46 | See our [CONTRIBUTING.md][link_contributing] file! 47 | 48 | ## Authors & Contributors 49 | 50 | * Alice Le Brigant 51 | * Claire Donnat 52 | * Oleg Kachan 53 | * Benjamin Hou 54 | * Johan Mathe 55 | * Nina Miolane 56 | * Xavier Pennec 57 | 58 | ## Acknowledgements 59 | 60 | This work is partially supported by the National Science Foundation, grant NSF DMS RTG 1501767. 61 | 62 | [link_contributing]: https://github.com/geomstats/geomstats/CONTRIBUTING.md 63 | -------------------------------------------------------------------------------- /tests/test_visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for visualization. 3 | """ 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') # NOQA 7 | import matplotlib.pyplot as plt 8 | 9 | import geomstats.tests 10 | import geomstats.visualization as visualization 11 | 12 | from geomstats.hyperbolic_space import HyperbolicSpace 13 | from geomstats.hypersphere import Hypersphere 14 | from geomstats.special_euclidean_group import SpecialEuclideanGroup 15 | from geomstats.special_orthogonal_group import SpecialOrthogonalGroup 16 | 17 | 18 | class TestVisualizationMethods(geomstats.tests.TestCase): 19 | _multiprocess_can_split_ = True 20 | 21 | def setUp(self): 22 | self.n_samples = 10 23 | self.SO3_GROUP = SpecialOrthogonalGroup(n=3) 24 | self.SE3_GROUP = SpecialEuclideanGroup(n=3) 25 | self.S1 = Hypersphere(dimension=1) 26 | self.S2 = Hypersphere(dimension=2) 27 | self.H2 = HyperbolicSpace(dimension=2) 28 | 29 | plt.figure() 30 | 31 | @geomstats.tests.np_only 32 | def test_plot_points_so3(self): 33 | points = self.SO3_GROUP.random_uniform(self.n_samples) 34 | visualization.plot(points, space='SO3_GROUP') 35 | 36 | @geomstats.tests.np_only 37 | def test_plot_points_se3(self): 38 | points = self.SE3_GROUP.random_uniform(self.n_samples) 39 | visualization.plot(points, space='SE3_GROUP') 40 | 41 | @geomstats.tests.np_only 42 | def test_plot_points_s1(self): 43 | points = self.S1.random_uniform(self.n_samples) 44 | visualization.plot(points, space='S1') 45 | 46 | @geomstats.tests.np_only 47 | def test_plot_points_s2(self): 48 | points = self.S2.random_uniform(self.n_samples) 49 | visualization.plot(points, space='S2') 50 | 51 | @geomstats.tests.np_only 52 | def test_plot_points_h2_poincare_disk(self): 53 | points = self.H2.random_uniform(self.n_samples) 54 | visualization.plot(points, space='H2_poincare_disk') 55 | 56 | @geomstats.tests.np_only 57 | def test_plot_points_h2_poincare_half_plane(self): 58 | points = self.H2.random_uniform(self.n_samples) 59 | visualization.plot(points, space='H2_poincare_half_plane') 60 | 61 | @geomstats.tests.np_only 62 | def test_plot_points_h2_klein_disk(self): 63 | points = self.H2.random_uniform(self.n_samples) 64 | visualization.plot(points, space='H2_klein_disk') 65 | 66 | 67 | if __name__ == '__main__': 68 | geomstats.tests.main() 69 | -------------------------------------------------------------------------------- /geomstats/tests.py: -------------------------------------------------------------------------------- 1 | """ 2 | Testing class for geomstats. 3 | 4 | This class abstracts the backend type. 5 | """ 6 | 7 | import os 8 | import tensorflow as tf 9 | import unittest 10 | 11 | import geomstats.backend as gs 12 | 13 | 14 | def pytorch_backend(): 15 | return os.environ['GEOMSTATS_BACKEND'] == 'pytorch' 16 | 17 | 18 | def tf_backend(): 19 | return os.environ['GEOMSTATS_BACKEND'] == 'tensorflow' 20 | 21 | 22 | def np_backend(): 23 | return os.environ['GEOMSTATS_BACKEND'] == 'numpy' 24 | 25 | 26 | test_class = unittest.TestCase 27 | if tf_backend(): 28 | test_class = tf.test.TestCase 29 | 30 | 31 | def np_only(test_item): 32 | """Decorator to filter tests for numpy only.""" 33 | if not np_backend(): 34 | test_item.__unittest_skip__ = True 35 | test_item.__unittest_skip_why__ = ( 36 | 'Test for numpy backend only.') 37 | return test_item 38 | 39 | 40 | def np_and_tf_only(test_item): 41 | """Decorator to filter tests for numpy and tensorflow only.""" 42 | if not (np_backend() or tf_backend()): 43 | test_item.__unittest_skip__ = True 44 | test_item.__unittest_skip_why__ = ( 45 | 'Test for numpy and tensorflow backends only.') 46 | return test_item 47 | 48 | 49 | def np_and_pytorch_only(test_item): 50 | """Decorator to filter tests for numpy and pytorch only.""" 51 | if not (np_backend() or pytorch_backend()): 52 | test_item.__unittest_skip__ = True 53 | test_item.__unittest_skip_why__ = ( 54 | 'Test for numpy and pytorch backends only.') 55 | return test_item 56 | 57 | 58 | class DummySession(): 59 | def __enter__(self): 60 | pass 61 | 62 | def __exit__(self, a, b, c): 63 | pass 64 | 65 | 66 | class TestCase(test_class): 67 | 68 | def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6): 69 | if tf_backend(): 70 | return super().assertAllClose(a, b, rtol=rtol, atol=atol) 71 | return self.assertTrue(gs.allclose(a, b, rtol=rtol, atol=atol)) 72 | 73 | def session(self): 74 | if tf_backend(): 75 | return super().test_session() 76 | return DummySession() 77 | 78 | def assertShapeEqual(self, a, b): 79 | if tf_backend(): 80 | return super().assertShapeEqual(a, b) 81 | super().assertEqual(a.shape, b.shape) 82 | 83 | @classmethod 84 | def setUpClass(cls): 85 | if tf_backend(): 86 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 87 | -------------------------------------------------------------------------------- /examples/plot_geodesics_h2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot a geodesic on the Hyperbolic space H2, 3 | with Poincare Disk visualization. 4 | """ 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import os 9 | 10 | import geomstats.visualization as visualization 11 | 12 | from geomstats.hyperbolic_space import HyperbolicSpace 13 | 14 | H2 = HyperbolicSpace(dimension=2) 15 | METRIC = H2.metric 16 | 17 | 18 | def plot_geodesic_between_two_points(initial_point, 19 | end_point, 20 | n_steps=10, 21 | ax=None): 22 | assert H2.belongs(initial_point) 23 | assert H2.belongs(end_point) 24 | 25 | geodesic = METRIC.geodesic(initial_point=initial_point, 26 | end_point=end_point) 27 | 28 | t = np.linspace(0, 1, n_steps) 29 | points = geodesic(t) 30 | visualization.plot(points, ax=ax, space='H2_poincare_disk') 31 | 32 | 33 | def plot_geodesic_with_initial_tangent_vector(initial_point, 34 | initial_tangent_vec, 35 | n_steps=10, 36 | ax=None): 37 | assert H2.belongs(initial_point) 38 | geodesic = METRIC.geodesic(initial_point=initial_point, 39 | initial_tangent_vec=initial_tangent_vec) 40 | n_steps = 10 41 | t = np.linspace(0, 1, n_steps) 42 | 43 | points = geodesic(t) 44 | visualization.plot(points, ax=ax, space='H2_poincare_disk') 45 | 46 | 47 | def main(): 48 | initial_point = [np.sqrt(2), 1., 0.] 49 | end_point = H2.intrinsic_to_extrinsic_coords([1.5, 1.5]) 50 | initial_tangent_vec = H2.projection_to_tangent_space( 51 | vector=[3.5, 0.6, 0.8], 52 | base_point=initial_point) 53 | 54 | ax = plt.gca() 55 | plot_geodesic_between_two_points(initial_point, 56 | end_point, 57 | ax=ax) 58 | plot_geodesic_with_initial_tangent_vector(initial_point, 59 | initial_tangent_vec, 60 | ax=ax) 61 | plt.show() 62 | 63 | 64 | if __name__ == "__main__": 65 | if os.environ['GEOMSTATS_BACKEND'] == 'tensorflow': 66 | print('Examples with visualizations are only implemented ' 67 | 'with numpy backend.\n' 68 | 'To change backend, write: ' 69 | 'export GEOMSTATS_BACKEND = \'numpy\'.') 70 | else: 71 | main() 72 | -------------------------------------------------------------------------------- /tests/helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions for unit tests. 3 | """ 4 | 5 | import geomstats.backend as gs 6 | 7 | 8 | def to_scalar(expected): 9 | expected = gs.to_ndarray(expected, to_ndim=1) 10 | expected = gs.to_ndarray(expected, to_ndim=2, axis=-1) 11 | return expected 12 | 13 | 14 | def to_vector(expected): 15 | expected = gs.to_ndarray(expected, to_ndim=2) 16 | return expected 17 | 18 | 19 | def to_matrix(expected): 20 | expected = gs.to_ndarray(expected, to_ndim=3) 21 | return expected 22 | 23 | 24 | def left_log_then_exp_from_identity(metric, point): 25 | aux = metric.left_log_from_identity(point=point) 26 | result = metric.left_exp_from_identity(tangent_vec=aux) 27 | return result 28 | 29 | 30 | def left_exp_then_log_from_identity(metric, tangent_vec): 31 | aux = metric.left_exp_from_identity(tangent_vec=tangent_vec) 32 | result = metric.left_log_from_identity(point=aux) 33 | return result 34 | 35 | 36 | def log_then_exp_from_identity(metric, point): 37 | aux = metric.log_from_identity(point=point) 38 | result = metric.exp_from_identity(tangent_vec=aux) 39 | return result 40 | 41 | 42 | def exp_then_log_from_identity(metric, tangent_vec): 43 | aux = metric.exp_from_identity(tangent_vec=tangent_vec) 44 | result = metric.log_from_identity(point=aux) 45 | return result 46 | 47 | 48 | def log_then_exp(metric, point, base_point): 49 | aux = metric.log(point=point, 50 | base_point=base_point) 51 | result = metric.exp(tangent_vec=aux, 52 | base_point=base_point) 53 | return result 54 | 55 | 56 | def exp_then_log(metric, tangent_vec, base_point): 57 | aux = metric.exp(tangent_vec=tangent_vec, 58 | base_point=base_point) 59 | result = metric.log(point=aux, 60 | base_point=base_point) 61 | return result 62 | 63 | 64 | def group_log_then_exp_from_identity(group, point): 65 | aux = group.group_log_from_identity(point=point) 66 | result = group.group_exp_from_identity(tangent_vec=aux) 67 | return result 68 | 69 | 70 | def group_exp_then_log_from_identity(group, tangent_vec): 71 | aux = group.group_exp_from_identity(tangent_vec=tangent_vec) 72 | result = group.group_log_from_identity(point=aux) 73 | return result 74 | 75 | 76 | def group_log_then_exp(group, point, base_point): 77 | aux = group.group_log(point=point, 78 | base_point=base_point) 79 | result = group.group_exp(tangent_vec=aux, 80 | base_point=base_point) 81 | return result 82 | 83 | 84 | def group_exp_then_log(group, tangent_vec, base_point): 85 | aux = group.group_exp(tangent_vec=tangent_vec, 86 | base_point=base_point) 87 | result = group.group_log(point=aux, 88 | base_point=base_point) 89 | return result 90 | -------------------------------------------------------------------------------- /docs/api-reference.rst: -------------------------------------------------------------------------------- 1 | ************* 2 | API reference 3 | ************* 4 | 5 | .. automodule:: geomstats 6 | :members: 7 | 8 | The "manifold" module 9 | --------------------- 10 | 11 | .. automodule:: geomstats.manifold 12 | :members: 13 | 14 | The "connection" module 15 | ----------------------- 16 | 17 | .. automodule:: geomstats.connection 18 | :members: 19 | 20 | The "riemannian_metric" module 21 | ------------------------------ 22 | 23 | .. automodule:: geomstats.riemannian_metric 24 | :members: 25 | 26 | Spaces of constant curvatures 27 | ============================= 28 | 29 | The "embedded_manifold" module 30 | ------------------------------ 31 | 32 | .. automodule:: geomstats.embedded_manifold 33 | :members: 34 | 35 | The "euclidean_space" module 36 | ---------------------------- 37 | 38 | .. automodule:: geomstats.euclidean_space 39 | :members: 40 | 41 | The "minkowski_space" module 42 | ---------------------------- 43 | 44 | .. automodule:: geomstats.minkowski_space 45 | :members: 46 | 47 | The "hypersphere" module 48 | ------------------------ 49 | 50 | .. automodule:: geomstats.hypersphere 51 | :members: 52 | 53 | The "hyperbolic_space" module 54 | ----------------------------- 55 | 56 | .. automodule:: geomstats.hyperbolic_space 57 | :members: 58 | 59 | Lie groups 60 | ========== 61 | 62 | The "lie_group" module 63 | ---------------------- 64 | 65 | .. automodule:: geomstats.lie_group 66 | :members: 67 | 68 | The "matrices_space" module 69 | --------------------------- 70 | 71 | .. automodule:: geomstats.matrices_space 72 | :members: 73 | 74 | The "general_linear_group" module 75 | --------------------------------- 76 | 77 | .. automodule:: geomstats.general_linear_group 78 | :members: 79 | 80 | The "invariant_metric" module 81 | ----------------------------- 82 | 83 | .. automodule:: geomstats.invariant_metric 84 | :members: 85 | 86 | The "special_euclidean_group" module 87 | ------------------------------------ 88 | 89 | .. automodule:: geomstats.special_euclidean_group 90 | :members: 91 | 92 | The "special_orthogonal_group" module 93 | ------------------------------------- 94 | 95 | .. automodule:: geomstats.special_orthogonal_group 96 | :members: 97 | 98 | More manifolds 99 | ============== 100 | 101 | The "discretized_curves_space" module 102 | ------------------------------------- 103 | 104 | .. automodule:: geomstats.discretized_curves_space 105 | :members: 106 | 107 | The "spd_matrices_space" module 108 | ------------------------------- 109 | 110 | .. automodule:: geomstats.spd_matrices_space 111 | :members: 112 | 113 | The "stiefel" module 114 | -------------------- 115 | 116 | .. automodule:: geomstats.stiefel 117 | :members: 118 | 119 | Visualization implementations 120 | ============================= 121 | 122 | .. automodule:: geomstats.visualization 123 | :members: 124 | 125 | -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for the examples. 3 | """ 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') # NOQA 7 | import matplotlib.pyplot as plt 8 | import os 9 | import sys 10 | import warnings 11 | 12 | import examples.gradient_descent_s2 as gradient_descent_s2 13 | import examples.loss_and_gradient_se3 as loss_and_gradient_se3 14 | import examples.loss_and_gradient_so3 as loss_and_gradient_so3 15 | import examples.plot_geodesics_h2 as plot_geodesics_h2 16 | import examples.plot_geodesics_s2 as plot_geodesics_s2 17 | import examples.plot_geodesics_se3 as plot_geodesics_se3 18 | import examples.plot_geodesics_so3 as plot_geodesics_so3 19 | import examples.plot_grid_h2 as plot_grid_h2 20 | import examples.plot_square_h2_poincare_disk as plot_square_h2_poincare_disk 21 | import examples.plot_square_h2_poincare_half_plane as plot_square_h2_poincare_half_plane # NOQA 22 | import examples.plot_square_h2_klein_disk as plot_square_h2_klein_disk 23 | import examples.plot_quantization_s1 as plot_quantization_s1 24 | import examples.plot_quantization_s2 as plot_quantization_s2 25 | import examples.tangent_pca_so3 as tangent_pca_so3 26 | import geomstats.tests 27 | 28 | 29 | class TestExamples(geomstats.tests.TestCase): 30 | _multiprocess_can_split_ = True 31 | 32 | @classmethod 33 | def setUpClass(cls): 34 | sys.stdout = open(os.devnull, 'w') 35 | 36 | def setUp(self): 37 | warnings.simplefilter('ignore', category=ImportWarning) 38 | plt.figure() 39 | 40 | @geomstats.tests.np_only 41 | def test_gradient_descent_s2(self): 42 | gradient_descent_s2.main(max_iter=32, output_file=None) 43 | 44 | @geomstats.tests.np_only 45 | def test_loss_and_gradient_so3(self): 46 | loss_and_gradient_so3.main() 47 | 48 | @geomstats.tests.np_only 49 | def test_loss_and_gradient_se3(self): 50 | loss_and_gradient_se3.main() 51 | 52 | @geomstats.tests.np_only 53 | def test_plot_geodesics_h2(self): 54 | plot_geodesics_h2.main() 55 | 56 | @geomstats.tests.np_only 57 | def test_plot_geodesics_s2(self): 58 | plot_geodesics_s2.main() 59 | 60 | @geomstats.tests.np_only 61 | def test_plot_geodesics_se3(self): 62 | plot_geodesics_se3.main() 63 | 64 | @geomstats.tests.np_only 65 | def test_plot_geodesics_so3(self): 66 | plot_geodesics_so3.main() 67 | 68 | @geomstats.tests.np_only 69 | def test_plot_grid_h2(self): 70 | plot_grid_h2.main() 71 | 72 | @geomstats.tests.np_only 73 | def test_plot_square_h2_square_poincare_disk(self): 74 | plot_square_h2_poincare_disk.main() 75 | 76 | @geomstats.tests.np_only 77 | def test_plot_square_h2_square_poincare_half_plane(self): 78 | plot_square_h2_poincare_half_plane.main() 79 | 80 | @geomstats.tests.np_only 81 | def test_plot_square_h2_square_klein_disk(self): 82 | plot_square_h2_klein_disk.main() 83 | 84 | @geomstats.tests.np_only 85 | def test_tangent_pca_so3(self): 86 | tangent_pca_so3.main() 87 | 88 | @geomstats.tests.np_only 89 | def test_plot_quantization_s1(self): 90 | plot_quantization_s1.main() 91 | 92 | @geomstats.tests.np_only 93 | def test_plot_quantization_s2(self): 94 | plot_quantization_s2.main() 95 | 96 | 97 | if __name__ == '__main__': 98 | geomstats.tests.main() 99 | -------------------------------------------------------------------------------- /geomstats/matrices_space.py: -------------------------------------------------------------------------------- 1 | """ 2 | The space of matrices (m, n), which is the Euclidean space R^{mn}. 3 | """ 4 | 5 | import geomstats.backend as gs 6 | 7 | from geomstats.euclidean_space import EuclideanSpace 8 | from geomstats.riemannian_metric import RiemannianMetric 9 | 10 | 11 | TOLERANCE = 1e-5 12 | 13 | 14 | class MatricesSpace(EuclideanSpace): 15 | """Class for the space of matrices (m, n).""" 16 | 17 | def __init__(self, m, n): 18 | assert isinstance(m, int) and isinstance(n, int) and m > 0 and n > 0 19 | super(MatricesSpace, self).__init__(dimension=m*n) 20 | self.m = m 21 | self.n = n 22 | self.default_point_type = 'matrix' 23 | self.metric = MatricesMetric(m, n) 24 | 25 | def belongs(self, point): 26 | """ 27 | Check if point belongs to the Matrix space. 28 | """ 29 | point = gs.to_ndarray(point, to_ndim=3) 30 | _, mat_dim_1, mat_dim_2 = point.shape 31 | return mat_dim_1 == self.m & mat_dim_2 == self.n 32 | 33 | @staticmethod 34 | def vector_from_matrix(matrix): 35 | """ 36 | Conversion function from (_, m, n) to (_, mn). 37 | """ 38 | matrix = gs.to_ndarray(matrix, to_ndim=3) 39 | n_mats, m, n = matrix.shape 40 | return gs.reshape(matrix, (n_mats, m*n)) 41 | 42 | @staticmethod 43 | def is_symmetric(matrix, tolerance=TOLERANCE): 44 | """Check if a matrix is symmetric.""" 45 | matrix = gs.to_ndarray(matrix, to_ndim=3) 46 | n_mats, m, n = matrix.shape 47 | assert m == n 48 | matrix_transpose = gs.transpose(matrix, axes=(0, 2, 1)) 49 | 50 | mask = gs.isclose(matrix, matrix_transpose, atol=tolerance) 51 | mask = gs.all(mask, axis=(1, 2)) 52 | 53 | mask = gs.to_ndarray(mask, to_ndim=1) 54 | mask = gs.to_ndarray(mask, to_ndim=2, axis=1) 55 | return mask 56 | 57 | @staticmethod 58 | def make_symmetric(matrix): 59 | """Make a matrix fully symmetric to avoid numerical issues.""" 60 | matrix = gs.to_ndarray(matrix, to_ndim=3) 61 | n_mats, m, n = matrix.shape 62 | assert m == n 63 | matrix = gs.to_ndarray(matrix, to_ndim=3) 64 | return (matrix + gs.transpose(matrix, axes=(0, 2, 1))) / 2 65 | 66 | def random_uniform(self, n_samples=1): 67 | point = gs.random.rand(n_samples, self.m, self.n) 68 | return point 69 | 70 | 71 | class MatricesMetric(RiemannianMetric): 72 | """ 73 | Euclidean metric on matrices given by the Frobenius inner product. 74 | """ 75 | def __init__(self, m, n): 76 | dimension = m*n 77 | super(MatricesMetric, self).__init__( 78 | dimension=dimension, 79 | signature=(dimension, 0, 0)) 80 | 81 | def inner_product(self, tangent_vec_a, tangent_vec_b, base_point=None): 82 | """ 83 | Compute the Frobenius inner product of tangent_vec_a and tangent_vec_b 84 | at base_point. 85 | """ 86 | tangent_vec_a = gs.to_ndarray(tangent_vec_a, to_ndim=3) 87 | n_tangent_vecs_a, _, _ = tangent_vec_a.shape 88 | 89 | tangent_vec_b = gs.to_ndarray(tangent_vec_b, to_ndim=3) 90 | n_tangent_vecs_b, _, _ = tangent_vec_b.shape 91 | 92 | assert n_tangent_vecs_a == n_tangent_vecs_b 93 | 94 | inner_prod = gs.einsum("nij,nij->n", tangent_vec_a, tangent_vec_b) 95 | inner_prod = gs.to_ndarray(inner_prod, to_ndim=1) 96 | inner_prod = gs.to_ndarray(inner_prod, to_ndim=2, axis=1) 97 | return inner_prod 98 | -------------------------------------------------------------------------------- /tests/test_matrices_space.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for the manifold of matrices. 3 | """ 4 | 5 | import geomstats.backend as gs 6 | import geomstats.tests 7 | import tests.helper as helper 8 | 9 | from geomstats.matrices_space import MatricesSpace 10 | 11 | 12 | class TestMatricesSpaceMethods(geomstats.tests.TestCase): 13 | _multiprocess_can_split_ = True 14 | 15 | def setUp(self): 16 | gs.random.seed(1234) 17 | 18 | self.n = 3 19 | self.space = MatricesSpace(m=self.n, n=self.n) 20 | self.metric = self.space.metric 21 | self.n_samples = 2 22 | 23 | @geomstats.tests.np_only 24 | def test_is_symmetric(self): 25 | sym_mat = gs.array([[1., 2.], 26 | [2., 1.]]) 27 | result = self.space.is_symmetric(sym_mat) 28 | expected = gs.array([True]) 29 | self.assertAllClose(result, expected) 30 | 31 | not_a_sym_mat = gs.array([[1., 0.6, -3.], 32 | [6., -7., 0.], 33 | [0., 7., 8.]]) 34 | result = self.space.is_symmetric(not_a_sym_mat) 35 | expected = gs.array([False]) 36 | self.assertAllClose(result, expected) 37 | 38 | @geomstats.tests.np_and_tf_only 39 | def test_is_symmetric_vectorization(self): 40 | points = gs.array([ 41 | [[1., 2.], 42 | [2., 1.]], 43 | [[3., 4.], 44 | [4., 5.]]]) 45 | result = gs.all(self.space.is_symmetric(points)) 46 | expected = True 47 | self.assertAllClose(result, expected) 48 | 49 | def test_make_symmetric(self): 50 | sym_mat = gs.array([[1., 2.], 51 | [2., 1.]]) 52 | result = self.space.make_symmetric(sym_mat) 53 | expected = helper.to_matrix(sym_mat) 54 | self.assertAllClose(result, expected) 55 | 56 | mat = gs.array([[1., 2., 3.], 57 | [0., 0., 0.], 58 | [3., 1., 1.]]) 59 | result = self.space.make_symmetric(mat) 60 | expected = gs.array([[[1., 1., 3.], 61 | [1., 0., 0.5], 62 | [3., 0.5, 1.]]]) 63 | self.assertAllClose(result, expected) 64 | 65 | mat = gs.array([[[1e100, 1e-100, 1e100], 66 | [1e100, 1e-100, 1e100], 67 | [1e-100, 1e-100, 1e100]]]) 68 | result = self.space.make_symmetric(mat) 69 | 70 | res = 0.5 * (1e100 + 1e-100) 71 | 72 | expected = gs.array([[[1e100, res, res], 73 | [res, 1e-100, res], 74 | [res, res, 1e100]]]) 75 | self.assertAllClose(result, expected) 76 | 77 | @geomstats.tests.np_and_tf_only 78 | def test_make_symmetric_and_is_symmetric_vectorization(self): 79 | points = gs.array([ 80 | [[1., 2.], 81 | [3., 4.]], 82 | [[5., 6.], 83 | [4., 9.]]]) 84 | 85 | sym_points = self.space.make_symmetric(points) 86 | result = gs.all(self.space.is_symmetric(sym_points)) 87 | expected = True 88 | self.assertAllClose(result, expected) 89 | 90 | def test_inner_product(self): 91 | base_point = gs.array([ 92 | [1., 2., 3.], 93 | [0., 0., 0.], 94 | [3., 1., 1.]]) 95 | 96 | tangent_vector_1 = gs.array([ 97 | [1., 2., 3.], 98 | [0., -10., 0.], 99 | [30., 1., 1.]]) 100 | 101 | tangent_vector_2 = gs.array([ 102 | [1., 4., 3.], 103 | [5., 0., 0.], 104 | [3., 1., 1.]]) 105 | 106 | result = self.metric.inner_product( 107 | tangent_vector_1, 108 | tangent_vector_2, 109 | base_point=base_point) 110 | 111 | expected = gs.trace( 112 | gs.matmul( 113 | gs.transpose(tangent_vector_1), 114 | tangent_vector_2)) 115 | expected = helper.to_scalar(expected) 116 | 117 | self.assertAllClose(result, expected) 118 | 119 | 120 | if __name__ == '__main__': 121 | geomstats.tests.main() 122 | -------------------------------------------------------------------------------- /geomstats/euclidean_space.py: -------------------------------------------------------------------------------- 1 | """ 2 | Euclidean space. 3 | """ 4 | 5 | import geomstats.backend as gs 6 | 7 | from geomstats.manifold import Manifold 8 | from geomstats.riemannian_metric import RiemannianMetric 9 | 10 | 11 | class EuclideanSpace(Manifold): 12 | """ 13 | Class for Euclidean spaces. 14 | 15 | By definition, a Euclidean space is a vector space of a given 16 | dimension, equipped with a Euclidean metric. 17 | """ 18 | 19 | def __init__(self, dimension): 20 | assert isinstance(dimension, int) and dimension > 0 21 | self.dimension = dimension 22 | self.metric = EuclideanMetric(dimension) 23 | 24 | def belongs(self, point): 25 | """ 26 | Evaluate if a point belongs to the Euclidean space. 27 | 28 | Parameters 29 | ---------- 30 | point : array-like, shape=[n_samples, dimension] 31 | Input points. 32 | 33 | Returns 34 | ------- 35 | belongs : array-like, shape=[n_samples, 1] 36 | """ 37 | point = gs.to_ndarray(point, to_ndim=2) 38 | n_points, point_dim = point.shape 39 | belongs = point_dim == self.dimension 40 | belongs = gs.to_ndarray(belongs, to_ndim=1) 41 | belongs = gs.to_ndarray(belongs, to_ndim=2, axis=1) 42 | belongs = gs.tile(belongs, (n_points, 1)) 43 | 44 | return belongs 45 | 46 | def random_uniform(self, n_samples=1, bound=1.): 47 | """ 48 | Sample in the Euclidean space with the uniform distribution. 49 | """ 50 | size = (n_samples, self.dimension) 51 | point = bound * (gs.random.rand(*size) - 0.5) * 2 52 | 53 | return point 54 | 55 | 56 | class EuclideanMetric(RiemannianMetric): 57 | """ 58 | Class for Euclidean metrics. 59 | 60 | As a Riemannian metric, the Euclidean metric is: 61 | - flat: the inner product is independent of the base point. 62 | - positive definite: it has signature (0, dimension), 63 | where dimension is the dimension of the Euclidean space. 64 | """ 65 | def __init__(self, dimension): 66 | assert isinstance(dimension, int) and dimension > 0 67 | super(EuclideanMetric, self).__init__( 68 | dimension=dimension, 69 | signature=(dimension, 0, 0)) 70 | 71 | def inner_product_matrix(self, base_point=None): 72 | """ 73 | Inner product matrix, independent of the base point. 74 | """ 75 | mat = gs.eye(self.dimension) 76 | mat = gs.to_ndarray(mat, to_ndim=3) 77 | return mat 78 | 79 | def exp(self, tangent_vec, base_point): 80 | """ 81 | The Riemannian exponential is the addition in the Euclidean space. 82 | """ 83 | tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=2) 84 | base_point = gs.to_ndarray(base_point, to_ndim=2) 85 | exp = base_point + tangent_vec 86 | return exp 87 | 88 | def log(self, point, base_point): 89 | """ 90 | The Riemannian logarithm is the subtraction in the Euclidean space. 91 | """ 92 | point = gs.to_ndarray(point, to_ndim=2) 93 | base_point = gs.to_ndarray(base_point, to_ndim=2) 94 | log = point - base_point 95 | return log 96 | 97 | def mean(self, points, weights=None): 98 | """ 99 | The Frechet mean of (weighted) points computed with the 100 | Euclidean metric is the weighted average of the points 101 | in the Euclidean space. 102 | """ 103 | if isinstance(points, list): 104 | points = gs.vstack(points) 105 | points = gs.to_ndarray(points, to_ndim=2) 106 | n_points = gs.shape(points)[0] 107 | 108 | if isinstance(weights, list): 109 | weights = gs.vstack(weights) 110 | elif weights is None: 111 | weights = gs.ones((n_points,)) 112 | 113 | weighted_points = gs.einsum('n,nj->nj', weights, points) 114 | mean = (gs.sum(weighted_points, axis=0) 115 | / gs.sum(weights)) 116 | mean = gs.to_ndarray(mean, to_ndim=2) 117 | return mean 118 | -------------------------------------------------------------------------------- /tests/test_backend_numpy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for numpy backend. 3 | """ 4 | 5 | import importlib 6 | import os 7 | import unittest 8 | import warnings 9 | 10 | import geomstats.backend as gs 11 | from geomstats.special_orthogonal_group import SpecialOrthogonalGroup 12 | 13 | 14 | class TestBackendNumpy(unittest.TestCase): 15 | _multiprocess_can_split_ = True 16 | 17 | @classmethod 18 | def setUpClass(cls): 19 | cls.initial_backend = os.environ['GEOMSTATS_BACKEND'] 20 | os.environ['GEOMSTATS_BACKEND'] = 'numpy' 21 | importlib.reload(gs) 22 | 23 | @classmethod 24 | def tearDownClass(cls): 25 | os.environ['GEOMSTATS_BACKEND'] = cls.initial_backend 26 | importlib.reload(gs) 27 | 28 | def setUp(self): 29 | warnings.simplefilter('ignore', category=ImportWarning) 30 | 31 | self.so3_group = SpecialOrthogonalGroup(n=3) 32 | self.n_samples = 2 33 | 34 | def test_logm(self): 35 | point = gs.array([[2., 0., 0.], 36 | [0., 3., 0.], 37 | [0., 0., 4.]]) 38 | result = gs.linalg.logm(point) 39 | expected = gs.array([[0.693147180, 0., 0.], 40 | [0., 1.098612288, 0.], 41 | [0., 0., 1.38629436]]) 42 | 43 | self.assertTrue(gs.allclose(result, expected)) 44 | 45 | def test_expm_and_logm(self): 46 | point = gs.array([[2., 0., 0.], 47 | [0., 3., 0.], 48 | [0., 0., 4.]]) 49 | result = gs.linalg.expm(gs.linalg.logm(point)) 50 | expected = point 51 | 52 | self.assertTrue(gs.allclose(result, expected)) 53 | 54 | def test_expm_vectorization(self): 55 | point = gs.array([[[2., 0., 0.], 56 | [0., 3., 0.], 57 | [0., 0., 4.]], 58 | [[1., 0., 0.], 59 | [0., 5., 0.], 60 | [0., 0., 6.]]]) 61 | 62 | expected = gs.array([[[7.38905609, 0., 0.], 63 | [0., 20.0855369, 0.], 64 | [0., 0., 54.5981500]], 65 | [[2.718281828, 0., 0.], 66 | [0., 148.413159, 0.], 67 | [0., 0., 403.42879349]]]) 68 | 69 | result = gs.linalg.expm(point) 70 | 71 | self.assertTrue(gs.allclose(result, expected)) 72 | 73 | def test_logm_vectorization_diagonal(self): 74 | point = gs.array([[[2., 0., 0.], 75 | [0., 3., 0.], 76 | [0., 0., 4.]], 77 | [[1., 0., 0.], 78 | [0., 5., 0.], 79 | [0., 0., 6.]]]) 80 | 81 | expected = gs.array([[[0.693147180, 0., 0.], 82 | [0., 1.09861228866, 0.], 83 | [0., 0., 1.38629436]], 84 | [[0., 0., 0.], 85 | [0., 1.609437912, 0.], 86 | [0., 0., 1.79175946]]]) 87 | 88 | result = gs.linalg.logm(point) 89 | 90 | self.assertTrue(gs.allclose(result, expected)) 91 | 92 | def test_expm_and_logm_vectorization_random_rotation(self): 93 | point = self.so3_group.random_uniform(self.n_samples) 94 | point = self.so3_group.matrix_from_rotation_vector(point) 95 | 96 | result = gs.linalg.expm(gs.linalg.logm(point)) 97 | expected = point 98 | 99 | self.assertTrue(gs.allclose(result, expected)) 100 | 101 | def test_expm_and_logm_vectorization(self): 102 | point = gs.array([[[2., 0., 0.], 103 | [0., 3., 0.], 104 | [0., 0., 4.]], 105 | [[1., 0., 0.], 106 | [0., 5., 0.], 107 | [0., 0., 6.]]]) 108 | result = gs.linalg.expm(gs.linalg.logm(point)) 109 | expected = point 110 | 111 | self.assertTrue(gs.allclose(result, expected)) 112 | 113 | 114 | if __name__ == '__main__': 115 | unittest.main() 116 | -------------------------------------------------------------------------------- /geomstats/general_linear_group.py: -------------------------------------------------------------------------------- 1 | """ 2 | The General Linear Group, i.e. the matrix group GL(n). 3 | """ 4 | 5 | import geomstats.backend as gs 6 | 7 | from geomstats.lie_group import LieGroup 8 | from geomstats.matrices_space import MatricesSpace 9 | 10 | 11 | class GeneralLinearGroup(LieGroup, MatricesSpace): 12 | """ 13 | Class for the General Linear Group, i.e. the matrix group GL(n). 14 | 15 | 16 | Note: The default representation for elements of GL(n) 17 | are matrices. 18 | For now, SO(n) and SE(n) elements are represented 19 | by a vector by default. 20 | """ 21 | 22 | def __init__(self, n): 23 | assert isinstance(n, int) and n > 0 24 | LieGroup.__init__(self, dimension=n*n) 25 | MatricesSpace.__init__(self, m=n, n=n) 26 | 27 | def get_identity(self, point_type=None): 28 | if point_type is None: 29 | point_type = self.default_point_type 30 | if point_type == 'matrix': 31 | return gs.eye(self.n) 32 | else: 33 | raise NotImplementedError( 34 | 'The identity of the general linear group is not' 35 | ' implemented for a point_type that is not \'matrix\'.') 36 | identity = property(get_identity) 37 | 38 | def belongs(self, mat): 39 | """ 40 | Check if mat belongs to GL(n). 41 | """ 42 | mat = gs.to_ndarray(mat, to_ndim=3) 43 | 44 | det = gs.linalg.det(mat) 45 | belongs = ~gs.isclose(det, 0.) 46 | 47 | belongs = gs.to_ndarray(belongs, to_ndim=1) 48 | belongs = gs.to_ndarray(belongs, to_ndim=2, axis=1) 49 | 50 | return belongs 51 | 52 | def compose(self, mat_a, mat_b): 53 | """ 54 | Matrix composition. 55 | """ 56 | mat_a = gs.to_ndarray(mat_a, to_ndim=3) 57 | mat_b = gs.to_ndarray(mat_b, to_ndim=3) 58 | composition = gs.einsum('nij,njk->nik', mat_a, mat_b) 59 | return composition 60 | 61 | def inverse(self, mat): 62 | """ 63 | Matrix inverse. 64 | """ 65 | mat = gs.to_ndarray(mat, to_ndim=3) 66 | return gs.linalg.inv(mat) 67 | 68 | def group_exp_from_identity(self, tangent_vec, point_type=None): 69 | """ 70 | Group exponential of the Lie group of 71 | all invertible matrices at the identity. 72 | """ 73 | tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=3) 74 | group_exp = gs.linalg.expm(tangent_vec) 75 | 76 | return gs.real(group_exp) 77 | 78 | def group_exp_not_from_identity( 79 | self, tangent_vec, base_point, point_type=None): 80 | """ 81 | Group exponential of the Lie group of 82 | all invertible matrices. 83 | """ 84 | tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=3) 85 | base_point = gs.to_ndarray(base_point, to_ndim=3) 86 | 87 | tangent_vec_at_identity = self.compose( 88 | self.inverse(base_point), tangent_vec) 89 | 90 | group_exp_from_identity = self.group_exp_from_identity( 91 | tangent_vec_at_identity) 92 | 93 | group_exp = self.compose( 94 | base_point, group_exp_from_identity) 95 | 96 | return group_exp 97 | 98 | def group_log_from_identity(self, point, point_type=None): 99 | """ 100 | Group logarithm of the Lie group of 101 | all invertible matrices at the identity. 102 | """ 103 | point = gs.to_ndarray(point, to_ndim=3) 104 | group_log = gs.linalg.logm(point) 105 | 106 | return gs.real(group_log) 107 | 108 | def group_log_not_from_identity( 109 | self, point, base_point, point_type=None): 110 | """ 111 | Group logarithm of the Lie group of 112 | all invertible matrices. 113 | """ 114 | point = gs.to_ndarray(point, to_ndim=3) 115 | base_point = gs.to_ndarray(base_point, to_ndim=3) 116 | 117 | point_near_identity = self.compose( 118 | self.inverse(base_point), point) 119 | 120 | group_log_from_identity = self.group_log_from_identity( 121 | point_near_identity) 122 | 123 | group_log = self.compose( 124 | base_point, group_log_from_identity) 125 | 126 | return group_log 127 | -------------------------------------------------------------------------------- /examples/loss_and_gradient_so3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Predict on manifolds: losses. 3 | """ 4 | 5 | import os 6 | import tensorflow as tf 7 | 8 | import geomstats.backend as gs 9 | import geomstats.lie_group as lie_group 10 | 11 | from geomstats.special_orthogonal_group import SpecialOrthogonalGroup 12 | 13 | 14 | SO3 = SpecialOrthogonalGroup(n=3) 15 | 16 | 17 | def loss(y_pred, y_true, 18 | metric=SO3.bi_invariant_metric, 19 | representation='vector'): 20 | 21 | if representation == 'quaternion': 22 | y_pred = SO3.rotation_vector_from_quaternion(y_pred) 23 | y_true = SO3.rotation_vector_from_quaternion(y_true) 24 | 25 | loss = lie_group.loss(y_pred, y_true, SO3, metric) 26 | return loss 27 | 28 | 29 | def grad(y_pred, y_true, 30 | metric=SO3.bi_invariant_metric, 31 | representation='vector'): 32 | 33 | y_pred = gs.expand_dims(y_pred, axis=0) 34 | y_true = gs.expand_dims(y_true, axis=0) 35 | 36 | if representation == 'vector': 37 | grad = lie_group.grad(y_pred, y_true, SO3, metric) 38 | 39 | if representation == 'quaternion': 40 | quat_scalar = y_pred[:, :1] 41 | quat_vec = y_pred[:, 1:] 42 | 43 | quat_vec_norm = gs.linalg.norm(quat_vec, axis=1) 44 | quat_sq_norm = quat_vec_norm ** 2 + quat_scalar ** 2 45 | 46 | quat_arctan2 = gs.arctan2(quat_vec_norm, quat_scalar) 47 | differential_scalar = - 2 * quat_vec / (quat_sq_norm) 48 | differential_scalar = gs.to_ndarray(differential_scalar, to_ndim=2) 49 | differential_scalar = gs.transpose(differential_scalar) 50 | 51 | differential_vec = (2 * (quat_scalar / quat_sq_norm 52 | - 2 * quat_arctan2 / quat_vec_norm) 53 | * (gs.einsum('ni,nj->nij', quat_vec, quat_vec) 54 | / quat_vec_norm ** 2) 55 | + 2 * quat_arctan2 / quat_vec_norm * gs.eye(3)) 56 | differential_vec = gs.squeeze(differential_vec) 57 | 58 | differential = gs.concatenate( 59 | [differential_scalar, differential_vec], 60 | axis=1) 61 | 62 | y_pred = SO3.rotation_vector_from_quaternion(y_pred) 63 | y_true = SO3.rotation_vector_from_quaternion(y_true) 64 | 65 | grad = lie_group.grad(y_pred, y_true, SO3, metric) 66 | 67 | grad = gs.matmul(grad, differential) 68 | 69 | grad = gs.squeeze(grad, axis=0) 70 | return grad 71 | 72 | 73 | def main(): 74 | y_pred = gs.array([1., 1.5, -0.3]) 75 | y_true = gs.array([0.1, 1.8, -0.1]) 76 | 77 | loss_rot_vec = loss(y_pred, y_true) 78 | grad_rot_vec = grad(y_pred, y_true) 79 | if os.environ['GEOMSTATS_BACKEND'] == 'tensorflow': 80 | with tf.Session() as sess: 81 | loss_rot_vec = sess.run(loss_rot_vec) 82 | grad_rot_vec = sess.run(grad_rot_vec) 83 | print('The loss between the rotation vectors is: {}'.format( 84 | loss_rot_vec[0, 0])) 85 | print('The riemannian gradient is: {}'.format( 86 | grad_rot_vec)) 87 | 88 | angle = gs.pi / 6 89 | cos = gs.cos(angle / 2) 90 | sin = gs.sin(angle / 2) 91 | u = gs.array([1., 2., 3.]) 92 | u = u / gs.linalg.norm(u) 93 | scalar = gs.to_ndarray(cos, to_ndim=1) 94 | vec = sin * u 95 | y_pred_quaternion = gs.concatenate([scalar, vec], axis=0) 96 | 97 | angle = gs.pi / 7 98 | cos = gs.cos(angle / 2) 99 | sin = gs.sin(angle / 2) 100 | u = gs.array([1., 2., 3.]) 101 | u = u / gs.linalg.norm(u) 102 | scalar = gs.to_ndarray(cos, to_ndim=1) 103 | vec = sin * u 104 | y_true_quaternion = gs.concatenate([scalar, vec], axis=0) 105 | 106 | loss_quaternion = loss(y_pred_quaternion, y_true_quaternion, 107 | representation='quaternion') 108 | grad_quaternion = grad(y_pred_quaternion, y_true_quaternion, 109 | representation='quaternion') 110 | 111 | if os.environ['GEOMSTATS_BACKEND'] == 'tensorflow': 112 | with tf.Session() as sess: 113 | loss_quaternion = sess.run(loss_quaternion) 114 | grad_quaternion = sess.run(grad_quaternion) 115 | print('The loss between the quaternions is: {}'.format( 116 | loss_quaternion[0, 0])) 117 | print('The riemannian gradient is: {}'.format( 118 | grad_quaternion)) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /examples/gradient_descent_s2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gradient descent on a sphere. 3 | 4 | We solve the following optimization problem: 5 | 6 | minimize: x^{T}Ax 7 | such than: x^{T}x = 1 8 | 9 | Using by operating a gradient descent of the quadratic form 10 | on the sphere. We solve this in 3 dimension on the 2-sphere 11 | manifold so that we can visualize and render the path as a video. 12 | """ 13 | 14 | import matplotlib 15 | matplotlib.use("Agg") # NOQA 16 | import matplotlib.animation as animation 17 | import matplotlib.pyplot as plt 18 | import numpy as np 19 | 20 | import geomstats.backend as gs 21 | import geomstats.visualization as visualization 22 | 23 | from geomstats.hypersphere import Hypersphere 24 | from geomstats.spd_matrices_space import SPDMatricesSpace 25 | 26 | 27 | SPHERE2 = Hypersphere(dimension=2) 28 | METRIC = SPHERE2.metric 29 | 30 | 31 | def gradient_descent(start, 32 | loss, 33 | grad, 34 | manifold, 35 | lr=0.01, 36 | max_iter=256, 37 | precision=1e-5): 38 | """Operate a gradient descent on a given manifold until either max_iter or 39 | a given precision is reached.""" 40 | x = start 41 | for i in range(max_iter): 42 | x_prev = x 43 | euclidean_grad = - lr * grad(x) 44 | tangent_vec = manifold.projection_to_tangent_space( 45 | vector=euclidean_grad, base_point=x) 46 | x = manifold.metric.exp(base_point=x, tangent_vec=tangent_vec)[0] 47 | if (gs.abs(loss(x, use_gs=True) - loss(x_prev, use_gs=True)) 48 | <= precision): 49 | print('x: %s' % x) 50 | print('reached precision %s' % precision) 51 | print('iterations: %d' % i) 52 | break 53 | yield x, loss(x) 54 | 55 | 56 | def plot_and_save_video(geodesics, 57 | loss, 58 | size=20, 59 | fps=10, 60 | dpi=100, 61 | out='out.mp4', 62 | color='red'): 63 | """Render a set of geodesics and save it to an mpeg 4 file.""" 64 | FFMpegWriter = animation.writers['ffmpeg'] 65 | writer = FFMpegWriter(fps=fps) 66 | fig = plt.figure(figsize=(size, size)) 67 | ax = fig.add_subplot(111, projection='3d', aspect='equal') 68 | sphere = visualization.Sphere() 69 | sphere.plot_heatmap(ax, loss) 70 | points = gs.to_ndarray(geodesics[0], to_ndim=2) 71 | sphere.add_points(points) 72 | sphere.draw(ax, color=color, marker='.') 73 | with writer.saving(fig, out, dpi=dpi): 74 | for points in geodesics[1:]: 75 | points = gs.to_ndarray(points, to_ndim=2) 76 | sphere.draw_points(ax, points=points, color=color, marker='.') 77 | writer.grab_frame() 78 | 79 | 80 | def generate_well_behaved_matrix(): 81 | """Generate a matrix with real eigenvalues.""" 82 | matrix = 2 * SPDMatricesSpace(n=3).random_uniform()[0] 83 | assert np.linalg.det(matrix) > 0 84 | return matrix 85 | 86 | 87 | def main(output_file='out.mp4', max_iter=128): 88 | gs.random.seed(1985) 89 | A = generate_well_behaved_matrix() 90 | 91 | def grad(x): 92 | return 2 * gs.matmul(A, x) 93 | 94 | def loss(x, use_gs=False): 95 | if use_gs: 96 | return gs.matmul(x, gs.matmul(A, x)) 97 | return np.matmul(x, np.matmul(A, x)) 98 | 99 | initial_point = gs.array([0., 1., 0.]) 100 | previous_x = initial_point 101 | geodesics = [] 102 | n_steps = 20 103 | for x, fx in gradient_descent(initial_point, 104 | loss, 105 | grad, 106 | max_iter=max_iter, 107 | manifold=SPHERE2): 108 | initial_tangent_vec = METRIC.log(point=x, base_point=previous_x) 109 | geodesic = METRIC.geodesic(initial_point=previous_x, 110 | initial_tangent_vec=initial_tangent_vec) 111 | 112 | t = np.linspace(0, 1, n_steps) 113 | geodesics.append(geodesic(t)) 114 | previous_x = x 115 | if output_file: 116 | plot_and_save_video(geodesics, loss, out=output_file) 117 | eig, _ = np.linalg.eig(A) 118 | np.testing.assert_almost_equal(loss(x), np.min(eig), decimal=2) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /geomstats/minkowski_space.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minkowski space. 3 | """ 4 | 5 | 6 | from geomstats.manifold import Manifold 7 | from geomstats.riemannian_metric import RiemannianMetric 8 | import geomstats.backend as gs 9 | 10 | 11 | class MinkowskiSpace(Manifold): 12 | """Class for Minkowski Space.""" 13 | 14 | def __init__(self, dimension): 15 | assert isinstance(dimension, int) and dimension > 0 16 | self.dimension = dimension 17 | self.metric = MinkowskiMetric(dimension) 18 | 19 | def belongs(self, point): 20 | """ 21 | Evaluate if a point belongs to the Minkowski space. 22 | 23 | Parameters 24 | ---------- 25 | point : array-like, shape=[n_samples, dimension] 26 | Input points. 27 | 28 | Returns 29 | ------- 30 | belongs : array-like, shape=[n_samples, 1] 31 | """ 32 | point = gs.to_ndarray(point, to_ndim=2) 33 | n_points, point_dim = point.shape 34 | belongs = point_dim == self.dimension 35 | belongs = gs.to_ndarray(belongs, to_ndim=1) 36 | belongs = gs.to_ndarray(belongs, to_ndim=2, axis=1) 37 | belongs = gs.tile(belongs, (n_points, 1)) 38 | 39 | return belongs 40 | 41 | def random_uniform(self, n_samples=1, bound=1.): 42 | """ 43 | Sample in the Minkowski space with the uniform distribution. 44 | 45 | Returns 46 | ------- 47 | points : array-like, shape=[n_samples, dimension] 48 | Sampled points. 49 | """ 50 | size = (n_samples, self.dimension) 51 | point = bound * gs.random.rand(*size) * 2 - 1 52 | 53 | return point 54 | 55 | 56 | class MinkowskiMetric(RiemannianMetric): 57 | """ 58 | Class for the pseudo-Riemannian Minkowski metric. 59 | The metric is flat: the inner product is independent of the base point. 60 | """ 61 | def __init__(self, dimension): 62 | super(MinkowskiMetric, self).__init__( 63 | dimension=dimension, 64 | signature=(dimension - 1, 1, 0)) 65 | 66 | def inner_product_matrix(self, base_point=None): 67 | """ 68 | Inner product matrix, independent of the base point. 69 | 70 | Parameters 71 | ---------- 72 | base_point: array-like, shape=[n_samples, dimension] 73 | """ 74 | inner_prod_mat = gs.eye(self.dimension-1, self.dimension-1) 75 | first_row = gs.array([0.] * (self.dimension - 1)) 76 | first_row = gs.to_ndarray(first_row, to_ndim=2, axis=1) 77 | inner_prod_mat = gs.vstack([gs.transpose(first_row), 78 | inner_prod_mat]) 79 | 80 | first_column = gs.array([-1.] + [0.] * (self.dimension - 1)) 81 | first_column = gs.to_ndarray(first_column, to_ndim=2, axis=1) 82 | inner_prod_mat = gs.hstack([first_column, 83 | inner_prod_mat]) 84 | 85 | return inner_prod_mat 86 | 87 | def exp(self, tangent_vec, base_point): 88 | """ 89 | The Riemannian exponential is the addition in the Minkowski space. 90 | 91 | Parameters 92 | ---------- 93 | tangent_vec: array-like, shape=[n_samples, dimension] 94 | or shape=[1, dimension] 95 | 96 | base_point: array-like, shape=[n_samples, dimension] 97 | or shape=[1, dimension] 98 | """ 99 | tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=2) 100 | base_point = gs.to_ndarray(base_point, to_ndim=2) 101 | return base_point + tangent_vec 102 | 103 | def log(self, point, base_point): 104 | """ 105 | The Riemannian logarithm is the subtraction in the Minkowski space. 106 | 107 | Parameters 108 | ---------- 109 | point: array-like, shape=[n_samples, dimension] 110 | or shape=[1, dimension] 111 | 112 | base_point: array-like, shape=[n_samples, dimension] 113 | or shape=[1, dimension] 114 | """ 115 | point = gs.to_ndarray(point, to_ndim=2) 116 | base_point = gs.to_ndarray(base_point, to_ndim=2) 117 | return point - base_point 118 | 119 | def mean(self, points, weights=None): 120 | """ 121 | The Frechet mean of (weighted) points is the weighted average of 122 | the points in the Minkowski space. 123 | 124 | Parameters 125 | ---------- 126 | points: array-like, shape=[n_samples, dimension] 127 | 128 | weights: array-like, shape=[n_samples, 1], optional 129 | """ 130 | if isinstance(points, list): 131 | points = gs.vstack(points) 132 | points = gs.to_ndarray(points, to_ndim=2) 133 | n_points = gs.shape(points)[0] 134 | 135 | if isinstance(weights, list): 136 | weights = gs.vstack(weights) 137 | elif weights is None: 138 | weights = gs.ones((n_points,)) 139 | 140 | weighted_points = gs.einsum('n,nj->nj', weights, points) 141 | mean = (gs.sum(weighted_points, axis=0) 142 | / gs.sum(weights)) 143 | mean = gs.to_ndarray(mean, to_ndim=2) 144 | return mean 145 | -------------------------------------------------------------------------------- /examples/loss_and_gradient_se3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Predict on SE3: losses. 3 | """ 4 | 5 | import os 6 | import tensorflow as tf 7 | 8 | import geomstats.backend as gs 9 | import geomstats.lie_group as lie_group 10 | 11 | from geomstats.special_euclidean_group import SpecialEuclideanGroup 12 | from geomstats.special_orthogonal_group import SpecialOrthogonalGroup 13 | 14 | 15 | SE3 = SpecialEuclideanGroup(n=3) 16 | SO3 = SpecialOrthogonalGroup(n=3) 17 | 18 | 19 | def loss(y_pred, y_true, 20 | metric=SE3.left_canonical_metric, 21 | representation='vector'): 22 | """ 23 | Loss function given by a riemannian metric on a Lie group, 24 | by default the left-invariant canonical metric. 25 | """ 26 | if gs.ndim(y_pred) == 1: 27 | y_pred = gs.expand_dims(y_pred, axis=0) 28 | if gs.ndim(y_true) == 1: 29 | y_true = gs.expand_dims(y_true, axis=0) 30 | 31 | if representation == 'quaternion': 32 | y_pred_rot_vec = SO3.rotation_vector_from_quaternion(y_pred[:, :4]) 33 | y_pred = gs.hstack([y_pred_rot_vec, y_pred[:, 4:]]) 34 | y_true_rot_vec = SO3.rotation_vector_from_quaternion(y_true[:, :4]) 35 | y_true = gs.hstack([y_true_rot_vec, y_true[:, 4:]]) 36 | 37 | loss = lie_group.loss(y_pred, y_true, SE3, metric) 38 | return loss 39 | 40 | 41 | def grad(y_pred, y_true, 42 | metric=SE3.left_canonical_metric, 43 | representation='vector'): 44 | """ 45 | Closed-form for the gradient of pose_loss. 46 | 47 | :return: tangent vector at point y_pred. 48 | """ 49 | if gs.ndim(y_pred) == 1: 50 | y_pred = gs.expand_dims(y_pred, axis=0) 51 | if gs.ndim(y_true) == 1: 52 | y_true = gs.expand_dims(y_true, axis=0) 53 | 54 | if representation == 'vector': 55 | grad = lie_group.grad(y_pred, y_true, SE3, metric) 56 | 57 | if representation == 'quaternion': 58 | 59 | y_pred_rot_vec = SO3.rotation_vector_from_quaternion(y_pred[:, :4]) 60 | y_pred_pose = gs.hstack([y_pred_rot_vec, y_pred[:, 4:]]) 61 | y_true_rot_vec = SO3.rotation_vector_from_quaternion(y_true[:, :4]) 62 | y_true_pose = gs.hstack([y_true_rot_vec, y_true[:, 4:]]) 63 | grad = lie_group.grad(y_pred_pose, y_true_pose, SE3, metric) 64 | 65 | quat_scalar = y_pred[:, :1] 66 | quat_vec = y_pred[:, 1:4] 67 | 68 | quat_vec_norm = gs.linalg.norm(quat_vec, axis=1) 69 | quat_sq_norm = quat_vec_norm ** 2 + quat_scalar ** 2 70 | 71 | quat_arctan2 = gs.arctan2(quat_vec_norm, quat_scalar) 72 | differential_scalar = - 2 * quat_vec / (quat_sq_norm) 73 | differential_vec = (2 * (quat_scalar / quat_sq_norm 74 | - 2 * quat_arctan2 / quat_vec_norm) 75 | * (gs.einsum('ni,nj->nij', quat_vec, quat_vec) 76 | / quat_vec_norm * quat_vec_norm) 77 | + 2 * quat_arctan2 / quat_vec_norm * gs.eye(3)) 78 | 79 | differential_scalar_t = gs.transpose(differential_scalar, axes=(1, 0)) 80 | 81 | upper_left_block = gs.hstack( 82 | (differential_scalar_t, differential_vec[0])) 83 | upper_right_block = gs.zeros((3, 3)) 84 | lower_right_block = gs.eye(3) 85 | lower_left_block = gs.zeros((3, 4)) 86 | 87 | top = gs.hstack((upper_left_block, upper_right_block)) 88 | bottom = gs.hstack((lower_left_block, lower_right_block)) 89 | 90 | differential = gs.vstack((top, bottom)) 91 | differential = gs.expand_dims(differential, axis=0) 92 | 93 | grad = gs.einsum('ni,nij->ni', grad, differential) 94 | 95 | grad = gs.squeeze(grad, axis=0) 96 | return grad 97 | 98 | 99 | def main(): 100 | y_pred = gs.array([1., 1.5, -0.3, 5., 6., 7.]) 101 | y_true = gs.array([0.1, 1.8, -0.1, 4., 5., 6.]) 102 | 103 | loss_rot_vec = loss(y_pred, y_true) 104 | grad_rot_vec = grad(y_pred, y_true) 105 | if os.environ['GEOMSTATS_BACKEND'] == 'tensorflow': 106 | with tf.Session() as sess: 107 | loss_rot_vec = sess.run(loss_rot_vec) 108 | grad_rot_vec = sess.run(grad_rot_vec) 109 | print('The loss between the poses using rotation vectors is: {}'.format( 110 | loss_rot_vec[0, 0])) 111 | print('The riemannian gradient is: {}'.format(grad_rot_vec)) 112 | 113 | angle = gs.pi / 6 114 | cos = gs.cos(angle / 2) 115 | sin = gs.sin(angle / 2) 116 | u = gs.array([1., 2., 3.]) 117 | u = u / gs.linalg.norm(u) 118 | scalar = gs.array(cos) 119 | vec = sin * u 120 | translation = gs.array([5., 6., 7.]) 121 | y_pred_quaternion = gs.concatenate([[scalar], vec, translation], axis=0) 122 | 123 | angle = gs.pi / 7 124 | cos = gs.cos(angle / 2) 125 | sin = gs.sin(angle / 2) 126 | u = gs.array([1., 2., 3.]) 127 | u = u / gs.linalg.norm(u) 128 | scalar = gs.array(cos) 129 | vec = sin * u 130 | translation = gs.array([4., 5., 6.]) 131 | y_true_quaternion = gs.concatenate([[scalar], vec, translation], axis=0) 132 | 133 | loss_quaternion = loss(y_pred_quaternion, y_true_quaternion, 134 | representation='quaternion') 135 | grad_quaternion = grad(y_pred_quaternion, y_true_quaternion, 136 | representation='quaternion') 137 | if os.environ['GEOMSTATS_BACKEND'] == 'tensorflow': 138 | with tf.Session() as sess: 139 | loss_quaternion = sess.run(loss_quaternion) 140 | grad_quaternion = sess.run(grad_quaternion) 141 | print('The loss between the poses using quaternions is: {}'.format( 142 | loss_quaternion[0, 0])) 143 | print('The riemannian gradient is: {}'.format( 144 | grad_quaternion)) 145 | 146 | 147 | if __name__ == "__main__": 148 | main() 149 | -------------------------------------------------------------------------------- /tests/test_general_linear_group.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for General Linear group. 3 | """ 4 | 5 | import warnings 6 | 7 | import geomstats.backend as gs 8 | import geomstats.tests 9 | import tests.helper as helper 10 | 11 | from geomstats.general_linear_group import GeneralLinearGroup 12 | from geomstats.special_orthogonal_group import SpecialOrthogonalGroup 13 | 14 | RTOL = 1e-5 15 | 16 | 17 | class TestGeneralLinearGroupMethods(geomstats.tests.TestCase): 18 | _multiprocess_can_split_ = True 19 | 20 | def setUp(self): 21 | gs.random.seed(1234) 22 | self.n = 3 23 | self.n_samples = 2 24 | self.group = GeneralLinearGroup(n=self.n) 25 | # We generate invertible matrices using so3_group 26 | self.so3_group = SpecialOrthogonalGroup(n=self.n) 27 | 28 | warnings.simplefilter('ignore', category=ImportWarning) 29 | 30 | @geomstats.tests.np_only 31 | def test_belongs(self): 32 | """ 33 | A rotation matrix belongs to the matrix Lie group 34 | of invertible matrices. 35 | """ 36 | rot_vec = gs.array([0.2, -0.1, 0.1]) 37 | rot_mat = self.so3_group.matrix_from_rotation_vector(rot_vec) 38 | result = self.group.belongs(rot_mat) 39 | expected = gs.array([True]) 40 | 41 | self.assertAllClose(result, expected) 42 | 43 | def test_compose(self): 44 | # 1. Composition by identity, on the right 45 | # Expect the original transformation 46 | rot_vec = gs.array([0.2, -0.1, 0.1]) 47 | mat = self.so3_group.matrix_from_rotation_vector(rot_vec) 48 | 49 | result = self.group.compose(mat, self.group.identity) 50 | expected = mat 51 | expected = helper.to_matrix(mat) 52 | 53 | self.assertAllClose(result, expected) 54 | 55 | # 2. Composition by identity, on the left 56 | # Expect the original transformation 57 | rot_vec = gs.array([0.2, 0.1, -0.1]) 58 | mat = self.so3_group.matrix_from_rotation_vector(rot_vec) 59 | 60 | result = self.group.compose(self.group.identity, mat) 61 | expected = mat 62 | 63 | self.assertAllClose(result, expected) 64 | 65 | def test_inverse(self): 66 | mat = gs.array([ 67 | [1., 2., 3.], 68 | [4., 5., 6.], 69 | [7., 8., 10.]]) 70 | result = self.group.inverse(mat) 71 | expected = 1. / 3. * gs.array([ 72 | [-2., -4., 3.], 73 | [-2., 11., -6.], 74 | [3., -6., 3.]]) 75 | expected = helper.to_matrix(expected) 76 | 77 | self.assertAllClose(result, expected) 78 | 79 | def test_compose_and_inverse(self): 80 | # 1. Compose transformation by its inverse on the right 81 | # Expect the group identity 82 | rot_vec = gs.array([0.2, 0.1, 0.1]) 83 | mat = self.so3_group.matrix_from_rotation_vector(rot_vec) 84 | inv_mat = self.group.inverse(mat) 85 | 86 | result = self.group.compose(mat, inv_mat) 87 | expected = self.group.identity 88 | expected = helper.to_matrix(expected) 89 | 90 | self.assertAllClose(result, expected) 91 | 92 | # 2. Compose transformation by its inverse on the left 93 | # Expect the group identity 94 | rot_vec = gs.array([0.7, 0.1, 0.1]) 95 | mat = self.so3_group.matrix_from_rotation_vector(rot_vec) 96 | inv_mat = self.group.inverse(mat) 97 | 98 | result = self.group.compose(inv_mat, mat) 99 | expected = self.group.identity 100 | expected = helper.to_matrix(expected) 101 | 102 | self.assertAllClose(result, expected) 103 | 104 | @geomstats.tests.np_and_tf_only 105 | def test_group_log_and_exp(self): 106 | point = 5 * gs.eye(self.n) 107 | 108 | group_log = self.group.group_log(point) 109 | result = self.group.group_exp(group_log) 110 | expected = point 111 | expected = helper.to_matrix(expected) 112 | 113 | self.assertAllClose(result, expected) 114 | 115 | @geomstats.tests.np_and_tf_only 116 | def test_group_exp_vectorization(self): 117 | point = gs.array([[[2., 0., 0.], 118 | [0., 3., 0.], 119 | [0., 0., 4.]], 120 | [[1., 0., 0.], 121 | [0., 5., 0.], 122 | [0., 0., 6.]]]) 123 | 124 | expected = gs.array([[[7.38905609, 0., 0.], 125 | [0., 20.0855369, 0.], 126 | [0., 0., 54.5981500]], 127 | [[2.718281828, 0., 0.], 128 | [0., 148.413159, 0.], 129 | [0., 0., 403.42879349]]]) 130 | 131 | result = self.group.group_exp(point) 132 | 133 | self.assertAllClose(result, expected, rtol=1e-3) 134 | 135 | @geomstats.tests.np_and_tf_only 136 | def test_group_log_vectorization(self): 137 | point = gs.array([[[2., 0., 0.], 138 | [0., 3., 0.], 139 | [0., 0., 4.]], 140 | [[1., 0., 0.], 141 | [0., 5., 0.], 142 | [0., 0., 6.]]]) 143 | 144 | expected = gs.array([[[0.693147180, 0., 0.], 145 | [0., 1.09861228866, 0.], 146 | [0., 0., 1.38629436]], 147 | [[0., 0., 0.], 148 | [0., 1.609437912, 0.], 149 | [0., 0., 1.79175946]]]) 150 | 151 | result = self.group.group_log(point) 152 | 153 | self.assertAllClose(result, expected, atol=1e-4) 154 | 155 | @geomstats.tests.np_and_tf_only 156 | def test_expm_and_logm_vectorization_symmetric(self): 157 | point = gs.array([[[2., 0., 0.], 158 | [0., 3., 0.], 159 | [0., 0., 4.]], 160 | [[1., 0., 0.], 161 | [0., 5., 0.], 162 | [0., 0., 6.]]]) 163 | result = self.group.group_exp(self.group.group_log(point)) 164 | expected = point 165 | 166 | self.assertAllClose(result, expected) 167 | 168 | 169 | if __name__ == '__main__': 170 | geomstats.tests.main() 171 | -------------------------------------------------------------------------------- /geomstats/backend/tensorflow.py: -------------------------------------------------------------------------------- 1 | """Tensorflow based computation backend.""" 2 | 3 | import tensorflow as tf 4 | 5 | 6 | int8 = tf.int8 7 | int32 = tf.int32 8 | int64 = tf.int64 9 | float16 = tf.float16 10 | float32 = tf.float32 11 | float64 = tf.float64 12 | 13 | 14 | def while_loop(*args, **kwargs): 15 | return tf.while_loop(*args, **kwargs) 16 | 17 | 18 | def logical_or(x, y): 19 | return tf.logical_or(x, y) 20 | 21 | 22 | def get_mask_i_float(i, n): 23 | range_n = arange(n) 24 | i_float = cast(array([i]), int32)[0] 25 | mask_i = equal(range_n, i_float) 26 | mask_i_float = cast(mask_i, float32) 27 | return mask_i_float 28 | 29 | 30 | def gather(*args, **kwargs): 31 | return tf.gather(*args, **kwargs) 32 | 33 | 34 | def where(*args, **kwargs): 35 | return tf.where(*args, **kwargs) 36 | 37 | 38 | def vectorize(x, pyfunc, multiple_args=False, dtype=None, **kwargs): 39 | if multiple_args: 40 | return tf.map_fn(lambda x: pyfunc(*x), elems=x, dtype=dtype) 41 | return tf.map_fn(pyfunc, elems=x, dtype=dtype) 42 | 43 | 44 | def sign(x): 45 | return tf.sign(x) 46 | 47 | 48 | def hsplit(x, n_splits): 49 | return tf.split(x, num_or_size_splits=n_splits, axis=1) 50 | 51 | 52 | def amax(x): 53 | return tf.reduce_max(x) 54 | 55 | 56 | def real(x): 57 | return tf.real(x) 58 | 59 | 60 | def cond(*args, **kwargs): 61 | return tf.cond(*args, **kwargs) 62 | 63 | 64 | def reshape(*args, **kwargs): 65 | return tf.reshape(*args, **kwargs) 66 | 67 | 68 | def arange(*args, **kwargs): 69 | return tf.range(*args, **kwargs) 70 | 71 | 72 | def outer(x, y): 73 | return tf.einsum('i,j->ij', x, y) 74 | 75 | 76 | def copy(x): 77 | return tf.Variable(x) 78 | 79 | 80 | def linspace(start, stop, num): 81 | return tf.linspace(start, stop, num) 82 | 83 | 84 | def mod(x, y): 85 | return tf.mod(x, y) 86 | 87 | 88 | def boolean_mask(x, mask, name='boolean_mask', axis=None): 89 | return tf.boolean_mask(x, mask, name, axis) 90 | 91 | 92 | def exp(x): 93 | return tf.exp(x) 94 | 95 | 96 | def log(x): 97 | return tf.log(x) 98 | 99 | 100 | def hstack(x): 101 | return tf.concat(x, axis=1) 102 | 103 | 104 | def vstack(x): 105 | return tf.concat(x, axis=0) 106 | 107 | 108 | def cast(x, dtype): 109 | return tf.cast(x, dtype) 110 | 111 | 112 | def divide(x1, x2): 113 | return tf.divide(x1, x2) 114 | 115 | 116 | def tile(x, reps): 117 | return tf.tile(x, reps) 118 | 119 | 120 | def eval(x): 121 | if tf.executing_eagerly(): 122 | return x 123 | return x.eval() 124 | 125 | 126 | def abs(x): 127 | return tf.abs(x) 128 | 129 | 130 | def zeros(x): 131 | return tf.zeros(x) 132 | 133 | 134 | def ones(x): 135 | return tf.ones(x) 136 | 137 | 138 | def sin(x): 139 | return tf.sin(x) 140 | 141 | 142 | def cos(x): 143 | return tf.cos(x) 144 | 145 | 146 | def cosh(x): 147 | return tf.cosh(x) 148 | 149 | 150 | def sinh(x): 151 | return tf.sinh(x) 152 | 153 | 154 | def tanh(x): 155 | return tf.tanh(x) 156 | 157 | 158 | def arccosh(x): 159 | return tf.acosh(x) 160 | 161 | 162 | def tan(x): 163 | return tf.tan(x) 164 | 165 | 166 | def arcsin(x): 167 | return tf.asin(x) 168 | 169 | 170 | def arccos(x): 171 | return tf.acos(x) 172 | 173 | 174 | def shape(x): 175 | return tf.shape(x) 176 | 177 | 178 | def ndim(x): 179 | x = array(x) 180 | dims = x.get_shape()._dims 181 | if dims is not None: 182 | return len(dims) 183 | return None 184 | 185 | 186 | def dot(x, y): 187 | return tf.tensordot(x, y, axes=1) 188 | 189 | 190 | def maximum(x, y): 191 | return tf.maximum(x, y) 192 | 193 | 194 | def greater(x, y): 195 | return tf.greater(x, y) 196 | 197 | 198 | def greater_equal(x, y): 199 | return tf.greater_equal(x, y) 200 | 201 | 202 | def equal(x, y): 203 | return tf.equal(x, y) 204 | 205 | 206 | def to_ndarray(x, to_ndim, axis=0): 207 | if ndim(x) == to_ndim - 1: 208 | x = tf.expand_dims(x, axis=axis) 209 | 210 | return x 211 | 212 | 213 | def sqrt(x): 214 | return tf.sqrt(x) 215 | 216 | 217 | def isclose(x, y, rtol=1e-05, atol=1e-08): 218 | rhs = tf.constant(atol) + tf.constant(rtol) * tf.abs(y) 219 | return tf.less_equal(tf.abs(tf.subtract(x, y)), rhs) 220 | 221 | 222 | def allclose(x, y, rtol=1e-05, atol=1e-08): 223 | return tf.reduce_all(isclose(x, y, rtol=rtol, atol=atol)) 224 | 225 | 226 | def less(x, y): 227 | return tf.less(x, y) 228 | 229 | 230 | def less_equal(x, y): 231 | return tf.less_equal(x, y) 232 | 233 | 234 | def eye(n, m=None): 235 | if m is None: 236 | m = n 237 | n = cast(n, dtype=int32) 238 | m = cast(m, dtype=int32) 239 | return tf.eye(num_rows=n, num_columns=m) 240 | 241 | 242 | def matmul(x, y): 243 | return tf.matmul(x, y) 244 | 245 | 246 | def argmax(*args, **kwargs): 247 | return tf.argmax(*args, **kwargs) 248 | 249 | 250 | def sum(*args, **kwargs): 251 | return tf.reduce_sum(*args, **kwargs) 252 | 253 | 254 | def einsum(equation, *inputs, **kwargs): 255 | return tf.einsum(equation, *inputs, **kwargs) 256 | 257 | 258 | def transpose(x, axes=None): 259 | return tf.transpose(x, perm=axes) 260 | 261 | 262 | def squeeze(x, **kwargs): 263 | return tf.squeeze(x, **kwargs) 264 | 265 | 266 | def zeros_like(x): 267 | return tf.zeros_like(x) 268 | 269 | 270 | def ones_like(x): 271 | return tf.ones_like(x) 272 | 273 | 274 | def trace(x, **kwargs): 275 | return tf.trace(x) 276 | 277 | 278 | def array(x): 279 | return tf.convert_to_tensor(x) 280 | 281 | 282 | def all(bool_tensor, axis=None, keepdims=False): 283 | bool_tensor = tf.cast(bool_tensor, tf.bool) 284 | all_true = tf.reduce_all(bool_tensor, axis, keepdims) 285 | return all_true 286 | 287 | 288 | def concatenate(*args, **kwargs): 289 | return tf.concat(*args, **kwargs) 290 | 291 | 292 | def asarray(x): 293 | return x 294 | 295 | 296 | def expand_dims(x, axis=None): 297 | return tf.expand_dims(x, axis) 298 | 299 | 300 | def clip(x, min_value, max_value): 301 | return tf.clip_by_value(x, min_value, max_value) 302 | 303 | 304 | def floor(x): 305 | return tf.floor(x) 306 | 307 | 308 | def diag(a): 309 | return tf.map_fn( 310 | lambda x: tf.diag(x), 311 | a) 312 | 313 | 314 | def cross(a, b): 315 | return tf.cross(a, b) 316 | 317 | 318 | def stack(*args, **kwargs): 319 | return tf.stack(*args, **kwargs) 320 | 321 | 322 | def arctan2(*args, **kwargs): 323 | return tf.atan2(*args, **kwargs) 324 | 325 | 326 | def diagonal(*args, **kwargs): 327 | return tf.linalg.diag_part(*args) 328 | 329 | 330 | def mean(x, axis=None): 331 | return tf.reduce_mean(x, axis) 332 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Geomstats 2 | 3 | Welcome to the geomstats repository! 4 | We are excited you are here and want to contribute. 5 | 6 | More details on contributing can be found on the [documentation website][link_doc]. 7 | 8 | ## Practical Guide to Submitting your Contribution 9 | 10 | These guidelines are designed to make it as easy as possible to get involved. 11 | If you have any questions that are not discussed below, 12 | please let us know by opening an [issue][link_issues]! 13 | 14 | Before you start you will need to set up a free [GitHub][link_github] account and sign in. 15 | Here are some [instructions][link_signupinstructions]. 16 | 17 | 18 | ## Joining the Conversation 19 | 20 | `geomstats` is maintained by a growing group of enthusiastic developers! 21 | Most of our discussions take place on [issues][link_issues]. 22 | 23 | 24 | ## Contributing through GitHub 25 | 26 | [git][link_git] is a really useful tool for version control. 27 | [GitHub][link_github] sits on top of git and supports collaborative and distributed working. 28 | 29 | If you are not yet familiar with `git`, there are lots of great resources to help you *git* started! 30 | Some of our favorites include the [git Handbook][link_handbook] and 31 | the [Software Carpentry introduction to git][link_swc_intro]. 32 | 33 | On GitHub, you will use [Markdown][markdown] to chat in issues and pull requests. 34 | 35 | GitHub has a really helpful page for getting started with 36 | [writing and formatting Markdown on GitHub][writing_formatting_github]. 37 | 38 | 39 | ## Understanding Issues 40 | 41 | Every project on GitHub uses [issues][link_issues] slightly differently. 42 | 43 | The following outlines how ``geomstats`` developers think about these tools. 44 | 45 | * **Issues** are individual pieces of work that need to be completed to move the project forwards. 46 | A general guideline: if you find yourself tempted to write a great big issue that 47 | is difficult to describe as one unit of work, please consider splitting it into two or more issues. 48 | 49 | * Issues are assigned [labels](#issue-labels) which explain how they relate to the overall project's 50 | goals and immediate next steps. 51 | 52 | 53 | ## Making a Change 54 | 55 | We appreciate all contributions to ``geomstats``, 56 | but those accepted fastest will follow a workflow similar to the following: 57 | 58 | **1. Comment on an existing issue or open a new issue referencing your addition.** 59 | 60 | This allows other members of the ``geomstats`` development team to confirm that you are not 61 | overlapping with work that is currently underway and that everyone is on the same page 62 | with the goal of the work you are going to carry out. 63 | 64 | [This blog][link_pushpullblog] is a nice explanation of why putting this work in up front 65 | is so useful to everyone involved. 66 | 67 | **2. [Fork][link_fork] the [geomstats repository][link_geomstats] to your profile.** 68 | 69 | This is now your own unique copy of ``geomstats``. 70 | Changes here will not effect anyone else's work, so it is a safe space to explore edits to the code! 71 | 72 | Make sure to [keep your fork up to date][link_updateupstreamwiki] with the master repository. 73 | 74 | **3. Make the changes you have discussed, following the [geomstats coding style guide](#geomstats-coding-style-guide).** 75 | 76 | - Create your feature branch (`git checkout -b feature/fooBar`) 77 | - Commit your changes (`git commit -am 'Add some fooBar'`) 78 | - Push to the branch (`git push origin feature/fooBar`) 79 | 80 | Try to keep the changes focused. 81 | If you feel tempted to "branch out" then please make a [new branch][link_branches]. 82 | 83 | **4. Add the corresponding unit tests.** 84 | 85 | If you are adding a new feature, do not forget to add the corresponding [unit tests][link_unit_tests]. 86 | As ``geomstats`` enables numpy and tensorflow, your unit tests should run on these two backends. 87 | 88 | **5. Ensure Geomstats Coding Style Guide.** 89 | 90 | Ensure that your code is compliant with [PEP8][link_pep8], 91 | the coding style guide for python. 92 | 93 | Use [flake8][link_flake8] or [yapf][link_yapf] to automatically enforces this style. 94 | 95 | 96 | **6. Submit a [pull request][link_pullrequest].** 97 | 98 | A member of the development team will review your changes to confirm 99 | that they can be merged into the main code base. 100 | 101 | Use the Github labels to label your pull request: 102 | * ``enhancement``: enhancements or new features 103 | * ``bug``: bug fixes 104 | * ``test``: new or updated tests 105 | * ``documentation``: new or updated documentation 106 | * ``style``: style changes 107 | * ``refactoring``: refactoring existing code 108 | * ``continuous integration``: updates to continous integration infrastructure 109 | 110 | 111 | ## Recognizing Contributions 112 | 113 | We welcome and recognize all contributions from documentation to testing to code development. 114 | You can see a list of current contributors in our [README.md][link_readme]. 115 | If you are new to the project, do not forget to add your name and affiliation there! 116 | 117 | ## Thank You! 118 | 119 |
120 | 121 | *— Based on contributing guidelines from the [fMRIprep][link_fmriprep] project.* 122 | 123 | [link_github]: https://github.com/ 124 | [link_geomstats]: https://github.com/geomstats/geomstats 125 | [link_signupinstructions]: https://help.github.com/articles/signing-up-for-a-new-github-account 126 | 127 | [link_git]: https://git-scm.com/ 128 | [link_handbook]: https://guides.github.com/introduction/git-handbook/ 129 | [link_swc_intro]: http://swcarpentry.github.io/git-novice/ 130 | 131 | [writing_formatting_github]: https://help.github.com/articles/getting-started-with-writing-and-formatting-on-github 132 | [markdown]: https://daringfireball.net/projects/markdown 133 | [rick_roll]: https://www.youtube.com/watch?v=dQw4w9WgXcQ 134 | 135 | [link_issues]: https://github.com/geomstats/geomstats/issues 136 | [link_labels]: https://github.com/geomstats/geomstats/labels 137 | [link_discussingissues]: https://help.github.com/articles/discussing-projects-in-issues-and-pull-requests 138 | 139 | [link_pullrequest]: https://help.github.com/articles/creating-a-pull-request/ 140 | [link_fork]: https://help.github.com/articles/fork-a-repo/ 141 | [link_pushpullblog]: https://www.igvita.com/2011/12/19/dont-push-your-pull-requests/ 142 | [link_branches]: https://help.github.com/articles/creating-and-deleting-branches-within-your-repository/ 143 | [link_unit_tests]: https://github.com/geomstats/geomstats/tree/master/tests 144 | [link_updateupstreamwiki]: https://help.github.com/articles/syncing-a-fork/ 145 | [link_pep8]: https://www.python.org/dev/peps/pep-0008/ 146 | [link_linters]: https://en.wikipedia.org/wiki/Lint_(software) 147 | [link_flake8]: http://flake8.pycqa.org/en/latest/ 148 | [link_yapf]: https://github.com/google/yapf 149 | [link_readme]: https://github.com/geomstats/geomstats/blob/master/README.md 150 | [link_fmriprep]: https://github.com/poldracklab/fmriprep/ 151 | [link_doc]: https://geomstats.github.io/contributing.html 152 | -------------------------------------------------------------------------------- /geomstats/learning/_template.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a module to be used as a reference for building other modules 3 | """ 4 | import numpy as np 5 | from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin 6 | from sklearn.utils.validation import check_X_y, check_array, check_is_fitted 7 | from sklearn.utils.multiclass import unique_labels 8 | from sklearn.metrics import euclidean_distances 9 | 10 | 11 | class TemplateEstimator(BaseEstimator): 12 | """ A template estimator to be used as a reference implementation. 13 | 14 | For more information regarding how to build your own estimator, read more 15 | in the :ref:`User Guide `. 16 | 17 | Parameters 18 | ---------- 19 | demo_param : str, default='demo_param' 20 | A parameter used for demonstation of how to pass and store paramters. 21 | """ 22 | def __init__(self, demo_param='demo_param'): 23 | self.demo_param = demo_param 24 | 25 | def fit(self, X, y): 26 | """A reference implementation of a fitting function. 27 | 28 | Parameters 29 | ---------- 30 | X : {array-like, sparse matrix}, shape (n_samples, n_features) 31 | The training input samples. 32 | y : array-like, shape (n_samples,) or (n_samples, n_outputs) 33 | The target values (class labels in classification, real numbers in 34 | regression). 35 | 36 | Returns 37 | ------- 38 | self : object 39 | Returns self. 40 | """ 41 | X, y = check_X_y(X, y, accept_sparse=True) 42 | self.is_fitted_ = True 43 | # `fit` should always return `self` 44 | return self 45 | 46 | def predict(self, X): 47 | """ A reference implementation of a predicting function. 48 | 49 | Parameters 50 | ---------- 51 | X : {array-like, sparse matrix}, shape (n_samples, n_features) 52 | The training input samples. 53 | 54 | Returns 55 | ------- 56 | y : ndarray, shape (n_samples,) 57 | Returns an array of ones. 58 | """ 59 | X = check_array(X, accept_sparse=True) 60 | check_is_fitted(self, 'is_fitted_') 61 | return np.ones(X.shape[0], dtype=np.int64) 62 | 63 | 64 | class TemplateClassifier(BaseEstimator, ClassifierMixin): 65 | """ An example classifier which implements a 1-NN algorithm. 66 | 67 | For more information regarding how to build your own classifier, read more 68 | in the :ref:`User Guide `. 69 | 70 | Parameters 71 | ---------- 72 | demo_param : str, default='demo' 73 | A parameter used for demonstation of how to pass and store paramters. 74 | 75 | Attributes 76 | ---------- 77 | X_ : ndarray, shape (n_samples, n_features) 78 | The input passed during :meth:`fit`. 79 | y_ : ndarray, shape (n_samples,) 80 | The labels passed during :meth:`fit`. 81 | classes_ : ndarray, shape (n_classes,) 82 | The classes seen at :meth:`fit`. 83 | """ 84 | def __init__(self, demo_param='demo'): 85 | self.demo_param = demo_param 86 | 87 | def fit(self, X, y): 88 | """A reference implementation of a fitting function for a classifier. 89 | 90 | Parameters 91 | ---------- 92 | X : array-like, shape (n_samples, n_features) 93 | The training input samples. 94 | y : array-like, shape (n_samples,) 95 | The target values. An array of int. 96 | 97 | Returns 98 | ------- 99 | self : object 100 | Returns self. 101 | """ 102 | # Check that X and y have correct shape 103 | X, y = check_X_y(X, y) 104 | # Store the classes seen during fit 105 | self.classes_ = unique_labels(y) 106 | 107 | self.X_ = X 108 | self.y_ = y 109 | # Return the classifier 110 | return self 111 | 112 | def predict(self, X): 113 | """ A reference implementation of a prediction for a classifier. 114 | 115 | Parameters 116 | ---------- 117 | X : array-like, shape (n_samples, n_features) 118 | The input samples. 119 | 120 | Returns 121 | ------- 122 | y : ndarray, shape (n_samples,) 123 | The label for each sample is the label of the closest sample 124 | seen during fit. 125 | """ 126 | # Check is fit had been called 127 | check_is_fitted(self, ['X_', 'y_']) 128 | 129 | # Input validation 130 | X = check_array(X) 131 | 132 | closest = np.argmin(euclidean_distances(X, self.X_), axis=1) 133 | return self.y_[closest] 134 | 135 | 136 | class TemplateTransformer(BaseEstimator, TransformerMixin): 137 | """ An example transformer that returns the element-wise square root. 138 | 139 | For more information regarding how to build your own transformer, read more 140 | in the :ref:`User Guide `. 141 | 142 | Parameters 143 | ---------- 144 | demo_param : str, default='demo' 145 | A parameter used for demonstation of how to pass and store paramters. 146 | 147 | Attributes 148 | ---------- 149 | n_features_ : int 150 | The number of features of the data passed to :meth:`fit`. 151 | """ 152 | def __init__(self, demo_param='demo'): 153 | self.demo_param = demo_param 154 | 155 | def fit(self, X, y=None): 156 | """A reference implementation of a fitting function for a transformer. 157 | 158 | Parameters 159 | ---------- 160 | X : {array-like, sparse matrix}, shape (n_samples, n_features) 161 | The training input samples. 162 | y : None 163 | There is no need of a target in a transformer, yet the pipeline API 164 | requires this parameter. 165 | 166 | Returns 167 | ------- 168 | self : object 169 | Returns self. 170 | """ 171 | X = check_array(X, accept_sparse=True) 172 | 173 | self.n_features_ = X.shape[1] 174 | 175 | # Return the transformer 176 | return self 177 | 178 | def transform(self, X): 179 | """ A reference implementation of a transform function. 180 | 181 | Parameters 182 | ---------- 183 | X : {array-like, sparse-matrix}, shape (n_samples, n_features) 184 | The input samples. 185 | 186 | Returns 187 | ------- 188 | X_transformed : array, shape (n_samples, n_features) 189 | The array containing the element-wise square roots of the values 190 | in ``X``. 191 | """ 192 | # Check is fit had been called 193 | check_is_fitted(self, 'n_features_') 194 | 195 | # Input validation 196 | X = check_array(X, accept_sparse=True) 197 | 198 | # Check that the input is of the same shape as the one passed 199 | # during fit. 200 | if X.shape[1] != self.n_features_: 201 | raise ValueError('Shape of input is different from what was seen' 202 | 'in `fit`') 203 | return np.sqrt(X) 204 | -------------------------------------------------------------------------------- /geomstats/learning/mean_shift.py: -------------------------------------------------------------------------------- 1 | """Mean shift clustering algorithm on Manifolds. 2 | """ 3 | 4 | 5 | from sklearn.base import BaseEstimator, ClusterMixin 6 | from sklearn.metrics.pairwise import pairwise_distances_argmin 7 | from sklearn.utils import check_array 8 | from sklearn.utils.validation import check_is_fitted 9 | 10 | 11 | def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False, 12 | min_bin_freq=1, cluster_all=True, max_iter=300, 13 | n_jobs=None): 14 | """Perform mean shift clustering of data. 15 | 16 | Parameters 17 | ---------- 18 | 19 | X : array-like, shape=[n_samples, n_features] 20 | Input data. 21 | 22 | bandwidth : float, optional 23 | Kernel bandwidth. 24 | 25 | If bandwidth is not given, it is determined using a heuristic based on 26 | the median of all pairwise distances. This will take quadratic time in 27 | the number of samples. The sklearn.cluster.estimate_bandwidth function 28 | can be used to do this more efficiently. 29 | 30 | seeds : array-like, shape=[n_seeds, n_features] or None 31 | Point used as initial kernel locations. If None and bin_seeding=False, 32 | each data point is used as a seed. If None and bin_seeding=True, 33 | see bin_seeding. 34 | 35 | bin_seeding : boolean, default=False 36 | If true, initial kernel locations are not locations of all 37 | points, but rather the location of the discretized version of 38 | points, where points are binned onto a grid whose coarseness 39 | corresponds to the bandwidth. Setting this option to True will speed 40 | up the algorithm because fewer seeds will be initialized. 41 | Ignored if seeds argument is not None. 42 | 43 | min_bin_freq : int, default=1 44 | To speed up the algorithm, accept only those bins with at least 45 | min_bin_freq points as seeds. 46 | 47 | cluster_all : boolean, default True 48 | If true, then all points are clustered, even those orphans that are 49 | not within any kernel. Orphans are assigned to the nearest kernel. 50 | If false, then orphans are given cluster label -1. 51 | 52 | max_iter : int, default 300 53 | Maximum number of iterations, per seed point before the clustering 54 | operation terminates (for that seed point), if has not converged yet. 55 | 56 | n_jobs : int or None, optional (default=None) 57 | The number of jobs to use for the computation. This works by computing 58 | each of the n_init runs in parallel. 59 | 60 | ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. 61 | ``-1`` means using all processors. See :term:`Glossary ` 62 | for more details. 63 | 64 | .. versionadded:: 0.17 65 | Parallel Execution using *n_jobs*. 66 | 67 | Returns 68 | ------- 69 | 70 | cluster_centers : array, shape=[n_clusters, n_features] 71 | Coordinates of cluster centers. 72 | 73 | labels : array, shape=[n_samples] 74 | Cluster labels for each point. 75 | 76 | """ 77 | raise NotImplementedError() 78 | 79 | 80 | class MeanShift(BaseEstimator, ClusterMixin): 81 | """Mean shift clustering using a flat kernel. 82 | 83 | Mean shift clustering aims to discover "blobs" in a smooth density of 84 | samples. It is a centroid-based algorithm, which works by updating 85 | candidates for centroids to be the mean of the points within a given 86 | region. These candidates are then filtered in a post-processing stage to 87 | eliminate near-duplicates to form the final set of centroids. 88 | 89 | Seeding is performed using a binning technique for scalability. 90 | 91 | Parameters 92 | ---------- 93 | bandwidth : float, optional 94 | Bandwidth used in the RBF kernel. 95 | 96 | If not given, the bandwidth is estimated using 97 | sklearn.cluster.estimate_bandwidth; see the documentation for that 98 | function for hints on scalability (see also the Notes, below). 99 | 100 | seeds : array, shape=[n_samples, n_features], optional 101 | Seeds used to initialize kernels. If not set, 102 | the seeds are calculated by clustering.get_bin_seeds 103 | with bandwidth as the grid size and default values for 104 | other parameters. 105 | 106 | bin_seeding : boolean, optional 107 | If true, initial kernel locations are not locations of all 108 | points, but rather the location of the discretized version of 109 | points, where points are binned onto a grid whose coarseness 110 | corresponds to the bandwidth. Setting this option to True will speed 111 | up the algorithm because fewer seeds will be initialized. 112 | default value: False 113 | Ignored if seeds argument is not None. 114 | 115 | min_bin_freq : int, optional 116 | To speed up the algorithm, accept only those bins with at least 117 | min_bin_freq points as seeds. If not defined, set to 1. 118 | 119 | cluster_all : boolean, default True 120 | If true, then all points are clustered, even those orphans that are 121 | not within any kernel. Orphans are assigned to the nearest kernel. 122 | If false, then orphans are given cluster label -1. 123 | 124 | n_jobs : int or None, optional (default=None) 125 | The number of jobs to use for the computation. This works by computing 126 | each of the n_init runs in parallel. 127 | 128 | ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. 129 | ``-1`` means using all processors. See :term:`Glossary ` 130 | for more details. 131 | 132 | Attributes 133 | ---------- 134 | cluster_centers_ : array, [n_clusters, n_features] 135 | Coordinates of cluster centers. 136 | 137 | labels_ : 138 | Labels of each point. 139 | """ 140 | def __init__(self, bandwidth=None, seeds=None, bin_seeding=False, 141 | min_bin_freq=1, cluster_all=True, n_jobs=None): 142 | self.bandwidth = bandwidth 143 | self.seeds = seeds 144 | self.bin_seeding = bin_seeding 145 | self.cluster_all = cluster_all 146 | self.min_bin_freq = min_bin_freq 147 | self.n_jobs = n_jobs 148 | 149 | def fit(self, X, y=None): 150 | """Perform clustering. 151 | 152 | Parameters 153 | ---------- 154 | X : array-like, shape=[n_samples, n_features] 155 | Samples to cluster. 156 | 157 | y : Ignored 158 | 159 | """ 160 | X = check_array(X) 161 | self.cluster_centers_, self.labels_ = \ 162 | mean_shift(X, bandwidth=self.bandwidth, seeds=self.seeds, 163 | min_bin_freq=self.min_bin_freq, 164 | bin_seeding=self.bin_seeding, 165 | cluster_all=self.cluster_all, n_jobs=self.n_jobs) 166 | return self 167 | 168 | def predict(self, X): 169 | """Predict the closest cluster each sample in X belongs to. 170 | 171 | Parameters 172 | ---------- 173 | X : {array-like, sparse matrix}, shape=[n_samples, n_features] 174 | New data to predict. 175 | 176 | Returns 177 | ------- 178 | labels : array, shape [n_samples,] 179 | Index of the cluster each sample belongs to. 180 | """ 181 | check_is_fitted(self, "cluster_centers_") 182 | 183 | return pairwise_distances_argmin(X, self.cluster_centers_) 184 | -------------------------------------------------------------------------------- /geomstats/backend/numpy.py: -------------------------------------------------------------------------------- 1 | """Numpy based computation backend.""" 2 | 3 | import numpy as np 4 | 5 | 6 | int32 = np.int32 7 | int8 = np.int8 8 | float32 = np.float32 9 | float64 = np.float64 10 | 11 | 12 | def while_loop(cond, body, loop_vars, maximum_iterations): 13 | iteration = 0 14 | while cond(*loop_vars): 15 | loop_vars = body(*loop_vars) 16 | iteration += 1 17 | if iteration >= maximum_iterations: 18 | break 19 | return loop_vars 20 | 21 | 22 | def logical_or(x, y): 23 | bool_result = x or y 24 | return bool_result 25 | 26 | 27 | def get_mask_i_float(i, n): 28 | range_n = arange(n) 29 | i_float = cast(array([i]), int32)[0] 30 | mask_i = equal(range_n, i_float) 31 | mask_i_float = cast(mask_i, float32) 32 | return mask_i_float 33 | 34 | 35 | def gather(x, indices): 36 | return x[indices] 37 | 38 | 39 | def vectorize(x, pyfunc, multiple_args=False, signature=None, **kwargs): 40 | if multiple_args: 41 | return np.vectorize(pyfunc, signature=signature)(*x) 42 | return np.vectorize(pyfunc, signature=signature)(x) 43 | 44 | 45 | def cond(pred, true_fn, false_fn): 46 | if pred: 47 | return true_fn() 48 | return false_fn() 49 | 50 | 51 | def real(x): 52 | return np.real(x) 53 | 54 | 55 | def reshape(*args, **kwargs): 56 | return np.reshape(*args, **kwargs) 57 | 58 | 59 | def cast_to_complex(x): 60 | return np.vectorize(complex)(x) 61 | 62 | 63 | def boolean_mask(x, mask): 64 | return x[mask] 65 | 66 | 67 | def flip(*args, **kwargs): 68 | return np.flip(*args, **kwargs) 69 | 70 | 71 | def amax(*args, **kwargs): 72 | return np.amax(*args, **kwargs) 73 | 74 | 75 | def arctan2(*args, **kwargs): 76 | return np.arctan2(*args, **kwargs) 77 | 78 | 79 | def cast(x, dtype): 80 | return x.astype(dtype) 81 | 82 | 83 | def divide(*args, **kwargs): 84 | return np.divide(*args, **kwargs) 85 | 86 | 87 | def repeat(*args, **kwargs): 88 | return np.repeat(*args, **kwargs) 89 | 90 | 91 | def asarray(*args, **kwargs): 92 | return np.asarray(*args, **kwargs) 93 | 94 | 95 | def concatenate(*args, **kwargs): 96 | return np.concatenate(*args, **kwargs) 97 | 98 | 99 | def identity(val): 100 | return np.identity(val) 101 | 102 | 103 | def hstack(val): 104 | return np.hstack(val) 105 | 106 | 107 | def stack(*args, **kwargs): 108 | return np.stack(*args, **kwargs) 109 | 110 | 111 | def vstack(val): 112 | return np.vstack(val) 113 | 114 | 115 | def array(val): 116 | return np.array(val) 117 | 118 | 119 | def abs(val): 120 | return np.abs(val) 121 | 122 | 123 | def zeros(val): 124 | return np.zeros(val) 125 | 126 | 127 | def ones(val): 128 | return np.ones(val) 129 | 130 | 131 | def ones_like(*args, **kwargs): 132 | return np.ones_like(*args, **kwargs) 133 | 134 | 135 | def empty_like(*args, **kwargs): 136 | return np.empty_like(*args, **kwargs) 137 | 138 | 139 | def all(*args, **kwargs): 140 | return np.all(*args, **kwargs) 141 | 142 | 143 | def allclose(a, b, **kwargs): 144 | return np.allclose(a, b, **kwargs) 145 | 146 | 147 | def sin(val): 148 | return np.sin(val) 149 | 150 | 151 | def cos(val): 152 | return np.cos(val) 153 | 154 | 155 | def cosh(*args, **kwargs): 156 | return np.cosh(*args, **kwargs) 157 | 158 | 159 | def sinh(*args, **kwargs): 160 | return np.sinh(*args, **kwargs) 161 | 162 | 163 | def tanh(*args, **kwargs): 164 | return np.tanh(*args, **kwargs) 165 | 166 | 167 | def arccosh(*args, **kwargs): 168 | return np.arccosh(*args, **kwargs) 169 | 170 | 171 | def tan(val): 172 | return np.tan(val) 173 | 174 | 175 | def arcsin(val): 176 | return np.arcsin(val) 177 | 178 | 179 | def arccos(val): 180 | return np.arccos(val) 181 | 182 | 183 | def shape(val): 184 | return val.shape 185 | 186 | 187 | def dot(a, b): 188 | return np.dot(a, b) 189 | 190 | 191 | def maximum(a, b): 192 | return np.maximum(a, b) 193 | 194 | 195 | def greater(a, b): 196 | return np.greater(a, b) 197 | 198 | 199 | def greater_equal(a, b): 200 | return np.greater_equal(a, b) 201 | 202 | 203 | def to_ndarray(x, to_ndim, axis=0): 204 | x = np.asarray(x) 205 | if x.ndim == to_ndim - 1: 206 | x = np.expand_dims(x, axis=axis) 207 | assert x.ndim >= to_ndim 208 | return x 209 | 210 | 211 | def sqrt(val): 212 | return np.sqrt(val) 213 | 214 | 215 | def norm(val, axis): 216 | return np.linalg.norm(val, axis=axis) 217 | 218 | 219 | def rand(*args, **largs): 220 | return np.random.rand(*args, **largs) 221 | 222 | 223 | def randint(*args, **kwargs): 224 | return np.random.randint(*args, **kwargs) 225 | 226 | 227 | def isclose(*args, **kwargs): 228 | return np.isclose(*args, **kwargs) 229 | 230 | 231 | def less_equal(a, b): 232 | return np.less_equal(a, b) 233 | 234 | 235 | def less(a, b): 236 | return np.less(a, b) 237 | 238 | 239 | def eye(*args, **kwargs): 240 | return np.eye(*args, **kwargs) 241 | 242 | 243 | def average(*args, **kwargs): 244 | return np.average(*args, **kwargs) 245 | 246 | 247 | def matmul(*args, **kwargs): 248 | return np.matmul(*args, **kwargs) 249 | 250 | 251 | def sum(*args, **kwargs): 252 | return np.sum(*args, **kwargs) 253 | 254 | 255 | def einsum(*args, **kwargs): 256 | return np.einsum(*args, **kwargs) 257 | 258 | 259 | def transpose(*args, **kwargs): 260 | return np.transpose(*args, **kwargs) 261 | 262 | 263 | def squeeze(*args, **kwargs): 264 | return np.squeeze(*args, **kwargs) 265 | 266 | 267 | def zeros_like(*args, **kwargs): 268 | return np.zeros_like(*args, **kwargs) 269 | 270 | 271 | def trace(*args, **kwargs): 272 | return np.trace(*args, **kwargs) 273 | 274 | 275 | def mod(*args, **kwargs): 276 | return np.mod(*args, **kwargs) 277 | 278 | 279 | def linspace(*args, **kwargs): 280 | return np.linspace(*args, **kwargs) 281 | 282 | 283 | def equal(*args, **kwargs): 284 | return np.equal(*args, **kwargs) 285 | 286 | 287 | def floor(*args, **kwargs): 288 | return np.floor(*args, **kwargs) 289 | 290 | 291 | def cross(*args, **kwargs): 292 | return np.cross(*args, **kwargs) 293 | 294 | 295 | def triu_indices(*args, **kwargs): 296 | return np.triu_indices(*args, **kwargs) 297 | 298 | 299 | def where(*args, **kwargs): 300 | return np.where(*args, **kwargs) 301 | 302 | 303 | def tile(*args, **kwargs): 304 | return np.tile(*args, **kwargs) 305 | 306 | 307 | def clip(*args, **kwargs): 308 | return np.clip(*args, **kwargs) 309 | 310 | 311 | def diag(x): 312 | x = to_ndarray(x, to_ndim=2) 313 | _, n = shape(x) 314 | aux = np.vectorize( 315 | np.diagflat, 316 | signature='(m,n)->(k,k)')(x) 317 | k, k = shape(aux) 318 | m = int(k / n) 319 | result = zeros((m, n, n)) 320 | for i in range(m): 321 | result[i] = aux[i*n:(i+1)*n, i*n:(i+1)*n] 322 | return result 323 | 324 | 325 | def any(*args, **kwargs): 326 | return np.any(*args, **kwargs) 327 | 328 | 329 | def expand_dims(*args, **kwargs): 330 | return np.expand_dims(*args, **kwargs) 331 | 332 | 333 | def outer(*args, **kwargs): 334 | return np.outer(*args, **kwargs) 335 | 336 | 337 | def hsplit(*args, **kwargs): 338 | return np.hsplit(*args, **kwargs) 339 | 340 | 341 | def argmax(*args, **kwargs): 342 | return np.argmax(*args, **kwargs) 343 | 344 | 345 | def argmin(*args, **kwargs): 346 | return np.argmin(*args, **kwargs) 347 | 348 | 349 | def diagonal(*args, **kwargs): 350 | return np.diagonal(*args, **kwargs) 351 | 352 | 353 | def exp(*args, **kwargs): 354 | return np.exp(*args, **kwargs) 355 | 356 | 357 | def log(*args, **kwargs): 358 | return np.log(*args, **kwargs) 359 | 360 | 361 | def cov(*args, **kwargs): 362 | return np.cov(*args, **kwargs) 363 | 364 | 365 | def eval(x): 366 | return x 367 | 368 | 369 | def ndim(x): 370 | return x.ndim 371 | 372 | 373 | def nonzero(x): 374 | return np.nonzero(x) 375 | 376 | 377 | def copy(x): 378 | return np.copy(x) 379 | 380 | 381 | def ix_(*args): 382 | return np.ix_(*args) 383 | 384 | 385 | def arange(*args, **kwargs): 386 | return np.arange(*args, **kwargs) 387 | 388 | 389 | def prod(x, axis=None): 390 | return np.prod(x, axis=axis) 391 | 392 | 393 | def sign(*args, **kwargs): 394 | return np.sign(*args, **kwargs) 395 | 396 | 397 | def mean(x, axis=None): 398 | return np.mean(x, axis) 399 | 400 | 401 | def normal(*args, **kwargs): 402 | return np.random.normal(*args, **kwargs) 403 | -------------------------------------------------------------------------------- /geomstats/learning/pca.py: -------------------------------------------------------------------------------- 1 | """ Principal Component Analysis on Manifolds 2 | """ 3 | 4 | from math import log 5 | import numbers 6 | 7 | import numpy as np 8 | from scipy import linalg 9 | from scipy.special import gammaln 10 | 11 | from sklearn.decomposition.base import _BasePCA 12 | from sklearn.utils.extmath import svd_flip 13 | from sklearn.utils.extmath import stable_cumsum 14 | from sklearn.utils.validation import check_array 15 | 16 | 17 | def _assess_dimension_(spectrum, rank, n_samples, n_features): 18 | """Compute the likelihood of a rank ``rank`` dataset 19 | 20 | The dataset is assumed to be embedded in gaussian noise of shape(n, 21 | dimf) having spectrum ``spectrum``. 22 | 23 | Parameters 24 | ---------- 25 | spectrum : array of shape (n) 26 | Data spectrum. 27 | rank : int 28 | Tested rank value. 29 | n_samples : int 30 | Number of samples. 31 | n_features : int 32 | Number of features. 33 | 34 | Returns 35 | ------- 36 | ll : float, 37 | The log-likelihood 38 | 39 | Notes 40 | ----- 41 | This implements the method of `Thomas P. Minka: 42 | Automatic Choice of Dimensionality for PCA. NIPS 2000: 598-604` 43 | """ 44 | if rank > len(spectrum): 45 | raise ValueError("The tested rank cannot exceed the rank of the" 46 | " dataset") 47 | 48 | pu = -rank * log(2.) 49 | for i in range(rank): 50 | pu += (gammaln((n_features - i) / 2.) - 51 | log(np.pi) * (n_features - i) / 2.) 52 | 53 | pl = np.sum(np.log(spectrum[:rank])) 54 | pl = -pl * n_samples / 2. 55 | 56 | if rank == n_features: 57 | pv = 0 58 | v = 1 59 | else: 60 | v = np.sum(spectrum[rank:]) / (n_features - rank) 61 | pv = -np.log(v) * n_samples * (n_features - rank) / 2. 62 | 63 | m = n_features * rank - rank * (rank + 1.) / 2. 64 | pp = log(2. * np.pi) * (m + rank + 1.) / 2. 65 | 66 | pa = 0. 67 | spectrum_ = spectrum.copy() 68 | spectrum_[rank:n_features] = v 69 | for i in range(rank): 70 | for j in range(i + 1, len(spectrum)): 71 | pa += log((spectrum[i] - spectrum[j]) * 72 | (1. / spectrum_[j] - 1. / spectrum_[i])) + log(n_samples) 73 | 74 | ll = pu + pl + pv + pp - pa / 2. - rank * log(n_samples) / 2. 75 | 76 | return ll 77 | 78 | 79 | def _infer_dimension_(spectrum, n_samples, n_features): 80 | """Infers the dimension of a dataset of shape (n_samples, n_features) 81 | 82 | The dataset is described by its spectrum `spectrum`. 83 | """ 84 | n_spectrum = len(spectrum) 85 | ll = np.empty(n_spectrum) 86 | for rank in range(n_spectrum): 87 | ll[rank] = _assess_dimension_(spectrum, rank, n_samples, n_features) 88 | return ll.argmax() 89 | 90 | 91 | class TangentPCA(_BasePCA): 92 | """Tangent Principal component analysis (tPCA) 93 | 94 | Linear dimensionality reduction using 95 | Singular Value Decomposition of the 96 | Riemannian Log of the data at the tangent space 97 | of the mean. 98 | """ 99 | 100 | def __init__(self, metric, n_components=None, copy=True, 101 | whiten=False, tol=0.0, iterated_power='auto', 102 | random_state=None): 103 | self.metric = metric 104 | self.n_components = n_components 105 | self.copy = copy 106 | self.whiten = whiten 107 | self.tol = tol 108 | self.iterated_power = iterated_power 109 | self.random_state = random_state 110 | 111 | def fit(self, X, 112 | base_point=None, point_type='vector', y=None): 113 | """Fit the model with X. 114 | 115 | Parameters 116 | ---------- 117 | X : array-like, shape (n_samples, n_features) 118 | Training data, where n_samples is the number of samples 119 | and n_features is the number of features. 120 | 121 | y : Ignored 122 | 123 | Returns 124 | ------- 125 | self : object 126 | Returns the instance itself. 127 | """ 128 | self._fit(X, base_point, point_type) 129 | return self 130 | 131 | def fit_transform(self, X, 132 | base_point=None, point_type='vector', 133 | y=None): 134 | """Fit the model with X and apply the dimensionality reduction on X. 135 | 136 | Parameters 137 | ---------- 138 | X : array-like, shape (n_samples, n_features) 139 | Training data, where n_samples is the number of samples 140 | and n_features is the number of features. 141 | 142 | y : Ignored 143 | 144 | Returns 145 | ------- 146 | X_new : array-like, shape (n_samples, n_components) 147 | 148 | """ 149 | U, S, V = self._fit(X, base_point, point_type) 150 | U = U[:, :self.n_components_] 151 | 152 | U *= S[:self.n_components_] 153 | 154 | return U 155 | 156 | def _fit(self, X, base_point=None, point_type='vector'): 157 | """Fit the model by computing full SVD on X""" 158 | if point_type == 'matrix': 159 | raise NotImplementedError( 160 | 'This is currently only implemented for vectors.') 161 | if base_point is None: 162 | base_point = self.metric.mean(X) 163 | 164 | tangent_vecs = self.metric.log(X, base_point=base_point) 165 | 166 | # Convert to sklearn format 167 | X = tangent_vecs 168 | 169 | X = check_array(X, dtype=[np.float64, np.float32], ensure_2d=True, 170 | copy=self.copy) 171 | 172 | # Handle n_components==None 173 | if self.n_components is None: 174 | n_components = min(X.shape) 175 | else: 176 | n_components = self.n_components 177 | n_samples, n_features = X.shape 178 | 179 | if n_components == 'mle': 180 | if n_samples < n_features: 181 | raise ValueError("n_components='mle' is only supported " 182 | "if n_samples >= n_features") 183 | elif not 0 <= n_components <= min(n_samples, n_features): 184 | raise ValueError("n_components=%r must be between 0 and " 185 | "min(n_samples, n_features)=%r with " 186 | "svd_solver='full'" 187 | % (n_components, min(n_samples, n_features))) 188 | elif n_components >= 1: 189 | if not isinstance(n_components, (numbers.Integral, np.integer)): 190 | raise ValueError("n_components=%r must be of type int " 191 | "when greater than or equal to 1, " 192 | "was of type=%r" 193 | % (n_components, type(n_components))) 194 | 195 | # Center data 196 | self.mean_ = np.mean(X, axis=0) 197 | X -= self.mean_ 198 | 199 | U, S, V = linalg.svd(X, full_matrices=False) 200 | # flip eigenvectors' sign to enforce deterministic output 201 | U, V = svd_flip(U, V) 202 | 203 | components_ = V 204 | 205 | # Get variance explained by singular values 206 | explained_variance_ = (S ** 2) / (n_samples - 1) 207 | total_var = explained_variance_.sum() 208 | explained_variance_ratio_ = explained_variance_ / total_var 209 | singular_values_ = S.copy() # Store the singular values. 210 | 211 | # Postprocess the number of components required 212 | if n_components == 'mle': 213 | n_components = \ 214 | _infer_dimension_(explained_variance_, n_samples, n_features) 215 | elif 0 < n_components < 1.0: 216 | # number of components for which the cumulated explained 217 | # variance percentage is superior to the desired threshold 218 | ratio_cumsum = stable_cumsum(explained_variance_ratio_) 219 | n_components = np.searchsorted(ratio_cumsum, n_components) + 1 220 | 221 | # Compute noise covariance using Probabilistic PCA model 222 | # The sigma2 maximum likelihood (cf. eq. 12.46) 223 | if n_components < min(n_features, n_samples): 224 | self.noise_variance_ = explained_variance_[n_components:].mean() 225 | else: 226 | self.noise_variance_ = 0. 227 | 228 | self.n_samples_, self.n_features_ = n_samples, n_features 229 | self.components_ = components_[:n_components] 230 | self.n_components_ = n_components 231 | self.explained_variance_ = explained_variance_[:n_components] 232 | self.explained_variance_ratio_ = \ 233 | explained_variance_ratio_[:n_components] 234 | self.singular_values_ = singular_values_[:n_components] 235 | 236 | return U, S, V 237 | -------------------------------------------------------------------------------- /geomstats/backend/pytorch.py: -------------------------------------------------------------------------------- 1 | """Pytorch based computation backend.""" 2 | 3 | import numpy as np 4 | import torch 5 | 6 | double = 'torch.DoubleTensor' 7 | float16 = 'torch.Float' 8 | float32 = 'torch.FloatTensor' 9 | float64 = 'torch.DoubleTensor' 10 | int32 = 'torch.LongTensor' 11 | int8 = 'torch.ByteTensor' 12 | 13 | 14 | def cond(pred, true_fn, false_fn): 15 | if pred: 16 | return true_fn() 17 | return false_fn() 18 | 19 | 20 | def amax(x): 21 | return torch.max(x) 22 | 23 | 24 | def boolean_mask(x, mask): 25 | idx = np.argwhere(np.asarray(mask)) 26 | return x[idx] 27 | 28 | 29 | def arctan2(*args, **kwargs): 30 | return torch.arctan2(*args, **kwargs) 31 | 32 | 33 | def cast(x, dtype): 34 | x = array(x) 35 | return x.type(dtype) 36 | 37 | 38 | def divide(*args, **kwargs): 39 | return torch.div(*args, **kwargs) 40 | 41 | 42 | def repeat(*args, **kwargs): 43 | return torch.repeat(*args, **kwargs) 44 | 45 | 46 | def asarray(x): 47 | return np.asarray(x) 48 | 49 | 50 | def concatenate(seq, axis=0, out=None): 51 | seq = [t.float() for t in seq] 52 | return torch.cat(seq, dim=axis, out=out) 53 | 54 | 55 | def identity(val): 56 | return torch.eye(val) 57 | 58 | 59 | def hstack(seq): 60 | return concatenate(seq, axis=1) 61 | 62 | 63 | def stack(*args, **kwargs): 64 | return torch.stack(*args, **kwargs) 65 | 66 | 67 | def vstack(seq): 68 | return concatenate(seq) 69 | 70 | 71 | def array(val): 72 | if type(val) == list: 73 | if type(val[0]) != torch.Tensor: 74 | val = np.copy(np.array(val)) 75 | else: 76 | val = concatenate(val) 77 | 78 | if type(val) == bool: 79 | val = np.array(val) 80 | if type(val) == np.ndarray: 81 | if val.dtype == bool: 82 | val = torch.from_numpy(np.array(val, dtype=np.uint8)) 83 | elif val.dtype == np.float32 or val.dtype == np.float64: 84 | val = torch.from_numpy(np.array(val, dtype=np.float32)) 85 | else: 86 | val = torch.from_numpy(val) 87 | 88 | if type(val) != torch.Tensor: 89 | val = torch.Tensor([val]) 90 | if val.dtype == torch.float64: 91 | val = val.float() 92 | return val 93 | 94 | 95 | def abs(val): 96 | return torch.abs(val) 97 | 98 | 99 | def zeros(*args): 100 | return torch.from_numpy(np.zeros(*args)).float() 101 | 102 | 103 | def ones(*args): 104 | return torch.from_numpy(np.ones(*args)).float() 105 | 106 | 107 | def ones_like(*args, **kwargs): 108 | return torch.ones_like(*args, **kwargs) 109 | 110 | 111 | def empty_like(*args, **kwargs): 112 | return torch.empty_like(*args, **kwargs) 113 | 114 | 115 | def all(x, axis=None): 116 | if axis is None: 117 | return x.byte().all() 118 | return torch.from_numpy(np.all(x, axis=axis).astype(int)) 119 | 120 | 121 | def allclose(a, b, **kwargs): 122 | a = torch.tensor(a) 123 | b = torch.tensor(b) 124 | a = a.float() 125 | b = b.float() 126 | a = to_ndarray(a, to_ndim=1) 127 | b = to_ndarray(b, to_ndim=1) 128 | n_a = a.shape[0] 129 | n_b = b.shape[0] 130 | ndim = len(a.shape) 131 | if n_a > n_b: 132 | reps = (int(n_a / n_b),) + (ndim-1) * (1,) 133 | b = tile(b, reps) 134 | elif n_a < n_b: 135 | reps = (int(n_b / n_a),) + (ndim-1) * (1,) 136 | a = tile(a, reps) 137 | return torch.allclose(a, b, **kwargs) 138 | 139 | 140 | def sin(val): 141 | return torch.sin(val) 142 | 143 | 144 | def cos(val): 145 | return torch.cos(val) 146 | 147 | 148 | def cosh(*args, **kwargs): 149 | return torch.cosh(*args, **kwargs) 150 | 151 | 152 | def arccosh(x): 153 | c0 = torch.log(x) 154 | c1 = torch.log1p(torch.sqrt(x * x - 1) / x) 155 | return c0 + c1 156 | 157 | 158 | def sinh(*args, **kwargs): 159 | return torch.sinh(*args, **kwargs) 160 | 161 | 162 | def tanh(*args, **kwargs): 163 | return torch.tanh(*args, **kwargs) 164 | 165 | 166 | def arcsinh(x): 167 | return torch.log(x + torch.sqrt(x*x+1)) 168 | 169 | 170 | def arcosh(x): 171 | return torch.log(x + torch.sqrt(x*x-1)) 172 | 173 | 174 | def tan(val): 175 | return torch.tan(val) 176 | 177 | 178 | def arcsin(val): 179 | return torch.asin(val) 180 | 181 | 182 | def arccos(val): 183 | return torch.acos(val) 184 | 185 | 186 | def shape(val): 187 | return val.shape 188 | 189 | 190 | def dot(a, b): 191 | dot = np.dot(a, b) 192 | return torch.from_numpy(np.array(dot)).float() 193 | 194 | 195 | def maximum(a, b): 196 | return torch.max(array(a), array(b)) 197 | 198 | 199 | def greater(a, b): 200 | return torch.gt(a, b) 201 | 202 | 203 | def greater_equal(a, b): 204 | return torch.greater_equal(a, b) 205 | 206 | 207 | def to_ndarray(x, to_ndim, axis=0): 208 | x = array(x) 209 | if x.dim() == to_ndim - 1: 210 | x = torch.unsqueeze(x, dim=axis) 211 | assert x.dim() >= to_ndim 212 | return x 213 | 214 | 215 | def sqrt(val): 216 | return torch.sqrt(torch.tensor(val).float()) 217 | 218 | 219 | def norm(val, axis): 220 | return torch.linalg.norm(val, axis=axis) 221 | 222 | 223 | def rand(*args, **largs): 224 | return torch.random.rand(*args, **largs) 225 | 226 | 227 | def isclose(*args, **kwargs): 228 | return torch.from_numpy(np.isclose(*args, **kwargs).astype(int)).byte() 229 | 230 | 231 | def less(a, b): 232 | return torch.le(a, b) 233 | 234 | 235 | def less_equal(a, b): 236 | return np.less_equal(a, b) 237 | 238 | 239 | def eye(*args, **kwargs): 240 | return torch.eye(*args, **kwargs) 241 | 242 | 243 | def average(*args, **kwargs): 244 | return torch.average(*args, **kwargs) 245 | 246 | 247 | def matmul(*args, **kwargs): 248 | return torch.matmul(*args, **kwargs) 249 | 250 | 251 | def sum(x, axis=None, **kwargs): 252 | if axis is None: 253 | return torch.sum(x, **kwargs) 254 | return torch.sum(x, dim=axis, **kwargs) 255 | 256 | 257 | def einsum(*args, **kwargs): 258 | return torch.from_numpy(np.einsum(*args, **kwargs)).float() 259 | 260 | 261 | def T(x): 262 | return torch.t(x) 263 | 264 | 265 | def transpose(x, axes=None): 266 | if axes: 267 | return x.permute(axes) 268 | if len(shape(x)) == 1: 269 | return x 270 | return x.t() 271 | 272 | 273 | def squeeze(x, axis=None): 274 | return torch.squeeze(x, dim=axis) 275 | 276 | 277 | def zeros_like(*args, **kwargs): 278 | return torch.zeros_like(*args, **kwargs) 279 | 280 | 281 | def trace(*args, **kwargs): 282 | trace = np.trace(*args, **kwargs) 283 | return torch.from_numpy(np.array(trace)).float() 284 | 285 | 286 | def mod(*args, **kwargs): 287 | return torch.fmod(*args, **kwargs) 288 | 289 | 290 | def linspace(start, stop, num): 291 | return torch.linspace(start=start, end=stop, steps=num) 292 | 293 | 294 | def equal(a, b, **kwargs): 295 | if a.dtype == torch.ByteTensor: 296 | a = cast(a, torch.uint8).float() 297 | if b.dtype == torch.ByteTensor: 298 | b = cast(b, torch.uint8).float() 299 | return torch.equal(a, b, **kwargs) 300 | 301 | 302 | def floor(*args, **kwargs): 303 | return torch.floor(*args, **kwargs) 304 | 305 | 306 | def cross(x, y): 307 | return torch.from_numpy(np.cross(x, y)) 308 | 309 | 310 | def triu_indices(*args, **kwargs): 311 | return torch.triu_indices(*args, **kwargs) 312 | 313 | 314 | def where(*args, **kwargs): 315 | return torch.where(*args, **kwargs) 316 | 317 | 318 | def tile(x, y): 319 | # TODO(johmathe): Native tile implementation 320 | return array(np.tile(x, y)) 321 | 322 | 323 | def clip(x, amin, amax): 324 | return np.clip(x, amin, amax) 325 | 326 | 327 | def diag(*args, **kwargs): 328 | return torch.diag(*args, **kwargs) 329 | 330 | 331 | def any(x): 332 | return x.byte().any() 333 | 334 | 335 | def expand_dims(x, axis=0): 336 | return torch.unsqueeze(x, dim=axis) 337 | 338 | 339 | def outer(*args, **kwargs): 340 | return torch.ger(*args, **kwargs) 341 | 342 | 343 | def hsplit(*args, **kwargs): 344 | return torch.hsplit(*args, **kwargs) 345 | 346 | 347 | def argmax(*args, **kwargs): 348 | return torch.argmax(*args, **kwargs) 349 | 350 | 351 | def diagonal(*args, **kwargs): 352 | return torch.diagonal(*args, **kwargs) 353 | 354 | 355 | def exp(*args, **kwargs): 356 | return torch.exp(*args, **kwargs) 357 | 358 | 359 | def log(*args, **kwargs): 360 | return torch.log(*args, **kwargs) 361 | 362 | 363 | def cov(*args, **kwargs): 364 | return torch.cov(*args, **kwargs) 365 | 366 | 367 | def eval(x): 368 | return x 369 | 370 | 371 | def ndim(x): 372 | return x.dim() 373 | 374 | 375 | def gt(*args, **kwargs): 376 | return torch.gt(*args, **kwargs) 377 | 378 | 379 | def eq(*args, **kwargs): 380 | return torch.eq(*args, **kwargs) 381 | 382 | 383 | def nonzero(*args, **kwargs): 384 | return torch.nonzero(*args, **kwargs) 385 | 386 | 387 | def copy(x): 388 | return x.clone() 389 | 390 | 391 | def seed(x): 392 | torch.manual_seed(x) 393 | 394 | 395 | def sign(*args, **kwargs): 396 | return torch.sign(*args, **kwargs) 397 | 398 | 399 | def mean(x, axis=None): 400 | if axis is None: 401 | return torch.mean(x) 402 | else: 403 | return np.mean(x, axis) 404 | -------------------------------------------------------------------------------- /geomstats/lie_group.py: -------------------------------------------------------------------------------- 1 | """Lie groups.""" 2 | 3 | 4 | import geomstats.backend as gs 5 | import geomstats.riemannian_metric as riemannian_metric 6 | 7 | from geomstats.invariant_metric import InvariantMetric 8 | from geomstats.manifold import Manifold 9 | 10 | 11 | def loss(y_pred, y_true, group, metric=None): 12 | """ 13 | Loss function given by a riemannian metric. 14 | """ 15 | if metric is None: 16 | metric = group.left_invariant_metric 17 | loss = riemannian_metric.loss(y_pred, y_true, metric) 18 | return loss 19 | 20 | 21 | def grad(y_pred, y_true, group, metric=None): 22 | """ 23 | Closed-form for the gradient of the loss function. 24 | 25 | :return: tangent vector at point y_pred. 26 | """ 27 | if metric is None: 28 | metric = group.left_invariant_metric 29 | grad = riemannian_metric.grad(y_pred, y_true, metric) 30 | return grad 31 | 32 | 33 | class LieGroup(Manifold): 34 | """ Class for Lie groups.""" 35 | 36 | def __init__(self, dimension): 37 | assert dimension > 0 38 | Manifold.__init__(self, dimension) 39 | 40 | self.left_canonical_metric = InvariantMetric( 41 | group=self, 42 | inner_product_mat_at_identity=gs.eye(self.dimension), 43 | left_or_right='left') 44 | 45 | self.right_canonical_metric = InvariantMetric( 46 | group=self, 47 | inner_product_mat_at_identity=gs.eye(self.dimension), 48 | left_or_right='right') 49 | 50 | self.metrics = [] 51 | 52 | def get_identity(self, point_type=None): 53 | """ 54 | Get the identity of the group. 55 | """ 56 | raise NotImplementedError('The Lie group identity' 57 | ' is not implemented.') 58 | identity = property(get_identity) 59 | 60 | def compose(self, point_a, point_b, point_type=None): 61 | """ 62 | Composition of the Lie group. 63 | """ 64 | raise NotImplementedError('The Lie group composition' 65 | ' is not implemented.') 66 | 67 | def inverse(self, point, point_type=None): 68 | """ 69 | Inverse law of the Lie group. 70 | """ 71 | raise NotImplementedError('The Lie group inverse is not implemented.') 72 | 73 | def jacobian_translation( 74 | self, point, left_or_right='left', point_type=None): 75 | """ 76 | Compute the jacobian matrix of the differential 77 | of the left translation by the point. 78 | """ 79 | raise NotImplementedError( 80 | 'The jacobian of the Lie group translation is not implemented.') 81 | 82 | def group_exp_from_identity(self, tangent_vec, point_type=None): 83 | """ 84 | Compute the group exponential 85 | of tangent vector tangent_vec from the identity. 86 | """ 87 | raise NotImplementedError( 88 | 'The group exponential from the identity is not implemented.') 89 | 90 | def group_exp_not_from_identity(self, tangent_vec, base_point, point_type): 91 | jacobian = self.jacobian_translation( 92 | point=base_point, 93 | left_or_right='left', 94 | point_type=point_type) 95 | 96 | if point_type == 'vector': 97 | tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=2) 98 | inv_jacobian = gs.linalg.inv(jacobian) 99 | 100 | tangent_vec_at_id = gs.einsum('ni,nij->nj', 101 | tangent_vec, 102 | gs.transpose(inv_jacobian, 103 | axes=(0, 2, 1))) 104 | group_exp_from_identity = self.group_exp_from_identity( 105 | tangent_vec=tangent_vec_at_id, 106 | point_type=point_type) 107 | group_exp = self.compose(base_point, 108 | group_exp_from_identity, 109 | point_type=point_type) 110 | group_exp = self.regularize(group_exp, point_type=point_type) 111 | return group_exp 112 | 113 | elif point_type == 'matrix': 114 | tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=3) 115 | raise NotImplementedError() 116 | 117 | def group_exp(self, tangent_vec, base_point=None, point_type=None): 118 | """ 119 | Compute the group exponential at point base_point 120 | of tangent vector tangent_vec. 121 | """ 122 | if point_type is None: 123 | point_type = self.default_point_type 124 | 125 | identity = self.get_identity(point_type=point_type) 126 | identity = self.regularize(identity, point_type=point_type) 127 | if base_point is None: 128 | base_point = identity 129 | base_point = self.regularize(base_point, point_type=point_type) 130 | 131 | if point_type == 'vector': 132 | tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=2) 133 | base_point = gs.to_ndarray(base_point, to_ndim=2) 134 | if point_type == 'matrix': 135 | tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=3) 136 | base_point = gs.to_ndarray(base_point, to_ndim=3) 137 | 138 | n_tangent_vecs = tangent_vec.shape[0] 139 | n_base_points = base_point.shape[0] 140 | 141 | assert (tangent_vec.shape == base_point.shape 142 | or n_tangent_vecs == 1 143 | or n_base_points == 1) 144 | 145 | if n_tangent_vecs == 1: 146 | tangent_vec = gs.array([tangent_vec[0]] * n_base_points) 147 | 148 | if n_base_points == 1: 149 | base_point = gs.array([base_point[0]] * n_tangent_vecs) 150 | 151 | result = gs.cond( 152 | pred=gs.allclose(base_point, identity), 153 | true_fn=lambda: self.group_exp_from_identity( 154 | tangent_vec, point_type=point_type), 155 | false_fn=lambda: self.group_exp_not_from_identity( 156 | tangent_vec, base_point, point_type)) 157 | return result 158 | 159 | def group_log_from_identity(self, point, point_type=None): 160 | """ 161 | Compute the group logarithm 162 | of the point point from the identity. 163 | """ 164 | raise NotImplementedError( 165 | 'The group logarithm from the identity is not implemented.') 166 | 167 | def group_log_not_from_identity(self, point, base_point, point_type): 168 | jacobian = self.jacobian_translation(point=base_point, 169 | left_or_right='left', 170 | point_type=point_type) 171 | point_near_id = self.compose( 172 | self.inverse(base_point), point, point_type=point_type) 173 | group_log_from_id = self.group_log_from_identity( 174 | point=point_near_id, 175 | point_type=point_type) 176 | 177 | group_log = gs.einsum('ni,nij->nj', 178 | group_log_from_id, 179 | gs.transpose(jacobian, axes=(0, 2, 1))) 180 | 181 | assert gs.ndim(group_log) == 2 182 | return group_log 183 | 184 | def group_log(self, point, base_point=None, point_type=None): 185 | """ 186 | Compute the group logarithm at point base_point 187 | of the point point. 188 | """ 189 | if point_type is None: 190 | point_type = self.default_point_type 191 | 192 | identity = self.get_identity(point_type=point_type) 193 | if base_point is None: 194 | base_point = identity 195 | 196 | if point_type == 'vector': 197 | point = gs.to_ndarray(point, to_ndim=2) 198 | base_point = gs.to_ndarray(base_point, to_ndim=2) 199 | if point_type == 'matrix': 200 | point = gs.to_ndarray(point, to_ndim=3) 201 | base_point = gs.to_ndarray(base_point, to_ndim=3) 202 | 203 | point = self.regularize(point, point_type=point_type) 204 | base_point = self.regularize(base_point, point_type=point_type) 205 | 206 | n_points = point.shape[0] 207 | n_base_points = base_point.shape[0] 208 | 209 | assert (point.shape == base_point.shape 210 | or n_points == 1 211 | or n_base_points == 1) 212 | 213 | if n_points == 1: 214 | point = gs.array([point[0]] * n_base_points) 215 | 216 | if n_base_points == 1: 217 | base_point = gs.array([base_point[0]] * n_points) 218 | 219 | result = gs.cond( 220 | pred=gs.allclose(base_point, identity), 221 | true_fn=lambda: self.group_log_from_identity( 222 | point, point_type=point_type), 223 | false_fn=lambda: self.group_log_not_from_identity( 224 | point, base_point, point_type)) 225 | 226 | return result 227 | 228 | def group_exponential_barycenter( 229 | self, points, weights=None, point_type=None): 230 | """ 231 | Compute the group exponential barycenter. 232 | """ 233 | raise NotImplementedError( 234 | 'The group exponential barycenter is not implemented.') 235 | 236 | def add_metric(self, metric): 237 | self.metrics.append(metric) 238 | -------------------------------------------------------------------------------- /tests/test_stiefel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for Stiefel manifolds. 3 | """ 4 | 5 | import warnings 6 | 7 | import geomstats.backend as gs 8 | import geomstats.tests 9 | import tests.helper as helper 10 | 11 | from geomstats.stiefel import Stiefel 12 | 13 | ATOL = 1e-6 14 | 15 | 16 | class TestStiefelMethods(geomstats.tests.TestCase): 17 | _multiprocess_can_split_ = True 18 | 19 | def setUp(self): 20 | """ 21 | Tangent vectors constructed following: 22 | http://noodle.med.yale.edu/hdtag/notes/steifel_notes.pdf 23 | """ 24 | warnings.filterwarnings('ignore') 25 | 26 | gs.random.seed(1234) 27 | 28 | self.p = 3 29 | self.n = 4 30 | self.space = Stiefel(self.n, self.p) 31 | self.n_samples = 10 32 | self.dimension = int( 33 | self.p * self.n - (self.p * (self.p + 1) / 2)) 34 | 35 | self.point_a = gs.array([ 36 | [1., 0., 0.], 37 | [0., 1., 0.], 38 | [0., 0., 1.], 39 | [0., 0., 0.]]) 40 | 41 | self.point_b = gs.array([ 42 | [1. / gs.sqrt(2.), 0., 0.], 43 | [0., 1., 0.], 44 | [0., 0., 1.], 45 | [1. / gs.sqrt(2.), 0., 0.]]) 46 | 47 | point_perp = gs.array([ 48 | [0.], 49 | [0.], 50 | [0.], 51 | [1.]]) 52 | 53 | matrix_a_1 = gs.array([ 54 | [0., 2., -5.], 55 | [-2., 0., -1.], 56 | [5., 1., 0.]]) 57 | 58 | matrix_b_1 = gs.array([ 59 | [-2., 1., 4.]]) 60 | 61 | matrix_a_2 = gs.array([ 62 | [0., 2., -5.], 63 | [-2., 0., -1.], 64 | [5., 1., 0.]]) 65 | 66 | matrix_b_2 = gs.array([ 67 | [-2., 1., 4.]]) 68 | 69 | self.tangent_vector_1 = ( 70 | gs.matmul(self.point_a, matrix_a_1) 71 | + gs.matmul(point_perp, matrix_b_1)) 72 | 73 | self.tangent_vector_2 = ( 74 | gs.matmul(self.point_a, matrix_a_2) 75 | + gs.matmul(point_perp, matrix_b_2)) 76 | 77 | self.metric = self.space.canonical_metric 78 | 79 | @geomstats.tests.np_and_tf_only 80 | def test_belongs(self): 81 | point = self.space.random_uniform() 82 | belongs = self.space.belongs(point) 83 | 84 | self.assertAllClose(gs.shape(belongs), (1, 1)) 85 | 86 | @geomstats.tests.np_and_tf_only 87 | def test_random_uniform_and_belongs(self): 88 | point = self.space.random_uniform() 89 | result = self.space.belongs(point) 90 | expected = gs.array([[True]]) 91 | 92 | self.assertAllClose(result, expected) 93 | 94 | @geomstats.tests.np_and_tf_only 95 | def test_random_uniform(self): 96 | result = self.space.random_uniform() 97 | 98 | self.assertAllClose(gs.shape(result), (1, self.n, self.p)) 99 | 100 | @geomstats.tests.np_only 101 | def test_log_and_exp(self): 102 | """ 103 | Test that the riemannian exponential 104 | and the riemannian logarithm are inverse. 105 | 106 | Expect their composition to give the identity function. 107 | """ 108 | # Riemannian Log then Riemannian Exp 109 | # General case 110 | base_point = self.point_a 111 | point = self.point_b 112 | 113 | log = self.metric.log(point=point, base_point=base_point) 114 | result = self.metric.exp(tangent_vec=log, base_point=base_point) 115 | expected = helper.to_matrix(point) 116 | 117 | self.assertAllClose(result, expected, atol=ATOL) 118 | 119 | @geomstats.tests.np_and_tf_only 120 | def test_exp_and_belongs(self): 121 | base_point = self.point_a 122 | tangent_vec = self.tangent_vector_1 123 | 124 | exp = self.metric.exp( 125 | tangent_vec=tangent_vec, 126 | base_point=base_point) 127 | result = self.space.belongs(exp) 128 | expected = gs.array([[True]]) 129 | self.assertAllClose(result, expected) 130 | 131 | @geomstats.tests.np_and_tf_only 132 | def test_exp_vectorization(self): 133 | n_samples = self.n_samples 134 | n = self.n 135 | p = self.p 136 | 137 | one_base_point = self.point_a 138 | n_base_points = gs.tile( 139 | gs.to_ndarray(self.point_a, to_ndim=3), 140 | (n_samples, 1, 1)) 141 | 142 | one_tangent_vec = self.tangent_vector_1 143 | result = self.metric.exp(one_tangent_vec, one_base_point) 144 | self.assertAllClose(gs.shape(result), (1, n, p)) 145 | 146 | n_tangent_vecs = gs.tile( 147 | gs.to_ndarray(self.tangent_vector_2, to_ndim=3), 148 | (n_samples, 1, 1)) 149 | 150 | result = self.metric.exp(n_tangent_vecs, one_base_point) 151 | self.assertAllClose(gs.shape(result), (n_samples, n, p)) 152 | 153 | result = self.metric.exp(one_tangent_vec, n_base_points) 154 | self.assertAllClose(gs.shape(result), (n_samples, n, p)) 155 | 156 | @geomstats.tests.np_and_tf_only 157 | def test_log_vectorization(self): 158 | n_samples = self.n_samples 159 | n = self.n 160 | p = self.p 161 | 162 | one_point = self.space.random_uniform() 163 | one_base_point = self.space.random_uniform() 164 | n_points = self.space.random_uniform(n_samples=n_samples) 165 | n_base_points = self.space.random_uniform(n_samples=n_samples) 166 | 167 | result = self.metric.log(one_point, one_base_point) 168 | self.assertAllClose(gs.shape(result), (1, n, p)) 169 | 170 | result = self.metric.log(n_points, one_base_point) 171 | self.assertAllClose(gs.shape(result), (n_samples, n, p)) 172 | 173 | result = self.metric.log(one_point, n_base_points) 174 | self.assertAllClose(gs.shape(result), (n_samples, n, p)) 175 | 176 | result = self.metric.log(n_points, n_base_points) 177 | self.assertAllClose(gs.shape(result), (n_samples, n, p)) 178 | 179 | @geomstats.tests.np_only 180 | def test_retractation_and_lifting(self): 181 | """ 182 | Test that the riemannian exponential 183 | and the riemannian logarithm are inverse. 184 | 185 | Expect their composition to give the identity function. 186 | """ 187 | # Riemannian Log then Riemannian Exp 188 | # General case 189 | base_point = self.point_a 190 | point = self.point_b 191 | tangent_vec = self.tangent_vector_1 192 | 193 | lifted = self.metric.lifting(point=point, base_point=base_point) 194 | result = self.metric.retraction( 195 | tangent_vec=lifted, base_point=base_point) 196 | expected = helper.to_matrix(point) 197 | 198 | self.assertAllClose(result, expected, atol=ATOL) 199 | 200 | retract = self.metric.retraction( 201 | tangent_vec=tangent_vec, base_point=base_point) 202 | result = self.metric.lifting(point=retract, base_point=base_point) 203 | expected = helper.to_matrix(tangent_vec) 204 | 205 | self.assertAllClose(result, expected, atol=ATOL) 206 | 207 | @geomstats.tests.np_only 208 | def test_lifting_vectorization(self): 209 | n_samples = self.n_samples 210 | n = self.n 211 | p = self.p 212 | 213 | one_point = self.point_a 214 | one_base_point = self.point_b 215 | n_points = gs.tile( 216 | gs.to_ndarray(self.point_a, to_ndim=3), 217 | (n_samples, 1, 1)) 218 | n_base_points = gs.tile( 219 | gs.to_ndarray(self.point_b, to_ndim=3), 220 | (n_samples, 1, 1)) 221 | 222 | result = self.metric.lifting(one_point, one_base_point) 223 | self.assertAllClose(gs.shape(result), (1, n, p)) 224 | 225 | result = self.metric.lifting(n_points, one_base_point) 226 | self.assertAllClose(gs.shape(result), (n_samples, n, p)) 227 | 228 | result = self.metric.lifting(one_point, n_base_points) 229 | self.assertAllClose(gs.shape(result), (n_samples, n, p)) 230 | 231 | result = self.metric.lifting(n_points, n_base_points) 232 | self.assertAllClose(gs.shape(result), (n_samples, n, p)) 233 | 234 | @geomstats.tests.np_and_tf_only 235 | def test_retraction_vectorization(self): 236 | n_samples = self.n_samples 237 | n = self.n 238 | p = self.p 239 | 240 | one_point = self.point_a 241 | n_points = gs.tile( 242 | gs.to_ndarray(one_point, to_ndim=3), 243 | (n_samples, 1, 1)) 244 | one_tangent_vec = self.tangent_vector_1 245 | n_tangent_vecs = gs.tile( 246 | gs.to_ndarray(self.tangent_vector_2, to_ndim=3), 247 | (n_samples, 1, 1)) 248 | 249 | result = self.metric.retraction(one_tangent_vec, one_point) 250 | self.assertAllClose(gs.shape(result), (1, n, p)) 251 | 252 | result = self.metric.retraction(n_tangent_vecs, one_point) 253 | self.assertAllClose(gs.shape(result), (n_samples, n, p)) 254 | 255 | result = self.metric.retraction(one_tangent_vec, n_points) 256 | self.assertAllClose(gs.shape(result), (n_samples, n, p)) 257 | 258 | result = self.metric.retraction(n_tangent_vecs, n_points) 259 | self.assertAllClose(gs.shape(result), (n_samples, n, p)) 260 | 261 | def test_inner_product(self): 262 | base_point = self.point_a 263 | tangent_vector_1 = self.tangent_vector_1 264 | tangent_vector_2 = self.tangent_vector_2 265 | 266 | result = self.metric.inner_product( 267 | tangent_vector_1, 268 | tangent_vector_2, 269 | base_point=base_point) 270 | self.assertAllClose(gs.shape(result), (1, 1)) 271 | 272 | 273 | if __name__ == '__main__': 274 | geomstats.tests.main() 275 | -------------------------------------------------------------------------------- /geomstats/spd_matrices_space.py: -------------------------------------------------------------------------------- 1 | """ 2 | The manifold of symmetric positive definite (SPD) matrices. 3 | """ 4 | 5 | import geomstats.backend as gs 6 | 7 | from geomstats.embedded_manifold import EmbeddedManifold 8 | from geomstats.general_linear_group import GeneralLinearGroup 9 | from geomstats.riemannian_metric import RiemannianMetric 10 | 11 | EPSILON = 1e-6 12 | TOLERANCE = 1e-12 13 | 14 | 15 | class SPDMatricesSpace(EmbeddedManifold): 16 | """ 17 | Class for the manifold of symmetric positive definite (SPD) matrices. 18 | """ 19 | def __init__(self, n): 20 | assert isinstance(n, int) and n > 0 21 | super(SPDMatricesSpace, self).__init__( 22 | dimension=int(n * (n + 1) / 2), 23 | embedding_manifold=GeneralLinearGroup(n=n)) 24 | self.n = n 25 | self.metric = SPDMetric(n=n) 26 | 27 | def belongs(self, mat, tolerance=TOLERANCE): 28 | """ 29 | Check if a matrix belongs to the manifold of 30 | symmetric positive definite matrices. 31 | """ 32 | mat = gs.to_ndarray(mat, to_ndim=3) 33 | n_mats, mat_dim, _ = mat.shape 34 | 35 | mask_is_symmetric = self.embedding_manifold.is_symmetric( 36 | mat, tolerance=tolerance) 37 | mask_is_invertible = self.embedding_manifold.belongs(mat) 38 | 39 | belongs = mask_is_symmetric & mask_is_invertible 40 | belongs = gs.to_ndarray(belongs, to_ndim=1) 41 | belongs = gs.to_ndarray(belongs, to_ndim=2, axis=1) 42 | return belongs 43 | 44 | def vector_from_symmetric_matrix(self, mat): 45 | """ 46 | Convert the symmetric part of a symmetric matrix 47 | into a vector. 48 | """ 49 | mat = gs.to_ndarray(mat, to_ndim=3) 50 | assert gs.all(self.embedding_manifold.is_symmetric(mat)) 51 | mat = self.embedding_manifold.make_symmetric(mat) 52 | 53 | _, mat_dim, _ = mat.shape 54 | vec_dim = int(mat_dim * (mat_dim + 1) / 2) 55 | vec = gs.zeros(vec_dim) 56 | 57 | idx = 0 58 | for i in range(mat_dim): 59 | for j in range(i + 1): 60 | if i == j: 61 | vec[idx] = mat[j, j] 62 | else: 63 | vec[idx] = mat[j, i] 64 | idx += 1 65 | 66 | return vec 67 | 68 | def symmetric_matrix_from_vector(self, vec): 69 | """ 70 | Convert a vector into a symmetric matrix. 71 | """ 72 | vec = gs.to_ndarray(vec, to_ndim=2) 73 | _, vec_dim = vec.shape 74 | mat_dim = int((gs.sqrt(8 * vec_dim + 1) - 1) / 2) 75 | mat = gs.zeros((mat_dim,) * 2) 76 | 77 | lower_triangle_indices = gs.tril_indices(mat_dim) 78 | diag_indices = gs.diag_indices(mat_dim) 79 | 80 | mat[lower_triangle_indices] = 2 * vec 81 | mat[diag_indices] = vec 82 | 83 | mat = self.embedding_manifold.make_symmetric(mat) 84 | return mat 85 | 86 | def random_uniform(self, n_samples=1): 87 | mat = 2 * gs.random.rand(n_samples, self.n, self.n) - 1 88 | 89 | spd_mat = self.embedding_manifold.group_exp( 90 | mat + gs.transpose(mat, axes=(0, 2, 1))) 91 | return spd_mat 92 | 93 | def random_tangent_vec_uniform(self, n_samples=1, base_point=None): 94 | if base_point is None: 95 | base_point = gs.eye(self.n) 96 | 97 | base_point = gs.to_ndarray(base_point, to_ndim=3) 98 | n_base_points, _, _ = base_point.shape 99 | 100 | assert n_base_points == n_samples or n_base_points == 1 101 | if n_base_points == 1: 102 | base_point = gs.tile(base_point, (n_samples, 1, 1)) 103 | 104 | sqrt_base_point = gs.linalg.sqrtm(base_point) 105 | 106 | tangent_vec_at_id = (2 * gs.random.rand(n_samples, 107 | self.n, 108 | self.n) 109 | - 1) 110 | tangent_vec_at_id = (tangent_vec_at_id 111 | + gs.transpose(tangent_vec_at_id, 112 | axes=(0, 2, 1))) 113 | 114 | tangent_vec = gs.matmul(sqrt_base_point, tangent_vec_at_id) 115 | tangent_vec = gs.matmul(tangent_vec, sqrt_base_point) 116 | 117 | return tangent_vec 118 | 119 | 120 | class SPDMetric(RiemannianMetric): 121 | 122 | def __init__(self, n): 123 | super(SPDMetric, self).__init__( 124 | dimension=int(n * (n + 1) / 2), 125 | signature=(int(n * (n + 1) / 2), 0, 0)) 126 | self.n = n 127 | 128 | def inner_product(self, tangent_vec_a, tangent_vec_b, base_point): 129 | """ 130 | Compute the inner product of tangent_vec_a and tangent_vec_b 131 | at point base_point using the affine invariant Riemannian metric. 132 | """ 133 | tangent_vec_a = gs.to_ndarray(tangent_vec_a, to_ndim=3) 134 | n_tangent_vecs_a, _, _ = tangent_vec_a.shape 135 | tangent_vec_b = gs.to_ndarray(tangent_vec_b, to_ndim=3) 136 | n_tangent_vecs_b, _, _ = tangent_vec_b.shape 137 | 138 | base_point = gs.to_ndarray(base_point, to_ndim=3) 139 | n_base_points, _, _ = base_point.shape 140 | 141 | assert (n_tangent_vecs_a == n_tangent_vecs_b == n_base_points 142 | or n_tangent_vecs_a == n_tangent_vecs_b and n_base_points == 1 143 | or n_base_points == n_tangent_vecs_a and n_tangent_vecs_b == 1 144 | or n_base_points == n_tangent_vecs_b and n_tangent_vecs_a == 1 145 | or n_tangent_vecs_a == 1 and n_tangent_vecs_b == 1 146 | or n_base_points == 1 and n_tangent_vecs_a == 1 147 | or n_base_points == 1 and n_tangent_vecs_b == 1) 148 | 149 | if n_tangent_vecs_a == 1: 150 | tangent_vec_a = gs.tile( 151 | tangent_vec_a, 152 | (gs.maximum(n_base_points, n_tangent_vecs_b), 1, 1)) 153 | 154 | if n_tangent_vecs_b == 1: 155 | tangent_vec_b = gs.tile( 156 | tangent_vec_b, 157 | (gs.maximum(n_base_points, n_tangent_vecs_a), 1, 1)) 158 | 159 | if n_base_points == 1: 160 | base_point = gs.tile( 161 | base_point, 162 | (gs.maximum(n_tangent_vecs_a, n_tangent_vecs_b), 1, 1)) 163 | 164 | inv_base_point = gs.linalg.inv(base_point) 165 | 166 | aux_a = gs.matmul(inv_base_point, tangent_vec_a) 167 | aux_b = gs.matmul(inv_base_point, tangent_vec_b) 168 | inner_product = gs.trace(gs.matmul(aux_a, aux_b), axis1=1, axis2=2) 169 | inner_product = gs.to_ndarray(inner_product, to_ndim=2, axis=1) 170 | return inner_product 171 | 172 | def exp(self, tangent_vec, base_point): 173 | """ 174 | Compute the Riemannian exponential at point base_point 175 | of tangent vector tangent_vec wrt the metric 176 | defined in inner_product. 177 | 178 | This gives a symmetric positive definite matrix. 179 | """ 180 | tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=3) 181 | n_tangent_vecs, _, _ = tangent_vec.shape 182 | 183 | base_point = gs.to_ndarray(base_point, to_ndim=3) 184 | n_base_points, mat_dim, _ = base_point.shape 185 | 186 | assert (n_tangent_vecs == n_base_points 187 | or n_tangent_vecs == 1 188 | or n_base_points == 1) 189 | 190 | if n_tangent_vecs == 1: 191 | tangent_vec = gs.tile(tangent_vec, (n_base_points, 1, 1)) 192 | if n_base_points == 1: 193 | base_point = gs.tile(base_point, (n_tangent_vecs, 1, 1)) 194 | 195 | sqrt_base_point = gs.linalg.sqrtm(base_point) 196 | 197 | inv_sqrt_base_point = gs.linalg.inv(sqrt_base_point) 198 | 199 | tangent_vec_at_id = gs.matmul(inv_sqrt_base_point, 200 | tangent_vec) 201 | tangent_vec_at_id = gs.matmul(tangent_vec_at_id, 202 | inv_sqrt_base_point) 203 | exp_from_id = gs.linalg.expm(tangent_vec_at_id) 204 | 205 | exp = gs.matmul(exp_from_id, sqrt_base_point) 206 | exp = gs.matmul(sqrt_base_point, exp) 207 | 208 | return exp 209 | 210 | def log(self, point, base_point): 211 | """ 212 | Compute the Riemannian logarithm at point base_point, 213 | of point wrt the metric defined in inner_product. 214 | 215 | This gives a tangent vector at point base_point. 216 | """ 217 | point = gs.to_ndarray(point, to_ndim=3) 218 | n_points, _, _ = point.shape 219 | 220 | base_point = gs.to_ndarray(base_point, to_ndim=3) 221 | n_base_points, mat_dim, _ = base_point.shape 222 | 223 | assert (n_points == n_base_points 224 | or n_points == 1 225 | or n_base_points == 1) 226 | 227 | if n_points == 1: 228 | point = gs.tile(point, (n_base_points, 1, 1)) 229 | if n_base_points == 1: 230 | base_point = gs.tile(base_point, (n_points, 1, 1)) 231 | 232 | sqrt_base_point = gs.zeros((n_base_points,) + (mat_dim,) * 2) 233 | sqrt_base_point = gs.linalg.sqrtm(base_point) 234 | 235 | inv_sqrt_base_point = gs.linalg.inv(sqrt_base_point) 236 | point_near_id = gs.matmul(inv_sqrt_base_point, point) 237 | point_near_id = gs.matmul(point_near_id, inv_sqrt_base_point) 238 | log_at_id = gs.linalg.logm(point_near_id) 239 | 240 | log = gs.matmul(sqrt_base_point, log_at_id) 241 | log = gs.matmul(log, sqrt_base_point) 242 | 243 | return log 244 | 245 | def geodesic(self, initial_point, initial_tangent_vec): 246 | return super(SPDMetric, self).geodesic( 247 | initial_point=initial_point, 248 | initial_tangent_vec=initial_tangent_vec, 249 | point_type='matrix') 250 | -------------------------------------------------------------------------------- /tests/test_minkowski_space.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for Minkowski space. 3 | """ 4 | import math 5 | import numpy as np 6 | 7 | import geomstats.backend as gs 8 | import geomstats.tests 9 | import tests.helper as helper 10 | 11 | from geomstats.minkowski_space import MinkowskiSpace 12 | 13 | 14 | class TestMinkowskiSpaceMethods(geomstats.tests.TestCase): 15 | _multiprocess_can_split_ = True 16 | 17 | def setUp(self): 18 | gs.random.seed(1234) 19 | 20 | self.time_like_dim = 0 21 | self.dimension = 2 22 | self.space = MinkowskiSpace(self.dimension) 23 | self.metric = self.space.metric 24 | self.n_samples = 10 25 | 26 | def test_belongs(self): 27 | point = gs.array([-1., 3.]) 28 | result = self.space.belongs(point) 29 | expected = gs.array([[True]]) 30 | 31 | self.assertAllClose(result, expected) 32 | 33 | def test_random_uniform(self): 34 | point = self.space.random_uniform() 35 | self.assertAllClose(gs.shape(point), (1, self.dimension)) 36 | 37 | def test_random_uniform_and_belongs(self): 38 | point = self.space.random_uniform() 39 | result = self.space.belongs(point) 40 | expected = gs.array([[True]]) 41 | self.assertAllClose(result, expected) 42 | 43 | def test_inner_product_matrix(self): 44 | result = self.metric.inner_product_matrix() 45 | 46 | expected = gs.array([[-1., 0.], [0., 1.]]) 47 | self.assertAllClose(result, expected) 48 | 49 | def test_inner_product(self): 50 | point_a = gs.array([0., 1.]) 51 | point_b = gs.array([2., 10.]) 52 | 53 | result = self.metric.inner_product(point_a, point_b) 54 | expected = helper.to_scalar(gs.dot(point_a, point_b)) 55 | expected -= (2 * point_a[self.time_like_dim] 56 | * point_b[self.time_like_dim]) 57 | 58 | self.assertAllClose(result, expected) 59 | 60 | def test_inner_product_vectorization(self): 61 | n_samples = 3 62 | one_point_a = gs.array([[-1., 0.]]) 63 | one_point_b = gs.array([[1.0, 0.]]) 64 | 65 | n_points_a = gs.array([ 66 | [-1., 0.], 67 | [1., 0.], 68 | [2., math.sqrt(3)]]) 69 | n_points_b = gs.array([ 70 | [2., -math.sqrt(3)], 71 | [4.0, math.sqrt(15)], 72 | [-4.0, math.sqrt(15)]]) 73 | 74 | result = self.metric.inner_product(one_point_a, one_point_b) 75 | expected = gs.dot(one_point_a, gs.transpose(one_point_b)) 76 | expected -= (2 * one_point_a[:, self.time_like_dim] 77 | * one_point_b[:, self.time_like_dim]) 78 | expected = helper.to_scalar(expected) 79 | 80 | result_no = self.metric.inner_product(n_points_a, 81 | one_point_b) 82 | result_on = self.metric.inner_product(one_point_a, n_points_b) 83 | 84 | result_nn = self.metric.inner_product(n_points_a, n_points_b) 85 | 86 | self.assertAllClose(result, expected) 87 | self.assertAllClose(gs.shape(result_no), (n_samples, 1)) 88 | self.assertAllClose(gs.shape(result_on), (n_samples, 1)) 89 | self.assertAllClose(gs.shape(result_nn), (n_samples, 1)) 90 | 91 | with self.session(): 92 | expected = np.zeros(n_samples) 93 | for i in range(n_samples): 94 | expected[i] = gs.eval(gs.dot(n_points_a[i], 95 | n_points_b[i])) 96 | expected[i] -= (2 * gs.eval(n_points_a[i, self.time_like_dim]) 97 | * gs.eval(n_points_b[i, self.time_like_dim])) 98 | expected = helper.to_scalar(gs.array(expected)) 99 | 100 | self.assertAllClose(result_nn, expected) 101 | 102 | def test_squared_norm(self): 103 | point = gs.array([-2., 4.]) 104 | 105 | result = self.metric.squared_norm(point) 106 | expected = gs.array([[12.]]) 107 | self.assertAllClose(result, expected) 108 | 109 | def test_squared_norm_vectorization(self): 110 | n_samples = 3 111 | n_points = gs.array([ 112 | [-1., 0.], 113 | [1., 0.], 114 | [2., math.sqrt(3)]]) 115 | 116 | result = self.metric.squared_norm(n_points) 117 | self.assertAllClose(gs.shape(result), (n_samples, 1)) 118 | 119 | def test_exp(self): 120 | base_point = gs.array([1.0, 0.]) 121 | vector = gs.array([2., math.sqrt(3)]) 122 | 123 | result = self.metric.exp(tangent_vec=vector, 124 | base_point=base_point) 125 | expected = base_point + vector 126 | expected = helper.to_vector(expected) 127 | self.assertAllClose(result, expected) 128 | 129 | def test_exp_vectorization(self): 130 | dim = self.dimension 131 | n_samples = 3 132 | one_tangent_vec = gs.array([[-1., 0.]]) 133 | one_base_point = gs.array([[1.0, 0.]]) 134 | 135 | n_tangent_vecs = gs.array([ 136 | [-1., 0.], 137 | [1., 0.], 138 | [2., math.sqrt(3)]]) 139 | n_base_points = gs.array([ 140 | [2., -math.sqrt(3)], 141 | [4.0, math.sqrt(15)], 142 | [-4.0, math.sqrt(15)]]) 143 | 144 | result = self.metric.exp(one_tangent_vec, one_base_point) 145 | expected = one_tangent_vec + one_base_point 146 | expected = helper.to_vector(expected) 147 | self.assertAllClose(result, expected) 148 | 149 | result = self.metric.exp(n_tangent_vecs, one_base_point) 150 | self.assertAllClose(gs.shape(result), (n_samples, dim)) 151 | 152 | result = self.metric.exp(one_tangent_vec, n_base_points) 153 | self.assertAllClose(gs.shape(result), (n_samples, dim)) 154 | 155 | result = self.metric.exp(n_tangent_vecs, n_base_points) 156 | self.assertAllClose(gs.shape(result), (n_samples, dim)) 157 | 158 | def test_log(self): 159 | base_point = gs.array([-1., 0.]) 160 | point = gs.array([2., math.sqrt(3)]) 161 | 162 | result = self.metric.log(point=point, base_point=base_point) 163 | expected = point - base_point 164 | expected = helper.to_vector(expected) 165 | self.assertAllClose(result, expected) 166 | 167 | def test_log_vectorization(self): 168 | dim = self.dimension 169 | n_samples = 3 170 | one_point = gs.array([[-1., 0.]]) 171 | one_base_point = gs.array([[1.0, 0.]]) 172 | 173 | n_points = gs.array([ 174 | [-1., 0.], 175 | [1., 0.], 176 | [2., math.sqrt(3)]]) 177 | n_base_points = gs.array([ 178 | [2., -math.sqrt(3)], 179 | [4.0, math.sqrt(15)], 180 | [-4.0, math.sqrt(15)]]) 181 | 182 | result = self.metric.log(one_point, one_base_point) 183 | expected = one_point - one_base_point 184 | expected = helper.to_vector(expected) 185 | self.assertAllClose(result, expected) 186 | 187 | result = self.metric.log(n_points, one_base_point) 188 | self.assertAllClose(gs.shape(result), (n_samples, dim)) 189 | 190 | result = self.metric.log(one_point, n_base_points) 191 | self.assertAllClose(gs.shape(result), (n_samples, dim)) 192 | 193 | result = self.metric.log(n_points, n_base_points) 194 | self.assertAllClose(gs.shape(result), (n_samples, dim)) 195 | 196 | def test_squared_dist(self): 197 | point_a = gs.array([2., -math.sqrt(3)]) 198 | point_b = gs.array([4.0, math.sqrt(15)]) 199 | 200 | result = self.metric.squared_dist(point_a, point_b) 201 | vec = point_b - point_a 202 | expected = gs.dot(vec, vec) 203 | expected -= 2 * vec[self.time_like_dim] * vec[self.time_like_dim] 204 | expected = helper.to_scalar(expected) 205 | self.assertAllClose(result, expected) 206 | 207 | def test_geodesic_and_belongs(self): 208 | n_geodesic_points = 100 209 | initial_point = gs.array([2., -math.sqrt(3)]) 210 | initial_tangent_vec = gs.array([2., 0.]) 211 | 212 | geodesic = self.metric.geodesic( 213 | initial_point=initial_point, 214 | initial_tangent_vec=initial_tangent_vec) 215 | 216 | t = gs.linspace(start=0., stop=1., num=n_geodesic_points) 217 | points = geodesic(t) 218 | 219 | result = self.space.belongs(points) 220 | expected = gs.array(n_geodesic_points * [[True]]) 221 | 222 | self.assertAllClose(result, expected) 223 | 224 | def test_mean(self): 225 | point = gs.array([[2., -math.sqrt(3)]]) 226 | result = self.metric.mean(points=[point, point, point]) 227 | expected = point 228 | expected = helper.to_vector(expected) 229 | 230 | self.assertAllClose(result, expected) 231 | 232 | points = gs.array([ 233 | [1., 0.], 234 | [2., math.sqrt(3)], 235 | [3., math.sqrt(8)], 236 | [4., math.sqrt(24)]]) 237 | weights = gs.array([1., 2., 1., 2.]) 238 | result = self.metric.mean(points, weights) 239 | result = self.space.belongs(result) 240 | expected = gs.array([[True]]) 241 | 242 | self.assertAllClose(result, expected) 243 | 244 | def test_variance(self): 245 | points = gs.array([ 246 | [1., 0.], 247 | [2., math.sqrt(3)], 248 | [3., math.sqrt(8)], 249 | [4., math.sqrt(24)]]) 250 | weights = gs.array([1., 2., 1., 2.]) 251 | base_point = gs.array([-1., 0.]) 252 | variance = self.metric.variance(points, weights, base_point) 253 | result = helper.to_scalar(variance != 0) 254 | # we expect the average of the points' Minkowski sq norms. 255 | expected = helper.to_scalar(gs.array([True])) 256 | self.assertAllClose(result, expected) 257 | 258 | 259 | if __name__ == '__main__': 260 | geomstats.tests.main() 261 | -------------------------------------------------------------------------------- /tests/test_spd_matrices_space.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for the manifold of symmetric positive definite matrices. 3 | """ 4 | 5 | import warnings 6 | 7 | import geomstats.backend as gs 8 | import geomstats.tests 9 | import tests.helper as helper 10 | 11 | from geomstats.spd_matrices_space import SPDMatricesSpace 12 | 13 | 14 | class TestSPDMatricesSpaceMethods(geomstats.tests.TestCase): 15 | _multiprocess_can_split_ = True 16 | 17 | def setUp(self): 18 | warnings.simplefilter('ignore', category=ImportWarning) 19 | 20 | gs.random.seed(1234) 21 | 22 | self.n = 3 23 | self.space = SPDMatricesSpace(n=self.n) 24 | self.metric = self.space.metric 25 | self.n_samples = 4 26 | 27 | @geomstats.tests.np_and_tf_only 28 | def test_random_uniform_and_belongs(self): 29 | point = self.space.random_uniform() 30 | result = self.space.belongs(point) 31 | expected = gs.array([[True]]) 32 | self.assertAllClose(result, expected) 33 | 34 | @geomstats.tests.np_and_tf_only 35 | def test_random_uniform_and_belongs_vectorization(self): 36 | """ 37 | Test that the random uniform method samples 38 | on the hypersphere space. 39 | """ 40 | n_samples = self.n_samples 41 | points = self.space.random_uniform(n_samples=n_samples) 42 | result = self.space.belongs(points) 43 | self.assertAllClose(gs.shape(result), (n_samples, 1)) 44 | 45 | @geomstats.tests.np_and_tf_only 46 | def vector_from_symmetric_matrix_and_symmetric_matrix_from_vector(self): 47 | sym_mat_1 = gs.array([[1., 0.6, -3.], 48 | [0.6, 7., 0.], 49 | [-3., 0., 8.]]) 50 | vector_1 = self.space.vector_from_symmetric_matrix(sym_mat_1) 51 | result_1 = self.space.symmetric_matrix_from_vector(vector_1) 52 | expected_1 = sym_mat_1 53 | 54 | self.assertTrue(gs.allclose(result_1, expected_1)) 55 | 56 | vector_2 = gs.array([1, 2, 3, 4, 5, 6]) 57 | sym_mat_2 = self.space.symmetric_matrix_from_vector(vector_2) 58 | result_2 = self.space.vector_from_symmetric_matrix(sym_mat_2) 59 | expected_2 = vector_2 60 | 61 | self.assertTrue(gs.allclose(result_2, expected_2)) 62 | 63 | @geomstats.tests.np_and_tf_only 64 | def vector_and_symmetric_matrix_vectorization(self): 65 | n_samples = self.n_samples 66 | vector = gs.random.rand(n_samples, 6) 67 | sym_mat = self.space.symmetric_matrix_from_vector(vector) 68 | result = self.space.vector_from_symmetric_matrix(sym_mat) 69 | expected = vector 70 | 71 | self.assertTrue(gs.allclose(result, expected)) 72 | 73 | sym_mat = self.space.random_uniform(n_samples) 74 | vector = self.space.vector_from_symmetric_matrix(sym_mat) 75 | result = self.space.symmetric_matrix_from_vector(vector) 76 | expected = sym_mat 77 | 78 | self.assertTrue(gs.allclose(result, expected)) 79 | 80 | @geomstats.tests.np_and_tf_only 81 | def test_log_and_exp(self): 82 | base_point = gs.array([[5., 0., 0.], 83 | [0., 7., 2.], 84 | [0., 2., 8.]]) 85 | point = gs.array([[9., 0., 0.], 86 | [0., 5., 0.], 87 | [0., 0., 1.]]) 88 | 89 | log = self.metric.log(point=point, base_point=base_point) 90 | result = self.metric.exp(tangent_vec=log, base_point=base_point) 91 | expected = helper.to_matrix(point) 92 | 93 | self.assertAllClose(result, expected) 94 | 95 | @geomstats.tests.np_and_tf_only 96 | def test_exp_and_belongs(self): 97 | n_samples = self.n_samples 98 | base_point = self.space.random_uniform(n_samples=1) 99 | tangent_vec = self.space.random_tangent_vec_uniform( 100 | n_samples=n_samples, 101 | base_point=base_point) 102 | exps = self.metric.exp(tangent_vec, base_point) 103 | result = self.space.belongs(exps) 104 | expected = gs.array([[True]] * n_samples) 105 | 106 | self.assertAllClose(result, expected) 107 | 108 | @geomstats.tests.np_and_tf_only 109 | def test_exp_vectorization(self): 110 | n_samples = self.n_samples 111 | one_base_point = self.space.random_uniform(n_samples=1) 112 | n_base_point = self.space.random_uniform(n_samples=n_samples) 113 | 114 | n_tangent_vec_same_base = self.space.random_tangent_vec_uniform( 115 | n_samples=n_samples, 116 | base_point=one_base_point) 117 | n_tangent_vec = self.space.random_tangent_vec_uniform( 118 | n_samples=n_samples, 119 | base_point=n_base_point) 120 | 121 | # Test with the 1 base_point, and several different tangent_vecs 122 | result = self.metric.exp(n_tangent_vec_same_base, one_base_point) 123 | 124 | self.assertAllClose( 125 | gs.shape(result), (n_samples, self.space.n, self.space.n)) 126 | 127 | # Test with the same number of base_points and tangent_vecs 128 | result = self.metric.exp(n_tangent_vec, n_base_point) 129 | 130 | self.assertAllClose( 131 | gs.shape(result), (n_samples, self.space.n, self.space.n)) 132 | 133 | @geomstats.tests.np_and_tf_only 134 | def test_log_vectorization(self): 135 | n_samples = self.n_samples 136 | one_base_point = self.space.random_uniform(n_samples=1) 137 | n_base_point = self.space.random_uniform(n_samples=n_samples) 138 | 139 | one_point = self.space.random_uniform(n_samples=1) 140 | n_point = self.space.random_uniform(n_samples=n_samples) 141 | 142 | # Test with different points, one base point 143 | result = self.metric.log(n_point, one_base_point) 144 | 145 | self.assertAllClose( 146 | gs.shape(result), (n_samples, self.space.n, self.space.n)) 147 | 148 | # Test with the same number of points and base points 149 | result = self.metric.log(n_point, n_base_point) 150 | 151 | self.assertAllClose( 152 | gs.shape(result), (n_samples, self.space.n, self.space.n)) 153 | 154 | # Test with the one point and n base points 155 | result = self.metric.log(one_point, n_base_point) 156 | 157 | self.assertAllClose( 158 | gs.shape(result), (n_samples, self.space.n, self.space.n)) 159 | 160 | @geomstats.tests.np_and_tf_only 161 | def test_geodesic_and_belongs(self): 162 | initial_point = self.space.random_uniform() 163 | initial_tangent_vec = self.space.random_tangent_vec_uniform( 164 | n_samples=1, 165 | base_point=initial_point) 166 | geodesic = self.metric.geodesic( 167 | initial_point=initial_point, 168 | initial_tangent_vec=initial_tangent_vec) 169 | 170 | n_points = 10 171 | t = gs.linspace(start=0., stop=1., num=n_points) 172 | points = geodesic(t) 173 | result = self.space.belongs(points) 174 | expected = gs.array([[True]] * n_points) 175 | 176 | self.assertAllClose(result, expected) 177 | 178 | @geomstats.tests.np_only 179 | def test_squared_dist_is_symmetric(self): 180 | n_samples = self.n_samples 181 | 182 | point_1 = self.space.random_uniform(n_samples=1) 183 | point_2 = self.space.random_uniform(n_samples=1) 184 | 185 | sq_dist_1_2 = self.metric.squared_dist(point_1, point_2) 186 | sq_dist_2_1 = self.metric.squared_dist(point_2, point_1) 187 | 188 | self.assertAllClose(sq_dist_1_2, sq_dist_2_1) 189 | 190 | point_1 = self.space.random_uniform(n_samples=1) 191 | point_2 = self.space.random_uniform(n_samples=n_samples) 192 | 193 | sq_dist_1_2 = self.metric.squared_dist(point_1, point_2) 194 | sq_dist_2_1 = self.metric.squared_dist(point_2, point_1) 195 | 196 | self.assertAllClose(sq_dist_1_2, sq_dist_2_1) 197 | 198 | point_1 = self.space.random_uniform(n_samples=n_samples) 199 | point_2 = self.space.random_uniform(n_samples=1) 200 | 201 | sq_dist_1_2 = self.metric.squared_dist(point_1, point_2) 202 | sq_dist_2_1 = self.metric.squared_dist(point_2, point_1) 203 | 204 | self.assertAllClose(sq_dist_1_2, sq_dist_2_1) 205 | 206 | point_1 = self.space.random_uniform(n_samples=n_samples) 207 | point_2 = self.space.random_uniform(n_samples=n_samples) 208 | 209 | sq_dist_1_2 = self.metric.squared_dist(point_1, point_2) 210 | sq_dist_2_1 = self.metric.squared_dist(point_2, point_1) 211 | 212 | self.assertAllClose(sq_dist_1_2, sq_dist_2_1) 213 | 214 | @geomstats.tests.np_and_tf_only 215 | def test_squared_dist_vectorization(self): 216 | n_samples = self.n_samples 217 | point_1 = self.space.random_uniform(n_samples=n_samples) 218 | point_2 = self.space.random_uniform(n_samples=n_samples) 219 | 220 | result = self.metric.squared_dist(point_1, point_2) 221 | 222 | self.assertAllClose(gs.shape(result), (n_samples, 1)) 223 | 224 | point_1 = self.space.random_uniform(n_samples=1) 225 | point_2 = self.space.random_uniform(n_samples=n_samples) 226 | 227 | result = self.metric.squared_dist(point_1, point_2) 228 | 229 | self.assertAllClose(gs.shape(result), (n_samples, 1)) 230 | 231 | point_1 = self.space.random_uniform(n_samples=n_samples) 232 | point_2 = self.space.random_uniform(n_samples=1) 233 | 234 | result = self.metric.squared_dist(point_1, point_2) 235 | 236 | self.assertAllClose(gs.shape(result), (n_samples, 1)) 237 | 238 | point_1 = self.space.random_uniform(n_samples=1) 239 | point_2 = self.space.random_uniform(n_samples=1) 240 | 241 | result = self.metric.squared_dist(point_1, point_2) 242 | 243 | self.assertAllClose(gs.shape(result), (1, 1)) 244 | 245 | 246 | if __name__ == '__main__': 247 | geomstats.tests.main() 248 | -------------------------------------------------------------------------------- /geomstats/connection.py: -------------------------------------------------------------------------------- 1 | """ 2 | Affine connections. 3 | """ 4 | 5 | import autograd 6 | 7 | import geomstats.backend as gs 8 | 9 | 10 | class Connection(object): 11 | 12 | def __init__(self, dimension): 13 | self.dimension = dimension 14 | 15 | def christoffel_symbol(self, base_point): 16 | """ 17 | Christoffel symbols associated with the connection. 18 | 19 | Parameters 20 | ---------- 21 | base_point : array-like, shape=[n_samples, dimension] 22 | """ 23 | raise NotImplementedError( 24 | 'The Christoffel symbols are not implemented.') 25 | 26 | def connection(self, tangent_vector_a, tangent_vector_b, base_point): 27 | """ 28 | Connection applied to tangent_vector_b in the direction of 29 | tangent_vector_a, both tangent at base_point. 30 | 31 | Parameters 32 | ---------- 33 | tangent_vec_a: array-like, shape=[n_samples, dimension] 34 | or shape=[1, dimension] 35 | 36 | tangent_vec_b: array-like, shape=[n_samples, dimension] 37 | or shape=[1, dimension] 38 | 39 | base_point: array-like, shape=[n_samples, dimension] 40 | or shape=[1, dimension] 41 | """ 42 | raise NotImplementedError( 43 | 'connection is not implemented.') 44 | 45 | def exp(self, tangent_vector, base_point): 46 | """ 47 | Exponential map associated to the affine connection. 48 | 49 | Parameters 50 | ---------- 51 | tangent_vec: array-like, shape=[n_samples, dimension] 52 | or shape=[1, dimension] 53 | 54 | base_point: array-like, shape=[n_samples, dimension] 55 | or shape=[1, dimension] 56 | """ 57 | raise NotImplementedError( 58 | 'The affine connection exponential is not implemented.') 59 | 60 | def log(self, point, base_point): 61 | """ 62 | Logarithm map associated to the affine connection. 63 | 64 | Parameters 65 | ---------- 66 | point: array-like, shape=[n_samples, dimension] 67 | or shape=[1, dimension] 68 | 69 | base_point: array-like, shape=[n_samples, dimension] 70 | or shape=[1, dimension] 71 | """ 72 | raise NotImplementedError( 73 | 'The affine connection logarithm is not implemented.') 74 | 75 | def pole_ladder_transport( 76 | self, tangent_vector_a, tangent_vector_b, base_point): 77 | """ 78 | One step of pole ladder (parallel transport associated with the 79 | symmetric part of the connection using transvections). 80 | 81 | Parameters 82 | ---------- 83 | tangent_vec_a: array-like, shape=[n_samples, dimension] 84 | or shape=[1, dimension] 85 | 86 | tangent_vec_b: array-like, shape=[n_samples, dimension] 87 | or shape=[1, dimension] 88 | 89 | base_point: array-like, shape=[n_samples, dimension] 90 | or shape=[1, dimension] 91 | """ 92 | half_tangent_vector_b = 1. / 2. * tangent_vector_b 93 | mid_point = self.exp( 94 | base_point=base_point, 95 | tangent_vector=half_tangent_vector_b) 96 | 97 | mid_tangent_vector = - self.log( 98 | base_point=mid_point, 99 | point=base_point) 100 | end_point = self.exp( 101 | base_point=mid_point, 102 | tangent_vector=mid_tangent_vector) 103 | 104 | base_shoot = self.exp( 105 | base_point=base_point, 106 | tangent_vector=tangent_vector_a) 107 | mid_tangent_vector_to_shoot = - self.log( 108 | base_point=mid_point, 109 | end_point=base_shoot) 110 | end_shoot = self.exp( 111 | base_point=mid_point, 112 | tangent_vector=mid_tangent_vector_to_shoot) 113 | 114 | tangent_vector = - self.log(base_point=end_point, point=end_shoot) 115 | return tangent_vector 116 | 117 | def parallel_transport( 118 | self, tangent_vector_a, tangent_vector_b, base_point, n_points=1): 119 | """ 120 | Parallel transport of tangent vector a integrating the connection 121 | along the (affine connection) geodesic starting at the initial point 122 | base_point with initial tangent vector the tangent vector b. 123 | 124 | Returns a tangent vector at the point 125 | exp_(base_point)(tangent_vector_b). 126 | 127 | Parameters 128 | ---------- 129 | tangent_vec_a: array-like, shape=[n_samples, dimension] 130 | or shape=[1, dimension] 131 | 132 | tangent_vec_b: array-like, shape=[n_samples, dimension] 133 | or shape=[1, dimension] 134 | 135 | base_point: array-like, shape=[n_samples, dimension] 136 | or shape=[1, dimension] 137 | """ 138 | current_point = gs.copy(base_point) 139 | geodesic_tangent_vector = 1. / n_points * tangent_vector_b 140 | transported_tangent_vector = gs.copy(tangent_vector_a) 141 | for i_point in range(1, n_points): 142 | transported_tangent_vector = self.pole_ladder_transport( 143 | tangent_vector_a=transported_tangent_vector, 144 | tangent_vector_b=geodesic_tangent_vector, 145 | base_point=current_point) 146 | current_point = self.exp( 147 | base_point=current_point, 148 | tangent_vector=geodesic_tangent_vector) 149 | 150 | frac_tangent_vector_b = (i_point + 1) / n_points * tangent_vector_b 151 | next_point = self.exp( 152 | base_point=base_point, 153 | tangent_vector=frac_tangent_vector_b) 154 | geodesic_tangent_vector = self.log( 155 | base_point=current_point, 156 | point=next_point) 157 | 158 | return transported_tangent_vector 159 | 160 | def riemannian_curvature(self, base_point): 161 | """ 162 | Riemannian curvature tensor associated with the connection. 163 | 164 | Parameters 165 | ---------- 166 | base_point: array-like, shape=[n_samples, dimension] 167 | or shape=[1, dimension] 168 | """ 169 | raise NotImplementedError( 170 | 'The Riemannian curvature tensor is not implemented.') 171 | 172 | def geodesic_equation(self): 173 | """ 174 | The geodesic ordinary differential equation associated 175 | with the connection. 176 | """ 177 | raise NotImplementedError( 178 | 'The geodesic equation tensor is not implemented.') 179 | 180 | def geodesic(self, initial_point, 181 | end_point=None, initial_tangent_vec=None, point_ndim=1): 182 | """ 183 | Geodesic associated with the connection. 184 | """ 185 | raise NotImplementedError( 186 | 'Geodesics are not implemented.') 187 | 188 | def torsion(self, base_point): 189 | """ 190 | Torsion tensor associated with the connection. 191 | 192 | Parameters 193 | ---------- 194 | base_point: array-like, shape=[n_samples, dimension] 195 | or shape=[1, dimension] 196 | """ 197 | raise NotImplementedError( 198 | 'The torsion tensor is not implemented.') 199 | 200 | 201 | class LeviCivitaConnection(Connection): 202 | """ 203 | Levi-Civita connection associated with a Riemannian metric. 204 | """ 205 | def __init__(self, metric): 206 | self.metric = metric 207 | self.dimension = metric.dimension 208 | 209 | def metric_matrix(self, base_point): 210 | metric_matrix = self.metric.inner_product_matrix(base_point) 211 | return metric_matrix 212 | 213 | def cometric_matrix(self, base_point): 214 | """ 215 | The cometric is the inverse of the metric. 216 | 217 | Parameters 218 | ---------- 219 | base_point: array-like, shape=[n_samples, dimension] 220 | or shape=[1, dimension] 221 | """ 222 | metric_matrix = self.metric_matrix(base_point) 223 | cometric_matrix = gs.linalg.inv(metric_matrix) 224 | return cometric_matrix 225 | 226 | def metric_derivative(self, base_point): 227 | """ 228 | 229 | Parameters 230 | ---------- 231 | base_point: array-like, shape=[n_samples, dimension] 232 | or shape=[1, dimension] 233 | """ 234 | metric_derivative = autograd.jacobian(self.metric_matrix) 235 | return metric_derivative(base_point) 236 | 237 | def christoffel_symbols(self, base_point): 238 | """ 239 | Christoffel symbols associated with the connection. 240 | 241 | Parameters 242 | ---------- 243 | base_point: array-like, shape=[n_samples, dimension] 244 | or shape=[1, dimension] 245 | """ 246 | term_1 = gs.einsum('nim,nmkl->nikl', 247 | self.cometric_matrix(base_point), 248 | self.metric_derivative(base_point)) 249 | term_2 = gs.einsum('nim,nmlk->nilk', 250 | self.cometric_matrix(base_point), 251 | self.metric_derivative(base_point)) 252 | term_3 = - gs.einsum('nim,nklm->nikl', 253 | self.cometric_matrix(base_point), 254 | self.metric_derivative(base_point)) 255 | 256 | christoffel_symbols = 0.5 * (term_1 + term_2 + term_3) 257 | return christoffel_symbols 258 | 259 | def torsion(self, base_point): 260 | """ 261 | Torsion tensor associated with the Levi-Civita connection is zero. 262 | 263 | Parameters 264 | ---------- 265 | base_point: array-like, shape=[n_samples, dimension] 266 | or shape=[1, dimension] 267 | """ 268 | return gs.zeros((self.dimension,) * 3) 269 | --------------------------------------------------------------------------------