├── .env
├── .gitignore
├── LICENSE
├── README.md
├── examples
├── Euclidean.ipynb
├── H2.ipynb
├── Heisenberg.ipynb
├── S2.ipynb
├── S2_statistics.ipynb
├── SO3.ipynb
├── SO3_stochastics.ipynb
├── SPD3.ipynb
├── T2.ipynb
├── cylinder.ipynb
├── diffusion_mean.ipynb
├── landmarks.ipynb
├── landmarks_stochastics.ipynb
└── lifted_stochastics.ipynb
├── jaxgeometry
├── Riemannian
│ ├── Log.py
│ ├── __init__.py
│ ├── curvature.py
│ ├── geodesic.py
│ ├── metric.py
│ └── parallel_transport.py
├── __init__.py
├── dynamics
│ ├── Hamiltonian.py
│ ├── MPP_Kunita.py
│ ├── MPP_Kunita_Log.py
│ ├── MPP_group.py
│ ├── MPP_landmarks.py
│ ├── MPP_landmarks_Log.py
│ ├── flow.py
│ └── flow_differential.py
├── framebundle
│ ├── FM.py
│ └── MPP.py
├── group
│ ├── EulerPoincare.py
│ ├── LiePoisson.py
│ ├── energy.py
│ ├── invariant_metric.py
│ └── quotient.py
├── groups
│ ├── GLN.py
│ ├── SON.py
│ ├── __init__.py
│ └── group.py
├── manifolds
│ ├── Euclidean.py
│ ├── H2.py
│ ├── Heisenberg.py
│ ├── S2.py
│ ├── SPDN.py
│ ├── cylinder.py
│ ├── ellipsoid.py
│ ├── landmarks.py
│ ├── latent.py
│ ├── manifold.py
│ └── torus.py
├── params.py
├── plotting.py
├── sR
│ ├── __init__.py
│ └── metric.py
├── setup.py
├── statistics
│ ├── Frechet_mean.py
│ ├── diffusion_mean.py
│ ├── iterative_mle.py
│ └── tangent_PCA.py
├── stochastics
│ ├── Brownian_coords.py
│ ├── Brownian_development.py
│ ├── Brownian_inv.py
│ ├── Brownian_process.py
│ ├── Brownian_sR.py
│ ├── Eulerian.py
│ ├── Langevin.py
│ ├── __init__.py
│ ├── diagonal_conditioning.py
│ ├── guided_process.py
│ ├── product_sde.py
│ ├── stochastic_coadjoint.py
│ └── stochastic_development.py
└── utils.py
├── logo
└── stocso31s.jpg
├── makefile
├── papers
├── Heisenberg_score.ipynb
├── SO3_S2_Most_probable_paths_and_development.ipynb
├── most_probable_landmark_paths.ipynb
└── most_probable_transformation_for_Kunita_flows.ipynb
├── pyproject.toml
├── requirements.txt
├── setup.py
└── tests
└── test.sh
/.env:
--------------------------------------------------------------------------------
1 | PYTHONPATH=../src
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Add any directories, files, or patterns you don't want to be tracked by version control
2 | .ipynb_checkpoints
3 | __pycache__
4 | *.swp
5 | *.pyc
6 | old/
7 | pyvenv.cfg
8 | *.egg-info/
9 | dist/
10 |
11 | bin/
12 | include/
13 | lib/
14 | etc/
15 | pip-*
16 | share/
17 | backup/
18 |
19 | # misc
20 | kent_distribution
21 | contributed
22 |
23 | # output figures
24 | *.pdf
25 | *.png
26 |
27 | # temporary files
28 | .tmp/
29 | .DS_store
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | # Jax Geometry #
4 |
5 | The code in this repository is based on the papers *Differential geometry and stochastic dynamics with deep learning numerics* [arXiv:1712.08364](https://arxiv.org/abs/1712.08364) and *Computational Anatomy in Theano* [arXiv:1706.07690](https://arxiv.org/abs/1706.07690).
6 |
7 | The code is a reimplementation of the Theano Geometry library https://bitbucket.org/stefansommer/jaxgeometry/ replacing Theano with Jax https://github.com/google/jax.
8 |
9 | The source repository is at https://bitbucket.org/stefansommer/jaxgeometry/
10 |
11 | ### Who do I talk to? ###
12 |
13 | Please contact Stefan Sommer *sommer@di.ku.dk*
14 |
15 | ### Installation Instructions ###
16 |
17 | Please use Python 3.X.
18 |
19 | #### pip:
20 | Install with
21 | ```
22 | pip install jaxdifferentialgeometry
23 | ```
24 |
25 | #### from the repository:
26 | Check out the source with git sandiInstall required packages:
27 | ```
28 | pip install -r requirements.txt
29 | ```
30 | Use e.g. a Python 3 virtualenv:
31 | ```
32 | virtualenv -p python3 .
33 | source bin/activate
34 | pip install -r requirements.txt
35 | ```
36 | If you don't use a virtual environment, make sure that you are actually using Python 3, e.g. use pip3 instead of pip.
37 |
38 | Alternatively, use conda:
39 | ```
40 | conda install -c conda-forge jaxlib
41 | conda install -c conda-forge jax
42 | ```
43 | and similarly for the remaining requirements in requirements.txt.
44 |
45 | ### Viewing the example notebooks
46 | After cloning the source repository, start jupyter notebook
47 | ```
48 | PYTHONPATH=$(pwd)/src jupyter notebook
49 | ```
50 | Your browser should now open and you can find the example Jax Geometry notebooks in the examples folder.
51 |
52 | ### Why Jax? ###
53 | Some good discussions about the architectural differences between autodiff frameworks: https://www.assemblyai.com/blog/why-you-should-or-shouldnt-be-using-jax-in-2022/ and http://www.stochasticlifestyle.com/engineering-trade-offs-in-automatic-differentiation-from-tensorflow-and-pytorch-to-jax-and-julia/
54 |
--------------------------------------------------------------------------------
/examples/SO3.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2021-02-18T11:05:48.736791Z",
9 | "start_time": "2021-02-18T11:05:48.732894Z"
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "## This file is part of Jax Geometry\n",
15 | "#\n",
16 | "# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)\n",
17 | "# https://bitbucket.org/stefansommer/jaxgeometry\n",
18 | "#\n",
19 | "# Jax Geometry is free software: you can redistribute it and/or modify\n",
20 | "# it under the terms of the GNU General Public License as published by\n",
21 | "# the Free Software Foundation, either version 3 of the License, or\n",
22 | "# (at your option) any later version.\n",
23 | "#\n",
24 | "# Jax Geometry is distributed in the hope that it will be useful,\n",
25 | "# but WITHOUT ANY WARRANTY; without even the implied warranty of\n",
26 | "# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n",
27 | "# GNU General Public License for more details.\n",
28 | "#\n",
29 | "# You should have received a copy of the GNU General Public License\n",
30 | "# along with Jax Geometry. If not, see .\n",
31 | "#"
32 | ]
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "metadata": {},
37 | "source": [
38 | "# SO(3) group operations and dynamics"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "%load_ext autoreload\n",
48 | "%autoreload 2"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {
55 | "ExecuteTime": {
56 | "end_time": "2021-02-18T11:06:54.715340Z",
57 | "start_time": "2021-02-18T11:05:48.760484Z"
58 | }
59 | },
60 | "outputs": [],
61 | "source": [
62 | "from jaxgeometry.groups.SON import *\n",
63 | "G = SON(3)\n",
64 | "print(G)\n",
65 | "from jaxgeometry.plotting import *\n",
66 | "#%matplotlib notebook"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {
73 | "ExecuteTime": {
74 | "end_time": "2021-02-18T11:07:03.517081Z",
75 | "start_time": "2021-02-18T11:06:54.718336Z"
76 | }
77 | },
78 | "outputs": [],
79 | "source": [
80 | "# visualization\n",
81 | "newfig()\n",
82 | "G.plotg(G.e)\n",
83 | "plt.show()\n",
84 | "\n",
85 | "# geodesics in three directions\n",
86 | "v=jnp.array([1,0,0])\n",
87 | "xiv=G.VtoLA(v)\n",
88 | "(ts,gsv) = G.expt(xiv)\n",
89 | "newfig()\n",
90 | "G.plot_path(gsv)\n",
91 | "plt.show()\n",
92 | "\n",
93 | "v=jnp.array([0,1,0])\n",
94 | "xiv=G.VtoLA(v)\n",
95 | "(ts,gsv) = G.expt(xiv)\n",
96 | "newfig()\n",
97 | "G.plot_path(gsv)\n",
98 | "plt.show()\n",
99 | "\n",
100 | "v=jnp.array([0,0,1])\n",
101 | "xiv=G.VtoLA(v)\n",
102 | "(ts,gsv) = G.expt(xiv)\n",
103 | "newfig()\n",
104 | "G.plot_path(gsv)\n",
105 | "plt.show()"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": null,
111 | "metadata": {
112 | "ExecuteTime": {
113 | "end_time": "2021-02-18T11:07:10.319275Z",
114 | "start_time": "2021-02-18T11:07:03.520364Z"
115 | }
116 | },
117 | "outputs": [],
118 | "source": [
119 | "# plot path on S2\n",
120 | "from jaxgeometry.manifolds.S2 import *\n",
121 | "M = S2()\n",
122 | "print(M)\n",
123 | "\n",
124 | "# plot\n",
125 | "newfig()\n",
126 | "M.plot()\n",
127 | "x = M.F(M.coords([0.,0.]))\n",
128 | "M.plot_path(M.acts(gsv,x))\n",
129 | "plt.show()"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {
136 | "ExecuteTime": {
137 | "end_time": "2021-02-18T11:08:17.862357Z",
138 | "start_time": "2021-02-18T11:07:10.321261Z"
139 | }
140 | },
141 | "outputs": [],
142 | "source": [
143 | "# setup for testing different versions of dynamics\n",
144 | "q = jnp.array([1e-3,0.,0.])\n",
145 | "g = G.psi(q)\n",
146 | "v = jnp.array([0.,1.,1.])\n",
147 | "\n",
148 | "from jaxgeometry.group import invariant_metric\n",
149 | "invariant_metric.initialize(G)\n",
150 | "p = G.sharppsi(q,v)\n",
151 | "mu = G.sharpV(v)\n",
152 | "print(p)\n",
153 | "print(mu)\n",
154 | "\n",
155 | "from jaxgeometry.group import energy\n",
156 | "energy.initialize(G)"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": null,
162 | "metadata": {
163 | "ExecuteTime": {
164 | "end_time": "2021-02-18T11:08:36.457769Z",
165 | "start_time": "2021-02-18T11:08:17.864547Z"
166 | }
167 | },
168 | "outputs": [],
169 | "source": [
170 | "# Euler-Poincare dynamics\n",
171 | "from jaxgeometry.group import EulerPoincare\n",
172 | "EulerPoincare.initialize(G)# Euler-Poincare dynamics\n",
173 | "\n",
174 | "# geodesic\n",
175 | "(ts,gsv) = G.ExpEPt(G.psi(q),v)\n",
176 | "newfig()\n",
177 | "G.plot_path(gsv)\n",
178 | "plt.show()\n",
179 | "(ts,musv) = G.EP(mu)\n",
180 | "xisv = [G.invFl(mu) for mu in musv]\n",
181 | "print(\"Energy: \",np.array([G.l(xi) for xi in xisv]))\n",
182 | "print(\"Orthogonality: \",np.array([np.linalg.norm(np.dot(g,g.T)-np.eye(int(np.sqrt(G.emb_dim))),np.inf) for g in gsv]))\n",
183 | "\n",
184 | "# on S2\n",
185 | "newfig()\n",
186 | "M.plot(rotate=(30,-15))\n",
187 | "x = jnp.array([0,0,1])\n",
188 | "M.plot_path(M.acts(gsv,x))\n",
189 | "plt.show()"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": null,
195 | "metadata": {
196 | "ExecuteTime": {
197 | "end_time": "2021-02-18T11:08:50.590252Z",
198 | "start_time": "2021-02-18T11:08:36.459751Z"
199 | }
200 | },
201 | "outputs": [],
202 | "source": [
203 | "# Lie-Poission dynamics\n",
204 | "from jaxgeometry.group import LiePoisson\n",
205 | "LiePoisson.initialize(G)\n",
206 | "\n",
207 | "# geodesic\n",
208 | "(ts,gsv) = G.ExpLPt(G.psi(q),v)\n",
209 | "newfig()\n",
210 | "G.plot_path(gsv)\n",
211 | "plt.show()\n",
212 | "(ts,musv) = G.LP(mu)\n",
213 | "print(\"Energy: \",np.array([G.Hminus(mu) for mu in musv]))\n",
214 | "print(\"Orthogonality: \",np.array([np.linalg.norm(np.dot(g,g.T)-np.eye(int(np.sqrt(G.dim))),np.inf) for g in gsv]))"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": null,
220 | "metadata": {
221 | "ExecuteTime": {
222 | "end_time": "2021-02-18T11:12:05.111339Z",
223 | "start_time": "2021-02-18T11:08:50.592810Z"
224 | }
225 | },
226 | "outputs": [],
227 | "source": [
228 | "# Hamiltonian dynamics\n",
229 | "from jaxgeometry.dynamics import Hamiltonian\n",
230 | "Hamiltonian.initialize(G)\n",
231 | "\n",
232 | "# test Hamiltionian and gradients\n",
233 | "print(p)\n",
234 | "print(G.H(q,p))\n",
235 | "\n",
236 | "# geodesic\n",
237 | "qsv,_ = G.Exp_Hamiltoniant((q,None),p)\n",
238 | "gsv = np.array([G.psi(q) for q in qsv])\n",
239 | "newfig()\n",
240 | "G.plot_path(gsv)\n",
241 | "plt.show()\n",
242 | "(ts,qpsv,_) = G.Hamiltonian_dynamics((q,None),p,dts())\n",
243 | "psv = qpsv[:,1,:]\n",
244 | "print(\"Energy: \",np.array([G.H(q,p) for (q,p) in zip(qsv,psv)]))"
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": null,
250 | "metadata": {},
251 | "outputs": [],
252 | "source": []
253 | }
254 | ],
255 | "metadata": {
256 | "kernelspec": {
257 | "display_name": "Python 3 (ipykernel)",
258 | "language": "python",
259 | "name": "python3"
260 | },
261 | "language_info": {
262 | "codemirror_mode": {
263 | "name": "ipython",
264 | "version": 3
265 | },
266 | "file_extension": ".py",
267 | "mimetype": "text/x-python",
268 | "name": "python",
269 | "nbconvert_exporter": "python",
270 | "pygments_lexer": "ipython3",
271 | "version": "3.12.1"
272 | }
273 | },
274 | "nbformat": 4,
275 | "nbformat_minor": 4
276 | }
277 |
--------------------------------------------------------------------------------
/examples/SO3_stochastics.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2021-02-18T19:20:32.491118Z",
9 | "start_time": "2021-02-18T19:20:32.488147Z"
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "## This file is part of Jax Geometry\n",
15 | "#\n",
16 | "# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)\n",
17 | "# https://bitbucket.org/stefansommer/jaxgeometry\n",
18 | "#\n",
19 | "# Jax Geometry is free software: you can redistribute it and/or modify\n",
20 | "# it under the terms of the GNU General Public License as published by\n",
21 | "# the Free Software Foundation, either version 3 of the License, or\n",
22 | "# (at your option) any later version.\n",
23 | "#\n",
24 | "# Jax Geometry is distributed in the hope that it will be useful,\n",
25 | "# but WITHOUT ANY WARRANTY; without even the implied warranty of\n",
26 | "# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n",
27 | "# GNU General Public License for more details.\n",
28 | "#\n",
29 | "# You should have received a copy of the GNU General Public License\n",
30 | "# along with Jax Geometry. If not, see .\n",
31 | "#"
32 | ]
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "metadata": {
37 | "collapsed": true
38 | },
39 | "source": [
40 | "# Stochastic Lie group dynamics"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "%load_ext autoreload\n",
50 | "%autoreload 2"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {
57 | "ExecuteTime": {
58 | "end_time": "2021-02-18T19:21:18.542521Z",
59 | "start_time": "2021-02-18T19:20:32.493368Z"
60 | }
61 | },
62 | "outputs": [],
63 | "source": [
64 | "# SO(3)\n",
65 | "from jaxgeometry.groups.SON import *\n",
66 | "G = SON(3)\n",
67 | "print(G)\n",
68 | "\n",
69 | "# SO(3) acts on S^2\n",
70 | "from jaxgeometry.manifolds.S2 import *\n",
71 | "M = S2()\n",
72 | "print(M)\n",
73 | "\n",
74 | "from jaxgeometry.plotting import *"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {
81 | "ExecuteTime": {
82 | "end_time": "2021-02-18T19:22:13.618982Z",
83 | "start_time": "2021-02-18T19:21:18.544893Z"
84 | }
85 | },
86 | "outputs": [],
87 | "source": [
88 | "# setup for testing different versions of stochastic dynamics\n",
89 | "q = jnp.array([1e-3,0.,0.])\n",
90 | "g = G.psi(q)\n",
91 | "v = jnp.array([0.,1.,1.])\n",
92 | "\n",
93 | "x = M.coords(jnp.array([0.,0.]))\n",
94 | "\n",
95 | "from jaxgeometry.group import invariant_metric\n",
96 | "invariant_metric.initialize(G)\n",
97 | "p = G.sharppsi(q,v)\n",
98 | "mu = G.sharpV(v)\n",
99 | "print(p)\n",
100 | "print(mu)\n",
101 | "\n",
102 | "from jaxgeometry.group import energy\n",
103 | "energy.initialize(G)"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "metadata": {
110 | "ExecuteTime": {
111 | "end_time": "2021-02-18T19:22:36.715889Z",
112 | "start_time": "2021-02-18T19:22:13.621269Z"
113 | },
114 | "scrolled": false
115 | },
116 | "outputs": [],
117 | "source": [
118 | "# Brownian motion\n",
119 | "from jaxgeometry.stochastics import Brownian_inv\n",
120 | "Brownian_inv.initialize(G)\n",
121 | "\n",
122 | "_dts = dts(n_steps=1000)\n",
123 | "(ts,gs,_) = G.Brownian_inv(g,_dts,dWs(G.dim,_dts))\n",
124 | "\n",
125 | "# plot\n",
126 | "newfig()\n",
127 | "G.plot_path(gs,linewidth=0.1,alpha=0.1)\n",
128 | "plt.show()\n",
129 | "#plt.savefig('stocso3.pdf')\n",
130 | "\n",
131 | "# on S2\n",
132 | "newfig()\n",
133 | "M.plot()\n",
134 | "M.plot_path(M.acts(gs,M.F(x)))\n",
135 | "plt.show()\n",
136 | "#plt.savefig('stocso32.pdf')"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "metadata": {
143 | "ExecuteTime": {
144 | "end_time": "2021-02-18T19:22:53.566779Z",
145 | "start_time": "2021-02-18T19:22:36.717802Z"
146 | }
147 | },
148 | "outputs": [],
149 | "source": [
150 | "# Brownian processes\n",
151 | "from jaxgeometry.stochastics import Brownian_process\n",
152 | "Brownian_process.initialize(G)\n",
153 | "\n",
154 | "_dts = dts(n_steps=1000)\n",
155 | "(ts,gs,_) = G.Brownian_process(g,_dts,dWs(G.dim,_dts))\n",
156 | "\n",
157 | "# plot\n",
158 | "newfig()\n",
159 | "G.plot_path(gs,color_intensity=1,alpha=0.1)\n",
160 | "plt.show()"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "metadata": {},
167 | "outputs": [],
168 | "source": [
169 | "# Euler-Poincare dynamics\n",
170 | "from jaxgeometry.group import EulerPoincare\n",
171 | "EulerPoincare.initialize(G)# Euler-Poincare dynamics"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": null,
177 | "metadata": {
178 | "ExecuteTime": {
179 | "end_time": "2021-02-18T19:23:24.463449Z",
180 | "start_time": "2021-02-18T19:22:53.568769Z"
181 | },
182 | "scrolled": false
183 | },
184 | "outputs": [],
185 | "source": [
186 | "# Stochastic coadjoint motion\n",
187 | "from jaxgeometry.stochastics import stochastic_coadjoint\n",
188 | "stochastic_coadjoint.initialize(G)\n",
189 | "\n",
190 | "_dts = dts(n_steps=1000)\n",
191 | "(ts,mus,_) = G.stochastic_coadjoint(mu,_dts,dWs(G.dim,_dts))\n",
192 | "(ts,gs) = G.stochastic_coadjointrec(g,mus,_dts)\n",
193 | "\n",
194 | "# plot\n",
195 | "newfig()\n",
196 | "G.plot_path(gs,color_intensity=1,alpha=0.1)\n",
197 | "plt.show()\n",
198 | "#plt.savefig('coadgeo.pdf')"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": null,
204 | "metadata": {
205 | "ExecuteTime": {
206 | "end_time": "2021-02-18T19:24:29.056415Z",
207 | "start_time": "2021-02-18T19:23:24.465976Z"
208 | },
209 | "scrolled": false
210 | },
211 | "outputs": [],
212 | "source": [
213 | "# Delyon/Hu guided process\n",
214 | "from jaxgeometry.stochastics.guided_process import *\n",
215 | "\n",
216 | "# guide function\n",
217 | "phi = lambda g,v,sigma: jnp.tensordot(G.inv(sigma),G.LAtoV(G.log(G.invtrns(G.inv(g[0]),v))),(1,0))\n",
218 | "A = lambda g,v,w,sigma: G.gG(g[0],v,w,sigma)\n",
219 | "logdetA = lambda x,sigma: -2*(jnp.linalg.slogdet(sigma)[1])\n",
220 | "\n",
221 | "(Brownian_inv_guided,sde_Brownian_coords_inv,_,_,_) = get_guided(\n",
222 | " G,G.sde_Brownian_inv,None,phi,\n",
223 | " lambda g,sigma: sigma,A,logdetA,integration='stratonovich')\n",
224 | "\n",
225 | "_dts = dts(n_steps=1000)\n",
226 | "(ts,gs,_,log_likelihood,log_varphi) = Brownian_inv_guided((g,None),G.psi(v),_dts,dWs(G.dim,_dts),\n",
227 | " jnp.sqrt(.2)*jnp.eye(G.dim))\n",
228 | "print(\"log likelihood: \", log_likelihood[-1], \", log varphi: \", log_varphi[-1])\n",
229 | "\n",
230 | "newfig()\n",
231 | "w = G.psi(v)\n",
232 | "G.plot_path(gs,linewidth=0.1,alpha=0.1)\n",
233 | "G.plotg(w,color='k')\n",
234 | "plt.show()\n",
235 | "\n",
236 | "# on S2\n",
237 | "newfig()\n",
238 | "M.plot()\n",
239 | "M.plot_path(M.acts(gs,M.F(x)))\n",
240 | "M.plotx(M.act(w,M.F(x)),color='k')\n",
241 | "plt.show()"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": null,
247 | "metadata": {},
248 | "outputs": [],
249 | "source": []
250 | }
251 | ],
252 | "metadata": {
253 | "kernelspec": {
254 | "display_name": "Python 3 (ipykernel)",
255 | "language": "python",
256 | "name": "python3"
257 | },
258 | "language_info": {
259 | "codemirror_mode": {
260 | "name": "ipython",
261 | "version": 3
262 | },
263 | "file_extension": ".py",
264 | "mimetype": "text/x-python",
265 | "name": "python",
266 | "nbconvert_exporter": "python",
267 | "pygments_lexer": "ipython3",
268 | "version": "3.12.1"
269 | }
270 | },
271 | "nbformat": 4,
272 | "nbformat_minor": 1
273 | }
274 |
--------------------------------------------------------------------------------
/examples/SPD3.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "## This file is part of Jax Geometry\n",
10 | "#\n",
11 | "# Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)\n",
12 | "# https://bitbucket.org/stefansommer/jaxgeometry\n",
13 | "#\n",
14 | "# Jax Geometry is free software: you can redistribute it and/or modify\n",
15 | "# it under the terms of the GNU General Public License as published by\n",
16 | "# the Free Software Foundation, either version 3 of the License, or\n",
17 | "# (at your option) any later version.\n",
18 | "#\n",
19 | "# Jax Geometry is distributed in the hope that it will be useful,\n",
20 | "# but WITHOUT ANY WARRANTY; without even the implied warranty of\n",
21 | "# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n",
22 | "# GNU General Public License for more details.\n",
23 | "#\n",
24 | "# You should have received a copy of the GNU General Public License\n",
25 | "# along with Jax Geometry. If not, see .\n",
26 | "#"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "collapsed": true
33 | },
34 | "source": [
35 | "# GLN and SPDN dynamics"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "%load_ext autoreload\n",
45 | "%autoreload 2"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {
52 | "scrolled": false
53 | },
54 | "outputs": [],
55 | "source": [
56 | "from jaxgeometry.groups.GLN import *\n",
57 | "G = GLN(3)\n",
58 | "print(G)\n",
59 | "\n",
60 | "from jaxgeometry.manifolds.SPDN import *\n",
61 | "M = SPDN(3)\n",
62 | "print(M)\n",
63 | "\n",
64 | "from jaxgeometry.plotting import *\n",
65 | "figsize = 12,12\n",
66 | "plt.rcParams['figure.figsize'] = figsize"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "# some values\n",
76 | "v=np.array([.5,0,0,0,0,0,0,0,0])+1e-6*np.random.normal(size=G.dim) # must be non-singular for Expm derivative\n",
77 | "xiv=G.VtoLA(v)\n",
78 | "x = G.exp(xiv)"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {
85 | "scrolled": false
86 | },
87 | "outputs": [],
88 | "source": [
89 | "# visualization\n",
90 | "newfig()\n",
91 | "G.plotg(x)\n",
92 | "plt.show()\n",
93 | "\n",
94 | "_dts = dts()\n",
95 | "gsv = np.zeros((_dts.shape[0],3,3))\n",
96 | "for i in range(_dts.shape[0]):\n",
97 | " gsv[i] = G.exp(_dts[i]*xiv)\n",
98 | "newfig()\n",
99 | "G.plot_path(gsv)\n",
100 | "plt.show()\n",
101 | "\n",
102 | "# on SPD(3)\n",
103 | "newfig()\n",
104 | "M.plot()\n",
105 | "x0 = np.eye(M.N).flatten()\n",
106 | "M.plot_path(M.acts(gsv,x0))\n",
107 | "plt.show()\n",
108 | "\n",
109 | "# ellipsoids\n",
110 | "plt.rcParams['figure.figsize'] = 23, 10\n",
111 | "M.plot_path(M.acts(gsv,x0),ellipsoid={'alpha': .2, 'step': _dts.shape[0]/4, 'subplot': True})\n",
112 | "plt.show()\n",
113 | "plt.rcParams['figure.figsize'] = figsize"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": null,
119 | "metadata": {},
120 | "outputs": [],
121 | "source": [
122 | "# define invariant metric on GL(N)\n",
123 | "from jaxgeometry.group import invariant_metric\n",
124 | "invariant_metric.initialize(G)\n",
125 | "from jaxgeometry.group import energy\n",
126 | "energy.initialize(G)"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "metadata": {
133 | "scrolled": false
134 | },
135 | "outputs": [],
136 | "source": [
137 | "# Euler-Poincare dynamics\n",
138 | "from jaxgeometry.group import EulerPoincare\n",
139 | "EulerPoincare.initialize(G)\n",
140 | "\n",
141 | "# geodesic\n",
142 | "(ts,gsv) = G.ExpEPt(G.e,v)\n",
143 | "newfig()\n",
144 | "G.plot_path(gsv)\n",
145 | "plt.show()\n",
146 | "(ts,musv) = G.EP(v)\n",
147 | "xisv = [G.invFl(mu) for mu in musv]\n",
148 | "print(\"Energy: \",np.array([G.l(xi) for xi in xisv]))\n",
149 | "\n",
150 | "# on SPD(3)\n",
151 | "newfig()\n",
152 | "M.plot()\n",
153 | "x0 = np.eye(M.N).flatten()\n",
154 | "M.plot_path(M.acts(gsv,x0))\n",
155 | "plt.show()\n",
156 | "\n",
157 | "# ellipsoids\n",
158 | "plt.rcParams['figure.figsize'] = 23, 10\n",
159 | "M.plot_path(M.acts(gsv,x0),ellipsoid={'alpha': .2, 'step': dts().shape[0]/4, 'subplot': True})\n",
160 | "plt.show()\n",
161 | "plt.rcParams['figure.figsize'] = figsize"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": null,
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "# Lie-Poission dynamics\n",
171 | "from jaxgeometry.group import LiePoisson\n",
172 | "LiePoisson.initialize(G)\n",
173 | "\n",
174 | "# geodesic\n",
175 | "(ts,gsv) = G.ExpLPt(G.e,v)\n",
176 | "newfig()\n",
177 | "G.plot_path(gsv)\n",
178 | "plt.show()\n",
179 | "(ts,musv) = G.LP(v)\n",
180 | "print(\"Energy: \",np.array([G.Hminus(mu) for mu in musv]))"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": null,
186 | "metadata": {},
187 | "outputs": [],
188 | "source": [
189 | "# Brownian motion\n",
190 | "from jaxgeometry.stochastics import Brownian_inv\n",
191 | "Brownian_inv.initialize(G)\n",
192 | "\n",
193 | "_dts = dts(n_steps=100)\n",
194 | "(ts,gs,_) = G.Brownian_inv(G.e,_dts,dWs(G.dim,_dts),jnp.sqrt(.1)*jnp.eye(G.emb_dim))\n",
195 | "\n",
196 | "# on SPD(3)\n",
197 | "newfig()\n",
198 | "M.plot()\n",
199 | "x0 = np.eye(M.N).flatten()\n",
200 | "M.plot_path(M.acts(gs,x0))\n",
201 | "plt.show()\n",
202 | "\n",
203 | "# ellipsoids\n",
204 | "plt.rcParams['figure.figsize'] = 23, 10\n",
205 | "M.plot_path(M.acts(gsv,x0),ellipsoid={'alpha': .2, 'step': _dts.shape[0]/8, 'subplot': True})\n",
206 | "# plt.savefig('SPD3-path.pdf')\n",
207 | "plt.show()\n",
208 | "plt.rcParams['figure.figsize'] = figsize"
209 | ]
210 | },
211 | {
212 | "cell_type": "code",
213 | "execution_count": null,
214 | "metadata": {},
215 | "outputs": [],
216 | "source": []
217 | }
218 | ],
219 | "metadata": {
220 | "kernelspec": {
221 | "display_name": "Python 3 (ipykernel)",
222 | "language": "python",
223 | "name": "python3"
224 | },
225 | "language_info": {
226 | "codemirror_mode": {
227 | "name": "ipython",
228 | "version": 3
229 | },
230 | "file_extension": ".py",
231 | "mimetype": "text/x-python",
232 | "name": "python",
233 | "nbconvert_exporter": "python",
234 | "pygments_lexer": "ipython3",
235 | "version": "3.9.14"
236 | }
237 | },
238 | "nbformat": 4,
239 | "nbformat_minor": 1
240 | }
241 |
--------------------------------------------------------------------------------
/jaxgeometry/Riemannian/Log.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | def initialize(M,f=None,lossf=None,method='BFGS'):
24 | """ numerical Riemannian Logarithm map """
25 |
26 | if f is None:
27 | print("using M.Exp for Logarithm")
28 | f = M.Exp
29 | def loss(x,v,y):
30 | (x1,chart1) = f(x,v)
31 | y_chart1 = M.update_coords(y,chart1)
32 | if lossf is None:
33 | return 1./M.dim*jnp.sum(jnp.square(x1 - y_chart1[0]))
34 | else:
35 | return lossf(x1,y_chart1[0])
36 | dloss = jax.grad(loss,1)
37 | #from scipy.optimize import approx_fprime
38 | #dloss = lambda x,v,y: approx_fprime(v,lambda v: loss(x,v,y),1e-4)
39 |
40 | from jax.scipy.optimize import minimize
41 | def shoot(x,y,v0=None):
42 |
43 | if v0 is None:
44 | v0 = jnp.zeros(M.dim)
45 |
46 | res = minimize(lambda w: loss(x,w,y), v0, method=method, options={'maxiter': 100})
47 |
48 | return (res.x,res.fun)
49 |
50 | M.Log = shoot
51 |
--------------------------------------------------------------------------------
/jaxgeometry/Riemannian/__init__.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
--------------------------------------------------------------------------------
/jaxgeometry/Riemannian/curvature.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.utils import *
23 | def initialize(M):
24 | """ Riemannian curvature """
25 |
26 | """
27 | Riemannian Curvature tensor
28 |
29 | Args:
30 | x: point on manifold
31 |
32 | Returns:
33 | 4-tensor R_ijk^l in with order i,j,k,l
34 | (see e.g. https://en.wikipedia.org/wiki/List_of_formulas_in_Riemannian_geometry#(3,1)_Riemann_curvature_tensor )
35 | Note that sign convention follows e.g. Lee, Riemannian Manifolds.
36 | """
37 | M.R = jit(lambda x: -(jnp.einsum('pik,ljp->ijkl',M.Gamma_g(x),M.Gamma_g(x))
38 | -jnp.einsum('pjk,lip->ijkl',M.Gamma_g(x),M.Gamma_g(x))
39 | +jnp.einsum('likj->ijkl',M.DGamma_g(x))
40 | -jnp.einsum('ljki->ijkl',M.DGamma_g(x))))
41 |
42 | """
43 | Riemannian Curvature form
44 | R_u (also denoted Omega) is the gl(n)-valued curvature form u^{-1}Ru for a frame
45 | u for T_xM
46 |
47 | Args:
48 | x: point on manifold
49 |
50 | Returns:
51 | 4-tensor (R_u)_ij^m_k with order i,j,m,k
52 | """
53 | M.R_u = jit(lambda x,u: jnp.einsum('ml,ijql,qk->ijmk',jnp.linalg.inv(u),R(x),u))
54 |
55 | # """
56 | # Sectional curvature
57 | #
58 | # Args:
59 | # x: point on manifold
60 | # e1,e2: two orthonormal vectors spanning the section
61 | #
62 | # Returns:
63 | # sectional curvature K(e1,e2)
64 | # """
65 | # @jit
66 | # def sec_curv(x,e1,e2):
67 | # Rflat = jnp.tensordot(M.R(x),M.g(x),[3,0])
68 | # sec = jnp.tensordot(
69 | # jnp.tensordot(
70 | # jnp.tensordot(
71 | # jnp.tensordot(
72 | # Rflat,
73 | # e1, [0,0]),
74 | # e2, [0,0]),
75 | # e2, [0,0]),
76 | # e1, [0,0])
77 | # return sec
78 | # M.sec_curv = sec_curv
79 |
80 | """
81 | Ricci curvature
82 |
83 | Args:
84 | x: point on manifold
85 |
86 | Returns:
87 | 2-tensor R_ij in order i,j
88 | """
89 | M.Ricci_curv = jit(lambda x: jnp.einsum('kijk->ij',M.R(x)))
90 |
91 | """
92 | Scalar curvature
93 |
94 | Args:
95 | x: point on manifold
96 |
97 | Returns:
98 | scalar curvature
99 | """
100 | M.S_curv = jit(lambda x: jnp.einsum('ij,ij->',M.gsharp(x),M.Ricci_curv(x)))
101 |
102 |
--------------------------------------------------------------------------------
/jaxgeometry/Riemannian/geodesic.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | def initialize(M):
24 | def ode_geodesic(c,y):
25 | t,x,chart = c
26 | dx2t = -jnp.einsum('ikl,k,l->i',M.Gamma_g((x[0],chart)),x[1],x[1])
27 | dx1t = x[1]
28 | return jnp.stack((dx1t,dx2t))
29 |
30 | def chart_update_geodesic(xv,chart,y):
31 | if M.do_chart_update is None:
32 | return (xv,chart)
33 |
34 | v = xv[1]
35 | x = (xv[0],chart)
36 |
37 | update = M.do_chart_update(x)
38 | new_chart = M.centered_chart(x)
39 | new_x = M.update_coords(x,new_chart)[0]
40 |
41 | return (jnp.where(update,
42 | jnp.stack((new_x,M.update_vector(x,new_x,new_chart,v))),
43 | xv),
44 | jnp.where(update,
45 | new_chart,
46 | chart))
47 |
48 | M.geodesic = jit(lambda x,v,dts: integrate(ode_geodesic,chart_update_geodesic,jnp.stack((x[0],v)),x[1],dts))
49 |
50 | def Exp(x,v,T=T,n_steps=n_steps):
51 | curve = M.geodesic(x,v,dts(T,n_steps))
52 | x = curve[1][-1,0]
53 | chart = curve[2][-1]
54 | return(x,chart)
55 | M.Exp = Exp
56 | def Expt(x,v,T=T,n_steps=n_steps):
57 | curve = M.geodesic(x,v,dts(T,n_steps))
58 | xs = curve[1][:,0]
59 | charts = curve[2]
60 | return(xs,charts)
61 | M.Expt = Expt
62 |
--------------------------------------------------------------------------------
/jaxgeometry/Riemannian/metric.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.utils import *
23 |
24 | def initialize(M,truncate_high_order_derivatives=False):
25 | """ add metric related structures to manifold """
26 |
27 | d = M.dim
28 |
29 | if hasattr(M, 'g'):
30 | if not hasattr(M, 'gsharp'):
31 | M.gsharp = lambda x: jnp.linalg.inv(M.g(x))
32 | elif hasattr(M, 'gsharp'):
33 | if not hasattr(M, 'g'):
34 | M.g = lambda x: jnp.linalg.inv(M.gsharp(x))
35 | else:
36 | raise ValueError('no metric or cometric defined on manifold')
37 |
38 | M.Dg = jacfwdx(M.g) # derivative of metric
39 |
40 | ##### Measure
41 | M.mu_Q = lambda x: 1./jnp.nlinalg.Det()(M.g(x))
42 |
43 | ### Determinant
44 | def det(x,A=None):
45 | return jnp.linalg.det(M.g(x)) if A is None else jnp.linalg.det(jnp.tensordot(M.g(x),A,(1,0)))
46 | def detsharp(x,A=None):
47 | return jnp.linalg.det(M.gsharp(x)) if A is None else jnp.linalg.det(jnp.tensordot(M.gsharp(x),A,(1,0)))
48 | M.det = det
49 | M.detsharp = detsharp
50 | def logAbsDet(x,A=None):
51 | return jnp.linalg.slogdet(M.g(x))[1] if A is None else jnp.linalg.slogdet(jnp.tensordot(M.g(x),A,(1,0)))[1]
52 | def logAbsDetsharp(x,A=None):
53 | return jnp.linalg.slogdet(M.gsharp(x))[1] if A is None else jnp.linalg.slogdet(jnp.tensordot(M.gsharp(x),A,(1,0)))[1]
54 | M.logAbsDet = logAbsDet
55 | M.logAbsDetsharp = logAbsDetsharp
56 |
57 | ##### Sharp and flat map:
58 | M.flat = lambda x,v: jnp.tensordot(M.g(x),v,(1,0))
59 | M.sharp = lambda x,p: jnp.tensordot(M.gsharp(x),p,(1,0))
60 |
61 | ##### Christoffel symbols
62 | # \Gamma^i_{kl}, indices in that order
63 | #M.Gamma_g = lambda x: 0.5*(jnp.einsum('im,mkl->ikl',M.gsharp(x),M.Dg(x))
64 | # +jnp.einsum('im,mlk->ikl',M.gsharp(x),M.Dg(x))
65 | # -jnp.einsum('im,klm->ikl',M.gsharp(x),M.Dg(x)))
66 | def Gamma_g(x):
67 | Dgx = M.Dg(x)
68 | gsharpx = M.gsharp(x)
69 | return 0.5*(jnp.einsum('im,kml->ikl',gsharpx,Dgx)
70 | +jnp.einsum('im,lmk->ikl',gsharpx,Dgx)
71 | -jnp.einsum('im,klm->ikl',gsharpx,Dgx))
72 | M.Gamma_g = Gamma_g
73 | M.DGamma_g = jacfwdx(M.Gamma_g)
74 |
75 | # Inner Product from g
76 | M.dot = lambda x,v,w: jnp.tensordot(jnp.tensordot(M.g(x),w,(1,0)),v,(0,0))
77 | M.norm = lambda x,v: jnp.sqrt(M.dot(x,v,v))
78 | M.norm2 = lambda x,v: M.dot(x,v,v)
79 | M.dotsharp = lambda x,p,pp: jnp.tensordot(jnp.tensordot(M.gsharp(x),pp,(1,0)),p,(0,0))
80 | M.conorm = lambda x,p: jnp.sqrt(M.dotsharp(x,p,p))
81 |
82 | ##### Gram-Schmidt and basis
83 | M.gramSchmidt = lambda x,u: (GramSchmidt_f(M.dotf))(x,u)
84 | M.orthFrame = lambda x: jnp.linalg.Cholesky(M.gsharp(x))
85 |
86 | ##### Hamiltonian
87 | M.H = lambda q,p: 0.5*jnp.tensordot(p,jnp.tensordot(M.gsharp(q),p,(1,0)),(0,0))
88 |
89 | # gradient, divergence, and Laplace-Beltrami
90 | M.grad = lambda x,f: M.sharp(x,gradx(f)(x))
91 | M.div = lambda x,X: jnp.trace(jacfwdx(X)(x))+.5*jnp.dot(X(x),gradx(M.logAbsDet)(x))
92 | M.divsharp = lambda x,X: jnp.trace(jacfwdx(X)(x))-.5*jnp.dot(X(x),gradx(M.logAbsDetsharp)(x))
93 | M.Laplacian = lambda x,f: M.div(x,lambda x: M.grad(x,f))
94 |
--------------------------------------------------------------------------------
/jaxgeometry/Riemannian/parallel_transport.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.utils import *
23 |
24 | def initialize(M):
25 | """ Riemannian parallel transport """
26 |
27 | def ode_parallel_transport(c,y):
28 | t,xv,prevchart = c
29 | x,chart,dx = y
30 | prevx = xv[0]
31 | v = xv[1]
32 |
33 | if M.do_chart_update is not None:
34 | dx = jnp.where(jnp.sum(jnp.square(chart-prevchart)) <= 1e-5,
35 | dx,
36 | M.update_vector((x,chart),prevx,prevchart,dx)
37 | )
38 | dv = -jnp.einsum('ikl,k,l->i',M.Gamma_g((x,chart)),dx,v)
39 | return jnp.stack((jnp.zeros_like(x),dv))
40 |
41 | def chart_update_parallel_transport(xv,prevchart,y):
42 | x,chart,dx = y
43 | if M.do_chart_update is None:
44 | return (xv,chart)
45 |
46 | prevx = xv[0]
47 | v = xv[1]
48 | return (jnp.where(jnp.sum(jnp.square(chart-prevchart)) <= 1e-5,
49 | jnp.stack((x,v)),
50 | jnp.stack((x,M.update_vector((prevx,prevchart),x,chart,v)))),
51 | chart)
52 |
53 | parallel_transport = lambda v,dts,xs,charts,dxs: integrate(ode_parallel_transport,chart_update_parallel_transport,jnp.stack((xs[0],v)),charts[0],dts,xs,charts,dxs)
54 | M.parallel_transport = jit(lambda v,dts,xs,charts,dxs: parallel_transport(v,dts,xs,charts,dxs)[1][:,1])
55 |
--------------------------------------------------------------------------------
/jaxgeometry/__init__.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
--------------------------------------------------------------------------------
/jaxgeometry/dynamics/Hamiltonian.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | ###############################################################
24 | # geodesic integration, Hamiltonian form #
25 | ###############################################################
26 | def initialize(M):
27 | dq = grad(M.H,argnums=1)
28 | dp = lambda q,p: -gradx(M.H)(q,p)
29 |
30 | def ode_Hamiltonian(c,y):
31 | t,x,chart = c
32 | dqt = dq((x[0],chart),x[1])
33 | dpt = dp((x[0],chart),x[1])
34 | return jnp.stack((dqt,dpt))
35 |
36 | def chart_update_Hamiltonian(xp,chart,y):
37 | if M.do_chart_update is None:
38 | return (xp,chart)
39 |
40 | p = xp[1]
41 | x = (xp[0],chart)
42 |
43 | update = M.do_chart_update(x)
44 | new_chart = M.centered_chart(x)
45 | new_x = M.update_coords(x,new_chart)[0]
46 |
47 | return (jnp.where(update,
48 | jnp.stack((new_x,M.update_covector(x,new_x,new_chart,p))),
49 | xp),
50 | jnp.where(update,
51 | new_chart,
52 | chart))
53 |
54 | M.Hamiltonian_dynamics = jit(lambda q,p,dts: integrate(ode_Hamiltonian,chart_update_Hamiltonian,jnp.stack((q[0] if type(q)==type(()) else q,p)),q[1] if type(q)==type(()) else None,dts))
55 |
56 | def Exp_Hamiltonian(q,p,T=T,n_steps=n_steps):
57 | curve = M.Hamiltonian_dynamics(q,p,dts(T,n_steps))
58 | q = curve[1][-1,0]
59 | chart = curve[2][-1]
60 | return(q,chart)
61 | M.Exp_Hamiltonian = Exp_Hamiltonian
62 | def Exp_Hamiltoniant(q,p,T=T,n_steps=n_steps):
63 | curve = M.Hamiltonian_dynamics(q,p,dts(T,n_steps))
64 | qs = curve[1][:,0]
65 | charts = curve[2]
66 | return(qs,charts)
67 | M.Exp_Hamiltoniant = Exp_Hamiltoniant
68 |
--------------------------------------------------------------------------------
/jaxgeometry/dynamics/MPP_Kunita.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.utils import *
23 |
24 | ###############################################################
25 | # Most probable paths for Kunita flows #
26 | ###############################################################
27 | def initialize(M,N,sigmas,u):
28 | """ Most probable paths for Kunita flows """
29 | """ M: shape manifold, N: embedding space, u: flow field """
30 |
31 | # Riemannian structure on N
32 | N.gsharp = lambda x: jnp.einsum('pri,qrj->ij',sigmas(x[0]),sigmas(x[0]))
33 | delattr(N,'g')
34 | from jaxgeometry.Riemannian import metric
35 | metric.initialize(N)
36 |
37 | # scalar part of elliptic operator L = 1/2 \Delta_g + z
38 | z = lambda x,qp: (u(x,qp)
39 | -0.25*jnp.einsum('ij,i->j',N.gsharp(x),gradx(N.logAbsDetsharp)(x))
40 | -0.5*jnp.einsum('iji->j',jacrevx(N.gsharp)(x))
41 | +0.5*jnp.einsum('...rj,...rii->j',sigmas(x[0]),jax.jacrev(sigmas)(x[0]))
42 | )
43 |
44 | # Onsager-Machlup deviation from geodesic energy
45 | # f = lambda x,qp: .5*jnp.einsum('rs,sr->',N.gsharp(x),
46 | # jacrevx(z)(x,qp)+jnp.einsum('k,srk->sr',z(x,qp),N.Gamma_g(x)))-1/12*N.S_curv(x)
47 | f = lambda x,qp: .5*N.divsharp(x,lambda x: z(x,qp))-1/12*N.S_curv(x)
48 |
49 | N.u = u
50 | N.z = z
51 | N.f = f
52 |
53 | def ode_MPP_AC(c,y):
54 | t,xx1,chart = c
55 | qp,dqp = y
56 | x = xx1[0] # point
57 | x1 = xx1[1] # derivative
58 |
59 | g = N.g((x,chart))
60 | gsharp = N.gsharp((x,chart))
61 | Gamma = N.Gamma_g((x,chart))
62 |
63 | zx = z((x,chart),qp)
64 | gradz = jacrevx(z)((x,chart),qp)
65 | dz = jnp.einsum('...ij,ij',jax.jacrev(z,argnums=1)((x,chart),qp),dqp)
66 |
67 | dx2 = (dz-jnp.einsum('i,j,kij->k',x1,x1,Gamma)
68 | +jnp.einsum('i,ki->k',x1,gradz+jnp.einsum('kij,j->ki',Gamma,zx))
69 | -jnp.einsum('rs,ri,s,ik->k',g,gradz+jnp.einsum('j,rij->ri',zx,Gamma),x1-zx,gsharp)
70 | +jnp.einsum('ik,i',gsharp,gradx(f)((x,chart),qp))
71 | )
72 | dx1 = x1
73 | return jnp.stack((dx1,dx2))
74 |
75 | def chart_update_MPP_AC(xv,chart,y):
76 | if M.do_chart_update is None:
77 | return (xv,chart)
78 |
79 | v = xv[1]
80 | x = (xv[0],chart)
81 |
82 | update = M.do_chart_update(x)
83 | new_chart = M.centered_chart(x)
84 | new_x = M.update_coords(x,new_chart)[0]
85 |
86 | return (jnp.where(update,
87 | jnp.stack((new_x,M.update_vector(x,new_x,new_chart,v))),
88 | xv),
89 | jnp.where(update,
90 | new_chart,
91 | chart))
92 |
93 | M.MPP_AC = jit(lambda x,v,qps,dqps,dts: integrate(ode_MPP_AC,chart_update_MPP_AC,jnp.stack((x[0],v)),x[1],dts,qps,dqps))
94 |
95 |
--------------------------------------------------------------------------------
/jaxgeometry/dynamics/MPP_Kunita_Log.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.utils import *
23 |
24 | ###############################################################
25 | # Most probable paths for Kunita flows - BVP #
26 | ###############################################################
27 | def initialize(M,N):
28 | method='BFGS'
29 |
30 | def loss(x,v,y,qps,dqps,_dts):
31 | (_,xx1,charts) = M.MPP_AC(x,v,qps,dqps,_dts)
32 | (x1,chart1) = (xx1[-1,0],charts[-1])
33 | y_chart1 = M.update_coords(y,chart1)
34 | return 1./N.dim*jnp.sum(jnp.square(x1 - y_chart1[0]))
35 | from scipy.optimize import approx_fprime
36 | dloss = lambda x,v,y,qps,dqps,_dts: approx_fprime(v,lambda v: loss(x,v,y,qps,dqps,_dts),1e-4)
37 |
38 | from scipy.optimize import minimize,fmin_bfgs,fmin_cg
39 | def shoot(x,y,qps,dqps,_dts,v0=None):
40 |
41 | if v0 is None:
42 | v0 = jnp.zeros(N.dim)
43 |
44 | #res = minimize(jax.value_and_grad(lambda w: loss(x,w,y,qps,dqps,_dts)), v0, method=method, jac=True, options={'disp': False, 'maxiter': 100})
45 | res = minimize(lambda w: (loss(x,w,y,qps,dqps,_dts),dloss(x,w,y,qps,dqps,_dts)), v0, method=method, jac=True, options={'disp': False, 'maxiter': 100})
46 | # res = minimize(lambda w: loss(x,w,y,qps,dqps,_dts), v0, method=method, jac=False, options={'disp': False, 'maxiter': 100})
47 |
48 | # print(res)
49 |
50 | return (res.x,res.fun)
51 |
52 | M.Log_MPP_AC = shoot
53 |
--------------------------------------------------------------------------------
/jaxgeometry/dynamics/MPP_group.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.utils import *
23 |
24 | from jaxgeometry.group.quotient import *
25 |
26 | ###############################################################
27 | # Most probable paths for Lie groups via development #
28 | ###############################################################
29 | def initialize(G,Sigma=None,a=None):
30 | """ Most probable paths and development """
31 |
32 | sign = -1. if G.invariance == 'right' else 1.
33 | Sigma = Sigma if Sigma is not None else jnp.eye(G.dim)
34 |
35 | def ode_mpp(sigma,c,y):
36 | t,alpha,_ = c
37 |
38 | at = a(t) if a is not None else jnp.zeros_like(alpha)
39 |
40 | z = jnp.dot(Sigma,G.sharpV(alpha))+at
41 | dalpha = sign*G.coad(z,alpha) # =-jnp.einsum('k,i,ijk->j',alpha,z,G.C)
42 | return dalpha
43 | G.mpp = lambda alpha,dts,sigma=jnp.eye(G.dim): integrate(partial(ode_mpp,sigma),None,alpha,None,dts)
44 |
45 | # reconstruction
46 | def ode_mpprec(sigma,c,y):
47 | t,g,_ = c
48 |
49 | alpha, = y
50 | at = a(t) if a is not None else jnp.zeros_like(alpha)
51 |
52 | z = jnp.dot(Sigma,G.sharpV(alpha))+at
53 | dgt = G.invpf(g,G.VtoLA(z))
54 | return dgt
55 | G.mpprec = lambda g,alpha,dts,sigma=jnp.eye(G.dim): integrate(partial(ode_mpprec,sigma),None,g,None,dts,alpha)
56 |
57 | # tracking point (not reduced to Lie algebra) to allow point-depending drift
58 | def ode_mpp_drift(sigma,c,y):
59 | t,x,_ = c
60 | alpha = x[0:G.dim]
61 | g = x[G.dim:].reshape((G.dim,G.dim))
62 |
63 | at = jnp.linalg.solve(g,a(t,g)) if a is not None else jnp.zeros_like(alpha)
64 |
65 | z = jnp.dot(Sigma,G.sharpV(alpha))+at
66 | dalpha = sign*G.coad(z,alpha) # =-jnp.einsum('k,i,ijk->j',alpha,z,G.C)
67 | dgt = G.invpf(g,G.VtoLA(z))
68 | return jnp.hstack((dalpha,dgt.flatten()))
69 | G.mpp_drift = lambda alpha,g,dts,sigma=jnp.eye(G.dim): integrate(partial(ode_mpp_drift,sigma),None,jnp.hstack((alpha,g.flatten())),None,dts)
70 |
71 | def MPP_forwardt(g,alpha,sigma,T=T,n_steps=n_steps):
72 | _dts = dts(T=T,n_steps=n_steps)
73 | (ts,alphas) = G.mpp(alpha,_dts,sigma)
74 | (ts,gs) = G.mpprec(g,alphas,_dts,sigma)
75 |
76 | return(gs,alphas)
77 | G.MPP_forwardt = MPP_forwardt
78 |
79 | # optimization to satisfy end-point conditions
80 | def MPP_f(g,alpha,y,sigma):
81 | gs,alphas = G.MPP_forwardt(g,alpha,sigma)
82 | gT = gs[-1]
83 | return (1./G.emb_dim)*jnp.sum(jnp.square(gT-y))
84 |
85 | def MPP(g,y,sigma=jnp.eye(G.dim)):
86 | res = jax.scipy.optimize.minimize(lambda alpha: MPP_f(g,alpha,y,sigma),jnp.zeros(G.dim),method='BFGS')
87 | alpha = res.x
88 |
89 | return alpha
90 | G.MPP = MPP
91 |
92 | def MPP_drift_f(g,alpha,y,sigma,proj,M,_dts):
93 | _,_,_,_,horz = horz_vert_split(g,proj,jnp.eye(G.dim),G,M)
94 | (ts,alphags) = G.mpp_drift(jnp.dot(horz,alpha),g,_dts,sigma)
95 | gT = alphags[-1,G.dim:].reshape((G.dim,G.dim))
96 | return (1./M.emb_dim)*jnp.sum(jnp.square(proj(gT)-M.F(y)))
97 |
98 | def MPP_drift(g,y,proj,M,sigma=jnp.eye(G.dim)):
99 | _dts = dts()
100 | res = jax.scipy.optimize.minimize(lambda alpha: MPP_drift_f(g,alpha,y,sigma,proj,M,_dts),jnp.zeros(M.dim),method='BFGS')
101 | _,_,_,_,horz = horz_vert_split(g,proj,jnp.eye(G.dim),G,M)
102 | alpha = jnp.dot(horz,res.x)
103 |
104 | return alpha
105 | G.MPP_drift = MPP_drift
106 |
107 |
108 | # # Most probable paths
109 | # def initialize(G,horz,vert,a=None):
110 | # """ Most probable paths and development """
111 |
112 | # assert(G.invariance == 'right')
113 |
114 | # def ode_mpp(sigma,c,y):
115 | # t,x,_ = c
116 |
117 | # vt = x[0]
118 | # ct = x[1]
119 | # lambdt = x[2]
120 | # at = a(t) if a is not None else jnp.zeros_like(vt)
121 | # Sigma = G.W(sigma)
122 |
123 | # domegat = jnp.dot(Sigma,vt)-at
124 | # dvt = horz(G.coad(domegat,vt-lambdt+ct))
125 | # dct = -vert(G.coad(domegat,vt))
126 | # dlambdt = vert(G.coad(domegat,ct))
127 | # return jnp.stack((dvt,ct,dlambdt))
128 | # G.mpp = lambda v,c,lambd,dts,sigma=jnp.eye(G.dim): integrate(partial(ode_mpp,sigma),None,jnp.stack((v,c,lambd)),None,dts)
129 |
130 | # # reconstruction
131 | # def ode_mpprec(sigma,c,y):
132 | # t,g,_ = c
133 |
134 | # x, = y
135 | # vt = x[0]
136 | # ct = x[1]
137 | # lambdt = x[2]
138 | # at = a(t) if a is not None else jnp.zeros_like(vt)
139 | # Sigma = G.W(sigma)
140 |
141 | # domegat = jnp.dot(Sigma,vt)-at
142 | # dgt = G.invpf(g,G.VtoLA(domegat))
143 | # return dgt
144 | # G.mpprec = lambda g,vclambds,dts,sigma=jnp.eye(G.dim): integrate(partial(ode_mpprec,sigma),None,g,None,dts,vclambds)
145 |
146 | # def MPP_forwardt(g,v,c,lambd,sigma,T=T,n_steps=n_steps):
147 | # _dts = dts(T=T,n_steps=n_steps)
148 | # (ts,xs) = G.mpp(v,c,lambd,_dts,sigma)
149 | # (ts,gs) = G.mpprec(g,xs,_dts,sigma)
150 |
151 | # vs = xs[0]
152 | # cs = xs[1]
153 | # lambds = xs[2]
154 | # return(gs,cs)
155 | # G.MPP_forwardt = MPP_forwardt
156 |
157 | # # optimization to satisfy end-point conditions
158 | # def MPP_f(g,x,y,sigma):
159 | # x = x.reshape((3,G.dim))
160 | # v = x[0]
161 | # c = x[1]
162 | # lambd = x[2]
163 | # gs,cs = G.MPP_forwardt(g,v,c,lambd,sigma)
164 | # gT = gs[-1]
165 | # cT = cs[-1]
166 | # return (1./G.emb_dim)*jnp.sum(jnp.square(gT-y))+(1./G.dim)*jnp.sum(jnp.square(cT))
167 |
168 | # def MPP(g,y,sigma=jnp.eye(G.dim)):
169 | # res = jax.scipy.optimize.minimize(lambda x: MPP_f(g,x,y,sigma),jnp.zeros((3,G.dim)).flatten(),method='BFGS')
170 | # x = res.x.reshape((3,G.dim))
171 | # v = x[0]
172 | # c = x[1]
173 | # lambd = x[2]
174 |
175 | # return (v,c,lambd)
176 | # G.MPP = MPP
177 |
178 | # E = jnp.eye(G.dim)
179 | # horz = lambda v: jnp.dot(E[:,:2],jnp.dot(E[:,:2].T,v))
180 | # vert = lambda v: jnp.dot(E[:,2:],jnp.dot(E[:,2:].T,v))
181 | # initialize(G,horz,vert)
182 |
--------------------------------------------------------------------------------
/jaxgeometry/dynamics/MPP_landmarks.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.utils import *
23 |
24 | ###############################################################
25 | # Most probable paths for landmarks via development #
26 | ###############################################################
27 | def initialize(M,sigmas,dsigmas,a):
28 | """ Most probable paths for Kunita flows """
29 | """ M: shape manifold, a: flow field """
30 |
31 | def ode_MPP_landmarks(c,y):
32 | t,xlambd,chart = c
33 | qp, = y
34 | x = xlambd[0].reshape((M.N,M.m)) # points
35 | lambd = xlambd[1].reshape((M.N,M.m))
36 |
37 | sigmasx = sigmas(x)
38 | dsigmasx = dsigmas(x)
39 | c = jnp.einsum('ri,rai->a',lambd,sigmasx)
40 |
41 | dx = a(x,qp)+jnp.einsum('a,rak->rk',c,sigmasx)
42 | #dlambd = -jnp.einsum('ri,a,rairk->rk',lambd,c,jacrev(sigmas)(x))-jnp.einsum('ri,rirk->rk',lambd,jacrev(a)(x,qp))
43 | dlambd = -jnp.einsum('ri,a,raik->rk',lambd,c,dsigmasx)-jnp.einsum('ri,rirk->rk',lambd,jacrev(a)(x,qp))
44 | return jnp.stack((dx.flatten(),dlambd.flatten()))
45 |
46 | def chart_update_MPP_landmarks(xlambd,chart,y):
47 | if M.do_chart_update is None:
48 | return (xlambd,chart)
49 |
50 | lambd = xlambd[1].reshape((M.N,M.m))
51 | x = (xlambd[0],chart)
52 |
53 | update = M.do_chart_update(x)
54 | new_chart = M.centered_chart(x)
55 | new_x = M.update_coords(x,new_chart)[0]
56 |
57 | return (jnp.where(update,
58 | jnp.stack((new_x,M.update_covector(x,new_x,new_chart,lambd))),
59 | xlambd),
60 | jnp.where(update,
61 | new_chart,
62 | chart))
63 |
64 | def MPP_landmarks(x,lambd,qps,dts):
65 | (ts,xlambds,charts) = integrate(ode_MPP_landmarks,chart_update_MPP_landmarks,jnp.stack((x[0],lambd)),x[1],dts,qps)
66 | return (ts,xlambds[:,0],xlambds[:,1],charts)
67 | M.MPP_landmarks = MPP_landmarks
68 |
69 |
--------------------------------------------------------------------------------
/jaxgeometry/dynamics/MPP_landmarks_Log.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.utils import *
23 |
24 | ###############################################################
25 | # Most probable paths for landmarks via development - BVP #
26 | ###############################################################
27 | def initialize(M):
28 | method='BFGS'
29 |
30 | def loss(x,lambd,y,qps,_dts):
31 | (_,xs,_,charts) = M.MPP_landmarks(x,lambd,qps,_dts)
32 | (x1,chart1) = (xs[-1],charts[-1])
33 | y_chart1 = M.update_coords(y,chart1)
34 | return 1./M.dim*jnp.sum(jnp.square(x1 - y_chart1[0]))
35 |
36 | from scipy.optimize import minimize,fmin_bfgs,fmin_cg
37 | def shoot(x,y,qps,_dts,lambd0=None):
38 |
39 | if lambd0 is None:
40 | lambd0 = jnp.zeros(M.dim)
41 |
42 | res = minimize(jax.value_and_grad(lambda w: loss(x,w,y,qps,_dts)), lambd0, method=method, jac=True, options={'disp': False, 'maxiter': 100})
43 |
44 | return (res.x,res.fun)
45 |
46 | M.Log_MPP_landmarks = shoot
47 |
--------------------------------------------------------------------------------
/jaxgeometry/dynamics/flow.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.utils import *
23 |
24 | def initialize(M):
25 | """ flow along a vector field X """
26 | def flow(X):
27 |
28 | def ode_flow(c,y):
29 | t,x,chart = c
30 | return X((x,chart))
31 |
32 | def chart_update_flow(x,chart,*ys):
33 | if M.do_chart_update is None:
34 | return (x,chart)
35 |
36 | update = M.do_chart_update(x)
37 | new_chart = M.centered_chart((x,chart))
38 | new_x = M.update_coords((x,chart),new_chart)[0]
39 |
40 | return (jnp.where(update,
41 | new_x,
42 | x),
43 | jnp.where(update,
44 | new_chart,
45 | chart),
46 | )
47 |
48 | flow = jit(lambda x,dts: integrate(ode_flow,chart_update_flow,x[0],x[1],dts))
49 | return flow
50 | M.flow = flow
51 |
--------------------------------------------------------------------------------
/jaxgeometry/dynamics/flow_differential.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.utils import *
23 |
24 | ###############################################################
25 | # Compute differential d\phi along a phase-space path qt #
26 | # See Younes, Shapes and Diffeomorphisms, 2010 and #
27 | # Sommer et al., SIIMS 2013 #
28 | ###############################################################
29 | def initialize(M):
30 | """ M: landmark manifold, scalar kernel """
31 |
32 | def ode_differential(c,y):
33 | t,dphi,chart = c
34 | qp, = y
35 | q = qp[0].reshape((M.N,M.m)) # points
36 | p = qp[1].reshape((M.N,M.m)) # points
37 |
38 | dk = M.dk_q(q,q)
39 | ddphi = jnp.einsum('iab,jic,jb->iac',dphi,dk,p)
40 |
41 | return ddphi
42 |
43 | def chart_update_differential(dphi,chart,y):
44 | if M.do_chart_update is None:
45 | return (dphi,chart)
46 |
47 | assert(False) # not implemented yet
48 |
49 | def flow_differential(qps,dts):
50 | """ Transport covector lambd along covector path qps """
51 | (ts,dphis,charts) = integrate(ode_differential,chart_update_differential,jnp.tile(jnp.eye(M.m),(M.N,1,1)),None,dts,qps)
52 | return (ts,dphis,charts)
53 | M.flow_differential = flow_differential
54 |
--------------------------------------------------------------------------------
/jaxgeometry/framebundle/FM.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.params import *
23 | from jaxgeometry.utils import *
24 |
25 |
26 | def initialize(M):
27 | """ Frame Bundle geometry """
28 |
29 | d = M.dim
30 |
31 | def chart_update_FM(u,chart,*args):
32 | if M.do_chart_update != True:
33 | return (u,chart)
34 |
35 | x = (u[0:d],chart)
36 | nu = u[d:].reshape((d,-1))
37 |
38 | update = M.do_chart_update(x)
39 | new_chart = M.centered_chart(x)
40 | new_x = M.update_coords(x,new_chart)[0]
41 |
42 | return (jnp.where(update,
43 | jnp.concatenate((new_x,M.update_vector(x,new_x,new_chart,nu).flatten())),
44 | u),
45 | jnp.where(update,
46 | new_chart,
47 | chart))
48 | M.chart_update_FM = chart_update_FM
49 |
50 | #### Bases shifts, see e.g. Sommer Entropy 2016 sec 2.3
51 | # D denotes frame adapted to the horizontal distribution
52 | def to_D(u,w):
53 | x = (u[0][0:d],u[1])
54 | nu = u[0][d:].reshape((d,-1))
55 | wx = w[0:d]
56 | wnu = w[d:].reshape((d,-1))
57 |
58 | # shift to D basis
59 | Gammanu = jnp.tensordot(M.Gamma_g(x),nu,(2,0)).swapaxes(1,2)
60 | Dwx = wx
61 | Dwnu = jnp.tensordot(Gammanu,wx,(2,0))+wnu
62 |
63 | return jnp.concatenate((Dwx,Dwnu.flatten()))
64 | def from_D(u,Dw):
65 | x = (u[0][0:d],u[1])
66 | nu = u[0][d:].reshape((d,-1))
67 | Dwx = Dw[0:d]
68 | Dwnu = Dw[d:].reshape((d,-1))
69 |
70 | # shift to D basis
71 | Gammanu = jnp.tensordot(M.Gamma_g(x),nu,(2,0)).swapaxes(1,2)
72 | wx = Dwx
73 | wnu = -jnp.tensordot(Gammanu,Dwx,(2,0))+Dwnu
74 |
75 | return jnp.concatenate((wx,wnu.flatten()))
76 | # corresponding dual space shifts
77 | def to_Dstar(u,p):
78 | x = (u[0][0:d],u[1])
79 | nu = u[0][d:].reshape((d,-1))
80 | px = p[0:d]
81 | pnu = p[d:].reshape((d,-1))
82 |
83 | # shift to D basis
84 | Gammanu = jnp.tensordot(M.Gamma_g(x),nu,(2,0)).swapaxes(1,2)
85 | Dpx = px-jnp.tensordot(Gammanu,pnu,((0,1),(0,1)))
86 | Dpnu = pnu
87 |
88 | return jnp.concatenate((Dpx,Dpnu.flatten()))
89 | def from_Dstar(u,Dp):
90 | x = (u[0][0:d],u[1])
91 | nu = u[0][d:].reshape((d,-1))
92 | Dpx = Dp[0:d]
93 | Dpnu = Dp[d:].reshape((d,-1))
94 |
95 | # shift to D basis
96 | Gammanu = jnp.tensordot(M.Gamma_g(x),nu,(2,0)).swapaxes(1,2)
97 | px = Dpx+jnp.tensordot(Gammanu,Dpnu,((0,1),(0,1)))
98 | pnu = Dpnu
99 |
100 | return jnp.concatenate((px,pnu.flatten()))
101 | M.to_D = to_D
102 | M.from_D = from_D
103 | M.to_Dstar = to_Dstar
104 | M.from_Dstar = from_Dstar
105 |
106 | ##### Horizontal vector fields:
107 | def Horizontal(u):
108 | x = (u[0][0:d],u[1])
109 | nu = u[0][d:].reshape((d,-1))
110 |
111 | # Contribution from the coordinate basis for x:
112 | dx = nu
113 | # Contribution from the basis for Xa:
114 | Gammahgammaj = jnp.einsum('hji,ig->hgj',M.Gamma_g(x),nu) # same as Gammanu above
115 | dnu = -jnp.einsum('hgj,ji->hgi',Gammahgammaj,nu)
116 |
117 | return jnp.concatenate([dx,dnu.reshape((-1,nu.shape[1]))],axis=0)
118 | M.Horizontal = Horizontal
119 |
120 |
121 |
--------------------------------------------------------------------------------
/jaxgeometry/group/EulerPoincare.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | def initialize(G):
24 | """ Euler-Poincare geodesic integration """
25 |
26 | assert(G.invariance == 'left')
27 |
28 | def ode_EP(c,y):
29 | t,mu,_ = c
30 | xi = G.invFl(mu)
31 | dmut = -G.coad(xi,mu)
32 | return dmut
33 | G.EP = lambda mu,_dts=None: integrate(ode_EP,None,mu,None,dts() if _dts is None else _dts)
34 |
35 | # reconstruction
36 | def ode_EPrec(c,y):
37 | t,g,_ = c
38 | mu, = y
39 | xi = G.invFl(mu)
40 | dgt = G.dL(g,G.e,G.VtoLA(xi))
41 | return dgt
42 | G.EPrec = lambda g,mus,_dts=None: integrate(ode_EPrec,None,g,None,dts() if _dts is None else _dts,mus)
43 |
44 | ### geodesics
45 | G.coExpEP = lambda g,mu: G.EPrec(g,G.EP(mu)[1])[1][-1]
46 | G.ExpEP = lambda g,v: G.coExpEP(g,G.flatV(v))
47 | G.ExpEPpsi = lambda q,v: G.ExpEP(G.psi(q),G.flatV(v))
48 | G.coExpEPt = lambda g,mu: G.EPrec(g,G.EP(mu)[1])
49 | G.ExpEPt = lambda g,v: G.coExpEPt(g,G.flatV(v))
50 | G.ExpEPpsit = lambda q,v: G.ExpEPt(G.psi(q),G.flatV(v))
51 | G.DcoExpEP = lambda g,mu: jax.jaxrev(G.coExpEP)(g,mu)
52 |
--------------------------------------------------------------------------------
/jaxgeometry/group/LiePoisson.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | def initialize(G):
24 | """ Lie-Poisson geodesic integration """
25 |
26 | assert(G.invariance == 'left')
27 |
28 | def ode_LP(c,y):
29 | t,mu,_ = c
30 | dmut = G.coad(G.dHminusdmu(mu),mu)
31 | return dmut
32 | G.LP = lambda mu,_dts=None: integrate(ode_LP,None,mu,None,dts() if _dts is None else _dts)
33 |
34 | # reconstruction
35 | def ode_LPrec(c,y):
36 | t,g,_ = c
37 | mu, = y
38 | dgt = G.dL(g,G.e,G.VtoLA(G.dHminusdmu(mu)))
39 | return dgt
40 | G.LPrec = lambda g,mus,_dts=None: integrate(ode_LPrec,None,g,None,dts() if _dts is None else _dts,mus)
41 |
42 | ### geodesics
43 | G.coExpLP = lambda g,mu: G.LPrec(g,G.LP(mu)[1])[1][-1]
44 | G.ExpLP = lambda g,v: G.coExpLP(g,G.flatV(v))
45 | G.coExpLPt = lambda g,mu: G.LPrec(g,G.LP(mu)[1])
46 | G.ExpLPt = lambda g,v: G.coExpLPt(g,G.flatV(v))
47 | G.DcoExpLP = lambda g,mu: jax.jacrev(G.coExp)(g,mu)
48 |
--------------------------------------------------------------------------------
/jaxgeometry/group/energy.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | def initialize(G):
24 | """ group Lagrangian and Hamiltonian from invariant metric """
25 |
26 | # Lagrangian
27 | def Lagrangian(g,vg):
28 | return .5*G.gG(g,vg,vg)
29 | G.Lagrangian = Lagrangian
30 | # Lagrangian using psi map
31 | def Lagrangianpsi(q,v):
32 | return .5*G.gpsi(q,v,v)
33 | G.Lagrangianpsi = Lagrangianpsi
34 | G.dLagrangianpsidq = jax.grad(G.Lagrangianpsi)
35 | G.dLagrangianpsidv = jax.grad(G.Lagrangianpsi)
36 | # LA restricted Lagrangian
37 | def l(hatxi):
38 | return 0.5*G.gV(hatxi,hatxi)
39 | G.l = l
40 | G.dldhatxi = jax.grad(G.l)
41 |
42 | # Hamiltonian using psi map
43 | def Hpsi(q,p):
44 | return .5*G.cogpsi(q,p,p)
45 | G.Hpsi = Hpsi
46 | # LA^* restricted Hamiltonian
47 | def Hminus(mu):
48 | return .5*G.cogV(mu,mu)
49 | G.Hminus = Hminus
50 | G.dHminusdmu = jax.grad(G.Hminus)
51 |
52 | # Legendre transformation. The above Lagrangian is hyperregular
53 | G.FLpsi = lambda q,v: (q,G.dLagrangianpsidv(q,v))
54 | G.invFLpsi = lambda q,p: (q,G.cogpsi(q,p))
55 | def HL(q,p):
56 | (q,v) = invFLpsi(q,p)
57 | return jnp.dot(p,v)-L(q,v)
58 | G.HL = HL
59 | G.Fl = lambda hatxi: G.dldhatxi(hatxi)
60 | G.invFl = lambda mu: G.cogV(mu)
61 | def hl(mu):
62 | hatxi = invFl(mu)
63 | return jnp.dot(mu,hatxi)-l(hatxi)
64 | G.hl = hl
65 |
66 | # default Hamiltonian
67 | G.H = lambda q,p: G.Hpsi(q[0],p) if type(q) == type(()) else G.Hpsi(q,p)
68 |
69 | # A.set_value(np.diag([3,2,1]))
70 | # print(FLpsif(q0,v0))
71 | # print(invFLpsif(q0,p0))
72 | # (flq0,flv0)=FLpsif(q0,v0)
73 | # print(q0,v0)
74 | # print(invFLpsif(flq0,flv0))
75 |
--------------------------------------------------------------------------------
/jaxgeometry/group/invariant_metric.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | def initialize(G,_sigma=None):
24 | """ add left-/right-invariant metric related structures to group
25 |
26 | parameter sigma is square root cometric / diffusion field
27 | """
28 |
29 | if _sigma is None:
30 | _sigma = jnp.eye(G.dim)
31 |
32 | G.sqrtA = lambda sigma=_sigma: G.inv(sigma) # square root metric
33 | G.A = lambda sigma=_sigma: jnp.tensordot(G.sqrtA(sigma),G.sqrtA(sigma),(0,0)) # metric
34 | G.W = lambda sigma=_sigma: jnp.tensordot(sigma,sigma,(1,1)) # covariance (cometric)
35 | def gV(v=None,w=None,sigma=_sigma):
36 | if v is None and w is None:
37 | return G.A(sigma)
38 | elif v is not None and w is None:
39 | return jnp.tensordot(G.A(sigma),v,(1,0))
40 | elif v.ndim == 1 and w.ndim == 1:
41 | return jnp.dot(v,jnp.dot(G.A(sigma),w))
42 | elif v.ndim == 1 and not w:
43 | return jnp.dot(G.A(sigma),v)
44 | elif v.ndim == 2 and w.ndim == 2:
45 | return jnp.tensordot(v,jnp.tensordot(G.A(sigma),w,(1,0)),(0,0))
46 | else:
47 | assert(False)
48 | G.gV = gV
49 | def cogV(cov=None,cow=None,sigma=_sigma):
50 | if cov is None and cow is None:
51 | return G.W(sigma)
52 | elif cov is not None and cow is None:
53 | return jnp.tensordot(G.W(sigma),cov,(1,0))
54 | elif cov.ndim == 1 and cow.ndim == 1:
55 | return jnp.dot(cov,jnp.dot(G.W(sigma),cow))
56 | elif cov.ndim == 2 and cow.ndim == 2:
57 | return jnp.tensordot(cov,jnp.tensordot(G.W(sigma),cow,(1,0)),(0,0))
58 | else:
59 | assert(False)
60 | G.cogV = cogV
61 | def gLA(xiv,xiw,sigma=_sigma):
62 | v = G.LAtoV(xiv)
63 | w = G.LAtoV(xiw)
64 | return G.gV(v,w,sigma)
65 | G.gLA = gLA
66 | def cogLA(coxiv,coxiw,sigma=_sigma):
67 | cov = G.LAtoV(coxiv)
68 | cow = G.LAtoV(coxiw)
69 | return G.cogV(cov,cow,sigma)
70 | G.cogLA = cogLA
71 | def gG(g,vg,wg,sigma=_sigma):
72 | xiv = G.invpb(g,vg)
73 | xiw = G.invpb(g,wg)
74 | return G.gLA(xiv,xiw,sigma)
75 | G.gG = gG
76 | def gpsi(hatxi,v=None,w=None,sigma=_sigma):
77 | g = G.psi(hatxi)
78 | vg = G.dpsi(hatxi,v)
79 | wg = G.dpsi(hatxi,w)
80 | return G.gG(g,vg,wg,sigma)
81 | G.gpsi = gpsi
82 | def cogpsi(hatxi,p=None,pp=None,sigma=_sigma):
83 | invgpsi = G.inv(G.gpsi(hatxi,sigma=sigma))
84 | if p is not None and pp is not None:
85 | return jnp.tensordot(p,jnp.tensordot(invgpsi,pp,(1,0)),(0,0))
86 | elif p and not pp:
87 | return jnp.tensordot(invgpsi,p,(1,0))
88 | return invgpsi
89 | G.cogpsi = cogpsi
90 |
91 | # sharp/flat mappings
92 | def sharpV(mu,sigma=_sigma):
93 | return jnp.dot(G.W(sigma),mu)
94 | G.sharpV = sharpV
95 | def flatV(v,sigma=_sigma):
96 | return jnp.dot(G.A(sigma),v)
97 | G.flatV = flatV
98 | def sharp(g,pg,sigma=_sigma):
99 | return G.invpf(g,G.VtoLA(jnp.dot(G.W(sigma),G.LAtoV(G.invcopb(g,pg)))))
100 | G.sharp = sharp
101 | def flat(g,vg,sigma=_sigma):
102 | return G.invcopf(g,G.VtoLA(jnp.dot(G.A(sigma),G.LAtoV(G.invpb(g,vg)))))
103 | G.flat = flat
104 | def sharppsi(hatxi,p,sigma=_sigma):
105 | return jnp.tensordot(G.cogpsi(hatxi,sigma=sigma),p,(1,0))
106 | G.sharppsi = sharppsi
107 | def flatpsi(hatxi,v,sigma=_sigma):
108 | return jnp.tensordot(G.gpsi(hatxi,sigma=sigma),v,(1,0))
109 | G.flatpsi = flatpsi
110 |
111 |
--------------------------------------------------------------------------------
/jaxgeometry/group/quotient.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.utils import *
23 |
24 | def horz_vert_split(x,proj,sigma,G,M):
25 | # compute kernel of proj derivative with respect to inv A metric
26 | rank = M.dim
27 | Xframe = jnp.tensordot(G.invpf(x,G.eiLA),sigma,(2,0))
28 | Xframe_inv = jnp.linalg.pinv(Xframe.reshape((-1,G.dim)))
29 | dproj = jnp.einsum('...ij,ijk->...k',jacrev(proj)(x), Xframe)
30 | (_,_,Vh) = jnp.linalg.svd(jax.lax.stop_gradient(dproj),full_matrices=True)
31 | ns = Vh[rank:].T # null space
32 | proj_ns = jnp.tensordot(ns,ns,(1,1))
33 | horz = Vh[0:rank].T # horz space
34 | proj_horz = jnp.tensordot(horz,horz,(1,1))
35 |
36 | return (Xframe,Xframe_inv,proj_horz,proj_ns,horz)
37 |
38 | # hit target v at time t=Tend
39 | def get_sde_fiber(sde_f,proj,G,M):
40 | def sde_fiber(c,y):
41 | (det,sto,X,*dys_sde) = sde_f(c,y)
42 | t,g,_,sigma = c
43 | dt,dW = y
44 |
45 | (Xframe,Xframe_inv,_,proj_ns,_) = horz_vert_split(g,proj,sigma,G,M)
46 |
47 | det = jnp.tensordot(Xframe,jnp.tensordot(proj_ns,jnp.tensordot(Xframe_inv,det.flatten(),(1,0)),(1,0)),(2,0)).reshape(g.shape)
48 | sto = jnp.tensordot(Xframe,jnp.tensordot(proj_ns,jnp.tensordot(Xframe_inv,sto.flatten(),(1,0)),(1,0)),(2,0)).reshape(g.shape)
49 | X = jnp.tensordot(Xframe,jnp.tensordot(proj_ns,jnp.tensordot(Xframe_inv,X.reshape((-1,G.dim)),(1,0)),(1,0)),(2,0)).reshape(X.shape)
50 |
51 | return (det,sto,X,*dys_sde)
52 |
53 | return sde_fiber
54 |
55 | def get_sde_horz(sde_f,proj,G,M):
56 | def sde_horz(c,y):
57 | (det,sto,X,*dys_sde) = sde_f(c,y)
58 | t,g,_,sigma = c
59 | dt,dW = y
60 |
61 | (Xframe,Xframe_inv,proj_horz,_,_) = horz_vert_split(g,proj,sigma,G,M)
62 | det = jnp.tensordot(Xframe,jnp.tensordot(proj_horz,jnp.tensordot(Xframe_inv,det.flatten(),(1,0)),(1,0)),(2,0)).reshape(g.shape)
63 | sto = jnp.tensordot(Xframe,jnp.tensordot(proj_horz,jnp.tensordot(Xframe_inv,sto.flatten(),(1,0)),(1,0)),(2,0)).reshape(g.shape)
64 | X = jnp.tensordot(Xframe,jnp.tensordot(proj_horz,jnp.tensordot(Xframe_inv,X.reshape((-1,G.dim)),(1,0)),(1,0)),(2,0)).reshape(X.shape)
65 |
66 | return (det,sto,X,*dys_sde)
67 |
68 | return sde_horz
69 |
70 | def get_sde_lifted(sde_f,proj,G,M):
71 | def sde_lifted(c,y):
72 | t,g,chart,sigma,*cs = c
73 | dt,dW = y
74 |
75 | (det,sto,X,*dys_sde) = sde_f((t,M.invF((proj(g),chart)),chart,*cs),y)
76 |
77 | (Xframe,Xframe_inv,proj_horz,_,horz) = horz_vert_split(g,proj,sigma,G,M)
78 |
79 |
80 | det = jnp.tensordot(Xframe,jnp.tensordot(horz,det,(1,0)),(2,0)).reshape(g.shape)
81 | sto = jnp.tensordot(Xframe,jnp.tensordot(horz,sto,(1,0)),(2,0)).reshape(g.shape)
82 | X = jnp.tensordot(Xframe,jnp.tensordot(horz,X,(1,0)),(2,0)).reshape((G.dim,G.dim,M.dim))
83 |
84 | return (det,sto,X,jnp.zeros_like(sigma),*dys_sde)
85 |
86 | return sde_lifted
87 |
88 | ## find g in fiber above x closests to g0
89 | #from scipy.optimize import minimize
90 | #def lift_to_fiber(x,x0,G,M):
91 | # shoot = lambda hatxi: G.gV(hatxi,hatxi)
92 | # try:
93 | # hatxi = minimize(shoot,
94 | # np.zeros(G.dim),
95 | # method='COBYLA',
96 | # constraints={'type':'ineq','fun':lambda hatxi: np.min((G.injectivity_radius-np.max(hatxi),
97 | # 1e-8-np.linalg.norm(M.act(G.exp(G.VtoLA(hatxi)),x0)-x)**2))},
98 | # ).x
99 | # hatxi = minimize(lambda hatxi: np.linalg.norm(M.act(G.exp(G.VtoLA(hatxi)),x0)-x)**2,
100 | # hatxi).x # fine tune
101 | # except AttributeError: # injectivity radius not defined
102 | # hatxi = minimize(shoot,
103 | # np.zeros(G.dim),
104 | # method='COBYLA',
105 | # constraints={'type':'ineq','fun':lambda hatxi: 1e-8-np.linalg.norm(M.act(G.exp(G.VtoLA(hatxi)),x0)-x)**2}).x
106 | # hatxi = minimize(lambda hatxi: np.linalg.norm(M.act(G.exp(G.VtoLA(hatxi)),x0)-x)**2,
107 | # hatxi).x # fine tune
108 | # l0 = G.exp(G.VtoLA(hatxi))
109 | # try: # project to group if to_group function is available
110 | # l0 = G.to_group(l0)
111 | # except NameError:
112 | # pass
113 | # return (l0,hatxi)
114 | #
115 | ## estimate fiber volume
116 | #import scipy.special
117 | #from jaxgeometry.plotting import *
118 | #
119 | #def fiber_samples(G,Brownian_fiberf,L,pars):
120 | # (seed,) = pars
121 | # if seed:
122 | # srng.seed(seed)
123 | # gsl = np.zeros((L,) + G.e.shape)
124 | # dsl = np.zeros(L)
125 | # (ts, gs) = Brownian_fiber(G.e, dWs(G.dim))
126 | # vl = gs[-1] # starting point
127 | # for l in range(L):
128 | # (ts, gs) = Brownian_fiber(vl, dWs(G.dim))
129 | # gsl[l] = gs[-1]
130 | # dsl[l] = np.linalg.norm(G.LAtoV(G.log(gs[-1]))) # distance to sample with canonical biinvariant metric
131 | # vl = gs[-1]
132 | #
133 | # return (gsl, dsl)
134 | #
135 | #def estimate_fiber_volume(G, M, lfiber_samples, nr_samples=100, plot_dist_histogram=False, plot_samples=False):
136 | # """ estimate fiber volume with restricted Riemannian G volume element (biinvariant metric) """
137 | # L = nr_samples // (cpu_count() // 2) # samples per processor
138 | #
139 | # try:
140 | # mpu.openPool()
141 | # sol = mpu.pool.imap(partial(lfiber_samples, L), mpu.inputArgs(np.random.randint(1000, size=cpu_count() // 2)))
142 | # res = list(sol)
143 | # gsl = mpu.getRes(res, 0).reshape((-1,) + G.e.shape)
144 | # dsl = mpu.getRes(res, 1).flatten()
145 | # except:
146 | # mpu.closePool()
147 | # raise
148 | # else:
149 | # mpu.closePool()
150 | #
151 | # if plot_dist_histogram:
152 | # # distance histogram
153 | # plt.hist(dsl, 20)
154 | #
155 | # if plot_samples:
156 | # # plot samples
157 | # newfig()
158 | # for l in range(0, L):
159 | # G.plotg(gsl[l])
160 | # plt.show()
161 | #
162 | # # count percentage of samples below distance d to e relative to volume of d-ball
163 | # d = np.max(dsl) # distance must be smaller than fiber radius
164 | # fiber_dim = G.dim - M.dim
165 | # ball_volume = np.pi ** (fiber_dim / 2) / scipy.special.gamma(fiber_dim / 2 + 1) * d ** fiber_dim
166 | #
167 | # return ball_volume / (np.sum(dsl < d) / (dsl.size))
168 |
--------------------------------------------------------------------------------
/jaxgeometry/groups/GLN.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.params import *
22 |
23 | from jaxgeometry.groups.group import *
24 |
25 | import matplotlib.pyplot as plt
26 |
27 | class GLN(LieGroup):
28 | """ General linear group GL(N) """
29 |
30 | def __init__(self,N=3):
31 | dim = N*N # group dimension
32 | LieGroup.__init__(self,dim,N=N,invariance='left')
33 |
34 | # project to group, here with minimum eigenvalue 1e-3
35 | def to_group(g):
36 | _min_eig = 1e-3
37 | w, V = jnp.linalg.eig(g.astype('complex128'))
38 | w_prime = jnp.where(abs(w) < _min_eig, _min_eig, w)
39 | return jnp.dot(V,jnp.dot(jnp.diag(w_prime),V.T)).real
40 |
41 | ## coordinate chart on the linking Lie algebra, trival in this case
42 | def VtoLA(hatxi): # from \RR^G.dim to LA
43 | if hatxi.ndim == 1:
44 | return hatxi.reshape((N,N))
45 | else: # matrix
46 | return hatxi.reshape((N,N,-1))
47 | self.VtoLA = VtoLA
48 | def LAtoV(m): # from LA to \RR^G.dim
49 | if m.ndim == 2:
50 | return m.reshape((self.dim,))
51 | elif m.ndim == 3:
52 | return m.reshape((self.dim,-1))
53 | else:
54 | assert(False)
55 | self.LAtoV = LAtoV
56 |
57 | self.Expm = jax.scipy.linalg.expm
58 | def logm(b):
59 | I = jnp.eye(b.shape[0])
60 | res = jnp.zeros_like(b)
61 | ITERATIONS = 20
62 | for k in range(1, ITERATIONS):
63 | res += pow(-1, k+1) * jnp.linalg.matrix_power(b-I, k)/k
64 | return res
65 | self.Logm = logm
66 |
67 | super(GLN,self).initialize()
68 |
69 | def __str__(self):
70 | return "GL(%d) (dimension %d)" % (self.N,self.dim)
71 |
72 | def plot_path(self, g,color_intensity=1.,color=None,linewidth=3.,prevg=None):
73 | assert(len(g.shape)>2)
74 | for i in range(g.shape[0]):
75 | self.plotg(g[i],
76 | linewidth=linewidth if i==0 or i==g.shape[0]-1 else .3,
77 | color_intensity=color_intensity if i==0 or i==g.shape[0]-1 else .7,
78 | prevg=g[i-1] if i>0 else None)
79 | return
80 |
81 | def plotg(self, g,color_intensity=1.,color=None,linewidth=3.,prevg=None):
82 | s0 = np.eye(self.N) # shape
83 | s = np.dot(g,s0) # rotated shape
84 | if prevg is not None:
85 | prevs = np.dot(prevg,s0)
86 |
87 | colors = color_intensity*np.array([[1,0,0],[0,1,0],[0,0,1]])
88 | for i in range(s.shape[1]):
89 | plt.quiver(0,0,0,s[0,i],s[1,i],s[2,i],pivot='tail',linewidth=linewidth,color=colors[i] if color is None else color,arrow_length_ratio=.15,length=1)
90 | if prevg is not None:
91 | ss = np.stack((prevs,s))
92 | plt.plot(ss[:,0,i],ss[:,1,i],ss[:,2,i],linewidth=.3,color=colors[i])
93 |
--------------------------------------------------------------------------------
/jaxgeometry/groups/SON.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.params import *
22 |
23 | from jaxgeometry.groups.group import *
24 |
25 | import matplotlib.pyplot as plt
26 |
27 | class SON(LieGroup):
28 | """ Special Orthogonal Group SO(N) """
29 |
30 | def __init__(self,N=3,invariance='left'):
31 | dim = N*(N-1)//2 # group dimension
32 | LieGroup.__init__(self,dim,N,invariance=invariance)
33 |
34 | self.injectivity_radius = 2*jnp.pi
35 |
36 | # project to group (here using QR factorization)
37 | def to_group(g):
38 | (q,r) = jnp.linalg.qr(g)
39 | return jnp.dot(q,jnp.diag(jnp.diag(r)))
40 |
41 | ## coordinate chart linking Lie algebra LA={A\in\RR^{NxN}|\trace{A}=0} and V=\RR^G_dim
42 | # derived from https://stackoverflow.com/questions/25326462/initializing-a-symmetric-theano-dmatrix-from-its-upper-triangle
43 | r = jnp.arange(N)
44 | tmp_mat = r[jnp.newaxis, :] + ((N * (N - 3)) // 2-(r * (r - 1)) // 2)[::-1,jnp.newaxis]
45 | triu_index_matrix = jnp.triu(tmp_mat+1)-jnp.diag(jnp.diagonal(tmp_mat+1))
46 |
47 | def VtoLA(hatxi): # from \RR^G_dim to LA
48 | if hatxi.ndim == 1:
49 | m = jnp.concatenate((jnp.zeros(1),hatxi))[triu_index_matrix]
50 | return m-m.T
51 | else: # matrix
52 | m = jnp.concatenate((jnp.zeros((1,hatxi.shape[1])),hatxi))[triu_index_matrix,:]
53 | return m-m.transpose((1,0,2))
54 | self.VtoLA = VtoLA
55 | self.LAtoV = lambda m: m[np.triu_indices(N, 1)]
56 |
57 | #import theano.tensor.slinalg
58 | #Expm = jnp.slinalg.Expm()
59 | def Expm(g): # hardcoded for skew symmetric matrices to allow higher-order gradients
60 | (w,V) = jnp.linalg.eigh(1.j*g)
61 | w = -1j*w
62 | expm = jnp.real(jnp.tensordot(V,jnp.tensordot(jnp.diag(jnp.exp(w)),jnp.conj(V.T),(1,0)),(1,0)))
63 | return expm
64 | self.Expm = Expm
65 | def logm(b):
66 | I = jnp.eye(b.shape[0])
67 | res = jnp.zeros_like(b)
68 | ITERATIONS = 20
69 | for k in range(1, ITERATIONS):
70 | res += pow(-1, k+1) * jnp.linalg.matrix_power(b-I, k)/k
71 | return res
72 | self.Logm = logm
73 |
74 | super(SON,self).initialize()
75 |
76 | def __str__(self):
77 | return "SO(%d) (dimension %d)" % (self.N,self.dim)
78 |
79 | def newfig(self):
80 | newfig3d()
81 |
82 | ### plotting
83 | import matplotlib.pyplot as plt
84 | def plot_path(self,g,color_intensity=1.,color=None,linewidth=3.,alpha=1.,prevg=None):
85 | assert(len(g.shape)>2)
86 | for i in range(g.shape[0]):
87 | self.plotg(g[i],
88 | linewidth=linewidth if i==0 or i==g.shape[0]-1 else .3,
89 | color_intensity=color_intensity if i==0 or i==g.shape[0]-1 else .7,
90 | alpha=alpha,
91 | prevg=g[i-1] if i>0 else None)
92 | return
93 |
94 | def plotg(self,g,color_intensity=1.,color=None,linewidth=3.,alpha=1.,prevg=None):
95 | # Grid Settings:
96 | import matplotlib.ticker as ticker
97 | ax = plt.gca()
98 | x = jnp.arange(-10,10,1)
99 | ax.xaxis._axinfo["grid"]['linewidth'] = 0.3
100 | ax.yaxis._axinfo["grid"]['linewidth'] = 0.3
101 | ax.zaxis._axinfo["grid"]['linewidth'] = 0.3
102 | ax.set_xlim(-1.,1.)
103 | ax.set_ylim(-1.,1.)
104 | ax.set_zlim(-1.,1.)
105 | #ax.set_aspect("equal")
106 |
107 | s0 = jnp.eye(3) # shape
108 | s = jnp.dot(g,s0) # rotated shape
109 | if prevg is not None:
110 | prevs = jnp.dot(prevg,s0)
111 |
112 | colors = color_intensity*np.array([[1,0,0],[0,1,0],[0,0,1]])
113 | for i in range(s.shape[1]):
114 | plt.quiver(0,0,0,s[0,i],s[1,i],s[2,i],pivot='tail',linewidth=linewidth,color=colors[i] if color is None else color,arrow_length_ratio=.15,length=1,alpha=alpha)
115 | if prevg is not None:
116 | ss = jnp.stack((prevs,s))
117 | ss = ss/jnp.linalg.norm(ss,axis=1)[:,jnp.newaxis,:]
118 | plt.plot(ss[:,0,i],ss[:,1,i],ss[:,2,i],linewidth=1,color=colors[i])
119 |
120 |
--------------------------------------------------------------------------------
/jaxgeometry/groups/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ComputationalEvolutionaryMorphometry/jaxgeometry/70e22f05966d5076adab37d9579a4900f486485f/jaxgeometry/groups/__init__.py
--------------------------------------------------------------------------------
/jaxgeometry/groups/group.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.params import *
22 |
23 | from jaxgeometry.manifolds.manifold import *
24 |
25 | class LieGroup(EmbeddedManifold):
26 | """ Base Lie Group class """
27 |
28 | def __init__(self,dim,N,invariance='left'):
29 | EmbeddedManifold.__init__(self)
30 |
31 | self.dim = dim
32 | self.N = N # N in SO(N)
33 | self.emb_dim = N*N # matrix/embedding space dimension
34 | self.invariance = invariance
35 |
36 | self.e = jnp.eye(N,N) # identity element
37 | self.zeroLA = jnp.zeros((N,N)) # zero element in LA
38 | self.zeroV = jnp.zeros((self.dim,)) # zero element in V
39 |
40 | def initialize(self):
41 | """ Initial group operations. To be called by sub-classes after definition of dimension, Expm etc.
42 |
43 | Notation:
44 | hatxi # \RR^G_dim vector
45 | xi # matrix in LA
46 | eta # matrix in LA
47 | alpha # matrix in LA^*
48 | beta # matrix in LA^*
49 | g # \RR^{NxN} matrix
50 | gs # sequence of \RR^{NxN} matrices
51 | h # \RR^{NxN} matrix
52 | vg # \RR^{NxN} tangent vector at g
53 | wg # \RR^{NxN} tangent vector at g
54 | vh # \RR^{NxN} tangent vector at h
55 | w # \RR^G_dim tangent vector in coordinates
56 | v # \RR^G_dim tangent vector in coordinates
57 | pg # \RR^{NxN} cotangent vector at g
58 | ph # \RR^{NxN} cotangent vector at h
59 | p # \RR^G_dim cotangent vector in coordinates
60 | pp # \RR^G_dim cotangent vector in coordinates
61 | mu # \RR^G_dim LA cotangent vector in coordinates
62 | """
63 |
64 | ## group operations
65 | self.inv = lambda g: jnp.linalg.inv(g)
66 |
67 | ## group exp/log maps
68 | self.exp = self.Expm
69 | def expt(xi,_dts=None):
70 | if _dts is None: _dts = dts()
71 | return lax.scan(lambda t,dt: (t+dt,self.exp(t*xi)),0.,_dts)
72 | self.expt = expt
73 | self.log = self.Logm
74 |
75 | ## Lie algebra
76 | self.eiV = jnp.eye(self.dim) # standard basis for V
77 | self.eiLA = self.VtoLA(self.eiV) # pushforward eiV basis for LA
78 | #stdLA = jnp.eye(N*N,N*N).reshape((N,N,N*N)) # standard basis for \RR^{NxN}
79 | #eijV = jnp.eye(G_dim) # standard basis for V
80 | #eijLA = jnp.zeros((N,N,G_dim)) # eij in LA
81 | def bracket(xi,eta):
82 | if xi.ndim == 2 and eta.ndim == 2:
83 | return jnp.tensordot(xi,eta,(1,0))-jnp.tensordot(eta,xi,(1,0))
84 | elif xi.ndim == 3 and eta.ndim == 3:
85 | return jnp.tensordot(xi,eta,(1,0)).dimshuffle((0,2,1,3))-jnp.tensordot(eta,xi,(1,0)).dimshuffle((0,2,1,3))
86 | else:
87 | assert(False)
88 | self.bracket = bracket
89 | #C = bracket(eiLA,eiLA) # structure constants, debug
90 | #C = jnp.linalg.lstsq(eiLA.reshape((N*N*G_dim*G_dim,G_dim*G_dim*G_dim)),bracket(eiLA,eiLA).reshape((N*N*G_dim*G_dim))).reshape((G_dim,G_dim,G_dim)) # structure constants
91 | self.C = jnp.zeros((self.dim,self.dim,self.dim)) # structure constants
92 | for i in range(self.dim):
93 | for j in range(self.dim):
94 | xij = self.bracket(self.eiLA[:,:,i],self.eiLA[:,:,j])
95 | #lC[i,j,:] = jnp.linalg.lstsq(
96 | # self.eiLA.reshape((self.N*self.N,self.dim)),
97 | # xij.flatten(),
98 | # rcond=-1
99 | #)[0]
100 | self.C = self.C.at[i,j].set(jnp.linalg.lstsq(
101 | self.eiLA.reshape(self.N*self.N, self.dim),
102 | xij.reshape(self.N*self.N)
103 | )[0])
104 |
105 | ## surjective mapping \psi:\RR^G_dim\rightarrow G
106 | self.psi = lambda hatxi: self.exp(self.VtoLA(hatxi))
107 | self.invpsi = lambda g: self.LAtoV(self.log(g))
108 | def dpsi(hatxi,v=None):
109 | dpsi = jax.jacrev(self.psi)(hatxi)
110 | if v:
111 | return jnp.tensordot(dpsi,v,(2,0))
112 | return dpsi
113 | self.dpsi = dpsi
114 | def dinvpsi(g,vg=None):
115 | dinvpsi = jax.jacrev(self.invpsi)(g)
116 | if vg:
117 | return jnp.tensordot(dinvpsi,vg,((1,2),(0,1)))
118 | return dinvpsi
119 | self.dinvpsi = dinvpsi
120 |
121 | ## left/right translation
122 | self.L = lambda g,h: jnp.tensordot(g,h,(1,0)) # left translation L_g(h)=gh
123 | self.R = lambda g,h: jnp.tensordot(h,g,(1,0)) # right translation R_g(h)=hg
124 | # pushforward of L/R of vh\in T_hG
125 | #dL = lambda g,h,vh: theano.gradient.Rop(L(theano.gradient.disconnected_grad(g),h).flatten(),h,vh).reshape((N,N))
126 | def dL(g,h,vh=None):
127 | dL = jax.jacrev(self.L,1)(g,h)
128 | if vh is not None:
129 | return jnp.tensordot(dL,vh,((2,3),(0,1)))
130 | return dL
131 | self.dL = dL
132 | def dR(g,h,vh=None):
133 | dR = jax.jacrev(self.R,1)(g,h)
134 | if vh is not None:
135 | return jnp.tensordot(dR,vh,((2,3),(0,1)))
136 | return dR
137 | self.dR = dR
138 | # pullback of L/R of vh\in T_h^*G
139 | self.codL = lambda g,h,vh: self.dL(g,h,vh).T
140 | self.codR = lambda g,h,vh: self.dR(g,h,vh).T
141 |
142 | ## actions
143 | self.Ad = lambda g,xi: self.dR(self.inv(g),g,self.dL(g,self.e,xi))
144 | self.ad = lambda xi,eta: self.bracket(xi,eta)
145 | self.coad = lambda v,p: jnp.tensordot(jnp.tensordot(self.C,v,(0,0)),p,(1,0))
146 |
147 | ## invariance
148 | if self.invariance == 'left':
149 | self.invtrns = self.L # invariance translation
150 | self.invpb = lambda g,vg: self.dL(self.inv(g),g,vg) # left invariance pullback from TgG to LA
151 | self.invpf = lambda g,xi: self.dL(g,self.e,xi) # left invariance pushforward from LA to TgG
152 | self.invcopb = lambda g,pg: self.codL(self.inv(g),g,pg) # left invariance pullback from Tg^*G to LA^*
153 | self.invcopf = lambda g,alpha: self.codL(g,self.e,alpha) # left invariance pushforward from LA^* to Tg^*G
154 | self.infgen = lambda xi,g: self.dR(g,self.e,xi) # infinitesimal generator
155 | else:
156 | self.invtrns = self.R # invariance translation
157 | self.invpb = lambda g,vg: self.dR(self.inv(g),g,vg) # right invariance pullback from TgG to LA
158 | self.invpf = lambda g,xi: self.dR(g,self.e,xi) # right invariance pushforward from LA to TgG
159 | self.invcopb = lambda g,pg: self.codR(self.inv(g),g,pg) # right invariance pullback from Tg^*G to LA^*
160 | self.invcopf = lambda g,alpha: self.codR(g,self.e,alpha) # right invariance pushforward from LA^* to Tg^*G
161 | self.infgen = lambda xi,g: self.dL(g,self.e,xi) # infinitesimal generator
162 |
163 | def __str__(self):
164 | return "abstract Lie group"
165 |
166 |
--------------------------------------------------------------------------------
/jaxgeometry/manifolds/Euclidean.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.params import *
22 |
23 | from jaxgeometry.manifolds.manifold import *
24 |
25 | from jaxgeometry.plotting import *
26 | import matplotlib.pyplot as plt
27 |
28 | class Euclidean(Manifold):
29 | """ Euclidean space """
30 |
31 | def __init__(self,N=3):
32 | Manifold.__init__(self)
33 | self.dim = N
34 |
35 | self.update_coords = lambda coords,_: coords
36 |
37 | ##### Metric:
38 | self.g = lambda x: jnp.eye(self.dim)
39 |
40 | # action of matrix group on elements
41 | self.act = lambda g,x: jnp.tensordot(g,x,(1,0))
42 |
43 | def __str__(self):
44 | return "Euclidean manifold of dimension %d" % (self.dim)
45 |
46 | def newfig(self):
47 | if self.dim == 2:
48 | newfig2d()
49 | elif self.dim == 3:
50 | newfig3d()
51 |
52 | def plot(self):
53 | if self.dim == 2:
54 | plt.axis('equal')
55 |
56 | def plot_path(self, xs, u=None, color='b', color_intensity=1., linewidth=1., prevx=None, last=True, s=20, arrowcolor='k'):
57 | xs = list(xs)
58 | N = len(xs)
59 | prevx = None
60 | for i,x in enumerate(xs):
61 | self.plotx(x, u=u if i == 0 else None,
62 | color=color,
63 | color_intensity=color_intensity if i==0 or i==N-1 else .7,
64 | linewidth=linewidth,
65 | s=s,
66 | prevx=prevx,
67 | last=i==N-1)
68 | prevx = x
69 | return
70 |
71 | def plotx(self, x, u=None, color='b', color_intensity=1., linewidth=1., prevx=None, last=True, s=20, arrowcolor='k'):
72 | assert(type(x) == type(()) or x.shape[0] == self.dim)
73 | if type(x) == type(()):
74 | x = x[0]
75 | if type(prevx) == type(()):
76 | prevx = prevx[0]
77 |
78 | ax = plt.gca()
79 |
80 | if last:
81 | if self.dim == 2:
82 | plt.scatter(x[0],x[1],color=color,s=s)
83 | elif self.dim == 3:
84 | ax.scatter(x[0],x[1],x[2],color=color,s=s)
85 | else:
86 | try:
87 | xx = np.stack((prevx,x))
88 | if self.dim == 2:
89 | plt.plot(xx[:,0],xx[:,1],linewidth=linewidth,color=color)
90 | elif self.dim == 3:
91 | ax.plot(xx[:,0],xx[:,1],xx[:,2],linewidth=linewidth,color=color)
92 | except:
93 | if self.dim == 2:
94 | plt.scatter(x[0],x[1],color=color,s=s)
95 | elif self.dim == 3:
96 | ax.scatter(x[0],x[1],x[2],color=color,s=s)
97 |
98 | try:
99 | plt.quiver(x[0], x[1], u[0], u[1], pivot='tail', linewidth=linewidth, scale=5, color=arrowcolor)
100 | except:
101 | pass
102 |
103 |
104 |
--------------------------------------------------------------------------------
/jaxgeometry/manifolds/H2.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.params import *
23 |
24 | from jaxgeometry.manifolds.ellipsoid import *
25 |
26 | import matplotlib.pyplot as plt
27 | from mpl_toolkits.mplot3d import Axes3D
28 | from matplotlib import cm
29 | import matplotlib.ticker as ticker
30 |
31 | class H2(EmbeddedManifold):
32 | """ hyperbolic plane """
33 |
34 | def __init__(self,):
35 | F = lambda x: jnp.stack([jnp.cosh(x[0][0]),jnp.sinh(x[0][0])*jnp.cos(x[0][1]),jnp.sinh(x[0][0])*jnp.sin(x[0][1])])
36 | invF = lambda x: jnp.stack([jnp.arccosh(x[0][0]),jnp.arctan2(x[0][2],x[0][1])])
37 | self.do_chart_update = lambda x: False
38 |
39 | EmbeddedManifold.__init__(self,F,2,3,invF=invF)
40 |
41 | # metric matrix from embedding into Minkowski space
42 | self.g = lambda x: jnp.einsum('ji,j,jl',self.JF(x),np.array([-1.,1.,1.]),self.JF(x))
43 |
44 |
45 | def __str__(self):
46 | return "%dd dim hyperbolic space" % (self.dim,)
47 |
48 | def newfig(self):
49 | newfig3d()
50 |
51 | def plot(self,rotate=None,alpha=None,lw=0.3):
52 | ax = plt.gca()
53 | x = np.arange(-10,10,1)
54 | ax.xaxis._axinfo["grid"]['linewidth'] = lw
55 | ax.yaxis._axinfo["grid"]['linewidth'] = lw
56 | ax.zaxis._axinfo["grid"]['linewidth'] = lw
57 | ax.set_xlim(1.,2.)
58 | ax.set_ylim(-1.,1.)
59 | ax.set_zlim(-1.,1.)
60 | #ax.set_aspect("equal")
61 | if rotate is not None:
62 | ax.view_init(rotate[0],rotate[1])
63 | # else:
64 | # ax.view_init(35,225)
65 | plt.xlabel('x')
66 | plt.ylabel('y')
67 |
68 | # ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f'))
69 | #draw ellipsoid
70 | u, v = np.mgrid[-1.5:1.5:20j, 0:2*np.pi:20j]
71 | x=np.cosh(u)
72 | y=np.sinh(u)*np.cos(v)
73 | z=np.sinh(u)*np.sin(v)
74 | ax.plot_wireframe(x, y, z, color='gray', alpha=0.5)
75 |
76 | if alpha is not None:
77 | ax.plot_surface(x, y, z, color=cm.jet(0.), alpha=alpha)
78 |
79 |
80 | def plot_field(self, field,lw=.3):
81 | ax = plt.gca(projection='3d')
82 | x = np.arange(-10,10,1)
83 | ax.xaxis._axinfo["grid"]['linewidth'] = lw
84 | ax.yaxis._axinfo["grid"]['linewidth'] = lw
85 | ax.zaxis._axinfo["grid"]['linewidth'] = lw
86 | ax.set_xlim(-1.,1.)
87 | ax.set_ylim(-1.,1.)
88 | ax.set_zlim(-1.,1.)
89 | #ax.set_aspect("equal")
90 |
91 | plt.xlabel('x')
92 | plt.ylabel('y')
93 |
94 | # ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f'))
95 | #draw ellipsoid
96 | u, v = np.mgrid[0:2*np.pi:20j, 0:np.pi:10j]
97 | x=np.cosh(u)
98 | y=np.sinh(u)*np.cos(v)
99 | z=np.sinh(u)*np.sin(v)
100 |
101 | for i in range(x.shape[0]):
102 | for j in range(x.shape[1]):
103 | Fx = np.array([x[i,j],y[i,j],z[i,j]])
104 | chart = self.centered_chart(Fx)
105 | xcoord = self.invF((Fx,chart))
106 | v = field((xcoord,chart))
107 | self.plotx((xcoord,chart),v=v)
108 |
--------------------------------------------------------------------------------
/jaxgeometry/manifolds/Heisenberg.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.params import *
23 |
24 | from jaxgeometry.manifolds.manifold import *
25 |
26 | import matplotlib.pyplot as plt
27 | from mpl_toolkits.mplot3d import Axes3D
28 | from matplotlib import cm
29 | import matplotlib.ticker as ticker
30 |
31 | class Heisenberg(Manifold):
32 | """ Heisenberg group """
33 |
34 | def __init__(self,):
35 | Manifold.__init__(self)
36 |
37 | self.dim = 3
38 | self.sR_dim = 2
39 |
40 | self.update_coords = lambda coords,chart: (coords[0],chart)
41 |
42 | ##### (orthonormal) distribution
43 | self.D = lambda x: jnp.array([[1,0,-x[0][1]/2],[0,1,x[0][0]/2]]).T
44 | #self.D = lambda x: jnp.array([[1,0,0],[0,1,0],[0,0,1]]).T
45 |
46 | def __str__(self):
47 | return "Heisenberg group"
48 |
49 | def plot(self):
50 | None
51 |
52 | def plot_path(self, xs, vs=None, v_steps=None, i0=0, color='b',
53 | color_intensity=1., linewidth=1., s=15., prevx=None, prevchart=None, last=True):
54 |
55 | if vs is not None and v_steps is not None:
56 | v_steps = np.arange(0,n_steps)
57 |
58 | xs = list(xs)
59 | N = len(xs)
60 | prevx = None
61 | for i,x in enumerate(xs):
62 | self.plotx(x, v=vs[i] if vs is not None else None,
63 | v_steps=v_steps,i=i,
64 | color=color,
65 | color_intensity=color_intensity if i==0 or i==N-1 else .7,
66 | linewidth=linewidth,
67 | s=s,
68 | prevx=prevx,
69 | last=i==(N-1))
70 | prevx = x
71 | return
72 |
73 | # plot x in coordinates
74 | def plotx(self, x, u=None, v=None, v_steps=None, i=0, color='b',
75 | color_intensity=1., linewidth=1., s=15., prevx=None, prevchart=None, last=True):
76 | if (type(x) != type(())):
77 | x = (x,)
78 | if (prevx is not None and type(prevx) != type(())):
79 | prevx = (prevx,)
80 |
81 | if v is not None and v_steps is None:
82 | v_steps = np.arange(0,n_steps,10)
83 |
84 | ax = plt.gca()
85 | if prevx is None or last:
86 | ax.scatter(x[0][0],x[0][1],x[0][2],color=color,s=s)
87 | if prevx is not None:
88 | xx = np.stack((prevx[0],x[0]))
89 | ax.plot(xx[:,0],xx[:,1],xx[:,2],linewidth=linewidth,color=color)
90 |
91 | if u is not None:
92 | ax.quiver(x[0][0], x[0][1], x[0][2], u[0], u[1], u[2],
93 | pivot='tail',
94 | arrow_length_ratio = 0.15, linewidths=linewidth, length=0.5,
95 | color='black')
96 |
97 | if v is not None:
98 | if i in v_steps:
99 | ax.quiver(x[0][0], x[0][1], x[0][2], v[0], v[1], v[2],
100 | pivot='tail',
101 | arrow_length_ratio = 0.15, linewidths=linewidth, length=1.0,
102 | color='black')
103 |
104 | # funtion to evaluate spherical harmonics
105 | def Y(l,m,theta,phi):
106 | if l == 0 and m == 0:
107 | return 1/np.sqrt(4*np.pi)
108 | elif l == 1 and m == -1:
109 | return np.sqrt(3/(4*np.pi))*np.sin(theta)*np.exp(1j*phi)
110 | elif l == 1 and m == 0:
111 | return np.sqrt(3/(4*np.pi))*np.cos(theta)
112 | elif l == 1 and m == 1:
113 | return np.sqrt(3/(4*np.pi))*np.sin(theta)*np.exp(-1j*phi)
114 | elif l == 2 and m == -2:
115 | return np.sqrt(15/(16*np.pi))*np.sin(theta)**2*np.exp(2j*phi)
116 | elif l == 2 and m == -1:
117 | return np.sqrt(15/(8*np.pi))*np.sin(theta)*np.cos(theta)*np.exp(1j*phi)
118 | elif l == 2 and m == 0:
119 | return np.sqrt(5/(16*np.pi))*(3*np.cos(theta)**2-1)
120 | elif l == 2 and m == 1:
121 | return np.sqrt(15/(8*np.pi))*np.sin(theta)*np.cos(theta)*np.exp(-1j*phi)
122 | elif l == 2 and m == 2:
123 | return np.sqrt(15/(16*np.pi))*np.sin(theta)**2*np.exp(-2j*phi)
124 | else:
125 | return 0
126 |
127 |
--------------------------------------------------------------------------------
/jaxgeometry/manifolds/S2.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.params import *
23 |
24 | from jaxgeometry.manifolds.ellipsoid import *
25 |
26 | import matplotlib.pyplot as plt
27 | from mpl_toolkits.mplot3d import Axes3D
28 | from matplotlib import cm
29 | import matplotlib.ticker as ticker
30 |
31 | class S2(Ellipsoid):
32 | """ 2d Sphere """
33 |
34 | def __init__(self,use_spherical_coords=False,chart_center='z'):
35 | Ellipsoid.__init__(self,params=[1.,1.,1.],chart_center=chart_center,use_spherical_coords=use_spherical_coords)
36 |
37 | def __str__(self):
38 | return "%dd sphere (ellipsoid parameters %s, spherical_coords: %s)" % (self.dim,self.params,self.use_spherical_coords)
39 |
40 |
--------------------------------------------------------------------------------
/jaxgeometry/manifolds/SPDN.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.params import *
22 |
23 | from jaxgeometry.manifolds.manifold import *
24 |
25 | from jaxgeometry.plotting import *
26 | import matplotlib.pyplot as plt
27 | from matplotlib import cm
28 |
29 | class SPDN(EmbeddedManifold):
30 | """ manifold of symmetric positive definite matrices """
31 |
32 | def __init__(self,N=3):
33 | self.N = N
34 | dim = N*(N+1)//2
35 | emb_dim = N*N
36 | EmbeddedManifold.__init__(self,dim=dim,emb_dim=emb_dim)
37 |
38 | self.act = lambda g,q: jnp.tensordot(g,jnp.tensordot(q.reshape((N,N)),g,(1,1)),(1,0)).flatten()
39 | self.acts = jax.vmap(self.act,(0,None))
40 |
41 | def __str__(self):
42 | return "SPDN(%d), dim %d" % (self.N,self.dim)
43 |
44 |
45 | def plot(self, rotate=None, alpha = None):
46 | ax = plt.gca()
47 | #ax.set_aspect("equal")
48 | if rotate != None:
49 | ax.view_init(rotate[0],rotate[1])
50 | # else:
51 | # ax.view_init(35,225)
52 | plt.xlabel('x')
53 | plt.ylabel('y')
54 |
55 |
56 | def plot_path(self, x,color_intensity=1.,color=None,linewidth=3.,prevx=None,ellipsoid=None,i=None,maxi=None):
57 | assert(len(x.shape)>1)
58 | for i in range(x.shape[0]):
59 | self.plotx(x[i],
60 | linewidth=linewidth if i==0 or i==x.shape[0]-1 else .3,
61 | color_intensity=color_intensity if i==0 or i==x.shape[0]-1 else .7,
62 | prevx=x[i-1] if i>0 else None,ellipsoid=ellipsoid,i=i,maxi=x.shape[0])
63 | return
64 |
65 | def plotx(self, x,color_intensity=1.,color=None,linewidth=3.,prevx=None,ellipsoid=None,i=None,maxi=None):
66 | x = x.reshape((self.N,self.N))
67 | (w,V) = np.linalg.eigh(x)
68 | s = np.sqrt(w[np.newaxis,:])*V # scaled eigenvectors
69 | if prevx is not None:
70 | prevx = prevx.reshape((self.N,self.N))
71 | (prevw,prevV) = np.linalg.eigh(prevx)
72 | prevs = np.sqrt(prevw[np.newaxis,:])*prevV # scaled eigenvectors
73 | ss = np.stack((prevs,s))
74 |
75 | colors = color_intensity*np.array([[1,0,0],[0,1,0],[0,0,1]])
76 | if ellipsoid is None:
77 | for i in range(s.shape[1]):
78 | plt.quiver(0,0,0,s[0,i],s[1,i],s[2,i],pivot='tail',linewidth=linewidth,color=colors[i] if color is None else color,arrow_length_ratio=.15,length=1)
79 | if prevx is not None:
80 | plt.plot(ss[:,0,i],ss[:,1,i],ss[:,2,i],linewidth=.3,color=colors[i])
81 | else:
82 | try:
83 | if i % int(ellipsoid['step']) != 0 and i != maxi-1:
84 | return
85 | except:
86 | pass
87 | try:
88 | if ellipsoid['subplot']:
89 | (fig,ax) = newfig3d(1,maxi//int(ellipsoid['step'])+1,i//int(ellipsoid['step'])+1,new_figure=i==0)
90 | except:
91 | (fig,ax) = newfig3d()
92 | #draw ellipsoid, from https://stackoverflow.com/questions/7819498/plotting-ellipsoid-with-matplotlib
93 | U, ss, rotation = np.linalg.svd(x)
94 | radii = np.sqrt(ss)
95 | u = np.linspace(0., 2.*np.pi, 20)
96 | v = np.linspace(0., np.pi, 10)
97 | x = radii[0] * np.outer(np.cos(u), np.sin(v))
98 | y = radii[1] * np.outer(np.sin(u), np.sin(v))
99 | z = radii[2] * np.outer(np.ones_like(u), np.cos(v))
100 | for l in range(x.shape[0]):
101 | for k in range(x.shape[1]):
102 | [x[l,k],y[l,k],z[l,k]] = np.dot([x[l,k],y[l,k],z[l,k]], rotation)
103 | ax.plot_surface(x, y, z, facecolors=cm.winter(y/np.amax(y)), linewidth=0, alpha=ellipsoid['alpha'])
104 | for i in range(s.shape[1]):
105 | plt.quiver(0,0,0,s[0,i],s[1,i],s[2,i],pivot='tail',linewidth=linewidth,color=colors[i] if color is None else color,arrow_length_ratio=.15,length=1)
106 | plt.axis('off')
107 |
--------------------------------------------------------------------------------
/jaxgeometry/manifolds/cylinder.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.params import *
23 |
24 | from jaxgeometry.manifolds.manifold import *
25 |
26 | from jaxgeometry.plotting import *
27 | import matplotlib.pyplot as plt
28 | from mpl_toolkits.mplot3d import Axes3D
29 | from matplotlib import cm
30 | import matplotlib.ticker as ticker
31 |
32 | class Cylinder(EmbeddedManifold):
33 | """ 2d Cylinder """
34 |
35 | def chart(self):
36 | """ return default coordinate chart """
37 | return jnp.zeros(self.dim)
38 |
39 | def centered_chart(self,x):
40 | """ return centered coordinate chart """
41 | if type(x) == type(()): # coordinate tuple
42 | Fx = jax.lax.stop_gradient(self.F(x))
43 | else:
44 | Fx = x # already in embedding space
45 | return self.invF((Fx,self.chart())) # chart centered at coords
46 |
47 | def get_B(self,v):
48 | """ R^3 basis with first basis vector v """
49 | b1 = v
50 | k = jnp.argmin(jnp.abs(v))
51 | ek = jnp.eye(3)[:,k]
52 | b2 = ek-v[k]*v
53 | b3 = cross(b1,b2)
54 | return jnp.stack((b1,b2,b3),axis=1)
55 |
56 | # Logarithm with standard Riemannian metric
57 | def StdLog(self,_x,y):
58 | (x,chart) = self.update_coords(_x,self.centered_chart(self.F(_x)))
59 | y = self.invF((y,chart))
60 | return self.update_vector((x,chart),_x[0],_x[1],y-x)
61 |
62 | def __init__(self,params=(1.,jnp.array([0.,1.,0.]),0.)):
63 | self.radius = params[0] # axis of cylinder
64 | self.orientation = jnp.array(params[1]) # axis of cylinder
65 | self.theta = params[2] # angle around rotation axis
66 |
67 | F = lambda x: jnp.dot(self.get_B(self.orientation),
68 | jnp.stack([x[0][1]+x[1][1],self.radius*jnp.cos(self.theta+x[1][0]+x[0][0]),self.radius*jnp.sin(self.theta+x[1][0]+x[0][0])]))
69 | def invF(x):
70 | Rinvx = jnp.linalg.solve(self.get_B(self.orientation),x[0])
71 | rotangle = -(self.theta+x[1][0])
72 | rot = jnp.dot(jnp.stack(
73 | (jnp.stack((jnp.cos(rotangle),-jnp.sin(rotangle))),
74 | jnp.stack((jnp.sin(rotangle),jnp.cos(rotangle))))),
75 | Rinvx[1:])
76 | return jnp.stack([jnp.arctan2(rot[1],rot[0]),Rinvx[0]-x[1][1]])
77 | self.do_chart_update = lambda x: jnp.max(jnp.abs(x[0])) >= np.pi/4 # look for a new chart if true
78 |
79 | EmbeddedManifold.__init__(self,F,2,3,invF=invF)
80 |
81 | def __str__(self):
82 | return "cylinder in R^3, radius %s, axis %s, rotation around axis %s" % (self.radius,self.orientation,self.theta)
83 |
84 | def newfig(self):
85 | newfig3d()
86 |
87 | def plot(self, rotate=None,alpha=None,lw=0.3):
88 | ax = plt.gca()
89 | x = np.arange(-10,10,1)
90 | ax.xaxis._axinfo["grid"]['linewidth'] = lw
91 | ax.yaxis._axinfo["grid"]['linewidth'] = lw
92 | ax.zaxis._axinfo["grid"]['linewidth'] = lw
93 | ax.set_xlim(-1.,1.)
94 | ax.set_ylim(-1.,1.)
95 | ax.set_zlim(-1.,1.)
96 | #ax.set_aspect("equal")
97 | if rotate is not None:
98 | ax.view_init(rotate[0],rotate[1])
99 | # else:
100 | # ax.view_init(35,225)
101 | plt.xlabel('x')
102 | plt.ylabel('y')
103 |
104 | # ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f'))
105 | #draw ellipsoid
106 | u, v = np.mgrid[-np.pi:np.pi:20j, -np.pi:np.pi:10j]
107 | x = np.zeros(u.shape)
108 | y = np.zeros(u.shape)
109 | z = np.zeros(u.shape)
110 | for i in range(u.shape[0]):
111 | for j in range(u.shape[1]):
112 | w = self.F(self.coords(jnp.array([u[i,j],v[i,j]])))
113 | x[i,j] = w[0]; y[i,j] = w[1]; z[i,j] = w[2]
114 | ax.plot_wireframe(x, y, z, color='gray', alpha=0.5)
115 |
116 | if alpha is not None:
117 | ax.plot_surface(x, y, z, color=cm.jet(0.), alpha=alpha)
118 |
119 |
120 | def plot_field(self, field,lw=.3):
121 | ax = plt.gca()
122 | x = np.arange(-10,10,1)
123 | ax.xaxis._axinfo["grid"]['linewidth'] = lw
124 | ax.yaxis._axinfo["grid"]['linewidth'] = lw
125 | ax.zaxis._axinfo["grid"]['linewidth'] = lw
126 | ax.set_xlim(-1.,1.)
127 | ax.set_ylim(-1.,1.)
128 | ax.set_zlim(-1.,1.)
129 | #ax.set_aspect("equal")
130 |
131 | plt.xlabel('x')
132 | plt.ylabel('y')
133 |
134 | u, v = np.mgrid[-np.pi:np.pi:40j, -np.pi:np.pi:20j]
135 | x = np.zeros(u.shape)
136 | y = np.zeros(u.shape)
137 | z = np.zeros(u.shape)
138 | for i in range(u.shape[0]):
139 | for j in range(u.shape[1]):
140 | w = self.F(self.coords(jnp.array([u[i,j],v[i,j]])))
141 | x[i,j] = w[0]; y[i,j] = w[1]; z[i,j] = w[2]
142 |
143 | for i in range(x.shape[0]):
144 | for j in range(x.shape[1]):
145 | Fx = np.array([x[i,j],y[i,j],z[i,j]])
146 | chart = self.centered_chartf(Fx)
147 | xcoord = self.invFf((Fx,chart))
148 | v = field((xcoord,chart))
149 | self.plotx((xcoord,chart),v=v)
150 |
151 |
152 |
153 |
--------------------------------------------------------------------------------
/jaxgeometry/manifolds/ellipsoid.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.params import *
22 |
23 | from jaxgeometry.manifolds.manifold import *
24 |
25 | from jaxgeometry.plotting import *
26 | import matplotlib.pyplot as plt
27 | from mpl_toolkits.mplot3d import Axes3D
28 | from matplotlib import cm
29 | import matplotlib.ticker as ticker
30 |
31 | class Ellipsoid(EmbeddedManifold):
32 | """ 2d Ellipsoid """
33 |
34 | def chart(self):
35 | """ return default coordinate chart """
36 | if self.chart_center == 'x':
37 | return jnp.eye(3)[:,0]
38 | elif self.chart_center == 'y':
39 | return jnp.eye(3)[:,1]
40 | elif self.chart_center == 'z':
41 | return jnp.eye(3)[:,2]
42 | else:
43 | assert(False)
44 |
45 | def centered_chart(self,x):
46 | """ return centered coordinate chart """
47 | if type(x) == type(()): # coordinate tuple
48 | return jax.lax.stop_gradient(self.F(x))/self.params
49 | else:
50 | return x/self.params # already in embedding space
51 |
52 | def get_B(self,v):
53 | """ R^3 basis with first basis vector v """
54 | b1 = v
55 | k = jnp.argmin(jnp.abs(v))
56 | ek = jnp.eye(3)[:,k]
57 | b2 = ek-v[k]*v
58 | b3 = cross(b1,b2)
59 | return jnp.stack((b1,b2,b3),axis=1)
60 |
61 | # Logarithm with standard Riemannian metric on S^2
62 | def StdLogEmb(self, x,y):
63 | y = y/self.params # from ellipsoid to S^2
64 | proj = lambda x,y: jnp.dot(x,y)*x
65 | Fx = self.F(x)/self.params
66 | v = y-proj(Fx,y)
67 | theta = jnp.arccos(jnp.dot(Fx,y))
68 | normv = jnp.linalg.norm(v,2)
69 | w = jax.lax.cond(normv >= 1e-5,
70 | lambda _: theta/normv*v,
71 | lambda _: jnp.zeros_like(v),
72 | None)
73 | return self.params*w
74 | def StdLog(self, x,y):
75 | Fx = self.F(x)/self.params
76 | return jnp.dot(self.invJF((Fx,x[1])),self.StdLogEmb(x,y))
77 |
78 | def __init__(self,params=np.array([1.,1.,1.]),chart_center='z',use_spherical_coords=False):
79 | self.params = jnp.array(params) # ellipsoid parameters (e.g. [1.,1.,1.] for sphere)
80 | self.use_spherical_coords = use_spherical_coords
81 | self.chart_center = chart_center
82 |
83 | if not use_spherical_coords:
84 | F = lambda x: self.params*jnp.dot(self.get_B(x[1]),jnp.stack([-(-1+x[0][0]**2+x[0][1]**2),2*x[0][0],2*x[0][1]])/(1+x[0][0]**2+x[0][1]**2))
85 | def invF(x):
86 | Rinvx = jnp.linalg.solve(self.get_B(x[1]),x[0]/self.params)
87 | return jnp.stack([Rinvx[1]/(1+Rinvx[0]),Rinvx[2]/(1+Rinvx[0])])
88 | self.do_chart_update = lambda x: jnp.linalg.norm(x[0]) > .1 # look for a new chart if true
89 | # spherical coordinates, no charts
90 | self.F_spherical = lambda phitheta: self.params*jnp.stack([jnp.sin(phitheta[1]-np.pi/2)*jnp.cos(phitheta[0]),jnp.sin(phitheta[1]-np.pi/2)*jnp.sin(phitheta[0]),jnp.cos(phitheta[1]-np.pi/2)])
91 | self.JF_spherical = lambda x: jnp.jacobian(self.F_spherical(x),x)
92 | self.F_spherical_inv = lambda x: jnp.stack([jnp.arctan2(x[1],x[0]),jnp.arccos(x[2])])
93 | self.g_spherical = lambda x: jnp.dot(self.JF_spherical(x).T,self.JF_spherical(x))
94 | self.mu_Q_spherical = lambda x: 1./jnp.nlinalg.Det()(self.g_spherical(x))
95 |
96 | ## optionally use spherical coordinates in chart computations
97 | #if use_spherical_coords:
98 | # F = lambda x: jnp.dot(x[1],self.F_spherical(x[0]))
99 |
100 | EmbeddedManifold.__init__(self,F,2,3,invF=invF)
101 |
102 | # action of matrix group on elements
103 | self.act = lambda g,x: jnp.tensordot(g,x,(1,0))
104 | self.acts = lambda g,x: jnp.tensordot(g,x,(2,0))
105 |
106 |
107 | def __str__(self):
108 | return "%dd ellipsoid, parameters %s, spherical coords %s" % (self.dim,self.params,self.use_spherical_coords)
109 |
110 | def newfig(self):
111 | newfig3d()
112 |
113 | def plot(self,rotate=None,alpha=None,lw=0.3,color='gray',scale=1.):
114 | ax = plt.gca()
115 | x = np.arange(-10,10,1)
116 | ax.xaxis._axinfo["grid"]['linewidth'] = lw
117 | ax.yaxis._axinfo["grid"]['linewidth'] = lw
118 | ax.zaxis._axinfo["grid"]['linewidth'] = lw
119 | ax.set_xlim(-1.,1.)
120 | ax.set_ylim(-1.,1.)
121 | ax.set_zlim(-1.,1.)
122 | #ax.set_aspect("equal")
123 | if rotate is not None:
124 | ax.view_init(rotate[0],rotate[1])
125 | # else:
126 | # ax.view_init(35,225)
127 | plt.xlabel('x')
128 | plt.ylabel('y')
129 |
130 | # ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f'))
131 | #draw ellipsoid
132 | u, v = np.mgrid[0:2*np.pi:20j, 0:np.pi:10j]
133 | x=scale*self.params[0]*np.cos(u)*np.sin(v)
134 | y=scale*self.params[1]*np.sin(u)*np.sin(v)
135 | z=scale*self.params[2]*np.cos(v)
136 | ax.plot_wireframe(x, y, z, color=color, alpha=0.5)
137 |
138 | if alpha is not None:
139 | ax.plot_surface(x, y, z, color=color, alpha=alpha)
140 |
141 |
142 | def plot_field(self, field,lw=.3, scale=1.):
143 | ax = plt.gca()
144 | x = np.arange(-10,10,1)
145 | ax.xaxis._axinfo["grid"]['linewidth'] = lw
146 | ax.yaxis._axinfo["grid"]['linewidth'] = lw
147 | ax.zaxis._axinfo["grid"]['linewidth'] = lw
148 | ax.set_xlim(-1.,1.)
149 | ax.set_ylim(-1.,1.)
150 | ax.set_zlim(-1.,1.)
151 | #ax.set_aspect("equal")
152 |
153 | plt.xlabel('x')
154 | plt.ylabel('y')
155 |
156 | # ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f'))
157 | #draw ellipsoid
158 | u, v = np.mgrid[0:2*np.pi:40j, 0:np.pi:20j]
159 | x=self.params[0]*np.cos(u)*np.sin(v)
160 | y=self.params[1]*np.sin(u)*np.sin(v)
161 | z=self.params[2]*np.cos(v)
162 |
163 | for i in range(x.shape[0]):
164 | for j in range(x.shape[1]):
165 | Fx = np.array([x[i,j],y[i,j],z[i,j]])
166 | chart = self.centered_chart(Fx)
167 | xcoord = self.invF((Fx,chart))
168 | v = field((xcoord,chart))
169 | self.plotx((xcoord,chart),v=scale*v)
170 |
--------------------------------------------------------------------------------
/jaxgeometry/manifolds/latent.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.params import *
22 |
23 | from jaxgeometry.manifolds.manifold import *
24 |
25 | from jaxgeometry.plotting import *
26 | import matplotlib.pyplot as plt
27 | from mpl_toolkits.mplot3d import Axes3D
28 | from matplotlib import cm
29 | import matplotlib.ticker as ticker
30 | from scipy.stats import norm
31 |
32 | class Latent(EmbeddedManifold):
33 | """ Latent space manifold define from embedding function F:R^dim->R^emb_dim, f e.g. a neural network """
34 |
35 | def __init__(self,F,dim,emb_dim,invF=None):
36 | EmbeddedManifold.__init__(self,F,dim,emb_dim,invF)
37 |
38 | # metric matrix
39 | self.g = lambda x: T.dot(self.JF(x).T,self.JF(x))
40 |
41 | def newfig(self):
42 | if self.emb_dim.eval() == 3:
43 | newfig3d()
44 | elif self.dim.eval() == 2:
45 | newfig2d()
46 |
47 | def plot(self, rotate=None, alpha=None, lw=0.3):
48 | if self.emb_dim.eval() == 3:
49 | ax = plt.gca(projection='3d')
50 | x = np.arange(-10, 10, 1)
51 | ax.xaxis._axinfo["grid"]['linewidth'] = lw
52 | ax.yaxis._axinfo["grid"]['linewidth'] = lw
53 | ax.zaxis._axinfo["grid"]['linewidth'] = lw
54 | ax.set_xlim(-1., 1.)
55 | ax.set_ylim(-1., 1.)
56 | ax.set_zlim(-1., 1.)
57 | ax.set_aspect("equal")
58 | if rotate is not None:
59 | ax.view_init(rotate[0], rotate[1])
60 | # else:
61 | # ax.view_init(35,225)
62 | plt.xlabel('x')
63 | plt.ylabel('y')
64 |
65 | # ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f'))
66 | # draw surface
67 | X, Y = np.meshgrid(norm.ppf(np.linspace(0.05, 0.95, 20)), norm.ppf(np.linspace(0.05, 0.95, 20)))
68 | xy = np.vstack([X.ravel(), Y.ravel()]).T
69 | xyz = np.apply_along_axis(self.Ff, 1, xy)
70 | x = xyz[:, 0].reshape(X.shape);
71 | y = xyz[:, 1].reshape(X.shape);
72 | z = xyz[:, 2].reshape(X.shape)
73 | print(z.shape)
74 | ax.plot_wireframe(x, y, z, color='gray', alpha=0.5)
75 |
76 | if alpha is not None:
77 | ax.plot_surface(x, y, z, color=cm.jet(0.), alpha=alpha)
78 |
79 | # plot x on ellipsoid. x can be either in coordinates or in R^3
80 | def plotx(self, x, u=None, v=None, N_vec=np.arange(0,n_steps.eval()), i0=0, color='b', color_intensity=1., linewidth=1., s=15., prevx=None, last=True):
81 | if len(x.shape)>1:
82 | for i in range(x.shape[0]):
83 | self.plotx(x[i], u=u if i == 0 else None, v=v[i] if v is not None else None,
84 | N_vec=N_vec,i0=i,
85 | color=color,
86 | color_intensity=color_intensity if i==0 or i==x.shape[0]-1 else .7,
87 | linewidth=linewidth,
88 | s=s,
89 | prevx=x[i-1] if i>0 else None,
90 | last=i==(x.shape[0]-1))
91 | return
92 |
93 | if self.emb_dim.eval() == 3:
94 | xcoords = x
95 | if x.shape[0] < 3: # map to embedding space
96 | x = self.Ff(x)
97 |
98 | ax = plt.gca(projection='3d')
99 | if prevx is None or last:
100 | ax.scatter(x[0],x[1],x[2],color=color,s=s)
101 | if prevx is not None:
102 | if prevx.shape[0] < 3:
103 | prevx = self.Ff(prevx)
104 | xx = np.stack((prevx,x))
105 | ax.plot(xx[:,0],xx[:,1],xx[:,2],linewidth=linewidth,color=color)
106 |
107 | if u is not None:
108 | JFx = self.JFf(xcoords)
109 | u = np.dot(JFx, u)
110 | ax.quiver(x[0], x[1], x[2], u[0], u[1], u[2],
111 | pivot='tail',
112 | arrow_length_ratio = 0.15, linewidths=linewidth, length=0.5,
113 | color='black')
114 |
115 | if v is not None:
116 | #Seq = lambda m, n: [t*n//m + n//(2*m) for t in range(m)]
117 | #Seqv = np.hstack([0,Seq(N_vec,n_steps.get_value())])
118 | if i0 in N_vec:#Seqv:
119 | JFx = self.JFf(xcoords)
120 | v = np.dot(JFx, v)
121 | ax.quiver(x[0], x[1], x[2], v[0], v[1], v[2],
122 | pivot='tail',
123 | arrow_length_ratio = 0.15, linewidths=linewidth, length=0.5,
124 | color='black')
125 | elif self.dim.eval() == 2:
126 | if prevx is None or last:
127 | plt.scatter(x[0],x[1],color=color,s=s)
128 | if prevx is not None:
129 | xx = np.stack((prevx,x))
130 | plt.plot(xx[:,0],xx[:,1],linewidth=linewidth,color=color)
131 | if v is not None:
132 | if i0 in N_vec:#Seqv:
133 | plt.quiver(x[0], x[1], v[0], v[1], pivot='tail', linewidth=linewidth, color='black',
134 | angles='xy', scale_units='xy', scale=1)
135 |
136 |
--------------------------------------------------------------------------------
/jaxgeometry/manifolds/manifold.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.params import *
22 |
23 | import matplotlib.pyplot as plt
24 |
25 | class Manifold(object):
26 | """ Base manifold class """
27 |
28 | def __init__(self):
29 | self.dim = None
30 | if not hasattr(self, 'do_chart_update'):
31 | self.do_chart_update = None # set to relevant function if what updates are desired
32 |
33 | def chart(self):
34 | """ return default or specified coordinate chart. This method will generally be overriding by inheriting classes """
35 | # default value
36 | return jnp.zeros(1)
37 |
38 | def centered_chart(self,coords):
39 | """ return centered coordinate chart. Must be implemented by inheriting classes
40 | Generally wish to stop gradient computations through the chart choice
41 | """
42 | return jax.lax.stop_gradient(jnp.zeros(1))
43 |
44 | def coords(self,coords=None,chart=None):
45 | """ return coordinate representation of point in manifold """
46 | if coords is None:
47 | coords = jnp.zeros(self.dim)
48 | if chart is None:
49 | chart = self.chart()
50 |
51 | return (jnp.array(coords),chart)
52 |
53 | def update_coords(self,coords,new_chart):
54 | """ change between charts """
55 | assert(False) # not implemented here
56 |
57 | def update_vector(self,coords,new_coords,new_chart,v):
58 | """ change tangent vector between charts """
59 | assert(False) # not implemented here
60 |
61 | def update_covector(self,coords,new_coords,new_chart,p):
62 | """ change cotangent vector between charts """
63 | assert(False) # not implemented here
64 |
65 | def newfig(self):
66 | """ open new plot for manifold """
67 |
68 | def __str__(self):
69 | return "abstract manifold"
70 |
71 | class EmbeddedManifold(Manifold):
72 | """ Embedded manifold base class """
73 |
74 | def update_coords(self,coords,new_chart):
75 | """ change between charts """
76 | return (self.invF((self.F(coords),new_chart)),new_chart)
77 |
78 | def update_vector(self,coords,new_coords,new_chart,v):
79 | """ change tangent vector between charts """
80 | return jnp.tensordot(self.invJF((self.F((new_coords,new_chart)),new_chart)),jnp.tensordot(self.JF(coords),v,(1,0)),(1,0))
81 |
82 | def update_covector(self,coords,new_coords,new_chart,p):
83 | """ change cotangent vector between charts """
84 | return jnp.tensordot(self.JF((new_coords,new_chart)).T,jnp.tensordot(self.invJF((self.F(coords),coords[1])).T,p,(1,0)),(1,0))
85 |
86 | def __init__(self,F=None,dim=None,emb_dim=None,invF=None):
87 | Manifold.__init__(self)
88 | self.dim = dim
89 | self.emb_dim = emb_dim
90 |
91 | # embedding map and its inverse
92 | if F is not None:
93 | self.F = F
94 | self.invF = invF
95 | self.JF = jacfwdx(self.F)
96 | self.invJF = jacfwdx(self.invF)
97 |
98 | # metric matrix
99 | self.g = lambda x: jnp.tensordot(self.JF(x),self.JF(x),(0,0))
100 |
101 |
102 | def plot_path(self, xs, vs=None, v_steps=None, i0=0, color='b',
103 | color_intensity=1., linewidth=1., s=15., prevx=None, prevchart=None, last=True):
104 |
105 | if vs is not None and v_steps is not None:
106 | v_steps = np.arange(0,n_steps)
107 |
108 | xs = list(xs)
109 | N = len(xs)
110 | prevx = None
111 | for i,x in enumerate(xs):
112 | xx = x[0] if type(x) is tuple else x
113 | if xx.shape[0] > self.dim and (self.emb_dim == None or xx.shape[0] != self.emb_dim): # attached vectors to display
114 | v = xx[self.dim:].reshape((self.dim,-1))
115 | x = (xx[0:self.dim],x[1]) if type(x) is tuple else xx[0:self.dim]
116 | elif vs is not None:
117 | v = vs[i]
118 | else:
119 | v = None
120 | self.plotx(x, v=v,
121 | v_steps=v_steps,i=i,
122 | color=color,
123 | color_intensity=color_intensity if i==0 or i==N-1 else .7,
124 | linewidth=linewidth,
125 | s=s,
126 | prevx=prevx,
127 | last=i==(N-1))
128 | prevx = x
129 | return
130 |
131 | # plot x. x can be either in coordinates or in R^3
132 | def plotx(self, x, u=None, v=None, v_steps=None, i=0, color='b',
133 | color_intensity=1., linewidth=1., s=15., prevx=None, prevchart=None, last=True):
134 |
135 | assert(type(x) == type(()) or x.shape[0] == self.emb_dim)
136 |
137 | if v is not None and v_steps is None:
138 | v_steps = np.arange(0,n_steps)
139 |
140 | if type(x) == type(()): # map to manifold
141 | Fx = self.F(x)
142 | chart = x[1]
143 | else: # get coordinates
144 | Fx = x
145 | chart = self.centered_chart(Fx)
146 | x = (self.invF((Fx,chart)),chart)
147 |
148 | if prevx is not None:
149 | if type(prevx) == type(()): # map to manifold
150 | Fprevx = self.F(prevx)
151 | else:
152 | Fprevx = prevx
153 | prevx = (self.invF((Fprevx,chart)),chart)
154 |
155 | ax = plt.gca()
156 | if prevx is None or last:
157 | ax.scatter(Fx[0],Fx[1],Fx[2],color=color,s=s)
158 | if prevx is not None:
159 | xx = np.stack((Fprevx,Fx))
160 | ax.plot(xx[:,0],xx[:,1],xx[:,2],linewidth=linewidth,color=color)
161 |
162 | if u is not None:
163 | Fu = np.dot(self.JF(x), u)
164 | ax.quiver(Fx[0], Fx[1], Fx[2], Fu[0], Fu[1], Fu[2],
165 | pivot='tail',
166 | arrow_length_ratio = 0.15, linewidths=linewidth, length=0.5,
167 | color='black')
168 |
169 | if v is not None:
170 | if i in v_steps:
171 | if not v.shape[0] == self.emb_dim:
172 | v = np.dot(self.JF(x), v)
173 | ax.quiver(Fx[0], Fx[1], Fx[2], v[0], v[1], v[2],
174 | pivot='tail',
175 | arrow_length_ratio = 0.15, linewidths=linewidth, length=0.5,
176 | color='black')
177 |
178 | def __str__(self):
179 | return "dim %d manifold embedded in R^%d" % (self.dim,self.emb_dim)
180 |
--------------------------------------------------------------------------------
/jaxgeometry/manifolds/torus.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.params import *
23 |
24 | from jaxgeometry.manifolds.manifold import *
25 |
26 | from jaxgeometry.plotting import *
27 | import matplotlib.pyplot as plt
28 | from mpl_toolkits.mplot3d import Axes3D
29 | from matplotlib import cm
30 | import matplotlib.ticker as ticker
31 |
32 | class Torus(EmbeddedManifold):
33 | """ 2d torus, embedded metric """
34 |
35 | def chart(self):
36 | """ return default coordinate chart """
37 | return jnp.zeros(self.dim)
38 |
39 | def centered_chart(self,x):
40 | """ return centered coordinate chart """
41 | if type(x) == type(()): # coordinate tuple
42 | Fx = jax.lax.stop_gradient(self.F(x))
43 | else:
44 | Fx = x # already in embedding space
45 | return self.invF((Fx,self.chart())) # chart centered at coords
46 |
47 | def get_B(self,v):
48 | """ R^3 basis with first basis vector v """
49 | b1 = v
50 | k = jnp.argmin(jnp.abs(v))
51 | ek = jnp.eye(3)[:,k]
52 | b2 = ek-v[k]*v
53 | b3 = cross(b1,b2)
54 | return jnp.stack((b1,b2,b3),axis=1)
55 |
56 | # Logarithm with standard Riemannian metric
57 | def StdLog(self,_x,y):
58 | (x,chart) = self.update_coords(_x,self.centered_chart(self.F(_x)))
59 | y = self.invF((y,chart))
60 | return self.update_vector((x,chart),_x[0],_x[1],y-x)
61 |
62 | def __init__(self,params=(1.,2.,jnp.array([0.,1.,0.]))):
63 | self.radius = params[0] # axis of small circle
64 | self.Radius = params[1] # axis of large circle
65 | self.orientation = jnp.array(params[2]) # axis of cylinder
66 |
67 | F = lambda x: jnp.dot(self.get_B(self.orientation),
68 | jnp.stack([self.radius*jnp.sin(x[0][1]+x[1][1]),
69 | (self.Radius+self.radius*jnp.cos(x[0][1]+x[1][1]))*jnp.cos(x[0][0]+x[1][0]),
70 | (self.Radius+self.radius*jnp.cos(x[0][1]+x[1][1]))*jnp.sin(x[0][0]+x[1][0])]))
71 | def invF(x):
72 | Rinvx = jnp.linalg.solve(self.get_B(self.orientation),x[0])
73 | rotangle0 = -x[1][0]
74 | rot0 = jnp.dot(jnp.stack(
75 | (jnp.stack((jnp.cos(rotangle0),-jnp.sin(rotangle0))),
76 | jnp.stack((jnp.sin(rotangle0),jnp.cos(rotangle0))))),
77 | Rinvx[1:])
78 | phi = jnp.arctan2(rot0[1],rot0[0])
79 | rotangle1 = -x[1][1]
80 | #epsilons = jnp.where(jnp.cos(phi) >= 1e-4,
81 | # jnp.stack((0.,1e-4)),
82 | # jnp.stack((1e-4,0.))) # to avoid divide by zero in gradient computations
83 | #rcosphi = jnp.where(jnp.cos(phi) >= 1e-4,
84 | # rot0[0]/(jnp.cos(phi)+epsilons[0])-self.Radius,
85 | # rot0[1]/(jnp.sin(phi)+epsilons[1])-self.Radius)
86 | #rcosphi = jnp.where(jnp.cos(phi) >= 1e-4,
87 | # rot0[0]/jnp.cos(phi)-self.Radius,
88 | # rot0[1]/jnp.sin(phi)-self.Radius)
89 | rcosphi = jax.lax.cond(jnp.cos(phi) >= 1e-4,
90 | lambda _: rot0[0]/jnp.cos(phi)-self.Radius,
91 | lambda _: rot0[1]/jnp.sin(phi)-self.Radius,operand=None)
92 | rot1 = jnp.dot(jnp.stack(
93 | (jnp.stack((jnp.cos(rotangle1),-jnp.sin(rotangle1))),
94 | jnp.stack((jnp.sin(rotangle1),jnp.cos(rotangle1))))),
95 | jnp.stack((rcosphi,Rinvx[0])))
96 | theta = jnp.arctan2(rot1[1],rot1[0])
97 | return jnp.stack([phi,theta])
98 | self.do_chart_update = lambda x: jnp.max(jnp.abs(x[0])) >= np.pi/4 # look for a new chart if true
99 |
100 | EmbeddedManifold.__init__(self,F,2,3,invF=invF)
101 |
102 | def __str__(self):
103 | return "torus in R^3, radius %s, Radius %s, axis %s" % (self.radius,self.Radius,self.orientation)
104 |
105 | def newfig(self):
106 | newfig3d()
107 |
108 | def plot(self, rotate=None,alpha=None,lw=0.3):
109 | ax = plt.gca()
110 | x = np.arange(-10,10,1)
111 | ax.xaxis._axinfo["grid"]['linewidth'] = lw
112 | ax.yaxis._axinfo["grid"]['linewidth'] = lw
113 | ax.zaxis._axinfo["grid"]['linewidth'] = lw
114 | ax.set_xlim(-1.,1.)
115 | ax.set_ylim(-1.,1.)
116 | ax.set_zlim(-1.,1.)
117 | #ax.set_aspect("equal")
118 | if rotate is not None:
119 | ax.view_init(rotate[0],rotate[1])
120 | # else:
121 | # ax.view_init(35,225)
122 | plt.xlabel('x')
123 | plt.ylabel('y')
124 |
125 | # ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f'))
126 | #draw ellipsoid
127 | u, v = np.mgrid[-np.pi:np.pi:20j, -np.pi:np.pi:10j]
128 | x = np.zeros(u.shape)
129 | y = np.zeros(u.shape)
130 | z = np.zeros(u.shape)
131 | for i in range(u.shape[0]):
132 | for j in range(u.shape[1]):
133 | w = self.F(self.coords(jnp.array([u[i,j],v[i,j]])))
134 | x[i,j] = w[0]; y[i,j] = w[1]; z[i,j] = w[2]
135 | ax.plot_wireframe(x, y, z, color='gray', alpha=0.5)
136 |
137 | if alpha is not None:
138 | ax.plot_surface(x, y, z, color=cm.jet(0.), alpha=alpha)
139 |
140 | def plot_field(self, field,lw=.3):
141 | ax = plt.gca()
142 | x = np.arange(-10,10,1)
143 | ax.xaxis._axinfo["grid"]['linewidth'] = lw
144 | ax.yaxis._axinfo["grid"]['linewidth'] = lw
145 | ax.zaxis._axinfo["grid"]['linewidth'] = lw
146 | ax.set_xlim(-1.,1.)
147 | ax.set_ylim(-1.,1.)
148 | ax.set_zlim(-1.,1.)
149 | #ax.set_aspect("equal")
150 |
151 | plt.xlabel('x')
152 | plt.ylabel('y')
153 |
154 | u, v = np.mgrid[-np.pi:np.pi:40j, -np.pi:np.pi:20j]
155 | x = np.zeros(u.shape)
156 | y = np.zeros(u.shape)
157 | z = np.zeros(u.shape)
158 | for i in range(u.shape[0]):
159 | for j in range(u.shape[1]):
160 | w = self.F(self.coords(jnp.array([u[i,j],v[i,j]])))
161 | x[i,j] = w[0]; y[i,j] = w[1]; z[i,j] = w[2]
162 |
163 | for i in range(x.shape[0]):
164 | for j in range(x.shape[1]):
165 | Fx = np.array([x[i,j],y[i,j],z[i,j]])
166 | chart = self.centered_chartf(Fx)
167 | xcoord = self.invF((Fx,chart))
168 | v = field((xcoord,chart))
169 | self.plotx((xcoord,chart),v=v)
170 |
171 |
172 |
173 |
--------------------------------------------------------------------------------
/jaxgeometry/params.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 |
23 | ##########################################################################
24 | # this file contains various object definitions, and standard parameters #
25 | ##########################################################################
26 |
27 | # default integration times and time steps
28 | T = 1.
29 | n_steps = 100
30 |
31 | # Integrator variables:
32 | default_method = 'euler'
33 | #default_method = 'rk4'
34 |
35 |
36 |
--------------------------------------------------------------------------------
/jaxgeometry/sR/__init__.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
--------------------------------------------------------------------------------
/jaxgeometry/sR/metric.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.utils import *
23 |
24 | def initialize(M):
25 | """ add sR structure to manifold """
26 | """ currently assumes distribution and that ambient Riemannian manifold is Euclidean """
27 |
28 | d = M.dim
29 |
30 | if not hasattr(M, 'D'):
31 | raise ValueError('no distribution defined on manifold')
32 |
33 | M.sR_dim = M.D(M.coords(jnp.zeros(M.dim))).shape[1]
34 |
35 | if not hasattr(M,'a'):
36 | M.a = lambda x: mmT(M.D(x))
37 | else:
38 | print('using existing M.a')
39 |
40 | ### trivial embedding
41 | M.F = lambda x: x[0]
42 | M.invF = lambda x: (x,M.chart())
43 | M.JF = jacfwdx(M.F)
44 | M.invJF = jacfwdx(M.invF)
45 |
46 | ##### sharp map:
47 | M.sharp = lambda x,p: jnp.tensordot(M.a(x),p,(1,0))
48 |
49 | ##### Hamiltonian
50 | if not hasattr(M,'H'):
51 | M.H = lambda x,p: 5*jnp.sum(jnp.einsum('i,ij->j',p,M.D(x))**2)
52 | else:
53 | print('using existing M.H')
54 |
55 | ##### divergence in divergence free othornormal distribution
56 | M.div = lambda x,X: jnp.einsum('ij,ji->',jacfwdx(X)(x)[:M.sR_dim,:],M.D(x))
57 |
--------------------------------------------------------------------------------
/jaxgeometry/setup.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | import warnings
22 | warnings.simplefilter(action='ignore', category=FutureWarning)
23 |
24 | import numpy as np
25 | import scipy
26 |
27 | import jax
28 | import jax.numpy as jnp
29 | from jax import lax
30 | from jax import grad, jacfwd, jacrev, jit, vmap
31 | from jax import random
32 | from jax.scipy import optimize
33 |
34 | from functools import partial
35 |
36 | from jaxgeometry.utils import *
37 |
38 | import time
39 |
40 | from jaxgeometry.params import *
41 |
42 | import itertools
43 | from functools import partial
44 |
45 |
--------------------------------------------------------------------------------
/jaxgeometry/statistics/Frechet_mean.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 |
22 | from scipy.optimize import minimize
23 | from jax.example_libraries import optimizers
24 |
25 | def initialize(M,Exp=None):
26 |
27 | try:
28 | if Exp is None:
29 | Exp = M.Exp
30 | except AttributeError:
31 | return
32 |
33 | # objective
34 | def f(chart,x,v):
35 | return jnp.dot(v,jnp.dot(M.g((x,chart)),v))
36 | M.Frechet_mean_f = f
37 |
38 | # constraint
39 | def _c(chart,x,v,y,ychart):
40 | xT,chartT = M.Exp((x,chart),v)
41 | y_chartT = M.update_coords((y,ychart),chartT)
42 | return jnp.sqrt(M.dim)*(xT-y_chartT[0])
43 | def c(chart,x,v,y,ychart):
44 | return jnp.sum(jnp.square(_c(chart,x,v,y,ychart)))
45 |
46 | # derivatives
47 | M.Frechet_mean_vgv_c = jit(jax.value_and_grad(c,(2,)))
48 | M.Frechet_mean_jacxv_c = jit(jax.jacrev(_c,(1,2)))
49 | M.Frechet_mean_jacxv_f = jit(jax.value_and_grad(f,(1,2)))
50 | def vgx_f(chart,x,v,y,ychart):
51 | _jacxv_c = M.Frechet_mean_jacxv_c(chart,x,v,y,ychart)
52 | jacv = -jnp.linalg.solve(_jacxv_c[1],_jacxv_c[0]) # implicit function theorem
53 |
54 | v_f, g_f = M.Frechet_mean_jacxv_f(chart,x,v)
55 | g_f = g_f[0]+jnp.dot(g_f[1],jacv)
56 |
57 | return v_f, g_f
58 | M.Frechet_mean_vgx_f = vgx_f
59 |
60 | def Frechet_mean(ys,x0,Log=None,options={}):
61 | # data
62 | ys = list(ys) # make sure y is subscriptable
63 | N = len(ys)
64 | chart = x0[1] # single chart for now, could be updated
65 |
66 | if Log is None:
67 | # combined optimization, no use of Log maps
68 | step_sizex=options.get('step_sizex',1e-1)
69 | step_sizevs=options.get('step_sizevs',1e-1)
70 | num_steps=options.get('num_steps',200)
71 | optx_update_mod=options.get('optx_update_mod',5)
72 |
73 | opt_initx, opt_updatex, get_paramsx = optimizers.adam(step_sizex)
74 | opt_initvs, opt_updatevs, get_paramsvs = optimizers.adam(step_sizevs)
75 |
76 | # tracking steps
77 | steps = (x0,)
78 |
79 | def step(step, params, ys, y_charts, opt_statex, opt_statevs):
80 | paramsx = get_paramsx(opt_statex); paramsvs = get_paramsvs(opt_statevs)
81 | valuex = None; gradx = jnp.zeros(M.dim); valuevs = (); gradvs = ()
82 | # for i in range(N):
83 | # vvs,gvs = vgv_c(params,paramsx,paramsvs[i],*ys[i])
84 | # valuevs += (vvs,); gradvs += gvs
85 |
86 | valuevs,gradvs = jax.vmap(M.Frechet_mean_vgv_c,(None,None,0,0,0))(params,paramsx,paramsvs,ys,y_charts)
87 | opt_statevs = opt_updatevs(step, jnp.array(gradvs).squeeze(), opt_statevs)
88 | if step % optx_update_mod == 0:
89 | # for i in range(N):
90 | # vx,gx = vgx_f(params,paramsx,paramsvs[i],*ys[i])
91 | # valuex += 1/N*vx; gradx = gradx+1/N*gx[0]
92 | valuex,gradx = jax.vmap(M.Frechet_mean_vgx_f,(None,None,0,0,0))(params,paramsx,paramsvs,ys,y_charts)
93 | valuex = jnp.mean(valuex,0); gradx = jnp.mean(gradx,0)
94 | opt_statex = opt_updatex(step, gradx, opt_statex)
95 | return (valuex, valuevs), (opt_statex, opt_statevs)
96 |
97 | # optim setup
98 | params = x0[1]
99 | paramsx = x0[0]
100 | paramsvs = jnp.zeros((N,M.dim))
101 | opt_statex = opt_initx(paramsx)
102 | opt_statevs = opt_initvs(paramsvs)
103 | valuex = 0; valuesvs = ()
104 | ys,y_charts=list(zip(*ys))
105 | ys = jnp.array(ys); y_charts = jnp.array(y_charts)
106 |
107 | for i in range(num_steps):
108 | (_valuex, valuevs), (opt_statex, opt_statevs) = step(i, params, ys, y_charts, opt_statex, opt_statevs)
109 | if _valuex:
110 | valuex = _valuex
111 | if i % 10 == 0:
112 | print("Step {} | T: {:0.6e} | T: {:0.6e}".format(i, valuex, jnp.max(valuevs)))
113 | if i % optx_update_mod == 0:
114 | steps += ((get_paramsx(opt_statex),chart),)
115 | print("Step {} | T: {:0.6e} | T: {:0.6e} ".format(i, valuex, jnp.max(valuevs)))
116 |
117 | m = (get_paramsx(opt_statex),params)
118 | vs = get_paramsvs(opt_statevs)
119 |
120 | return (m,valuex,steps,vs)
121 |
122 |
123 | else:
124 | # Log based optimization
125 | def fopts(x):
126 | N = len(ys)
127 | Logs = np.zeros((N, x.shape[0]))
128 | for i in range(N):
129 | Logs[i] = Log((x,chart), ys[i])[0]
130 |
131 | res = (1. / N) * np.sum(np.square(Logs))
132 | grad = -(2. / N) * np.sum(Logs, 0)
133 |
134 | return (res, grad)
135 |
136 | # tracking steps
137 | global _steps
138 | _steps = (x0,)
139 | def save_step(k):
140 | global _steps
141 | _steps += ((k,chart),)
142 |
143 | res = minimize(fopts, x0[0], method='BFGS', jac=True, options=options, callback=save_step)
144 |
145 | return ((res.x,x0[1]), res.fun, _steps)
146 |
147 | M.Frechet_mean = Frechet_mean
148 |
--------------------------------------------------------------------------------
/jaxgeometry/statistics/diffusion_mean.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.params import *
22 |
23 | # Delyon/Hu guided process
24 | from jaxgeometry.stochastics.guided_process import *
25 |
26 | from jaxgeometry.statistics.iterative_mle import *
27 |
28 | def initialize(M):
29 |
30 | # guide function
31 | phi = lambda q,v,s: jnp.tensordot((1/s)*jnp.linalg.cholesky(M.g(q)).T,M.StdLog(q,M.F((v,q[1]))).flatten(),(1,0))
32 | A = lambda x,v,w,s: (s**(-2))*jnp.dot(v,jnp.dot(M.g(x),w))
33 | logdetA = lambda x,s: jnp.linalg.slogdet(s**(-2)*M.g(x))[1]
34 |
35 | (Brownian_coords_guided,sde_Brownian_coords_guided,chart_update_Brownian_coords_guided,log_p_T,neg_log_p_Ts) = get_guided(
36 | M,M.sde_Brownian_coords,M.chart_update_Brownian_coords,phi,
37 | lambda x,s: s*jnp.linalg.cholesky(M.gsharp(x)),A,logdetA)
38 |
39 | # optimization setup
40 | N = 1 # bridge samples per datapoint
41 | _dts = dts(n_steps=100,T=1.)
42 |
43 | # define parameters
44 | x = M.coords(jnp.zeros(M.dim))
45 | params_inds = (0,5)
46 |
47 | # function to update charts for position depends parameters
48 | def params_update(state, chart):
49 | try:
50 | ((x,m,v),),*s = state
51 | if M.do_chart_update((x,chart)):
52 | new_chart = M.centered_chart((x,chart))
53 | (x,chart) = M.update_coords((x,chart),new_chart)
54 | return optimizers.OptimizerState(((x,m,v),),*s),chart
55 | except ValueError: # state is packed
56 | states_flat, tree_def, subtree_defs = state
57 | ((x,m,v),*s) = states_flat
58 | if M.do_chart_update((x,chart)):
59 | new_chart = M.centered_chart((x,chart))
60 | (x,chart) = M.update_coords((x,chart),new_chart)
61 | states_flat = ((x,m,v),*s)
62 | return (states_flat,tree_def,subtree_defs),chart
63 |
64 | M.diffusion_mean = lambda samples,params=(x[0]+.1*np.random.normal(size=M.dim),jnp.array(.2,dtype="float32")),N=N,num_steps=80: \
65 | iterative_mle(samples,\
66 | neg_log_p_Ts,\
67 | params,params_inds,params_update,x[1],_dts,M,\
68 | N=N,num_steps=num_steps,step_size=1e-2)
69 |
70 |
--------------------------------------------------------------------------------
/jaxgeometry/statistics/iterative_mle.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 |
22 | from jax.example_libraries import optimizers
23 |
24 | def iterative_mle(obss,neg_log_p_Ts,params,params_inds,params_update,chart,_dts,M,N=1,step_size=1e-1,num_steps=50):
25 | opt_init, opt_update, get_params = optimizers.adam(step_size)
26 | vg = jax.value_and_grad(neg_log_p_Ts,params_inds)
27 |
28 | def step(step, params, opt_state, chart):
29 | params = get_params(opt_state)
30 | value,grads = vg(params[0],chart,obss,dWs(len(obss[0])*N*M.dim,_dts).reshape(-1,_dts.shape[0],N,M.dim),_dts,*params[1:])
31 | opt_state = opt_update(step, grads, opt_state)
32 | opt_state,chart = params_update(opt_state, chart)
33 | return (value,opt_state,chart)
34 |
35 | opt_state = opt_init(params)
36 | values = (); paramss = ()
37 |
38 | for i in range(num_steps):
39 | (value, opt_state, chart) = step(i, params, opt_state, chart)
40 | values += (value,); paramss += ((*get_params(opt_state),chart),)
41 | if i % 1 == 0:
42 | print("Step {} | T: {:0.6e} | T: {}".format(i, value, str((get_params(opt_state),chart))))
43 | print("Final {} | T: {:0.6e} | T: {}".format(i, value, str(get_params(opt_state))))
44 |
45 | return (get_params(opt_state),chart,value,jnp.array(values),paramss)
46 |
--------------------------------------------------------------------------------
/jaxgeometry/statistics/tangent_PCA.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 |
23 | from jaxgeometry.utils import *
24 | from sklearn.decomposition import PCA
25 |
26 | def tangent_PCA(M, Log, mean, ys):
27 | Logs = jax.vmap(lambda y,chart: Log(mean,(y,chart))[0])(ys[0],ys[1])
28 |
29 | print(Logs.shape)
30 | pca = PCA()
31 | pca.fit(Logs)
32 | pca.transformed_Logs = pca.transform(Logs)
33 |
34 | return pca
35 |
--------------------------------------------------------------------------------
/jaxgeometry/stochastics/Brownian_coords.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | def initialize(M):
24 | """ Brownian motion in coordinates """
25 |
26 | def sde_Brownian_coords(c,y):
27 | t,x,chart,s = c
28 | dt,dW = y
29 |
30 | gsharpx = M.gsharp((x,chart))
31 | X = s*jnp.linalg.cholesky(gsharpx)
32 | det = -.5*(s**2)*jnp.einsum('kl,ikl->i',gsharpx,M.Gamma_g((x,chart)))
33 | sto = jnp.tensordot(X,dW,(1,0))
34 | return (det,sto,X,0.)
35 |
36 | def chart_update_Brownian_coords(x,chart,*ys):
37 | if M.do_chart_update is None:
38 | return (x,chart,*ys)
39 |
40 | update = M.do_chart_update(x)
41 | new_chart = M.centered_chart((x,chart))
42 | new_x = M.update_coords((x,chart),new_chart)[0]
43 |
44 | return (jnp.where(update,
45 | new_x,
46 | x),
47 | jnp.where(update,
48 | new_chart,
49 | chart),
50 | *ys)
51 |
52 | M.sde_Brownian_coords = sde_Brownian_coords
53 | M.chart_update_Brownian_coords = chart_update_Brownian_coords
54 | M.Brownian_coords = jit(lambda x,dts,dWs,stdCov=1.: integrate_sde(sde_Brownian_coords,integrator_ito,chart_update_Brownian_coords,x[0],x[1],dts,dWs,stdCov)[0:3])
55 |
--------------------------------------------------------------------------------
/jaxgeometry/stochastics/Brownian_development.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.utils import *
23 |
24 | def initialize(M):
25 | """ Brownian motion from stochastic development """
26 |
27 | def Brownian_development(x,dts,dWs):
28 | # amend x with orthogonal basis to get initial frame bundle element
29 | gsharpx = M.gsharp(x)
30 | nu = jnp.linalg.cholesky(gsharpx)
31 | u = (jnp.concatenate((x[0],nu.flatten())),x[1])
32 |
33 | (ts,us,charts) = M.stochastic_development(u,dts,dWs)
34 |
35 | return (ts,us[:,0:M.dim],charts)
36 |
37 | M.Brownian_development = Brownian_development
38 |
--------------------------------------------------------------------------------
/jaxgeometry/stochastics/Brownian_inv.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | def initialize(G):
24 | """ Brownian motion with respect to left/right invariant metric """
25 |
26 | assert(G.invariance == 'left')
27 |
28 | def sde_Brownian_inv(c,y):
29 | t,g,_,sigma = c
30 | dt,dW = y
31 |
32 | X = jnp.tensordot(G.invpf(g,G.eiLA),sigma,(2,0))
33 | det = -.5*jnp.tensordot(jnp.diagonal(G.C,0,2).sum(1),X,(0,2))
34 | sto = jnp.tensordot(X,dW,(2,0))
35 | return (det,sto,X,jnp.zeros_like(sigma))
36 |
37 | G.sde_Brownian_inv = sde_Brownian_inv
38 | G.Brownian_inv = lambda g,dts,dWt,sigma=jnp.eye(G.dim): integrate_sde(G.sde_Brownian_inv,integrator_stratonovich,None,g,None,dts,dWt,sigma)[0:3]
39 |
40 |
--------------------------------------------------------------------------------
/jaxgeometry/stochastics/Brownian_process.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | def initialize(G):
24 | """ Brownian motion with respect to left/right invariant metric """
25 |
26 | assert(G.invariance == 'left')
27 |
28 | def sde_Brownian_process(c,y):
29 | t,g,_,sigma = c
30 | dt,dW = y
31 |
32 | X = jnp.tensordot(G.invpf(g,G.eiLA),sigma,(2,0))
33 | det = jnp.zeros_like(g)
34 | sto = jnp.tensordot(X,dW,(2,0))
35 | return (det,sto,X,0.)
36 |
37 | G.sde_Brownian_process = sde_Brownian_process
38 | G.Brownian_process = lambda g,dts,dWt,sigma=jnp.eye(G.dim): integrate_sde(G.sde_Brownian_process,integrator_stratonovich,None,g,None,dts,dWt,sigma)[0:3]
39 |
40 |
--------------------------------------------------------------------------------
/jaxgeometry/stochastics/Brownian_sR.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 | from jaxgeometry.setup import *
22 | from jaxgeometry.utils import *
23 |
24 | def initialize(M):
25 | """ sub-Riemannian Brownian motion """
26 |
27 | def sde_Brownian_sR(c,y):
28 | t,x,chart,s = c
29 | dt,dW = y
30 |
31 | D = s*M.D((x,chart))
32 | # D0 = \sum_{i=1}^m div_\mu(X_i) X_i) - not implemented yet
33 | det = jnp.zeros_like(x) # Y^k(x)=X_0^k(x)+(1/2)\sum_{i=1}^m \langle \nabla X_i^k(x),X_i(x)\rangle
34 | sto = jnp.tensordot(D,dW,(1,0))
35 | return (det,sto,D,0.)
36 |
37 | def chart_update_Brownian_sR(x,chart,*ys):
38 | if M.do_chart_update is None:
39 | return (x,chart,*ys)
40 |
41 | update = M.do_chart_update(x)
42 | new_chart = M.centered_chart((x,chart))
43 | new_x = M.update_coords((x,chart),new_chart)[0]
44 |
45 | return (jnp.where(update,
46 | new_x,
47 | x),
48 | jnp.where(update,
49 | new_chart,
50 | chart),*ys)
51 |
52 | M.sde_Brownian_sR = sde_Brownian_sR
53 | M.chart_update_Brownian_sR = chart_update_Brownian_sR
54 | M.Brownian_sR = jit(lambda x,dts,dWs,stdCov=1.: integrate_sde(sde_Brownian_sR,integrator_ito,chart_update_Brownian_sR,x[0],x[1],dts,dWs,stdCov)[0:3])
55 |
--------------------------------------------------------------------------------
/jaxgeometry/stochastics/Eulerian.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | ###############################################################
24 | # Eulerian / stochastic EPDiff acting on landmarks
25 | ###############################################################
26 | def initialize(M,k=None):
27 | dq = jit(grad(M.H,argnums=1))
28 | dp = jit(lambda q,p: -gradx(M.H)(q,p))
29 |
30 | # noise basis
31 | if k is None: # use landmark kernel per default
32 | k = M.k
33 |
34 | k_q = lambda q1,q2: k(q1.reshape((-1,M.m))[:,np.newaxis,:]-q2.reshape((-1,M.m))[np.newaxis,:,:])
35 | K = lambda q1,q2: (k_q(q1,q2)[:,:,np.newaxis,np.newaxis]*jnp.eye(M.m)[np.newaxis,np.newaxis,:,:]).transpose((0,2,1,3)).reshape((M.dim,-1))
36 |
37 | def sde_Eulerian(c,y):
38 | t,x,chart,sigmas_x,sigmas_a = c
39 | dt,dW = y
40 | dqt = dq((x[0],chart),x[1])
41 | dpt = dp((x[0],chart),x[1])
42 |
43 | sigmas_adW = sigmas_a*dW[:,np.newaxis]
44 | sigmadWq = jnp.tensordot(K(x[0],sigmas_x),sigmas_adW.flatten(),(1,0))
45 | sigmadWp = jnp.tensordot(
46 | jax.jacrev(
47 | lambda lq: jnp.tensordot(K(lq,sigmas_x),sigmas_adW.flatten(),(1,0)).flatten(),
48 | )(x[0]),
49 | x[1],(1,0))
50 |
51 | X = None # to be implemented
52 | det = jnp.stack((dqt,dpt))
53 | sto = jnp.stack((sigmadWq,sigmadWp))
54 | return (det,sto,X,jnp.zeros_like(sigmas_x),jnp.zeros_like(sigmas_a))
55 |
56 | def chart_update_Eulerian(xp,chart,*cy):
57 | if M.do_chart_update is None:
58 | return (xp,chart,*cy)
59 |
60 | p = xp[1]
61 | x = (xp[0],chart)
62 |
63 | update = M.do_chart_update(x)
64 | new_chart = M.centered_chart(M.F(x))
65 | new_x = M.update_coords(x,new_chart)[0]
66 |
67 | return (jnp.where(update,
68 | jnp.stack((new_x,M.update_covector(x,new_x,new_chart,p))),
69 | xp),
70 | jnp.where(update,
71 | new_chart,
72 | chart),
73 | *cy)
74 |
75 | M.Eulerian_qp = lambda q,p,sigmas_x,sigmas_a,dts,dWs: integrate_sde(sde_Eulerian,integrator_stratonovich,chart_update_Eulerian,jnp.stack((q[0],p)),q[1],dts,dWs,sigmas_x,sigmas_a)
76 | M.Eulerian = lambda q,p,sigmas_x,sigmas_a,dts,dWs: M.Eulerian_qp(q,p,sigmas_x,sigmas_a,dts,dWs)[0:3]
77 |
--------------------------------------------------------------------------------
/jaxgeometry/stochastics/Langevin.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | ###############################################################
24 | # Langevin equations https://arxiv.org/abs/1605.09276
25 | ###############################################################
26 | def initialize(M):
27 | dq = jit(grad(M.H,argnums=1))
28 | dp = jit(lambda q,p: -gradx(M.H)(q,p))
29 |
30 | def sde_Langevin(c,y):
31 | t,x,chart,l,s = c
32 | dt,dW = y
33 | dqt = dq((x[0],chart),x[1])
34 | dpt = dp((x[0],chart),x[1])-l*dq((x[0],chart),x[1])
35 |
36 | X = jnp.stack((jnp.zeros((M.dim,M.dim)),s*jnp.eye(M.dim)))
37 | det = jnp.stack((dqt,dpt))
38 | sto = jnp.tensordot(X,dW,(1,0))
39 | return (det,sto,X,jnp.zeros_like(l),jnp.zeros_like(s))
40 |
41 | def chart_update_Langevin(xp,chart,*cy):
42 | if M.do_chart_update is None:
43 | return (xp,chart,*cy)
44 |
45 | p = xp[1]
46 | x = (xp[0],chart)
47 |
48 | update = M.do_chart_update(x)
49 | new_chart = M.centered_chart(M.F(x))
50 | new_x = M.update_coords(x,new_chart)[0]
51 |
52 | return (jnp.where(update,
53 | jnp.stack((new_x,M.update_covector(x,new_x,new_chart,p))),
54 | xp),
55 | jnp.where(update,
56 | new_chart,
57 | chart),
58 | *cy)
59 |
60 | M.Langevin_qp = lambda q,p,l,s,dts,dWt: integrate_sde(sde_Langevin,integrator_ito,chart_update_Langevin,jnp.stack((q[0],p)),q[1],dts,dWt,l,s)
61 |
62 | M.Langevin = lambda q,p,l,s,dts,dWt: M.Langevin_qp(q,p,l,s,dts,dWt)[0:3]
63 |
--------------------------------------------------------------------------------
/jaxgeometry/stochastics/__init__.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 |
21 |
--------------------------------------------------------------------------------
/jaxgeometry/stochastics/diagonal_conditioning.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | def initialize(M,sde_product,chart_update_product,integrator=integrator_ito,T=1):
24 | """ diagonally conditioned product diffusions """
25 |
26 | def sde_diagonal(c,y):
27 | if M.do_chart_update is None:
28 | t,x,chart,T,*cy = c
29 | else:
30 | t,x,chart,T,ref_chart,*cy = c
31 | dt,dW = y
32 |
33 | (det,sto,X,*dcy) = sde_product((t,x,chart,*cy),y)
34 |
35 | if M.do_chart_update is None:
36 | xref = x
37 | else:
38 | xref = jax.vmap(lambda x,chart: M.update_coords((x,chart),ref_chart)[0],0)(x,chart)
39 | m = jnp.mean(xref,0) # mean
40 | href = jax.lax.cond(t.
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | #######################################################################
24 | # guided processes, Delyon/Hu 2006 #
25 | #######################################################################
26 |
27 | # hit target v at time t=Tend
28 | def get_guided(M,sde,chart_update,phi,sqrtCov=None,A=None,logdetA=None,method='DelyonHu',integration='ito'):
29 | """ guided diffusions """
30 |
31 | def sde_guided(c,y):
32 | t,x,chart,log_likelihood,log_varphi,T,v,*cy = c
33 | xchart = (x,chart)
34 | dt,dW = y
35 |
36 | (det,sto,X,*dcy) = sde((t,x,chart,*cy),y)
37 |
38 | h = jax.lax.cond(t.
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | def initialize(M,sde,chart_update,integrator=integrator_ito):
24 | """ product diffusions """
25 |
26 | def sde_product(c,y):
27 | t,x,chart,*cy = c
28 | dt,dW = y
29 |
30 | (det,sto,X,*dcy) = jax.vmap(lambda x,chart,dW,*_cy: sde((t,x,chart,*_cy),(dt,dW)),0)(x,chart,dW,*cy)
31 |
32 | return (det,sto,X,*dcy)
33 |
34 | chart_update_product = jax.vmap(chart_update)
35 |
36 | product = jit(lambda x,dts,dWs,*cy: integrate_sde(sde_product,integrator,chart_update_product,x[0],x[1],dts,dWs,*cy))
37 |
38 | return (product,sde_product,chart_update_product)
39 |
40 | # for initializing parameters
41 | def tile(x,N):
42 | try:
43 | return jnp.tile(x,(N,)+(1,)*x.ndim)
44 | except AttributeError:
45 | try:
46 | return jnp.tile(x,N)
47 | except TypeError:
48 | return tuple([tile(y,N) for y in x])
49 |
--------------------------------------------------------------------------------
/jaxgeometry/stochastics/stochastic_coadjoint.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | def initialize(G,Psi=None,r=None):
24 | """ stochastic coadjoint motion with left/right invariant metric
25 | see Noise and dissipation on coadjoint orbits arXiv:1601.02249 [math.DS]
26 | and EulerPoincare.py """
27 |
28 | assert(G.invariance == 'left')
29 |
30 | # Matrix function Psi:LA\rightarrow R^r must be defined beforehand
31 | # example here from arXiv:1601.02249
32 | if Psi is None:
33 | sigmaPsi = jnp.eye(G.dim)
34 | Psi = lambda mu: jnp.dot(sigmaPsi,mu)
35 | # r = Psi.shape[0]
36 | r = G.dim
37 | assert(Psi is not None and r is not None)
38 |
39 | def sde_stochastic_coadjoint(c,y):
40 | t,mu,_ = c
41 | dt,dW = y
42 |
43 | xi = G.invFl(mu)
44 | det = -G.coad(xi,mu)
45 | Sigma = G.coad(mu,jax.jacrev(Psi)(mu).transpose((1,0)))
46 | sto = jnp.tensordot(Sigma,dW,(1,0))
47 | return (det,sto,Sigma)
48 | G.sde_stochastic_coadjoint = sde_stochastic_coadjoint
49 | G.stochastic_coadjoint = lambda mu,dts,dWt: integrate_sde(G.sde_stochastic_coadjoint,integrator_stratonovich,None,mu,None,dts,dWt)
50 |
51 | # reconstruction as in Euler-Poincare / Lie-Poisson reconstruction
52 | if not hasattr(G,'EPrec'):
53 | from jaxgeometry.group import EulerPoincare
54 | EulerPoincare.initialize(G)
55 | G.stochastic_coadjointrec = G.EPrec
56 |
57 |
--------------------------------------------------------------------------------
/jaxgeometry/stochastics/stochastic_development.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.utils import *
22 |
23 | def initialize(M):
24 | """ development and stochastic development from R^d to M """
25 |
26 | # Deterministic development
27 | def ode_development(c,y):
28 | t,u,chart = c
29 | dgamma, = y
30 |
31 | u = (u,chart)
32 | nu = u[0][M.dim:].reshape((M.dim,-1))
33 | m = nu.shape[1]
34 |
35 | det = jnp.tensordot(M.Horizontal(u)[:,0:m],dgamma,(1,0))
36 |
37 | return det
38 |
39 | M.development = jit(lambda u,dgamma,dts: integrate(ode_development,M.chart_update_FM,u[0],u[1],dts,dgamma))
40 |
41 | # Stochastic development
42 | def sde_development(c,y):
43 | t,u,chart = c
44 | dt,dW = y
45 |
46 | u = (u,chart)
47 | nu = u[0][M.dim:].reshape((M.dim,-1))
48 | m = nu.shape[1]
49 |
50 | sto = jnp.tensordot(M.Horizontal(u)[:,0:m],dW,(1,0))
51 |
52 | return (jnp.zeros_like(sto), sto, M.Horizontal(u)[:,0:m])
53 |
54 | M.sde_development = sde_development
55 | M.stochastic_development = jit(lambda u,dts,dWs: integrate_sde(sde_development,integrator_stratonovich,M.chart_update_FM,u[0],u[1],dts,dWs))
56 |
--------------------------------------------------------------------------------
/jaxgeometry/utils.py:
--------------------------------------------------------------------------------
1 | ## This file is part of Jax Geometry
2 | #
3 | # Copyright (C) 2021, Stefan Sommer (sommer@di.ku.dk)
4 | # https://bitbucket.org/stefansommer/jaxgeometry
5 | #
6 | # Jax Geometry is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # Jax Geometry is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with Jax Geometry. If not, see .
18 | #
19 |
20 | from jaxgeometry.setup import *
21 | from jaxgeometry.params import *
22 |
23 | #######################################################################
24 | # various useful functions #
25 | #######################################################################
26 |
27 | # jax.grad but only for the x variable of a function taking coordinates and chart
28 | def gradx(f):
29 | def fxchart(x,chart,*args,**kwargs):
30 | return f((x,chart),*args,**kwargs)
31 | def gradf(x,*args,**kwargs):
32 | return jax.grad(fxchart,argnums=0)(x[0],x[1],*args,**kwargs)
33 | return gradf
34 |
35 | # jax.jacfwd but only for the x variable of a function taking coordinates and chart
36 | def jacfwdx(f):
37 | def fxchart(x,chart,*args,**kwargs):
38 | return f((x,chart),*args,**kwargs)
39 | def jacf(x,*args,**kwargs):
40 | return jax.jacfwd(fxchart,argnums=0)(x[0],x[1],*args,**kwargs)
41 | return jacf
42 |
43 | # jax.jacrev but only for the x variable of a function taking coordinates and chart
44 | def jacrevx(f):
45 | def fxchart(x,chart,*args,**kwargs):
46 | return f((x,chart),*args,**kwargs)
47 | def jacf(x,*args,**kwargs):
48 | return jax.jacrev(fxchart,argnums=0)(x[0],x[1],*args,**kwargs)
49 | return jacf
50 |
51 | # hessian only for the x variable of a function taking coordinates and chart
52 | def hessianx(f):
53 | return jacfwdx(jacrevx(f))
54 |
55 | # evaluation with pass through derivatives
56 | def straight_through(f,x,*ys):
57 | # Create an exactly-zero expression with Sterbenz lemma that has
58 | # an exactly-one gradient.
59 | if type(x) == type(()):
60 | zeros = tuple([xi - jax.lax.stop_gradient(xi) for xi in x])
61 | fx = jax.lax.stop_gradient(f(x,*ys))
62 | return tuple([fxi - jax.lax.stop_gradient(fxi) for fxi in fx])
63 | else:
64 | zero = x - jax.lax.stop_gradient(x)
65 | return zeros + jax.lax.stop_gradient(f(x,*ys))
66 |
67 |
68 | # time increments, deterministic
69 | def dts(T=T,n_steps=n_steps):
70 | return jnp.array([T/n_steps]*n_steps)
71 | # standard noise realisations
72 |
73 | # time increments, stochastic
74 | seed = 42
75 | global key
76 | key = jax.random.PRNGKey(seed)
77 | def dWs(d,_dts=None,num=1):
78 | global key
79 | keys = jax.random.split(key,num=num+1)
80 | key = keys[0]
81 | subkeys = keys[1:]
82 | if _dts == None:
83 | _dts = dts()
84 | if num == 1:
85 | return jnp.sqrt(_dts)[:,None]*random.normal(subkeys[0],(_dts.shape[0],d))
86 | else:
87 | return jax.vmap(lambda subkey: jnp.sqrt(_dts)[:,None]*random.normal(subkey,(_dts.shape[0],d)))(subkeys)
88 |
89 | # Integrator (deterministic)
90 | def integrator(ode_f,chart_update=None,method=default_method):
91 | if chart_update == None: # no chart update
92 | chart_update = lambda *args: args[0:2]
93 |
94 | # euler:
95 | def euler(c,y):
96 | t,x,chart = c
97 | dt,*_ = y
98 | return ((t+dt,*chart_update(x+dt*ode_f(c,y[1:]),chart,y[1:])),)*2
99 |
100 | # Runge-kutta:
101 | def rk4(c,y):
102 | t,x,chart = c
103 | dt,*_ = y
104 | k1 = ode_f(c,y[1:])
105 | k2 = ode_f((t+dt/2,x + dt/2*k1,chart),y[1:])
106 | k3 = ode_f((t+dt/2,x + dt/2*k2,chart),y[1:])
107 | k4 = ode_f((t,x + dt*k3,chart),y[1:])
108 | return ((t+dt,*chart_update(x + dt/6*(k1 + 2*k2 + 2*k3 + k4),chart,y[1:])),)*2
109 |
110 | if method == 'euler':
111 | return euler
112 | elif method == 'rk4':
113 | return rk4
114 | else:
115 | assert(False)
116 |
117 | # return symbolic path given ode and integrator
118 | def integrate(ode,chart_update,x,chart,dts,*ys):
119 | _,xs = lax.scan(integrator(ode,chart_update),
120 | (0.,x,chart),
121 | (dts,*ys))
122 | return xs if chart_update is not None else xs[0:2]
123 |
124 | # sde functions should return (det,sto,Sigma) where
125 | # det is determinisitc part, sto is stochastic part,
126 | # and Sigma stochastic generator (i.e. often sto=dot(Sigma,dW)
127 |
128 |
129 | def integrate_sde(sde,integrator,chart_update,x,chart,dts,dWs,*cy):
130 | _,xs = lax.scan(integrator(sde,chart_update),
131 | (0.,x,chart,*cy),
132 | (dts,dWs,))
133 | return xs
134 |
135 | def integrator_stratonovich(sde_f,chart_update=None):
136 | if chart_update == None: # no chart update
137 | chart_update = lambda xp,chart,*cy: (xp,chart,*cy)
138 |
139 | def euler_heun(c,y):
140 | t,x,chart,*cy = c
141 | dt,dW = y
142 |
143 | (detx, stox, X, *dcy) = sde_f(c,y)
144 | tx = x + stox
145 | cy_new = tuple([y+dt*dy for (y,dy) in zip(cy,dcy)])
146 | return ((t+dt,*chart_update(x + dt*detx + 0.5*(stox + sde_f((t+dt,tx,chart,*cy),y)[1]), chart, *cy_new),),)*2
147 |
148 | return euler_heun
149 |
150 | def integrator_ito(sde_f,chart_update=None):
151 | if chart_update == None: # no chart update
152 | chart_update = lambda xp,chart,*cy: (xp,chart,*cy)
153 |
154 | def euler(c,y):
155 | t,x,chart,*cy = c
156 | dt,dW = y
157 |
158 | (detx, stox, X, *dcy) = sde_f(c,y)
159 | cy_new = tuple([y+dt*dy for (y,dy) in zip(cy,dcy)])
160 | return ((t+dt,*chart_update(x + dt*detx + stox, chart, *cy_new)),)*2
161 |
162 | return euler
163 |
164 |
165 | def cross(a, b):
166 | return jnp.array([
167 | a[1]*b[2] - a[2]*b[1],
168 | a[2]*b[0] - a[0]*b[2],
169 | a[0]*b[1] - a[1]*b[0]])
170 |
171 | def mmT(A,C=None):
172 | return A@A.T if C is None else A@C@A.T
173 |
174 | def mTm(A,C=None):
175 | return A.T@A if C is None else A.T@C@A
176 |
177 | #import numpy as np
178 | #def python_scan(f, init, xs, length=None):
179 | # if xs is None:
180 | # xs = [None] * length
181 | # carry = init
182 | # ys = []
183 | # for i in range(xs[0].shape[0]):
184 | # x = (xs[0][i],)
185 | # carry, y = f(carry, x)
186 | # ys.append(y)
187 | # return carry, np.stack(ys)
188 | #
189 |
--------------------------------------------------------------------------------
/logo/stocso31s.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ComputationalEvolutionaryMorphometry/jaxgeometry/70e22f05966d5076adab37d9579a4900f486485f/logo/stocso31s.jpg
--------------------------------------------------------------------------------
/makefile:
--------------------------------------------------------------------------------
1 | .PHONY: test
2 |
3 | test:
4 | tests/test.sh
5 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "jaxdifferentialgeometry"
3 | version = "0.9.4"
4 | authors = [
5 | { name="Stefan Sommer", email="sommer@di.ku.dk" },
6 | ]
7 | description = "Differential geometry using jax"
8 | readme = "README.md"
9 | requires-python = ">=3.8"
10 | classifiers = [
11 | "Programming Language :: Python :: 3",
12 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
13 | "Operating System :: OS Independent",
14 | ]
15 | dependencies = [
16 | "jax",
17 | "jaxlib",
18 | "scikit-learn",
19 | ]
20 |
21 | [project.optional-dependencies]
22 | examples = [
23 | "jupyter",
24 | "matplotlib",
25 | ]
26 |
27 | [project.urls]
28 | Homepage = "https://github.com/computationalevolutionarymorphometry/jaxgeometry/"
29 | Issues = "https://github.com/computationalevolutionarymorphometry/jaxgeometry/issues"
30 |
31 | [build-system]
32 | requires = [
33 | "setuptools>=60",
34 | "setuptools-scm>=8.0",
35 | ]
36 | build-backend = "setuptools.build_meta"
37 |
38 | [tool.setuptools_scm]
39 | version_file = "jaxgeometry/_version.py"
40 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | jax
2 | jaxlib
3 | matplotlib
4 | jupyter
5 | scikit-learn
6 | flax
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 |
4 | setup(
5 | packages=find_packages(),
6 | )
7 |
--------------------------------------------------------------------------------
/tests/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # install dependencies
4 | pip install pytest nbmake
5 |
6 | # run tests
7 | PYTHONPATH=$(pwd) pytest --nbmake --nbmake-timeout=12000 examples/*.ipynb papers/*.ipynb
8 |
--------------------------------------------------------------------------------