├── .env ├── .gitignore ├── LICENSE ├── README.md ├── examples ├── Euclidean.ipynb ├── H2.ipynb ├── Heisenberg.ipynb ├── S2.ipynb ├── S2_statistics.ipynb ├── SO3.ipynb ├── SO3_stochastics.ipynb ├── SPD3.ipynb ├── T2.ipynb ├── cylinder.ipynb ├── diffusion_mean.ipynb ├── landmarks.ipynb ├── landmarks_stochastics.ipynb └── lifted_stochastics.ipynb ├── jaxgeometry ├── Riemannian │ ├── Log.py │ ├── __init__.py │ ├── curvature.py │ ├── geodesic.py │ ├── metric.py │ └── parallel_transport.py ├── __init__.py ├── dynamics │ ├── Hamiltonian.py │ ├── MPP_Kunita.py │ ├── MPP_Kunita_Log.py │ ├── MPP_group.py │ ├── MPP_landmarks.py │ ├── MPP_landmarks_Log.py │ ├── flow.py │ └── flow_differential.py ├── framebundle │ ├── FM.py │ └── MPP.py ├── group │ ├── EulerPoincare.py │ ├── LiePoisson.py │ ├── energy.py │ ├── invariant_metric.py │ └── quotient.py ├── groups │ ├── GLN.py │ ├── SON.py │ ├── __init__.py │ └── group.py ├── manifolds │ ├── Euclidean.py │ ├── H2.py │ ├── Heisenberg.py │ ├── S2.py │ ├── SPDN.py │ ├── cylinder.py │ ├── ellipsoid.py │ ├── landmarks.py │ ├── latent.py │ ├── manifold.py │ └── torus.py ├── params.py ├── plotting.py ├── sR │ ├── __init__.py │ └── metric.py ├── setup.py ├── statistics │ ├── Frechet_mean.py │ ├── diffusion_mean.py │ ├── iterative_mle.py │ └── tangent_PCA.py ├── stochastics │ ├── Brownian_coords.py │ ├── Brownian_development.py │ ├── Brownian_inv.py │ ├── Brownian_process.py │ ├── Brownian_sR.py │ ├── Eulerian.py │ ├── Langevin.py │ ├── __init__.py │ ├── diagonal_conditioning.py │ ├── guided_process.py │ ├── product_sde.py │ ├── stochastic_coadjoint.py │ └── stochastic_development.py └── utils.py ├── logo └── stocso31s.jpg ├── makefile ├── papers ├── Heisenberg_score.ipynb ├── SO3_S2_Most_probable_paths_and_development.ipynb ├── most_probable_landmark_paths.ipynb └── most_probable_transformation_for_Kunita_flows.ipynb ├── pyproject.toml ├── requirements.txt ├── setup.py └── tests └── test.sh /.env: -------------------------------------------------------------------------------- 1 | PYTHONPATH=../src 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Add any directories, files, or patterns you don't want to be tracked by version control 2 | .ipynb_checkpoints 3 | __pycache__ 4 | *.swp 5 | *.pyc 6 | old/ 7 | pyvenv.cfg 8 | *.egg-info/ 9 | dist/ 10 | 11 | bin/ 12 | include/ 13 | lib/ 14 | etc/ 15 | pip-* 16 | share/ 17 | backup/ 18 | 19 | # misc 20 | kent_distribution 21 | contributed 22 | 23 | # output figures 24 | *.pdf 25 | *.png 26 | 27 | # temporary files 28 | .tmp/ 29 | .DS_store 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![](logo/stocso31s.jpg) 2 | 3 | # Jax Geometry # 4 | 5 | The code in this repository is based on the papers *Differential geometry and stochastic dynamics with deep learning numerics* [arXiv:1712.08364](https://arxiv.org/abs/1712.08364) and *Computational Anatomy in Theano* [arXiv:1706.07690](https://arxiv.org/abs/1706.07690). 6 | 7 | The code is a reimplementation of the Theano Geometry library https://bitbucket.org/stefansommer/jaxgeometry/ replacing Theano with Jax https://github.com/google/jax. 8 | 9 | The source repository is at https://bitbucket.org/stefansommer/jaxgeometry/ 10 | 11 | ### Who do I talk to? ### 12 | 13 | Please contact Stefan Sommer *sommer@di.ku.dk* 14 | 15 | ### Installation Instructions ### 16 | 17 | Please use Python 3.X. 18 | 19 | #### pip: 20 | Install with 21 | ``` 22 | pip install jaxdifferentialgeometry 23 | ``` 24 | 25 | #### from the repository: 26 | Check out the source with git sandiInstall required packages: 27 | ``` 28 | pip install -r requirements.txt 29 | ``` 30 | Use e.g. a Python 3 virtualenv: 31 | ``` 32 | virtualenv -p python3 . 33 | source bin/activate 34 | pip install -r requirements.txt 35 | ``` 36 | If you don't use a virtual environment, make sure that you are actually using Python 3, e.g. use pip3 instead of pip. 37 | 38 | Alternatively, use conda: 39 | ``` 40 | conda install -c conda-forge jaxlib 41 | conda install -c conda-forge jax 42 | ``` 43 | and similarly for the remaining requirements in requirements.txt. 44 | 45 | ### Viewing the example notebooks 46 | After cloning the source repository, start jupyter notebook 47 | ``` 48 | PYTHONPATH=$(pwd)/src jupyter notebook 49 | ``` 50 | Your browser should now open and you can find the example Jax Geometry notebooks in the examples folder. 51 | 52 | ### Why Jax? ### 53 | Some good discussions about the architectural differences between autodiff frameworks: https://www.assemblyai.com/blog/why-you-should-or-shouldnt-be-using-jax-in-2022/ and http://www.stochasticlifestyle.com/engineering-trade-offs-in-automatic-differentiation-from-tensorflow-and-pytorch-to-jax-and-julia/ 54 | -------------------------------------------------------------------------------- /examples/SO3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2021-02-18T11:05:48.736791Z", 9 | "start_time": "2021-02-18T11:05:48.732894Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "## This file is part of Jax Geometry\n", 15 | "#\n", 16 | "# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)\n", 17 | "# https://bitbucket.org/stefansommer/jaxgeometry\n", 18 | "#\n", 19 | "# Jax Geometry is free software: you can redistribute it and/or modify\n", 20 | "# it under the terms of the GNU General Public License as published by\n", 21 | "# the Free Software Foundation, either version 3 of the License, or\n", 22 | "# (at your option) any later version.\n", 23 | "#\n", 24 | "# Jax Geometry is distributed in the hope that it will be useful,\n", 25 | "# but WITHOUT ANY WARRANTY; without even the implied warranty of\n", 26 | "# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n", 27 | "# GNU General Public License for more details.\n", 28 | "#\n", 29 | "# You should have received a copy of the GNU General Public License\n", 30 | "# along with Jax Geometry. If not, see .\n", 31 | "#" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "# SO(3) group operations and dynamics" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "%load_ext autoreload\n", 48 | "%autoreload 2" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "ExecuteTime": { 56 | "end_time": "2021-02-18T11:06:54.715340Z", 57 | "start_time": "2021-02-18T11:05:48.760484Z" 58 | } 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "from jaxgeometry.groups.SON import *\n", 63 | "G = SON(3)\n", 64 | "print(G)\n", 65 | "from jaxgeometry.plotting import *\n", 66 | "#%matplotlib notebook" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": { 73 | "ExecuteTime": { 74 | "end_time": "2021-02-18T11:07:03.517081Z", 75 | "start_time": "2021-02-18T11:06:54.718336Z" 76 | } 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "# visualization\n", 81 | "newfig()\n", 82 | "G.plotg(G.e)\n", 83 | "plt.show()\n", 84 | "\n", 85 | "# geodesics in three directions\n", 86 | "v=jnp.array([1,0,0])\n", 87 | "xiv=G.VtoLA(v)\n", 88 | "(ts,gsv) = G.expt(xiv)\n", 89 | "newfig()\n", 90 | "G.plot_path(gsv)\n", 91 | "plt.show()\n", 92 | "\n", 93 | "v=jnp.array([0,1,0])\n", 94 | "xiv=G.VtoLA(v)\n", 95 | "(ts,gsv) = G.expt(xiv)\n", 96 | "newfig()\n", 97 | "G.plot_path(gsv)\n", 98 | "plt.show()\n", 99 | "\n", 100 | "v=jnp.array([0,0,1])\n", 101 | "xiv=G.VtoLA(v)\n", 102 | "(ts,gsv) = G.expt(xiv)\n", 103 | "newfig()\n", 104 | "G.plot_path(gsv)\n", 105 | "plt.show()" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": { 112 | "ExecuteTime": { 113 | "end_time": "2021-02-18T11:07:10.319275Z", 114 | "start_time": "2021-02-18T11:07:03.520364Z" 115 | } 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "# plot path on S2\n", 120 | "from jaxgeometry.manifolds.S2 import *\n", 121 | "M = S2()\n", 122 | "print(M)\n", 123 | "\n", 124 | "# plot\n", 125 | "newfig()\n", 126 | "M.plot()\n", 127 | "x = M.F(M.coords([0.,0.]))\n", 128 | "M.plot_path(M.acts(gsv,x))\n", 129 | "plt.show()" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": { 136 | "ExecuteTime": { 137 | "end_time": "2021-02-18T11:08:17.862357Z", 138 | "start_time": "2021-02-18T11:07:10.321261Z" 139 | } 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "# setup for testing different versions of dynamics\n", 144 | "q = jnp.array([1e-3,0.,0.])\n", 145 | "g = G.psi(q)\n", 146 | "v = jnp.array([0.,1.,1.])\n", 147 | "\n", 148 | "from jaxgeometry.group import invariant_metric\n", 149 | "invariant_metric.initialize(G)\n", 150 | "p = G.sharppsi(q,v)\n", 151 | "mu = G.sharpV(v)\n", 152 | "print(p)\n", 153 | "print(mu)\n", 154 | "\n", 155 | "from jaxgeometry.group import energy\n", 156 | "energy.initialize(G)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": { 163 | "ExecuteTime": { 164 | "end_time": "2021-02-18T11:08:36.457769Z", 165 | "start_time": "2021-02-18T11:08:17.864547Z" 166 | } 167 | }, 168 | "outputs": [], 169 | "source": [ 170 | "# Euler-Poincare dynamics\n", 171 | "from jaxgeometry.group import EulerPoincare\n", 172 | "EulerPoincare.initialize(G)# Euler-Poincare dynamics\n", 173 | "\n", 174 | "# geodesic\n", 175 | "(ts,gsv) = G.ExpEPt(G.psi(q),v)\n", 176 | "newfig()\n", 177 | "G.plot_path(gsv)\n", 178 | "plt.show()\n", 179 | "(ts,musv) = G.EP(mu)\n", 180 | "xisv = [G.invFl(mu) for mu in musv]\n", 181 | "print(\"Energy: \",np.array([G.l(xi) for xi in xisv]))\n", 182 | "print(\"Orthogonality: \",np.array([np.linalg.norm(np.dot(g,g.T)-np.eye(int(np.sqrt(G.emb_dim))),np.inf) for g in gsv]))\n", 183 | "\n", 184 | "# on S2\n", 185 | "newfig()\n", 186 | "M.plot(rotate=(30,-15))\n", 187 | "x = jnp.array([0,0,1])\n", 188 | "M.plot_path(M.acts(gsv,x))\n", 189 | "plt.show()" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": { 196 | "ExecuteTime": { 197 | "end_time": "2021-02-18T11:08:50.590252Z", 198 | "start_time": "2021-02-18T11:08:36.459751Z" 199 | } 200 | }, 201 | "outputs": [], 202 | "source": [ 203 | "# Lie-Poission dynamics\n", 204 | "from jaxgeometry.group import LiePoisson\n", 205 | "LiePoisson.initialize(G)\n", 206 | "\n", 207 | "# geodesic\n", 208 | "(ts,gsv) = G.ExpLPt(G.psi(q),v)\n", 209 | "newfig()\n", 210 | "G.plot_path(gsv)\n", 211 | "plt.show()\n", 212 | "(ts,musv) = G.LP(mu)\n", 213 | "print(\"Energy: \",np.array([G.Hminus(mu) for mu in musv]))\n", 214 | "print(\"Orthogonality: \",np.array([np.linalg.norm(np.dot(g,g.T)-np.eye(int(np.sqrt(G.dim))),np.inf) for g in gsv]))" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": { 221 | "ExecuteTime": { 222 | "end_time": "2021-02-18T11:12:05.111339Z", 223 | "start_time": "2021-02-18T11:08:50.592810Z" 224 | } 225 | }, 226 | "outputs": [], 227 | "source": [ 228 | "# Hamiltonian dynamics\n", 229 | "from jaxgeometry.dynamics import Hamiltonian\n", 230 | "Hamiltonian.initialize(G)\n", 231 | "\n", 232 | "# test Hamiltionian and gradients\n", 233 | "print(p)\n", 234 | "print(G.H(q,p))\n", 235 | "\n", 236 | "# geodesic\n", 237 | "qsv,_ = G.Exp_Hamiltoniant((q,None),p)\n", 238 | "gsv = np.array([G.psi(q) for q in qsv])\n", 239 | "newfig()\n", 240 | "G.plot_path(gsv)\n", 241 | "plt.show()\n", 242 | "(ts,qpsv,_) = G.Hamiltonian_dynamics((q,None),p,dts())\n", 243 | "psv = qpsv[:,1,:]\n", 244 | "print(\"Energy: \",np.array([G.H(q,p) for (q,p) in zip(qsv,psv)]))" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [] 253 | } 254 | ], 255 | "metadata": { 256 | "kernelspec": { 257 | "display_name": "Python 3 (ipykernel)", 258 | "language": "python", 259 | "name": "python3" 260 | }, 261 | "language_info": { 262 | "codemirror_mode": { 263 | "name": "ipython", 264 | "version": 3 265 | }, 266 | "file_extension": ".py", 267 | "mimetype": "text/x-python", 268 | "name": "python", 269 | "nbconvert_exporter": "python", 270 | "pygments_lexer": "ipython3", 271 | "version": "3.12.1" 272 | } 273 | }, 274 | "nbformat": 4, 275 | "nbformat_minor": 4 276 | } 277 | -------------------------------------------------------------------------------- /examples/SO3_stochastics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2021-02-18T19:20:32.491118Z", 9 | "start_time": "2021-02-18T19:20:32.488147Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "## This file is part of Jax Geometry\n", 15 | "#\n", 16 | "# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)\n", 17 | "# https://bitbucket.org/stefansommer/jaxgeometry\n", 18 | "#\n", 19 | "# Jax Geometry is free software: you can redistribute it and/or modify\n", 20 | "# it under the terms of the GNU General Public License as published by\n", 21 | "# the Free Software Foundation, either version 3 of the License, or\n", 22 | "# (at your option) any later version.\n", 23 | "#\n", 24 | "# Jax Geometry is distributed in the hope that it will be useful,\n", 25 | "# but WITHOUT ANY WARRANTY; without even the implied warranty of\n", 26 | "# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n", 27 | "# GNU General Public License for more details.\n", 28 | "#\n", 29 | "# You should have received a copy of the GNU General Public License\n", 30 | "# along with Jax Geometry. If not, see .\n", 31 | "#" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": { 37 | "collapsed": true 38 | }, 39 | "source": [ 40 | "# Stochastic Lie group dynamics" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "%load_ext autoreload\n", 50 | "%autoreload 2" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "ExecuteTime": { 58 | "end_time": "2021-02-18T19:21:18.542521Z", 59 | "start_time": "2021-02-18T19:20:32.493368Z" 60 | } 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "# SO(3)\n", 65 | "from jaxgeometry.groups.SON import *\n", 66 | "G = SON(3)\n", 67 | "print(G)\n", 68 | "\n", 69 | "# SO(3) acts on S^2\n", 70 | "from jaxgeometry.manifolds.S2 import *\n", 71 | "M = S2()\n", 72 | "print(M)\n", 73 | "\n", 74 | "from jaxgeometry.plotting import *" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": { 81 | "ExecuteTime": { 82 | "end_time": "2021-02-18T19:22:13.618982Z", 83 | "start_time": "2021-02-18T19:21:18.544893Z" 84 | } 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "# setup for testing different versions of stochastic dynamics\n", 89 | "q = jnp.array([1e-3,0.,0.])\n", 90 | "g = G.psi(q)\n", 91 | "v = jnp.array([0.,1.,1.])\n", 92 | "\n", 93 | "x = M.coords(jnp.array([0.,0.]))\n", 94 | "\n", 95 | "from jaxgeometry.group import invariant_metric\n", 96 | "invariant_metric.initialize(G)\n", 97 | "p = G.sharppsi(q,v)\n", 98 | "mu = G.sharpV(v)\n", 99 | "print(p)\n", 100 | "print(mu)\n", 101 | "\n", 102 | "from jaxgeometry.group import energy\n", 103 | "energy.initialize(G)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": { 110 | "ExecuteTime": { 111 | "end_time": "2021-02-18T19:22:36.715889Z", 112 | "start_time": "2021-02-18T19:22:13.621269Z" 113 | }, 114 | "scrolled": false 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "# Brownian motion\n", 119 | "from jaxgeometry.stochastics import Brownian_inv\n", 120 | "Brownian_inv.initialize(G)\n", 121 | "\n", 122 | "_dts = dts(n_steps=1000)\n", 123 | "(ts,gs,_) = G.Brownian_inv(g,_dts,dWs(G.dim,_dts))\n", 124 | "\n", 125 | "# plot\n", 126 | "newfig()\n", 127 | "G.plot_path(gs,linewidth=0.1,alpha=0.1)\n", 128 | "plt.show()\n", 129 | "#plt.savefig('stocso3.pdf')\n", 130 | "\n", 131 | "# on S2\n", 132 | "newfig()\n", 133 | "M.plot()\n", 134 | "M.plot_path(M.acts(gs,M.F(x)))\n", 135 | "plt.show()\n", 136 | "#plt.savefig('stocso32.pdf')" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": { 143 | "ExecuteTime": { 144 | "end_time": "2021-02-18T19:22:53.566779Z", 145 | "start_time": "2021-02-18T19:22:36.717802Z" 146 | } 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "# Brownian processes\n", 151 | "from jaxgeometry.stochastics import Brownian_process\n", 152 | "Brownian_process.initialize(G)\n", 153 | "\n", 154 | "_dts = dts(n_steps=1000)\n", 155 | "(ts,gs,_) = G.Brownian_process(g,_dts,dWs(G.dim,_dts))\n", 156 | "\n", 157 | "# plot\n", 158 | "newfig()\n", 159 | "G.plot_path(gs,color_intensity=1,alpha=0.1)\n", 160 | "plt.show()" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "# Euler-Poincare dynamics\n", 170 | "from jaxgeometry.group import EulerPoincare\n", 171 | "EulerPoincare.initialize(G)# Euler-Poincare dynamics" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": { 178 | "ExecuteTime": { 179 | "end_time": "2021-02-18T19:23:24.463449Z", 180 | "start_time": "2021-02-18T19:22:53.568769Z" 181 | }, 182 | "scrolled": false 183 | }, 184 | "outputs": [], 185 | "source": [ 186 | "# Stochastic coadjoint motion\n", 187 | "from jaxgeometry.stochastics import stochastic_coadjoint\n", 188 | "stochastic_coadjoint.initialize(G)\n", 189 | "\n", 190 | "_dts = dts(n_steps=1000)\n", 191 | "(ts,mus,_) = G.stochastic_coadjoint(mu,_dts,dWs(G.dim,_dts))\n", 192 | "(ts,gs) = G.stochastic_coadjointrec(g,mus,_dts)\n", 193 | "\n", 194 | "# plot\n", 195 | "newfig()\n", 196 | "G.plot_path(gs,color_intensity=1,alpha=0.1)\n", 197 | "plt.show()\n", 198 | "#plt.savefig('coadgeo.pdf')" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": { 205 | "ExecuteTime": { 206 | "end_time": "2021-02-18T19:24:29.056415Z", 207 | "start_time": "2021-02-18T19:23:24.465976Z" 208 | }, 209 | "scrolled": false 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "# Delyon/Hu guided process\n", 214 | "from jaxgeometry.stochastics.guided_process import *\n", 215 | "\n", 216 | "# guide function\n", 217 | "phi = lambda g,v,sigma: jnp.tensordot(G.inv(sigma),G.LAtoV(G.log(G.invtrns(G.inv(g[0]),v))),(1,0))\n", 218 | "A = lambda g,v,w,sigma: G.gG(g[0],v,w,sigma)\n", 219 | "logdetA = lambda x,sigma: -2*(jnp.linalg.slogdet(sigma)[1])\n", 220 | "\n", 221 | "(Brownian_inv_guided,sde_Brownian_coords_inv,_,_,_) = get_guided(\n", 222 | " G,G.sde_Brownian_inv,None,phi,\n", 223 | " lambda g,sigma: sigma,A,logdetA,integration='stratonovich')\n", 224 | "\n", 225 | "_dts = dts(n_steps=1000)\n", 226 | "(ts,gs,_,log_likelihood,log_varphi) = Brownian_inv_guided((g,None),G.psi(v),_dts,dWs(G.dim,_dts),\n", 227 | " jnp.sqrt(.2)*jnp.eye(G.dim))\n", 228 | "print(\"log likelihood: \", log_likelihood[-1], \", log varphi: \", log_varphi[-1])\n", 229 | "\n", 230 | "newfig()\n", 231 | "w = G.psi(v)\n", 232 | "G.plot_path(gs,linewidth=0.1,alpha=0.1)\n", 233 | "G.plotg(w,color='k')\n", 234 | "plt.show()\n", 235 | "\n", 236 | "# on S2\n", 237 | "newfig()\n", 238 | "M.plot()\n", 239 | "M.plot_path(M.acts(gs,M.F(x)))\n", 240 | "M.plotx(M.act(w,M.F(x)),color='k')\n", 241 | "plt.show()" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [] 250 | } 251 | ], 252 | "metadata": { 253 | "kernelspec": { 254 | "display_name": "Python 3 (ipykernel)", 255 | "language": "python", 256 | "name": "python3" 257 | }, 258 | "language_info": { 259 | "codemirror_mode": { 260 | "name": "ipython", 261 | "version": 3 262 | }, 263 | "file_extension": ".py", 264 | "mimetype": "text/x-python", 265 | "name": "python", 266 | "nbconvert_exporter": "python", 267 | "pygments_lexer": "ipython3", 268 | "version": "3.12.1" 269 | } 270 | }, 271 | "nbformat": 4, 272 | "nbformat_minor": 1 273 | } 274 | -------------------------------------------------------------------------------- /examples/SPD3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "## This file is part of Jax Geometry\n", 10 | "#\n", 11 | "# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)\n", 12 | "# https://bitbucket.org/stefansommer/jaxgeometry\n", 13 | "#\n", 14 | "# Jax Geometry is free software: you can redistribute it and/or modify\n", 15 | "# it under the terms of the GNU General Public License as published by\n", 16 | "# the Free Software Foundation, either version 3 of the License, or\n", 17 | "# (at your option) any later version.\n", 18 | "#\n", 19 | "# Jax Geometry is distributed in the hope that it will be useful,\n", 20 | "# but WITHOUT ANY WARRANTY; without even the implied warranty of\n", 21 | "# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n", 22 | "# GNU General Public License for more details.\n", 23 | "#\n", 24 | "# You should have received a copy of the GNU General Public License\n", 25 | "# along with Jax Geometry. If not, see .\n", 26 | "#" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "collapsed": true 33 | }, 34 | "source": [ 35 | "# GLN and SPDN dynamics" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "%load_ext autoreload\n", 45 | "%autoreload 2" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": { 52 | "scrolled": false 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "from jaxgeometry.groups.GLN import *\n", 57 | "G = GLN(3)\n", 58 | "print(G)\n", 59 | "\n", 60 | "from jaxgeometry.manifolds.SPDN import *\n", 61 | "M = SPDN(3)\n", 62 | "print(M)\n", 63 | "\n", 64 | "from jaxgeometry.plotting import *\n", 65 | "figsize = 12,12\n", 66 | "plt.rcParams['figure.figsize'] = figsize" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "# some values\n", 76 | "v=np.array([.5,0,0,0,0,0,0,0,0])+1e-6*np.random.normal(size=G.dim) # must be non-singular for Expm derivative\n", 77 | "xiv=G.VtoLA(v)\n", 78 | "x = G.exp(xiv)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": { 85 | "scrolled": false 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "# visualization\n", 90 | "newfig()\n", 91 | "G.plotg(x)\n", 92 | "plt.show()\n", 93 | "\n", 94 | "_dts = dts()\n", 95 | "gsv = np.zeros((_dts.shape[0],3,3))\n", 96 | "for i in range(_dts.shape[0]):\n", 97 | " gsv[i] = G.exp(_dts[i]*xiv)\n", 98 | "newfig()\n", 99 | "G.plot_path(gsv)\n", 100 | "plt.show()\n", 101 | "\n", 102 | "# on SPD(3)\n", 103 | "newfig()\n", 104 | "M.plot()\n", 105 | "x0 = np.eye(M.N).flatten()\n", 106 | "M.plot_path(M.acts(gsv,x0))\n", 107 | "plt.show()\n", 108 | "\n", 109 | "# ellipsoids\n", 110 | "plt.rcParams['figure.figsize'] = 23, 10\n", 111 | "M.plot_path(M.acts(gsv,x0),ellipsoid={'alpha': .2, 'step': _dts.shape[0]/4, 'subplot': True})\n", 112 | "plt.show()\n", 113 | "plt.rcParams['figure.figsize'] = figsize" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "# define invariant metric on GL(N)\n", 123 | "from jaxgeometry.group import invariant_metric\n", 124 | "invariant_metric.initialize(G)\n", 125 | "from jaxgeometry.group import energy\n", 126 | "energy.initialize(G)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": { 133 | "scrolled": false 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "# Euler-Poincare dynamics\n", 138 | "from jaxgeometry.group import EulerPoincare\n", 139 | "EulerPoincare.initialize(G)\n", 140 | "\n", 141 | "# geodesic\n", 142 | "(ts,gsv) = G.ExpEPt(G.e,v)\n", 143 | "newfig()\n", 144 | "G.plot_path(gsv)\n", 145 | "plt.show()\n", 146 | "(ts,musv) = G.EP(v)\n", 147 | "xisv = [G.invFl(mu) for mu in musv]\n", 148 | "print(\"Energy: \",np.array([G.l(xi) for xi in xisv]))\n", 149 | "\n", 150 | "# on SPD(3)\n", 151 | "newfig()\n", 152 | "M.plot()\n", 153 | "x0 = np.eye(M.N).flatten()\n", 154 | "M.plot_path(M.acts(gsv,x0))\n", 155 | "plt.show()\n", 156 | "\n", 157 | "# ellipsoids\n", 158 | "plt.rcParams['figure.figsize'] = 23, 10\n", 159 | "M.plot_path(M.acts(gsv,x0),ellipsoid={'alpha': .2, 'step': dts().shape[0]/4, 'subplot': True})\n", 160 | "plt.show()\n", 161 | "plt.rcParams['figure.figsize'] = figsize" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "# Lie-Poission dynamics\n", 171 | "from jaxgeometry.group import LiePoisson\n", 172 | "LiePoisson.initialize(G)\n", 173 | "\n", 174 | "# geodesic\n", 175 | "(ts,gsv) = G.ExpLPt(G.e,v)\n", 176 | "newfig()\n", 177 | "G.plot_path(gsv)\n", 178 | "plt.show()\n", 179 | "(ts,musv) = G.LP(v)\n", 180 | "print(\"Energy: \",np.array([G.Hminus(mu) for mu in musv]))" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "# Brownian motion\n", 190 | "from jaxgeometry.stochastics import Brownian_inv\n", 191 | "Brownian_inv.initialize(G)\n", 192 | "\n", 193 | "_dts = dts(n_steps=100)\n", 194 | "(ts,gs,_) = G.Brownian_inv(G.e,_dts,dWs(G.dim,_dts),jnp.sqrt(.1)*jnp.eye(G.emb_dim))\n", 195 | "\n", 196 | "# on SPD(3)\n", 197 | "newfig()\n", 198 | "M.plot()\n", 199 | "x0 = np.eye(M.N).flatten()\n", 200 | "M.plot_path(M.acts(gs,x0))\n", 201 | "plt.show()\n", 202 | "\n", 203 | "# ellipsoids\n", 204 | "plt.rcParams['figure.figsize'] = 23, 10\n", 205 | "M.plot_path(M.acts(gsv,x0),ellipsoid={'alpha': .2, 'step': _dts.shape[0]/8, 'subplot': True})\n", 206 | "# plt.savefig('SPD3-path.pdf')\n", 207 | "plt.show()\n", 208 | "plt.rcParams['figure.figsize'] = figsize" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [] 217 | } 218 | ], 219 | "metadata": { 220 | "kernelspec": { 221 | "display_name": "Python 3 (ipykernel)", 222 | "language": "python", 223 | "name": "python3" 224 | }, 225 | "language_info": { 226 | "codemirror_mode": { 227 | "name": "ipython", 228 | "version": 3 229 | }, 230 | "file_extension": ".py", 231 | "mimetype": "text/x-python", 232 | "name": "python", 233 | "nbconvert_exporter": "python", 234 | "pygments_lexer": "ipython3", 235 | "version": "3.9.14" 236 | } 237 | }, 238 | "nbformat": 4, 239 | "nbformat_minor": 1 240 | } 241 | -------------------------------------------------------------------------------- /jaxgeometry/Riemannian/Log.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | def initialize(M,f=None,lossf=None,method='BFGS'): 24 | """ numerical Riemannian Logarithm map """ 25 | 26 | if f is None: 27 | print("using M.Exp for Logarithm") 28 | f = M.Exp 29 | def loss(x,v,y): 30 | (x1,chart1) = f(x,v) 31 | y_chart1 = M.update_coords(y,chart1) 32 | if lossf is None: 33 | return 1./M.dim*jnp.sum(jnp.square(x1 - y_chart1[0])) 34 | else: 35 | return lossf(x1,y_chart1[0]) 36 | dloss = jax.grad(loss,1) 37 | #from scipy.optimize import approx_fprime 38 | #dloss = lambda x,v,y: approx_fprime(v,lambda v: loss(x,v,y),1e-4) 39 | 40 | from jax.scipy.optimize import minimize 41 | def shoot(x,y,v0=None): 42 | 43 | if v0 is None: 44 | v0 = jnp.zeros(M.dim) 45 | 46 | res = minimize(lambda w: loss(x,w,y), v0, method=method, options={'maxiter': 100}) 47 | 48 | return (res.x,res.fun) 49 | 50 | M.Log = shoot 51 | -------------------------------------------------------------------------------- /jaxgeometry/Riemannian/__init__.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | -------------------------------------------------------------------------------- /jaxgeometry/Riemannian/curvature.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.utils import * 23 | def initialize(M): 24 | """ Riemannian curvature """ 25 | 26 | """ 27 | Riemannian Curvature tensor 28 | 29 | Args: 30 | x: point on manifold 31 | 32 | Returns: 33 | 4-tensor R_ijk^l in with order i,j,k,l 34 | (see e.g. https://en.wikipedia.org/wiki/List_of_formulas_in_Riemannian_geometry#(3,1)_Riemann_curvature_tensor ) 35 | Note that sign convention follows e.g. Lee, Riemannian Manifolds. 36 | """ 37 | M.R = jit(lambda x: -(jnp.einsum('pik,ljp->ijkl',M.Gamma_g(x),M.Gamma_g(x)) 38 | -jnp.einsum('pjk,lip->ijkl',M.Gamma_g(x),M.Gamma_g(x)) 39 | +jnp.einsum('likj->ijkl',M.DGamma_g(x)) 40 | -jnp.einsum('ljki->ijkl',M.DGamma_g(x)))) 41 | 42 | """ 43 | Riemannian Curvature form 44 | R_u (also denoted Omega) is the gl(n)-valued curvature form u^{-1}Ru for a frame 45 | u for T_xM 46 | 47 | Args: 48 | x: point on manifold 49 | 50 | Returns: 51 | 4-tensor (R_u)_ij^m_k with order i,j,m,k 52 | """ 53 | M.R_u = jit(lambda x,u: jnp.einsum('ml,ijql,qk->ijmk',jnp.linalg.inv(u),R(x),u)) 54 | 55 | # """ 56 | # Sectional curvature 57 | # 58 | # Args: 59 | # x: point on manifold 60 | # e1,e2: two orthonormal vectors spanning the section 61 | # 62 | # Returns: 63 | # sectional curvature K(e1,e2) 64 | # """ 65 | # @jit 66 | # def sec_curv(x,e1,e2): 67 | # Rflat = jnp.tensordot(M.R(x),M.g(x),[3,0]) 68 | # sec = jnp.tensordot( 69 | # jnp.tensordot( 70 | # jnp.tensordot( 71 | # jnp.tensordot( 72 | # Rflat, 73 | # e1, [0,0]), 74 | # e2, [0,0]), 75 | # e2, [0,0]), 76 | # e1, [0,0]) 77 | # return sec 78 | # M.sec_curv = sec_curv 79 | 80 | """ 81 | Ricci curvature 82 | 83 | Args: 84 | x: point on manifold 85 | 86 | Returns: 87 | 2-tensor R_ij in order i,j 88 | """ 89 | M.Ricci_curv = jit(lambda x: jnp.einsum('kijk->ij',M.R(x))) 90 | 91 | """ 92 | Scalar curvature 93 | 94 | Args: 95 | x: point on manifold 96 | 97 | Returns: 98 | scalar curvature 99 | """ 100 | M.S_curv = jit(lambda x: jnp.einsum('ij,ij->',M.gsharp(x),M.Ricci_curv(x))) 101 | 102 | -------------------------------------------------------------------------------- /jaxgeometry/Riemannian/geodesic.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | def initialize(M): 24 | def ode_geodesic(c,y): 25 | t,x,chart = c 26 | dx2t = -jnp.einsum('ikl,k,l->i',M.Gamma_g((x[0],chart)),x[1],x[1]) 27 | dx1t = x[1] 28 | return jnp.stack((dx1t,dx2t)) 29 | 30 | def chart_update_geodesic(xv,chart,y): 31 | if M.do_chart_update is None: 32 | return (xv,chart) 33 | 34 | v = xv[1] 35 | x = (xv[0],chart) 36 | 37 | update = M.do_chart_update(x) 38 | new_chart = M.centered_chart(x) 39 | new_x = M.update_coords(x,new_chart)[0] 40 | 41 | return (jnp.where(update, 42 | jnp.stack((new_x,M.update_vector(x,new_x,new_chart,v))), 43 | xv), 44 | jnp.where(update, 45 | new_chart, 46 | chart)) 47 | 48 | M.geodesic = jit(lambda x,v,dts: integrate(ode_geodesic,chart_update_geodesic,jnp.stack((x[0],v)),x[1],dts)) 49 | 50 | def Exp(x,v,T=T,n_steps=n_steps): 51 | curve = M.geodesic(x,v,dts(T,n_steps)) 52 | x = curve[1][-1,0] 53 | chart = curve[2][-1] 54 | return(x,chart) 55 | M.Exp = Exp 56 | def Expt(x,v,T=T,n_steps=n_steps): 57 | curve = M.geodesic(x,v,dts(T,n_steps)) 58 | xs = curve[1][:,0] 59 | charts = curve[2] 60 | return(xs,charts) 61 | M.Expt = Expt 62 | -------------------------------------------------------------------------------- /jaxgeometry/Riemannian/metric.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.utils import * 23 | 24 | def initialize(M,truncate_high_order_derivatives=False): 25 | """ add metric related structures to manifold """ 26 | 27 | d = M.dim 28 | 29 | if hasattr(M, 'g'): 30 | if not hasattr(M, 'gsharp'): 31 | M.gsharp = lambda x: jnp.linalg.inv(M.g(x)) 32 | elif hasattr(M, 'gsharp'): 33 | if not hasattr(M, 'g'): 34 | M.g = lambda x: jnp.linalg.inv(M.gsharp(x)) 35 | else: 36 | raise ValueError('no metric or cometric defined on manifold') 37 | 38 | M.Dg = jacfwdx(M.g) # derivative of metric 39 | 40 | ##### Measure 41 | M.mu_Q = lambda x: 1./jnp.nlinalg.Det()(M.g(x)) 42 | 43 | ### Determinant 44 | def det(x,A=None): 45 | return jnp.linalg.det(M.g(x)) if A is None else jnp.linalg.det(jnp.tensordot(M.g(x),A,(1,0))) 46 | def detsharp(x,A=None): 47 | return jnp.linalg.det(M.gsharp(x)) if A is None else jnp.linalg.det(jnp.tensordot(M.gsharp(x),A,(1,0))) 48 | M.det = det 49 | M.detsharp = detsharp 50 | def logAbsDet(x,A=None): 51 | return jnp.linalg.slogdet(M.g(x))[1] if A is None else jnp.linalg.slogdet(jnp.tensordot(M.g(x),A,(1,0)))[1] 52 | def logAbsDetsharp(x,A=None): 53 | return jnp.linalg.slogdet(M.gsharp(x))[1] if A is None else jnp.linalg.slogdet(jnp.tensordot(M.gsharp(x),A,(1,0)))[1] 54 | M.logAbsDet = logAbsDet 55 | M.logAbsDetsharp = logAbsDetsharp 56 | 57 | ##### Sharp and flat map: 58 | M.flat = lambda x,v: jnp.tensordot(M.g(x),v,(1,0)) 59 | M.sharp = lambda x,p: jnp.tensordot(M.gsharp(x),p,(1,0)) 60 | 61 | ##### Christoffel symbols 62 | # \Gamma^i_{kl}, indices in that order 63 | #M.Gamma_g = lambda x: 0.5*(jnp.einsum('im,mkl->ikl',M.gsharp(x),M.Dg(x)) 64 | # +jnp.einsum('im,mlk->ikl',M.gsharp(x),M.Dg(x)) 65 | # -jnp.einsum('im,klm->ikl',M.gsharp(x),M.Dg(x))) 66 | def Gamma_g(x): 67 | Dgx = M.Dg(x) 68 | gsharpx = M.gsharp(x) 69 | return 0.5*(jnp.einsum('im,kml->ikl',gsharpx,Dgx) 70 | +jnp.einsum('im,lmk->ikl',gsharpx,Dgx) 71 | -jnp.einsum('im,klm->ikl',gsharpx,Dgx)) 72 | M.Gamma_g = Gamma_g 73 | M.DGamma_g = jacfwdx(M.Gamma_g) 74 | 75 | # Inner Product from g 76 | M.dot = lambda x,v,w: jnp.tensordot(jnp.tensordot(M.g(x),w,(1,0)),v,(0,0)) 77 | M.norm = lambda x,v: jnp.sqrt(M.dot(x,v,v)) 78 | M.norm2 = lambda x,v: M.dot(x,v,v) 79 | M.dotsharp = lambda x,p,pp: jnp.tensordot(jnp.tensordot(M.gsharp(x),pp,(1,0)),p,(0,0)) 80 | M.conorm = lambda x,p: jnp.sqrt(M.dotsharp(x,p,p)) 81 | 82 | ##### Gram-Schmidt and basis 83 | M.gramSchmidt = lambda x,u: (GramSchmidt_f(M.dotf))(x,u) 84 | M.orthFrame = lambda x: jnp.linalg.Cholesky(M.gsharp(x)) 85 | 86 | ##### Hamiltonian 87 | M.H = lambda q,p: 0.5*jnp.tensordot(p,jnp.tensordot(M.gsharp(q),p,(1,0)),(0,0)) 88 | 89 | # gradient, divergence, and Laplace-Beltrami 90 | M.grad = lambda x,f: M.sharp(x,gradx(f)(x)) 91 | M.div = lambda x,X: jnp.trace(jacfwdx(X)(x))+.5*jnp.dot(X(x),gradx(M.logAbsDet)(x)) 92 | M.divsharp = lambda x,X: jnp.trace(jacfwdx(X)(x))-.5*jnp.dot(X(x),gradx(M.logAbsDetsharp)(x)) 93 | M.Laplacian = lambda x,f: M.div(x,lambda x: M.grad(x,f)) 94 | -------------------------------------------------------------------------------- /jaxgeometry/Riemannian/parallel_transport.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.utils import * 23 | 24 | def initialize(M): 25 | """ Riemannian parallel transport """ 26 | 27 | def ode_parallel_transport(c,y): 28 | t,xv,prevchart = c 29 | x,chart,dx = y 30 | prevx = xv[0] 31 | v = xv[1] 32 | 33 | if M.do_chart_update is not None: 34 | dx = jnp.where(jnp.sum(jnp.square(chart-prevchart)) <= 1e-5, 35 | dx, 36 | M.update_vector((x,chart),prevx,prevchart,dx) 37 | ) 38 | dv = -jnp.einsum('ikl,k,l->i',M.Gamma_g((x,chart)),dx,v) 39 | return jnp.stack((jnp.zeros_like(x),dv)) 40 | 41 | def chart_update_parallel_transport(xv,prevchart,y): 42 | x,chart,dx = y 43 | if M.do_chart_update is None: 44 | return (xv,chart) 45 | 46 | prevx = xv[0] 47 | v = xv[1] 48 | return (jnp.where(jnp.sum(jnp.square(chart-prevchart)) <= 1e-5, 49 | jnp.stack((x,v)), 50 | jnp.stack((x,M.update_vector((prevx,prevchart),x,chart,v)))), 51 | chart) 52 | 53 | parallel_transport = lambda v,dts,xs,charts,dxs: integrate(ode_parallel_transport,chart_update_parallel_transport,jnp.stack((xs[0],v)),charts[0],dts,xs,charts,dxs) 54 | M.parallel_transport = jit(lambda v,dts,xs,charts,dxs: parallel_transport(v,dts,xs,charts,dxs)[1][:,1]) 55 | -------------------------------------------------------------------------------- /jaxgeometry/__init__.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | -------------------------------------------------------------------------------- /jaxgeometry/dynamics/Hamiltonian.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | ############################################################### 24 | # geodesic integration, Hamiltonian form # 25 | ############################################################### 26 | def initialize(M): 27 | dq = grad(M.H,argnums=1) 28 | dp = lambda q,p: -gradx(M.H)(q,p) 29 | 30 | def ode_Hamiltonian(c,y): 31 | t,x,chart = c 32 | dqt = dq((x[0],chart),x[1]) 33 | dpt = dp((x[0],chart),x[1]) 34 | return jnp.stack((dqt,dpt)) 35 | 36 | def chart_update_Hamiltonian(xp,chart,y): 37 | if M.do_chart_update is None: 38 | return (xp,chart) 39 | 40 | p = xp[1] 41 | x = (xp[0],chart) 42 | 43 | update = M.do_chart_update(x) 44 | new_chart = M.centered_chart(x) 45 | new_x = M.update_coords(x,new_chart)[0] 46 | 47 | return (jnp.where(update, 48 | jnp.stack((new_x,M.update_covector(x,new_x,new_chart,p))), 49 | xp), 50 | jnp.where(update, 51 | new_chart, 52 | chart)) 53 | 54 | M.Hamiltonian_dynamics = jit(lambda q,p,dts: integrate(ode_Hamiltonian,chart_update_Hamiltonian,jnp.stack((q[0] if type(q)==type(()) else q,p)),q[1] if type(q)==type(()) else None,dts)) 55 | 56 | def Exp_Hamiltonian(q,p,T=T,n_steps=n_steps): 57 | curve = M.Hamiltonian_dynamics(q,p,dts(T,n_steps)) 58 | q = curve[1][-1,0] 59 | chart = curve[2][-1] 60 | return(q,chart) 61 | M.Exp_Hamiltonian = Exp_Hamiltonian 62 | def Exp_Hamiltoniant(q,p,T=T,n_steps=n_steps): 63 | curve = M.Hamiltonian_dynamics(q,p,dts(T,n_steps)) 64 | qs = curve[1][:,0] 65 | charts = curve[2] 66 | return(qs,charts) 67 | M.Exp_Hamiltoniant = Exp_Hamiltoniant 68 | -------------------------------------------------------------------------------- /jaxgeometry/dynamics/MPP_Kunita.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.utils import * 23 | 24 | ############################################################### 25 | # Most probable paths for Kunita flows # 26 | ############################################################### 27 | def initialize(M,N,sigmas,u): 28 | """ Most probable paths for Kunita flows """ 29 | """ M: shape manifold, N: embedding space, u: flow field """ 30 | 31 | # Riemannian structure on N 32 | N.gsharp = lambda x: jnp.einsum('pri,qrj->ij',sigmas(x[0]),sigmas(x[0])) 33 | delattr(N,'g') 34 | from jaxgeometry.Riemannian import metric 35 | metric.initialize(N) 36 | 37 | # scalar part of elliptic operator L = 1/2 \Delta_g + z 38 | z = lambda x,qp: (u(x,qp) 39 | -0.25*jnp.einsum('ij,i->j',N.gsharp(x),gradx(N.logAbsDetsharp)(x)) 40 | -0.5*jnp.einsum('iji->j',jacrevx(N.gsharp)(x)) 41 | +0.5*jnp.einsum('...rj,...rii->j',sigmas(x[0]),jax.jacrev(sigmas)(x[0])) 42 | ) 43 | 44 | # Onsager-Machlup deviation from geodesic energy 45 | # f = lambda x,qp: .5*jnp.einsum('rs,sr->',N.gsharp(x), 46 | # jacrevx(z)(x,qp)+jnp.einsum('k,srk->sr',z(x,qp),N.Gamma_g(x)))-1/12*N.S_curv(x) 47 | f = lambda x,qp: .5*N.divsharp(x,lambda x: z(x,qp))-1/12*N.S_curv(x) 48 | 49 | N.u = u 50 | N.z = z 51 | N.f = f 52 | 53 | def ode_MPP_AC(c,y): 54 | t,xx1,chart = c 55 | qp,dqp = y 56 | x = xx1[0] # point 57 | x1 = xx1[1] # derivative 58 | 59 | g = N.g((x,chart)) 60 | gsharp = N.gsharp((x,chart)) 61 | Gamma = N.Gamma_g((x,chart)) 62 | 63 | zx = z((x,chart),qp) 64 | gradz = jacrevx(z)((x,chart),qp) 65 | dz = jnp.einsum('...ij,ij',jax.jacrev(z,argnums=1)((x,chart),qp),dqp) 66 | 67 | dx2 = (dz-jnp.einsum('i,j,kij->k',x1,x1,Gamma) 68 | +jnp.einsum('i,ki->k',x1,gradz+jnp.einsum('kij,j->ki',Gamma,zx)) 69 | -jnp.einsum('rs,ri,s,ik->k',g,gradz+jnp.einsum('j,rij->ri',zx,Gamma),x1-zx,gsharp) 70 | +jnp.einsum('ik,i',gsharp,gradx(f)((x,chart),qp)) 71 | ) 72 | dx1 = x1 73 | return jnp.stack((dx1,dx2)) 74 | 75 | def chart_update_MPP_AC(xv,chart,y): 76 | if M.do_chart_update is None: 77 | return (xv,chart) 78 | 79 | v = xv[1] 80 | x = (xv[0],chart) 81 | 82 | update = M.do_chart_update(x) 83 | new_chart = M.centered_chart(x) 84 | new_x = M.update_coords(x,new_chart)[0] 85 | 86 | return (jnp.where(update, 87 | jnp.stack((new_x,M.update_vector(x,new_x,new_chart,v))), 88 | xv), 89 | jnp.where(update, 90 | new_chart, 91 | chart)) 92 | 93 | M.MPP_AC = jit(lambda x,v,qps,dqps,dts: integrate(ode_MPP_AC,chart_update_MPP_AC,jnp.stack((x[0],v)),x[1],dts,qps,dqps)) 94 | 95 | -------------------------------------------------------------------------------- /jaxgeometry/dynamics/MPP_Kunita_Log.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.utils import * 23 | 24 | ############################################################### 25 | # Most probable paths for Kunita flows - BVP # 26 | ############################################################### 27 | def initialize(M,N): 28 | method='BFGS' 29 | 30 | def loss(x,v,y,qps,dqps,_dts): 31 | (_,xx1,charts) = M.MPP_AC(x,v,qps,dqps,_dts) 32 | (x1,chart1) = (xx1[-1,0],charts[-1]) 33 | y_chart1 = M.update_coords(y,chart1) 34 | return 1./N.dim*jnp.sum(jnp.square(x1 - y_chart1[0])) 35 | from scipy.optimize import approx_fprime 36 | dloss = lambda x,v,y,qps,dqps,_dts: approx_fprime(v,lambda v: loss(x,v,y,qps,dqps,_dts),1e-4) 37 | 38 | from scipy.optimize import minimize,fmin_bfgs,fmin_cg 39 | def shoot(x,y,qps,dqps,_dts,v0=None): 40 | 41 | if v0 is None: 42 | v0 = jnp.zeros(N.dim) 43 | 44 | #res = minimize(jax.value_and_grad(lambda w: loss(x,w,y,qps,dqps,_dts)), v0, method=method, jac=True, options={'disp': False, 'maxiter': 100}) 45 | res = minimize(lambda w: (loss(x,w,y,qps,dqps,_dts),dloss(x,w,y,qps,dqps,_dts)), v0, method=method, jac=True, options={'disp': False, 'maxiter': 100}) 46 | # res = minimize(lambda w: loss(x,w,y,qps,dqps,_dts), v0, method=method, jac=False, options={'disp': False, 'maxiter': 100}) 47 | 48 | # print(res) 49 | 50 | return (res.x,res.fun) 51 | 52 | M.Log_MPP_AC = shoot 53 | -------------------------------------------------------------------------------- /jaxgeometry/dynamics/MPP_group.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.utils import * 23 | 24 | from jaxgeometry.group.quotient import * 25 | 26 | ############################################################### 27 | # Most probable paths for Lie groups via development # 28 | ############################################################### 29 | def initialize(G,Sigma=None,a=None): 30 | """ Most probable paths and development """ 31 | 32 | sign = -1. if G.invariance == 'right' else 1. 33 | Sigma = Sigma if Sigma is not None else jnp.eye(G.dim) 34 | 35 | def ode_mpp(sigma,c,y): 36 | t,alpha,_ = c 37 | 38 | at = a(t) if a is not None else jnp.zeros_like(alpha) 39 | 40 | z = jnp.dot(Sigma,G.sharpV(alpha))+at 41 | dalpha = sign*G.coad(z,alpha) # =-jnp.einsum('k,i,ijk->j',alpha,z,G.C) 42 | return dalpha 43 | G.mpp = lambda alpha,dts,sigma=jnp.eye(G.dim): integrate(partial(ode_mpp,sigma),None,alpha,None,dts) 44 | 45 | # reconstruction 46 | def ode_mpprec(sigma,c,y): 47 | t,g,_ = c 48 | 49 | alpha, = y 50 | at = a(t) if a is not None else jnp.zeros_like(alpha) 51 | 52 | z = jnp.dot(Sigma,G.sharpV(alpha))+at 53 | dgt = G.invpf(g,G.VtoLA(z)) 54 | return dgt 55 | G.mpprec = lambda g,alpha,dts,sigma=jnp.eye(G.dim): integrate(partial(ode_mpprec,sigma),None,g,None,dts,alpha) 56 | 57 | # tracking point (not reduced to Lie algebra) to allow point-depending drift 58 | def ode_mpp_drift(sigma,c,y): 59 | t,x,_ = c 60 | alpha = x[0:G.dim] 61 | g = x[G.dim:].reshape((G.dim,G.dim)) 62 | 63 | at = jnp.linalg.solve(g,a(t,g)) if a is not None else jnp.zeros_like(alpha) 64 | 65 | z = jnp.dot(Sigma,G.sharpV(alpha))+at 66 | dalpha = sign*G.coad(z,alpha) # =-jnp.einsum('k,i,ijk->j',alpha,z,G.C) 67 | dgt = G.invpf(g,G.VtoLA(z)) 68 | return jnp.hstack((dalpha,dgt.flatten())) 69 | G.mpp_drift = lambda alpha,g,dts,sigma=jnp.eye(G.dim): integrate(partial(ode_mpp_drift,sigma),None,jnp.hstack((alpha,g.flatten())),None,dts) 70 | 71 | def MPP_forwardt(g,alpha,sigma,T=T,n_steps=n_steps): 72 | _dts = dts(T=T,n_steps=n_steps) 73 | (ts,alphas) = G.mpp(alpha,_dts,sigma) 74 | (ts,gs) = G.mpprec(g,alphas,_dts,sigma) 75 | 76 | return(gs,alphas) 77 | G.MPP_forwardt = MPP_forwardt 78 | 79 | # optimization to satisfy end-point conditions 80 | def MPP_f(g,alpha,y,sigma): 81 | gs,alphas = G.MPP_forwardt(g,alpha,sigma) 82 | gT = gs[-1] 83 | return (1./G.emb_dim)*jnp.sum(jnp.square(gT-y)) 84 | 85 | def MPP(g,y,sigma=jnp.eye(G.dim)): 86 | res = jax.scipy.optimize.minimize(lambda alpha: MPP_f(g,alpha,y,sigma),jnp.zeros(G.dim),method='BFGS') 87 | alpha = res.x 88 | 89 | return alpha 90 | G.MPP = MPP 91 | 92 | def MPP_drift_f(g,alpha,y,sigma,proj,M,_dts): 93 | _,_,_,_,horz = horz_vert_split(g,proj,jnp.eye(G.dim),G,M) 94 | (ts,alphags) = G.mpp_drift(jnp.dot(horz,alpha),g,_dts,sigma) 95 | gT = alphags[-1,G.dim:].reshape((G.dim,G.dim)) 96 | return (1./M.emb_dim)*jnp.sum(jnp.square(proj(gT)-M.F(y))) 97 | 98 | def MPP_drift(g,y,proj,M,sigma=jnp.eye(G.dim)): 99 | _dts = dts() 100 | res = jax.scipy.optimize.minimize(lambda alpha: MPP_drift_f(g,alpha,y,sigma,proj,M,_dts),jnp.zeros(M.dim),method='BFGS') 101 | _,_,_,_,horz = horz_vert_split(g,proj,jnp.eye(G.dim),G,M) 102 | alpha = jnp.dot(horz,res.x) 103 | 104 | return alpha 105 | G.MPP_drift = MPP_drift 106 | 107 | 108 | # # Most probable paths 109 | # def initialize(G,horz,vert,a=None): 110 | # """ Most probable paths and development """ 111 | 112 | # assert(G.invariance == 'right') 113 | 114 | # def ode_mpp(sigma,c,y): 115 | # t,x,_ = c 116 | 117 | # vt = x[0] 118 | # ct = x[1] 119 | # lambdt = x[2] 120 | # at = a(t) if a is not None else jnp.zeros_like(vt) 121 | # Sigma = G.W(sigma) 122 | 123 | # domegat = jnp.dot(Sigma,vt)-at 124 | # dvt = horz(G.coad(domegat,vt-lambdt+ct)) 125 | # dct = -vert(G.coad(domegat,vt)) 126 | # dlambdt = vert(G.coad(domegat,ct)) 127 | # return jnp.stack((dvt,ct,dlambdt)) 128 | # G.mpp = lambda v,c,lambd,dts,sigma=jnp.eye(G.dim): integrate(partial(ode_mpp,sigma),None,jnp.stack((v,c,lambd)),None,dts) 129 | 130 | # # reconstruction 131 | # def ode_mpprec(sigma,c,y): 132 | # t,g,_ = c 133 | 134 | # x, = y 135 | # vt = x[0] 136 | # ct = x[1] 137 | # lambdt = x[2] 138 | # at = a(t) if a is not None else jnp.zeros_like(vt) 139 | # Sigma = G.W(sigma) 140 | 141 | # domegat = jnp.dot(Sigma,vt)-at 142 | # dgt = G.invpf(g,G.VtoLA(domegat)) 143 | # return dgt 144 | # G.mpprec = lambda g,vclambds,dts,sigma=jnp.eye(G.dim): integrate(partial(ode_mpprec,sigma),None,g,None,dts,vclambds) 145 | 146 | # def MPP_forwardt(g,v,c,lambd,sigma,T=T,n_steps=n_steps): 147 | # _dts = dts(T=T,n_steps=n_steps) 148 | # (ts,xs) = G.mpp(v,c,lambd,_dts,sigma) 149 | # (ts,gs) = G.mpprec(g,xs,_dts,sigma) 150 | 151 | # vs = xs[0] 152 | # cs = xs[1] 153 | # lambds = xs[2] 154 | # return(gs,cs) 155 | # G.MPP_forwardt = MPP_forwardt 156 | 157 | # # optimization to satisfy end-point conditions 158 | # def MPP_f(g,x,y,sigma): 159 | # x = x.reshape((3,G.dim)) 160 | # v = x[0] 161 | # c = x[1] 162 | # lambd = x[2] 163 | # gs,cs = G.MPP_forwardt(g,v,c,lambd,sigma) 164 | # gT = gs[-1] 165 | # cT = cs[-1] 166 | # return (1./G.emb_dim)*jnp.sum(jnp.square(gT-y))+(1./G.dim)*jnp.sum(jnp.square(cT)) 167 | 168 | # def MPP(g,y,sigma=jnp.eye(G.dim)): 169 | # res = jax.scipy.optimize.minimize(lambda x: MPP_f(g,x,y,sigma),jnp.zeros((3,G.dim)).flatten(),method='BFGS') 170 | # x = res.x.reshape((3,G.dim)) 171 | # v = x[0] 172 | # c = x[1] 173 | # lambd = x[2] 174 | 175 | # return (v,c,lambd) 176 | # G.MPP = MPP 177 | 178 | # E = jnp.eye(G.dim) 179 | # horz = lambda v: jnp.dot(E[:,:2],jnp.dot(E[:,:2].T,v)) 180 | # vert = lambda v: jnp.dot(E[:,2:],jnp.dot(E[:,2:].T,v)) 181 | # initialize(G,horz,vert) 182 | -------------------------------------------------------------------------------- /jaxgeometry/dynamics/MPP_landmarks.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.utils import * 23 | 24 | ############################################################### 25 | # Most probable paths for landmarks via development # 26 | ############################################################### 27 | def initialize(M,sigmas,dsigmas,a): 28 | """ Most probable paths for Kunita flows """ 29 | """ M: shape manifold, a: flow field """ 30 | 31 | def ode_MPP_landmarks(c,y): 32 | t,xlambd,chart = c 33 | qp, = y 34 | x = xlambd[0].reshape((M.N,M.m)) # points 35 | lambd = xlambd[1].reshape((M.N,M.m)) 36 | 37 | sigmasx = sigmas(x) 38 | dsigmasx = dsigmas(x) 39 | c = jnp.einsum('ri,rai->a',lambd,sigmasx) 40 | 41 | dx = a(x,qp)+jnp.einsum('a,rak->rk',c,sigmasx) 42 | #dlambd = -jnp.einsum('ri,a,rairk->rk',lambd,c,jacrev(sigmas)(x))-jnp.einsum('ri,rirk->rk',lambd,jacrev(a)(x,qp)) 43 | dlambd = -jnp.einsum('ri,a,raik->rk',lambd,c,dsigmasx)-jnp.einsum('ri,rirk->rk',lambd,jacrev(a)(x,qp)) 44 | return jnp.stack((dx.flatten(),dlambd.flatten())) 45 | 46 | def chart_update_MPP_landmarks(xlambd,chart,y): 47 | if M.do_chart_update is None: 48 | return (xlambd,chart) 49 | 50 | lambd = xlambd[1].reshape((M.N,M.m)) 51 | x = (xlambd[0],chart) 52 | 53 | update = M.do_chart_update(x) 54 | new_chart = M.centered_chart(x) 55 | new_x = M.update_coords(x,new_chart)[0] 56 | 57 | return (jnp.where(update, 58 | jnp.stack((new_x,M.update_covector(x,new_x,new_chart,lambd))), 59 | xlambd), 60 | jnp.where(update, 61 | new_chart, 62 | chart)) 63 | 64 | def MPP_landmarks(x,lambd,qps,dts): 65 | (ts,xlambds,charts) = integrate(ode_MPP_landmarks,chart_update_MPP_landmarks,jnp.stack((x[0],lambd)),x[1],dts,qps) 66 | return (ts,xlambds[:,0],xlambds[:,1],charts) 67 | M.MPP_landmarks = MPP_landmarks 68 | 69 | -------------------------------------------------------------------------------- /jaxgeometry/dynamics/MPP_landmarks_Log.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.utils import * 23 | 24 | ############################################################### 25 | # Most probable paths for landmarks via development - BVP # 26 | ############################################################### 27 | def initialize(M): 28 | method='BFGS' 29 | 30 | def loss(x,lambd,y,qps,_dts): 31 | (_,xs,_,charts) = M.MPP_landmarks(x,lambd,qps,_dts) 32 | (x1,chart1) = (xs[-1],charts[-1]) 33 | y_chart1 = M.update_coords(y,chart1) 34 | return 1./M.dim*jnp.sum(jnp.square(x1 - y_chart1[0])) 35 | 36 | from scipy.optimize import minimize,fmin_bfgs,fmin_cg 37 | def shoot(x,y,qps,_dts,lambd0=None): 38 | 39 | if lambd0 is None: 40 | lambd0 = jnp.zeros(M.dim) 41 | 42 | res = minimize(jax.value_and_grad(lambda w: loss(x,w,y,qps,_dts)), lambd0, method=method, jac=True, options={'disp': False, 'maxiter': 100}) 43 | 44 | return (res.x,res.fun) 45 | 46 | M.Log_MPP_landmarks = shoot 47 | -------------------------------------------------------------------------------- /jaxgeometry/dynamics/flow.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.utils import * 23 | 24 | def initialize(M): 25 | """ flow along a vector field X """ 26 | def flow(X): 27 | 28 | def ode_flow(c,y): 29 | t,x,chart = c 30 | return X((x,chart)) 31 | 32 | def chart_update_flow(x,chart,*ys): 33 | if M.do_chart_update is None: 34 | return (x,chart) 35 | 36 | update = M.do_chart_update(x) 37 | new_chart = M.centered_chart((x,chart)) 38 | new_x = M.update_coords((x,chart),new_chart)[0] 39 | 40 | return (jnp.where(update, 41 | new_x, 42 | x), 43 | jnp.where(update, 44 | new_chart, 45 | chart), 46 | ) 47 | 48 | flow = jit(lambda x,dts: integrate(ode_flow,chart_update_flow,x[0],x[1],dts)) 49 | return flow 50 | M.flow = flow 51 | -------------------------------------------------------------------------------- /jaxgeometry/dynamics/flow_differential.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.utils import * 23 | 24 | ############################################################### 25 | # Compute differential d\phi along a phase-space path qt # 26 | # See Younes, Shapes and Diffeomorphisms, 2010 and # 27 | # Sommer et al., SIIMS 2013 # 28 | ############################################################### 29 | def initialize(M): 30 | """ M: landmark manifold, scalar kernel """ 31 | 32 | def ode_differential(c,y): 33 | t,dphi,chart = c 34 | qp, = y 35 | q = qp[0].reshape((M.N,M.m)) # points 36 | p = qp[1].reshape((M.N,M.m)) # points 37 | 38 | dk = M.dk_q(q,q) 39 | ddphi = jnp.einsum('iab,jic,jb->iac',dphi,dk,p) 40 | 41 | return ddphi 42 | 43 | def chart_update_differential(dphi,chart,y): 44 | if M.do_chart_update is None: 45 | return (dphi,chart) 46 | 47 | assert(False) # not implemented yet 48 | 49 | def flow_differential(qps,dts): 50 | """ Transport covector lambd along covector path qps """ 51 | (ts,dphis,charts) = integrate(ode_differential,chart_update_differential,jnp.tile(jnp.eye(M.m),(M.N,1,1)),None,dts,qps) 52 | return (ts,dphis,charts) 53 | M.flow_differential = flow_differential 54 | -------------------------------------------------------------------------------- /jaxgeometry/framebundle/FM.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.params import * 23 | from jaxgeometry.utils import * 24 | 25 | 26 | def initialize(M): 27 | """ Frame Bundle geometry """ 28 | 29 | d = M.dim 30 | 31 | def chart_update_FM(u,chart,*args): 32 | if M.do_chart_update != True: 33 | return (u,chart) 34 | 35 | x = (u[0:d],chart) 36 | nu = u[d:].reshape((d,-1)) 37 | 38 | update = M.do_chart_update(x) 39 | new_chart = M.centered_chart(x) 40 | new_x = M.update_coords(x,new_chart)[0] 41 | 42 | return (jnp.where(update, 43 | jnp.concatenate((new_x,M.update_vector(x,new_x,new_chart,nu).flatten())), 44 | u), 45 | jnp.where(update, 46 | new_chart, 47 | chart)) 48 | M.chart_update_FM = chart_update_FM 49 | 50 | #### Bases shifts, see e.g. Sommer Entropy 2016 sec 2.3 51 | # D denotes frame adapted to the horizontal distribution 52 | def to_D(u,w): 53 | x = (u[0][0:d],u[1]) 54 | nu = u[0][d:].reshape((d,-1)) 55 | wx = w[0:d] 56 | wnu = w[d:].reshape((d,-1)) 57 | 58 | # shift to D basis 59 | Gammanu = jnp.tensordot(M.Gamma_g(x),nu,(2,0)).swapaxes(1,2) 60 | Dwx = wx 61 | Dwnu = jnp.tensordot(Gammanu,wx,(2,0))+wnu 62 | 63 | return jnp.concatenate((Dwx,Dwnu.flatten())) 64 | def from_D(u,Dw): 65 | x = (u[0][0:d],u[1]) 66 | nu = u[0][d:].reshape((d,-1)) 67 | Dwx = Dw[0:d] 68 | Dwnu = Dw[d:].reshape((d,-1)) 69 | 70 | # shift to D basis 71 | Gammanu = jnp.tensordot(M.Gamma_g(x),nu,(2,0)).swapaxes(1,2) 72 | wx = Dwx 73 | wnu = -jnp.tensordot(Gammanu,Dwx,(2,0))+Dwnu 74 | 75 | return jnp.concatenate((wx,wnu.flatten())) 76 | # corresponding dual space shifts 77 | def to_Dstar(u,p): 78 | x = (u[0][0:d],u[1]) 79 | nu = u[0][d:].reshape((d,-1)) 80 | px = p[0:d] 81 | pnu = p[d:].reshape((d,-1)) 82 | 83 | # shift to D basis 84 | Gammanu = jnp.tensordot(M.Gamma_g(x),nu,(2,0)).swapaxes(1,2) 85 | Dpx = px-jnp.tensordot(Gammanu,pnu,((0,1),(0,1))) 86 | Dpnu = pnu 87 | 88 | return jnp.concatenate((Dpx,Dpnu.flatten())) 89 | def from_Dstar(u,Dp): 90 | x = (u[0][0:d],u[1]) 91 | nu = u[0][d:].reshape((d,-1)) 92 | Dpx = Dp[0:d] 93 | Dpnu = Dp[d:].reshape((d,-1)) 94 | 95 | # shift to D basis 96 | Gammanu = jnp.tensordot(M.Gamma_g(x),nu,(2,0)).swapaxes(1,2) 97 | px = Dpx+jnp.tensordot(Gammanu,Dpnu,((0,1),(0,1))) 98 | pnu = Dpnu 99 | 100 | return jnp.concatenate((px,pnu.flatten())) 101 | M.to_D = to_D 102 | M.from_D = from_D 103 | M.to_Dstar = to_Dstar 104 | M.from_Dstar = from_Dstar 105 | 106 | ##### Horizontal vector fields: 107 | def Horizontal(u): 108 | x = (u[0][0:d],u[1]) 109 | nu = u[0][d:].reshape((d,-1)) 110 | 111 | # Contribution from the coordinate basis for x: 112 | dx = nu 113 | # Contribution from the basis for Xa: 114 | Gammahgammaj = jnp.einsum('hji,ig->hgj',M.Gamma_g(x),nu) # same as Gammanu above 115 | dnu = -jnp.einsum('hgj,ji->hgi',Gammahgammaj,nu) 116 | 117 | return jnp.concatenate([dx,dnu.reshape((-1,nu.shape[1]))],axis=0) 118 | M.Horizontal = Horizontal 119 | 120 | 121 | -------------------------------------------------------------------------------- /jaxgeometry/group/EulerPoincare.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | def initialize(G): 24 | """ Euler-Poincare geodesic integration """ 25 | 26 | assert(G.invariance == 'left') 27 | 28 | def ode_EP(c,y): 29 | t,mu,_ = c 30 | xi = G.invFl(mu) 31 | dmut = -G.coad(xi,mu) 32 | return dmut 33 | G.EP = lambda mu,_dts=None: integrate(ode_EP,None,mu,None,dts() if _dts is None else _dts) 34 | 35 | # reconstruction 36 | def ode_EPrec(c,y): 37 | t,g,_ = c 38 | mu, = y 39 | xi = G.invFl(mu) 40 | dgt = G.dL(g,G.e,G.VtoLA(xi)) 41 | return dgt 42 | G.EPrec = lambda g,mus,_dts=None: integrate(ode_EPrec,None,g,None,dts() if _dts is None else _dts,mus) 43 | 44 | ### geodesics 45 | G.coExpEP = lambda g,mu: G.EPrec(g,G.EP(mu)[1])[1][-1] 46 | G.ExpEP = lambda g,v: G.coExpEP(g,G.flatV(v)) 47 | G.ExpEPpsi = lambda q,v: G.ExpEP(G.psi(q),G.flatV(v)) 48 | G.coExpEPt = lambda g,mu: G.EPrec(g,G.EP(mu)[1]) 49 | G.ExpEPt = lambda g,v: G.coExpEPt(g,G.flatV(v)) 50 | G.ExpEPpsit = lambda q,v: G.ExpEPt(G.psi(q),G.flatV(v)) 51 | G.DcoExpEP = lambda g,mu: jax.jaxrev(G.coExpEP)(g,mu) 52 | -------------------------------------------------------------------------------- /jaxgeometry/group/LiePoisson.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | def initialize(G): 24 | """ Lie-Poisson geodesic integration """ 25 | 26 | assert(G.invariance == 'left') 27 | 28 | def ode_LP(c,y): 29 | t,mu,_ = c 30 | dmut = G.coad(G.dHminusdmu(mu),mu) 31 | return dmut 32 | G.LP = lambda mu,_dts=None: integrate(ode_LP,None,mu,None,dts() if _dts is None else _dts) 33 | 34 | # reconstruction 35 | def ode_LPrec(c,y): 36 | t,g,_ = c 37 | mu, = y 38 | dgt = G.dL(g,G.e,G.VtoLA(G.dHminusdmu(mu))) 39 | return dgt 40 | G.LPrec = lambda g,mus,_dts=None: integrate(ode_LPrec,None,g,None,dts() if _dts is None else _dts,mus) 41 | 42 | ### geodesics 43 | G.coExpLP = lambda g,mu: G.LPrec(g,G.LP(mu)[1])[1][-1] 44 | G.ExpLP = lambda g,v: G.coExpLP(g,G.flatV(v)) 45 | G.coExpLPt = lambda g,mu: G.LPrec(g,G.LP(mu)[1]) 46 | G.ExpLPt = lambda g,v: G.coExpLPt(g,G.flatV(v)) 47 | G.DcoExpLP = lambda g,mu: jax.jacrev(G.coExp)(g,mu) 48 | -------------------------------------------------------------------------------- /jaxgeometry/group/energy.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | def initialize(G): 24 | """ group Lagrangian and Hamiltonian from invariant metric """ 25 | 26 | # Lagrangian 27 | def Lagrangian(g,vg): 28 | return .5*G.gG(g,vg,vg) 29 | G.Lagrangian = Lagrangian 30 | # Lagrangian using psi map 31 | def Lagrangianpsi(q,v): 32 | return .5*G.gpsi(q,v,v) 33 | G.Lagrangianpsi = Lagrangianpsi 34 | G.dLagrangianpsidq = jax.grad(G.Lagrangianpsi) 35 | G.dLagrangianpsidv = jax.grad(G.Lagrangianpsi) 36 | # LA restricted Lagrangian 37 | def l(hatxi): 38 | return 0.5*G.gV(hatxi,hatxi) 39 | G.l = l 40 | G.dldhatxi = jax.grad(G.l) 41 | 42 | # Hamiltonian using psi map 43 | def Hpsi(q,p): 44 | return .5*G.cogpsi(q,p,p) 45 | G.Hpsi = Hpsi 46 | # LA^* restricted Hamiltonian 47 | def Hminus(mu): 48 | return .5*G.cogV(mu,mu) 49 | G.Hminus = Hminus 50 | G.dHminusdmu = jax.grad(G.Hminus) 51 | 52 | # Legendre transformation. The above Lagrangian is hyperregular 53 | G.FLpsi = lambda q,v: (q,G.dLagrangianpsidv(q,v)) 54 | G.invFLpsi = lambda q,p: (q,G.cogpsi(q,p)) 55 | def HL(q,p): 56 | (q,v) = invFLpsi(q,p) 57 | return jnp.dot(p,v)-L(q,v) 58 | G.HL = HL 59 | G.Fl = lambda hatxi: G.dldhatxi(hatxi) 60 | G.invFl = lambda mu: G.cogV(mu) 61 | def hl(mu): 62 | hatxi = invFl(mu) 63 | return jnp.dot(mu,hatxi)-l(hatxi) 64 | G.hl = hl 65 | 66 | # default Hamiltonian 67 | G.H = lambda q,p: G.Hpsi(q[0],p) if type(q) == type(()) else G.Hpsi(q,p) 68 | 69 | # A.set_value(np.diag([3,2,1])) 70 | # print(FLpsif(q0,v0)) 71 | # print(invFLpsif(q0,p0)) 72 | # (flq0,flv0)=FLpsif(q0,v0) 73 | # print(q0,v0) 74 | # print(invFLpsif(flq0,flv0)) 75 | -------------------------------------------------------------------------------- /jaxgeometry/group/invariant_metric.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | def initialize(G,_sigma=None): 24 | """ add left-/right-invariant metric related structures to group 25 | 26 | parameter sigma is square root cometric / diffusion field 27 | """ 28 | 29 | if _sigma is None: 30 | _sigma = jnp.eye(G.dim) 31 | 32 | G.sqrtA = lambda sigma=_sigma: G.inv(sigma) # square root metric 33 | G.A = lambda sigma=_sigma: jnp.tensordot(G.sqrtA(sigma),G.sqrtA(sigma),(0,0)) # metric 34 | G.W = lambda sigma=_sigma: jnp.tensordot(sigma,sigma,(1,1)) # covariance (cometric) 35 | def gV(v=None,w=None,sigma=_sigma): 36 | if v is None and w is None: 37 | return G.A(sigma) 38 | elif v is not None and w is None: 39 | return jnp.tensordot(G.A(sigma),v,(1,0)) 40 | elif v.ndim == 1 and w.ndim == 1: 41 | return jnp.dot(v,jnp.dot(G.A(sigma),w)) 42 | elif v.ndim == 1 and not w: 43 | return jnp.dot(G.A(sigma),v) 44 | elif v.ndim == 2 and w.ndim == 2: 45 | return jnp.tensordot(v,jnp.tensordot(G.A(sigma),w,(1,0)),(0,0)) 46 | else: 47 | assert(False) 48 | G.gV = gV 49 | def cogV(cov=None,cow=None,sigma=_sigma): 50 | if cov is None and cow is None: 51 | return G.W(sigma) 52 | elif cov is not None and cow is None: 53 | return jnp.tensordot(G.W(sigma),cov,(1,0)) 54 | elif cov.ndim == 1 and cow.ndim == 1: 55 | return jnp.dot(cov,jnp.dot(G.W(sigma),cow)) 56 | elif cov.ndim == 2 and cow.ndim == 2: 57 | return jnp.tensordot(cov,jnp.tensordot(G.W(sigma),cow,(1,0)),(0,0)) 58 | else: 59 | assert(False) 60 | G.cogV = cogV 61 | def gLA(xiv,xiw,sigma=_sigma): 62 | v = G.LAtoV(xiv) 63 | w = G.LAtoV(xiw) 64 | return G.gV(v,w,sigma) 65 | G.gLA = gLA 66 | def cogLA(coxiv,coxiw,sigma=_sigma): 67 | cov = G.LAtoV(coxiv) 68 | cow = G.LAtoV(coxiw) 69 | return G.cogV(cov,cow,sigma) 70 | G.cogLA = cogLA 71 | def gG(g,vg,wg,sigma=_sigma): 72 | xiv = G.invpb(g,vg) 73 | xiw = G.invpb(g,wg) 74 | return G.gLA(xiv,xiw,sigma) 75 | G.gG = gG 76 | def gpsi(hatxi,v=None,w=None,sigma=_sigma): 77 | g = G.psi(hatxi) 78 | vg = G.dpsi(hatxi,v) 79 | wg = G.dpsi(hatxi,w) 80 | return G.gG(g,vg,wg,sigma) 81 | G.gpsi = gpsi 82 | def cogpsi(hatxi,p=None,pp=None,sigma=_sigma): 83 | invgpsi = G.inv(G.gpsi(hatxi,sigma=sigma)) 84 | if p is not None and pp is not None: 85 | return jnp.tensordot(p,jnp.tensordot(invgpsi,pp,(1,0)),(0,0)) 86 | elif p and not pp: 87 | return jnp.tensordot(invgpsi,p,(1,0)) 88 | return invgpsi 89 | G.cogpsi = cogpsi 90 | 91 | # sharp/flat mappings 92 | def sharpV(mu,sigma=_sigma): 93 | return jnp.dot(G.W(sigma),mu) 94 | G.sharpV = sharpV 95 | def flatV(v,sigma=_sigma): 96 | return jnp.dot(G.A(sigma),v) 97 | G.flatV = flatV 98 | def sharp(g,pg,sigma=_sigma): 99 | return G.invpf(g,G.VtoLA(jnp.dot(G.W(sigma),G.LAtoV(G.invcopb(g,pg))))) 100 | G.sharp = sharp 101 | def flat(g,vg,sigma=_sigma): 102 | return G.invcopf(g,G.VtoLA(jnp.dot(G.A(sigma),G.LAtoV(G.invpb(g,vg))))) 103 | G.flat = flat 104 | def sharppsi(hatxi,p,sigma=_sigma): 105 | return jnp.tensordot(G.cogpsi(hatxi,sigma=sigma),p,(1,0)) 106 | G.sharppsi = sharppsi 107 | def flatpsi(hatxi,v,sigma=_sigma): 108 | return jnp.tensordot(G.gpsi(hatxi,sigma=sigma),v,(1,0)) 109 | G.flatpsi = flatpsi 110 | 111 | -------------------------------------------------------------------------------- /jaxgeometry/group/quotient.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.utils import * 23 | 24 | def horz_vert_split(x,proj,sigma,G,M): 25 | # compute kernel of proj derivative with respect to inv A metric 26 | rank = M.dim 27 | Xframe = jnp.tensordot(G.invpf(x,G.eiLA),sigma,(2,0)) 28 | Xframe_inv = jnp.linalg.pinv(Xframe.reshape((-1,G.dim))) 29 | dproj = jnp.einsum('...ij,ijk->...k',jacrev(proj)(x), Xframe) 30 | (_,_,Vh) = jnp.linalg.svd(jax.lax.stop_gradient(dproj),full_matrices=True) 31 | ns = Vh[rank:].T # null space 32 | proj_ns = jnp.tensordot(ns,ns,(1,1)) 33 | horz = Vh[0:rank].T # horz space 34 | proj_horz = jnp.tensordot(horz,horz,(1,1)) 35 | 36 | return (Xframe,Xframe_inv,proj_horz,proj_ns,horz) 37 | 38 | # hit target v at time t=Tend 39 | def get_sde_fiber(sde_f,proj,G,M): 40 | def sde_fiber(c,y): 41 | (det,sto,X,*dys_sde) = sde_f(c,y) 42 | t,g,_,sigma = c 43 | dt,dW = y 44 | 45 | (Xframe,Xframe_inv,_,proj_ns,_) = horz_vert_split(g,proj,sigma,G,M) 46 | 47 | det = jnp.tensordot(Xframe,jnp.tensordot(proj_ns,jnp.tensordot(Xframe_inv,det.flatten(),(1,0)),(1,0)),(2,0)).reshape(g.shape) 48 | sto = jnp.tensordot(Xframe,jnp.tensordot(proj_ns,jnp.tensordot(Xframe_inv,sto.flatten(),(1,0)),(1,0)),(2,0)).reshape(g.shape) 49 | X = jnp.tensordot(Xframe,jnp.tensordot(proj_ns,jnp.tensordot(Xframe_inv,X.reshape((-1,G.dim)),(1,0)),(1,0)),(2,0)).reshape(X.shape) 50 | 51 | return (det,sto,X,*dys_sde) 52 | 53 | return sde_fiber 54 | 55 | def get_sde_horz(sde_f,proj,G,M): 56 | def sde_horz(c,y): 57 | (det,sto,X,*dys_sde) = sde_f(c,y) 58 | t,g,_,sigma = c 59 | dt,dW = y 60 | 61 | (Xframe,Xframe_inv,proj_horz,_,_) = horz_vert_split(g,proj,sigma,G,M) 62 | det = jnp.tensordot(Xframe,jnp.tensordot(proj_horz,jnp.tensordot(Xframe_inv,det.flatten(),(1,0)),(1,0)),(2,0)).reshape(g.shape) 63 | sto = jnp.tensordot(Xframe,jnp.tensordot(proj_horz,jnp.tensordot(Xframe_inv,sto.flatten(),(1,0)),(1,0)),(2,0)).reshape(g.shape) 64 | X = jnp.tensordot(Xframe,jnp.tensordot(proj_horz,jnp.tensordot(Xframe_inv,X.reshape((-1,G.dim)),(1,0)),(1,0)),(2,0)).reshape(X.shape) 65 | 66 | return (det,sto,X,*dys_sde) 67 | 68 | return sde_horz 69 | 70 | def get_sde_lifted(sde_f,proj,G,M): 71 | def sde_lifted(c,y): 72 | t,g,chart,sigma,*cs = c 73 | dt,dW = y 74 | 75 | (det,sto,X,*dys_sde) = sde_f((t,M.invF((proj(g),chart)),chart,*cs),y) 76 | 77 | (Xframe,Xframe_inv,proj_horz,_,horz) = horz_vert_split(g,proj,sigma,G,M) 78 | 79 | 80 | det = jnp.tensordot(Xframe,jnp.tensordot(horz,det,(1,0)),(2,0)).reshape(g.shape) 81 | sto = jnp.tensordot(Xframe,jnp.tensordot(horz,sto,(1,0)),(2,0)).reshape(g.shape) 82 | X = jnp.tensordot(Xframe,jnp.tensordot(horz,X,(1,0)),(2,0)).reshape((G.dim,G.dim,M.dim)) 83 | 84 | return (det,sto,X,jnp.zeros_like(sigma),*dys_sde) 85 | 86 | return sde_lifted 87 | 88 | ## find g in fiber above x closests to g0 89 | #from scipy.optimize import minimize 90 | #def lift_to_fiber(x,x0,G,M): 91 | # shoot = lambda hatxi: G.gV(hatxi,hatxi) 92 | # try: 93 | # hatxi = minimize(shoot, 94 | # np.zeros(G.dim), 95 | # method='COBYLA', 96 | # constraints={'type':'ineq','fun':lambda hatxi: np.min((G.injectivity_radius-np.max(hatxi), 97 | # 1e-8-np.linalg.norm(M.act(G.exp(G.VtoLA(hatxi)),x0)-x)**2))}, 98 | # ).x 99 | # hatxi = minimize(lambda hatxi: np.linalg.norm(M.act(G.exp(G.VtoLA(hatxi)),x0)-x)**2, 100 | # hatxi).x # fine tune 101 | # except AttributeError: # injectivity radius not defined 102 | # hatxi = minimize(shoot, 103 | # np.zeros(G.dim), 104 | # method='COBYLA', 105 | # constraints={'type':'ineq','fun':lambda hatxi: 1e-8-np.linalg.norm(M.act(G.exp(G.VtoLA(hatxi)),x0)-x)**2}).x 106 | # hatxi = minimize(lambda hatxi: np.linalg.norm(M.act(G.exp(G.VtoLA(hatxi)),x0)-x)**2, 107 | # hatxi).x # fine tune 108 | # l0 = G.exp(G.VtoLA(hatxi)) 109 | # try: # project to group if to_group function is available 110 | # l0 = G.to_group(l0) 111 | # except NameError: 112 | # pass 113 | # return (l0,hatxi) 114 | # 115 | ## estimate fiber volume 116 | #import scipy.special 117 | #from jaxgeometry.plotting import * 118 | # 119 | #def fiber_samples(G,Brownian_fiberf,L,pars): 120 | # (seed,) = pars 121 | # if seed: 122 | # srng.seed(seed) 123 | # gsl = np.zeros((L,) + G.e.shape) 124 | # dsl = np.zeros(L) 125 | # (ts, gs) = Brownian_fiber(G.e, dWs(G.dim)) 126 | # vl = gs[-1] # starting point 127 | # for l in range(L): 128 | # (ts, gs) = Brownian_fiber(vl, dWs(G.dim)) 129 | # gsl[l] = gs[-1] 130 | # dsl[l] = np.linalg.norm(G.LAtoV(G.log(gs[-1]))) # distance to sample with canonical biinvariant metric 131 | # vl = gs[-1] 132 | # 133 | # return (gsl, dsl) 134 | # 135 | #def estimate_fiber_volume(G, M, lfiber_samples, nr_samples=100, plot_dist_histogram=False, plot_samples=False): 136 | # """ estimate fiber volume with restricted Riemannian G volume element (biinvariant metric) """ 137 | # L = nr_samples // (cpu_count() // 2) # samples per processor 138 | # 139 | # try: 140 | # mpu.openPool() 141 | # sol = mpu.pool.imap(partial(lfiber_samples, L), mpu.inputArgs(np.random.randint(1000, size=cpu_count() // 2))) 142 | # res = list(sol) 143 | # gsl = mpu.getRes(res, 0).reshape((-1,) + G.e.shape) 144 | # dsl = mpu.getRes(res, 1).flatten() 145 | # except: 146 | # mpu.closePool() 147 | # raise 148 | # else: 149 | # mpu.closePool() 150 | # 151 | # if plot_dist_histogram: 152 | # # distance histogram 153 | # plt.hist(dsl, 20) 154 | # 155 | # if plot_samples: 156 | # # plot samples 157 | # newfig() 158 | # for l in range(0, L): 159 | # G.plotg(gsl[l]) 160 | # plt.show() 161 | # 162 | # # count percentage of samples below distance d to e relative to volume of d-ball 163 | # d = np.max(dsl) # distance must be smaller than fiber radius 164 | # fiber_dim = G.dim - M.dim 165 | # ball_volume = np.pi ** (fiber_dim / 2) / scipy.special.gamma(fiber_dim / 2 + 1) * d ** fiber_dim 166 | # 167 | # return ball_volume / (np.sum(dsl < d) / (dsl.size)) 168 | -------------------------------------------------------------------------------- /jaxgeometry/groups/GLN.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.params import * 22 | 23 | from jaxgeometry.groups.group import * 24 | 25 | import matplotlib.pyplot as plt 26 | 27 | class GLN(LieGroup): 28 | """ General linear group GL(N) """ 29 | 30 | def __init__(self,N=3): 31 | dim = N*N # group dimension 32 | LieGroup.__init__(self,dim,N=N,invariance='left') 33 | 34 | # project to group, here with minimum eigenvalue 1e-3 35 | def to_group(g): 36 | _min_eig = 1e-3 37 | w, V = jnp.linalg.eig(g.astype('complex128')) 38 | w_prime = jnp.where(abs(w) < _min_eig, _min_eig, w) 39 | return jnp.dot(V,jnp.dot(jnp.diag(w_prime),V.T)).real 40 | 41 | ## coordinate chart on the linking Lie algebra, trival in this case 42 | def VtoLA(hatxi): # from \RR^G.dim to LA 43 | if hatxi.ndim == 1: 44 | return hatxi.reshape((N,N)) 45 | else: # matrix 46 | return hatxi.reshape((N,N,-1)) 47 | self.VtoLA = VtoLA 48 | def LAtoV(m): # from LA to \RR^G.dim 49 | if m.ndim == 2: 50 | return m.reshape((self.dim,)) 51 | elif m.ndim == 3: 52 | return m.reshape((self.dim,-1)) 53 | else: 54 | assert(False) 55 | self.LAtoV = LAtoV 56 | 57 | self.Expm = jax.scipy.linalg.expm 58 | def logm(b): 59 | I = jnp.eye(b.shape[0]) 60 | res = jnp.zeros_like(b) 61 | ITERATIONS = 20 62 | for k in range(1, ITERATIONS): 63 | res += pow(-1, k+1) * jnp.linalg.matrix_power(b-I, k)/k 64 | return res 65 | self.Logm = logm 66 | 67 | super(GLN,self).initialize() 68 | 69 | def __str__(self): 70 | return "GL(%d) (dimension %d)" % (self.N,self.dim) 71 | 72 | def plot_path(self, g,color_intensity=1.,color=None,linewidth=3.,prevg=None): 73 | assert(len(g.shape)>2) 74 | for i in range(g.shape[0]): 75 | self.plotg(g[i], 76 | linewidth=linewidth if i==0 or i==g.shape[0]-1 else .3, 77 | color_intensity=color_intensity if i==0 or i==g.shape[0]-1 else .7, 78 | prevg=g[i-1] if i>0 else None) 79 | return 80 | 81 | def plotg(self, g,color_intensity=1.,color=None,linewidth=3.,prevg=None): 82 | s0 = np.eye(self.N) # shape 83 | s = np.dot(g,s0) # rotated shape 84 | if prevg is not None: 85 | prevs = np.dot(prevg,s0) 86 | 87 | colors = color_intensity*np.array([[1,0,0],[0,1,0],[0,0,1]]) 88 | for i in range(s.shape[1]): 89 | plt.quiver(0,0,0,s[0,i],s[1,i],s[2,i],pivot='tail',linewidth=linewidth,color=colors[i] if color is None else color,arrow_length_ratio=.15,length=1) 90 | if prevg is not None: 91 | ss = np.stack((prevs,s)) 92 | plt.plot(ss[:,0,i],ss[:,1,i],ss[:,2,i],linewidth=.3,color=colors[i]) 93 | -------------------------------------------------------------------------------- /jaxgeometry/groups/SON.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.params import * 22 | 23 | from jaxgeometry.groups.group import * 24 | 25 | import matplotlib.pyplot as plt 26 | 27 | class SON(LieGroup): 28 | """ Special Orthogonal Group SO(N) """ 29 | 30 | def __init__(self,N=3,invariance='left'): 31 | dim = N*(N-1)//2 # group dimension 32 | LieGroup.__init__(self,dim,N,invariance=invariance) 33 | 34 | self.injectivity_radius = 2*jnp.pi 35 | 36 | # project to group (here using QR factorization) 37 | def to_group(g): 38 | (q,r) = jnp.linalg.qr(g) 39 | return jnp.dot(q,jnp.diag(jnp.diag(r))) 40 | 41 | ## coordinate chart linking Lie algebra LA={A\in\RR^{NxN}|\trace{A}=0} and V=\RR^G_dim 42 | # derived from https://stackoverflow.com/questions/25326462/initializing-a-symmetric-theano-dmatrix-from-its-upper-triangle 43 | r = jnp.arange(N) 44 | tmp_mat = r[jnp.newaxis, :] + ((N * (N - 3)) // 2-(r * (r - 1)) // 2)[::-1,jnp.newaxis] 45 | triu_index_matrix = jnp.triu(tmp_mat+1)-jnp.diag(jnp.diagonal(tmp_mat+1)) 46 | 47 | def VtoLA(hatxi): # from \RR^G_dim to LA 48 | if hatxi.ndim == 1: 49 | m = jnp.concatenate((jnp.zeros(1),hatxi))[triu_index_matrix] 50 | return m-m.T 51 | else: # matrix 52 | m = jnp.concatenate((jnp.zeros((1,hatxi.shape[1])),hatxi))[triu_index_matrix,:] 53 | return m-m.transpose((1,0,2)) 54 | self.VtoLA = VtoLA 55 | self.LAtoV = lambda m: m[np.triu_indices(N, 1)] 56 | 57 | #import theano.tensor.slinalg 58 | #Expm = jnp.slinalg.Expm() 59 | def Expm(g): # hardcoded for skew symmetric matrices to allow higher-order gradients 60 | (w,V) = jnp.linalg.eigh(1.j*g) 61 | w = -1j*w 62 | expm = jnp.real(jnp.tensordot(V,jnp.tensordot(jnp.diag(jnp.exp(w)),jnp.conj(V.T),(1,0)),(1,0))) 63 | return expm 64 | self.Expm = Expm 65 | def logm(b): 66 | I = jnp.eye(b.shape[0]) 67 | res = jnp.zeros_like(b) 68 | ITERATIONS = 20 69 | for k in range(1, ITERATIONS): 70 | res += pow(-1, k+1) * jnp.linalg.matrix_power(b-I, k)/k 71 | return res 72 | self.Logm = logm 73 | 74 | super(SON,self).initialize() 75 | 76 | def __str__(self): 77 | return "SO(%d) (dimension %d)" % (self.N,self.dim) 78 | 79 | def newfig(self): 80 | newfig3d() 81 | 82 | ### plotting 83 | import matplotlib.pyplot as plt 84 | def plot_path(self,g,color_intensity=1.,color=None,linewidth=3.,alpha=1.,prevg=None): 85 | assert(len(g.shape)>2) 86 | for i in range(g.shape[0]): 87 | self.plotg(g[i], 88 | linewidth=linewidth if i==0 or i==g.shape[0]-1 else .3, 89 | color_intensity=color_intensity if i==0 or i==g.shape[0]-1 else .7, 90 | alpha=alpha, 91 | prevg=g[i-1] if i>0 else None) 92 | return 93 | 94 | def plotg(self,g,color_intensity=1.,color=None,linewidth=3.,alpha=1.,prevg=None): 95 | # Grid Settings: 96 | import matplotlib.ticker as ticker 97 | ax = plt.gca() 98 | x = jnp.arange(-10,10,1) 99 | ax.xaxis._axinfo["grid"]['linewidth'] = 0.3 100 | ax.yaxis._axinfo["grid"]['linewidth'] = 0.3 101 | ax.zaxis._axinfo["grid"]['linewidth'] = 0.3 102 | ax.set_xlim(-1.,1.) 103 | ax.set_ylim(-1.,1.) 104 | ax.set_zlim(-1.,1.) 105 | #ax.set_aspect("equal") 106 | 107 | s0 = jnp.eye(3) # shape 108 | s = jnp.dot(g,s0) # rotated shape 109 | if prevg is not None: 110 | prevs = jnp.dot(prevg,s0) 111 | 112 | colors = color_intensity*np.array([[1,0,0],[0,1,0],[0,0,1]]) 113 | for i in range(s.shape[1]): 114 | plt.quiver(0,0,0,s[0,i],s[1,i],s[2,i],pivot='tail',linewidth=linewidth,color=colors[i] if color is None else color,arrow_length_ratio=.15,length=1,alpha=alpha) 115 | if prevg is not None: 116 | ss = jnp.stack((prevs,s)) 117 | ss = ss/jnp.linalg.norm(ss,axis=1)[:,jnp.newaxis,:] 118 | plt.plot(ss[:,0,i],ss[:,1,i],ss[:,2,i],linewidth=1,color=colors[i]) 119 | 120 | -------------------------------------------------------------------------------- /jaxgeometry/groups/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ComputationalEvolutionaryMorphometry/jaxgeometry/70e22f05966d5076adab37d9579a4900f486485f/jaxgeometry/groups/__init__.py -------------------------------------------------------------------------------- /jaxgeometry/groups/group.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.params import * 22 | 23 | from jaxgeometry.manifolds.manifold import * 24 | 25 | class LieGroup(EmbeddedManifold): 26 | """ Base Lie Group class """ 27 | 28 | def __init__(self,dim,N,invariance='left'): 29 | EmbeddedManifold.__init__(self) 30 | 31 | self.dim = dim 32 | self.N = N # N in SO(N) 33 | self.emb_dim = N*N # matrix/embedding space dimension 34 | self.invariance = invariance 35 | 36 | self.e = jnp.eye(N,N) # identity element 37 | self.zeroLA = jnp.zeros((N,N)) # zero element in LA 38 | self.zeroV = jnp.zeros((self.dim,)) # zero element in V 39 | 40 | def initialize(self): 41 | """ Initial group operations. To be called by sub-classes after definition of dimension, Expm etc. 42 | 43 | Notation: 44 | hatxi # \RR^G_dim vector 45 | xi # matrix in LA 46 | eta # matrix in LA 47 | alpha # matrix in LA^* 48 | beta # matrix in LA^* 49 | g # \RR^{NxN} matrix 50 | gs # sequence of \RR^{NxN} matrices 51 | h # \RR^{NxN} matrix 52 | vg # \RR^{NxN} tangent vector at g 53 | wg # \RR^{NxN} tangent vector at g 54 | vh # \RR^{NxN} tangent vector at h 55 | w # \RR^G_dim tangent vector in coordinates 56 | v # \RR^G_dim tangent vector in coordinates 57 | pg # \RR^{NxN} cotangent vector at g 58 | ph # \RR^{NxN} cotangent vector at h 59 | p # \RR^G_dim cotangent vector in coordinates 60 | pp # \RR^G_dim cotangent vector in coordinates 61 | mu # \RR^G_dim LA cotangent vector in coordinates 62 | """ 63 | 64 | ## group operations 65 | self.inv = lambda g: jnp.linalg.inv(g) 66 | 67 | ## group exp/log maps 68 | self.exp = self.Expm 69 | def expt(xi,_dts=None): 70 | if _dts is None: _dts = dts() 71 | return lax.scan(lambda t,dt: (t+dt,self.exp(t*xi)),0.,_dts) 72 | self.expt = expt 73 | self.log = self.Logm 74 | 75 | ## Lie algebra 76 | self.eiV = jnp.eye(self.dim) # standard basis for V 77 | self.eiLA = self.VtoLA(self.eiV) # pushforward eiV basis for LA 78 | #stdLA = jnp.eye(N*N,N*N).reshape((N,N,N*N)) # standard basis for \RR^{NxN} 79 | #eijV = jnp.eye(G_dim) # standard basis for V 80 | #eijLA = jnp.zeros((N,N,G_dim)) # eij in LA 81 | def bracket(xi,eta): 82 | if xi.ndim == 2 and eta.ndim == 2: 83 | return jnp.tensordot(xi,eta,(1,0))-jnp.tensordot(eta,xi,(1,0)) 84 | elif xi.ndim == 3 and eta.ndim == 3: 85 | return jnp.tensordot(xi,eta,(1,0)).dimshuffle((0,2,1,3))-jnp.tensordot(eta,xi,(1,0)).dimshuffle((0,2,1,3)) 86 | else: 87 | assert(False) 88 | self.bracket = bracket 89 | #C = bracket(eiLA,eiLA) # structure constants, debug 90 | #C = jnp.linalg.lstsq(eiLA.reshape((N*N*G_dim*G_dim,G_dim*G_dim*G_dim)),bracket(eiLA,eiLA).reshape((N*N*G_dim*G_dim))).reshape((G_dim,G_dim,G_dim)) # structure constants 91 | self.C = jnp.zeros((self.dim,self.dim,self.dim)) # structure constants 92 | for i in range(self.dim): 93 | for j in range(self.dim): 94 | xij = self.bracket(self.eiLA[:,:,i],self.eiLA[:,:,j]) 95 | #lC[i,j,:] = jnp.linalg.lstsq( 96 | # self.eiLA.reshape((self.N*self.N,self.dim)), 97 | # xij.flatten(), 98 | # rcond=-1 99 | #)[0] 100 | self.C = self.C.at[i,j].set(jnp.linalg.lstsq( 101 | self.eiLA.reshape(self.N*self.N, self.dim), 102 | xij.reshape(self.N*self.N) 103 | )[0]) 104 | 105 | ## surjective mapping \psi:\RR^G_dim\rightarrow G 106 | self.psi = lambda hatxi: self.exp(self.VtoLA(hatxi)) 107 | self.invpsi = lambda g: self.LAtoV(self.log(g)) 108 | def dpsi(hatxi,v=None): 109 | dpsi = jax.jacrev(self.psi)(hatxi) 110 | if v: 111 | return jnp.tensordot(dpsi,v,(2,0)) 112 | return dpsi 113 | self.dpsi = dpsi 114 | def dinvpsi(g,vg=None): 115 | dinvpsi = jax.jacrev(self.invpsi)(g) 116 | if vg: 117 | return jnp.tensordot(dinvpsi,vg,((1,2),(0,1))) 118 | return dinvpsi 119 | self.dinvpsi = dinvpsi 120 | 121 | ## left/right translation 122 | self.L = lambda g,h: jnp.tensordot(g,h,(1,0)) # left translation L_g(h)=gh 123 | self.R = lambda g,h: jnp.tensordot(h,g,(1,0)) # right translation R_g(h)=hg 124 | # pushforward of L/R of vh\in T_hG 125 | #dL = lambda g,h,vh: theano.gradient.Rop(L(theano.gradient.disconnected_grad(g),h).flatten(),h,vh).reshape((N,N)) 126 | def dL(g,h,vh=None): 127 | dL = jax.jacrev(self.L,1)(g,h) 128 | if vh is not None: 129 | return jnp.tensordot(dL,vh,((2,3),(0,1))) 130 | return dL 131 | self.dL = dL 132 | def dR(g,h,vh=None): 133 | dR = jax.jacrev(self.R,1)(g,h) 134 | if vh is not None: 135 | return jnp.tensordot(dR,vh,((2,3),(0,1))) 136 | return dR 137 | self.dR = dR 138 | # pullback of L/R of vh\in T_h^*G 139 | self.codL = lambda g,h,vh: self.dL(g,h,vh).T 140 | self.codR = lambda g,h,vh: self.dR(g,h,vh).T 141 | 142 | ## actions 143 | self.Ad = lambda g,xi: self.dR(self.inv(g),g,self.dL(g,self.e,xi)) 144 | self.ad = lambda xi,eta: self.bracket(xi,eta) 145 | self.coad = lambda v,p: jnp.tensordot(jnp.tensordot(self.C,v,(0,0)),p,(1,0)) 146 | 147 | ## invariance 148 | if self.invariance == 'left': 149 | self.invtrns = self.L # invariance translation 150 | self.invpb = lambda g,vg: self.dL(self.inv(g),g,vg) # left invariance pullback from TgG to LA 151 | self.invpf = lambda g,xi: self.dL(g,self.e,xi) # left invariance pushforward from LA to TgG 152 | self.invcopb = lambda g,pg: self.codL(self.inv(g),g,pg) # left invariance pullback from Tg^*G to LA^* 153 | self.invcopf = lambda g,alpha: self.codL(g,self.e,alpha) # left invariance pushforward from LA^* to Tg^*G 154 | self.infgen = lambda xi,g: self.dR(g,self.e,xi) # infinitesimal generator 155 | else: 156 | self.invtrns = self.R # invariance translation 157 | self.invpb = lambda g,vg: self.dR(self.inv(g),g,vg) # right invariance pullback from TgG to LA 158 | self.invpf = lambda g,xi: self.dR(g,self.e,xi) # right invariance pushforward from LA to TgG 159 | self.invcopb = lambda g,pg: self.codR(self.inv(g),g,pg) # right invariance pullback from Tg^*G to LA^* 160 | self.invcopf = lambda g,alpha: self.codR(g,self.e,alpha) # right invariance pushforward from LA^* to Tg^*G 161 | self.infgen = lambda xi,g: self.dL(g,self.e,xi) # infinitesimal generator 162 | 163 | def __str__(self): 164 | return "abstract Lie group" 165 | 166 | -------------------------------------------------------------------------------- /jaxgeometry/manifolds/Euclidean.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.params import * 22 | 23 | from jaxgeometry.manifolds.manifold import * 24 | 25 | from jaxgeometry.plotting import * 26 | import matplotlib.pyplot as plt 27 | 28 | class Euclidean(Manifold): 29 | """ Euclidean space """ 30 | 31 | def __init__(self,N=3): 32 | Manifold.__init__(self) 33 | self.dim = N 34 | 35 | self.update_coords = lambda coords,_: coords 36 | 37 | ##### Metric: 38 | self.g = lambda x: jnp.eye(self.dim) 39 | 40 | # action of matrix group on elements 41 | self.act = lambda g,x: jnp.tensordot(g,x,(1,0)) 42 | 43 | def __str__(self): 44 | return "Euclidean manifold of dimension %d" % (self.dim) 45 | 46 | def newfig(self): 47 | if self.dim == 2: 48 | newfig2d() 49 | elif self.dim == 3: 50 | newfig3d() 51 | 52 | def plot(self): 53 | if self.dim == 2: 54 | plt.axis('equal') 55 | 56 | def plot_path(self, xs, u=None, color='b', color_intensity=1., linewidth=1., prevx=None, last=True, s=20, arrowcolor='k'): 57 | xs = list(xs) 58 | N = len(xs) 59 | prevx = None 60 | for i,x in enumerate(xs): 61 | self.plotx(x, u=u if i == 0 else None, 62 | color=color, 63 | color_intensity=color_intensity if i==0 or i==N-1 else .7, 64 | linewidth=linewidth, 65 | s=s, 66 | prevx=prevx, 67 | last=i==N-1) 68 | prevx = x 69 | return 70 | 71 | def plotx(self, x, u=None, color='b', color_intensity=1., linewidth=1., prevx=None, last=True, s=20, arrowcolor='k'): 72 | assert(type(x) == type(()) or x.shape[0] == self.dim) 73 | if type(x) == type(()): 74 | x = x[0] 75 | if type(prevx) == type(()): 76 | prevx = prevx[0] 77 | 78 | ax = plt.gca() 79 | 80 | if last: 81 | if self.dim == 2: 82 | plt.scatter(x[0],x[1],color=color,s=s) 83 | elif self.dim == 3: 84 | ax.scatter(x[0],x[1],x[2],color=color,s=s) 85 | else: 86 | try: 87 | xx = np.stack((prevx,x)) 88 | if self.dim == 2: 89 | plt.plot(xx[:,0],xx[:,1],linewidth=linewidth,color=color) 90 | elif self.dim == 3: 91 | ax.plot(xx[:,0],xx[:,1],xx[:,2],linewidth=linewidth,color=color) 92 | except: 93 | if self.dim == 2: 94 | plt.scatter(x[0],x[1],color=color,s=s) 95 | elif self.dim == 3: 96 | ax.scatter(x[0],x[1],x[2],color=color,s=s) 97 | 98 | try: 99 | plt.quiver(x[0], x[1], u[0], u[1], pivot='tail', linewidth=linewidth, scale=5, color=arrowcolor) 100 | except: 101 | pass 102 | 103 | 104 | -------------------------------------------------------------------------------- /jaxgeometry/manifolds/H2.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.params import * 23 | 24 | from jaxgeometry.manifolds.ellipsoid import * 25 | 26 | import matplotlib.pyplot as plt 27 | from mpl_toolkits.mplot3d import Axes3D 28 | from matplotlib import cm 29 | import matplotlib.ticker as ticker 30 | 31 | class H2(EmbeddedManifold): 32 | """ hyperbolic plane """ 33 | 34 | def __init__(self,): 35 | F = lambda x: jnp.stack([jnp.cosh(x[0][0]),jnp.sinh(x[0][0])*jnp.cos(x[0][1]),jnp.sinh(x[0][0])*jnp.sin(x[0][1])]) 36 | invF = lambda x: jnp.stack([jnp.arccosh(x[0][0]),jnp.arctan2(x[0][2],x[0][1])]) 37 | self.do_chart_update = lambda x: False 38 | 39 | EmbeddedManifold.__init__(self,F,2,3,invF=invF) 40 | 41 | # metric matrix from embedding into Minkowski space 42 | self.g = lambda x: jnp.einsum('ji,j,jl',self.JF(x),np.array([-1.,1.,1.]),self.JF(x)) 43 | 44 | 45 | def __str__(self): 46 | return "%dd dim hyperbolic space" % (self.dim,) 47 | 48 | def newfig(self): 49 | newfig3d() 50 | 51 | def plot(self,rotate=None,alpha=None,lw=0.3): 52 | ax = plt.gca() 53 | x = np.arange(-10,10,1) 54 | ax.xaxis._axinfo["grid"]['linewidth'] = lw 55 | ax.yaxis._axinfo["grid"]['linewidth'] = lw 56 | ax.zaxis._axinfo["grid"]['linewidth'] = lw 57 | ax.set_xlim(1.,2.) 58 | ax.set_ylim(-1.,1.) 59 | ax.set_zlim(-1.,1.) 60 | #ax.set_aspect("equal") 61 | if rotate is not None: 62 | ax.view_init(rotate[0],rotate[1]) 63 | # else: 64 | # ax.view_init(35,225) 65 | plt.xlabel('x') 66 | plt.ylabel('y') 67 | 68 | # ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f')) 69 | #draw ellipsoid 70 | u, v = np.mgrid[-1.5:1.5:20j, 0:2*np.pi:20j] 71 | x=np.cosh(u) 72 | y=np.sinh(u)*np.cos(v) 73 | z=np.sinh(u)*np.sin(v) 74 | ax.plot_wireframe(x, y, z, color='gray', alpha=0.5) 75 | 76 | if alpha is not None: 77 | ax.plot_surface(x, y, z, color=cm.jet(0.), alpha=alpha) 78 | 79 | 80 | def plot_field(self, field,lw=.3): 81 | ax = plt.gca(projection='3d') 82 | x = np.arange(-10,10,1) 83 | ax.xaxis._axinfo["grid"]['linewidth'] = lw 84 | ax.yaxis._axinfo["grid"]['linewidth'] = lw 85 | ax.zaxis._axinfo["grid"]['linewidth'] = lw 86 | ax.set_xlim(-1.,1.) 87 | ax.set_ylim(-1.,1.) 88 | ax.set_zlim(-1.,1.) 89 | #ax.set_aspect("equal") 90 | 91 | plt.xlabel('x') 92 | plt.ylabel('y') 93 | 94 | # ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f')) 95 | #draw ellipsoid 96 | u, v = np.mgrid[0:2*np.pi:20j, 0:np.pi:10j] 97 | x=np.cosh(u) 98 | y=np.sinh(u)*np.cos(v) 99 | z=np.sinh(u)*np.sin(v) 100 | 101 | for i in range(x.shape[0]): 102 | for j in range(x.shape[1]): 103 | Fx = np.array([x[i,j],y[i,j],z[i,j]]) 104 | chart = self.centered_chart(Fx) 105 | xcoord = self.invF((Fx,chart)) 106 | v = field((xcoord,chart)) 107 | self.plotx((xcoord,chart),v=v) 108 | -------------------------------------------------------------------------------- /jaxgeometry/manifolds/Heisenberg.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.params import * 23 | 24 | from jaxgeometry.manifolds.manifold import * 25 | 26 | import matplotlib.pyplot as plt 27 | from mpl_toolkits.mplot3d import Axes3D 28 | from matplotlib import cm 29 | import matplotlib.ticker as ticker 30 | 31 | class Heisenberg(Manifold): 32 | """ Heisenberg group """ 33 | 34 | def __init__(self,): 35 | Manifold.__init__(self) 36 | 37 | self.dim = 3 38 | self.sR_dim = 2 39 | 40 | self.update_coords = lambda coords,chart: (coords[0],chart) 41 | 42 | ##### (orthonormal) distribution 43 | self.D = lambda x: jnp.array([[1,0,-x[0][1]/2],[0,1,x[0][0]/2]]).T 44 | #self.D = lambda x: jnp.array([[1,0,0],[0,1,0],[0,0,1]]).T 45 | 46 | def __str__(self): 47 | return "Heisenberg group" 48 | 49 | def plot(self): 50 | None 51 | 52 | def plot_path(self, xs, vs=None, v_steps=None, i0=0, color='b', 53 | color_intensity=1., linewidth=1., s=15., prevx=None, prevchart=None, last=True): 54 | 55 | if vs is not None and v_steps is not None: 56 | v_steps = np.arange(0,n_steps) 57 | 58 | xs = list(xs) 59 | N = len(xs) 60 | prevx = None 61 | for i,x in enumerate(xs): 62 | self.plotx(x, v=vs[i] if vs is not None else None, 63 | v_steps=v_steps,i=i, 64 | color=color, 65 | color_intensity=color_intensity if i==0 or i==N-1 else .7, 66 | linewidth=linewidth, 67 | s=s, 68 | prevx=prevx, 69 | last=i==(N-1)) 70 | prevx = x 71 | return 72 | 73 | # plot x in coordinates 74 | def plotx(self, x, u=None, v=None, v_steps=None, i=0, color='b', 75 | color_intensity=1., linewidth=1., s=15., prevx=None, prevchart=None, last=True): 76 | if (type(x) != type(())): 77 | x = (x,) 78 | if (prevx is not None and type(prevx) != type(())): 79 | prevx = (prevx,) 80 | 81 | if v is not None and v_steps is None: 82 | v_steps = np.arange(0,n_steps,10) 83 | 84 | ax = plt.gca() 85 | if prevx is None or last: 86 | ax.scatter(x[0][0],x[0][1],x[0][2],color=color,s=s) 87 | if prevx is not None: 88 | xx = np.stack((prevx[0],x[0])) 89 | ax.plot(xx[:,0],xx[:,1],xx[:,2],linewidth=linewidth,color=color) 90 | 91 | if u is not None: 92 | ax.quiver(x[0][0], x[0][1], x[0][2], u[0], u[1], u[2], 93 | pivot='tail', 94 | arrow_length_ratio = 0.15, linewidths=linewidth, length=0.5, 95 | color='black') 96 | 97 | if v is not None: 98 | if i in v_steps: 99 | ax.quiver(x[0][0], x[0][1], x[0][2], v[0], v[1], v[2], 100 | pivot='tail', 101 | arrow_length_ratio = 0.15, linewidths=linewidth, length=1.0, 102 | color='black') 103 | 104 | # funtion to evaluate spherical harmonics 105 | def Y(l,m,theta,phi): 106 | if l == 0 and m == 0: 107 | return 1/np.sqrt(4*np.pi) 108 | elif l == 1 and m == -1: 109 | return np.sqrt(3/(4*np.pi))*np.sin(theta)*np.exp(1j*phi) 110 | elif l == 1 and m == 0: 111 | return np.sqrt(3/(4*np.pi))*np.cos(theta) 112 | elif l == 1 and m == 1: 113 | return np.sqrt(3/(4*np.pi))*np.sin(theta)*np.exp(-1j*phi) 114 | elif l == 2 and m == -2: 115 | return np.sqrt(15/(16*np.pi))*np.sin(theta)**2*np.exp(2j*phi) 116 | elif l == 2 and m == -1: 117 | return np.sqrt(15/(8*np.pi))*np.sin(theta)*np.cos(theta)*np.exp(1j*phi) 118 | elif l == 2 and m == 0: 119 | return np.sqrt(5/(16*np.pi))*(3*np.cos(theta)**2-1) 120 | elif l == 2 and m == 1: 121 | return np.sqrt(15/(8*np.pi))*np.sin(theta)*np.cos(theta)*np.exp(-1j*phi) 122 | elif l == 2 and m == 2: 123 | return np.sqrt(15/(16*np.pi))*np.sin(theta)**2*np.exp(-2j*phi) 124 | else: 125 | return 0 126 | 127 | -------------------------------------------------------------------------------- /jaxgeometry/manifolds/S2.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.params import * 23 | 24 | from jaxgeometry.manifolds.ellipsoid import * 25 | 26 | import matplotlib.pyplot as plt 27 | from mpl_toolkits.mplot3d import Axes3D 28 | from matplotlib import cm 29 | import matplotlib.ticker as ticker 30 | 31 | class S2(Ellipsoid): 32 | """ 2d Sphere """ 33 | 34 | def __init__(self,use_spherical_coords=False,chart_center='z'): 35 | Ellipsoid.__init__(self,params=[1.,1.,1.],chart_center=chart_center,use_spherical_coords=use_spherical_coords) 36 | 37 | def __str__(self): 38 | return "%dd sphere (ellipsoid parameters %s, spherical_coords: %s)" % (self.dim,self.params,self.use_spherical_coords) 39 | 40 | -------------------------------------------------------------------------------- /jaxgeometry/manifolds/SPDN.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.params import * 22 | 23 | from jaxgeometry.manifolds.manifold import * 24 | 25 | from jaxgeometry.plotting import * 26 | import matplotlib.pyplot as plt 27 | from matplotlib import cm 28 | 29 | class SPDN(EmbeddedManifold): 30 | """ manifold of symmetric positive definite matrices """ 31 | 32 | def __init__(self,N=3): 33 | self.N = N 34 | dim = N*(N+1)//2 35 | emb_dim = N*N 36 | EmbeddedManifold.__init__(self,dim=dim,emb_dim=emb_dim) 37 | 38 | self.act = lambda g,q: jnp.tensordot(g,jnp.tensordot(q.reshape((N,N)),g,(1,1)),(1,0)).flatten() 39 | self.acts = jax.vmap(self.act,(0,None)) 40 | 41 | def __str__(self): 42 | return "SPDN(%d), dim %d" % (self.N,self.dim) 43 | 44 | 45 | def plot(self, rotate=None, alpha = None): 46 | ax = plt.gca() 47 | #ax.set_aspect("equal") 48 | if rotate != None: 49 | ax.view_init(rotate[0],rotate[1]) 50 | # else: 51 | # ax.view_init(35,225) 52 | plt.xlabel('x') 53 | plt.ylabel('y') 54 | 55 | 56 | def plot_path(self, x,color_intensity=1.,color=None,linewidth=3.,prevx=None,ellipsoid=None,i=None,maxi=None): 57 | assert(len(x.shape)>1) 58 | for i in range(x.shape[0]): 59 | self.plotx(x[i], 60 | linewidth=linewidth if i==0 or i==x.shape[0]-1 else .3, 61 | color_intensity=color_intensity if i==0 or i==x.shape[0]-1 else .7, 62 | prevx=x[i-1] if i>0 else None,ellipsoid=ellipsoid,i=i,maxi=x.shape[0]) 63 | return 64 | 65 | def plotx(self, x,color_intensity=1.,color=None,linewidth=3.,prevx=None,ellipsoid=None,i=None,maxi=None): 66 | x = x.reshape((self.N,self.N)) 67 | (w,V) = np.linalg.eigh(x) 68 | s = np.sqrt(w[np.newaxis,:])*V # scaled eigenvectors 69 | if prevx is not None: 70 | prevx = prevx.reshape((self.N,self.N)) 71 | (prevw,prevV) = np.linalg.eigh(prevx) 72 | prevs = np.sqrt(prevw[np.newaxis,:])*prevV # scaled eigenvectors 73 | ss = np.stack((prevs,s)) 74 | 75 | colors = color_intensity*np.array([[1,0,0],[0,1,0],[0,0,1]]) 76 | if ellipsoid is None: 77 | for i in range(s.shape[1]): 78 | plt.quiver(0,0,0,s[0,i],s[1,i],s[2,i],pivot='tail',linewidth=linewidth,color=colors[i] if color is None else color,arrow_length_ratio=.15,length=1) 79 | if prevx is not None: 80 | plt.plot(ss[:,0,i],ss[:,1,i],ss[:,2,i],linewidth=.3,color=colors[i]) 81 | else: 82 | try: 83 | if i % int(ellipsoid['step']) != 0 and i != maxi-1: 84 | return 85 | except: 86 | pass 87 | try: 88 | if ellipsoid['subplot']: 89 | (fig,ax) = newfig3d(1,maxi//int(ellipsoid['step'])+1,i//int(ellipsoid['step'])+1,new_figure=i==0) 90 | except: 91 | (fig,ax) = newfig3d() 92 | #draw ellipsoid, from https://stackoverflow.com/questions/7819498/plotting-ellipsoid-with-matplotlib 93 | U, ss, rotation = np.linalg.svd(x) 94 | radii = np.sqrt(ss) 95 | u = np.linspace(0., 2.*np.pi, 20) 96 | v = np.linspace(0., np.pi, 10) 97 | x = radii[0] * np.outer(np.cos(u), np.sin(v)) 98 | y = radii[1] * np.outer(np.sin(u), np.sin(v)) 99 | z = radii[2] * np.outer(np.ones_like(u), np.cos(v)) 100 | for l in range(x.shape[0]): 101 | for k in range(x.shape[1]): 102 | [x[l,k],y[l,k],z[l,k]] = np.dot([x[l,k],y[l,k],z[l,k]], rotation) 103 | ax.plot_surface(x, y, z, facecolors=cm.winter(y/np.amax(y)), linewidth=0, alpha=ellipsoid['alpha']) 104 | for i in range(s.shape[1]): 105 | plt.quiver(0,0,0,s[0,i],s[1,i],s[2,i],pivot='tail',linewidth=linewidth,color=colors[i] if color is None else color,arrow_length_ratio=.15,length=1) 106 | plt.axis('off') 107 | -------------------------------------------------------------------------------- /jaxgeometry/manifolds/cylinder.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.params import * 23 | 24 | from jaxgeometry.manifolds.manifold import * 25 | 26 | from jaxgeometry.plotting import * 27 | import matplotlib.pyplot as plt 28 | from mpl_toolkits.mplot3d import Axes3D 29 | from matplotlib import cm 30 | import matplotlib.ticker as ticker 31 | 32 | class Cylinder(EmbeddedManifold): 33 | """ 2d Cylinder """ 34 | 35 | def chart(self): 36 | """ return default coordinate chart """ 37 | return jnp.zeros(self.dim) 38 | 39 | def centered_chart(self,x): 40 | """ return centered coordinate chart """ 41 | if type(x) == type(()): # coordinate tuple 42 | Fx = jax.lax.stop_gradient(self.F(x)) 43 | else: 44 | Fx = x # already in embedding space 45 | return self.invF((Fx,self.chart())) # chart centered at coords 46 | 47 | def get_B(self,v): 48 | """ R^3 basis with first basis vector v """ 49 | b1 = v 50 | k = jnp.argmin(jnp.abs(v)) 51 | ek = jnp.eye(3)[:,k] 52 | b2 = ek-v[k]*v 53 | b3 = cross(b1,b2) 54 | return jnp.stack((b1,b2,b3),axis=1) 55 | 56 | # Logarithm with standard Riemannian metric 57 | def StdLog(self,_x,y): 58 | (x,chart) = self.update_coords(_x,self.centered_chart(self.F(_x))) 59 | y = self.invF((y,chart)) 60 | return self.update_vector((x,chart),_x[0],_x[1],y-x) 61 | 62 | def __init__(self,params=(1.,jnp.array([0.,1.,0.]),0.)): 63 | self.radius = params[0] # axis of cylinder 64 | self.orientation = jnp.array(params[1]) # axis of cylinder 65 | self.theta = params[2] # angle around rotation axis 66 | 67 | F = lambda x: jnp.dot(self.get_B(self.orientation), 68 | jnp.stack([x[0][1]+x[1][1],self.radius*jnp.cos(self.theta+x[1][0]+x[0][0]),self.radius*jnp.sin(self.theta+x[1][0]+x[0][0])])) 69 | def invF(x): 70 | Rinvx = jnp.linalg.solve(self.get_B(self.orientation),x[0]) 71 | rotangle = -(self.theta+x[1][0]) 72 | rot = jnp.dot(jnp.stack( 73 | (jnp.stack((jnp.cos(rotangle),-jnp.sin(rotangle))), 74 | jnp.stack((jnp.sin(rotangle),jnp.cos(rotangle))))), 75 | Rinvx[1:]) 76 | return jnp.stack([jnp.arctan2(rot[1],rot[0]),Rinvx[0]-x[1][1]]) 77 | self.do_chart_update = lambda x: jnp.max(jnp.abs(x[0])) >= np.pi/4 # look for a new chart if true 78 | 79 | EmbeddedManifold.__init__(self,F,2,3,invF=invF) 80 | 81 | def __str__(self): 82 | return "cylinder in R^3, radius %s, axis %s, rotation around axis %s" % (self.radius,self.orientation,self.theta) 83 | 84 | def newfig(self): 85 | newfig3d() 86 | 87 | def plot(self, rotate=None,alpha=None,lw=0.3): 88 | ax = plt.gca() 89 | x = np.arange(-10,10,1) 90 | ax.xaxis._axinfo["grid"]['linewidth'] = lw 91 | ax.yaxis._axinfo["grid"]['linewidth'] = lw 92 | ax.zaxis._axinfo["grid"]['linewidth'] = lw 93 | ax.set_xlim(-1.,1.) 94 | ax.set_ylim(-1.,1.) 95 | ax.set_zlim(-1.,1.) 96 | #ax.set_aspect("equal") 97 | if rotate is not None: 98 | ax.view_init(rotate[0],rotate[1]) 99 | # else: 100 | # ax.view_init(35,225) 101 | plt.xlabel('x') 102 | plt.ylabel('y') 103 | 104 | # ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f')) 105 | #draw ellipsoid 106 | u, v = np.mgrid[-np.pi:np.pi:20j, -np.pi:np.pi:10j] 107 | x = np.zeros(u.shape) 108 | y = np.zeros(u.shape) 109 | z = np.zeros(u.shape) 110 | for i in range(u.shape[0]): 111 | for j in range(u.shape[1]): 112 | w = self.F(self.coords(jnp.array([u[i,j],v[i,j]]))) 113 | x[i,j] = w[0]; y[i,j] = w[1]; z[i,j] = w[2] 114 | ax.plot_wireframe(x, y, z, color='gray', alpha=0.5) 115 | 116 | if alpha is not None: 117 | ax.plot_surface(x, y, z, color=cm.jet(0.), alpha=alpha) 118 | 119 | 120 | def plot_field(self, field,lw=.3): 121 | ax = plt.gca() 122 | x = np.arange(-10,10,1) 123 | ax.xaxis._axinfo["grid"]['linewidth'] = lw 124 | ax.yaxis._axinfo["grid"]['linewidth'] = lw 125 | ax.zaxis._axinfo["grid"]['linewidth'] = lw 126 | ax.set_xlim(-1.,1.) 127 | ax.set_ylim(-1.,1.) 128 | ax.set_zlim(-1.,1.) 129 | #ax.set_aspect("equal") 130 | 131 | plt.xlabel('x') 132 | plt.ylabel('y') 133 | 134 | u, v = np.mgrid[-np.pi:np.pi:40j, -np.pi:np.pi:20j] 135 | x = np.zeros(u.shape) 136 | y = np.zeros(u.shape) 137 | z = np.zeros(u.shape) 138 | for i in range(u.shape[0]): 139 | for j in range(u.shape[1]): 140 | w = self.F(self.coords(jnp.array([u[i,j],v[i,j]]))) 141 | x[i,j] = w[0]; y[i,j] = w[1]; z[i,j] = w[2] 142 | 143 | for i in range(x.shape[0]): 144 | for j in range(x.shape[1]): 145 | Fx = np.array([x[i,j],y[i,j],z[i,j]]) 146 | chart = self.centered_chartf(Fx) 147 | xcoord = self.invFf((Fx,chart)) 148 | v = field((xcoord,chart)) 149 | self.plotx((xcoord,chart),v=v) 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /jaxgeometry/manifolds/ellipsoid.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.params import * 22 | 23 | from jaxgeometry.manifolds.manifold import * 24 | 25 | from jaxgeometry.plotting import * 26 | import matplotlib.pyplot as plt 27 | from mpl_toolkits.mplot3d import Axes3D 28 | from matplotlib import cm 29 | import matplotlib.ticker as ticker 30 | 31 | class Ellipsoid(EmbeddedManifold): 32 | """ 2d Ellipsoid """ 33 | 34 | def chart(self): 35 | """ return default coordinate chart """ 36 | if self.chart_center == 'x': 37 | return jnp.eye(3)[:,0] 38 | elif self.chart_center == 'y': 39 | return jnp.eye(3)[:,1] 40 | elif self.chart_center == 'z': 41 | return jnp.eye(3)[:,2] 42 | else: 43 | assert(False) 44 | 45 | def centered_chart(self,x): 46 | """ return centered coordinate chart """ 47 | if type(x) == type(()): # coordinate tuple 48 | return jax.lax.stop_gradient(self.F(x))/self.params 49 | else: 50 | return x/self.params # already in embedding space 51 | 52 | def get_B(self,v): 53 | """ R^3 basis with first basis vector v """ 54 | b1 = v 55 | k = jnp.argmin(jnp.abs(v)) 56 | ek = jnp.eye(3)[:,k] 57 | b2 = ek-v[k]*v 58 | b3 = cross(b1,b2) 59 | return jnp.stack((b1,b2,b3),axis=1) 60 | 61 | # Logarithm with standard Riemannian metric on S^2 62 | def StdLogEmb(self, x,y): 63 | y = y/self.params # from ellipsoid to S^2 64 | proj = lambda x,y: jnp.dot(x,y)*x 65 | Fx = self.F(x)/self.params 66 | v = y-proj(Fx,y) 67 | theta = jnp.arccos(jnp.dot(Fx,y)) 68 | normv = jnp.linalg.norm(v,2) 69 | w = jax.lax.cond(normv >= 1e-5, 70 | lambda _: theta/normv*v, 71 | lambda _: jnp.zeros_like(v), 72 | None) 73 | return self.params*w 74 | def StdLog(self, x,y): 75 | Fx = self.F(x)/self.params 76 | return jnp.dot(self.invJF((Fx,x[1])),self.StdLogEmb(x,y)) 77 | 78 | def __init__(self,params=np.array([1.,1.,1.]),chart_center='z',use_spherical_coords=False): 79 | self.params = jnp.array(params) # ellipsoid parameters (e.g. [1.,1.,1.] for sphere) 80 | self.use_spherical_coords = use_spherical_coords 81 | self.chart_center = chart_center 82 | 83 | if not use_spherical_coords: 84 | F = lambda x: self.params*jnp.dot(self.get_B(x[1]),jnp.stack([-(-1+x[0][0]**2+x[0][1]**2),2*x[0][0],2*x[0][1]])/(1+x[0][0]**2+x[0][1]**2)) 85 | def invF(x): 86 | Rinvx = jnp.linalg.solve(self.get_B(x[1]),x[0]/self.params) 87 | return jnp.stack([Rinvx[1]/(1+Rinvx[0]),Rinvx[2]/(1+Rinvx[0])]) 88 | self.do_chart_update = lambda x: jnp.linalg.norm(x[0]) > .1 # look for a new chart if true 89 | # spherical coordinates, no charts 90 | self.F_spherical = lambda phitheta: self.params*jnp.stack([jnp.sin(phitheta[1]-np.pi/2)*jnp.cos(phitheta[0]),jnp.sin(phitheta[1]-np.pi/2)*jnp.sin(phitheta[0]),jnp.cos(phitheta[1]-np.pi/2)]) 91 | self.JF_spherical = lambda x: jnp.jacobian(self.F_spherical(x),x) 92 | self.F_spherical_inv = lambda x: jnp.stack([jnp.arctan2(x[1],x[0]),jnp.arccos(x[2])]) 93 | self.g_spherical = lambda x: jnp.dot(self.JF_spherical(x).T,self.JF_spherical(x)) 94 | self.mu_Q_spherical = lambda x: 1./jnp.nlinalg.Det()(self.g_spherical(x)) 95 | 96 | ## optionally use spherical coordinates in chart computations 97 | #if use_spherical_coords: 98 | # F = lambda x: jnp.dot(x[1],self.F_spherical(x[0])) 99 | 100 | EmbeddedManifold.__init__(self,F,2,3,invF=invF) 101 | 102 | # action of matrix group on elements 103 | self.act = lambda g,x: jnp.tensordot(g,x,(1,0)) 104 | self.acts = lambda g,x: jnp.tensordot(g,x,(2,0)) 105 | 106 | 107 | def __str__(self): 108 | return "%dd ellipsoid, parameters %s, spherical coords %s" % (self.dim,self.params,self.use_spherical_coords) 109 | 110 | def newfig(self): 111 | newfig3d() 112 | 113 | def plot(self,rotate=None,alpha=None,lw=0.3,color='gray',scale=1.): 114 | ax = plt.gca() 115 | x = np.arange(-10,10,1) 116 | ax.xaxis._axinfo["grid"]['linewidth'] = lw 117 | ax.yaxis._axinfo["grid"]['linewidth'] = lw 118 | ax.zaxis._axinfo["grid"]['linewidth'] = lw 119 | ax.set_xlim(-1.,1.) 120 | ax.set_ylim(-1.,1.) 121 | ax.set_zlim(-1.,1.) 122 | #ax.set_aspect("equal") 123 | if rotate is not None: 124 | ax.view_init(rotate[0],rotate[1]) 125 | # else: 126 | # ax.view_init(35,225) 127 | plt.xlabel('x') 128 | plt.ylabel('y') 129 | 130 | # ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f')) 131 | #draw ellipsoid 132 | u, v = np.mgrid[0:2*np.pi:20j, 0:np.pi:10j] 133 | x=scale*self.params[0]*np.cos(u)*np.sin(v) 134 | y=scale*self.params[1]*np.sin(u)*np.sin(v) 135 | z=scale*self.params[2]*np.cos(v) 136 | ax.plot_wireframe(x, y, z, color=color, alpha=0.5) 137 | 138 | if alpha is not None: 139 | ax.plot_surface(x, y, z, color=color, alpha=alpha) 140 | 141 | 142 | def plot_field(self, field,lw=.3, scale=1.): 143 | ax = plt.gca() 144 | x = np.arange(-10,10,1) 145 | ax.xaxis._axinfo["grid"]['linewidth'] = lw 146 | ax.yaxis._axinfo["grid"]['linewidth'] = lw 147 | ax.zaxis._axinfo["grid"]['linewidth'] = lw 148 | ax.set_xlim(-1.,1.) 149 | ax.set_ylim(-1.,1.) 150 | ax.set_zlim(-1.,1.) 151 | #ax.set_aspect("equal") 152 | 153 | plt.xlabel('x') 154 | plt.ylabel('y') 155 | 156 | # ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f')) 157 | #draw ellipsoid 158 | u, v = np.mgrid[0:2*np.pi:40j, 0:np.pi:20j] 159 | x=self.params[0]*np.cos(u)*np.sin(v) 160 | y=self.params[1]*np.sin(u)*np.sin(v) 161 | z=self.params[2]*np.cos(v) 162 | 163 | for i in range(x.shape[0]): 164 | for j in range(x.shape[1]): 165 | Fx = np.array([x[i,j],y[i,j],z[i,j]]) 166 | chart = self.centered_chart(Fx) 167 | xcoord = self.invF((Fx,chart)) 168 | v = field((xcoord,chart)) 169 | self.plotx((xcoord,chart),v=scale*v) 170 | -------------------------------------------------------------------------------- /jaxgeometry/manifolds/latent.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.params import * 22 | 23 | from jaxgeometry.manifolds.manifold import * 24 | 25 | from jaxgeometry.plotting import * 26 | import matplotlib.pyplot as plt 27 | from mpl_toolkits.mplot3d import Axes3D 28 | from matplotlib import cm 29 | import matplotlib.ticker as ticker 30 | from scipy.stats import norm 31 | 32 | class Latent(EmbeddedManifold): 33 | """ Latent space manifold define from embedding function F:R^dim->R^emb_dim, f e.g. a neural network """ 34 | 35 | def __init__(self,F,dim,emb_dim,invF=None): 36 | EmbeddedManifold.__init__(self,F,dim,emb_dim,invF) 37 | 38 | # metric matrix 39 | self.g = lambda x: T.dot(self.JF(x).T,self.JF(x)) 40 | 41 | def newfig(self): 42 | if self.emb_dim.eval() == 3: 43 | newfig3d() 44 | elif self.dim.eval() == 2: 45 | newfig2d() 46 | 47 | def plot(self, rotate=None, alpha=None, lw=0.3): 48 | if self.emb_dim.eval() == 3: 49 | ax = plt.gca(projection='3d') 50 | x = np.arange(-10, 10, 1) 51 | ax.xaxis._axinfo["grid"]['linewidth'] = lw 52 | ax.yaxis._axinfo["grid"]['linewidth'] = lw 53 | ax.zaxis._axinfo["grid"]['linewidth'] = lw 54 | ax.set_xlim(-1., 1.) 55 | ax.set_ylim(-1., 1.) 56 | ax.set_zlim(-1., 1.) 57 | ax.set_aspect("equal") 58 | if rotate is not None: 59 | ax.view_init(rotate[0], rotate[1]) 60 | # else: 61 | # ax.view_init(35,225) 62 | plt.xlabel('x') 63 | plt.ylabel('y') 64 | 65 | # ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f')) 66 | # draw surface 67 | X, Y = np.meshgrid(norm.ppf(np.linspace(0.05, 0.95, 20)), norm.ppf(np.linspace(0.05, 0.95, 20))) 68 | xy = np.vstack([X.ravel(), Y.ravel()]).T 69 | xyz = np.apply_along_axis(self.Ff, 1, xy) 70 | x = xyz[:, 0].reshape(X.shape); 71 | y = xyz[:, 1].reshape(X.shape); 72 | z = xyz[:, 2].reshape(X.shape) 73 | print(z.shape) 74 | ax.plot_wireframe(x, y, z, color='gray', alpha=0.5) 75 | 76 | if alpha is not None: 77 | ax.plot_surface(x, y, z, color=cm.jet(0.), alpha=alpha) 78 | 79 | # plot x on ellipsoid. x can be either in coordinates or in R^3 80 | def plotx(self, x, u=None, v=None, N_vec=np.arange(0,n_steps.eval()), i0=0, color='b', color_intensity=1., linewidth=1., s=15., prevx=None, last=True): 81 | if len(x.shape)>1: 82 | for i in range(x.shape[0]): 83 | self.plotx(x[i], u=u if i == 0 else None, v=v[i] if v is not None else None, 84 | N_vec=N_vec,i0=i, 85 | color=color, 86 | color_intensity=color_intensity if i==0 or i==x.shape[0]-1 else .7, 87 | linewidth=linewidth, 88 | s=s, 89 | prevx=x[i-1] if i>0 else None, 90 | last=i==(x.shape[0]-1)) 91 | return 92 | 93 | if self.emb_dim.eval() == 3: 94 | xcoords = x 95 | if x.shape[0] < 3: # map to embedding space 96 | x = self.Ff(x) 97 | 98 | ax = plt.gca(projection='3d') 99 | if prevx is None or last: 100 | ax.scatter(x[0],x[1],x[2],color=color,s=s) 101 | if prevx is not None: 102 | if prevx.shape[0] < 3: 103 | prevx = self.Ff(prevx) 104 | xx = np.stack((prevx,x)) 105 | ax.plot(xx[:,0],xx[:,1],xx[:,2],linewidth=linewidth,color=color) 106 | 107 | if u is not None: 108 | JFx = self.JFf(xcoords) 109 | u = np.dot(JFx, u) 110 | ax.quiver(x[0], x[1], x[2], u[0], u[1], u[2], 111 | pivot='tail', 112 | arrow_length_ratio = 0.15, linewidths=linewidth, length=0.5, 113 | color='black') 114 | 115 | if v is not None: 116 | #Seq = lambda m, n: [t*n//m + n//(2*m) for t in range(m)] 117 | #Seqv = np.hstack([0,Seq(N_vec,n_steps.get_value())]) 118 | if i0 in N_vec:#Seqv: 119 | JFx = self.JFf(xcoords) 120 | v = np.dot(JFx, v) 121 | ax.quiver(x[0], x[1], x[2], v[0], v[1], v[2], 122 | pivot='tail', 123 | arrow_length_ratio = 0.15, linewidths=linewidth, length=0.5, 124 | color='black') 125 | elif self.dim.eval() == 2: 126 | if prevx is None or last: 127 | plt.scatter(x[0],x[1],color=color,s=s) 128 | if prevx is not None: 129 | xx = np.stack((prevx,x)) 130 | plt.plot(xx[:,0],xx[:,1],linewidth=linewidth,color=color) 131 | if v is not None: 132 | if i0 in N_vec:#Seqv: 133 | plt.quiver(x[0], x[1], v[0], v[1], pivot='tail', linewidth=linewidth, color='black', 134 | angles='xy', scale_units='xy', scale=1) 135 | 136 | -------------------------------------------------------------------------------- /jaxgeometry/manifolds/manifold.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.params import * 22 | 23 | import matplotlib.pyplot as plt 24 | 25 | class Manifold(object): 26 | """ Base manifold class """ 27 | 28 | def __init__(self): 29 | self.dim = None 30 | if not hasattr(self, 'do_chart_update'): 31 | self.do_chart_update = None # set to relevant function if what updates are desired 32 | 33 | def chart(self): 34 | """ return default or specified coordinate chart. This method will generally be overriding by inheriting classes """ 35 | # default value 36 | return jnp.zeros(1) 37 | 38 | def centered_chart(self,coords): 39 | """ return centered coordinate chart. Must be implemented by inheriting classes 40 | Generally wish to stop gradient computations through the chart choice 41 | """ 42 | return jax.lax.stop_gradient(jnp.zeros(1)) 43 | 44 | def coords(self,coords=None,chart=None): 45 | """ return coordinate representation of point in manifold """ 46 | if coords is None: 47 | coords = jnp.zeros(self.dim) 48 | if chart is None: 49 | chart = self.chart() 50 | 51 | return (jnp.array(coords),chart) 52 | 53 | def update_coords(self,coords,new_chart): 54 | """ change between charts """ 55 | assert(False) # not implemented here 56 | 57 | def update_vector(self,coords,new_coords,new_chart,v): 58 | """ change tangent vector between charts """ 59 | assert(False) # not implemented here 60 | 61 | def update_covector(self,coords,new_coords,new_chart,p): 62 | """ change cotangent vector between charts """ 63 | assert(False) # not implemented here 64 | 65 | def newfig(self): 66 | """ open new plot for manifold """ 67 | 68 | def __str__(self): 69 | return "abstract manifold" 70 | 71 | class EmbeddedManifold(Manifold): 72 | """ Embedded manifold base class """ 73 | 74 | def update_coords(self,coords,new_chart): 75 | """ change between charts """ 76 | return (self.invF((self.F(coords),new_chart)),new_chart) 77 | 78 | def update_vector(self,coords,new_coords,new_chart,v): 79 | """ change tangent vector between charts """ 80 | return jnp.tensordot(self.invJF((self.F((new_coords,new_chart)),new_chart)),jnp.tensordot(self.JF(coords),v,(1,0)),(1,0)) 81 | 82 | def update_covector(self,coords,new_coords,new_chart,p): 83 | """ change cotangent vector between charts """ 84 | return jnp.tensordot(self.JF((new_coords,new_chart)).T,jnp.tensordot(self.invJF((self.F(coords),coords[1])).T,p,(1,0)),(1,0)) 85 | 86 | def __init__(self,F=None,dim=None,emb_dim=None,invF=None): 87 | Manifold.__init__(self) 88 | self.dim = dim 89 | self.emb_dim = emb_dim 90 | 91 | # embedding map and its inverse 92 | if F is not None: 93 | self.F = F 94 | self.invF = invF 95 | self.JF = jacfwdx(self.F) 96 | self.invJF = jacfwdx(self.invF) 97 | 98 | # metric matrix 99 | self.g = lambda x: jnp.tensordot(self.JF(x),self.JF(x),(0,0)) 100 | 101 | 102 | def plot_path(self, xs, vs=None, v_steps=None, i0=0, color='b', 103 | color_intensity=1., linewidth=1., s=15., prevx=None, prevchart=None, last=True): 104 | 105 | if vs is not None and v_steps is not None: 106 | v_steps = np.arange(0,n_steps) 107 | 108 | xs = list(xs) 109 | N = len(xs) 110 | prevx = None 111 | for i,x in enumerate(xs): 112 | xx = x[0] if type(x) is tuple else x 113 | if xx.shape[0] > self.dim and (self.emb_dim == None or xx.shape[0] != self.emb_dim): # attached vectors to display 114 | v = xx[self.dim:].reshape((self.dim,-1)) 115 | x = (xx[0:self.dim],x[1]) if type(x) is tuple else xx[0:self.dim] 116 | elif vs is not None: 117 | v = vs[i] 118 | else: 119 | v = None 120 | self.plotx(x, v=v, 121 | v_steps=v_steps,i=i, 122 | color=color, 123 | color_intensity=color_intensity if i==0 or i==N-1 else .7, 124 | linewidth=linewidth, 125 | s=s, 126 | prevx=prevx, 127 | last=i==(N-1)) 128 | prevx = x 129 | return 130 | 131 | # plot x. x can be either in coordinates or in R^3 132 | def plotx(self, x, u=None, v=None, v_steps=None, i=0, color='b', 133 | color_intensity=1., linewidth=1., s=15., prevx=None, prevchart=None, last=True): 134 | 135 | assert(type(x) == type(()) or x.shape[0] == self.emb_dim) 136 | 137 | if v is not None and v_steps is None: 138 | v_steps = np.arange(0,n_steps) 139 | 140 | if type(x) == type(()): # map to manifold 141 | Fx = self.F(x) 142 | chart = x[1] 143 | else: # get coordinates 144 | Fx = x 145 | chart = self.centered_chart(Fx) 146 | x = (self.invF((Fx,chart)),chart) 147 | 148 | if prevx is not None: 149 | if type(prevx) == type(()): # map to manifold 150 | Fprevx = self.F(prevx) 151 | else: 152 | Fprevx = prevx 153 | prevx = (self.invF((Fprevx,chart)),chart) 154 | 155 | ax = plt.gca() 156 | if prevx is None or last: 157 | ax.scatter(Fx[0],Fx[1],Fx[2],color=color,s=s) 158 | if prevx is not None: 159 | xx = np.stack((Fprevx,Fx)) 160 | ax.plot(xx[:,0],xx[:,1],xx[:,2],linewidth=linewidth,color=color) 161 | 162 | if u is not None: 163 | Fu = np.dot(self.JF(x), u) 164 | ax.quiver(Fx[0], Fx[1], Fx[2], Fu[0], Fu[1], Fu[2], 165 | pivot='tail', 166 | arrow_length_ratio = 0.15, linewidths=linewidth, length=0.5, 167 | color='black') 168 | 169 | if v is not None: 170 | if i in v_steps: 171 | if not v.shape[0] == self.emb_dim: 172 | v = np.dot(self.JF(x), v) 173 | ax.quiver(Fx[0], Fx[1], Fx[2], v[0], v[1], v[2], 174 | pivot='tail', 175 | arrow_length_ratio = 0.15, linewidths=linewidth, length=0.5, 176 | color='black') 177 | 178 | def __str__(self): 179 | return "dim %d manifold embedded in R^%d" % (self.dim,self.emb_dim) 180 | -------------------------------------------------------------------------------- /jaxgeometry/manifolds/torus.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.params import * 23 | 24 | from jaxgeometry.manifolds.manifold import * 25 | 26 | from jaxgeometry.plotting import * 27 | import matplotlib.pyplot as plt 28 | from mpl_toolkits.mplot3d import Axes3D 29 | from matplotlib import cm 30 | import matplotlib.ticker as ticker 31 | 32 | class Torus(EmbeddedManifold): 33 | """ 2d torus, embedded metric """ 34 | 35 | def chart(self): 36 | """ return default coordinate chart """ 37 | return jnp.zeros(self.dim) 38 | 39 | def centered_chart(self,x): 40 | """ return centered coordinate chart """ 41 | if type(x) == type(()): # coordinate tuple 42 | Fx = jax.lax.stop_gradient(self.F(x)) 43 | else: 44 | Fx = x # already in embedding space 45 | return self.invF((Fx,self.chart())) # chart centered at coords 46 | 47 | def get_B(self,v): 48 | """ R^3 basis with first basis vector v """ 49 | b1 = v 50 | k = jnp.argmin(jnp.abs(v)) 51 | ek = jnp.eye(3)[:,k] 52 | b2 = ek-v[k]*v 53 | b3 = cross(b1,b2) 54 | return jnp.stack((b1,b2,b3),axis=1) 55 | 56 | # Logarithm with standard Riemannian metric 57 | def StdLog(self,_x,y): 58 | (x,chart) = self.update_coords(_x,self.centered_chart(self.F(_x))) 59 | y = self.invF((y,chart)) 60 | return self.update_vector((x,chart),_x[0],_x[1],y-x) 61 | 62 | def __init__(self,params=(1.,2.,jnp.array([0.,1.,0.]))): 63 | self.radius = params[0] # axis of small circle 64 | self.Radius = params[1] # axis of large circle 65 | self.orientation = jnp.array(params[2]) # axis of cylinder 66 | 67 | F = lambda x: jnp.dot(self.get_B(self.orientation), 68 | jnp.stack([self.radius*jnp.sin(x[0][1]+x[1][1]), 69 | (self.Radius+self.radius*jnp.cos(x[0][1]+x[1][1]))*jnp.cos(x[0][0]+x[1][0]), 70 | (self.Radius+self.radius*jnp.cos(x[0][1]+x[1][1]))*jnp.sin(x[0][0]+x[1][0])])) 71 | def invF(x): 72 | Rinvx = jnp.linalg.solve(self.get_B(self.orientation),x[0]) 73 | rotangle0 = -x[1][0] 74 | rot0 = jnp.dot(jnp.stack( 75 | (jnp.stack((jnp.cos(rotangle0),-jnp.sin(rotangle0))), 76 | jnp.stack((jnp.sin(rotangle0),jnp.cos(rotangle0))))), 77 | Rinvx[1:]) 78 | phi = jnp.arctan2(rot0[1],rot0[0]) 79 | rotangle1 = -x[1][1] 80 | #epsilons = jnp.where(jnp.cos(phi) >= 1e-4, 81 | # jnp.stack((0.,1e-4)), 82 | # jnp.stack((1e-4,0.))) # to avoid divide by zero in gradient computations 83 | #rcosphi = jnp.where(jnp.cos(phi) >= 1e-4, 84 | # rot0[0]/(jnp.cos(phi)+epsilons[0])-self.Radius, 85 | # rot0[1]/(jnp.sin(phi)+epsilons[1])-self.Radius) 86 | #rcosphi = jnp.where(jnp.cos(phi) >= 1e-4, 87 | # rot0[0]/jnp.cos(phi)-self.Radius, 88 | # rot0[1]/jnp.sin(phi)-self.Radius) 89 | rcosphi = jax.lax.cond(jnp.cos(phi) >= 1e-4, 90 | lambda _: rot0[0]/jnp.cos(phi)-self.Radius, 91 | lambda _: rot0[1]/jnp.sin(phi)-self.Radius,operand=None) 92 | rot1 = jnp.dot(jnp.stack( 93 | (jnp.stack((jnp.cos(rotangle1),-jnp.sin(rotangle1))), 94 | jnp.stack((jnp.sin(rotangle1),jnp.cos(rotangle1))))), 95 | jnp.stack((rcosphi,Rinvx[0]))) 96 | theta = jnp.arctan2(rot1[1],rot1[0]) 97 | return jnp.stack([phi,theta]) 98 | self.do_chart_update = lambda x: jnp.max(jnp.abs(x[0])) >= np.pi/4 # look for a new chart if true 99 | 100 | EmbeddedManifold.__init__(self,F,2,3,invF=invF) 101 | 102 | def __str__(self): 103 | return "torus in R^3, radius %s, Radius %s, axis %s" % (self.radius,self.Radius,self.orientation) 104 | 105 | def newfig(self): 106 | newfig3d() 107 | 108 | def plot(self, rotate=None,alpha=None,lw=0.3): 109 | ax = plt.gca() 110 | x = np.arange(-10,10,1) 111 | ax.xaxis._axinfo["grid"]['linewidth'] = lw 112 | ax.yaxis._axinfo["grid"]['linewidth'] = lw 113 | ax.zaxis._axinfo["grid"]['linewidth'] = lw 114 | ax.set_xlim(-1.,1.) 115 | ax.set_ylim(-1.,1.) 116 | ax.set_zlim(-1.,1.) 117 | #ax.set_aspect("equal") 118 | if rotate is not None: 119 | ax.view_init(rotate[0],rotate[1]) 120 | # else: 121 | # ax.view_init(35,225) 122 | plt.xlabel('x') 123 | plt.ylabel('y') 124 | 125 | # ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f')) 126 | #draw ellipsoid 127 | u, v = np.mgrid[-np.pi:np.pi:20j, -np.pi:np.pi:10j] 128 | x = np.zeros(u.shape) 129 | y = np.zeros(u.shape) 130 | z = np.zeros(u.shape) 131 | for i in range(u.shape[0]): 132 | for j in range(u.shape[1]): 133 | w = self.F(self.coords(jnp.array([u[i,j],v[i,j]]))) 134 | x[i,j] = w[0]; y[i,j] = w[1]; z[i,j] = w[2] 135 | ax.plot_wireframe(x, y, z, color='gray', alpha=0.5) 136 | 137 | if alpha is not None: 138 | ax.plot_surface(x, y, z, color=cm.jet(0.), alpha=alpha) 139 | 140 | def plot_field(self, field,lw=.3): 141 | ax = plt.gca() 142 | x = np.arange(-10,10,1) 143 | ax.xaxis._axinfo["grid"]['linewidth'] = lw 144 | ax.yaxis._axinfo["grid"]['linewidth'] = lw 145 | ax.zaxis._axinfo["grid"]['linewidth'] = lw 146 | ax.set_xlim(-1.,1.) 147 | ax.set_ylim(-1.,1.) 148 | ax.set_zlim(-1.,1.) 149 | #ax.set_aspect("equal") 150 | 151 | plt.xlabel('x') 152 | plt.ylabel('y') 153 | 154 | u, v = np.mgrid[-np.pi:np.pi:40j, -np.pi:np.pi:20j] 155 | x = np.zeros(u.shape) 156 | y = np.zeros(u.shape) 157 | z = np.zeros(u.shape) 158 | for i in range(u.shape[0]): 159 | for j in range(u.shape[1]): 160 | w = self.F(self.coords(jnp.array([u[i,j],v[i,j]]))) 161 | x[i,j] = w[0]; y[i,j] = w[1]; z[i,j] = w[2] 162 | 163 | for i in range(x.shape[0]): 164 | for j in range(x.shape[1]): 165 | Fx = np.array([x[i,j],y[i,j],z[i,j]]) 166 | chart = self.centered_chartf(Fx) 167 | xcoord = self.invF((Fx,chart)) 168 | v = field((xcoord,chart)) 169 | self.plotx((xcoord,chart),v=v) 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /jaxgeometry/params.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | 23 | ########################################################################## 24 | # this file contains various object definitions, and standard parameters # 25 | ########################################################################## 26 | 27 | # default integration times and time steps 28 | T = 1. 29 | n_steps = 100 30 | 31 | # Integrator variables: 32 | default_method = 'euler' 33 | #default_method = 'rk4' 34 | 35 | 36 | -------------------------------------------------------------------------------- /jaxgeometry/sR/__init__.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | -------------------------------------------------------------------------------- /jaxgeometry/sR/metric.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.utils import * 23 | 24 | def initialize(M): 25 | """ add sR structure to manifold """ 26 | """ currently assumes distribution and that ambient Riemannian manifold is Euclidean """ 27 | 28 | d = M.dim 29 | 30 | if not hasattr(M, 'D'): 31 | raise ValueError('no distribution defined on manifold') 32 | 33 | M.sR_dim = M.D(M.coords(jnp.zeros(M.dim))).shape[1] 34 | 35 | if not hasattr(M,'a'): 36 | M.a = lambda x: mmT(M.D(x)) 37 | else: 38 | print('using existing M.a') 39 | 40 | ### trivial embedding 41 | M.F = lambda x: x[0] 42 | M.invF = lambda x: (x,M.chart()) 43 | M.JF = jacfwdx(M.F) 44 | M.invJF = jacfwdx(M.invF) 45 | 46 | ##### sharp map: 47 | M.sharp = lambda x,p: jnp.tensordot(M.a(x),p,(1,0)) 48 | 49 | ##### Hamiltonian 50 | if not hasattr(M,'H'): 51 | M.H = lambda x,p: 5*jnp.sum(jnp.einsum('i,ij->j',p,M.D(x))**2) 52 | else: 53 | print('using existing M.H') 54 | 55 | ##### divergence in divergence free othornormal distribution 56 | M.div = lambda x,X: jnp.einsum('ij,ji->',jacfwdx(X)(x)[:M.sR_dim,:],M.D(x)) 57 | -------------------------------------------------------------------------------- /jaxgeometry/setup.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | import warnings 22 | warnings.simplefilter(action='ignore', category=FutureWarning) 23 | 24 | import numpy as np 25 | import scipy 26 | 27 | import jax 28 | import jax.numpy as jnp 29 | from jax import lax 30 | from jax import grad, jacfwd, jacrev, jit, vmap 31 | from jax import random 32 | from jax.scipy import optimize 33 | 34 | from functools import partial 35 | 36 | from jaxgeometry.utils import * 37 | 38 | import time 39 | 40 | from jaxgeometry.params import * 41 | 42 | import itertools 43 | from functools import partial 44 | 45 | -------------------------------------------------------------------------------- /jaxgeometry/statistics/Frechet_mean.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | 22 | from scipy.optimize import minimize 23 | from jax.example_libraries import optimizers 24 | 25 | def initialize(M,Exp=None): 26 | 27 | try: 28 | if Exp is None: 29 | Exp = M.Exp 30 | except AttributeError: 31 | return 32 | 33 | # objective 34 | def f(chart,x,v): 35 | return jnp.dot(v,jnp.dot(M.g((x,chart)),v)) 36 | M.Frechet_mean_f = f 37 | 38 | # constraint 39 | def _c(chart,x,v,y,ychart): 40 | xT,chartT = M.Exp((x,chart),v) 41 | y_chartT = M.update_coords((y,ychart),chartT) 42 | return jnp.sqrt(M.dim)*(xT-y_chartT[0]) 43 | def c(chart,x,v,y,ychart): 44 | return jnp.sum(jnp.square(_c(chart,x,v,y,ychart))) 45 | 46 | # derivatives 47 | M.Frechet_mean_vgv_c = jit(jax.value_and_grad(c,(2,))) 48 | M.Frechet_mean_jacxv_c = jit(jax.jacrev(_c,(1,2))) 49 | M.Frechet_mean_jacxv_f = jit(jax.value_and_grad(f,(1,2))) 50 | def vgx_f(chart,x,v,y,ychart): 51 | _jacxv_c = M.Frechet_mean_jacxv_c(chart,x,v,y,ychart) 52 | jacv = -jnp.linalg.solve(_jacxv_c[1],_jacxv_c[0]) # implicit function theorem 53 | 54 | v_f, g_f = M.Frechet_mean_jacxv_f(chart,x,v) 55 | g_f = g_f[0]+jnp.dot(g_f[1],jacv) 56 | 57 | return v_f, g_f 58 | M.Frechet_mean_vgx_f = vgx_f 59 | 60 | def Frechet_mean(ys,x0,Log=None,options={}): 61 | # data 62 | ys = list(ys) # make sure y is subscriptable 63 | N = len(ys) 64 | chart = x0[1] # single chart for now, could be updated 65 | 66 | if Log is None: 67 | # combined optimization, no use of Log maps 68 | step_sizex=options.get('step_sizex',1e-1) 69 | step_sizevs=options.get('step_sizevs',1e-1) 70 | num_steps=options.get('num_steps',200) 71 | optx_update_mod=options.get('optx_update_mod',5) 72 | 73 | opt_initx, opt_updatex, get_paramsx = optimizers.adam(step_sizex) 74 | opt_initvs, opt_updatevs, get_paramsvs = optimizers.adam(step_sizevs) 75 | 76 | # tracking steps 77 | steps = (x0,) 78 | 79 | def step(step, params, ys, y_charts, opt_statex, opt_statevs): 80 | paramsx = get_paramsx(opt_statex); paramsvs = get_paramsvs(opt_statevs) 81 | valuex = None; gradx = jnp.zeros(M.dim); valuevs = (); gradvs = () 82 | # for i in range(N): 83 | # vvs,gvs = vgv_c(params,paramsx,paramsvs[i],*ys[i]) 84 | # valuevs += (vvs,); gradvs += gvs 85 | 86 | valuevs,gradvs = jax.vmap(M.Frechet_mean_vgv_c,(None,None,0,0,0))(params,paramsx,paramsvs,ys,y_charts) 87 | opt_statevs = opt_updatevs(step, jnp.array(gradvs).squeeze(), opt_statevs) 88 | if step % optx_update_mod == 0: 89 | # for i in range(N): 90 | # vx,gx = vgx_f(params,paramsx,paramsvs[i],*ys[i]) 91 | # valuex += 1/N*vx; gradx = gradx+1/N*gx[0] 92 | valuex,gradx = jax.vmap(M.Frechet_mean_vgx_f,(None,None,0,0,0))(params,paramsx,paramsvs,ys,y_charts) 93 | valuex = jnp.mean(valuex,0); gradx = jnp.mean(gradx,0) 94 | opt_statex = opt_updatex(step, gradx, opt_statex) 95 | return (valuex, valuevs), (opt_statex, opt_statevs) 96 | 97 | # optim setup 98 | params = x0[1] 99 | paramsx = x0[0] 100 | paramsvs = jnp.zeros((N,M.dim)) 101 | opt_statex = opt_initx(paramsx) 102 | opt_statevs = opt_initvs(paramsvs) 103 | valuex = 0; valuesvs = () 104 | ys,y_charts=list(zip(*ys)) 105 | ys = jnp.array(ys); y_charts = jnp.array(y_charts) 106 | 107 | for i in range(num_steps): 108 | (_valuex, valuevs), (opt_statex, opt_statevs) = step(i, params, ys, y_charts, opt_statex, opt_statevs) 109 | if _valuex: 110 | valuex = _valuex 111 | if i % 10 == 0: 112 | print("Step {} | T: {:0.6e} | T: {:0.6e}".format(i, valuex, jnp.max(valuevs))) 113 | if i % optx_update_mod == 0: 114 | steps += ((get_paramsx(opt_statex),chart),) 115 | print("Step {} | T: {:0.6e} | T: {:0.6e} ".format(i, valuex, jnp.max(valuevs))) 116 | 117 | m = (get_paramsx(opt_statex),params) 118 | vs = get_paramsvs(opt_statevs) 119 | 120 | return (m,valuex,steps,vs) 121 | 122 | 123 | else: 124 | # Log based optimization 125 | def fopts(x): 126 | N = len(ys) 127 | Logs = np.zeros((N, x.shape[0])) 128 | for i in range(N): 129 | Logs[i] = Log((x,chart), ys[i])[0] 130 | 131 | res = (1. / N) * np.sum(np.square(Logs)) 132 | grad = -(2. / N) * np.sum(Logs, 0) 133 | 134 | return (res, grad) 135 | 136 | # tracking steps 137 | global _steps 138 | _steps = (x0,) 139 | def save_step(k): 140 | global _steps 141 | _steps += ((k,chart),) 142 | 143 | res = minimize(fopts, x0[0], method='BFGS', jac=True, options=options, callback=save_step) 144 | 145 | return ((res.x,x0[1]), res.fun, _steps) 146 | 147 | M.Frechet_mean = Frechet_mean 148 | -------------------------------------------------------------------------------- /jaxgeometry/statistics/diffusion_mean.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.params import * 22 | 23 | # Delyon/Hu guided process 24 | from jaxgeometry.stochastics.guided_process import * 25 | 26 | from jaxgeometry.statistics.iterative_mle import * 27 | 28 | def initialize(M): 29 | 30 | # guide function 31 | phi = lambda q,v,s: jnp.tensordot((1/s)*jnp.linalg.cholesky(M.g(q)).T,M.StdLog(q,M.F((v,q[1]))).flatten(),(1,0)) 32 | A = lambda x,v,w,s: (s**(-2))*jnp.dot(v,jnp.dot(M.g(x),w)) 33 | logdetA = lambda x,s: jnp.linalg.slogdet(s**(-2)*M.g(x))[1] 34 | 35 | (Brownian_coords_guided,sde_Brownian_coords_guided,chart_update_Brownian_coords_guided,log_p_T,neg_log_p_Ts) = get_guided( 36 | M,M.sde_Brownian_coords,M.chart_update_Brownian_coords,phi, 37 | lambda x,s: s*jnp.linalg.cholesky(M.gsharp(x)),A,logdetA) 38 | 39 | # optimization setup 40 | N = 1 # bridge samples per datapoint 41 | _dts = dts(n_steps=100,T=1.) 42 | 43 | # define parameters 44 | x = M.coords(jnp.zeros(M.dim)) 45 | params_inds = (0,5) 46 | 47 | # function to update charts for position depends parameters 48 | def params_update(state, chart): 49 | try: 50 | ((x,m,v),),*s = state 51 | if M.do_chart_update((x,chart)): 52 | new_chart = M.centered_chart((x,chart)) 53 | (x,chart) = M.update_coords((x,chart),new_chart) 54 | return optimizers.OptimizerState(((x,m,v),),*s),chart 55 | except ValueError: # state is packed 56 | states_flat, tree_def, subtree_defs = state 57 | ((x,m,v),*s) = states_flat 58 | if M.do_chart_update((x,chart)): 59 | new_chart = M.centered_chart((x,chart)) 60 | (x,chart) = M.update_coords((x,chart),new_chart) 61 | states_flat = ((x,m,v),*s) 62 | return (states_flat,tree_def,subtree_defs),chart 63 | 64 | M.diffusion_mean = lambda samples,params=(x[0]+.1*np.random.normal(size=M.dim),jnp.array(.2,dtype="float32")),N=N,num_steps=80: \ 65 | iterative_mle(samples,\ 66 | neg_log_p_Ts,\ 67 | params,params_inds,params_update,x[1],_dts,M,\ 68 | N=N,num_steps=num_steps,step_size=1e-2) 69 | 70 | -------------------------------------------------------------------------------- /jaxgeometry/statistics/iterative_mle.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | 22 | from jax.example_libraries import optimizers 23 | 24 | def iterative_mle(obss,neg_log_p_Ts,params,params_inds,params_update,chart,_dts,M,N=1,step_size=1e-1,num_steps=50): 25 | opt_init, opt_update, get_params = optimizers.adam(step_size) 26 | vg = jax.value_and_grad(neg_log_p_Ts,params_inds) 27 | 28 | def step(step, params, opt_state, chart): 29 | params = get_params(opt_state) 30 | value,grads = vg(params[0],chart,obss,dWs(len(obss[0])*N*M.dim,_dts).reshape(-1,_dts.shape[0],N,M.dim),_dts,*params[1:]) 31 | opt_state = opt_update(step, grads, opt_state) 32 | opt_state,chart = params_update(opt_state, chart) 33 | return (value,opt_state,chart) 34 | 35 | opt_state = opt_init(params) 36 | values = (); paramss = () 37 | 38 | for i in range(num_steps): 39 | (value, opt_state, chart) = step(i, params, opt_state, chart) 40 | values += (value,); paramss += ((*get_params(opt_state),chart),) 41 | if i % 1 == 0: 42 | print("Step {} | T: {:0.6e} | T: {}".format(i, value, str((get_params(opt_state),chart)))) 43 | print("Final {} | T: {:0.6e} | T: {}".format(i, value, str(get_params(opt_state)))) 44 | 45 | return (get_params(opt_state),chart,value,jnp.array(values),paramss) 46 | -------------------------------------------------------------------------------- /jaxgeometry/statistics/tangent_PCA.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | 23 | from jaxgeometry.utils import * 24 | from sklearn.decomposition import PCA 25 | 26 | def tangent_PCA(M, Log, mean, ys): 27 | Logs = jax.vmap(lambda y,chart: Log(mean,(y,chart))[0])(ys[0],ys[1]) 28 | 29 | print(Logs.shape) 30 | pca = PCA() 31 | pca.fit(Logs) 32 | pca.transformed_Logs = pca.transform(Logs) 33 | 34 | return pca 35 | -------------------------------------------------------------------------------- /jaxgeometry/stochastics/Brownian_coords.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | def initialize(M): 24 | """ Brownian motion in coordinates """ 25 | 26 | def sde_Brownian_coords(c,y): 27 | t,x,chart,s = c 28 | dt,dW = y 29 | 30 | gsharpx = M.gsharp((x,chart)) 31 | X = s*jnp.linalg.cholesky(gsharpx) 32 | det = -.5*(s**2)*jnp.einsum('kl,ikl->i',gsharpx,M.Gamma_g((x,chart))) 33 | sto = jnp.tensordot(X,dW,(1,0)) 34 | return (det,sto,X,0.) 35 | 36 | def chart_update_Brownian_coords(x,chart,*ys): 37 | if M.do_chart_update is None: 38 | return (x,chart,*ys) 39 | 40 | update = M.do_chart_update(x) 41 | new_chart = M.centered_chart((x,chart)) 42 | new_x = M.update_coords((x,chart),new_chart)[0] 43 | 44 | return (jnp.where(update, 45 | new_x, 46 | x), 47 | jnp.where(update, 48 | new_chart, 49 | chart), 50 | *ys) 51 | 52 | M.sde_Brownian_coords = sde_Brownian_coords 53 | M.chart_update_Brownian_coords = chart_update_Brownian_coords 54 | M.Brownian_coords = jit(lambda x,dts,dWs,stdCov=1.: integrate_sde(sde_Brownian_coords,integrator_ito,chart_update_Brownian_coords,x[0],x[1],dts,dWs,stdCov)[0:3]) 55 | -------------------------------------------------------------------------------- /jaxgeometry/stochastics/Brownian_development.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.utils import * 23 | 24 | def initialize(M): 25 | """ Brownian motion from stochastic development """ 26 | 27 | def Brownian_development(x,dts,dWs): 28 | # amend x with orthogonal basis to get initial frame bundle element 29 | gsharpx = M.gsharp(x) 30 | nu = jnp.linalg.cholesky(gsharpx) 31 | u = (jnp.concatenate((x[0],nu.flatten())),x[1]) 32 | 33 | (ts,us,charts) = M.stochastic_development(u,dts,dWs) 34 | 35 | return (ts,us[:,0:M.dim],charts) 36 | 37 | M.Brownian_development = Brownian_development 38 | -------------------------------------------------------------------------------- /jaxgeometry/stochastics/Brownian_inv.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | def initialize(G): 24 | """ Brownian motion with respect to left/right invariant metric """ 25 | 26 | assert(G.invariance == 'left') 27 | 28 | def sde_Brownian_inv(c,y): 29 | t,g,_,sigma = c 30 | dt,dW = y 31 | 32 | X = jnp.tensordot(G.invpf(g,G.eiLA),sigma,(2,0)) 33 | det = -.5*jnp.tensordot(jnp.diagonal(G.C,0,2).sum(1),X,(0,2)) 34 | sto = jnp.tensordot(X,dW,(2,0)) 35 | return (det,sto,X,jnp.zeros_like(sigma)) 36 | 37 | G.sde_Brownian_inv = sde_Brownian_inv 38 | G.Brownian_inv = lambda g,dts,dWt,sigma=jnp.eye(G.dim): integrate_sde(G.sde_Brownian_inv,integrator_stratonovich,None,g,None,dts,dWt,sigma)[0:3] 39 | 40 | -------------------------------------------------------------------------------- /jaxgeometry/stochastics/Brownian_process.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | def initialize(G): 24 | """ Brownian motion with respect to left/right invariant metric """ 25 | 26 | assert(G.invariance == 'left') 27 | 28 | def sde_Brownian_process(c,y): 29 | t,g,_,sigma = c 30 | dt,dW = y 31 | 32 | X = jnp.tensordot(G.invpf(g,G.eiLA),sigma,(2,0)) 33 | det = jnp.zeros_like(g) 34 | sto = jnp.tensordot(X,dW,(2,0)) 35 | return (det,sto,X,0.) 36 | 37 | G.sde_Brownian_process = sde_Brownian_process 38 | G.Brownian_process = lambda g,dts,dWt,sigma=jnp.eye(G.dim): integrate_sde(G.sde_Brownian_process,integrator_stratonovich,None,g,None,dts,dWt,sigma)[0:3] 39 | 40 | -------------------------------------------------------------------------------- /jaxgeometry/stochastics/Brownian_sR.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | from jaxgeometry.setup import * 22 | from jaxgeometry.utils import * 23 | 24 | def initialize(M): 25 | """ sub-Riemannian Brownian motion """ 26 | 27 | def sde_Brownian_sR(c,y): 28 | t,x,chart,s = c 29 | dt,dW = y 30 | 31 | D = s*M.D((x,chart)) 32 | # D0 = \sum_{i=1}^m div_\mu(X_i) X_i) - not implemented yet 33 | det = jnp.zeros_like(x) # Y^k(x)=X_0^k(x)+(1/2)\sum_{i=1}^m \langle \nabla X_i^k(x),X_i(x)\rangle 34 | sto = jnp.tensordot(D,dW,(1,0)) 35 | return (det,sto,D,0.) 36 | 37 | def chart_update_Brownian_sR(x,chart,*ys): 38 | if M.do_chart_update is None: 39 | return (x,chart,*ys) 40 | 41 | update = M.do_chart_update(x) 42 | new_chart = M.centered_chart((x,chart)) 43 | new_x = M.update_coords((x,chart),new_chart)[0] 44 | 45 | return (jnp.where(update, 46 | new_x, 47 | x), 48 | jnp.where(update, 49 | new_chart, 50 | chart),*ys) 51 | 52 | M.sde_Brownian_sR = sde_Brownian_sR 53 | M.chart_update_Brownian_sR = chart_update_Brownian_sR 54 | M.Brownian_sR = jit(lambda x,dts,dWs,stdCov=1.: integrate_sde(sde_Brownian_sR,integrator_ito,chart_update_Brownian_sR,x[0],x[1],dts,dWs,stdCov)[0:3]) 55 | -------------------------------------------------------------------------------- /jaxgeometry/stochastics/Eulerian.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | ############################################################### 24 | # Eulerian / stochastic EPDiff acting on landmarks 25 | ############################################################### 26 | def initialize(M,k=None): 27 | dq = jit(grad(M.H,argnums=1)) 28 | dp = jit(lambda q,p: -gradx(M.H)(q,p)) 29 | 30 | # noise basis 31 | if k is None: # use landmark kernel per default 32 | k = M.k 33 | 34 | k_q = lambda q1,q2: k(q1.reshape((-1,M.m))[:,np.newaxis,:]-q2.reshape((-1,M.m))[np.newaxis,:,:]) 35 | K = lambda q1,q2: (k_q(q1,q2)[:,:,np.newaxis,np.newaxis]*jnp.eye(M.m)[np.newaxis,np.newaxis,:,:]).transpose((0,2,1,3)).reshape((M.dim,-1)) 36 | 37 | def sde_Eulerian(c,y): 38 | t,x,chart,sigmas_x,sigmas_a = c 39 | dt,dW = y 40 | dqt = dq((x[0],chart),x[1]) 41 | dpt = dp((x[0],chart),x[1]) 42 | 43 | sigmas_adW = sigmas_a*dW[:,np.newaxis] 44 | sigmadWq = jnp.tensordot(K(x[0],sigmas_x),sigmas_adW.flatten(),(1,0)) 45 | sigmadWp = jnp.tensordot( 46 | jax.jacrev( 47 | lambda lq: jnp.tensordot(K(lq,sigmas_x),sigmas_adW.flatten(),(1,0)).flatten(), 48 | )(x[0]), 49 | x[1],(1,0)) 50 | 51 | X = None # to be implemented 52 | det = jnp.stack((dqt,dpt)) 53 | sto = jnp.stack((sigmadWq,sigmadWp)) 54 | return (det,sto,X,jnp.zeros_like(sigmas_x),jnp.zeros_like(sigmas_a)) 55 | 56 | def chart_update_Eulerian(xp,chart,*cy): 57 | if M.do_chart_update is None: 58 | return (xp,chart,*cy) 59 | 60 | p = xp[1] 61 | x = (xp[0],chart) 62 | 63 | update = M.do_chart_update(x) 64 | new_chart = M.centered_chart(M.F(x)) 65 | new_x = M.update_coords(x,new_chart)[0] 66 | 67 | return (jnp.where(update, 68 | jnp.stack((new_x,M.update_covector(x,new_x,new_chart,p))), 69 | xp), 70 | jnp.where(update, 71 | new_chart, 72 | chart), 73 | *cy) 74 | 75 | M.Eulerian_qp = lambda q,p,sigmas_x,sigmas_a,dts,dWs: integrate_sde(sde_Eulerian,integrator_stratonovich,chart_update_Eulerian,jnp.stack((q[0],p)),q[1],dts,dWs,sigmas_x,sigmas_a) 76 | M.Eulerian = lambda q,p,sigmas_x,sigmas_a,dts,dWs: M.Eulerian_qp(q,p,sigmas_x,sigmas_a,dts,dWs)[0:3] 77 | -------------------------------------------------------------------------------- /jaxgeometry/stochastics/Langevin.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | ############################################################### 24 | # Langevin equations https://arxiv.org/abs/1605.09276 25 | ############################################################### 26 | def initialize(M): 27 | dq = jit(grad(M.H,argnums=1)) 28 | dp = jit(lambda q,p: -gradx(M.H)(q,p)) 29 | 30 | def sde_Langevin(c,y): 31 | t,x,chart,l,s = c 32 | dt,dW = y 33 | dqt = dq((x[0],chart),x[1]) 34 | dpt = dp((x[0],chart),x[1])-l*dq((x[0],chart),x[1]) 35 | 36 | X = jnp.stack((jnp.zeros((M.dim,M.dim)),s*jnp.eye(M.dim))) 37 | det = jnp.stack((dqt,dpt)) 38 | sto = jnp.tensordot(X,dW,(1,0)) 39 | return (det,sto,X,jnp.zeros_like(l),jnp.zeros_like(s)) 40 | 41 | def chart_update_Langevin(xp,chart,*cy): 42 | if M.do_chart_update is None: 43 | return (xp,chart,*cy) 44 | 45 | p = xp[1] 46 | x = (xp[0],chart) 47 | 48 | update = M.do_chart_update(x) 49 | new_chart = M.centered_chart(M.F(x)) 50 | new_x = M.update_coords(x,new_chart)[0] 51 | 52 | return (jnp.where(update, 53 | jnp.stack((new_x,M.update_covector(x,new_x,new_chart,p))), 54 | xp), 55 | jnp.where(update, 56 | new_chart, 57 | chart), 58 | *cy) 59 | 60 | M.Langevin_qp = lambda q,p,l,s,dts,dWt: integrate_sde(sde_Langevin,integrator_ito,chart_update_Langevin,jnp.stack((q[0],p)),q[1],dts,dWt,l,s) 61 | 62 | M.Langevin = lambda q,p,l,s,dts,dWt: M.Langevin_qp(q,p,l,s,dts,dWt)[0:3] 63 | -------------------------------------------------------------------------------- /jaxgeometry/stochastics/__init__.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | 21 | -------------------------------------------------------------------------------- /jaxgeometry/stochastics/diagonal_conditioning.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | def initialize(M,sde_product,chart_update_product,integrator=integrator_ito,T=1): 24 | """ diagonally conditioned product diffusions """ 25 | 26 | def sde_diagonal(c,y): 27 | if M.do_chart_update is None: 28 | t,x,chart,T,*cy = c 29 | else: 30 | t,x,chart,T,ref_chart,*cy = c 31 | dt,dW = y 32 | 33 | (det,sto,X,*dcy) = sde_product((t,x,chart,*cy),y) 34 | 35 | if M.do_chart_update is None: 36 | xref = x 37 | else: 38 | xref = jax.vmap(lambda x,chart: M.update_coords((x,chart),ref_chart)[0],0)(x,chart) 39 | m = jnp.mean(xref,0) # mean 40 | href = jax.lax.cond(t. 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | ####################################################################### 24 | # guided processes, Delyon/Hu 2006 # 25 | ####################################################################### 26 | 27 | # hit target v at time t=Tend 28 | def get_guided(M,sde,chart_update,phi,sqrtCov=None,A=None,logdetA=None,method='DelyonHu',integration='ito'): 29 | """ guided diffusions """ 30 | 31 | def sde_guided(c,y): 32 | t,x,chart,log_likelihood,log_varphi,T,v,*cy = c 33 | xchart = (x,chart) 34 | dt,dW = y 35 | 36 | (det,sto,X,*dcy) = sde((t,x,chart,*cy),y) 37 | 38 | h = jax.lax.cond(t. 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | def initialize(M,sde,chart_update,integrator=integrator_ito): 24 | """ product diffusions """ 25 | 26 | def sde_product(c,y): 27 | t,x,chart,*cy = c 28 | dt,dW = y 29 | 30 | (det,sto,X,*dcy) = jax.vmap(lambda x,chart,dW,*_cy: sde((t,x,chart,*_cy),(dt,dW)),0)(x,chart,dW,*cy) 31 | 32 | return (det,sto,X,*dcy) 33 | 34 | chart_update_product = jax.vmap(chart_update) 35 | 36 | product = jit(lambda x,dts,dWs,*cy: integrate_sde(sde_product,integrator,chart_update_product,x[0],x[1],dts,dWs,*cy)) 37 | 38 | return (product,sde_product,chart_update_product) 39 | 40 | # for initializing parameters 41 | def tile(x,N): 42 | try: 43 | return jnp.tile(x,(N,)+(1,)*x.ndim) 44 | except AttributeError: 45 | try: 46 | return jnp.tile(x,N) 47 | except TypeError: 48 | return tuple([tile(y,N) for y in x]) 49 | -------------------------------------------------------------------------------- /jaxgeometry/stochastics/stochastic_coadjoint.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | def initialize(G,Psi=None,r=None): 24 | """ stochastic coadjoint motion with left/right invariant metric 25 | see Noise and dissipation on coadjoint orbits arXiv:1601.02249 [math.DS] 26 | and EulerPoincare.py """ 27 | 28 | assert(G.invariance == 'left') 29 | 30 | # Matrix function Psi:LA\rightarrow R^r must be defined beforehand 31 | # example here from arXiv:1601.02249 32 | if Psi is None: 33 | sigmaPsi = jnp.eye(G.dim) 34 | Psi = lambda mu: jnp.dot(sigmaPsi,mu) 35 | # r = Psi.shape[0] 36 | r = G.dim 37 | assert(Psi is not None and r is not None) 38 | 39 | def sde_stochastic_coadjoint(c,y): 40 | t,mu,_ = c 41 | dt,dW = y 42 | 43 | xi = G.invFl(mu) 44 | det = -G.coad(xi,mu) 45 | Sigma = G.coad(mu,jax.jacrev(Psi)(mu).transpose((1,0))) 46 | sto = jnp.tensordot(Sigma,dW,(1,0)) 47 | return (det,sto,Sigma) 48 | G.sde_stochastic_coadjoint = sde_stochastic_coadjoint 49 | G.stochastic_coadjoint = lambda mu,dts,dWt: integrate_sde(G.sde_stochastic_coadjoint,integrator_stratonovich,None,mu,None,dts,dWt) 50 | 51 | # reconstruction as in Euler-Poincare / Lie-Poisson reconstruction 52 | if not hasattr(G,'EPrec'): 53 | from jaxgeometry.group import EulerPoincare 54 | EulerPoincare.initialize(G) 55 | G.stochastic_coadjointrec = G.EPrec 56 | 57 | -------------------------------------------------------------------------------- /jaxgeometry/stochastics/stochastic_development.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.utils import * 22 | 23 | def initialize(M): 24 | """ development and stochastic development from R^d to M """ 25 | 26 | # Deterministic development 27 | def ode_development(c,y): 28 | t,u,chart = c 29 | dgamma, = y 30 | 31 | u = (u,chart) 32 | nu = u[0][M.dim:].reshape((M.dim,-1)) 33 | m = nu.shape[1] 34 | 35 | det = jnp.tensordot(M.Horizontal(u)[:,0:m],dgamma,(1,0)) 36 | 37 | return det 38 | 39 | M.development = jit(lambda u,dgamma,dts: integrate(ode_development,M.chart_update_FM,u[0],u[1],dts,dgamma)) 40 | 41 | # Stochastic development 42 | def sde_development(c,y): 43 | t,u,chart = c 44 | dt,dW = y 45 | 46 | u = (u,chart) 47 | nu = u[0][M.dim:].reshape((M.dim,-1)) 48 | m = nu.shape[1] 49 | 50 | sto = jnp.tensordot(M.Horizontal(u)[:,0:m],dW,(1,0)) 51 | 52 | return (jnp.zeros_like(sto), sto, M.Horizontal(u)[:,0:m]) 53 | 54 | M.sde_development = sde_development 55 | M.stochastic_development = jit(lambda u,dts,dWs: integrate_sde(sde_development,integrator_stratonovich,M.chart_update_FM,u[0],u[1],dts,dWs)) 56 | -------------------------------------------------------------------------------- /jaxgeometry/utils.py: -------------------------------------------------------------------------------- 1 | ## This file is part of Jax Geometry 2 | # 3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk) 4 | # https://bitbucket.org/stefansommer/jaxgeometry 5 | # 6 | # Jax Geometry is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # Jax Geometry is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Jax Geometry. If not, see . 18 | # 19 | 20 | from jaxgeometry.setup import * 21 | from jaxgeometry.params import * 22 | 23 | ####################################################################### 24 | # various useful functions # 25 | ####################################################################### 26 | 27 | # jax.grad but only for the x variable of a function taking coordinates and chart 28 | def gradx(f): 29 | def fxchart(x,chart,*args,**kwargs): 30 | return f((x,chart),*args,**kwargs) 31 | def gradf(x,*args,**kwargs): 32 | return jax.grad(fxchart,argnums=0)(x[0],x[1],*args,**kwargs) 33 | return gradf 34 | 35 | # jax.jacfwd but only for the x variable of a function taking coordinates and chart 36 | def jacfwdx(f): 37 | def fxchart(x,chart,*args,**kwargs): 38 | return f((x,chart),*args,**kwargs) 39 | def jacf(x,*args,**kwargs): 40 | return jax.jacfwd(fxchart,argnums=0)(x[0],x[1],*args,**kwargs) 41 | return jacf 42 | 43 | # jax.jacrev but only for the x variable of a function taking coordinates and chart 44 | def jacrevx(f): 45 | def fxchart(x,chart,*args,**kwargs): 46 | return f((x,chart),*args,**kwargs) 47 | def jacf(x,*args,**kwargs): 48 | return jax.jacrev(fxchart,argnums=0)(x[0],x[1],*args,**kwargs) 49 | return jacf 50 | 51 | # hessian only for the x variable of a function taking coordinates and chart 52 | def hessianx(f): 53 | return jacfwdx(jacrevx(f)) 54 | 55 | # evaluation with pass through derivatives 56 | def straight_through(f,x,*ys): 57 | # Create an exactly-zero expression with Sterbenz lemma that has 58 | # an exactly-one gradient. 59 | if type(x) == type(()): 60 | zeros = tuple([xi - jax.lax.stop_gradient(xi) for xi in x]) 61 | fx = jax.lax.stop_gradient(f(x,*ys)) 62 | return tuple([fxi - jax.lax.stop_gradient(fxi) for fxi in fx]) 63 | else: 64 | zero = x - jax.lax.stop_gradient(x) 65 | return zeros + jax.lax.stop_gradient(f(x,*ys)) 66 | 67 | 68 | # time increments, deterministic 69 | def dts(T=T,n_steps=n_steps): 70 | return jnp.array([T/n_steps]*n_steps) 71 | # standard noise realisations 72 | 73 | # time increments, stochastic 74 | seed = 42 75 | global key 76 | key = jax.random.PRNGKey(seed) 77 | def dWs(d,_dts=None,num=1): 78 | global key 79 | keys = jax.random.split(key,num=num+1) 80 | key = keys[0] 81 | subkeys = keys[1:] 82 | if _dts == None: 83 | _dts = dts() 84 | if num == 1: 85 | return jnp.sqrt(_dts)[:,None]*random.normal(subkeys[0],(_dts.shape[0],d)) 86 | else: 87 | return jax.vmap(lambda subkey: jnp.sqrt(_dts)[:,None]*random.normal(subkey,(_dts.shape[0],d)))(subkeys) 88 | 89 | # Integrator (deterministic) 90 | def integrator(ode_f,chart_update=None,method=default_method): 91 | if chart_update == None: # no chart update 92 | chart_update = lambda *args: args[0:2] 93 | 94 | # euler: 95 | def euler(c,y): 96 | t,x,chart = c 97 | dt,*_ = y 98 | return ((t+dt,*chart_update(x+dt*ode_f(c,y[1:]),chart,y[1:])),)*2 99 | 100 | # Runge-kutta: 101 | def rk4(c,y): 102 | t,x,chart = c 103 | dt,*_ = y 104 | k1 = ode_f(c,y[1:]) 105 | k2 = ode_f((t+dt/2,x + dt/2*k1,chart),y[1:]) 106 | k3 = ode_f((t+dt/2,x + dt/2*k2,chart),y[1:]) 107 | k4 = ode_f((t,x + dt*k3,chart),y[1:]) 108 | return ((t+dt,*chart_update(x + dt/6*(k1 + 2*k2 + 2*k3 + k4),chart,y[1:])),)*2 109 | 110 | if method == 'euler': 111 | return euler 112 | elif method == 'rk4': 113 | return rk4 114 | else: 115 | assert(False) 116 | 117 | # return symbolic path given ode and integrator 118 | def integrate(ode,chart_update,x,chart,dts,*ys): 119 | _,xs = lax.scan(integrator(ode,chart_update), 120 | (0.,x,chart), 121 | (dts,*ys)) 122 | return xs if chart_update is not None else xs[0:2] 123 | 124 | # sde functions should return (det,sto,Sigma) where 125 | # det is determinisitc part, sto is stochastic part, 126 | # and Sigma stochastic generator (i.e. often sto=dot(Sigma,dW) 127 | 128 | 129 | def integrate_sde(sde,integrator,chart_update,x,chart,dts,dWs,*cy): 130 | _,xs = lax.scan(integrator(sde,chart_update), 131 | (0.,x,chart,*cy), 132 | (dts,dWs,)) 133 | return xs 134 | 135 | def integrator_stratonovich(sde_f,chart_update=None): 136 | if chart_update == None: # no chart update 137 | chart_update = lambda xp,chart,*cy: (xp,chart,*cy) 138 | 139 | def euler_heun(c,y): 140 | t,x,chart,*cy = c 141 | dt,dW = y 142 | 143 | (detx, stox, X, *dcy) = sde_f(c,y) 144 | tx = x + stox 145 | cy_new = tuple([y+dt*dy for (y,dy) in zip(cy,dcy)]) 146 | return ((t+dt,*chart_update(x + dt*detx + 0.5*(stox + sde_f((t+dt,tx,chart,*cy),y)[1]), chart, *cy_new),),)*2 147 | 148 | return euler_heun 149 | 150 | def integrator_ito(sde_f,chart_update=None): 151 | if chart_update == None: # no chart update 152 | chart_update = lambda xp,chart,*cy: (xp,chart,*cy) 153 | 154 | def euler(c,y): 155 | t,x,chart,*cy = c 156 | dt,dW = y 157 | 158 | (detx, stox, X, *dcy) = sde_f(c,y) 159 | cy_new = tuple([y+dt*dy for (y,dy) in zip(cy,dcy)]) 160 | return ((t+dt,*chart_update(x + dt*detx + stox, chart, *cy_new)),)*2 161 | 162 | return euler 163 | 164 | 165 | def cross(a, b): 166 | return jnp.array([ 167 | a[1]*b[2] - a[2]*b[1], 168 | a[2]*b[0] - a[0]*b[2], 169 | a[0]*b[1] - a[1]*b[0]]) 170 | 171 | def mmT(A,C=None): 172 | return A@A.T if C is None else A@C@A.T 173 | 174 | def mTm(A,C=None): 175 | return A.T@A if C is None else A.T@C@A 176 | 177 | #import numpy as np 178 | #def python_scan(f, init, xs, length=None): 179 | # if xs is None: 180 | # xs = [None] * length 181 | # carry = init 182 | # ys = [] 183 | # for i in range(xs[0].shape[0]): 184 | # x = (xs[0][i],) 185 | # carry, y = f(carry, x) 186 | # ys.append(y) 187 | # return carry, np.stack(ys) 188 | # 189 | -------------------------------------------------------------------------------- /logo/stocso31s.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ComputationalEvolutionaryMorphometry/jaxgeometry/70e22f05966d5076adab37d9579a4900f486485f/logo/stocso31s.jpg -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test 2 | 3 | test: 4 | tests/test.sh 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "jaxdifferentialgeometry" 3 | version = "0.9.4" 4 | authors = [ 5 | { name="Stefan Sommer", email="sommer@di.ku.dk" }, 6 | ] 7 | description = "Differential geometry using jax" 8 | readme = "README.md" 9 | requires-python = ">=3.8" 10 | classifiers = [ 11 | "Programming Language :: Python :: 3", 12 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", 13 | "Operating System :: OS Independent", 14 | ] 15 | dependencies = [ 16 | "jax", 17 | "jaxlib", 18 | "scikit-learn", 19 | ] 20 | 21 | [project.optional-dependencies] 22 | examples = [ 23 | "jupyter", 24 | "matplotlib", 25 | ] 26 | 27 | [project.urls] 28 | Homepage = "https://github.com/computationalevolutionarymorphometry/jaxgeometry/" 29 | Issues = "https://github.com/computationalevolutionarymorphometry/jaxgeometry/issues" 30 | 31 | [build-system] 32 | requires = [ 33 | "setuptools>=60", 34 | "setuptools-scm>=8.0", 35 | ] 36 | build-backend = "setuptools.build_meta" 37 | 38 | [tool.setuptools_scm] 39 | version_file = "jaxgeometry/_version.py" 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax 2 | jaxlib 3 | matplotlib 4 | jupyter 5 | scikit-learn 6 | flax 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup( 5 | packages=find_packages(), 6 | ) 7 | -------------------------------------------------------------------------------- /tests/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # install dependencies 4 | pip install pytest nbmake 5 | 6 | # run tests 7 | PYTHONPATH=$(pwd) pytest --nbmake --nbmake-timeout=12000 examples/*.ipynb papers/*.ipynb 8 | --------------------------------------------------------------------------------