├── .github └── workflows │ └── build.yml ├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── docs ├── Makefile ├── _static │ └── usage.png ├── api │ ├── index.rst │ ├── layers.rst │ ├── manifolds.rst │ ├── optimizers.rst │ └── variable.rst ├── conf.py └── index.md ├── examples ├── grnet │ ├── README.md │ ├── __init__.py │ ├── grnet.png │ ├── model.py │ ├── requirements.txt │ └── task.py ├── hyperbolic_nn │ └── .gitkeep ├── lienet │ ├── README.md │ ├── __init__.py │ ├── lienet.png │ ├── model.py │ ├── requirements.txt │ └── task.py ├── poincare_glove │ └── .gitkeep ├── shared │ ├── __init__.py │ └── utils.py ├── spdnet │ ├── README.md │ ├── __init__.py │ ├── model.py │ ├── requirements.txt │ ├── spdnet.png │ └── task.py ├── tutorial.ipynb ├── usage.ipynb └── usage.png ├── requirements.txt ├── requirements_dev.txt ├── setup.cfg ├── setup.py └── tensorflow_riemopt ├── __init__.py ├── layers ├── __init__.py ├── embeddings.py └── embeddings_test.py ├── manifolds ├── __init__.py ├── approximate_mixin.py ├── approximate_mixin_test.py ├── cholesky.py ├── cholesky_test.py ├── euclidean.py ├── euclidean_test.py ├── grassmannian.py ├── grassmannian_test.py ├── hyperboloid.py ├── hyperboloid_test.py ├── manifold.py ├── poincare.py ├── poincare_test.py ├── product.py ├── product_test.py ├── special_orthogonal.py ├── special_orthogonal_test.py ├── sphere.py ├── sphere_test.py ├── stiefel.py ├── stiefel_test.py ├── symmetric_positive.py ├── symmetric_positive_test.py ├── test_invariants.py └── utils.py ├── mcmc └── .gitkeep ├── optimizers ├── __init__.py ├── constrained_rmsprop.py ├── constrained_rmsprop_test.py ├── riemannian_adam.py ├── riemannian_adam_test.py ├── riemannian_gradient_descent.py └── riemannian_gradient_descent_test.py ├── variable.py └── variable_test.py /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.10", "3.11", "3.12"] 11 | 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python ${{ matrix.python-version }} 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: ${{ matrix.python-version }} 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install -r requirements_dev.txt 22 | python -m pip install --editable . 23 | - name: Lint with black 24 | run: | 25 | black --line-length 80 --check --diff tensorflow_riemopt 26 | - name: Test with pytest 27 | run: | 28 | pytest -v tensorflow_riemopt --cov tensorflow_riemopt 29 | - name: Publish to coveralls.io 30 | env: 31 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 32 | run: | 33 | coveralls --service=github 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | __pycache__ 3 | *.pyc 4 | build 5 | dist 6 | *.egg-info 7 | docs/_build/ 8 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the OS, Python version, and other tools you might need 8 | build: 9 | os: ubuntu-24.04 10 | tools: 11 | python: "3.13" 12 | 13 | # Build documentation in the "docs/" directory with Sphinx 14 | sphinx: 15 | configuration: docs/conf.py 16 | 17 | # Optionally, but recommended, 18 | # declare the Python requirements required to build your documentation 19 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 20 | python: 21 | install: 22 | - requirements: requirements_dev.txt 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Oleg Smirnov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow RiemOpt 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/tensorflow-riemopt.svg)](https://pypi.org/project/tensorflow-riemopt/) 4 | [![arXiv](https://img.shields.io/badge/arXiv-2105.13921-b31b1b.svg)](https://arxiv.org/abs/2105.13921) 5 | [![Build Status](https://github.com/master/tensorflow-riemopt/actions/workflows/build.yml/badge.svg)](https://github.com/master/tensorflow-riemopt/actions) 6 | [![Documentation Status](https://readthedocs.org/projects/tensorflow-riemopt/badge/?version=latest)](https://tensorflow-riemopt.readthedocs.io) 7 | [![Coverage Status](https://coveralls.io/repos/github/master/tensorflow-riemopt/badge.svg)](https://coveralls.io/github/master/tensorflow-riemopt) 8 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/python/black) 9 | [![License](https://img.shields.io/:license-mit-blue.svg)](https://badges.mit-license.org) 10 | 11 | A library for manifold-constrained optimization in TensorFlow. 12 | 13 | ## Installation 14 | 15 | To install the latest development version from GitHub: 16 | 17 | ```bash 18 | pip install git+https://github.com/master/tensorflow-riemopt.git 19 | ``` 20 | 21 | To install a package from PyPI: 22 | 23 | ```bash 24 | pip install tensorflow-riemopt 25 | ``` 26 | 27 | ## Features 28 | 29 | The core package implements concepts in differential geometry, such as 30 | manifolds and Riemannian metrics with associated exponential and logarithmic 31 | maps, geodesics, retractions, and transports. For manifolds, where closed-form 32 | expressions are not available, the library provides numerical approximations. 33 | 34 | 35 | 36 | ```python 37 | import tensorflow_riemopt as riemopt 38 | 39 | S = riemopt.manifolds.Sphere() 40 | 41 | x = S.projx(tf.constant([0.1, -0.1, 0.1])) 42 | u = S.proju(x, tf.constant([1., 1., 1.])) 43 | v = S.proju(x, tf.constant([-0.7, -1.4, 1.4])) 44 | 45 | y = S.exp(x, v) 46 | 47 | u_ = S.transp(x, y, u) 48 | v_ = S.transp(x, y, v) 49 | ``` 50 | 51 | ### Manifolds 52 | 53 | - `manifolds.Cholesky` - manifold of lower triangular matrices with positive diagonal elements 54 | - `manifolds.Euclidian` - unconstrained manifold with the Euclidean metric 55 | - `manifolds.Grassmannian` - manifold of `p`-dimensional linear subspaces of the `n`-dimensional space 56 | - `manifolds.Hyperboloid` - manifold of `n`-dimensional hyperbolic space embedded in the `n+1`-dimensional Minkowski space 57 | - `manifolds.Poincare` - the Poincaré ball model of the hyperbolic space 58 | - `manifolds.Product` - Cartesian product of manifolds 59 | - `manifolds.SPDAffineInvariant` - manifold of symmetric positive definite (SPD) matrices endowed with the affine-invariant metric 60 | - `manifolds.SPDLogCholesky` - SPD manifold with the Log-Cholesky metric 61 | - `manifolds.SPDLogEuclidean` - SPD manifold with the Log-Euclidean metric 62 | - `manifolds.SpecialOrthogonal` - manifold of rotation matrices 63 | - `manifolds.Sphere` - manifold of unit-normalized points 64 | - `manifolds.StiefelEuclidean` - manifold of orthonormal `p`-frames in the `n`-dimensional space endowed with the Euclidean metric 65 | - `manifolds.StiefelCanonical` - Stiefel manifold with the canonical metric 66 | - `manifolds.StiefelCayley` - Stiefel manifold the retraction map via an iterative Cayley transform 67 | 68 | 69 | ### Optimizers 70 | 71 | Constrained optimization algorithms work as drop-in replacements for Keras 72 | optimizers for sparse and dense updates in both Eager and Graph modes. 73 | 74 | - `optimizers.RiemannianSGD` - Riemannian Gradient Descent 75 | - `optimizers.RiemannianAdam` - Riemannian Adam and AMSGrad 76 | - `optimizers.ConstrainedRMSProp` - Constrained RMSProp 77 | 78 | ### Layers 79 | 80 | - `layers.ManifoldEmbedding` - constrained `keras.layers.Embedding` layer 81 | 82 | ## Examples 83 | 84 | - [ANTHEM](https://github.com/amazon-science/hyperbolic-embeddings/tree/main/product_matching) - Choudhary, Nurendra, Nikhil Rao, Sumeet Katariya, Karthik Subbian, and Chandan K. Reddy. "ANTHEM: Attentive Hyperbolic Entity Model for Product Search." In Proceedings of the Fifteenth ACM International Conference on Web Search and Data Mining, 2022. 85 | - [SPDNet](examples/spdnet/) - Huang, Zhiwu, and Luc Van Gool. "A Riemannian network for SPD matrix learning." Proceedings of the Thirty-First AAAI Conference on Artificial Intelligence. AAAI Press, 2017. 86 | - [LieNet](examples/lienet/) - Huang, Zhiwu, et al. "Deep learning on Lie groups for skeleton-based action recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. 87 | - [GrNet](examples/grnet/) - Huang, Zhiwu, Jiqing Wu, and Luc Van Gool. "Building Deep Networks on Grassmann Manifolds." AAAI. AAAI Press, 2018. 88 | - [Hyperbolic Neural Network](examples/hyperbolic_nn/) - Ganea, Octavian, Gary Bécigneul, and Thomas Hofmann. "Hyperbolic neural networks." Advances in neural information processing systems. 2018. 89 | - [Poincaré GloVe](examples/poincare_glove/) - Tifrea, Alexandru, Gary Becigneul, and Octavian-Eugen Ganea. "Poincaré Glove: Hyperbolic Word Embeddings." International Conference on Learning Representations. 2018. 90 | 91 | ## References 92 | 93 | If you find TensorFlow RiemOpt useful in your research, please cite: 94 | 95 | ``` 96 | @misc{smirnov2021tensorflow, 97 | title={TensorFlow RiemOpt: a library for optimization on Riemannian manifolds}, 98 | author={Oleg Smirnov}, 99 | year={2021}, 100 | eprint={2105.13921}, 101 | archivePrefix={arXiv}, 102 | primaryClass={cs.MS} 103 | } 104 | ``` 105 | 106 | ## Acknowledgment 107 | 108 | TensorFlow RiemOpt was inspired by many similar projects: 109 | 110 | - [Manopt](https://www.manopt.org/), a matlab toolbox for optimization on manifolds 111 | - [Pymanopt](https://www.pymanopt.org/), a Python toolbox for optimization on manifolds 112 | - [Geoopt](https://geoopt.readthedocs.io/): Riemannian Optimization in PyTorch 113 | - [Geomstats](https://geomstats.github.io/), an open-source Python package for computations and statistics on nonlinear manifolds 114 | 115 | ## License 116 | 117 | The code is MIT-licensed. 118 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | ## Makefile for building the Sphinx documentation 2 | # You can also run sphinx-build directly: sphinx-build -b html . _build/html 3 | 4 | # Directory for source files 5 | SOURCEDIR = . 6 | # Directory for build output 7 | BUILDDIR = _build 8 | 9 | .PHONY: help clean html 10 | 11 | help: 12 | @echo "Please use 'make ' where is one of" 13 | @echo " html to build the HTML documentation" 14 | 15 | html: 16 | @sphinx-build -b html $(SOURCEDIR) $(BUILDDIR)/html 17 | 18 | clean: 19 | @rm -rf $(BUILDDIR) 20 | -------------------------------------------------------------------------------- /docs/_static/usage.png: -------------------------------------------------------------------------------- 1 | ../../examples/usage.png -------------------------------------------------------------------------------- /docs/api/index.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | This section contains the API reference for the `tensorflow_riemopt` package. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :titlesonly: 9 | 10 | variable 11 | layers 12 | manifolds 13 | optimizers 14 | -------------------------------------------------------------------------------- /docs/api/layers.rst: -------------------------------------------------------------------------------- 1 | Layers Package 2 | ============== 3 | 4 | This package provides manifold-aware neural network layers. 5 | 6 | .. automodule:: tensorflow_riemopt.layers 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | -------------------------------------------------------------------------------- /docs/api/manifolds.rst: -------------------------------------------------------------------------------- 1 | Manifolds Package 2 | ================= 3 | 4 | The module implements a variety of Riemannian manifolds as Python classes. Each manifold class encapsulates core geometric operations, such as distance functions, geodesics, exponential/logarithm maps, projections, and retractions, enabling integration of non-Euclidean geometry into neural network models. 5 | 6 | Certain manifolds (e.g., SPD and Stiefel) can be endowed with different Riemannian metrics, resulting in distinct operations implemented in their corresponding classes. 7 | 8 | .. automodule:: tensorflow_riemopt.manifolds 9 | 10 | .. automodule:: tensorflow_riemopt.manifolds 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | -------------------------------------------------------------------------------- /docs/api/optimizers.rst: -------------------------------------------------------------------------------- 1 | Optimizers Package 2 | ================== 3 | 4 | The module provides Riemannian optimization algorithms for TensorFlow and its Keras API. These optimizers adapt standard methods to manifold settings by projecting gradients onto tangent spaces, retracting updates back onto the manifold, and optionally using vector transport to handle adaptive schemes. 5 | 6 | .. automodule:: tensorflow_riemopt.optimizers 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | -------------------------------------------------------------------------------- /docs/api/variable.rst: -------------------------------------------------------------------------------- 1 | Variable Utilities 2 | ================== 3 | 4 | Utilities for associating TensorFlow variables with Riemannian manifolds. 5 | 6 | .. automodule:: tensorflow_riemopt.variable 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | import os 3 | import sys 4 | 5 | # Add project root to sys.path for autodoc 6 | sys.path.insert(0, os.path.abspath('..')) 7 | 8 | project = 'tensorflow-riemopt' 9 | author = 'Oleg Smirnov' 10 | # The full version, including alpha/beta/rc tags 11 | release = '0.3.0' 12 | 13 | extensions = [ 14 | 'sphinx.ext.autodoc', 15 | 'sphinx.ext.napoleon', 16 | 'sphinx.ext.viewcode', 17 | 'sphinx.ext.mathjax', 18 | 'sphinx_autodoc_typehints', 19 | 'myst_parser', 20 | ] 21 | 22 | templates_path = ['_templates'] 23 | # Exclude Jupyter checkpoints 24 | exclude_patterns = ['**.ipynb_checkpoints'] 25 | 26 | # HTML output 27 | html_title = "tensorflow-riemopt documentation" 28 | html_theme = 'pydata_sphinx_theme' 29 | html_static_path = ['_static'] 30 | html_theme_options = { 31 | "primary_sidebar_end": [], 32 | "icon_links": [ 33 | { 34 | "name": "GitHub", 35 | "url": "https://github.com/master/tensorflow-riemopt", 36 | "icon": "fa-brands fa-square-github", 37 | "type": "fontawesome", 38 | } 39 | ], 40 | "use_edit_page_button": False, 41 | "collapse_navigation": True, 42 | } 43 | html_context = { 44 | "github_user": "master", 45 | "github_repo": "tensorflow-riemopt", 46 | "doc_path": "docs", 47 | "default_mode": "light", 48 | } 49 | 50 | # Autodoc settings 51 | autodoc_member_order = 'bysource' 52 | autoclass_content = 'both' 53 | 54 | # Source suffix and master doc 55 | source_suffix = { 56 | '.rst': 'restructuredtext', 57 | '.md': 'markdown', 58 | } 59 | master_doc = 'index' 60 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # TensorFlow RiemOpt 2 | 3 | ```{figure} _static/usage.png 4 | :align: right 5 | :width: 300px 6 | ``` 7 | 8 | **TensorFlow RiemOpt** is a flexible, extensible library for Riemannian optimization and geometric deep learning in TensorFlow. It provides: 9 | - Riemannian manifold classes with associated exponential and logarithmic maps, geodesics, and transports 10 | - Riemannian optimizers (SGD, RMSProp, and adaptive methods such as Adam) 11 | - Manifold-aware TensorFlow/Keras layers (e.g., Embedding) 12 | 13 | ## Installation 14 | 15 | Install via PyPI: 16 | 17 | ```bash 18 | pip install tensorflow-riemopt 19 | ``` 20 | 21 | ## Quickstart 22 | 23 | ```python 24 | import tensorflow as tf 25 | from tensorflow_riemopt.optimizers import ConstrainedRMSprop 26 | from tensorflow_riemopt.manifolds import Grassmannian 27 | from tensorflow_riemopt.variable import assign_to_manifold 28 | 29 | # Create a variable on the Grassmannian manifold (3×2 matrix) 30 | manifold = Grassmannian() 31 | x = tf.Variable(tf.random.uniform((3, 2)), dtype=tf.float32) 32 | assign_to_manifold(x, manifold) 33 | 34 | # Define a simple loss (squared Frobenius norm) 35 | with tf.GradientTape() as tape: 36 | loss = tf.reduce_sum(x * x) 37 | 38 | # Compute gradients and update 39 | grads = tape.gradient(loss, [x]) 40 | opt = ConstrainedRMSprop(learning_rate=0.1, rho=0.9) 41 | opt.apply_gradients(zip(grads, [x])) 42 | ``` 43 | 44 | ## Documentation 45 | 46 | ```{toctree} 47 | :titlesonly: 48 | :glob: 49 | :maxdepth: 1 50 | 51 | api/* 52 | ``` 53 | 54 | ## Examples 55 | 56 | The repository provides several fully implemented example network projects: 57 | 58 | - **GrNet**: Deep networks on Grassmann manifolds. [examples/grnet](https://github.com/master/tensorflow-riemopt/tree/master/examples/grnet) 59 | - **LieNet**: Deep learning on Lie groups for action recognition. [examples/lienet](https://github.com/master/tensorflow-riemopt/tree/master/examples/lienet) 60 | - **SPDNet**: Riemannian network for SPD matrix learning. [examples/spdnet](https://github.com/master/tensorflow-riemopt/tree/master/examples/spdnet) 61 | -------------------------------------------------------------------------------- /examples/grnet/README.md: -------------------------------------------------------------------------------- 1 | # GrNet in TensorFlow 2 | 3 | Implementation of GrNet [1], a deep network on Grassmann manifolds. 4 | 5 | 6 | 7 | ## Requirements 8 | 9 | * Python 3.6+ 10 | * SciPy 11 | * NumPy 12 | * TensorFlow 2.0+ 13 | * TensorFlow RiemOpt 14 | 15 | ## Training 16 | 17 | Configure `gcloud` to use Python 3: 18 | 19 | ```bash 20 | gcloud config set ml_engine/local_python /usr/bin/python3 21 | ``` 22 | 23 | Train GrNet locally on the Acted Facial Expression in Wild [2] dataset: 24 | 25 | ```bash 26 | gcloud ai-platform local train \ 27 | --module-name grnet.task \ 28 | --package-path . \ 29 | -- \ 30 | --data-dir data 31 | --job-dir ckpt 32 | ``` 33 | 34 | ## References 35 | 36 | 1. Huang, Zhiwu, Jiqing Wu, and Luc Van Gool. "Building Deep Networks on 37 | Grassmann Manifolds." AAAI. AAAI Press, 2018. 38 | 2. Dhall, Abhinav, et al. "Acted facial expressions in the wild database." 39 | Australian National University, Canberra, Australia, Technical Report 40 | TR-CS-11 2 (2011): 1. 41 | -------------------------------------------------------------------------------- /examples/grnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/master/tensorflow-riemopt/743ca9f0fad8b735de8b537c54fde15fa3b54ba6/examples/grnet/__init__.py -------------------------------------------------------------------------------- /examples/grnet/grnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/master/tensorflow-riemopt/743ca9f0fad8b735de8b537c54fde15fa3b54ba6/examples/grnet/grnet.png -------------------------------------------------------------------------------- /examples/grnet/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow_riemopt.variable import assign_to_manifold 4 | from tensorflow_riemopt.manifolds import Grassmannian 5 | from tensorflow_riemopt.manifolds import utils 6 | from tensorflow_riemopt.optimizers import RiemannianSGD 7 | 8 | 9 | @tf.keras.utils.register_keras_serializable(name="FRMap") 10 | class FRMap(tf.keras.layers.Layer): 11 | """Full Rank Mapping layer.""" 12 | 13 | def __init__(self, output_dim, num_proj=8, *args, **kwargs): 14 | """Instantiate the FRMap layer. 15 | 16 | Args: 17 | output_dim: projection output dimension 18 | num_proj: number of projections to compute 19 | """ 20 | super().__init__(*args, **kwargs) 21 | self.output_dim = output_dim 22 | self.num_proj = num_proj 23 | 24 | def build(self, input_shape): 25 | grassmannian = Grassmannian() 26 | self.w = self.add_weight( 27 | "w", 28 | shape=[self.num_proj, input_shape[-2], self.output_dim], 29 | initializer=grassmannian.random, 30 | ) 31 | assign_to_manifold(self.w, grassmannian) 32 | self._expand = len(input_shape) == 3 33 | 34 | def call(self, inputs): 35 | if self._expand: 36 | inputs = tf.expand_dims(inputs, -3) 37 | return utils.transposem(self.w) @ inputs 38 | 39 | def get_config(self): 40 | config = {"output_dim": self.output_dim, "num_proj": self.num_proj} 41 | return dict(list(super().get_config().items()) + list(config.items())) 42 | 43 | 44 | @tf.keras.utils.register_keras_serializable(name="ReOrth") 45 | class ReOrth(tf.keras.layers.Layer): 46 | """Re-Orthonormalization layer.""" 47 | 48 | def call(self, inputs): 49 | q, _r = tf.linalg.qr(inputs) 50 | return q 51 | 52 | 53 | @tf.keras.utils.register_keras_serializable(name="ProjMap") 54 | class ProjMap(tf.keras.layers.Layer): 55 | """Projection Mapping layer.""" 56 | 57 | def call(self, inputs): 58 | return inputs @ utils.transposem(inputs) 59 | 60 | 61 | @tf.keras.utils.register_keras_serializable(name="ProjPooling") 62 | class ProjPooling(tf.keras.layers.Layer): 63 | """Projection Pooling layer.""" 64 | 65 | def __init__(self, stride=2, *args, **kwargs): 66 | """Instantiate the ProjPooling layer. 67 | 68 | Args: 69 | stride: factor by which to downscale 70 | """ 71 | super().__init__(*args, **kwargs) 72 | self.stride = stride 73 | 74 | def call(self, inputs): 75 | shape = tf.shape(inputs) 76 | new_shape = [ 77 | shape[0], 78 | shape[1], 79 | self.stride, 80 | shape[2] // self.stride, 81 | shape[3], 82 | ] 83 | return tf.reduce_mean( 84 | tf.reshape(inputs, new_shape), axis=-3, keepdims=False 85 | ) 86 | 87 | def get_config(self): 88 | config = {"stride": self.stride} 89 | return dict(list(super().get_config().items()) + list(config.items())) 90 | 91 | 92 | @tf.keras.utils.register_keras_serializable(name="OrthMap") 93 | class OrthMap(tf.keras.layers.Layer): 94 | """Orthonormal Mapping layer.""" 95 | 96 | def __init__(self, top_eigen, *args, **kwargs): 97 | """Instantiate the OrthMap layer. 98 | 99 | Args: 100 | num_eigen: number of top eigenvectors to retain 101 | """ 102 | super().__init__(*args, **kwargs) 103 | self.top_eigen = top_eigen 104 | 105 | def call(self, inputs): 106 | _s, u, _vt = tf.linalg.svd(inputs) 107 | return u[..., : self.top_eigen] 108 | 109 | def get_config(self): 110 | config = {"top_eigen": self.top_eigen} 111 | return dict(list(super().get_config().items()) + list(config.items())) 112 | 113 | 114 | def create_model( 115 | learning_rate, 116 | num_classes, 117 | frmap_dims=[300, 100], 118 | pool_stride=2, 119 | top_eigen=10, 120 | ): 121 | """Instantiate the GrNet architecture. 122 | 123 | Huang, Zhiwu, Jiqing Wu, and Luc Van Gool. "Building Deep Networks on 124 | Grassmann Manifolds." AAAI. AAAI Press, 2018. 125 | 126 | Args: 127 | learning_rate: model learning rate 128 | num_classes: number of output classes 129 | frmap_dims: dimensions of FrMap layers 130 | pool_stride: pooling stride 131 | top_eigen: number of eigenvectors to retain in OrthMap 132 | """ 133 | model = tf.keras.Sequential() 134 | for output_dim in frmap_dims: 135 | model.add(FRMap(output_dim)) 136 | model.add(ReOrth()) 137 | model.add(ProjMap()) 138 | model.add(ProjPooling(pool_stride)) 139 | model.add(OrthMap(top_eigen)) 140 | model.add(ProjMap()) 141 | model.add(tf.keras.layers.Flatten()) 142 | model.add(tf.keras.layers.Dense(num_classes, use_bias=False)) 143 | model.compile( 144 | optimizer=RiemannianSGD(learning_rate), 145 | loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True), 146 | metrics=[tf.metrics.SparseCategoricalAccuracy()], 147 | ) 148 | return model 149 | -------------------------------------------------------------------------------- /examples/grnet/requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | tensorflow-riemopt 3 | -------------------------------------------------------------------------------- /examples/grnet/task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os 4 | import tensorflow as tf 5 | 6 | import model 7 | from shared import utils 8 | 9 | DATA_URL = "https://data.vision.ee.ethz.ch/zzhiwu/ManifoldNetData/GrData/AFEW_Gr_data.zip" 10 | DATA_FOLDER = "grface_400_inter_histeq" 11 | AFEW_CLASSES = 7 12 | 13 | 14 | def get_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | '--job-dir', type=str, required=True, help='checkpoint dir' 18 | ) 19 | parser.add_argument('--data-dir', type=str, required=True, help='data dir') 20 | parser.add_argument( 21 | '--num-epochs', 22 | type=float, 23 | default=50, 24 | help='number of training epochs (default 50)', 25 | ) 26 | parser.add_argument( 27 | '--batch-size', 28 | default=30, 29 | type=int, 30 | help='number of examples per batch (default 30)', 31 | ) 32 | parser.add_argument( 33 | '--shuffle-buffer', 34 | default=100, 35 | type=int, 36 | help='shuffle buffer size (default 100)', 37 | ) 38 | parser.add_argument( 39 | '--learning-rate', 40 | default=0.01, 41 | type=float, 42 | help='learning rate (default .01)', 43 | ) 44 | return parser.parse_args() 45 | 46 | 47 | def train_and_evaluate(args): 48 | utils.download_data(args.data_dir, DATA_URL, unpack=True) 49 | train = utils.load_matlab_data("Y1", args.data_dir, DATA_FOLDER, "train") 50 | val = utils.load_matlab_data("Y1", args.data_dir, DATA_FOLDER, "val") 51 | train_dataset = ( 52 | tf.data.Dataset.from_tensor_slices(train) 53 | .repeat(args.num_epochs) 54 | .shuffle(args.shuffle_buffer) 55 | .batch(args.batch_size, drop_remainder=True) 56 | ) 57 | val_dataset = tf.data.Dataset.from_tensor_slices(val).batch( 58 | args.batch_size, drop_remainder=True 59 | ) 60 | 61 | grnet = model.create_model(args.learning_rate, num_classes=AFEW_CLASSES) 62 | 63 | os.makedirs(args.job_dir, exist_ok=True) 64 | checkpoint_path = os.path.join(args.job_dir, "afew-grnet.ckpt") 65 | cp_callback = tf.keras.callbacks.ModelCheckpoint( 66 | filepath=checkpoint_path, save_weights_only=True, verbose=1 67 | ) 68 | log_dir = os.path.join(args.job_dir, "logs") 69 | tb_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir) 70 | 71 | grnet.fit( 72 | train_dataset, 73 | epochs=args.num_epochs, 74 | validation_data=val_dataset, 75 | callbacks=[cp_callback, tb_callback], 76 | ) 77 | _, acc = grnet.evaluate(val_dataset, verbose=2) 78 | print("Final accuracy: {}%".format(acc * 100)) 79 | 80 | 81 | if __name__ == "__main__": 82 | tf.get_logger().setLevel("INFO") 83 | train_and_evaluate(get_args()) 84 | -------------------------------------------------------------------------------- /examples/hyperbolic_nn/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/master/tensorflow-riemopt/743ca9f0fad8b735de8b537c54fde15fa3b54ba6/examples/hyperbolic_nn/.gitkeep -------------------------------------------------------------------------------- /examples/lienet/README.md: -------------------------------------------------------------------------------- 1 | # LieNet in TensorFlow 2 | 3 | Implementation of LieNet [1], a deep learning network on Lie Groups for 4 | skeleton-based action recognition. 5 | 6 | 7 | 8 | ## Requirements 9 | 10 | * Python 3.6+ 11 | * SciPy 12 | * NumPy 13 | * TensorFlow 2.0+ 14 | * TensorFlow RiemOpt 15 | 16 | ## Training 17 | 18 | Configure `gcloud` to use Python 3: 19 | 20 | ```bash 21 | gcloud config set ml_engine/local_python /usr/bin/python3 22 | ``` 23 | 24 | Train LieNet locally on the G3D-Gaming [2] dataset: 25 | 26 | ```bash 27 | gcloud ai-platform local train \ 28 | --module-name lienet.task \ 29 | --package-path . \ 30 | -- \ 31 | --data-dir data 32 | --job-dir ckpt 33 | ``` 34 | 35 | ## References 36 | 37 | 1. Huang, Zhiwu, et al. "Deep learning on Lie groups for skeleton-based 38 | action recognition." Proceedings of the IEEE conference on computer vision 39 | and pattern recognition. 2017. 40 | 41 | 2. Bloom, Victoria, Dimitrios Makris, and Vasileios Argyriou. "G3D: A gaming 42 | action dataset and real time action recognition evaluation framework." 2012 43 | IEEE Computer Society Conference on Computer Vision and Pattern Recognition 44 | Workshops. IEEE, 2012. 45 | -------------------------------------------------------------------------------- /examples/lienet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/master/tensorflow-riemopt/743ca9f0fad8b735de8b537c54fde15fa3b54ba6/examples/lienet/__init__.py -------------------------------------------------------------------------------- /examples/lienet/lienet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/master/tensorflow-riemopt/743ca9f0fad8b735de8b537c54fde15fa3b54ba6/examples/lienet/lienet.png -------------------------------------------------------------------------------- /examples/lienet/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow_riemopt.variable import assign_to_manifold 4 | from tensorflow_riemopt.manifolds import SpecialOrthogonal 5 | from tensorflow_riemopt.manifolds import utils 6 | from tensorflow_riemopt.optimizers import RiemannianSGD 7 | 8 | 9 | EPS = 1e-7 10 | 11 | 12 | def rot_angle(inputs): 13 | cos = (tf.linalg.trace(inputs) - 1) / 2.0 14 | return tf.math.acos(tf.clip_by_value(cos, -1.0, 1.0)) 15 | 16 | 17 | @tf.keras.utils.register_keras_serializable(name="RotMap") 18 | class RotMap(tf.keras.layers.Layer): 19 | """Rotation Mapping layer.""" 20 | 21 | def build(self, input_shape): 22 | """Create weights depending on the shape of the inputs. 23 | 24 | Expected `input_shape`: 25 | `[batch_size, spatial_dim, temp_dim, num_rows, num_cols]`, where 26 | `num_rows` = 3, `num_cols` = 3, `temp_dim` is the number of frames, 27 | and `spatial_dim` is the number of edges. 28 | """ 29 | input_shape = input_shape.as_list() 30 | so = SpecialOrthogonal() 31 | self.w = self.add_weight( 32 | "w", 33 | shape=[input_shape[-4], input_shape[-2], input_shape[-1]], 34 | initializer=so.random, 35 | ) 36 | assign_to_manifold(self.w, so) 37 | 38 | def call(self, inputs): 39 | return tf.einsum("sij,...stjk->...stik", self.w, inputs) 40 | 41 | 42 | @tf.keras.utils.register_keras_serializable(name="RotPooling") 43 | class RotPooling(tf.keras.layers.Layer): 44 | """Rotation Pooling layer.""" 45 | 46 | def __init__(self, pooling, *args, **kwargs): 47 | """Instantiate the RotPooling layer. 48 | 49 | Args: 50 | pooling: `spatial` or `temporal` pooling type 51 | """ 52 | super().__init__(*args, **kwargs) 53 | if not pooling in ["spatial", "temporal"]: 54 | raise ValueError("Invalid pooling type {}".format(pooling)) 55 | self.pooling = pooling 56 | 57 | def call(self, inputs): 58 | if self.pooling == "spatial": 59 | stride = 2 60 | # reshape to [batch_size, temp_dim, spatial_dim, num_rows, num_cols] 61 | inputs = tf.transpose(inputs, perm=[0, 2, 1, 3, 4]) 62 | elif self.pooling == "temporal": 63 | stride = 4 64 | # pad temp_dim to a multitude of stride 65 | temp_dim = tf.shape(inputs)[2] 66 | temp_pad = ( 67 | tf.cast(tf.math.ceil(temp_dim / stride), tf.int32) * stride 68 | - temp_dim 69 | ) 70 | padding = [[0, 0], [0, 0], [0, temp_pad], [0, 0], [0, 0]] 71 | inputs = tf.pad(inputs, padding) 72 | shape = tf.shape(inputs) 73 | new_shape = [ 74 | shape[0], 75 | shape[1], 76 | shape[2] // stride, 77 | stride, 78 | shape[3], 79 | shape[4], 80 | ] 81 | inputs = tf.reshape(inputs, new_shape) 82 | thetas = rot_angle(inputs) 83 | indices = tf.math.argmax(thetas, axis=-1) 84 | results = tf.gather(inputs, indices, batch_dims=3) 85 | if self.pooling == "spatial": 86 | results = tf.transpose(results, perm=[0, 2, 1, 3, 4]) 87 | return results 88 | 89 | def get_config(self): 90 | config = {"pooling": self.pooling} 91 | return dict(list(super().get_config().items()) + list(config.items())) 92 | 93 | 94 | @tf.keras.utils.register_keras_serializable(name="LogMap") 95 | class LogMap(tf.keras.layers.Layer): 96 | """Logarithmic Map layer.""" 97 | 98 | def call(self, inputs): 99 | thetas = rot_angle(inputs)[..., tf.newaxis, tf.newaxis] 100 | zeros = tf.zeros_like(thetas) 101 | skew = (inputs - utils.transposem(inputs)) / 2.0 102 | log = skew * thetas / tf.math.sin(thetas) 103 | return tf.where(thetas < EPS, zeros, log) 104 | 105 | 106 | def create_model( 107 | learning_rate, 108 | num_classes, 109 | pooling_types=["spatial", "temporal", "temporal"], 110 | ): 111 | """Instantiate the LieNet architecture. 112 | 113 | Huang, Zhiwu, et al. "Deep learning on Lie groups for skeleton-based 114 | action recognition." Proceedings of the IEEE conference on computer vision 115 | and pattern recognition. 2017. 116 | 117 | Args: 118 | learning_rate: model learning rate 119 | num_classes: number of output classes 120 | pooling_types: types of pooling layers in RotMap/RotPooling blocks 121 | """ 122 | model = tf.keras.Sequential() 123 | for pooling in pooling_types: 124 | model.add(RotMap()) 125 | model.add(RotPooling(pooling)) 126 | model.add(LogMap()) 127 | model.add(tf.keras.layers.Flatten()) 128 | model.add(tf.keras.layers.Dense(num_classes, use_bias=False)) 129 | model.compile( 130 | optimizer=RiemannianSGD(learning_rate), 131 | loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True), 132 | metrics=[tf.metrics.SparseCategoricalAccuracy()], 133 | ) 134 | return model 135 | -------------------------------------------------------------------------------- /examples/lienet/requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | tensorflow-riemopt 3 | -------------------------------------------------------------------------------- /examples/lienet/task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | import model 8 | from shared import utils 9 | 10 | DATA_URL = "https://data.vision.ee.ethz.ch/zzhiwu/ManifoldNetData/LieData/G3D_Lie_data.zip" 11 | DATA_FOLDER = "lie20_half_inter1" 12 | G3D_CLASSES = 20 13 | VAL_SPLIT = 0.2 14 | 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | '--job-dir', type=str, required=True, help='checkpoint dir' 20 | ) 21 | parser.add_argument('--data-dir', type=str, required=True, help='data dir') 22 | parser.add_argument( 23 | '--num-epochs', 24 | type=float, 25 | default=100, 26 | help='number of training epochs (default 50)', 27 | ) 28 | parser.add_argument( 29 | '--batch-size', 30 | default=30, 31 | type=int, 32 | help='number of examples per batch (default 30)', 33 | ) 34 | parser.add_argument( 35 | '--shuffle-buffer', 36 | default=100, 37 | type=int, 38 | help='shuffle buffer size (default 100)', 39 | ) 40 | parser.add_argument( 41 | '--learning-rate', 42 | default=0.01, 43 | type=float, 44 | help='learning rate (default .01)', 45 | ) 46 | return parser.parse_args() 47 | 48 | 49 | def prepare_data(args): 50 | features, labels = utils.load_matlab_data("fea", args.data_dir, DATA_FOLDER) 51 | features = np.array([np.stack(example) for example in features.squeeze()]) 52 | # reshape to [batch_size, spatial_dim, temp_dim, num_rows, num_cols] 53 | features = np.transpose(features, axes=[0, 1, 4, 2, 3]) 54 | indices = np.random.permutation(len(features)) 55 | features, labels = features[indices], labels[indices] 56 | val_len = int(len(features) * VAL_SPLIT) 57 | X_train, X_val = features[-val_len:, ...], features[:-val_len, ...] 58 | y_train, y_val = labels[-val_len:, ...], labels[:-val_len, ...] 59 | return (X_train, y_train), (X_val, y_val) 60 | 61 | 62 | def train_and_evaluate(args): 63 | utils.download_data(args.data_dir, DATA_URL, unpack=True) 64 | train, val = prepare_data(args) 65 | 66 | train_dataset = ( 67 | tf.data.Dataset.from_tensor_slices(train) 68 | .repeat(args.num_epochs) 69 | .shuffle(args.shuffle_buffer) 70 | .batch(args.batch_size, drop_remainder=True) 71 | ) 72 | val_dataset = tf.data.Dataset.from_tensor_slices(val).batch( 73 | args.batch_size, drop_remainder=True 74 | ) 75 | 76 | lienet = model.create_model(args.learning_rate, num_classes=G3D_CLASSES) 77 | 78 | os.makedirs(args.job_dir, exist_ok=True) 79 | checkpoint_path = os.path.join(args.job_dir, "g3d-lienet.ckpt") 80 | cp_callback = tf.keras.callbacks.ModelCheckpoint( 81 | filepath=checkpoint_path, save_weights_only=True, verbose=1 82 | ) 83 | log_dir = os.path.join(args.job_dir, "logs") 84 | tb_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir) 85 | 86 | lienet.fit( 87 | train_dataset, 88 | epochs=args.num_epochs, 89 | validation_data=val_dataset, 90 | callbacks=[cp_callback, tb_callback], 91 | ) 92 | _, acc = lienet.evaluate(val_dataset, verbose=2) 93 | print("Final accuracy: {}%".format(acc * 100)) 94 | 95 | 96 | if __name__ == "__main__": 97 | tf.get_logger().setLevel("INFO") 98 | train_and_evaluate(get_args()) 99 | -------------------------------------------------------------------------------- /examples/poincare_glove/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/master/tensorflow-riemopt/743ca9f0fad8b735de8b537c54fde15fa3b54ba6/examples/poincare_glove/.gitkeep -------------------------------------------------------------------------------- /examples/shared/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/master/tensorflow-riemopt/743ca9f0fad8b735de8b537c54fde15fa3b54ba6/examples/shared/__init__.py -------------------------------------------------------------------------------- /examples/shared/utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import zipfile 3 | import os 4 | import glob 5 | from pathlib import Path 6 | 7 | import tqdm 8 | import scipy.io 9 | import numpy as np 10 | 11 | 12 | def download_data(data_dir, url, unpack=True, block_size=10 * 1024): 13 | filename = os.path.join(data_dir, os.path.basename(url)) 14 | os.makedirs(data_dir, exist_ok=True) 15 | 16 | if os.path.exists(filename): 17 | print("{} already exists. Skipping download".format(filename)) 18 | return 19 | 20 | print("Downloading {} to {}".format(url, filename)) 21 | response = requests.get(url, stream=True) 22 | total = int(response.headers.get("content-length", 0)) 23 | progress_bar = tqdm.tqdm(total=total, unit="iB", unit_scale=True) 24 | with open(filename, "wb") as f: 25 | for data in response.iter_content(block_size): 26 | progress_bar.update(len(data)) 27 | f.write(data) 28 | progress_bar.close() 29 | 30 | if total != 0 and progress_bar.n != total: 31 | raise RuntimeError("Error downloading {}".format(url)) 32 | 33 | if unpack and filename[-3:] == "zip": 34 | with open(filename, "rb") as f: 35 | with zipfile.ZipFile(f) as zip_ref: 36 | zip_ref.extractall(data_dir) 37 | print("Unzipped {} to {}".format(filename, data_dir)) 38 | 39 | 40 | def load_matlab_data(key, data_dir, *folders): 41 | folders = [data_dir] + list(folders) + ["*/*.mat"] 42 | examples, labels = [], [] 43 | for filename in glob.glob(os.path.join(*folders)): 44 | examples.append(scipy.io.loadmat(filename)[key]) 45 | labels.append(int(Path(filename).parts[-2])) 46 | return np.stack(examples), np.array(labels) - 1 47 | -------------------------------------------------------------------------------- /examples/spdnet/README.md: -------------------------------------------------------------------------------- 1 | # SPDNet in TensorFlow 2 | 3 | Implementation of SPDNet [1], a Riemannian network for SPD matrix learning. 4 | 5 | 6 | 7 | ## Requirements 8 | 9 | * Python 3.6+ 10 | * SciPy 11 | * NumPy 12 | * TensorFlow 2.0+ 13 | * TensorFlow RiemOpt 14 | 15 | ## Training 16 | 17 | Configure `gcloud` to use Python 3: 18 | 19 | ```bash 20 | gcloud config set ml_engine/local_python /usr/bin/python3 21 | ``` 22 | 23 | Train SPDNet locally on the Acted Facial Expression in Wild [2] dataset: 24 | 25 | ```bash 26 | gcloud ai-platform local train \ 27 | --module-name spdnet.task \ 28 | --package-path . \ 29 | -- \ 30 | --data-dir data 31 | --job-dir ckpt 32 | ``` 33 | 34 | ## References 35 | 36 | 1. Huang, Zhiwu, and Luc Van Gool. "A riemannian network for SPD matrix 37 | learning." Proceedings of the Thirty-First AAAI Conference on Artificial 38 | Intelligence. AAAI Press, 2017. 39 | 2. Dhall, Abhinav, et al. "Acted facial expressions in the wild database." 40 | Australian National University, Canberra, Australia, Technical Report 41 | TR-CS-11 2 (2011): 1. 42 | -------------------------------------------------------------------------------- /examples/spdnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/master/tensorflow-riemopt/743ca9f0fad8b735de8b537c54fde15fa3b54ba6/examples/spdnet/__init__.py -------------------------------------------------------------------------------- /examples/spdnet/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow_riemopt.variable import assign_to_manifold 4 | from tensorflow_riemopt.manifolds import StiefelEuclidean 5 | from tensorflow_riemopt.manifolds import utils 6 | from tensorflow_riemopt.optimizers import RiemannianSGD 7 | 8 | 9 | @tf.keras.utils.register_keras_serializable(name="BiMap") 10 | class BiMap(tf.keras.layers.Layer): 11 | """Bilinear Mapping layer.""" 12 | 13 | def __init__(self, output_dim, *args, **kwargs): 14 | """Instantiate the BiMap layer. 15 | 16 | Args: 17 | output_dim: projection output dimension 18 | """ 19 | super().__init__(*args, **kwargs) 20 | self.output_dim = output_dim 21 | 22 | def build(self, input_shape): 23 | self.w = self.add_weight( 24 | "w", shape=[int(input_shape[-1]), self.output_dim] 25 | ) 26 | assign_to_manifold(self.w, StiefelEuclidean()) 27 | 28 | def call(self, inputs): 29 | return utils.transposem(self.w) @ inputs @ self.w 30 | 31 | def get_config(self): 32 | config = {"output_dim": self.output_dim} 33 | return dict(list(super().get_config().items()) + list(config.items())) 34 | 35 | 36 | @tf.keras.utils.register_keras_serializable(name="ReEig") 37 | class ReEig(tf.keras.layers.Layer): 38 | """Eigen Rectifier layer.""" 39 | 40 | def __init__(self, epsilon=1e-4, *args, **kwargs): 41 | """Instantiate the ReEig layer. 42 | 43 | Args: 44 | epsilon: a rectification threshold value 45 | """ 46 | super().__init__(*args, **kwargs) 47 | self.epsilon = epsilon 48 | 49 | def call(self, inputs): 50 | s, u, v = tf.linalg.svd(inputs) 51 | sigma = tf.maximum(s, self.epsilon) 52 | return u @ tf.linalg.diag(sigma) @ utils.transposem(v) 53 | 54 | def get_config(self): 55 | config = {"epsilon": self.epsilon} 56 | return dict(list(super().get_config().items()) + list(config.items())) 57 | 58 | 59 | @tf.keras.utils.register_keras_serializable(name="LogEig") 60 | class LogEig(tf.keras.layers.Layer): 61 | """Eigen Log layer.""" 62 | 63 | def call(self, inputs): 64 | s, u, v = tf.linalg.svd(inputs) 65 | log_s = tf.math.log(s) 66 | return u @ tf.linalg.diag(log_s) @ utils.transposem(v) 67 | 68 | 69 | def create_model( 70 | learning_rate, num_classes, bimap_dims=[200, 100, 50], eig_eps=1e-4 71 | ): 72 | """Instantiate the SPDNet architecture. 73 | 74 | Huang, Zhiwu, and Luc Van Gool. "A riemannian network for SPD matrix 75 | learning." Proceedings of the Thirty-First AAAI Conference on Artificial 76 | Intelligence. 2017. 77 | 78 | Args: 79 | learning_rate: model learning rate 80 | num_classes: number of output classes 81 | bimap_dims: dimensions of BiMap layers 82 | eig_eps: a rectification threshold value 83 | """ 84 | 85 | model = tf.keras.Sequential() 86 | for output_dim in bimap_dims: 87 | model.add(BiMap(output_dim)) 88 | model.add(ReEig(eig_eps)) 89 | model.add(LogEig()) 90 | model.add(tf.keras.layers.Flatten()) 91 | model.add(tf.keras.layers.Dense(num_classes, use_bias=False)) 92 | model.compile( 93 | optimizer=RiemannianSGD(learning_rate), 94 | loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True), 95 | metrics=[tf.metrics.SparseCategoricalAccuracy()], 96 | ) 97 | return model 98 | -------------------------------------------------------------------------------- /examples/spdnet/requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | tensorflow-riemopt 3 | -------------------------------------------------------------------------------- /examples/spdnet/spdnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/master/tensorflow-riemopt/743ca9f0fad8b735de8b537c54fde15fa3b54ba6/examples/spdnet/spdnet.png -------------------------------------------------------------------------------- /examples/spdnet/task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os 4 | import tensorflow as tf 5 | 6 | from . import model 7 | from shared import utils 8 | 9 | DATA_URL = "https://data.vision.ee.ethz.ch/zzhiwu/ManifoldNetData/SPDData/AFEW_SPD_data.zip" 10 | DATA_FOLDER = "spdface_400_inter_histeq" 11 | AFEW_CLASSES = 7 12 | 13 | 14 | def get_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | '--job-dir', type=str, required=True, help='checkpoint dir' 18 | ) 19 | parser.add_argument('--data-dir', type=str, required=True, help='data dir') 20 | parser.add_argument( 21 | '--num-epochs', 22 | type=float, 23 | default=50, 24 | help='number of training epochs (default 50)', 25 | ) 26 | parser.add_argument( 27 | '--batch-size', 28 | default=30, 29 | type=int, 30 | help='number of examples per batch (default 30)', 31 | ) 32 | parser.add_argument( 33 | '--shuffle-buffer', 34 | default=100, 35 | type=int, 36 | help='shuffle buffer size (default 100)', 37 | ) 38 | parser.add_argument( 39 | '--learning-rate', 40 | default=0.01, 41 | type=float, 42 | help='learning rate (default .01)', 43 | ) 44 | return parser.parse_args() 45 | 46 | 47 | def train_and_evaluate(args): 48 | utils.download_data(args.data_dir, DATA_URL, unpack=True) 49 | train = utils.load_matlab_data("Y1", args.data_dir, DATA_FOLDER, "train") 50 | val = utils.load_matlab_data("Y1", args.data_dir, DATA_FOLDER, "val") 51 | train_dataset = ( 52 | tf.data.Dataset.from_tensor_slices(train) 53 | .repeat(args.num_epochs) 54 | .shuffle(args.shuffle_buffer) 55 | .batch(args.batch_size, drop_remainder=True) 56 | ) 57 | val_dataset = tf.data.Dataset.from_tensor_slices(val).batch( 58 | args.batch_size, drop_remainder=True 59 | ) 60 | 61 | spdnet = model.create_model(args.learning_rate, num_classes=AFEW_CLASSES) 62 | 63 | os.makedirs(args.job_dir, exist_ok=True) 64 | checkpoint_path = os.path.join(args.job_dir, "afew-spdnet.ckpt") 65 | cp_callback = tf.keras.callbacks.ModelCheckpoint( 66 | filepath=checkpoint_path, save_weights_only=True, verbose=1 67 | ) 68 | log_dir = os.path.join(args.job_dir, "logs") 69 | tb_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir) 70 | 71 | spdnet.fit( 72 | train_dataset, 73 | epochs=args.num_epochs, 74 | validation_data=val_dataset, 75 | callbacks=[cp_callback, tb_callback], 76 | ) 77 | _, acc = spdnet.evaluate(val_dataset, verbose=2) 78 | print("Final accuracy: {}%".format(acc * 100)) 79 | 80 | 81 | if __name__ == "__main__": 82 | tf.get_logger().setLevel("INFO") 83 | train_and_evaluate(get_args()) 84 | -------------------------------------------------------------------------------- /examples/usage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/master/tensorflow-riemopt/743ca9f0fad8b735de8b537c54fde15fa3b54ba6/examples/usage.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow 2 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | pip>=20.0.0 2 | black 3 | coveralls 4 | pytest-cov 5 | sphinx 6 | myst-parser 7 | sphinx-autodoc-typehints 8 | pydata-sphinx-theme 9 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import setup, find_packages 3 | 4 | setup( 5 | name="tensorflow-riemopt", 6 | version="0.3.0", 7 | description="a library for optimization on Riemannian manifolds", 8 | long_description=open("README.md").read(), 9 | long_description_content_type="text/markdown", 10 | author="Oleg Smirnov", 11 | author_email="oleg.smirnov@gmail.com", 12 | packages=find_packages(), 13 | install_requires=["tensorflow"], 14 | python_requires=">=3.10.0", 15 | url="https://github.com/master/tensorflow-riemopt", 16 | zip_safe=True, 17 | classifiers=[ 18 | "Development Status :: 4 - Beta", 19 | "Intended Audience :: Science/Research", 20 | "License :: OSI Approved :: MIT License", 21 | "Programming Language :: Python :: 3", 22 | "Programming Language :: Python :: 3.10", 23 | "Programming Language :: Python :: 3.11", 24 | "Programming Language :: Python :: 3.12", 25 | "Topic :: Scientific/Engineering :: Mathematics", 26 | "Topic :: Software Development :: Libraries :: Python Modules", 27 | "Topic :: Software Development :: Libraries", 28 | "Operating System :: OS Independent", 29 | ], 30 | keywords="tensorflow optimization machine learning", 31 | ) 32 | -------------------------------------------------------------------------------- /tensorflow_riemopt/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow_riemopt import manifolds, optimizers 2 | 3 | __all__ = ["manifolds", "optimizers"] 4 | -------------------------------------------------------------------------------- /tensorflow_riemopt/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow_riemopt.layers.embeddings import ManifoldEmbedding 2 | 3 | __all__ = ["ManifoldEmbedding"] 4 | -------------------------------------------------------------------------------- /tensorflow_riemopt/layers/embeddings.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow_riemopt.variable import assign_to_manifold 4 | from tensorflow_riemopt.manifolds import Manifold 5 | 6 | 7 | @tf.keras.utils.register_keras_serializable(name="ManifoldEmbedding") 8 | class ManifoldEmbedding(tf.keras.layers.Embedding): 9 | def __init__(self, *args, manifold, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | self.manifold = manifold 12 | 13 | def build(self, input_shape): 14 | super().build(input_shape) 15 | assign_to_manifold(self.embeddings, self.manifold) 16 | 17 | def get_config(self): 18 | config = {"manifold": self.manifold} 19 | return dict(list(super().get_config().items()) + list(config.items())) 20 | -------------------------------------------------------------------------------- /tensorflow_riemopt/layers/embeddings_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow_riemopt.layers import ManifoldEmbedding 4 | from tensorflow_riemopt.manifolds import Grassmannian 5 | from tensorflow_riemopt.variable import get_manifold 6 | 7 | 8 | class EmbeddingsTest(tf.test.TestCase): 9 | def test_layer(self): 10 | grassmannian = Grassmannian() 11 | with self.cached_session(use_gpu=True): 12 | inp = tf.keras.Input((5, 2)) 13 | layer = ManifoldEmbedding(1000, 64, manifold=grassmannian) 14 | _ = layer(inp) 15 | self.assertEqual( 16 | type(get_manifold(layer.embeddings)), type(grassmannian) 17 | ) 18 | config = layer.get_config() 19 | from_config = ManifoldEmbedding(**config) 20 | _ = from_config(inp) 21 | self.assertEqual( 22 | type(get_manifold(from_config.embeddings)), type(grassmannian) 23 | ) 24 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow_riemopt.manifolds.cholesky import Cholesky 2 | from tensorflow_riemopt.manifolds.euclidean import Euclidean 3 | from tensorflow_riemopt.manifolds.grassmannian import Grassmannian 4 | from tensorflow_riemopt.manifolds.hyperboloid import Hyperboloid 5 | from tensorflow_riemopt.manifolds.manifold import Manifold 6 | from tensorflow_riemopt.manifolds.poincare import Poincare 7 | from tensorflow_riemopt.manifolds.product import Product 8 | from tensorflow_riemopt.manifolds.special_orthogonal import SpecialOrthogonal 9 | from tensorflow_riemopt.manifolds.sphere import Sphere 10 | from tensorflow_riemopt.manifolds.stiefel import StiefelCanonical 11 | from tensorflow_riemopt.manifolds.stiefel import StiefelCayley 12 | from tensorflow_riemopt.manifolds.stiefel import StiefelEuclidean 13 | from tensorflow_riemopt.manifolds.symmetric_positive import SPDAffineInvariant 14 | from tensorflow_riemopt.manifolds.symmetric_positive import SPDLogCholesky 15 | from tensorflow_riemopt.manifolds.symmetric_positive import SPDLogEuclidean 16 | 17 | __all__ = [ 18 | "Cholesky", 19 | "Euclidean", 20 | "Grassmannian", 21 | "Hyperboloid", 22 | "Manifold", 23 | "Poincare", 24 | "Product", 25 | "SPDAffineInvariant", 26 | "SPDLogCholesky", 27 | "SPDLogEuclidean", 28 | "SpecialOrthogonal", 29 | "Sphere", 30 | "StiefelCanonical", 31 | "StiefelCayley", 32 | "StiefelEuclidean", 33 | ] 34 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/approximate_mixin.py: -------------------------------------------------------------------------------- 1 | """Numerical approximations.""" 2 | 3 | 4 | class ApproximateMixin: 5 | def ladder_ptransp(self, x, y, v, method, n_steps): 6 | """Perform an approximate parallel transport. 7 | 8 | Marco, Lorenzi, and Xavier Pennec. "Parallel transport with Pole 9 | ladder: Application to deformations of time series of images." 10 | International Conference on Geometric Science of 11 | Information. Springer, Berlin, Heidelberg, 2013. 12 | 13 | Args: 14 | method: either "pole" or "schild" transport algorithm 15 | n_steps: number of iterations 16 | """ 17 | if not method in ["pole", "schild"]: 18 | raise ValueError("Invalid transport method {}".format(method)) 19 | if n_steps <= 0: 20 | raise ValueError("n_steps should be greater than zero") 21 | u = self.log(x, y) 22 | v_i = v / n_steps 23 | x_i = self.exp(x, v_i) 24 | x_prev = x 25 | for i in range(1, n_steps + 1): 26 | u_i = u * i / n_steps 27 | y_i = self.exp(x, u_i) 28 | if method == "pole": 29 | u_1 = self.log(x_prev, y_i) / 2.0 30 | half_geo = self.exp(x_prev, u_1) 31 | u_2 = -self.log(half_geo, x_i) 32 | end_geo = self.exp(half_geo, u_2) 33 | v_i = -self.log(y_i, end_geo) 34 | x_i = self.exp(y_i, v_i) 35 | elif method == "schild": 36 | u_1 = self.log(x_i, y_i) / 2.0 37 | half_geo = self.exp(x_i, u_1) 38 | u_2 = -self.log(half_geo, x_prev) 39 | end_geo = self.exp(half_geo, u_2) 40 | v_i = self.log(y_i, end_geo) 41 | x_i = end_geo 42 | x_prev = y_i 43 | return v_i * n_steps 44 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/approximate_mixin_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from absl.testing import parameterized 4 | from tensorflow.python.keras import combinations 5 | 6 | from tensorflow_riemopt.manifolds.test_invariants import TestInvariants 7 | from tensorflow_riemopt.manifolds.sphere import Sphere 8 | from tensorflow_riemopt.manifolds.special_orthogonal import SpecialOrthogonal 9 | from tensorflow_riemopt.manifolds.approximate_mixin import ApproximateMixin 10 | 11 | 12 | class SpherePoleLadder(ApproximateMixin, Sphere): 13 | def transp(self, x, y, v): 14 | return self.ladder_transp(x, y, v, method="pole", n_steps=1) 15 | 16 | 17 | class SphereSchildLadder(ApproximateMixin, Sphere): 18 | def transp(self, x, y, v): 19 | return self.ladder_transp(x, y, v, method="schild", n_steps=1) 20 | 21 | 22 | @combinations.generate( 23 | combinations.combine( 24 | mode=["graph", "eager"], 25 | manifold=[SpherePoleLadder()], 26 | shape=[(3,), (3, 3)], 27 | dtype=[tf.float32, tf.float64], 28 | ) 29 | ) 30 | class PoleLadderTest(tf.test.TestCase, parameterized.TestCase): 31 | test_ptransp_inverse = TestInvariants.check_ptransp_inverse 32 | 33 | test_ptransp_inner = TestInvariants.check_ptransp_inner 34 | 35 | 36 | @combinations.generate( 37 | combinations.combine( 38 | mode=["graph", "eager"], 39 | manifold=[SphereSchildLadder()], 40 | shape=[(3,), (3, 3)], 41 | dtype=[tf.float32, tf.float64], 42 | ) 43 | ) 44 | class SchildLadderTest(tf.test.TestCase, parameterized.TestCase): 45 | test_ptransp_inverse = TestInvariants.check_ptransp_inverse 46 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/cholesky.py: -------------------------------------------------------------------------------- 1 | """Manifold of the Cholesky space.""" 2 | 3 | import tensorflow as tf 4 | 5 | from tensorflow_riemopt.manifolds.manifold import Manifold 6 | from tensorflow_riemopt.manifolds import utils 7 | 8 | 9 | class Cholesky(Manifold): 10 | """Manifold of lower triangular matrices with positive diagonal elements. 11 | 12 | Lin, Zhenhua. "Riemannian Geometry of Symmetric Positive Definite Matrices 13 | via Cholesky Decomposition." SIAM Journal on Matrix Analysis and 14 | Applications 40.4 (2019): 1353-1370. 15 | 16 | """ 17 | 18 | name = "Cholesky" 19 | ndims = 2 20 | 21 | def _check_shape(self, shape): 22 | return shape[-1] == shape[-2] 23 | 24 | def _check_point_on_manifold(self, x, atol, rtol): 25 | lower_triang = tf.linalg.band_part(x, -1, 0) 26 | is_lower_triang = utils.allclose(x, lower_triang, atol, rtol) 27 | diag = tf.linalg.diag_part(x) 28 | is_pos_diag = utils.allclose(diag, tf.abs(diag), atol, rtol) 29 | return is_lower_triang & is_pos_diag 30 | 31 | def _check_vector_on_tangent(self, x, u, atol, rtol): 32 | lower_triang = tf.linalg.band_part(x, -1, 0) 33 | return utils.allclose(x, lower_triang, atol, rtol) 34 | 35 | def _diag_and_strictly_lower(self, x): 36 | x = tf.linalg.band_part(x, -1, 0) 37 | diag = tf.linalg.diag_part(x) 38 | return diag, x - tf.linalg.diag(diag) 39 | 40 | def projx(self, x): 41 | x_sym = (utils.transposem(x) + x) / 2.0 42 | s, _u, v = tf.linalg.svd(x_sym) 43 | sigma = tf.linalg.diag(tf.maximum(s, 0.0)) 44 | spd = v @ sigma @ utils.transposem(v) 45 | return tf.linalg.cholesky(spd) 46 | 47 | def proju(self, x, u): 48 | u_sym = (utils.transposem(u) + u) / 2.0 49 | u_diag, u_lower = self._diag_and_strictly_lower(u_sym) 50 | x_diag = tf.linalg.diag_part(x) 51 | return u_lower + tf.linalg.diag(u_diag * x_diag**2) 52 | 53 | def inner(self, x, u, v, keepdims=False): 54 | u_diag, u_lower = self._diag_and_strictly_lower(u) 55 | v_diag, v_lower = self._diag_and_strictly_lower(v) 56 | x_diag = tf.linalg.diag_part(x) 57 | lower = tf.reduce_sum( 58 | u_lower * v_lower, axis=[-2, -1], keepdims=keepdims 59 | ) 60 | diag = tf.reduce_sum( 61 | u_diag * v_diag / tf.math.square(x_diag), axis=[-1] 62 | ) 63 | return lower + tf.reshape(diag, lower.shape) 64 | 65 | def geodesic(self, x, u, t): 66 | x_diag, x_lower = self._diag_and_strictly_lower(x) 67 | u_diag, u_lower = self._diag_and_strictly_lower(u) 68 | diag = tf.linalg.diag(x_diag * tf.math.exp(t * u_diag / x_diag)) 69 | return x_lower + t * u_lower + diag 70 | 71 | def exp(self, x, u): 72 | x_diag, x_lower = self._diag_and_strictly_lower(x) 73 | u_diag, u_lower = self._diag_and_strictly_lower(u) 74 | diag = tf.linalg.diag(x_diag * tf.math.exp(u_diag / x_diag)) 75 | return x_lower + u_lower + diag 76 | 77 | retr = exp 78 | 79 | def log(self, x, y): 80 | x_diag, x_lower = self._diag_and_strictly_lower(x) 81 | y_diag, y_lower = self._diag_and_strictly_lower(y) 82 | diag = tf.linalg.diag(x_diag * tf.math.log(y_diag / x_diag)) 83 | return y_lower - x_lower + diag 84 | 85 | def dist(self, x, y, keepdims=False): 86 | x_diag, x_lower = self._diag_and_strictly_lower(x) 87 | y_diag, y_lower = self._diag_and_strictly_lower(y) 88 | lower = tf.linalg.norm( 89 | x_lower - y_lower, axis=[-2, -1], ord="fro", keepdims=keepdims 90 | ) 91 | diag = tf.linalg.norm( 92 | tf.math.log(x_diag) - tf.math.log(y_diag), 93 | axis=-1, 94 | keepdims=keepdims, 95 | ) 96 | return tf.math.sqrt(lower**2 + tf.reshape(diag, lower.shape) ** 2) 97 | 98 | def ptransp(self, x, y, v): 99 | x_diag, _ = self._diag_and_strictly_lower(x) 100 | y_diag, _ = self._diag_and_strictly_lower(y) 101 | v_diag, v_lower = self._diag_and_strictly_lower(v) 102 | diag = tf.linalg.diag(v_diag * y_diag / x_diag) 103 | return v_lower + diag 104 | 105 | transp = ptransp 106 | 107 | def pairmean(self, x, y): 108 | return self.geodesic(x, self.log(x, y), 0.5) 109 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/cholesky_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from absl.testing import parameterized 4 | from tensorflow.python.keras import combinations 5 | 6 | from tensorflow_riemopt.manifolds.test_invariants import TestInvariants 7 | from tensorflow_riemopt.manifolds.cholesky import Cholesky 8 | 9 | 10 | @combinations.generate( 11 | combinations.combine( 12 | mode=["graph", "eager"], 13 | manifold=[Cholesky()], 14 | shape=[(2, 3, 3), (3, 3)], 15 | dtype=[tf.float32, tf.float64], 16 | ) 17 | ) 18 | class CholeskyTest(tf.test.TestCase, parameterized.TestCase): 19 | test_random = TestInvariants.check_random 20 | 21 | test_dist = TestInvariants.check_dist 22 | 23 | test_inner = TestInvariants.check_inner 24 | 25 | test_proj = TestInvariants.check_proj 26 | 27 | test_exp_log_inverse = TestInvariants.check_exp_log_inverse 28 | 29 | test_transp_retr = TestInvariants.check_transp_retr 30 | 31 | test_ptransp_inverse = TestInvariants.check_ptransp_inverse 32 | 33 | test_ptransp_inner = TestInvariants.check_ptransp_inner 34 | 35 | test_geodesic = TestInvariants.check_geodesic 36 | 37 | test_pairmean = TestInvariants.check_pairmean 38 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/euclidean.py: -------------------------------------------------------------------------------- 1 | """Manifold of the Euclidean space.""" 2 | 3 | import tensorflow as tf 4 | 5 | from tensorflow_riemopt.manifolds.manifold import Manifold 6 | 7 | 8 | class Euclidean(Manifold): 9 | name = "Euclidean" 10 | ndims = 0 11 | 12 | def __init__(self, ndims=0): 13 | """Instantiate the Euclidean manifold. 14 | 15 | Args: 16 | ndims: number of dimensions 17 | """ 18 | super().__init__() 19 | self.ndims = ndims 20 | 21 | def _check_point_on_manifold(self, x, atol, rtol): 22 | return tf.constant(True) 23 | 24 | def _check_vector_on_tangent(self, x, u, atol, rtol): 25 | return tf.constant(True) 26 | 27 | def dist(self, x, y, keepdims=False): 28 | return self.norm(x, x - y, keepdims=keepdims) 29 | 30 | def inner(self, x, u, v, keepdims=False): 31 | return tf.reduce_sum( 32 | u * v, axis=tuple(range(-self.ndims, 0)), keepdims=keepdims 33 | ) 34 | 35 | def proju(self, x, u): 36 | return u 37 | 38 | def projx(self, x): 39 | return x 40 | 41 | def exp(self, x, u): 42 | return x + u 43 | 44 | retr = exp 45 | 46 | def log(self, x, y): 47 | return y - x 48 | 49 | def ptransp(self, x, y, v): 50 | return v 51 | 52 | transp = ptransp 53 | 54 | def geodesic(self, x, u, t): 55 | return x + t * u 56 | 57 | def pairmean(self, x, y): 58 | return (x + y) / 2.0 59 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/euclidean_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from absl.testing import parameterized 4 | from tensorflow.python.keras import combinations 5 | 6 | from tensorflow_riemopt.manifolds.test_invariants import TestInvariants 7 | from tensorflow_riemopt.manifolds.euclidean import Euclidean 8 | 9 | 10 | @combinations.generate( 11 | combinations.combine( 12 | mode=["graph", "eager"], 13 | manifold=[Euclidean()], 14 | shape=[(2,), (2, 2)], 15 | dtype=[tf.float32, tf.float64], 16 | ) 17 | ) 18 | class EuclideanTest(tf.test.TestCase, parameterized.TestCase): 19 | test_random = TestInvariants.check_random 20 | 21 | test_dist = TestInvariants.check_dist 22 | 23 | test_inner = TestInvariants.check_inner 24 | 25 | test_proj = TestInvariants.check_proj 26 | 27 | test_exp_log_inverse = TestInvariants.check_exp_log_inverse 28 | 29 | test_transp_retr = TestInvariants.check_transp_retr 30 | 31 | test_ptransp_inverse = TestInvariants.check_ptransp_inverse 32 | 33 | test_ptransp_inner = TestInvariants.check_ptransp_inner 34 | 35 | test_geodesic = TestInvariants.check_geodesic 36 | 37 | test_pairmean = TestInvariants.check_pairmean 38 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/grassmannian.py: -------------------------------------------------------------------------------- 1 | """Manifold of linear subspaces of a vector space.""" 2 | 3 | import tensorflow as tf 4 | 5 | from tensorflow_riemopt.manifolds.manifold import Manifold 6 | from tensorflow_riemopt.manifolds import utils 7 | 8 | 9 | class Grassmannian(Manifold): 10 | """Manifold of :math:`k`-dimensional linear subspaces in the 11 | :math:`n`-dimensional Euclidean space. 12 | 13 | Edelman, Alan, Tomás A. Arias, and Steven T. Smith. "The geometry of 14 | algorithms with orthogonality constraints." SIAM journal on Matrix 15 | Analysis and Applications 20.2 (1998): 303-353. 16 | """ 17 | 18 | name = "Grassmannian" 19 | ndims = 2 20 | 21 | def _check_shape(self, shape): 22 | return shape[-2] >= shape[-1] 23 | 24 | def _check_point_on_manifold(self, x, atol, rtol): 25 | xtx = utils.transposem(x) @ x 26 | shape = xtx.shape.as_list() 27 | eye = tf.eye(shape[-1], batch_shape=shape[:-2]) 28 | is_idempotent = utils.allclose(xtx, tf.cast(eye, x.dtype), atol, rtol) 29 | s = tf.linalg.svd(x, compute_uv=False) 30 | rank = tf.math.count_nonzero(s, axis=-1, dtype=tf.float32) 31 | k = tf.ones_like(rank) * int(x.shape[-1]) 32 | is_col_rank = utils.allclose(rank, k, atol, rtol) 33 | return is_idempotent & is_col_rank 34 | 35 | def _check_vector_on_tangent(self, x, u, atol, rtol): 36 | xtu = utils.transposem(x) @ u 37 | return utils.allclose(xtu, tf.zeros_like(xtu), atol, rtol) 38 | 39 | def projx(self, x): 40 | q, _r = tf.linalg.qr(x) 41 | return q 42 | 43 | def proju(self, x, u): 44 | xtu = utils.transposem(x) @ u 45 | return u - x @ xtu 46 | 47 | def dist(self, x, y, keepdims=False): 48 | s = tf.linalg.svd(utils.transposem(x) @ y, compute_uv=False) 49 | theta = tf.math.acos(tf.clip_by_value(s, -1.0, 1.0)) 50 | norm = tf.linalg.norm(theta, axis=-1, keepdims=False) 51 | return norm[..., tf.newaxis, tf.newaxis] if keepdims else norm 52 | 53 | def inner(self, x, u, v, keepdims=False): 54 | return tf.reduce_sum(u * v, axis=[-2, -1], keepdims=keepdims) 55 | 56 | def exp(self, x, u): 57 | s, u, vt = tf.linalg.svd(u, full_matrices=False) 58 | cos_s = tf.linalg.diag(tf.math.cos(s)) 59 | sin_s = tf.linalg.diag(tf.math.sin(s)) 60 | return (x @ utils.transposem(vt) @ cos_s + u @ sin_s) @ vt 61 | 62 | def log(self, x, y): 63 | proj_xy = self.proju(x, y) 64 | s, u, vt = tf.linalg.svd(proj_xy, full_matrices=False) 65 | sigma = tf.linalg.diag(tf.math.asin(tf.clip_by_value(s, -1.0, 1.0))) 66 | return u @ sigma @ vt 67 | 68 | def geodesic(self, x, u, t): 69 | s, u, vt = tf.linalg.svd(u, full_matrices=False) 70 | cos_ts = tf.linalg.diag(tf.math.cos(t * s)) 71 | sin_ts = tf.linalg.diag(tf.math.sin(t * s)) 72 | return (x @ utils.transposem(vt) @ cos_ts + u @ sin_ts) @ vt 73 | 74 | def transp(self, x, y, v): 75 | return self.proju(y, v) 76 | 77 | def retr(self, x, u): 78 | _s, u, vt = tf.linalg.svd(x + u, full_matrices=False) 79 | return u @ vt 80 | 81 | def ptransp(self, x, y, v): 82 | log_xy = self.log(x, y) 83 | s, u, vt = tf.linalg.svd(log_xy, full_matrices=False) 84 | cos_s = tf.linalg.diag(tf.math.cos(s)) 85 | sin_s = tf.linalg.diag(tf.math.sin(s)) 86 | geod = ( 87 | (-x @ utils.transposem(vt) @ sin_s + u @ cos_s) 88 | @ utils.transposem(u) 89 | @ v 90 | ) 91 | proj = v - u @ utils.transposem(u) @ v 92 | return geod + proj 93 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/grassmannian_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from absl.testing import parameterized 4 | from tensorflow.python.keras import combinations 5 | 6 | from tensorflow_riemopt.manifolds.test_invariants import TestInvariants 7 | from tensorflow_riemopt.manifolds.grassmannian import Grassmannian 8 | 9 | 10 | @combinations.generate( 11 | combinations.combine( 12 | mode=["graph", "eager"], 13 | manifold=[Grassmannian()], 14 | shape=[(5, 3), (2, 5, 3)], 15 | dtype=[tf.float32, tf.float64], 16 | ) 17 | ) 18 | class GrassmannianTest(tf.test.TestCase, parameterized.TestCase): 19 | test_random = TestInvariants.check_random 20 | 21 | test_dist = TestInvariants.check_dist 22 | 23 | test_inner = TestInvariants.check_inner 24 | 25 | test_proj = TestInvariants.check_proj 26 | 27 | test_exp_log_inverse = TestInvariants.check_exp_log_inverse 28 | 29 | test_transp_retr = TestInvariants.check_transp_retr 30 | 31 | test_ptransp_inner = TestInvariants.check_ptransp_inner 32 | 33 | test_geodesic = TestInvariants.check_geodesic 34 | 35 | test_pairmean = TestInvariants.check_pairmean 36 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/hyperboloid.py: -------------------------------------------------------------------------------- 1 | """The Lorentz model.""" 2 | 3 | import tensorflow as tf 4 | 5 | from tensorflow_riemopt.manifolds.manifold import Manifold 6 | from tensorflow_riemopt.manifolds import utils 7 | 8 | 9 | class Hyperboloid(Manifold): 10 | """Manifold of `math:`n`-dimensional hyperbolic space as embedded in 11 | :math:`n+1`-dimensional Minkowski space, also known as the Lorentz model. 12 | 13 | """ 14 | 15 | name = "Hyperboloid" 16 | ndims = 1 17 | 18 | def __init__(self, k=1.0): 19 | """Instantiate the Hyperboloid manifold. 20 | 21 | Args: 22 | k: scale of the hyperbolic space. 23 | """ 24 | self.k = k 25 | super().__init__() 26 | 27 | def __repr__(self): 28 | return "{} (k={}, ndims={}) manifold".format( 29 | self.name, self.k, self.ndims 30 | ) 31 | 32 | def _check_point_on_manifold(self, x, atol, rtol): 33 | x_sq = tf.square(x) 34 | quad_form = -x_sq[..., :1] + tf.reduce_sum( 35 | x_sq[..., 1:], axis=-1, keepdims=True 36 | ) 37 | return utils.allclose( 38 | quad_form, tf.ones_like(quad_form) * -self.k, atol, rtol 39 | ) 40 | 41 | def _check_vector_on_tangent(self, x, u, atol, rtol): 42 | inner = self.inner(x, x, u) 43 | rtol = 100 * utils.get_eps(x) if rtol is None else rtol 44 | return utils.allclose(inner, tf.zeros_like(inner), atol, rtol) 45 | 46 | def dist(self, x, y, keepdims=False): 47 | d = -self.inner(x, x, y, keepdims=keepdims) 48 | k = tf.cast(self.k, x.dtype) 49 | return tf.math.sqrt(k) * tf.math.acosh(tf.maximum(d / k, 1.0)) 50 | 51 | def inner(self, x, u, v, keepdims=False): 52 | uv = u * v 53 | x0 = -tf.reduce_sum(uv[..., :1], axis=-1, keepdims=keepdims) 54 | return x0 + tf.reduce_sum(uv[..., 1:], axis=-1, keepdims=keepdims) 55 | 56 | def norm(self, x, u, keepdims=False): 57 | inner = self.inner(x, u, u, keepdims=keepdims) 58 | return tf.math.sqrt(tf.maximum(inner, utils.get_eps(x))) 59 | 60 | def proju(self, x, u): 61 | k = tf.cast(self.k, x.dtype) 62 | return u + self.inner(x, x, u, keepdims=True) * x / k 63 | 64 | def projx(self, x): 65 | k = tf.cast(self.k, x.dtype) 66 | x0 = tf.math.sqrt( 67 | k + tf.linalg.norm(x[..., 1:], axis=-1, keepdims=True) ** 2 68 | ) 69 | return tf.concat([x0, x[..., 1:]], axis=-1) 70 | 71 | def log(self, x, y): 72 | k = tf.cast(self.k, x.dtype) 73 | dist = self.dist(x, y, keepdims=True) 74 | inner = self.inner(x, x, y, keepdims=True) 75 | p = y + 1.0 / k * inner * x 76 | return dist * p / self.norm(x, p, keepdims=True) 77 | 78 | def exp(self, x, u): 79 | sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype)) 80 | norm_u = self.norm(x, u, keepdims=True) 81 | return ( 82 | tf.math.cosh(norm_u / sqrt_k) * x 83 | + sqrt_k * tf.math.sinh(norm_u / sqrt_k) * u / norm_u 84 | ) 85 | 86 | retr = exp 87 | 88 | def ptransp(self, x, y, v): 89 | log_xy = self.log(x, y) 90 | log_yx = self.log(y, x) 91 | dist_sq = self.dist(x, y, keepdims=True) ** 2 92 | inner = self.inner(x, log_xy, v, keepdims=True) 93 | return v - inner * (log_xy + log_yx) / dist_sq 94 | 95 | transp = ptransp 96 | 97 | def to_poincare(self, x, k): 98 | """Diffeomorphism that maps to the Poincaré ball""" 99 | k = tf.cast(k, x.dtype) 100 | return x[..., 1:] / (x[..., :1] + tf.math.sqrt(k)) 101 | 102 | def from_poincare(self, x, k): 103 | """Inverse of the diffeomorphism to the Poincaré ball""" 104 | k = tf.cast(k, x.dtype) 105 | x_sq_norm = tf.reduce_sum(x * x, axis=-1, keepdims=True) 106 | y = tf.math.sqrt(k) * tf.concat([1 + x_sq_norm, 2 * x], axis=-1) 107 | return y / (1.0 - x_sq_norm + utils.get_eps(x)) 108 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/hyperboloid_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from absl.testing import parameterized 4 | from tensorflow.python.keras import combinations 5 | 6 | from tensorflow_riemopt.manifolds.test_invariants import ( 7 | TestInvariants, 8 | random_constant, 9 | ) 10 | from tensorflow_riemopt.manifolds.hyperboloid import Hyperboloid 11 | 12 | 13 | @combinations.generate( 14 | combinations.combine( 15 | mode=["graph", "eager"], 16 | manifold=[Hyperboloid(), Hyperboloid(k=5.0)], 17 | shape=[(5,), (2, 2)], 18 | dtype=[tf.float64], 19 | ) 20 | ) 21 | class HyperboloidTest(tf.test.TestCase, parameterized.TestCase): 22 | test_random = TestInvariants.check_random 23 | 24 | test_dist = TestInvariants.check_dist 25 | 26 | test_inner = TestInvariants.check_inner 27 | 28 | test_proj = TestInvariants.check_proj 29 | 30 | test_exp_log_inverse = TestInvariants.check_exp_log_inverse 31 | 32 | test_transp_retr = TestInvariants.check_transp_retr 33 | 34 | test_ptransp_inverse = TestInvariants.check_ptransp_inverse 35 | 36 | test_ptransp_inner = TestInvariants.check_ptransp_inner 37 | 38 | def test_poincare(self, manifold, shape, dtype): 39 | with self.cached_session(use_gpu=True): 40 | x = manifold.projx(random_constant(shape=shape, dtype=dtype)) 41 | y = manifold.to_poincare(x, manifold.k) 42 | x_ = manifold.from_poincare(y, manifold.k) 43 | if not tf.executing_eagerly(): 44 | x_ = self.evaluate(x_) 45 | self.assertAllCloseAccordingToType(x, x_) 46 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/manifold.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import abc 3 | 4 | 5 | class Manifold(metaclass=abc.ABCMeta): 6 | name = "Base" 7 | ndims = None 8 | 9 | def __repr__(self): 10 | """Returns a string representation of the particular manifold.""" 11 | return "{} (ndims={}) manifold".format(self.name, self.ndims) 12 | 13 | def check_shape(self, shape_or_tensor): 14 | """Check if given shape is compatible with the manifold.""" 15 | shape = ( 16 | shape_or_tensor.shape 17 | if hasattr(shape_or_tensor, "shape") 18 | else shape_or_tensor 19 | ) 20 | return (len(shape) >= self.ndims) & self._check_shape(shape) 21 | 22 | def check_point_on_manifold(self, x, atol=None, rtol=None): 23 | """Check if point :math:`x` lies on the manifold.""" 24 | return self.check_shape(x) & self._check_point_on_manifold( 25 | x, atol=atol, rtol=rtol 26 | ) 27 | 28 | def check_vector_on_tangent(self, x, u, atol=None, rtol=None): 29 | """Check if vector :math:`u` lies on the tangent space at :math:`x`.""" 30 | return ( 31 | self._check_point_on_manifold(x, atol=atol, rtol=rtol) 32 | & self.check_shape(u) 33 | & self._check_vector_on_tangent(x, u, atol=atol, rtol=rtol) 34 | ) 35 | 36 | def _check_shape(self, shape): 37 | return tf.constant(True) 38 | 39 | @abc.abstractmethod 40 | def _check_point_on_manifold(self, x, atol, rtol): 41 | raise NotImplementedError 42 | 43 | @abc.abstractmethod 44 | def _check_vector_on_tangent(self, x, u, atol, rtol): 45 | raise NotImplementedError 46 | 47 | @abc.abstractmethod 48 | def dist(self, x, y, keepdims=False): 49 | """Compute the distance between two points :math:`x` and :math:`y: along a 50 | geodesic. 51 | """ 52 | raise NotImplementedError 53 | 54 | @abc.abstractmethod 55 | def inner(self, x, u, v, keepdims=False): 56 | """Return the inner product (i.e., the Riemannian metric) between two tangent 57 | vectors :math:`u` and :math:`v` in the tangent space at :math:`x`. 58 | """ 59 | raise NotImplementedError 60 | 61 | def norm(self, x, u, keepdims=False): 62 | """Compute the norm of a tangent vector :math:`u` in the tangent space at 63 | :math:`x`. 64 | """ 65 | return self.inner(x, u, u, keepdims=keepdims) ** 0.5 66 | 67 | @abc.abstractmethod 68 | def proju(self, x, u): 69 | """Project a vector :math:`u` in the ambient space on the tangent space at 70 | :math:`x` 71 | """ 72 | raise NotImplementedError 73 | 74 | def egrad2rgrad(self, x, u): 75 | """Map the Euclidean gradient :math:`u` in the ambient space on the tangent 76 | space at :math:`x`. 77 | """ 78 | return self.proju(x, u) 79 | 80 | @abc.abstractmethod 81 | def projx(self, x): 82 | """Project a point :math:`x` on the manifold.""" 83 | raise NotImplementedError 84 | 85 | @abc.abstractmethod 86 | def retr(self, x, u): 87 | """Perform a retraction from point :math:`x` with given direction :math:`u`.""" 88 | raise NotImplementedError 89 | 90 | @abc.abstractmethod 91 | def exp(self, x, u): 92 | r"""Perform an exponential map :math:`\operatorname{Exp}_x(u)`.""" 93 | raise NotImplementedError 94 | 95 | @abc.abstractmethod 96 | def log(self, x, y): 97 | r"""Perform a logarithmic map :math:`\operatorname{Log}_{x}(y)`.""" 98 | raise NotImplementedError 99 | 100 | @abc.abstractmethod 101 | def transp(self, x, y, v): 102 | r"""Perform a vector transport :math:`\mathfrak{T}_{x\to y}(v)`.""" 103 | return NotImplementedError 104 | 105 | def ptransp(self, x, y, v): 106 | r"""Perform a parallel transport :math:`\operatorname{P}_{x\to y}(v)`.""" 107 | return NotImplementedError 108 | 109 | def random(self, shape, dtype=tf.float32): 110 | """Sample a random point on the manifold.""" 111 | return self.projx(tf.random.uniform(shape, dtype=dtype)) 112 | 113 | def geodesic(self, x, u, t): 114 | """Geodesic from point :math:`x` in the direction of tanget vector 115 | :math:`u` 116 | """ 117 | raise NotImplementedError 118 | 119 | def pairmean(self, x, y): 120 | """Compute a Riemannian (Fréchet) mean of points :math:`x` and :math:`y`""" 121 | return self.geodesic(x, self.log(x, y), 0.5) 122 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/poincare.py: -------------------------------------------------------------------------------- 1 | """The Poincaré ball model.""" 2 | 3 | import tensorflow as tf 4 | 5 | from tensorflow_riemopt.manifolds.manifold import Manifold 6 | from tensorflow_riemopt.manifolds import utils 7 | 8 | 9 | class Poincare(Manifold): 10 | """Manifold of `math:`n`-dimensional hyperbolic space as embedded in the 11 | Poincaré ball model. 12 | 13 | Nickel, Maximillian, and Douwe Kiela. "Poincaré embeddings for learning 14 | hierarchical representations." Advances in neural information processing 15 | systems. 2017. 16 | 17 | Ganea, Octavian, Gary Bécigneul, and Thomas Hofmann. "Poincaré neural 18 | networks." Advances in neural information processing systems. 2018. 19 | 20 | """ 21 | 22 | name = "Poincaré" 23 | ndims = 1 24 | 25 | def __init__(self, k=1.0): 26 | """Instantiate the Poincaré manifold. 27 | 28 | Args: 29 | k: scale of the hyperbolic space, k > 0. 30 | """ 31 | self.k = k 32 | super().__init__() 33 | 34 | def __repr__(self): 35 | return "{} (k={}, ndims={}) manifold".format( 36 | self.name, self.k, self.ndims 37 | ) 38 | 39 | def _check_point_on_manifold(self, x, atol, rtol): 40 | k = tf.cast(self.k, x.dtype) 41 | sq_norm = tf.reduce_sum(x * x, axis=-1, keepdims=True) 42 | return tf.reduce_all(sq_norm * k < tf.ones_like(sq_norm)) 43 | 44 | def _check_vector_on_tangent(self, x, u, atol, rtol): 45 | return tf.constant(True) 46 | 47 | def _mobius_add(self, x, y): 48 | """Compute the Möbius addition of :math:`x` and :math:`y` in 49 | :math:`\\mathcal{D}^{n}_{k}` 50 | 51 | :math:`x \\oplus y = \frac{(1 + 2k\\langle x, y\rangle + k||y||^2)x + (1 52 | - k||x||^2)y}{1 + 2k\\langle x,y\rangle + k^2||x||^2||y||^2}` 53 | """ 54 | x_2 = tf.reduce_sum(tf.math.square(x), axis=-1, keepdims=True) 55 | y_2 = tf.reduce_sum(tf.math.square(y), axis=-1, keepdims=True) 56 | x_y = tf.reduce_sum(x * y, axis=-1, keepdims=True) 57 | k = tf.cast(self.k, x.dtype) 58 | return ((1 + 2 * k * x_y + k * y_2) * x + (1 - k * x_2) * y) / ( 59 | 1 + 2 * k * x_y + k**2 * x_2 * y_2 60 | ) 61 | 62 | def _mobius_scal_mul(self, x, r): 63 | """Compute the Möbius scalar multiplication of :math:`x \\in 64 | \\mathcal{D}^{n}_{k} \\ {0}` by :math:`r` 65 | 66 | :math:`x \\otimes r = (1/\\sqrt{k})\tanh(r 67 | \atanh(\\sqrt{k}||x||))\frac{x}{||x||}` 68 | 69 | """ 70 | sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype)) 71 | norm_x = tf.linalg.norm(x, axis=-1, keepdims=True) 72 | eps = utils.get_eps(x) 73 | tan = tf.clip_by_value(sqrt_k * norm_x, -1.0 + eps, 1.0 - eps) 74 | return (1 / sqrt_k) * tf.math.tanh(r * tf.math.atanh(tan)) * x / norm_x 75 | 76 | def _gyration(self, u, v, w): 77 | r"""Compute the gyration of :math:`u`, :math:`v`, :math:`w`: 78 | 79 | :math:`\operatorname{gyr}[u, v]w = 80 | \ominus (u \oplus_\kappa v) \oplus (u \oplus_\kappa (v \oplus_\kappa w))` 81 | """ 82 | min_u_v = -self._mobius_add(u, v) 83 | v_w = self._mobius_add(v, w) 84 | u_v_w = self._mobius_add(u, v_w) 85 | return self._mobius_add(min_u_v, u_v_w) 86 | 87 | def _lambda(self, x, keepdims=False): 88 | """Compute the conformal factor :math:`lambda_x^k`""" 89 | k = tf.cast(self.k, x.dtype) 90 | norm_x_2 = tf.reduce_sum(x * x, axis=-1, keepdims=keepdims) 91 | return 2.0 / (1.0 - k * norm_x_2) 92 | 93 | def inner(self, x, u, v, keepdims=False): 94 | lambda_x = self._lambda(x, keepdims=keepdims) 95 | return tf.reduce_sum(u * v, axis=-1, keepdims=keepdims) * lambda_x**2 96 | 97 | def norm(self, x, u, keepdims=False): 98 | lambda_x = self._lambda(x, keepdims=keepdims) 99 | return tf.linalg.norm(u, axis=-1, keepdims=keepdims) * lambda_x 100 | 101 | def proju(self, x, u): 102 | lambda_x = self._lambda(x, keepdims=True) 103 | return u / lambda_x**2 104 | 105 | def projx(self, x): 106 | sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype)) 107 | norm = tf.linalg.norm(x, axis=-1, keepdims=True) 108 | return tf.where( 109 | sqrt_k * norm < tf.ones_like(norm), 110 | x, 111 | x / (sqrt_k * norm + 10 * utils.get_eps(x)), 112 | ) 113 | 114 | def dist(self, x, y, keepdims=False): 115 | sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype)) 116 | x_y = self._mobius_add(-x, y) 117 | norm_x_y = tf.linalg.norm(x_y, axis=-1, keepdims=keepdims) 118 | eps = utils.get_eps(x) 119 | tanh = tf.clip_by_value(sqrt_k * norm_x_y, -1.0 + eps, 1.0 - eps) 120 | return 2 * tf.math.atanh(tanh) / sqrt_k 121 | 122 | def exp(self, x, u): 123 | sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype)) 124 | norm_u = tf.linalg.norm(u, axis=-1, keepdims=True) 125 | lambda_x = self._lambda(x, keepdims=True) 126 | y = ( 127 | tf.math.tanh(sqrt_k * norm_u * lambda_x / 2.0) 128 | * u 129 | / (sqrt_k * norm_u) 130 | ) 131 | return self._mobius_add(x, y) 132 | 133 | def log(self, x, y): 134 | sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype)) 135 | x_y = self._mobius_add(-x, y) 136 | norm_x_y = tf.linalg.norm(x_y, axis=-1, keepdims=True) 137 | eps = utils.get_eps(x) 138 | tanh = tf.clip_by_value(sqrt_k * norm_x_y, -1.0 + eps, 1.0 - eps) 139 | lambda_x = self._lambda(x, keepdims=True) 140 | return 2 * (x_y / norm_x_y) * tf.math.atanh(tanh) / (sqrt_k * lambda_x) 141 | 142 | retr = exp 143 | 144 | def ptransp(self, x, y, v): 145 | lambda_x = self._lambda(x, keepdims=True) 146 | lambda_y = self._lambda(y, keepdims=True) 147 | return self._gyration(y, -x, v) * lambda_x / lambda_y 148 | 149 | transp = ptransp 150 | 151 | def geodesic(self, x, u, t): 152 | sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype)) 153 | norm_u = tf.linalg.norm(u, axis=-1, keepdims=True) 154 | y = tf.math.tanh(sqrt_k * t / 2.0) * u / (norm_u * sqrt_k) 155 | return self._mobius_add(x, y) 156 | 157 | def exp0(self, u): 158 | """Perform an exponential map from the origin""" 159 | sqrt_k = tf.math.sqrt(tf.cast(self.k, u.dtype)) 160 | norm_u = tf.linalg.norm(u, axis=-1, keepdims=True) 161 | return tf.math.tanh(sqrt_k * norm_u) * u / (sqrt_k * norm_u) 162 | 163 | def log0(self, y): 164 | """Perform a logarithmic map from the origin""" 165 | sqrt_k = tf.math.sqrt(tf.cast(self.k, y.dtype)) 166 | norm_y = tf.linalg.norm(y, axis=-1, keepdims=True) 167 | return tf.math.atanh(sqrt_k * norm_y) * y / (sqrt_k * norm_y) 168 | 169 | def ptransp0(self, y, v): 170 | """Perform a parallel transport from the origin""" 171 | lambda_y = self._lambda(y, keepdims=True) 172 | return 2 * v / lambda_y 173 | 174 | def random(self, shape, dtype=tf.float32): 175 | return self.projx( 176 | tf.random.uniform(shape, minval=-1e-3, maxval=1e-3, dtype=dtype) 177 | ) 178 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/poincare_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from absl.testing import parameterized 4 | from tensorflow.python.keras import combinations 5 | 6 | from tensorflow_riemopt.manifolds.test_invariants import ( 7 | TestInvariants, 8 | random_constant, 9 | ) 10 | from tensorflow_riemopt.manifolds.poincare import Poincare 11 | 12 | 13 | @combinations.generate( 14 | combinations.combine( 15 | mode=["graph", "eager"], 16 | manifold=[Poincare(), Poincare(k=5.0)], 17 | shape=[(5,), (2, 5)], 18 | dtype=[tf.float32, tf.float64], 19 | ) 20 | ) 21 | class PoincareTest(tf.test.TestCase, parameterized.TestCase): 22 | test_random = TestInvariants.check_random 23 | 24 | test_dist = TestInvariants.check_dist 25 | 26 | test_inner = TestInvariants.check_inner 27 | 28 | test_proj = TestInvariants.check_proj 29 | 30 | test_exp_log_inverse = TestInvariants.check_exp_log_inverse 31 | 32 | test_transp_retr = TestInvariants.check_transp_retr 33 | 34 | test_ptransp_inverse = TestInvariants.check_ptransp_inverse 35 | 36 | test_ptransp_inner = TestInvariants.check_ptransp_inner 37 | 38 | def test_oper0(self, manifold, shape, dtype): 39 | with self.cached_session(use_gpu=True): 40 | x = random_constant(shape=shape, dtype=dtype) 41 | u = manifold.proju(x, random_constant(shape=shape, dtype=dtype)) 42 | z = manifold.projx(tf.zeros_like(x)) 43 | u_at_0 = manifold.ptransp(x, z, u) 44 | u_ = manifold.ptransp0(x, u_at_0) 45 | self.assertAllCloseAccordingToType(u, u_) 46 | y = manifold.exp(z, u_at_0) 47 | y_ = manifold.exp0(u_at_0) 48 | self.assertAllCloseAccordingToType(y, y_) 49 | v = manifold.log0(y_) 50 | self.assertAllCloseAccordingToType(v, u_at_0) 51 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/product.py: -------------------------------------------------------------------------------- 1 | """Cartesian product of manifolds.""" 2 | 3 | import tensorflow as tf 4 | from functools import reduce 5 | from operator import mul 6 | 7 | from tensorflow_riemopt.manifolds.manifold import Manifold 8 | 9 | 10 | class Product(Manifold): 11 | """Product space of manifolds.""" 12 | 13 | name = "Product" 14 | ndims = 1 15 | 16 | def __init__(self, *manifolds): 17 | """Initialize a product of manifolds. 18 | 19 | Args: 20 | *manifolds: an iterable of (`manifold`, `shape`) tuples, where 21 | `manifold` is an instance of `Manifold` and `shape` is a tuple of 22 | `manifold` dimensions 23 | 24 | Example: 25 | 26 | >>> from tensorflow_riemopt import manifolds 27 | 28 | >>> S = manifolds.Sphere() 29 | >>> torus = manifolds.Product((S, (2,)), (S, (2,))) 30 | 31 | >>> St = manifolds.EuclideanStiefel() 32 | >>> St_2 = manifolds.Product((St, (5, 3)), (St, (5, 3))) 33 | """ 34 | self._indices = [0] 35 | self._manifolds = [] 36 | self._shapes = [] 37 | for i, (m, shape) in enumerate(manifolds): 38 | if not isinstance(m, Manifold): 39 | raise ValueError( 40 | "{} should be an instance of Manifold".format(m) 41 | ) 42 | if not m.check_shape(shape): 43 | raise ValueError( 44 | "Invalid shape {} for manifold {}".format(shape, m) 45 | ) 46 | self._indices.append(self._indices[i] + reduce(mul, shape)) 47 | self._manifolds.append(m) 48 | self._shapes.append(list(shape)) 49 | super().__init__() 50 | 51 | def __repr__(self): 52 | names = [ 53 | "{}{}".format(m.name, tuple(shape)) 54 | for m, shape in zip(self._manifolds, self._shapes) 55 | ] 56 | return " × ".join(names) 57 | 58 | def _check_shape(self, shape): 59 | return self._indices[-1] == shape[-1] 60 | 61 | def _check_point_on_manifold(self, x, atol, rtol): 62 | checks = [ 63 | m.check_point_on_manifold(self._get_slice(x, i), atol, rtol) 64 | for (i, m) in enumerate(self._manifolds) 65 | ] 66 | return reduce(tf.logical_and, checks) 67 | 68 | def _check_vector_on_tangent(self, x, u, atol, rtol): 69 | checks = [ 70 | m.check_vector_on_tangent( 71 | self._get_slice(x, i), self._get_slice(u, i), atol, rtol 72 | ) 73 | for (i, m) in enumerate(self._manifolds) 74 | ] 75 | return reduce(tf.logical_and, checks) 76 | 77 | def _get_slice(self, x, idx): 78 | if not 0 <= idx < len(self._indices) - 1: 79 | raise ValueError("Invalid index {}".format(idx)) 80 | slice = x[..., self._indices[idx] : self._indices[idx + 1]] 81 | shape = tf.concat([tf.shape(x)[:-1], self._shapes[idx]], axis=-1) 82 | return tf.reshape(slice, shape) 83 | 84 | def _product_fn(self, fn, *args, **kwargs): 85 | results = [] 86 | for i, m in enumerate(self._manifolds): 87 | arg_slices = [self._get_slice(arg, i) for arg in args] 88 | result = getattr(m, fn)(*arg_slices, **kwargs) 89 | shape = tf.concat([tf.shape(result)[:-1], [-1]], axis=-1) 90 | results.append(tf.reshape(result, shape)) 91 | return tf.concat(results, axis=-1) 92 | 93 | def random(self, shape, dtype=tf.float32): 94 | if not self.check_shape(shape): 95 | raise ValueError("Invalid shape {}".format(shape)) 96 | shape = list(shape) 97 | results = [] 98 | for i, m in enumerate(self._manifolds): 99 | result = m.random(shape[:-1] + self._shapes[i], dtype=dtype) 100 | results.append(tf.reshape(result, shape[:-1] + [-1])) 101 | return tf.concat(results, axis=-1) 102 | 103 | def dist(self, x, y, keepdims=False): 104 | dists = self._product_fn("dist", x, y, keepdims=True) 105 | shape = tf.concat([tf.shape(x)[:-1], [-1]], axis=-1) 106 | sq_dists = tf.reduce_sum(dists * dists, axis=-1, keepdims=keepdims) 107 | return tf.math.sqrt(sq_dists) 108 | 109 | def inner(self, x, u, v, keepdims=False): 110 | inners = self._product_fn("inner", x, u, v, keepdims=True) 111 | return tf.reduce_sum(inners, axis=-1, keepdims=keepdims) 112 | 113 | def proju(self, x, u): 114 | return self._product_fn("proju", x, u) 115 | 116 | def projx(self, x): 117 | return self._product_fn("projx", x) 118 | 119 | def exp(self, x, u): 120 | return self._product_fn("exp", x, u) 121 | 122 | def retr(self, x, u): 123 | return self._product_fn("retr", x, u) 124 | 125 | def log(self, x, y): 126 | return self._product_fn("log", x, y) 127 | 128 | def ptransp(self, x, y, v): 129 | return self._product_fn("ptransp", x, y, v) 130 | 131 | def transp(self, x, y, v): 132 | return self._product_fn("transp", x, y, v) 133 | 134 | def pairmean(self, x, y): 135 | return self._product_fn("pairmean", x, y) 136 | 137 | def geodesic(self, x, u, t): 138 | return self._product_fn("geodesic", x, u, t=t) 139 | 140 | def pairmean(self, x, y): 141 | return self._product_fn("pairmean", x, y) 142 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/product_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from absl.testing import parameterized 4 | from tensorflow.python.keras import combinations 5 | 6 | from tensorflow_riemopt.manifolds.test_invariants import TestInvariants 7 | from tensorflow_riemopt.manifolds.product import Product 8 | from tensorflow_riemopt.manifolds.sphere import Sphere 9 | from tensorflow_riemopt.manifolds.euclidean import Euclidean 10 | 11 | 12 | @combinations.generate( 13 | combinations.combine( 14 | mode=["graph", "eager"], 15 | manifold=[ 16 | Product((Sphere(), (3,)), (Sphere(), (2,))), 17 | Product((Euclidean(), (3,)), (Sphere(), (2,))), 18 | ], 19 | shape=[(5,), (2, 5), (2, 2, 5)], 20 | dtype=[tf.float64], 21 | ) 22 | ) 23 | class ProductTest(tf.test.TestCase, parameterized.TestCase): 24 | test_random = TestInvariants.check_random 25 | 26 | test_dist = TestInvariants.check_dist 27 | 28 | test_inner = TestInvariants.check_inner 29 | 30 | test_proj = TestInvariants.check_proj 31 | 32 | test_exp_log_inverse = TestInvariants.check_exp_log_inverse 33 | 34 | test_transp_retr = TestInvariants.check_transp_retr 35 | 36 | test_ptransp_inverse = TestInvariants.check_ptransp_inverse 37 | 38 | test_ptransp_inner = TestInvariants.check_ptransp_inner 39 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/special_orthogonal.py: -------------------------------------------------------------------------------- 1 | """Manifold of rotation matrices.""" 2 | 3 | import tensorflow as tf 4 | 5 | from tensorflow_riemopt.manifolds.manifold import Manifold 6 | from tensorflow_riemopt.manifolds import utils 7 | 8 | 9 | class SpecialOrthogonal(Manifold): 10 | """Manifold of square orthogonal matrices of determinant 1.""" 11 | 12 | name = "Special Orthogonal" 13 | ndims = 2 14 | 15 | def _check_shape(self, shape): 16 | return shape[-2] == shape[-1] 17 | 18 | def _check_point_on_manifold(self, x, atol, rtol): 19 | xtx = utils.transposem(x) @ x 20 | eye = tf.eye( 21 | tf.shape(xtx)[-1], batch_shape=tf.shape(xtx)[:-2], dtype=x.dtype 22 | ) 23 | is_orth = utils.allclose(xtx, eye, atol, rtol) 24 | det = tf.linalg.det(x) 25 | is_unit_det = utils.allclose(det, tf.ones_like(det), atol, rtol) 26 | return is_orth & is_unit_det 27 | 28 | def _check_vector_on_tangent(self, x, u, atol, rtol): 29 | diff = utils.transposem(u) + u 30 | return utils.allclose(diff, tf.zeros_like(diff), atol, rtol) 31 | 32 | def dist(self, x, y, keepdims=False): 33 | return tf.linalg.norm(self.log(x, y), axis=[-2, -1], keepdims=keepdims) 34 | 35 | def inner(self, x, u, v, keepdims=False): 36 | return tf.reduce_sum(u * v, axis=[-2, -1], keepdims=keepdims) 37 | 38 | def proju(self, x, u): 39 | xtu = utils.transposem(x) @ u 40 | return (xtu - utils.transposem(xtu)) / 2.0 41 | 42 | def projx(self, x): 43 | s, u, vt = tf.linalg.svd(x) 44 | ones = tf.ones_like(s)[..., :-1] 45 | signs = tf.cast(tf.sign(tf.linalg.det(u @ vt)), ones.dtype) 46 | flip = tf.concat([ones, tf.expand_dims(signs, -1)], axis=-1) 47 | return u @ tf.linalg.diag(flip) @ vt 48 | 49 | def exp(self, x, u): 50 | return x @ tf.linalg.expm(u) 51 | 52 | def retr(self, x, u): 53 | q, r = tf.linalg.qr(x + x @ u) 54 | unflip = tf.sign(tf.sign(tf.linalg.diag_part(r)) + 0.5) 55 | return q * unflip[..., tf.newaxis, :] 56 | 57 | def retr2(self, x, u): 58 | _s, u, vt = tf.linalg.svd(x + x @ u) 59 | return u @ vt 60 | 61 | def log(self, x, y): 62 | xty = utils.transposem(x) @ y 63 | u = utils.logm(xty) 64 | return (u - utils.transposem(u)) / 2.0 65 | 66 | def transp(self, x, y, v): 67 | return v 68 | 69 | ptransp = transp 70 | 71 | def pairmean(self, x, y): 72 | return self.exp(x, self.log(x, y) / 2.0) 73 | 74 | def geodesic(self, x, u, t): 75 | return x @ tf.linalg.expm(t * u) 76 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/special_orthogonal_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from absl.testing import parameterized 4 | from tensorflow.python.keras import combinations 5 | 6 | from tensorflow_riemopt.manifolds.test_invariants import TestInvariants 7 | from tensorflow_riemopt.manifolds.special_orthogonal import SpecialOrthogonal 8 | 9 | 10 | @combinations.generate( 11 | combinations.combine( 12 | mode=["graph", "eager"], 13 | manifold=[SpecialOrthogonal()], 14 | shape=[(2, 3, 3), (3, 3)], 15 | dtype=[tf.float32, tf.float64], 16 | ) 17 | ) 18 | class SpecialOrthogonalTest(tf.test.TestCase, parameterized.TestCase): 19 | test_random = TestInvariants.check_random 20 | 21 | test_dist = TestInvariants.check_dist 22 | 23 | test_inner = TestInvariants.check_inner 24 | 25 | test_proj = TestInvariants.check_proj 26 | 27 | test_exp_log_inverse = TestInvariants.check_exp_log_inverse 28 | 29 | test_transp_retr = TestInvariants.check_transp_retr 30 | 31 | test_ptransp_inverse = TestInvariants.check_ptransp_inverse 32 | 33 | test_ptransp_inner = TestInvariants.check_ptransp_inner 34 | 35 | test_geodesic = TestInvariants.check_geodesic 36 | 37 | test_pairmean = TestInvariants.check_pairmean 38 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/sphere.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow_riemopt.manifolds.manifold import Manifold 4 | from tensorflow_riemopt.manifolds import utils 5 | 6 | 7 | class Sphere(Manifold): 8 | """Manifold of unit-norm points. 9 | 10 | Bergmann, Ronny, et al. "Priors with coupled first and second order 11 | differences for manifold-valued image processing." Journal of mathematical 12 | imaging and vision 60.9 (2018): 1459-1481. 13 | """ 14 | 15 | name = "Sphere" 16 | ndims = 1 17 | 18 | def _check_shape(self, shape): 19 | return shape[-1] > 1 20 | 21 | def _check_point_on_manifold(self, x, atol, rtol): 22 | norm = tf.linalg.norm(x, axis=-1) 23 | return utils.allclose(norm, tf.ones_like(norm), atol, rtol) 24 | 25 | def _check_vector_on_tangent(self, x, u, atol, rtol): 26 | inner = self.inner(x, x, u, keepdims=True) 27 | return utils.allclose(inner, tf.zeros_like(inner), atol, rtol) 28 | 29 | def dist(self, x, y, keepdims=False): 30 | inner = self.inner(x, x, y, keepdims=keepdims) 31 | cos_angle = tf.clip_by_value(inner, -1.0, 1.0) 32 | return tf.math.acos(cos_angle) 33 | 34 | def inner(self, x, u, v, keepdims=False): 35 | return tf.reduce_sum(u * v, axis=-1, keepdims=keepdims) 36 | 37 | def proju(self, x, u): 38 | return u - tf.reduce_sum(x * u, axis=-1, keepdims=True) * x 39 | 40 | def projx(self, x): 41 | return x / tf.linalg.norm(x, axis=-1, keepdims=True) 42 | 43 | def norm(self, x, u, keepdims=False): 44 | norm_u = tf.linalg.norm(u, axis=-1, keepdims=keepdims) 45 | return tf.maximum(norm_u, utils.get_eps(x)) 46 | 47 | def exp(self, x, u): 48 | norm_u = self.norm(x, u, keepdims=True) 49 | exp = x * tf.math.cos(norm_u) + u * tf.math.sin(norm_u) / norm_u 50 | retr = self.projx(x + u) 51 | return tf.where(norm_u > utils.get_eps(x), exp, retr) 52 | 53 | def retr(self, x, u): 54 | return self.projx(x + u) 55 | 56 | def log(self, x, y): 57 | u = self.proju(x, y - x) 58 | norm_u = self.norm(x, u, keepdims=True) 59 | dist = self.dist(x, y, keepdims=True) 60 | log = u * dist / norm_u 61 | return tf.where(dist > utils.get_eps(x), log, u) 62 | 63 | def ptransp(self, x, y, v): 64 | log_xy = self.log(x, y) 65 | log_yx = self.log(y, x) 66 | dist_sq = self.dist(x, y, keepdims=True) ** 2 67 | inner = self.inner(x, log_xy, v, keepdims=True) 68 | return v - inner * (log_xy + log_yx) / dist_sq 69 | 70 | def transp(self, x, y, v): 71 | return self.proju(y, v) 72 | 73 | def geodesic(self, x, u, t): 74 | norm_u = self.norm(x, u, keepdims=True) 75 | return ( 76 | x * tf.math.cos(norm_u * t) + u * tf.math.sin(norm_u * t) / norm_u 77 | ) 78 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/sphere_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from absl.testing import parameterized 4 | from tensorflow.python.keras import combinations 5 | 6 | from tensorflow_riemopt.manifolds.test_invariants import TestInvariants 7 | from tensorflow_riemopt.manifolds.sphere import Sphere 8 | 9 | 10 | @combinations.generate( 11 | combinations.combine( 12 | mode=["graph", "eager"], 13 | manifold=[Sphere()], 14 | shape=[(2,), (2, 2)], 15 | dtype=[tf.float64], 16 | ) 17 | ) 18 | class SphereTest(tf.test.TestCase, parameterized.TestCase): 19 | test_random = TestInvariants.check_random 20 | 21 | test_dist = TestInvariants.check_dist 22 | 23 | test_inner = TestInvariants.check_inner 24 | 25 | test_proj = TestInvariants.check_proj 26 | 27 | test_exp_log_inverse = TestInvariants.check_exp_log_inverse 28 | 29 | test_transp_retr = TestInvariants.check_transp_retr 30 | 31 | test_ptransp_inverse = TestInvariants.check_ptransp_inverse 32 | 33 | test_ptransp_inner = TestInvariants.check_ptransp_inner 34 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/stiefel.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow_riemopt.manifolds.manifold import Manifold 4 | from tensorflow_riemopt.manifolds import utils 5 | 6 | 7 | def blockm(a, b, c, d): 8 | """Form a block matrix. 9 | 10 | Returns a `Tensor` 11 | 12 | | a b | 13 | | c d | 14 | 15 | of shape [..., n, p], where n = n_1 + n_2, p = p_1 + p_2, and arguments 16 | have shapes 17 | 18 | `a` - [..., n_1, p_1] 19 | `b` - [..., n_1, p_2] 20 | `c` - [..., n_2, p_1] 21 | `d` - [..., n_2, p_2] 22 | 23 | """ 24 | a_b = tf.concat([a, b], axis=-1) 25 | c_d = tf.concat([c, d], axis=-1) 26 | return tf.concat([a_b, c_d], axis=-2) 27 | 28 | 29 | class _Stiefel(Manifold): 30 | """Manifold of orthonormal p-frames in the n-dimensional Euclidean space.""" 31 | 32 | ndims = 2 33 | 34 | def _check_shape(self, shape): 35 | return shape[-2] >= shape[-1] 36 | 37 | def _check_point_on_manifold(self, x, atol, rtol): 38 | xtx = utils.transposem(x) @ x 39 | eye = tf.eye( 40 | tf.shape(xtx)[-1], batch_shape=tf.shape(xtx)[:-2], dtype=xtx.dtype 41 | ) 42 | return utils.allclose(xtx, eye, atol, rtol) 43 | 44 | def _check_vector_on_tangent(self, x, u, atol, rtol): 45 | diff = utils.transposem(u) @ x + utils.transposem(x) @ u 46 | return utils.allclose(diff, tf.zeros_like(diff), atol, rtol) 47 | 48 | def projx(self, x): 49 | _s, u, vt = tf.linalg.svd(x) 50 | return u @ vt 51 | 52 | 53 | class StiefelEuclidean(_Stiefel): 54 | """Manifold of orthonormal p-frames in the n-dimensional space endowed with 55 | the Euclidean inner product. 56 | 57 | """ 58 | 59 | name = "Euclidean Stiefel" 60 | 61 | def inner(self, x, u, v, keepdims=False): 62 | return tf.reduce_sum(u * v, axis=[-2, -1], keepdims=keepdims) 63 | 64 | def proju(self, x, u): 65 | xtu = utils.transposem(x) @ u 66 | xtu_sym = (utils.transposem(xtu) + xtu) / 2.0 67 | return u - x @ xtu_sym 68 | 69 | def exp(self, x, u): 70 | return self.geodesic(x, u, 1.0) 71 | 72 | def retr(self, x, u): 73 | q, r = tf.linalg.qr(x + u) 74 | unflip = tf.cast(tf.sign(tf.linalg.diag_part(r)), r.dtype) 75 | return q * unflip[..., tf.newaxis, :] 76 | 77 | def geodesic(self, x, u, t): 78 | xtu = utils.transposem(x) @ u 79 | utu = utils.transposem(u) @ u 80 | eye = tf.eye( 81 | tf.shape(utu)[-1], batch_shape=tf.shape(utu)[:-2], dtype=x.dtype 82 | ) 83 | logw = blockm(xtu, -utu, eye, xtu) 84 | w = tf.linalg.expm(t * logw) 85 | z = tf.concat([tf.linalg.expm(-xtu * t), tf.zeros_like(utu)], axis=-2) 86 | y = tf.concat([x, u], axis=-1) @ w @ z 87 | return y 88 | 89 | def dist(self, x, y, keepdims=False): 90 | raise NotImplementedError 91 | 92 | def log(self, x, y): 93 | return NotImplementedError 94 | 95 | def transp(self, x, y, v): 96 | return self.proju(y, v) 97 | 98 | 99 | class StiefelCanonical(_Stiefel): 100 | """Manifold of orthonormal p-frames in the n-dimensional space endowed with 101 | the canonical inner product derived from the quotient space 102 | representation. 103 | 104 | Zimmermann, Ralf. "A matrix-algebraic algorithm for the Riemannian 105 | logarithm on the Stiefel manifold under the canonical metric." SIAM 106 | Journal on Matrix Analysis and Applications 38.2 (2017): 322-342. 107 | 108 | """ 109 | 110 | name = "Canonical Stiefel" 111 | 112 | def inner(self, x, u, v, keepdims=False): 113 | xtu = utils.transposem(x) @ u 114 | xtv = utils.transposem(x) @ v 115 | u_v_inner = tf.reduce_sum(u * v, axis=[-2, -1], keepdims=keepdims) 116 | xtu_xtv_inner = tf.reduce_sum( 117 | xtu * xtv, axis=[-2, -1], keepdims=keepdims 118 | ) 119 | return u_v_inner - 0.5 * xtu_xtv_inner 120 | 121 | def proju(self, x, u): 122 | return u - x @ utils.transposem(u) @ x 123 | 124 | def exp(self, x, u): 125 | return self.geodesic(x, u, 1.0) 126 | 127 | def retr(self, x, u): 128 | xut = x @ utils.transposem(u) 129 | xut_sym = (xut - utils.transposem(xut)) / 2.0 130 | eye = tf.eye( 131 | tf.shape(xut)[-1], batch_shape=tf.shape(xut)[:-2], dtype=x.dtype 132 | ) 133 | return tf.linalg.solve(xut_sym + eye, x - xut_sym @ x) 134 | 135 | def geodesic(self, x, u, t): 136 | xtu = utils.transposem(x) @ u 137 | utu = utils.transposem(u) @ u 138 | eye = tf.eye( 139 | tf.shape(utu)[-1], batch_shape=tf.shape(utu)[:-2], dtype=x.dtype 140 | ) 141 | q, r = tf.linalg.qr(u - x @ xtu) 142 | logw = blockm(xtu, -utils.transposem(r), r, tf.zeros_like(r)) 143 | w = tf.linalg.expm(t * logw) 144 | z = tf.concat([eye, tf.zeros_like(utu)], axis=-2) 145 | y = tf.concat([x, u], axis=-1) @ w @ z 146 | return y 147 | 148 | def transp(self, x, y, v): 149 | return self.proju(y, v) 150 | 151 | def dist(self, x, y, keepdims=False): 152 | raise NotImplementedError 153 | 154 | def log(self, x, y): 155 | return NotImplementedError 156 | 157 | 158 | class StiefelCayley(_Stiefel): 159 | """Manifold of orthonormal p-frames in the n-dimensional space endowed with 160 | the Euclidean inner product. 161 | 162 | Retraction and parallel transport operations are implemented using 163 | the iterative Cayley transform. 164 | 165 | Li, Jun, Fuxin Li, and Sinisa Todorovic. "Efficient Riemannian 166 | Optimization on the Stiefel Manifold via the Cayley Transform." 167 | International Conference on Learning Representations. 2019. 168 | 169 | """ 170 | 171 | name = "Cayley Stiefel" 172 | 173 | def __init__(self, num_iter=10): 174 | """Instantiate the Stiefel manifold with Cayley transform. 175 | 176 | Args: 177 | num_iter: number of Cayley iterations to perform. 178 | """ 179 | self.num_iter = num_iter 180 | super().__init__() 181 | 182 | def inner(self, x, u, v, keepdims=False): 183 | return tf.reduce_sum(u * v, axis=[-2, -1], keepdims=keepdims) 184 | 185 | def proju(self, x, u): 186 | xtu = utils.transposem(x) @ u 187 | w = (u - x @ xtu) @ utils.transposem(x) 188 | return (w - utils.transposem(w)) @ x 189 | 190 | def exp(self, x, u): 191 | return self.geodesic(x, u, 1.0) 192 | 193 | def geodesic(self, x, u, t): 194 | xtu = utils.transposem(x) @ u 195 | w_ = (u - x @ xtu) @ utils.transposem(x) 196 | w = w_ - utils.transposem(w_) 197 | eye = tf.linalg.eye( 198 | tf.shape(w)[-1], batch_shape=tf.shape(w)[:-2], dtype=x.dtype 199 | ) 200 | cayley_t = tf.linalg.inv(eye - t * w / 2.0) @ (eye + t * w / 2.0) 201 | return cayley_t @ x 202 | 203 | def retr(self, x, u): 204 | xtu = utils.transposem(x) @ u 205 | w_ = (u - x @ xtu) @ utils.transposem(x) 206 | w = w_ - utils.transposem(w_) 207 | y = x + u 208 | for _ in range(self.num_iter): 209 | y = x + w @ ((x + y) / 2.0) 210 | return y 211 | 212 | def transp(self, x, y, v): 213 | return self.proju(y, v) 214 | 215 | def dist(self, x, y, keepdims=False): 216 | raise NotImplementedError 217 | 218 | def log(self, x, y): 219 | return NotImplementedError 220 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/stiefel_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from absl.testing import parameterized 4 | from tensorflow.python.keras import combinations 5 | 6 | from tensorflow_riemopt.manifolds.test_invariants import TestInvariants 7 | from tensorflow_riemopt.manifolds.stiefel import StiefelEuclidean 8 | from tensorflow_riemopt.manifolds.stiefel import StiefelCanonical 9 | from tensorflow_riemopt.manifolds.stiefel import StiefelCayley 10 | 11 | 12 | @combinations.generate( 13 | combinations.combine( 14 | mode=["graph", "eager"], 15 | manifold=[ 16 | StiefelEuclidean(), 17 | StiefelCanonical(), 18 | StiefelCayley(num_iter=100), 19 | ], 20 | shape=[(2, 5, 3), (5, 3)], 21 | dtype=[tf.float32, tf.float64], 22 | ) 23 | ) 24 | class StiefelTest(tf.test.TestCase, parameterized.TestCase): 25 | test_random = TestInvariants.check_random 26 | 27 | test_inner = TestInvariants.check_inner 28 | 29 | test_proj = TestInvariants.check_proj 30 | 31 | test_transp_retr = TestInvariants.check_transp_retr 32 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/symmetric_positive.py: -------------------------------------------------------------------------------- 1 | """Manifolds of symmetric positive definite matrices.""" 2 | 3 | import warnings 4 | 5 | import tensorflow as tf 6 | 7 | from tensorflow_riemopt.manifolds.manifold import Manifold 8 | from tensorflow_riemopt.manifolds.cholesky import Cholesky 9 | from tensorflow_riemopt.manifolds import utils 10 | 11 | 12 | class _SymmetricPositiveDefinite(Manifold): 13 | """Manifold of real symmetric positive definite matrices. 14 | 15 | Henrion, Didier, and Jérôme Malick. "Projection methods in conic 16 | optimization." Handbook on Semidefinite, Conic and Polynomial 17 | Optimization. Springer, Boston, MA, 2012. 565-600. 18 | 19 | """ 20 | 21 | ndims = 2 22 | 23 | def _check_shape(self, shape): 24 | return shape[-1] == shape[-2] 25 | 26 | def _check_point_on_manifold(self, x, atol, rtol): 27 | x_t = utils.transposem(x) 28 | eigvals, _ = tf.linalg.eigh(x) 29 | is_symmetric = utils.allclose(x, x_t, atol, rtol) 30 | is_pos_vals = utils.allclose(eigvals, tf.abs(eigvals), atol, rtol) 31 | is_zero_vals = utils.allclose( 32 | eigvals, tf.zeros_like(eigvals), atol, rtol 33 | ) 34 | return is_symmetric & is_pos_vals & tf.logical_not(is_zero_vals) 35 | 36 | def _check_vector_on_tangent(self, x, u, atol, rtol): 37 | u_t = utils.transposem(u) 38 | return utils.allclose(u, u_t, atol, rtol) 39 | 40 | def projx(self, x): 41 | warnings.warn( 42 | ( 43 | "{}.projx performs a projection onto the open set of" 44 | + " PSD matrices" 45 | ).format(self.__class__.__name__) 46 | ) 47 | x_sym = (utils.transposem(x) + x) / 2.0 48 | s, u, v = tf.linalg.svd(x_sym) 49 | sigma = tf.linalg.diag(tf.maximum(s, 0.0)) 50 | return v @ sigma @ utils.transposem(v) 51 | 52 | def proju(self, x, u): 53 | return 0.5 * (utils.transposem(u) + u) 54 | 55 | 56 | class SPDAffineInvariant(_SymmetricPositiveDefinite): 57 | """Manifold of symmetric positive definite matrices endowed with the 58 | affine-invariant metric. 59 | 60 | Pennec, Xavier, Pierre Fillard, and Nicholas Ayache. "A Riemannian 61 | framework for tensor computing." International Journal of computer vision 62 | 66.1 (2006): 41-66. 63 | 64 | Sra, Suvrit, and Reshad Hosseini. "Conic geometric optimization on the 65 | manifold of positive definite matrices." SIAM Journal on Optimization 25.1 66 | (2015): 713-739. 67 | 68 | """ 69 | 70 | name = "Affine-Invariant SPD" 71 | 72 | def dist(self, x, y, keepdims=False): 73 | x_sqrt_inv = tf.linalg.inv(tf.linalg.sqrtm(x)) 74 | log = utils.logm(x_sqrt_inv @ y @ x_sqrt_inv) 75 | return tf.linalg.norm(log, axis=[-2, -1], ord="fro", keepdims=keepdims) 76 | 77 | def inner(self, x, u, v, keepdims=False): 78 | x_sqrt_inv = tf.linalg.inv(tf.linalg.sqrtm(x)) 79 | x_inv = tf.linalg.inv(x) 80 | traces = tf.linalg.trace(x_sqrt_inv @ u @ x_inv @ v @ x_sqrt_inv) 81 | return traces[..., tf.newaxis, tf.newaxis] if keepdims else traces 82 | 83 | def exp(self, x, u): 84 | x_sqrt = tf.linalg.sqrtm(x) 85 | x_sqrt_inv = tf.linalg.inv(x_sqrt) 86 | return x_sqrt @ tf.linalg.expm(x_sqrt_inv @ u @ x_sqrt_inv) @ x_sqrt 87 | 88 | retr = exp 89 | 90 | def log(self, x, y): 91 | x_sqrt = tf.linalg.sqrtm(x) 92 | x_sqrt_inv = tf.linalg.inv(x_sqrt) 93 | return x_sqrt @ utils.logm(x_sqrt_inv @ y @ x_sqrt_inv) @ x_sqrt 94 | 95 | def geodesic(self, x, u, t): 96 | x_sqrt = tf.linalg.sqrtm(x) 97 | x_sqrt_inv = tf.linalg.inv(x_sqrt) 98 | return x_sqrt @ tf.linalg.expm(t * x_sqrt_inv @ u @ x_sqrt_inv) @ x_sqrt 99 | 100 | def ptransp(self, x, y, v): 101 | e = tf.linalg.sqrtm(y @ tf.linalg.inv(x)) 102 | return e @ v @ utils.transposem(e) 103 | 104 | def ptransp1(self, x, y, v): 105 | x_sqrt = tf.linalg.sqrtm(x) 106 | x_sqrt_inv = tf.linalg.inv(x_sqrt) 107 | h = tf.linalg.expm(0.5 * x_sqrt_inv @ self.log(x, y) @ x_sqrt_inv) 108 | return x_sqrt @ h @ x_sqrt_inv @ v @ x_sqrt_inv @ h @ x_sqrt 109 | 110 | transp = ptransp 111 | 112 | def pairmean(self, x, y): 113 | x_sqrt = tf.linalg.sqrtm(x) 114 | x_sqrt_inv = tf.linalg.inv(x_sqrt) 115 | return x_sqrt @ tf.linalg.sqrtm(x_sqrt_inv @ y @ x_sqrt_inv) @ x_sqrt 116 | 117 | 118 | class SPDLogEuclidean(_SymmetricPositiveDefinite): 119 | """Manifold of symmetric positive definite matrices endowed with the 120 | Log-Euclidean metric. 121 | 122 | Arsigny, Vincent, et al. "Geometric means in a novel vector space 123 | structure on symmetric positive-definite matrices." SIAM journal on matrix 124 | analysis and applications 29.1 (2007): 328-347. 125 | 126 | """ 127 | 128 | def dist(self, x, y, keepdims=False): 129 | diff = utils.logm(y) - utils.logm(x) 130 | return tf.linalg.norm(diff, axis=[-2, -1], ord="fro", keepdims=keepdims) 131 | 132 | def _diff_power(self, x, d, power): 133 | e, v = tf.linalg.eigh(x) 134 | v_t = utils.transposem(v) 135 | e = tf.expand_dims(e, -2) 136 | if power == "log": 137 | pow_e = tf.math.log(e) 138 | elif power == "exp": 139 | pow_e = tf.math.exp(e) 140 | s = utils.transposem(tf.ones_like(e)) @ e 141 | pow_s = utils.transposem(tf.ones_like(pow_e)) @ pow_e 142 | denom = utils.transposem(s) - s 143 | numer = utils.transposem(pow_s) - pow_s 144 | abs_denom = tf.math.abs(denom) 145 | eps = utils.get_eps(x) 146 | if power == "log": 147 | numer = tf.where(abs_denom < eps, tf.ones_like(numer), numer) 148 | denom = tf.where(abs_denom < eps, utils.transposem(s), denom) 149 | elif power == "exp": 150 | numer = tf.where(abs_denom < eps, utils.transposem(pow_s), numer) 151 | denom = tf.where(abs_denom < eps, tf.ones_like(denom), denom) 152 | t = v_t @ d @ v * numer / denom 153 | return v @ t @ v_t 154 | 155 | def _diff_exp(self, x, d): 156 | """Directional derivative of `expm` at :math:`x` along :math:`d`""" 157 | return self._diff_power(x, d, "exp") 158 | 159 | def _diff_log(self, x, d): 160 | """Directional derivative of `logm` at :math:`x` along :math:`d`""" 161 | return self._diff_power(x, d, "log") 162 | 163 | def inner(self, x, u, v, keepdims=False): 164 | dlog_u = self._diff_log(x, u) 165 | dlog_v = self._diff_log(x, v) 166 | return tf.reduce_sum(dlog_u * dlog_v, axis=[-2, -1], keepdims=keepdims) 167 | 168 | def exp(self, x, u): 169 | return tf.linalg.expm(utils.logm(x) + self._diff_log(x, u)) 170 | 171 | retr = exp 172 | 173 | def log(self, x, y): 174 | return self._diff_exp(utils.logm(x), utils.logm(y) - utils.logm(x)) 175 | 176 | def geodesic(self, x, u, t): 177 | return tf.linalg.expm(utils.logm(x) + t * self._diff_log(x, u)) 178 | 179 | def pairmean(self, x, y): 180 | return tf.linalg.expm((utils.logm(x) + utils.logm(y)) / 2.0) 181 | 182 | def transp(self, x, y, v): 183 | raise NotImplementedError 184 | 185 | 186 | class SPDLogCholesky(_SymmetricPositiveDefinite): 187 | """Manifold of symmetric positive definite matrices endowed with the 188 | Log-Cholesky metric. 189 | 190 | The geometry of manifold is induced by the diffeomorphism between the SPD 191 | and Cholesky spaces. 192 | 193 | Lin, Zhenhua. "Riemannian Geometry of Symmetric Positive Definite Matrices 194 | via Cholesky Decomposition." SIAM Journal on Matrix Analysis and 195 | Applications 40.4 (2019): 1353-1370. 196 | 197 | """ 198 | 199 | name = "Log-Cholesky SPD" 200 | 201 | def __init__(self): 202 | self._cholesky = Cholesky() 203 | super().__init__() 204 | 205 | def to_cholesky(self, x): 206 | """Diffeomorphism that maps to the Cholesky space""" 207 | assert_x = tf.debugging.Assert(self.check_point_on_manifold(x), [x]) 208 | with tf.control_dependencies([assert_x]): 209 | return tf.linalg.cholesky(x) 210 | 211 | def from_cholesky(self, x): 212 | """Inverse of the diffeomorphism to the Cholesky space""" 213 | assert_x = tf.debugging.Assert( 214 | self._cholesky.check_point_on_manifold(x), [x] 215 | ) 216 | with tf.control_dependencies([assert_x]): 217 | return x @ utils.transposem(x) 218 | 219 | def diff_to_cholesky(self, x, u): 220 | """Differential of the diffeomorphism to the Cholesky space""" 221 | assert_x = tf.debugging.Assert(self.check_point_on_manifold(x), [x]) 222 | assert_u = tf.debugging.Assert(self.check_vector_on_tangent(x, u), [u]) 223 | with tf.control_dependencies([assert_x, assert_u]): 224 | y = self.to_cholesky(x) 225 | y_inv = tf.linalg.inv(y) 226 | p = y_inv @ u @ utils.transposem(y_inv) 227 | p_diag, p_lower = self._cholesky._diag_and_strictly_lower(p) 228 | return y @ (p_lower + 0.5 * tf.linalg.diag(p_diag)) 229 | 230 | def diff_from_cholesky(self, x, u): 231 | """Inverse of the differential of diffeomorphism to the Cholesky space""" 232 | assert_x = tf.debugging.Assert( 233 | self._cholesky.check_point_on_manifold(x), [x] 234 | ) 235 | assert_u = tf.debugging.Assert( 236 | self._cholesky.check_vector_on_tangent(x, u), [u] 237 | ) 238 | with tf.control_dependencies([assert_x, assert_u]): 239 | return x @ utils.transposem(u) + u @ utils.transposem(x) 240 | 241 | def dist(self, x, y, keepdims=False): 242 | x_chol = self.to_cholesky(x) 243 | y_chol = self.to_cholesky(y) 244 | return self._cholesky.dist(x_chol, y_chol, keepdims=keepdims) 245 | 246 | def inner(self, x, u, v, keepdims=False): 247 | x_chol = self.to_cholesky(x) 248 | u_diff_chol = self.diff_to_cholesky(x, u) 249 | v_diff_chol = self.diff_to_cholesky(x, v) 250 | return self._cholesky.inner( 251 | x_chol, u_diff_chol, v_diff_chol, keepdims=keepdims 252 | ) 253 | 254 | def exp(self, x, u): 255 | x_chol = self.to_cholesky(x) 256 | u_diff_chol = self.diff_to_cholesky(x, u) 257 | exp_chol = self._cholesky.exp(x_chol, u_diff_chol) 258 | return self.from_cholesky(exp_chol) 259 | 260 | retr = exp 261 | 262 | def log(self, x, y): 263 | x_chol = self.to_cholesky(x) 264 | y_chol = self.to_cholesky(y) 265 | log_chol = self._cholesky.log(x_chol, y_chol) 266 | return self.diff_from_cholesky(x_chol, log_chol) 267 | 268 | def ptransp(self, x, y, v): 269 | x_chol = self.to_cholesky(x) 270 | y_chol = self.to_cholesky(y) 271 | v_diff_chol = self.diff_to_cholesky(x, v) 272 | transp_chol = self._cholesky.transp(x_chol, y_chol, v_diff_chol) 273 | return self.diff_from_cholesky(y_chol, transp_chol) 274 | 275 | transp = ptransp 276 | 277 | def geodesic(self, x, u, t): 278 | x_chol = self.to_cholesky(x) 279 | u_diff_chol = self.diff_to_cholesky(x, u) 280 | geodesic_chol = self._cholesky.geodesic(x_chol, u_diff_chol, t) 281 | return self.from_cholesky(geodesic_chol) 282 | 283 | def pairmean(self, x, y): 284 | x_chol = self.to_cholesky(x) 285 | y_chol = self.to_cholesky(y) 286 | pairmean_chol = self._cholesky.pairmean(x_chol, y_chol) 287 | return self.from_cholesky(pairmean_chol) 288 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/symmetric_positive_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from absl.testing import parameterized 4 | from tensorflow.python.keras import combinations 5 | 6 | from tensorflow_riemopt.manifolds.test_invariants import TestInvariants 7 | from tensorflow_riemopt.manifolds.symmetric_positive import SPDAffineInvariant 8 | from tensorflow_riemopt.manifolds.symmetric_positive import SPDLogEuclidean 9 | from tensorflow_riemopt.manifolds.symmetric_positive import SPDLogCholesky 10 | 11 | 12 | @combinations.generate( 13 | combinations.combine( 14 | mode=["graph", "eager"], 15 | manifold=[SPDAffineInvariant(), SPDLogCholesky()], 16 | shape=[(2, 3, 3), (3, 3)], 17 | dtype=[tf.float64], 18 | ) 19 | ) 20 | class SymmetricPositiveTest(tf.test.TestCase, parameterized.TestCase): 21 | test_random = TestInvariants.check_random 22 | 23 | test_dist = TestInvariants.check_dist 24 | 25 | test_inner = TestInvariants.check_inner 26 | 27 | test_proj = TestInvariants.check_proj 28 | 29 | test_exp_log_inverse = TestInvariants.check_exp_log_inverse 30 | 31 | test_transp_retr = TestInvariants.check_transp_retr 32 | 33 | test_ptransp_inverse = TestInvariants.check_ptransp_inverse 34 | 35 | test_ptransp_inner = TestInvariants.check_ptransp_inner 36 | 37 | test_geodesic = TestInvariants.check_geodesic 38 | 39 | test_pairmean = TestInvariants.check_pairmean 40 | 41 | 42 | @combinations.generate( 43 | combinations.combine( 44 | mode=["graph", "eager"], 45 | manifold=[SPDLogEuclidean()], 46 | shape=[(2, 3, 3), (3, 3)], 47 | dtype=[tf.float32, tf.float64], 48 | ) 49 | ) 50 | class SymmetricPositiveLETest(tf.test.TestCase, parameterized.TestCase): 51 | test_random = TestInvariants.check_random 52 | 53 | test_dist = TestInvariants.check_dist 54 | 55 | test_inner = TestInvariants.check_inner 56 | 57 | test_proj = TestInvariants.check_proj 58 | 59 | test_exp_log_inverse = TestInvariants.check_exp_log_inverse 60 | 61 | test_geodesic = TestInvariants.check_geodesic 62 | 63 | test_pairmean = TestInvariants.check_pairmean 64 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/test_invariants.py: -------------------------------------------------------------------------------- 1 | """Manifold test invariants.""" 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | 7 | def random_constant(shape, dtype): 8 | return tf.constant( 9 | np.random.uniform(size=shape, high=1e-1), 10 | dtype=dtype.as_numpy_dtype, 11 | ) 12 | 13 | 14 | class TestInvariants(tf.test.TestCase): 15 | def check_random(self, manifold, shape, dtype): 16 | """Check random point generator""" 17 | with self.cached_session(use_gpu=True): 18 | x = manifold.random(shape=shape, dtype=dtype) 19 | self.assertEqual(list(shape), x.shape.as_list()) 20 | x_on_manifold = manifold.check_point_on_manifold(x) 21 | if not tf.executing_eagerly(): 22 | x_on_manifold = self.evaluate(x_on_manifold) 23 | self.assertTrue(x_on_manifold) 24 | 25 | def check_dist(self, manifold, shape, dtype): 26 | """Check the distance axioms""" 27 | with self.cached_session(use_gpu=True): 28 | x_rand = random_constant(shape=shape, dtype=dtype) 29 | x = manifold.projx(x_rand) 30 | y_rand = random_constant(shape=shape, dtype=dtype) 31 | y = manifold.projx(x_rand) 32 | dist_xy = manifold.dist(x, y) 33 | dist_yx = manifold.dist(y, x) 34 | self.assertAllCloseAccordingToType(dist_xy, dist_yx) 35 | dist_xx = manifold.dist(x, x) 36 | # low precision comparison for manifolds which use trigonometric functions 37 | self.assertAllClose(dist_xx, tf.zeros_like(dist_xx), atol=1e-3) 38 | orig_shape = x.shape.as_list() 39 | keepdims_shape = manifold.dist(x, x, keepdims=True).shape.as_list() 40 | nokeepdims_shape = manifold.dist( 41 | x, x, keepdims=False 42 | ).shape.as_list() 43 | self.assertEqual(len(keepdims_shape), len(orig_shape)) 44 | self.assertEqual( 45 | len(nokeepdims_shape), len(orig_shape) - manifold.ndims 46 | ) 47 | if manifold.ndims > 0: 48 | self.assertEqual( 49 | keepdims_shape[-manifold.ndims :], [1] * manifold.ndims 50 | ) 51 | 52 | def check_inner(self, manifold, shape, dtype): 53 | """Check the inner product axioms""" 54 | with self.cached_session(use_gpu=True): 55 | x_rand = random_constant(shape=shape, dtype=dtype) 56 | x = manifold.projx(x_rand) 57 | u = manifold.proju(x, x_rand) 58 | orig_shape = u.shape.as_list() 59 | keepdims_shape = manifold.inner( 60 | x, u, u, keepdims=True 61 | ).shape.as_list() 62 | nokeepdims_shape = manifold.inner( 63 | x, u, u, keepdims=False 64 | ).shape.as_list() 65 | self.assertEqual(len(keepdims_shape), len(orig_shape)) 66 | self.assertEqual( 67 | len(nokeepdims_shape), len(orig_shape) - manifold.ndims 68 | ) 69 | if manifold.ndims > 0: 70 | self.assertEqual( 71 | keepdims_shape[-manifold.ndims :], [1] * manifold.ndims 72 | ) 73 | 74 | def check_proj(self, manifold, shape, dtype): 75 | """Check projection from the ambient space""" 76 | with self.cached_session(use_gpu=True): 77 | x_rand = random_constant(shape=shape, dtype=dtype) 78 | x = manifold.projx(x_rand) 79 | x_on_manifold = manifold.check_point_on_manifold(x) 80 | if not tf.executing_eagerly(): 81 | x_on_manifold = self.evaluate(x_on_manifold) 82 | self.assertTrue(x_on_manifold) 83 | u_rand = random_constant(shape=shape, dtype=dtype) 84 | u = manifold.proju(x, u_rand) 85 | u_on_tangent = manifold.check_vector_on_tangent(x, u) 86 | if not tf.executing_eagerly(): 87 | u_on_tangent = self.evaluate(u_on_tangent) 88 | self.assertTrue(u_on_tangent) 89 | 90 | def check_exp_log_inverse(self, manifold, shape, dtype): 91 | """Check that logarithmic map is the inverse of exponential map""" 92 | with self.cached_session(use_gpu=True): 93 | x = manifold.projx(random_constant(shape, dtype)) 94 | u = manifold.proju(x, random_constant(shape, dtype)) 95 | y = manifold.exp(x, u) 96 | y_on_manifold = manifold.check_point_on_manifold(y) 97 | if not tf.executing_eagerly(): 98 | y_on_manifold = self.evaluate(y_on_manifold) 99 | self.assertTrue(y_on_manifold) 100 | v = manifold.log(x, y) 101 | v_on_tangent = manifold.check_vector_on_tangent(x, v) 102 | if not tf.executing_eagerly(): 103 | v_on_tangent = self.evaluate(v_on_tangent) 104 | self.assertTrue(v_on_tangent) 105 | self.assertAllCloseAccordingToType(u, v) 106 | 107 | def check_transp_retr(self, manifold, shape, dtype): 108 | """Test that vector transport is compatible with retraction""" 109 | with self.cached_session(use_gpu=True): 110 | x = manifold.projx(random_constant(shape, dtype)) 111 | u = manifold.proju(x, random_constant(shape, dtype)) 112 | y = manifold.retr(x, u) 113 | y_on_manifold = manifold.check_point_on_manifold(y) 114 | if not tf.executing_eagerly(): 115 | y_on_manifold = self.evaluate(y_on_manifold) 116 | v = manifold.proju(x, random_constant(shape, dtype)) 117 | v_ = manifold.transp(x, y, v) 118 | v_on_tangent = manifold.check_vector_on_tangent(y, v_) 119 | if not tf.executing_eagerly(): 120 | v_on_tangent = self.evaluate(v_on_tangent) 121 | self.assertTrue(v_on_tangent) 122 | w = manifold.proju(x, random_constant(shape, dtype)) 123 | w_ = manifold.transp(x, y, w) 124 | w_v = v + w 125 | w_v_ = manifold.transp(x, y, w_v) 126 | self.assertAllCloseAccordingToType(w_v_, w_ + v_) 127 | 128 | def check_ptransp_inverse(self, manifold, shape, dtype): 129 | """Test that parallel transport is an invertible operation""" 130 | with self.cached_session(use_gpu=True): 131 | x = manifold.projx(random_constant(shape, dtype)) 132 | y = manifold.projx(random_constant(shape, dtype)) 133 | u = manifold.proju(x, random_constant(shape, dtype)) 134 | v = manifold.ptransp(x, y, u) 135 | v_on_tangent = manifold.check_vector_on_tangent(y, v) 136 | if not tf.executing_eagerly(): 137 | v_on_tangent = self.evaluate(v_on_tangent) 138 | self.assertTrue(v_on_tangent) 139 | w = manifold.ptransp(y, x, v) 140 | w_on_tangent = manifold.check_vector_on_tangent(x, w) 141 | if not tf.executing_eagerly(): 142 | w_on_tangent = self.evaluate(w_on_tangent) 143 | self.assertTrue(w_on_tangent) 144 | self.assertAllCloseAccordingToType(u, w) 145 | 146 | def check_ptransp_inner(self, manifold, shape, dtype): 147 | """Check that parallel transport preserves the inner product""" 148 | with self.cached_session(use_gpu=True): 149 | x = manifold.projx(random_constant(shape, dtype)) 150 | y = manifold.projx(random_constant(shape, dtype)) 151 | u = manifold.proju(x, random_constant(shape, dtype)) 152 | v = manifold.proju(x, random_constant(shape, dtype)) 153 | uv = manifold.inner(x, u, v) 154 | u_ = manifold.ptransp(x, y, u) 155 | v_ = manifold.ptransp(x, y, v) 156 | u_v_ = manifold.inner(y, u_, v_) 157 | self.assertAllCloseAccordingToType(uv, u_v_) 158 | 159 | def check_geodesic(self, manifold, shape, dtype): 160 | """Check that the exponential map lies on a geodesic""" 161 | with self.cached_session(use_gpu=True): 162 | x = manifold.projx(random_constant(shape=shape, dtype=dtype)) 163 | u = manifold.proju(x, random_constant(shape=shape, dtype=dtype)) 164 | y = manifold.geodesic(x, u, 1.0) 165 | y_on_manifold = manifold.check_point_on_manifold(y) 166 | if not tf.executing_eagerly(): 167 | y_on_manifold = self.evaluate(y_on_manifold) 168 | self.assertTrue(y_on_manifold) 169 | y_ = manifold.exp(x, u) 170 | self.assertAllCloseAccordingToType(y, y_) 171 | 172 | def check_pairmean(self, manifold, shape, dtype): 173 | """Check that the Riemannian mean is equidistant from points""" 174 | with self.cached_session(use_gpu=True): 175 | x_rand = random_constant(shape=shape, dtype=dtype) 176 | x = manifold.projx(x_rand) 177 | y_rand = random_constant(shape=shape, dtype=dtype) 178 | y = manifold.projx(x_rand) 179 | m = manifold.pairmean(x, y) 180 | m_on_manifold = manifold.check_point_on_manifold(m) 181 | if not tf.executing_eagerly(): 182 | m_on_manifold = self.evaluate(m_on_manifold) 183 | self.assertTrue(m_on_manifold) 184 | dist_x_m = manifold.dist(x, m) 185 | dist_y_m = manifold.dist(y, m) 186 | self.assertAllCloseAccordingToType(dist_x_m, dist_y_m) 187 | -------------------------------------------------------------------------------- /tensorflow_riemopt/manifolds/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def get_eps(val): 6 | return np.finfo(val.dtype.name).eps 7 | 8 | 9 | def allclose(x, y, rtol=None, atol=None): 10 | """Return True if two arrays are element-wise equal within a tolerance.""" 11 | rtol = 10 * get_eps(x) if rtol is None else rtol 12 | atol = 10 * get_eps(x) if atol is None else atol 13 | return tf.reduce_all(tf.abs(x - y) <= tf.abs(y) * rtol + atol) 14 | 15 | 16 | def logm(inp): 17 | """Compute the matrix logarithm of positive-definite real matrices.""" 18 | inp = tf.convert_to_tensor(inp) 19 | complex_dtype = tf.complex128 if inp.dtype == tf.float64 else tf.complex64 20 | log = tf.linalg.logm(tf.cast(inp, dtype=complex_dtype)) 21 | return tf.cast(log, dtype=inp.dtype) 22 | 23 | 24 | def transposem(inp): 25 | """Transpose multiple matrices.""" 26 | perm = list(range(len(inp.shape))) 27 | perm[-2], perm[-1] = perm[-1], perm[-2] 28 | return tf.transpose(inp, perm) 29 | -------------------------------------------------------------------------------- /tensorflow_riemopt/mcmc/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/master/tensorflow-riemopt/743ca9f0fad8b735de8b537c54fde15fa3b54ba6/tensorflow_riemopt/mcmc/.gitkeep -------------------------------------------------------------------------------- /tensorflow_riemopt/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow_riemopt.optimizers.constrained_rmsprop import ConstrainedRMSprop 2 | from tensorflow_riemopt.optimizers.riemannian_adam import RiemannianAdam 3 | from tensorflow_riemopt.optimizers.riemannian_gradient_descent import ( 4 | RiemannianSGD, 5 | ) 6 | 7 | __all__ = ["ConstrainedRMSprop", "RiemannianAdam", "RiemannianSGD"] 8 | -------------------------------------------------------------------------------- /tensorflow_riemopt/optimizers/constrained_rmsprop.py: -------------------------------------------------------------------------------- 1 | """Constrained RMSprop optimizer implementation. 2 | 3 | Kumar Roy, Soumava, Zakaria Mhammedi, and Mehrtash Harandi. "Geometry aware 4 | constrained optimization techniques for deep learning." Proceedings of the 5 | IEEE Conference on Computer Vision and Pattern Recognition. 2018. 6 | """ 7 | 8 | import numpy as np 9 | 10 | from tensorflow.python.eager import def_function 11 | from tensorflow.python.framework import ops 12 | from tensorflow.python.keras import backend_config 13 | from tensorflow.python.keras.utils import generic_utils 14 | from tensorflow.python.ops import array_ops 15 | from tensorflow.python.ops import control_flow_ops 16 | from tensorflow.python.ops import math_ops 17 | from tensorflow.python.ops import state_ops 18 | from tensorflow.python.ops import gen_training_ops 19 | from tensorflow.python.framework.indexed_slices import IndexedSlices 20 | from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2 21 | 22 | from tensorflow_riemopt.variable import get_manifold 23 | 24 | 25 | @generic_utils.register_keras_serializable(name="ConstrainedRMSprop") 26 | class ConstrainedRMSprop(OptimizerV2): 27 | """Optimizer that implements the RMSprop algorithm.""" 28 | 29 | _HAS_AGGREGATE_GRAD = True 30 | 31 | def __init__( 32 | self, 33 | learning_rate=0.001, 34 | rho=0.9, 35 | epsilon=1e-7, 36 | centered=False, 37 | stabilize=None, 38 | name="ConstrainedRMSprop", 39 | **kwargs, 40 | ): 41 | """Construct a new Constrained RMSprop optimizer. 42 | 43 | Kumar Roy, Soumava, Zakaria Mhammedi, and Mehrtash Harandi. "Geometry 44 | aware constrained optimization techniques for deep learning." Proceedings 45 | of the IEEE Conference on Computer Vision and Pattern Recognition. 2018. 46 | 47 | Args: 48 | learning_rate: A `Tensor`, floating point value, or a schedule that is a 49 | `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable 50 | that takes no arguments and returns the actual value to use. The 51 | learning rate. Defeaults to 0.001. 52 | rho: Discounting factor for the history/coming gradient. Defaults to 0.9. 53 | epsilon: A small constant for numerical stability. This epsilon is 54 | "epsilon hat" in the Kingma and Ba paper (in the formula just before 55 | Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to 56 | 1e-7. 57 | centered: Boolean. If `True`, gradients are normalized by the estimated 58 | variance of the gradient; if False, by the uncentered second moment. 59 | Setting this to `True` may help with training, but is slightly more 60 | expensive in terms of computation and memory. Defaults to `False`. 61 | stabilize: Project variables back to manifold every `stabilize` steps. 62 | Defaults to `None`. 63 | name: Optional name prefix for the operations created when applying 64 | gradients. Defaults to "ConstrainedRMSprop". 65 | **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, 66 | `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip 67 | gradients by value, `decay` is included for backward compatibility to 68 | allow time inverse decay of learning rate. `lr` is included for backward 69 | compatibility, recommended to use `learning_rate` instead. 70 | """ 71 | super().__init__(name, **kwargs) 72 | self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) 73 | self._set_hyper("decay", self._initial_decay) 74 | self._set_hyper("rho", rho) 75 | 76 | self.epsilon = epsilon or backend_config.epsilon() 77 | self.centered = centered 78 | self.stabilize = stabilize 79 | 80 | def _create_slots(self, var_list): 81 | for var in var_list: 82 | self.add_slot(var, "rms") 83 | if self.centered: 84 | for var in var_list: 85 | self.add_slot(var, "mg") 86 | 87 | def _prepare_local(self, var_device, var_dtype, apply_state): 88 | super()._prepare_local(var_device, var_dtype, apply_state) 89 | 90 | rho = array_ops.identity(self._get_hyper("rho", var_dtype)) 91 | apply_state[(var_device, var_dtype)].update( 92 | dict( 93 | neg_lr_t=-apply_state[(var_device, var_dtype)]["lr_t"], 94 | epsilon=ops.convert_to_tensor(self.epsilon, var_dtype), 95 | rho=rho, 96 | one_minus_rho=1.0 - rho, 97 | ) 98 | ) 99 | 100 | @def_function.function(experimental_compile=True) 101 | def _resource_apply_dense(self, grad, var, apply_state=None): 102 | var_device, var_dtype = var.device, var.dtype.base_dtype 103 | coefficients = (apply_state or {}).get( 104 | (var_device, var_dtype) 105 | ) or self._fallback_apply_state(var_device, var_dtype) 106 | 107 | manifold = get_manifold(var) 108 | grad = manifold.egrad2rgrad(var, grad) 109 | grad_sq = manifold.egrad2rgrad(var, grad * grad) 110 | 111 | rms = self.get_slot(var, "rms") 112 | rms_t = ( 113 | coefficients["rho"] * rms + coefficients["one_minus_rho"] * grad_sq 114 | ) 115 | denom_t = rms_t 116 | if self.centered: 117 | mg = self.get_slot(var, "mg") 118 | mg_t = ( 119 | coefficients["rho"] * mg + coefficients["one_minus_rho"] * grad 120 | ) 121 | denom_t -= math_ops.square(mg_t) 122 | 123 | var_t = manifold.retr( 124 | var, 125 | -coefficients["lr_t"] 126 | * grad 127 | / (math_ops.sqrt(denom_t) + coefficients["epsilon"]), 128 | ) 129 | rms.assign(manifold.transp(var, var_t, rms_t)) 130 | if self.centered: 131 | mg.assign(manifold.transp(var, var_t, mg_t)) 132 | var_update = var.assign(var_t) 133 | if self.stabilize is not None: 134 | self._stabilize(var) 135 | return var_update 136 | 137 | @def_function.function(experimental_compile=True) 138 | def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 139 | var_device, var_dtype = var.device, var.dtype.base_dtype 140 | coefficients = (apply_state or {}).get( 141 | (var_device, var_dtype) 142 | ) or self._fallback_apply_state(var_device, var_dtype) 143 | 144 | manifold = get_manifold(var) 145 | grad = manifold.egrad2rgrad(var, grad) 146 | grad_sq = manifold.egrad2rgrad(var, grad * grad) 147 | 148 | rms = self.get_slot(var, "rms") 149 | rms_scaled_g_values = grad_sq * coefficients["one_minus_rho"] 150 | rms_t_values = ( 151 | array_ops.gather(rms, indices) * coefficients["rho"] 152 | + rms_scaled_g_values 153 | ) 154 | 155 | denom_t = rms_t_values 156 | if self.centered: 157 | mg = self.get_slot(var, "mg") 158 | mg_scaled_g_values = grad * coefficients["one_minus_rho"] 159 | mg_t_values = ( 160 | array_ops.gather(mg, indices) * coefficients["rho"] 161 | + mg_scaled_g_values 162 | ) 163 | denom_t -= math_ops.square(mg_t_values) 164 | 165 | var_values = array_ops.gather(var, indices) 166 | var_t_values = manifold.retr( 167 | var_values, 168 | coefficients["neg_lr_t"] 169 | * grad 170 | / (math_ops.sqrt(denom_t) + coefficients["epsilon"]), 171 | ) 172 | 173 | rms_t_transp = manifold.transp(var_values, var_t_values, rms_t_values) 174 | rms.scatter_update(IndexedSlices(rms_t_transp, indices)) 175 | 176 | if self.centered: 177 | mg_t_transp = manifold.transp(var_values, var_t_values, mg_t_values) 178 | mg.scatter_update(IndexedSlices(mg_t_transp, indices)) 179 | 180 | var_update = var.scatter_update(IndexedSlices(var_t_values, indices)) 181 | if self.stabilize is not None: 182 | self._stabilize(var) 183 | return var_update 184 | 185 | @def_function.function(experimental_compile=True) 186 | def _stabilize(self, var): 187 | if math_ops.floor_mod(self.iterations, self.stabilize) == 0: 188 | manifold = get_manifold(var) 189 | var.assign(manifold.projx(var)) 190 | rms = self.get_slot(var, "rms") 191 | rms.assign(manifold.proju(var, rms)) 192 | if self.centered: 193 | mg = self.get_slot(var, "mg") 194 | mg.assign(manifold.proju(var, mg)) 195 | 196 | def set_weights(self, weights): 197 | params = self.weights 198 | if len(params) == len(weights) + 1: 199 | weights = [np.array(0)] + weights 200 | super().set_weights(weights) 201 | 202 | def get_config(self): 203 | config = super().get_config() 204 | config.update( 205 | { 206 | "learning_rate": self._serialize_hyperparameter( 207 | "learning_rate" 208 | ), 209 | "decay": self._serialize_hyperparameter("decay"), 210 | "rho": self._serialize_hyperparameter("rho"), 211 | "epsilon": self.epsilon, 212 | "centered": self.centered, 213 | "stabilize": self.stabilize, 214 | } 215 | ) 216 | return config 217 | -------------------------------------------------------------------------------- /tensorflow_riemopt/optimizers/constrained_rmsprop_test.py: -------------------------------------------------------------------------------- 1 | """Tests for RMSprop.""" 2 | 3 | from absl.testing import parameterized 4 | import numpy as np 5 | 6 | from tensorflow.python.eager import context 7 | from tensorflow.python.framework import constant_op 8 | from tensorflow.python.framework import dtypes 9 | from tensorflow.python.framework import ops 10 | from tensorflow.python.keras import combinations 11 | from tensorflow.python.keras.optimizer_v2 import rmsprop 12 | from tensorflow.python.ops import array_ops 13 | from tensorflow.python.ops import math_ops 14 | from tensorflow.python.ops import variables 15 | from tensorflow.python.platform import test 16 | from tensorflow.python.framework.indexed_slices import IndexedSlices 17 | 18 | from tensorflow_riemopt.optimizers.constrained_rmsprop import ( 19 | ConstrainedRMSprop, 20 | ) 21 | 22 | 23 | class ConstrainedRMSpropTest(test.TestCase, parameterized.TestCase): 24 | def testSparse(self): 25 | for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: 26 | for centered in [False, True]: 27 | with ops.Graph().as_default(), self.cached_session( 28 | use_gpu=True 29 | ): 30 | var0_np = np.array( 31 | [1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype 32 | ) 33 | grads0_np = np.array( 34 | [0.1, 0.0, 0.1], dtype=dtype.as_numpy_dtype 35 | ) 36 | var1_np = np.array( 37 | [3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype 38 | ) 39 | grads1_np = np.array( 40 | [0.01, 0.0, 0.01], dtype=dtype.as_numpy_dtype 41 | ) 42 | 43 | var0 = variables.Variable(var0_np) 44 | var1 = variables.Variable(var1_np) 45 | var0_ref = variables.Variable(var0_np) 46 | var1_ref = variables.Variable(var1_np) 47 | grads0_np_indices = np.array([0, 2], dtype=np.int32) 48 | grads0 = IndexedSlices( 49 | constant_op.constant(grads0_np[grads0_np_indices]), 50 | constant_op.constant(grads0_np_indices), 51 | constant_op.constant([3]), 52 | ) 53 | grads1_np_indices = np.array([0, 2], dtype=np.int32) 54 | grads1 = IndexedSlices( 55 | constant_op.constant(grads1_np[grads1_np_indices]), 56 | constant_op.constant(grads1_np_indices), 57 | constant_op.constant([3]), 58 | ) 59 | opt = ConstrainedRMSprop(centered=centered) 60 | update = opt.apply_gradients( 61 | zip([grads0, grads1], [var0, var1]) 62 | ) 63 | opt_ref = rmsprop.RMSprop(centered=centered) 64 | update_ref = opt_ref.apply_gradients( 65 | zip([grads0, grads1], [var0_ref, var1_ref]) 66 | ) 67 | self.evaluate(variables.global_variables_initializer()) 68 | 69 | # Run 3 steps 70 | for t in range(3): 71 | update.run() 72 | update_ref.run() 73 | 74 | # Validate updated params 75 | self.assertAllCloseAccordingToType( 76 | self.evaluate(var0_ref), self.evaluate(var0) 77 | ) 78 | self.assertAllCloseAccordingToType( 79 | self.evaluate(var1_ref), self.evaluate(var1) 80 | ) 81 | 82 | @combinations.generate(combinations.combine(mode=["graph", "eager"])) 83 | def testBasic(self): 84 | for i, dtype in enumerate( 85 | [dtypes.half, dtypes.float32, dtypes.float64] 86 | ): 87 | for centered in [False, True]: 88 | with self.cached_session(use_gpu=True): 89 | var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) 90 | grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) 91 | var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) 92 | grads1_np = np.array( 93 | [0.01, 0.01], dtype=dtype.as_numpy_dtype 94 | ) 95 | 96 | var0 = variables.Variable(var0_np, name="var0_%d" % i) 97 | var1 = variables.Variable(var1_np, name="var1_%d" % i) 98 | var0_ref = variables.Variable( 99 | var0_np, name="var0_ref_%d" % i 100 | ) 101 | var1_ref = variables.Variable( 102 | var1_np, name="var1_ref_%d" % i 103 | ) 104 | grads0 = constant_op.constant(grads0_np) 105 | grads1 = constant_op.constant(grads1_np) 106 | 107 | learning_rate = 0.001 108 | opt = ConstrainedRMSprop( 109 | learning_rate=learning_rate, 110 | centered=centered, 111 | ) 112 | opt_ref = rmsprop.RMSprop( 113 | learning_rate=learning_rate, 114 | centered=centered, 115 | ) 116 | 117 | if not context.executing_eagerly(): 118 | update = opt.apply_gradients( 119 | zip([grads0, grads1], [var0, var1]) 120 | ) 121 | update_ref = opt_ref.apply_gradients( 122 | zip([grads0, grads1], [var0_ref, var1_ref]) 123 | ) 124 | 125 | self.evaluate(variables.global_variables_initializer()) 126 | 127 | # Run 3 steps 128 | for t in range(3): 129 | if not context.executing_eagerly(): 130 | self.evaluate(update) 131 | self.evaluate(update_ref) 132 | else: 133 | opt.apply_gradients( 134 | zip([grads0, grads1], [var0, var1]) 135 | ) 136 | opt_ref.apply_gradients( 137 | zip([grads0, grads1], [var0_ref, var1_ref]) 138 | ) 139 | 140 | # Validate updated params 141 | self.assertAllCloseAccordingToType( 142 | self.evaluate(var0_ref), 143 | self.evaluate(var0), 144 | rtol=1e-3, 145 | atol=1e-3, 146 | ) 147 | self.assertAllCloseAccordingToType( 148 | self.evaluate(var1_ref), 149 | self.evaluate(var1), 150 | rtol=1e-2, 151 | atol=1e-2, 152 | ) 153 | 154 | 155 | if __name__ == "__main__": 156 | test.main() 157 | -------------------------------------------------------------------------------- /tensorflow_riemopt/optimizers/riemannian_adam.py: -------------------------------------------------------------------------------- 1 | """Riemannian Adam optimizer implementation. 2 | 3 | Becigneul, Gary, and Octavian-Eugen Ganea. "Riemannian Adaptive Optimization 4 | Methods." International Conference on Learning Representations. 2018. 5 | """ 6 | 7 | from tensorflow.python.eager import def_function 8 | from tensorflow.python.framework import ops 9 | from tensorflow.python.keras import backend_config 10 | from tensorflow.python.keras.utils import generic_utils 11 | from tensorflow.python.ops import array_ops 12 | from tensorflow.python.ops import control_flow_ops 13 | from tensorflow.python.ops import math_ops 14 | from tensorflow.python.ops import state_ops 15 | from tensorflow.python.ops import gen_training_ops 16 | from tensorflow.python.framework.indexed_slices import IndexedSlices 17 | from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2 18 | 19 | from tensorflow_riemopt.variable import get_manifold 20 | 21 | 22 | @generic_utils.register_keras_serializable(name="RiemannianAdam") 23 | class RiemannianAdam(OptimizerV2): 24 | """Optimizer that implements the Riemannian Adam algorithm.""" 25 | 26 | _HAS_AGGREGATE_GRAD = True 27 | 28 | def __init__( 29 | self, 30 | learning_rate=0.001, 31 | beta_1=0.9, 32 | beta_2=0.999, 33 | epsilon=1e-7, 34 | amsgrad=False, 35 | stabilize=None, 36 | name="RiemannianAdam", 37 | **kwargs, 38 | ): 39 | """Construct a new Riemannian Adam optimizer. 40 | 41 | Becigneul, Gary, and Octavian-Eugen Ganea. "Riemannian Adaptive 42 | Optimization Methods." International Conference on Learning 43 | Representations. 2018. 44 | 45 | Args: 46 | learning_rate: A `Tensor`, floating point value, or a schedule that is a 47 | `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable that 48 | takes no arguments and returns the actual value to use, The learning 49 | rate. Defaults to 0.001. 50 | beta_1: A float value or a constant float tensor, or a callable that takes 51 | no arguments and returns the actual value to use. The exponential decay 52 | rate for the 1st moment estimates. Defaults to 0.9. 53 | beta_2: A float value or a constant float tensor, or a callable that takes 54 | no arguments and returns the actual value to use, The exponential decay 55 | rate for the 2nd moment estimates. Defaults to 0.999. 56 | epsilon: A small constant for numerical stability. This epsilon is 57 | "epsilon hat" in the Kingma and Ba paper (in the formula just before 58 | Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to 59 | 1e-7. 60 | amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm from 61 | the paper "On the Convergence of Adam and beyond". Defaults to `False`. 62 | stabilize: Project variables back to manifold every `stabilize` steps. 63 | Defaults to `None`. 64 | name: Optional name for the operations created when applying gradients. 65 | Defaults to "RiemannianAdam". 66 | **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, 67 | `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip 68 | gradients by value, `decay` is included for backward compatibility to 69 | allow time inverse decay of learning rate. `lr` is included for backward 70 | compatibility, recommended to use `learning_rate` instead. 71 | 72 | """ 73 | 74 | super().__init__(name, **kwargs) 75 | self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) 76 | self._set_hyper("decay", self._initial_decay) 77 | self._set_hyper("beta_1", beta_1) 78 | self._set_hyper("beta_2", beta_2) 79 | self.epsilon = epsilon or backend_config.epsilon() 80 | self.amsgrad = amsgrad 81 | self.stabilize = stabilize 82 | 83 | def _create_slots(self, var_list): 84 | for var in var_list: 85 | self.add_slot(var, "m") 86 | for var in var_list: 87 | self.add_slot(var, "v") 88 | if self.amsgrad: 89 | for var in var_list: 90 | self.add_slot(var, "vhat") 91 | 92 | def _prepare_local(self, var_device, var_dtype, apply_state): 93 | super()._prepare_local(var_device, var_dtype, apply_state) 94 | 95 | local_step = math_ops.cast(self.iterations + 1, var_dtype) 96 | beta_1_t = array_ops.identity(self._get_hyper("beta_1", var_dtype)) 97 | beta_2_t = array_ops.identity(self._get_hyper("beta_2", var_dtype)) 98 | beta_1_power = math_ops.pow(beta_1_t, local_step) 99 | beta_2_power = math_ops.pow(beta_2_t, local_step) 100 | lr = apply_state[(var_device, var_dtype)]["lr_t"] * ( 101 | math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power) 102 | ) 103 | apply_state[(var_device, var_dtype)].update( 104 | dict( 105 | lr=lr, 106 | epsilon=ops.convert_to_tensor(self.epsilon, var_dtype), 107 | beta_1_t=beta_1_t, 108 | beta_1_power=beta_1_power, 109 | one_minus_beta_1_t=1 - beta_1_t, 110 | beta_2_t=beta_2_t, 111 | beta_2_power=beta_2_power, 112 | one_minus_beta_2_t=1 - beta_2_t, 113 | ) 114 | ) 115 | 116 | def set_weights(self, weights): 117 | params = self.weights 118 | num_vars = int((len(params) - 1) / 2) 119 | if len(weights) == 3 * num_vars + 1: 120 | weights = weights[: len(params)] 121 | super().set_weights(weights) 122 | 123 | @def_function.function(experimental_compile=True) 124 | def _resource_apply_dense(self, grad, var, apply_state=None): 125 | var_device, var_dtype = var.device, var.dtype.base_dtype 126 | coefficients = (apply_state or {}).get( 127 | (var_device, var_dtype) 128 | ) or self._fallback_apply_state(var_device, var_dtype) 129 | 130 | m = self.get_slot(var, "m") 131 | v = self.get_slot(var, "v") 132 | 133 | manifold = get_manifold(var) 134 | grad = manifold.egrad2rgrad(var, grad) 135 | 136 | alpha = ( 137 | coefficients["lr_t"] 138 | * math_ops.sqrt(1 - coefficients["beta_2_power"]) 139 | / (1 - coefficients["beta_1_power"]) 140 | ) 141 | m.assign_add((grad - m) * (1 - coefficients["beta_1_t"])) 142 | v.assign_add( 143 | (manifold.inner(var, grad, grad, keepdims=True) - v) 144 | * (1 - coefficients["beta_2_t"]) 145 | ) 146 | 147 | if self.amsgrad: 148 | vhat = self.get_slot(var, "vhat") 149 | vhat.assign(math_ops.maximum(vhat, v)) 150 | v = vhat 151 | var_t = manifold.retr( 152 | var, -(m * alpha) / (math_ops.sqrt(v) + coefficients["epsilon"]) 153 | ) 154 | m.assign(manifold.transp(var, var_t, m)) 155 | var_update = var.assign(var_t) 156 | if self.stabilize is not None: 157 | self._stabilize(var) 158 | return var_update 159 | 160 | @def_function.function(experimental_compile=True) 161 | def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 162 | var_device, var_dtype = var.device, var.dtype.base_dtype 163 | coefficients = (apply_state or {}).get( 164 | (var_device, var_dtype) 165 | ) or self._fallback_apply_state(var_device, var_dtype) 166 | 167 | manifold = get_manifold(var) 168 | grad = manifold.egrad2rgrad(var, grad) 169 | 170 | # m_t = beta1 * m + (1 - beta1) * g_t 171 | m = self.get_slot(var, "m") 172 | m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"] 173 | m_t_values = ( 174 | array_ops.gather(m, indices) * coefficients["beta_1_t"] 175 | + m_scaled_g_values 176 | ) 177 | 178 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 179 | v = self.get_slot(var, "v") 180 | v_scaled_g_values = ( 181 | manifold.inner(var, grad, grad, keepdims=True) 182 | * coefficients["one_minus_beta_2_t"] 183 | ) 184 | v_t_values = ( 185 | array_ops.gather(v, indices) * coefficients["beta_2_t"] 186 | + v_scaled_g_values 187 | ) 188 | 189 | if self.amsgrad: 190 | vhat = self.get_slot(var, "vhat") 191 | vhat.scatter_max(IndexedSlices(v_t_values, indices)) 192 | v_t_values = array_ops.gather(vhat, indices) 193 | 194 | var_values = array_ops.gather(var, indices) 195 | var_t_values = manifold.retr( 196 | var_values, 197 | -(m_t_values * coefficients["lr"]) 198 | / (math_ops.sqrt(v_t_values) + coefficients["epsilon"]), 199 | ) 200 | m_t_transp = manifold.transp(var_values, var_t_values, m_t_values) 201 | 202 | m.scatter_update(IndexedSlices(m_t_transp, indices)) 203 | v.scatter_update(IndexedSlices(v_t_values, indices)) 204 | var_update = var.scatter_update(IndexedSlices(var_t_values, indices)) 205 | if self.stabilize is not None: 206 | self._stabilize(var) 207 | return var_update 208 | 209 | @def_function.function(experimental_compile=True) 210 | def _stabilize(self, var): 211 | if math_ops.floor_mod(self.iterations, self.stabilize) == 0: 212 | manifold = get_manifold(var) 213 | m = self.get_slot(var, "m") 214 | var.assign(manifold.projx(var)) 215 | m.assign(manifold.proju(var, m)) 216 | 217 | def get_config(self): 218 | config = super().get_config() 219 | config.update( 220 | { 221 | "learning_rate": self._serialize_hyperparameter( 222 | "learning_rate" 223 | ), 224 | "decay": self._serialize_hyperparameter("decay"), 225 | "beta_1": self._serialize_hyperparameter("beta_1"), 226 | "beta_2": self._serialize_hyperparameter("beta_2"), 227 | "epsilon": self.epsilon, 228 | "amsgrad": self.amsgrad, 229 | "stabilize": self.stabilize, 230 | } 231 | ) 232 | return config 233 | -------------------------------------------------------------------------------- /tensorflow_riemopt/optimizers/riemannian_adam_test.py: -------------------------------------------------------------------------------- 1 | """Tests for Adam.""" 2 | 3 | from absl.testing import parameterized 4 | import numpy as np 5 | 6 | from tensorflow.python.eager import context 7 | from tensorflow.python.framework import constant_op 8 | from tensorflow.python.framework import dtypes 9 | from tensorflow.python.framework import ops 10 | from tensorflow.python.keras import combinations 11 | from tensorflow.python.keras.optimizer_v2 import adam 12 | from tensorflow.python.keras import optimizers 13 | from tensorflow.python.ops import array_ops 14 | from tensorflow.python.ops import math_ops 15 | from tensorflow.python.ops import variables 16 | from tensorflow.python.platform import test 17 | from tensorflow.python.framework.indexed_slices import IndexedSlices 18 | 19 | from tensorflow_riemopt.optimizers.riemannian_adam import RiemannianAdam 20 | 21 | 22 | def get_beta_accumulators(opt, dtype): 23 | local_step = math_ops.cast(opt.iterations + 1, dtype) 24 | beta_1_t = math_ops.cast(opt._get_hyper("beta_1"), dtype) 25 | beta_1_power = math_ops.pow(beta_1_t, local_step) 26 | beta_2_t = math_ops.cast(opt._get_hyper("beta_2"), dtype) 27 | beta_2_power = math_ops.pow(beta_2_t, local_step) 28 | return (beta_1_power, beta_2_power) 29 | 30 | 31 | class RiemannianAdamOptimizerTest(test.TestCase, parameterized.TestCase): 32 | @combinations.generate(combinations.combine(mode=["graph", "eager"])) 33 | def testBasic(self): 34 | for i, dtype in enumerate( 35 | [dtypes.half, dtypes.float32, dtypes.float64] 36 | ): 37 | for amsgrad in [False, True]: 38 | with self.cached_session(use_gpu=True): 39 | var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) 40 | grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) 41 | var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) 42 | grads1_np = np.array( 43 | [0.01, 0.01], dtype=dtype.as_numpy_dtype 44 | ) 45 | 46 | var0 = variables.Variable(var0_np, name="var0_%d" % i) 47 | var1 = variables.Variable(var1_np, name="var1_%d" % i) 48 | var0_ref = variables.Variable( 49 | var0_np, name="var0_ref_%d" % i 50 | ) 51 | var1_ref = variables.Variable( 52 | var1_np, name="var1_ref_%d" % i 53 | ) 54 | grads0 = constant_op.constant(grads0_np) 55 | grads1 = constant_op.constant(grads1_np) 56 | 57 | learning_rate = 0.001 58 | beta1 = 0.9 59 | beta2 = 0.999 60 | epsilon = 1e-8 61 | 62 | opt = RiemannianAdam( 63 | learning_rate=learning_rate, 64 | beta_1=beta1, 65 | beta_2=beta2, 66 | epsilon=epsilon, 67 | amsgrad=amsgrad, 68 | ) 69 | opt_ref = adam.Adam( 70 | learning_rate=learning_rate, 71 | beta_1=beta1, 72 | beta_2=beta2, 73 | epsilon=epsilon, 74 | amsgrad=amsgrad, 75 | ) 76 | 77 | if not context.executing_eagerly(): 78 | update = opt.apply_gradients( 79 | zip([grads0, grads1], [var0, var1]) 80 | ) 81 | update_ref = opt_ref.apply_gradients( 82 | zip([grads0, grads1], [var0_ref, var1_ref]) 83 | ) 84 | 85 | self.evaluate(variables.global_variables_initializer()) 86 | # Run 3 steps 87 | for t in range(3): 88 | beta_1_power, beta_2_power = get_beta_accumulators( 89 | opt, dtype 90 | ) 91 | self.assertAllCloseAccordingToType( 92 | beta1 ** (t + 1), self.evaluate(beta_1_power) 93 | ) 94 | self.assertAllCloseAccordingToType( 95 | beta2 ** (t + 1), self.evaluate(beta_2_power) 96 | ) 97 | if not context.executing_eagerly(): 98 | self.evaluate(update) 99 | self.evaluate(update_ref) 100 | else: 101 | opt.apply_gradients( 102 | zip([grads0, grads1], [var0, var1]) 103 | ) 104 | opt_ref.apply_gradients( 105 | zip([grads0, grads1], [var0_ref, var1_ref]) 106 | ) 107 | 108 | # Validate updated params 109 | self.assertAllCloseAccordingToType( 110 | self.evaluate(var0_ref), 111 | self.evaluate(var0), 112 | rtol=1e-4, 113 | atol=1e-4, 114 | ) 115 | self.assertAllCloseAccordingToType( 116 | self.evaluate(var1_ref), 117 | self.evaluate(var1), 118 | rtol=1e-4, 119 | atol=1e-4, 120 | ) 121 | 122 | def testSparse(self): 123 | for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: 124 | for amsgrad in [False, True]: 125 | with ops.Graph().as_default(), self.cached_session( 126 | use_gpu=True 127 | ): 128 | var0_np = np.array( 129 | [1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype 130 | ) 131 | grads0_np = np.array( 132 | [0.1, 0.0, 0.1], dtype=dtype.as_numpy_dtype 133 | ) 134 | var1_np = np.array( 135 | [3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype 136 | ) 137 | grads1_np = np.array( 138 | [0.01, 0.0, 0.01], dtype=dtype.as_numpy_dtype 139 | ) 140 | 141 | var0 = variables.Variable(var0_np) 142 | var1 = variables.Variable(var1_np) 143 | var0_ref = variables.Variable(var0_np) 144 | var1_ref = variables.Variable(var1_np) 145 | grads0_np_indices = np.array([0, 2], dtype=np.int32) 146 | grads0 = IndexedSlices( 147 | constant_op.constant(grads0_np[grads0_np_indices]), 148 | constant_op.constant(grads0_np_indices), 149 | constant_op.constant([3]), 150 | ) 151 | grads1_np_indices = np.array([0, 2], dtype=np.int32) 152 | grads1 = IndexedSlices( 153 | constant_op.constant(grads1_np[grads1_np_indices]), 154 | constant_op.constant(grads1_np_indices), 155 | constant_op.constant([3]), 156 | ) 157 | opt = RiemannianAdam(amsgrad=amsgrad) 158 | update = opt.apply_gradients( 159 | zip([grads0, grads1], [var0, var1]) 160 | ) 161 | opt_ref = adam.Adam(amsgrad=amsgrad) 162 | update_ref = opt_ref.apply_gradients( 163 | zip([grads0, grads1], [var0_ref, var1_ref]) 164 | ) 165 | 166 | self.evaluate(variables.global_variables_initializer()) 167 | beta_1_power, beta_2_power = get_beta_accumulators( 168 | opt, dtype 169 | ) 170 | # Run 3 steps 171 | for t in range(3): 172 | self.assertAllCloseAccordingToType( 173 | 0.9 ** (t + 1), self.evaluate(beta_1_power) 174 | ) 175 | self.assertAllCloseAccordingToType( 176 | 0.999 ** (t + 1), self.evaluate(beta_2_power) 177 | ) 178 | update.run() 179 | update_ref.run() 180 | 181 | # Validate updated params 182 | self.assertAllCloseAccordingToType( 183 | self.evaluate(var0_ref), self.evaluate(var0) 184 | ) 185 | self.assertAllCloseAccordingToType( 186 | self.evaluate(var1_ref), self.evaluate(var1) 187 | ) 188 | 189 | def testSharing(self): 190 | for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: 191 | with ops.Graph().as_default(), self.cached_session(use_gpu=True): 192 | var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) 193 | grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) 194 | var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) 195 | grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) 196 | 197 | var0 = variables.Variable(var0_np) 198 | var1 = variables.Variable(var1_np) 199 | var0_ref = variables.Variable(var0_np) 200 | var1_ref = variables.Variable(var1_np) 201 | grads0 = constant_op.constant(grads0_np) 202 | grads1 = constant_op.constant(grads1_np) 203 | opt = RiemannianAdam() 204 | update1 = opt.apply_gradients( 205 | zip([grads0, grads1], [var0, var1]) 206 | ) 207 | update2 = opt.apply_gradients( 208 | zip([grads0, grads1], [var0, var1]) 209 | ) 210 | 211 | update1_ref = adam.Adam().apply_gradients( 212 | zip([grads0], [var0_ref]) 213 | ) 214 | update2_ref = adam.Adam().apply_gradients( 215 | zip([grads1], [var1_ref]) 216 | ) 217 | 218 | self.evaluate(variables.global_variables_initializer()) 219 | beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) 220 | # Run 3 steps 221 | for t in range(3): 222 | self.assertAllCloseAccordingToType( 223 | 0.9 ** (t + 1), self.evaluate(beta_1_power) 224 | ) 225 | self.assertAllCloseAccordingToType( 226 | 0.999 ** (t + 1), self.evaluate(beta_2_power) 227 | ) 228 | if t % 2 == 0: 229 | update1.run() 230 | else: 231 | update2.run() 232 | 233 | update1_ref.run() 234 | update2_ref.run() 235 | 236 | # Validate updated params 237 | self.assertAllCloseAccordingToType( 238 | self.evaluate(var0_ref), self.evaluate(var0) 239 | ) 240 | self.assertAllCloseAccordingToType( 241 | self.evaluate(var1_ref), self.evaluate(var1) 242 | ) 243 | 244 | 245 | if __name__ == "__main__": 246 | test.main() 247 | -------------------------------------------------------------------------------- /tensorflow_riemopt/optimizers/riemannian_gradient_descent.py: -------------------------------------------------------------------------------- 1 | """Riemannian SGD optimizer implementation. 2 | 3 | Bonnabel, Silvere. "Stochastic gradient descent on Riemannian manifolds." 4 | IEEE Transactions on Automatic Control 58.9 (2013): 2217-2229. 5 | """ 6 | 7 | from tensorflow.python.eager import def_function 8 | from tensorflow.python.framework import ops 9 | from tensorflow.python.keras import backend_config 10 | from tensorflow.python.keras.utils import generic_utils 11 | from tensorflow.python.ops import array_ops 12 | from tensorflow.python.ops import control_flow_ops 13 | from tensorflow.python.ops import math_ops 14 | from tensorflow.python.ops import state_ops 15 | from tensorflow.python.ops import gen_training_ops 16 | from tensorflow.python.framework.tensor import Tensor 17 | from tensorflow.python.framework.indexed_slices import IndexedSlices 18 | from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2 19 | 20 | from tensorflow_riemopt.variable import get_manifold 21 | 22 | 23 | @generic_utils.register_keras_serializable(name="RiemannianSGD") 24 | class RiemannianSGD(OptimizerV2): 25 | """Optimizer that implements the Riemannian SGD algorithm.""" 26 | 27 | _HAS_AGGREGATE_GRAD = True 28 | 29 | def __init__( 30 | self, 31 | learning_rate=0.01, 32 | momentum=0.0, 33 | nesterov=False, 34 | stabilize=None, 35 | name="RiemannianSGD", 36 | **kwargs, 37 | ): 38 | """Construct a new Riemannian SGD optimizer. 39 | 40 | Bonnabel, Silvere. "Stochastic gradient descent on Riemannian 41 | manifolds." IEEE Transactions on Automatic Control 58.9 (2013): 42 | 2217-2229. 43 | 44 | Args: 45 | learning_rate: A `Tensor`, floating point value, or a schedule that is a 46 | `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable that 47 | takes no arguments and returns the actual value to use, The learning 48 | rate. Defaults to 0.001. 49 | momentum: A float hyperparameter >= 0 that accelerates gradient descent 50 | in the relevant direction and dampens oscillations. Defaults to 0, i.e., 51 | vanilla gradient descent. 52 | nesterov: boolean. Whether to apply Nesterov momentum. Defaults to `False`. 53 | stabilize: Project variables back to manifold every `stabilize` steps. 54 | Defaults to `None`. 55 | name: Optional name for the operations created when applying gradients. 56 | Defaults to "RiemannianSGD". 57 | **kwargs: Keyword arguments. Allowed to be one of `"clipnorm"` or 58 | `"clipvalue"`. `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` 59 | (float) clips gradients by value. 60 | 61 | """ 62 | 63 | super().__init__(name, **kwargs) 64 | self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) 65 | self._set_hyper("decay", self._initial_decay) 66 | self._momentum = False 67 | if isinstance(momentum, Tensor) or callable(momentum) or momentum > 0: 68 | self._momentum = True 69 | if isinstance(momentum, (int, float)) and ( 70 | momentum < 0 or momentum > 1 71 | ): 72 | raise ValueError("`momentum` must be between [0, 1].") 73 | self._set_hyper("momentum", momentum) 74 | self.nesterov = nesterov 75 | self.stabilize = stabilize 76 | 77 | def _create_slots(self, var_list): 78 | if self._momentum: 79 | for var in var_list: 80 | self.add_slot(var, "momentum") 81 | 82 | def _prepare_local(self, var_device, var_dtype, apply_state): 83 | super()._prepare_local(var_device, var_dtype, apply_state) 84 | apply_state[(var_device, var_dtype)]["momentum"] = array_ops.identity( 85 | self._get_hyper("momentum", var_dtype) 86 | ) 87 | 88 | @def_function.function(experimental_compile=True) 89 | def _resource_apply_dense(self, grad, var, apply_state=None): 90 | var_device, var_dtype = var.device, var.dtype.base_dtype 91 | coefficients = (apply_state or {}).get( 92 | (var_device, var_dtype) 93 | ) or self._fallback_apply_state(var_device, var_dtype) 94 | 95 | manifold = get_manifold(var) 96 | grad = manifold.egrad2rgrad(var, grad) 97 | 98 | if self._momentum: 99 | momentum = self.get_slot(var, "momentum") 100 | momentum_t = momentum * self._momentum - grad * coefficients["lr_t"] 101 | if self.nesterov: 102 | var_t = manifold.retr( 103 | var, 104 | momentum_t * self._momentum - grad * coefficients["lr_t"], 105 | ) 106 | else: 107 | var_t = manifold.retr(var, momentum_t) 108 | momentum.assign(manifold.transp(var, var_t, momentum_t)) 109 | var_update = var.assign(var_t) 110 | else: 111 | var_update = var.assign( 112 | manifold.retr(var, -grad * coefficients["lr_t"]) 113 | ) 114 | 115 | if self.stabilize is not None: 116 | self._stabilize(var) 117 | return var_update 118 | 119 | @def_function.function(experimental_compile=True) 120 | def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 121 | var_device, var_dtype = var.device, var.dtype.base_dtype 122 | coefficients = (apply_state or {}).get( 123 | (var_device, var_dtype) 124 | ) or self._fallback_apply_state(var_device, var_dtype) 125 | 126 | manifold = get_manifold(var) 127 | grad = manifold.egrad2rgrad(var, grad) 128 | 129 | var_values = array_ops.gather(var, indices) 130 | 131 | if self._momentum: 132 | momentum = self.get_slot(var, "momentum") 133 | momentum_t_values = ( 134 | array_ops.gather(momentum, indices) * self._momentum 135 | - grad * coefficients["lr_t"] 136 | ) 137 | if self.nesterov: 138 | var_t_values = manifold.retr( 139 | var_values, 140 | momentum_t_values * self._momentum 141 | - grad * coefficients["lr_t"], 142 | ) 143 | else: 144 | var_t_values = manifold.retr(var_values, momentum_t_values) 145 | momentum_transp_values = manifold.transp( 146 | var_values, var_t_values, momentum_t_values 147 | ) 148 | momentum.scatter_update( 149 | IndexedSlices(momentum_transp_values, indices) 150 | ) 151 | var_update = var.scatter_update( 152 | IndexedSlices(var_t_values, indices) 153 | ) 154 | else: 155 | var_t_values = manifold.retr( 156 | var_values, -grad * coefficients["lr_t"] 157 | ) 158 | var_update = var.scatter_update( 159 | IndexedSlices(var_t_values, indices) 160 | ) 161 | 162 | if self.stabilize is not None: 163 | self._stabilize(var) 164 | return var_update 165 | 166 | @def_function.function(experimental_compile=True) 167 | def _stabilize(self, var): 168 | if math_ops.floor_mod(self.iterations, self.stabilize) == 0: 169 | manifold = get_manifold(var) 170 | var.assign(manifold.projx(var)) 171 | if self._momentum: 172 | momentum = self.get_slot(var, "momentum") 173 | momentum.assign(manifold.proju(var, momentum)) 174 | 175 | def get_config(self): 176 | config = super().get_config() 177 | config.update( 178 | { 179 | "learning_rate": self._serialize_hyperparameter( 180 | "learning_rate" 181 | ), 182 | "decay": self._serialize_hyperparameter("decay"), 183 | "momentum": self._serialize_hyperparameter("momentum"), 184 | "nesterov": self.nesterov, 185 | } 186 | ) 187 | return config 188 | -------------------------------------------------------------------------------- /tensorflow_riemopt/optimizers/riemannian_gradient_descent_test.py: -------------------------------------------------------------------------------- 1 | """Tests for SGD.""" 2 | 3 | from absl.testing import parameterized 4 | import numpy as np 5 | 6 | from tensorflow.python.eager import context 7 | from tensorflow.python.framework import constant_op 8 | from tensorflow.python.framework import dtypes 9 | from tensorflow.python.framework import ops 10 | from tensorflow.python.keras import combinations 11 | from tensorflow.python.keras.optimizer_v2 import gradient_descent 12 | from tensorflow.python.keras import optimizers 13 | from tensorflow.python.ops import array_ops 14 | from tensorflow.python.ops import math_ops 15 | from tensorflow.python.ops import variables 16 | from tensorflow.python.platform import test 17 | from tensorflow.python.framework.indexed_slices import IndexedSlices 18 | 19 | from tensorflow_riemopt.optimizers.riemannian_gradient_descent import ( 20 | RiemannianSGD, 21 | ) 22 | 23 | 24 | class RiemannianSGDOptimizerTest(test.TestCase, parameterized.TestCase): 25 | def testSparse(self): 26 | for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: 27 | for momentum in [0.0, 0.9]: 28 | for nesterov in [False, True]: 29 | with ops.Graph().as_default(), self.cached_session( 30 | use_gpu=True 31 | ): 32 | var0_np = np.array( 33 | [1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype 34 | ) 35 | grads0_np = np.array( 36 | [0.1, 0.0, 0.1], dtype=dtype.as_numpy_dtype 37 | ) 38 | var1_np = np.array( 39 | [3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype 40 | ) 41 | grads1_np = np.array( 42 | [0.01, 0.0, 0.01], dtype=dtype.as_numpy_dtype 43 | ) 44 | 45 | var0 = variables.Variable(var0_np) 46 | var1 = variables.Variable(var1_np) 47 | var0_ref = variables.Variable(var0_np) 48 | var1_ref = variables.Variable(var1_np) 49 | grads0_np_indices = np.array([0, 2], dtype=np.int32) 50 | grads0 = IndexedSlices( 51 | constant_op.constant(grads0_np[grads0_np_indices]), 52 | constant_op.constant(grads0_np_indices), 53 | constant_op.constant([3]), 54 | ) 55 | grads1_np_indices = np.array([0, 2], dtype=np.int32) 56 | grads1 = IndexedSlices( 57 | constant_op.constant(grads1_np[grads1_np_indices]), 58 | constant_op.constant(grads1_np_indices), 59 | constant_op.constant([3]), 60 | ) 61 | opt = RiemannianSGD() 62 | update = opt.apply_gradients( 63 | zip([grads0, grads1], [var0, var1]) 64 | ) 65 | opt_ref = gradient_descent.SGD() 66 | update_ref = opt_ref.apply_gradients( 67 | zip([grads0, grads1], [var0_ref, var1_ref]) 68 | ) 69 | self.evaluate(variables.global_variables_initializer()) 70 | 71 | # Run 3 steps 72 | for t in range(3): 73 | update.run() 74 | update_ref.run() 75 | 76 | # Validate updated params 77 | self.assertAllCloseAccordingToType( 78 | self.evaluate(var0_ref), self.evaluate(var0) 79 | ) 80 | self.assertAllCloseAccordingToType( 81 | self.evaluate(var1_ref), self.evaluate(var1) 82 | ) 83 | 84 | @combinations.generate(combinations.combine(mode=["graph", "eager"])) 85 | def testBasic(self): 86 | for i, dtype in enumerate( 87 | [dtypes.half, dtypes.float32, dtypes.float64] 88 | ): 89 | for momentum in [0.0, 0.9]: 90 | for nesterov in [False, True]: 91 | with self.cached_session(use_gpu=True): 92 | var0_np = np.array( 93 | [1.0, 2.0], dtype=dtype.as_numpy_dtype 94 | ) 95 | grads0_np = np.array( 96 | [0.1, 0.1], dtype=dtype.as_numpy_dtype 97 | ) 98 | var1_np = np.array( 99 | [3.0, 4.0], dtype=dtype.as_numpy_dtype 100 | ) 101 | grads1_np = np.array( 102 | [0.01, 0.01], dtype=dtype.as_numpy_dtype 103 | ) 104 | 105 | var0 = variables.Variable(var0_np, name="var0_%d" % i) 106 | var1 = variables.Variable(var1_np, name="var1_%d" % i) 107 | var0_ref = variables.Variable( 108 | var0_np, name="var0_ref_%d" % i 109 | ) 110 | var1_ref = variables.Variable( 111 | var1_np, name="var1_ref_%d" % i 112 | ) 113 | grads0 = constant_op.constant(grads0_np) 114 | grads1 = constant_op.constant(grads1_np) 115 | 116 | learning_rate = 0.001 117 | opt = RiemannianSGD( 118 | learning_rate=learning_rate, 119 | momentum=momentum, 120 | nesterov=nesterov, 121 | ) 122 | opt_ref = gradient_descent.SGD( 123 | learning_rate=learning_rate, 124 | momentum=momentum, 125 | nesterov=nesterov, 126 | ) 127 | 128 | if not context.executing_eagerly(): 129 | update = opt.apply_gradients( 130 | zip([grads0, grads1], [var0, var1]) 131 | ) 132 | update_ref = opt_ref.apply_gradients( 133 | zip([grads0, grads1], [var0_ref, var1_ref]) 134 | ) 135 | 136 | self.evaluate(variables.global_variables_initializer()) 137 | 138 | # Run 3 steps 139 | for t in range(3): 140 | if not context.executing_eagerly(): 141 | self.evaluate(update) 142 | self.evaluate(update_ref) 143 | else: 144 | opt.apply_gradients( 145 | zip([grads0, grads1], [var0, var1]) 146 | ) 147 | opt_ref.apply_gradients( 148 | zip([grads0, grads1], [var0_ref, var1_ref]) 149 | ) 150 | 151 | # Validate updated params 152 | self.assertAllCloseAccordingToType( 153 | self.evaluate(var0_ref), 154 | self.evaluate(var0), 155 | rtol=1e-4, 156 | atol=1e-4, 157 | ) 158 | self.assertAllCloseAccordingToType( 159 | self.evaluate(var1_ref), 160 | self.evaluate(var1), 161 | rtol=1e-3, 162 | atol=1e-3, 163 | ) 164 | 165 | 166 | if __name__ == "__main__": 167 | test.main() 168 | -------------------------------------------------------------------------------- /tensorflow_riemopt/variable.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow_riemopt.manifolds import Euclidean 3 | 4 | 5 | def assign_to_manifold(var, manifold): 6 | if not hasattr(var, "shape"): 7 | raise ValueError( 8 | "var should be a valid variable with a 'shape' attribute" 9 | ) 10 | if not manifold.check_shape(var): 11 | raise ValueError("Invalid variable shape {}".format(var.shape)) 12 | setattr(var, "manifold", manifold) 13 | 14 | 15 | def get_manifold(var, default_manifold=Euclidean()): 16 | if not hasattr(var, "shape"): 17 | raise ValueError( 18 | "var should be a valid variable with a 'shape' attribute" 19 | ) 20 | return getattr(var, "manifold", default_manifold) 21 | -------------------------------------------------------------------------------- /tensorflow_riemopt/variable_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow_riemopt import variable 4 | from tensorflow_riemopt.manifolds import Euclidean, Grassmannian 5 | 6 | 7 | class VariableTest(tf.test.TestCase): 8 | def test_variable(self): 9 | euclidean = Euclidean() 10 | grassmannian = Grassmannian() 11 | with self.cached_session(use_gpu=True): 12 | var_1x2 = tf.Variable([[5, 3]]) 13 | var_2x1 = tf.Variable([[5], [3]]) 14 | self.assertEqual( 15 | type(variable.get_manifold(var_1x2)), type(euclidean) 16 | ) 17 | self.assertEqual( 18 | type(variable.get_manifold(var_1x2)), 19 | type(variable.get_manifold(var_2x1)), 20 | ) 21 | with self.assertRaises(ValueError): 22 | variable.assign_to_manifold(var_1x2, grassmannian) 23 | variable.assign_to_manifold(var_2x1, grassmannian) 24 | self.assertEqual( 25 | type(variable.get_manifold(var_2x1)), type(grassmannian) 26 | ) 27 | --------------------------------------------------------------------------------