├── fig1.png
├── fig3.png
├── GoTube.jpg
├── Appendix.pdf
├── GoTube1.jpg
├── config.ini
├── logged
└── .gitignore
├── rl
├── ctrnn_osc.npz
├── lds_ctrnn.npz
└── pendulum_ctrnn.npz
├── saved_outputs
└── .gitignore
├── .gitignore
├── requirements.txt
├── timer.py
├── compute_volume_intersection.py
├── table4.sh
├── polar_coordinates.py
├── LICENSE.md
├── performance_log.py
├── dynamics.py
├── figure4.sh
├── table2.sh
├── plot.py
├── main.py
├── README.md
├── go_tube.py
├── stochastic_reachtube.py
└── benchmarks.py
/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DatenVorsprung/GoTube/HEAD/fig1.png
--------------------------------------------------------------------------------
/fig3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DatenVorsprung/GoTube/HEAD/fig3.png
--------------------------------------------------------------------------------
/GoTube.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DatenVorsprung/GoTube/HEAD/GoTube.jpg
--------------------------------------------------------------------------------
/Appendix.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DatenVorsprung/GoTube/HEAD/Appendix.pdf
--------------------------------------------------------------------------------
/GoTube1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DatenVorsprung/GoTube/HEAD/GoTube1.jpg
--------------------------------------------------------------------------------
/config.ini:
--------------------------------------------------------------------------------
1 | [files]
2 | output_directory = ./saved_outputs/
3 | output_file = _GoTube.txt
--------------------------------------------------------------------------------
/logged/.gitignore:
--------------------------------------------------------------------------------
1 | # just to make sure that the directory logged is created
2 | *.json
--------------------------------------------------------------------------------
/rl/ctrnn_osc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DatenVorsprung/GoTube/HEAD/rl/ctrnn_osc.npz
--------------------------------------------------------------------------------
/rl/lds_ctrnn.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DatenVorsprung/GoTube/HEAD/rl/lds_ctrnn.npz
--------------------------------------------------------------------------------
/saved_outputs/.gitignore:
--------------------------------------------------------------------------------
1 | # just to make sure that the directory saved_outputs is created
2 | *.txt
--------------------------------------------------------------------------------
/rl/pendulum_ctrnn.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DatenVorsprung/GoTube/HEAD/rl/pendulum_ctrnn.npz
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # PyCharm IDEA
2 | /.idea
3 | .project
4 | .pydevproject
5 | /__pycache__
6 |
7 | # Mac OS
8 | /.DS_Store
9 |
10 | # Virtual environment
11 | /venv
12 | /env
13 |
14 | # Output and Log files
15 | all_prob_scores.csv
16 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cycler==0.10.0
2 | future==0.18.2
3 | kiwisolver==1.3.1
4 | matplotlib==3.3.3
5 | numpy==1.22.2
6 | Pillow==8.0.1
7 | pyparsing==2.4.7
8 | python-dateutil==2.8.1
9 | scipy==1.8.1
10 | six==1.15.0
11 | torch==1.7.1
12 | typing-extensions==4.6.3
13 | jax==0.4.12
14 | jaxlib==0.4.12
--------------------------------------------------------------------------------
/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 |
4 | class Timer(object):
5 | def __init__(self, name=None):
6 | self.name = name
7 |
8 | def __enter__(self):
9 | self.tstart = time.time()
10 |
11 | def __exit__(self, type, value, traceback):
12 | True
13 | # if self.name:
14 | # print('[%s]' % self.name,)
15 | # print('Elapsed: %.2f seconds' % (time.time() - self.tstart))
--------------------------------------------------------------------------------
/compute_volume_intersection.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 |
4 | dim = 2
5 | volumes = []
6 |
7 | fname1 = f"logged/stat_0037.json"
8 | with open(fname1, "r") as f:
9 | log1 = json.load(f)
10 |
11 | radius1 = log1["stats"]["radius"]
12 | semiAxes_prod1 = log1["stats"]["semiAxes_prod"]
13 | volumes1 = log1["stats"]["volume"]
14 |
15 | fname2 = f"logged/stat_0038.json"
16 | with open(fname2, "r") as f:
17 | log2 = json.load(f)
18 |
19 | radius2 = log2["stats"]["radius"]
20 | semiAxes_prod2 = log2["stats"]["semiAxes_prod"]
21 | volumes2 = log2["stats"]["volume"]
22 |
23 | for r1, s1, v1, r2, s2, v2 in zip(radius1, semiAxes_prod1, volumes1, radius2, semiAxes_prod2, volumes2):
24 | d = 2 * max(r1*s1, r2*s2)
25 | v = d ** dim
26 | print(v)
27 | v = min(v, v1)
28 | v = min(v, v2)
29 | volumes.append(v)
30 |
31 | volumes = np.array(volumes)
32 |
33 | print("average volume", volumes.mean())
--------------------------------------------------------------------------------
/table4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source venv/bin/activate
4 |
5 | BENCHMARK=cartpoleLTC_RK
6 | TIME_STEP=0.01
7 | TIME_HORIZON_SHORT=0.35
8 | TIME_HORIZON_LONG=10
9 | INITIAL_RADIUS=1e-4
10 | # new parameters - to be defined
11 | python main.py --time_horizon $TIME_HORIZON_SHORT --benchmark $BENCHMARK --batch_size 10000 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.05 --score
12 | python main.py --time_horizon $TIME_HORIZON_LONG --benchmark $BENCHMARK --batch_size 10000 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.05 --score
13 |
14 |
15 | BENCHMARK=cartpoleCTRNN
16 | TIME_STEP=0.1
17 | TIME_HORIZON_SHORT=1
18 | TIME_HORIZON_LONG=10
19 | INITIAL_RADIUS=1e-4
20 | # new parameters - to be defined
21 | python main.py --time_horizon $TIME_HORIZON_SHORT --benchmark $BENCHMARK --batch_size 10000 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.05 --score
22 | python main.py --time_horizon $TIME_HORIZON_LONG --benchmark $BENCHMARK --batch_size 10000 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.05 --score
23 |
--------------------------------------------------------------------------------
/polar_coordinates.py:
--------------------------------------------------------------------------------
1 | # transformation between polar and cartesian coordinates
2 |
3 | import numpy as np
4 | import jax.numpy as jnp
5 | from jax import jit
6 | import dynamics
7 |
8 | # initialize random polar coordinates with dimension dim
9 | _rng = np.random.RandomState(12937)
10 |
11 |
12 | def uniform(start, end, dim, fixed_seed):
13 | if fixed_seed:
14 | global _rng
15 | return _rng.uniform(start, end, dim)
16 | else:
17 | return np.random.uniform(start, end, dim)
18 |
19 |
20 | def init_random_phi(dim, samples=1, num_gpus=1, fixed_seed=False):
21 | phi = uniform(0, jnp.pi, samples * (dim - 2), fixed_seed)
22 | phi = jnp.append(phi, uniform(0, 2 * jnp.pi, samples, fixed_seed))
23 | phi = jnp.reshape(phi, (num_gpus, samples // num_gpus, dim - 1), order="F")
24 |
25 | return phi
26 |
27 |
28 | @jit
29 | def polar2cart(rad, phi):
30 | return rad * polar2cart_no_rad(phi)
31 |
32 |
33 | def polar2cart_euclidean_metric(rad, phis, A0inv):
34 | return rad * jnp.matmul(A0inv, polar2cart_no_rad(phis))
35 |
36 |
37 | def polar2cart_no_rad(phi):
38 | return dynamics.polar2cart_no_rad(phi)
39 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | # Attribution-NonCommercial-ShareAlike 4.0 International
2 | This work is licensed under the
3 | [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](https://creativecommons.org/licenses/by-nc-sa/4.0/).
4 |
5 |
6 |
7 | ## Human-readable summary
8 | This is a human-readable summary of (and not a substitute for) the [License](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
9 |
10 |
11 | This license allows reusers to distribute, remix, adapt, and build upon the material in any medium or format for noncommercial purposes only, and only so long as attribution is given to the creator. If you remix, adapt, or build upon the material, you must license the modified material under identical terms.
12 | CC BY-NC-SA includes the following elements:
13 |
14 |
15 |
16 | * **BY** – Credit must be given to the creator
17 |
18 | * **NC** – Only noncommercial uses of the work are permitted
19 |
20 | * **SA** – Adaptations must be shared under the same terms
21 |
22 |
--------------------------------------------------------------------------------
/performance_log.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | registered_args = {}
5 | logged_stats = {}
6 |
7 |
8 | def log_args(args):
9 | global registered_args
10 | for k, v in dict(args).items():
11 | registered_args[k] = v
12 |
13 |
14 | def log_stat(stat):
15 | global logged_stats
16 | for k, v in stat.items():
17 | if not k in logged_stats.keys():
18 | logged_stats[k] = []
19 | logged_stats[k].append(v)
20 |
21 |
22 | def close_log(notes):
23 | global registered_args
24 | global logged_stats
25 | final_dict = {"args": registered_args, "stats": logged_stats, "notes": notes}
26 | os.makedirs("logged", exist_ok=True)
27 | for i in range(100000):
28 | fname = f"logged/stat_{i:04d}.json"
29 | if not os.path.isfile(fname):
30 | break
31 | with open(fname, "w") as f:
32 | json.dump(final_dict, f)
33 |
34 |
35 | def create_plot_file(files):
36 | for i in range(100000):
37 | fname = files['output_directory']+f"{i:04d}"+files['output_file']
38 | if not os.path.isfile(fname):
39 | break
40 | return fname
41 |
42 |
43 | def write_plot_file(fname, mode, t, cx, rad, M1):
44 | f = open(fname, mode)
45 | f.write(str(t) + " ")
46 | f.write(' '.join(map(str, cx.reshape(-1))) + " ")
47 | f.write(str(rad) + " ")
48 | f.write(' '.join(map(str, M1.reshape(-1))) + "\n")
49 | f.close()
50 |
51 |
52 | if __name__ == "__main__":
53 | print("registered_args: ", str(registered_args))
54 | log_args({"hello": "test"})
55 | print("registered_args: ", str(registered_args))
56 |
57 |
--------------------------------------------------------------------------------
/dynamics.py:
--------------------------------------------------------------------------------
1 | # computes the jacobian and the metric for a given model
2 |
3 | import jax.numpy as jnp
4 | import numpy as np
5 | from jax import jacfwd, jacrev, jit
6 |
7 | from scipy.linalg import eigh
8 | from numpy.linalg import inv
9 | import benchmarks as bm
10 |
11 |
12 | class FunctionDynamics:
13 | def __init__(self, model):
14 | self.model = model
15 |
16 | x = jnp.ones(self.model.dim)
17 | if jnp.sum(jnp.abs(self.model.fdyn(0.0, x) - self.model.fdyn(1.0, x))) > 1e-8:
18 | # https://github.com/google/jax/issues/47
19 | raise ValueError("Only time-invariant systems supported currently")
20 | self._cached_f_jac = jit(jacrev(lambda x: self.model.fdyn(0.0, x)))
21 |
22 | def f_jac_at(self, t, x):
23 | return jnp.array(self._cached_f_jac(x))
24 |
25 | def metric(self, Fmid, ellipsoids):
26 | if ellipsoids:
27 | A1inv = Fmid
28 | A1 = inv(A1inv)
29 | M1 = inv(A1inv @ A1inv.T)
30 |
31 | W, v = eigh(M1)
32 |
33 | W = abs(W) # to prevent nan-errors
34 |
35 | semiAxes = 1 / np.sqrt(W) # needed to compute volume of ellipse
36 |
37 | else:
38 | A1 = np.eye(Fmid.shape[0])
39 | M1 = np.eye(Fmid.shape[0])
40 | semiAxes = np.array([1])
41 |
42 | return M1, A1, semiAxes.prod()
43 |
44 |
45 | def polar2cart_no_rad(phi):
46 | sin_polar = jnp.sin(phi)
47 | cart = jnp.append(jnp.cos(phi), jnp.ones(1)) * jnp.append(jnp.ones(1), sin_polar)
48 | for i in range(1, jnp.size(phi)):
49 | cart *= jnp.append(jnp.ones(i + 1), sin_polar[:-i])
50 | return (
51 | cart # rad*polar2cart_no_rad(phi) is the true value of the cartesian coordinate
52 | )
53 |
54 | _jac_polar_cached = None
55 |
56 |
57 | def jacobian_polar_at(phi):
58 | global _jac_polar_cached
59 | if _jac_polar_cached is None:
60 | _jac_polar_cached = jit(jacfwd(polar2cart_no_rad))
61 | return _jac_polar_cached(phi)
62 |
63 |
64 | if __name__ == "__main__":
65 | fdyn = FunctionDynamics(bm.CartpoleCTRNN())
66 |
67 | print(fdyn.f_jac_at(0, fdyn.model.cx))
68 |
--------------------------------------------------------------------------------
/figure4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source venv/bin/activate
4 |
5 |
6 | python3 main.py --time_horizon 2 --mu 2.0 --benchmark CTRNNosc --batch_size 256 --time_step 0.5 --radius 0.001 --gamma 0.5 --score
7 | python3 main.py --time_horizon 2 --mu 1.8 --benchmark CTRNNosc --batch_size 256 --time_step 0.5 --radius 0.001 --gamma 0.5 --score
8 | python3 main.py --time_horizon 2 --mu 1.7 --benchmark CTRNNosc --batch_size 256 --time_step 0.5 --radius 0.001 --gamma 0.5 --score
9 | python3 main.py --time_horizon 2 --mu 1.6 --benchmark CTRNNosc --batch_size 256 --time_step 0.5 --radius 0.001 --gamma 0.5 --score
10 | python3 main.py --time_horizon 2 --mu 1.5 --benchmark CTRNNosc --batch_size 1024 --time_step 0.5 --radius 0.001 --gamma 0.5 --score
11 |
12 | python main.py --time_horizon 10 --mu 2.0 --benchmark pendulumCTRNN --batch_size 256 --time_step 0.5 --radius 0.01 --gamma 0.5 --score
13 | python main.py --time_horizon 10 --mu 1.8 --benchmark pendulumCTRNN --batch_size 256 --time_step 0.5 --radius 0.01 --gamma 0.5 --score
14 | python main.py --time_horizon 10 --mu 1.7 --benchmark pendulumCTRNN --batch_size 256 --time_step 0.5 --radius 0.01 --gamma 0.5 --score
15 | python main.py --time_horizon 10 --mu 1.6 --benchmark pendulumCTRNN --batch_size 256 --time_step 0.5 --radius 0.01 --gamma 0.5 --score
16 | python main.py --time_horizon 10 --mu 1.5 --benchmark pendulumCTRNN --batch_size 1024 --time_step 0.5 --radius 0.01 --gamma 0.5 --score
17 | python main.py --time_horizon 10 --mu 1.3 --benchmark pendulumCTRNN --batch_size 1024 --time_step 0.5 --radius 0.01 --gamma 0.5 --score
18 |
19 | python main.py --time_horizon 10 --mu 2.0 --benchmark ldsCTRNN --batch_size 256 --time_step 0.5 --radius 0.01 --gamma 0.5 --score
20 | python main.py --time_horizon 10 --mu 1.8 --benchmark ldsCTRNN --batch_size 256 --time_step 0.5 --radius 0.01 --gamma 0.5 --score
21 | python main.py --time_horizon 10 --mu 1.7 --benchmark ldsCTRNN --batch_size 256 --time_step 0.5 --radius 0.01 --gamma 0.5 --score
22 | python main.py --time_horizon 10 --mu 1.6 --benchmark ldsCTRNN --batch_size 256 --time_step 0.5 --radius 0.01 --gamma 0.5 --score
23 | python main.py --time_horizon 10 --mu 1.5 --benchmark ldsCTRNN --batch_size 1024 --time_step 0.5 --radius 0.01 --gamma 0.5 --score
24 | python main.py --time_horizon 10 --mu 1.3 --benchmark ldsCTRNN --batch_size 1024 --time_step 0.5 --radius 0.01 --gamma 0.5 --score
25 | python main.py --time_horizon 10 --mu 1.1 --benchmark ldsCTRNN --batch_size 1024 --time_step 0.5 --radius 0.01 --gamma 0.5 --score
--------------------------------------------------------------------------------
/table2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source venv/bin/activate
4 |
5 | #Brusselator
6 | BENCHMARK=bruss
7 | MU=1.1
8 | TIME_STEP=0.01
9 | TIME_HORIZON=9
10 | INITIAL_RADIUS=0.01
11 | # new parameters - to be defined
12 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 5 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.1 --score
13 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 5 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.01 --score
14 |
15 | # #Van der Pol
16 | BENCHMARK=vdp
17 | TIME_STEP=0.01
18 | TIME_HORIZON=40
19 | INITIAL_RADIUS=0.01
20 | # new parameters - to be defined
21 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 5 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.1 --score
22 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 5 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.01 --score
23 |
24 | # #Robotarm
25 | BENCHMARK=robot
26 | TIME_STEP=0.01
27 | TIME_HORIZON=40
28 | INITIAL_RADIUS=0.005
29 | # new parameters - to be defined
30 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 5 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.1 --score
31 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 5 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.01 --score
32 |
33 | #Dubins car
34 | BENCHMARK=dubins
35 | TIME_STEP=0.1
36 | TIME_HORIZON=15
37 | INITIAL_RADIUS=0.01
38 | # new parameters - to be defined
39 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 10 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.1 --score
40 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 10 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.01 --score
41 |
42 | # #MS cardiac cell
43 | BENCHMARK=ms
44 | TIME_STEP=0.01
45 | TIME_HORIZON=10
46 | INITIAL_RADIUS=1e-4
47 | # new parameters - to be defined
48 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 10 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.1 --score
49 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 10 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.01 --score
50 |
51 | # #Cartpole CTRNN
52 | BENCHMARK=cartpoleCTRNN
53 | TIME_STEP=0.02
54 | TIME_HORIZON=1
55 | INITIAL_RADIUS=1e-4
56 | # new parameters - to be defined
57 | BATCH_SIZE=10000
58 | python main.py --time_horizon $TIME_HORIZON --benchmark $BENCHMARK --batch_size $BATCH_SIZE --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.1 --score
59 | python main.py --time_horizon $TIME_HORIZON --benchmark $BENCHMARK --batch_size $BATCH_SIZE --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.01 --score
60 |
61 | # #Cartpole LTC
62 | BENCHMARK=cartpoleLTC_RK
63 | TIME_STEP=0.01
64 | TIME_HORIZON=0.35
65 | INITIAL_RADIUS=1e-4
66 | # new parameters - to be defined
67 | BATCH_SIZE=10000
68 | python main.py --time_horizon $TIME_HORIZON --benchmark $BENCHMARK --batch_size $BATCH_SIZE --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.1 --score
69 | python main.py --time_horizon $TIME_HORIZON --benchmark $BENCHMARK --batch_size $BATCH_SIZE --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.01 --score
70 |
71 |
--------------------------------------------------------------------------------
/plot.py:
--------------------------------------------------------------------------------
1 | # plotting the outputs from GoTube (ellipse, circle, intersections etc.)
2 |
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 |
6 | from numpy.linalg import svd
7 | from numpy.linalg import inv
8 |
9 | import stochastic_reachtube as reach
10 | import benchmarks as bm
11 | import configparser
12 | import argparse
13 | import pickle
14 | import os
15 |
16 |
17 | def draw_ellipse(ellipse, color, alph, axis_3d):
18 | c = ellipse[1:3]
19 | r = ellipse[3]
20 | M = np.reshape(ellipse[4:8], (2, 2)) / r ** 2
21 |
22 | # "singular value decomposition" to extract the orientation and the
23 | # axes of the ellipsoid
24 | _, D, V = svd(M)
25 |
26 | plot_grid = 50
27 |
28 | # get the major and minor axes
29 | a = 1 / np.sqrt(D[0])
30 | b = 1 / np.sqrt(D[1])
31 |
32 | theta = np.arange(0, 2 * np.pi + 1 / plot_grid, 1 / plot_grid)
33 |
34 | # parametric equation of the ellipse
35 | state = np.zeros((2, np.size(theta)))
36 | state[0, :] = a * np.cos(theta)
37 | state[1, :] = b * np.sin(theta)
38 |
39 | # coordinate transform
40 | X = V @ state
41 |
42 | X[0] += c[0]
43 | X[1] += c[1]
44 |
45 | axis_3d.plot(xs=X[0], ys=X[1], zs=ellipse[0], color=color, alpha=alph)
46 |
47 |
48 | def plot_ellipse(
49 | time_horizon, dim, axis1, axis2, file, color, alph, axis_3d, skip_reachsets=1
50 | ):
51 | data_ellipse = np.loadtxt(file)
52 |
53 | # permutation matrix to project on axis1 and axis2
54 | P = np.eye(dim)
55 | P[:, [0, axis1]] = P[:, [axis1, 0]]
56 | P[:, [1, axis2]] = P[:, [axis2, 1]]
57 |
58 | count = skip_reachsets
59 |
60 | for ellipse in data_ellipse[1:]:
61 |
62 | if count != skip_reachsets:
63 | count += 1
64 | continue
65 |
66 | count = 1
67 |
68 | ellipse2 = ellipse
69 |
70 | # create ellipse plotting values for 2d projection
71 | # https://math.stackexchange.com/questions/2438495/showing-positive-definiteness-in-the-projection-of-ellipsoid
72 | # construct ellipse2 to have a 2-dimensional ellipse as an input to
73 | # ellipse_plot
74 |
75 | if dim > 2:
76 | center = ellipse[1 : dim + 1]
77 | ellipse2[1] = center[axis1]
78 | ellipse2[2] = center[axis2]
79 | radius_ellipse = ellipse[dim + 1]
80 | m1 = np.reshape(ellipse[dim + 2 :], (dim, dim))
81 | m1 = m1 / radius_ellipse ** 2
82 | m1 = P.transpose() @ m1 @ P # permutation to project on chosen axes
83 | ellipse2[3] = 1 # because radius is already in m1
84 |
85 | # plot ellipse onto axis1-axis2 plane
86 | J = m1[0:2, 0:2]
87 | K = m1[2:, 2:]
88 | L = m1[2:, 0:2]
89 | m2 = J - L.transpose() @ inv(K) @ L
90 | ellipse2[4:8] = m2.reshape(1, -1)
91 |
92 | draw_ellipse(ellipse2[0:8], color, alph, axis_3d)
93 |
94 | if ellipse[0] >= time_horizon:
95 | break
96 |
97 |
98 | if __name__ == "__main__":
99 | parser = argparse.ArgumentParser(description="")
100 | parser.add_argument("--time_step", default=0.01, type=float)
101 | parser.add_argument("--time_horizon", default=0.01, type=float)
102 | parser.add_argument("--benchmark", default="bruss")
103 | parser.add_argument("--output_number", default="0000")
104 | parser.add_argument("--samples", default=100, type=int)
105 | parser.add_argument("--axis1", default=0, type=int)
106 | parser.add_argument("--axis2", default=1, type=int)
107 | # initial radius
108 | parser.add_argument("--radius", default=None, type=float)
109 |
110 | args = parser.parse_args()
111 |
112 | config = configparser.ConfigParser()
113 | config.read("config.ini")
114 |
115 | files = config["files"]
116 |
117 | rt = reach.StochasticReachtube(
118 | model=bm.get_model(args.benchmark, args.radius),
119 | time_horizon=args.time_horizon,
120 | time_step=args.time_step,
121 | samples=args.samples,
122 | axis1=args.axis1,
123 | axis2=args.axis2,
124 | ) # reachtube
125 |
126 | fig = plt.figure()
127 | ax = fig.add_subplot(111, projection="3d")
128 |
129 | p_dict = rt.plot_traces(axis_3d=ax)
130 | p_dict["time_horizon"] = args.time_horizon
131 | p_dict["dim"] = rt.model.dim
132 | p_dict["axis1"] = rt.axis1
133 | p_dict["axis2"] = rt.axis2
134 | p_dict["data_ellipse"] = np.loadtxt(
135 | files["output_directory"] + str(args.output_number) + files["output_file"]
136 | )
137 |
138 | plot_ellipse(
139 | args.time_horizon,
140 | rt.model.dim,
141 | rt.axis1,
142 | rt.axis2,
143 | files["output_directory"] + str(args.output_number) + files["output_file"],
144 | color="magenta",
145 | alph=0.8,
146 | axis_3d=ax,
147 | skip_reachsets=1,
148 | )
149 |
150 | plt.show()
151 |
152 | os.makedirs("plot_obj", exist_ok=True)
153 | for i in range(1000):
154 | filename = f"plot_obj/plot_{i:03d}.pkl"
155 | if not os.path.isfile(filename):
156 | with open(filename, "wb") as f:
157 | pickle.dump(p_dict, f, protocol=pickle.DEFAULT_PROTOCOL)
158 | break
159 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import jax.numpy as jnp
3 |
4 | import benchmarks as bm
5 | import stochastic_reachtube as reach
6 | import go_tube
7 | import configparser
8 | import time
9 | from performance_log import log_args
10 | from performance_log import close_log
11 | from performance_log import create_plot_file
12 | from performance_log import write_plot_file
13 | from performance_log import log_stat
14 |
15 | import argparse
16 |
17 | from jax import config
18 | config.update("jax_enable_x64", True)
19 |
20 |
21 | if __name__ == "__main__":
22 |
23 | start_time = time.time()
24 |
25 | parser = argparse.ArgumentParser(description="")
26 | parser.add_argument("--profile", action="store_true")
27 | parser.add_argument("--score", action="store_true")
28 | parser.add_argument("--benchmark", default="vdp")
29 | # starting_time, time_step and time_horizon for creating reachtubes
30 | parser.add_argument("--starting_time", default=0.0, type=float)
31 | parser.add_argument("--time_step", default=0.01, type=float)
32 | parser.add_argument("--time_horizon", default=10, type=float)
33 | # batch-size for tensorization
34 | parser.add_argument("--batch_size", default=10000, type=int)
35 | # number of GPUs for parallelization
36 | parser.add_argument("--num_gpus", default=1, type=int)
37 | # use fixed seed for random points (only for comparing different algorithms)
38 | parser.add_argument("--fixed_seed", action="store_true")
39 | # error-probability
40 | parser.add_argument("--gamma", default=0.2, type=float)
41 | # mu as maximum over-approximation
42 | parser.add_argument("--mu", default=1.5, type=float)
43 | # choose between hyperspheres and ellipsoids to describe the Reachsets
44 | parser.add_argument("--ellipsoids", action="store_true")
45 | # initial radius
46 | parser.add_argument("--radius", default=None, type=float)
47 |
48 | args = parser.parse_args()
49 | log_args(vars(args))
50 |
51 | config = configparser.ConfigParser()
52 | config.read("./config.ini")
53 |
54 | files = config["files"]
55 |
56 | rt = reach.StochasticReachtube(
57 | model=bm.get_model(args.benchmark, args.radius),
58 | profile=args.profile,
59 | mu=args.mu, # mu as maximum over-approximation
60 | gamma=args.gamma, # error-probability
61 | batch=args.batch_size,
62 | num_gpus=args.num_gpus,
63 | fixed_seed=args.fixed_seed,
64 | radius=args.radius,
65 | ) # reachtube
66 |
67 | timeRange = np.arange(args.starting_time + args.time_step, args.time_horizon + 1e-9, args.time_step)
68 | timeRange_with_start = np.append(np.array([0]), timeRange)
69 |
70 | volume = jnp.array([rt.compute_volume()])
71 |
72 | # propagate center point and compute metric
73 | (
74 | cx_timeRange,
75 | A1_timeRange,
76 | M1_timeRange,
77 | semiAxes_prod_timeRange,
78 | ) = rt.compute_metric_and_center(
79 | timeRange_with_start,
80 | args.ellipsoids,
81 | )
82 |
83 | fname = create_plot_file(files)
84 |
85 | write_plot_file(fname, "w", 0, rt.model.cx, rt.model.rad, M1_timeRange[0, :, :])
86 |
87 | total_random_points = None
88 | total_gradients = None
89 | total_initial_points = jnp.zeros((rt.num_gpus, 0, rt.model.dim))
90 |
91 | # for loop starting at Line 2
92 | for i, time_py in enumerate(timeRange):
93 | print(f"Step {i}/{timeRange.shape[0]} at {time_py:0.2f} s")
94 |
95 | rt.time_horizon = time_py
96 | rt.time_step = args.time_step
97 | rt.cur_time = rt.time_horizon
98 |
99 | rt.cur_cx = cx_timeRange[i + 1, :]
100 | rt.A1 = A1_timeRange[i + 1, :, :]
101 |
102 | # GoTube Algorithm for t = t_j
103 | # while loop from Line 5 - Line 15
104 | (
105 | rt.cur_rad,
106 | prob,
107 | total_initial_points,
108 | total_random_points,
109 | total_gradients,
110 | ) = go_tube.optimize(
111 | rt, total_initial_points, total_random_points, total_gradients
112 | )
113 |
114 | write_plot_file(
115 | fname, "a", time_py, rt.cur_cx, rt.cur_rad, M1_timeRange[i + 1, :, :]
116 | )
117 |
118 | volume = jnp.append(volume, rt.compute_volume(semiAxes_prod_timeRange[i + 1]))
119 |
120 | if rt.profile:
121 | # If profiling is enabled, log some statistics about the GD optimization process
122 | volumes = {
123 | "volume": float(volume[-1]),
124 | "average_volume": float(volume.mean()),
125 | }
126 | log_stat(volumes)
127 |
128 | if args.score:
129 | with open("all_prob_scores.csv", "a") as f:
130 | # CSV with header benchmark, time-horizon, prob, runtime, volume
131 | f.write(f"{args.benchmark},")
132 | f.write(f"{args.time_horizon:0.4g},")
133 | f.write(f"{args.radius:0.4g},")
134 | f.write(f"{args.mu:0.4g},")
135 | f.write(f"{1.0-args.gamma:0.4f},")
136 | f.write(f"{time.time()-start_time:0.2f},")
137 | f.write(f"{total_random_points.shape[0]:d},")
138 | f.write(f"{float(volume.mean()):0.5g}")
139 | f.write("\n")
140 | if rt.profile:
141 | final_notes = {
142 | "total_time": time.time() - start_time,
143 | "samples": args.num_gpus * total_random_points.shape[1],
144 | }
145 | close_log(final_notes)
146 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GoTube - Scalable Stochastic Verification of Continuous-Depth Models
2 |
3 | This is the official code repository of the paper *GoTube: Scalable Stochastic Verification of Continuous-Depth Models*
4 | accepted to the Thirty-Sixth AAAI Conference on Artificial Intelligence (AAAI-22)
5 | ([arXiv link](https://arxiv.org/abs/2107.08467)).
6 |
7 | GoTube constructs stochastic reachtubes (= the set of all reachable system states) of continuous-time systems. GoTube is made deliberately for the verification of continuous-depth neural networks.
8 | 
9 |
10 | This document describes the general usage of GoTube. For the setup to reproduce the exact numbers reported in the paper have a look at the files ```table1.sh```, ```table2.sh```, and ```table3.sh```.
11 |
12 | ## Setup
13 |
14 | Requirement is Python3.6 or newer.
15 | The setup was tested on Ubuntu 16.04, Ubuntu 20.04 and the recent MacOS machines with Python 3.6 and 3.8.
16 |
17 | We recommend creating a virtual environment, for GoTube's dependencies to not interfere with other python installations.
18 |
19 | ```bash
20 | python3 -m venv venv # Optional
21 | source venv/bin/activate # Optional
22 | python3 -m pip install -r requirements.txt
23 | ```
24 |
25 | ## Benchmarks
26 |
27 | Each Benchmark is encoded in a python class exposing three the attributes ```rad``` for the initial set radius, ```cx``` for the initial set center, and ```dim``` for the number of dimensions of the dynamical system.
28 | Moreover, each benchmark class must implement a method ```fdyn``` that defines the dynamical system and a case in the ``get_model`` method for choosing that model via a parameter.
29 | For example,
30 |
31 | ```python
32 | # CartPole-v1 with a linear policy
33 | class CartpoleLinear:
34 |
35 | def __init__(self, radius=None):
36 |
37 | #============ adapt initial values ===========
38 | self.cx = (0, 0, 0.001, 0) #initial values
39 | if radius is not None:
40 | self.rad = radius
41 | else:
42 | self.rad = 1e-4 #initial radius
43 | #===================================================
44 |
45 | self.cx = np.array(self.cx, dtype=float)
46 | self.dim = self.cx.size #dimension of the system
47 |
48 | def fdyn(self,t=0,x=None):
49 |
50 | if x is None:
51 | x=np.zeros(self.dim, dtype=object)
52 |
53 | #============ adapt input and system dynamics ===========
54 | dth, dx, th, x = x #input variables
55 |
56 | M = 1.0
57 | g = 9.81
58 | l = 1.0
59 | m = 0.001
60 |
61 | f = -1.1 * M * g * th - dth
62 |
63 | fdth = 1.0 / (l*(M+m*sin(th)*sin(th))) * (f * cos(th) - m*l*dth*dth*cos(th)*sin(th) + (m+M)*g*sin(th))
64 | fdx = 1.0 / (M+m*sin(th)*sin(th)) * ( f + m*sin(th) * (-l * dth*dth + g*cos(th)) )
65 |
66 | fx = dx
67 |
68 | fth = dth
69 |
70 | system_dynamics = [fdth, fdx, fth, fx] #has to be in the same order as the input variables
71 | #===================================================
72 |
73 | return np.array(system_dynamics) #return as numpy array
74 | ```
75 | And inside the ``get_model`` method:
76 | ```
77 | elif benchmark == "cartpole":
78 | return CartpoleLinear(radius) # Benchmark to run
79 | ```
80 |
81 | ## Running GoTube
82 |
83 | The entry point of GoTube is defined in the file ```main.py```, which accepts several arguments to specify properties of the reachtube and GoTube.
84 | The most notable arguments are:
85 |
86 | - ```--benchmark``` Name of the benchmark (=dynamical system) for which the reachtube should be constructed
87 | - ```--time_horizon``` Length in seconds of the reachtube to be constructed
88 | - ```--time_step``` Intermediate time-points for which reachtube should be constructed
89 | - ```--batch_size``` Batch size used for simulating points of the system. A large batch size speeds up the computation but adds additional memory footprints.
90 | - ```--gamma``` Error probability. For instance, a gamma of 0.1 means the reachtube will have a 90% confidence.
91 | - ```--mu``` Maximum multiplicative tolerance of over-approximation. A higher mu speeds up the computation of the bounding tube by in return increasing the average volume. For instance, a mu of 1.5 means the reachtube will have a 1.5 times larger radius than the most distant sample point.
92 |
93 | ## Examples
94 |
95 | Create Van der Pol dynamical system reachtube for 2 seconds with a batch size of 100 samples, while creating intermediate reachsets every 0.1 seconds
96 |
97 | ```bash
98 | python main.py --time_horizon 2 --benchmark vdp --batch_size 100 --time_step 0.1
99 | ```
100 |
101 | To create a 95% confidence reachtube for the CartPole-v1 with a CT-RNN and a maximum multiplicative tolerance of over-approximation mu of 1.5 run
102 |
103 | ```bash
104 | python main.py --time_horizon 1 --benchmark cartpoleCTRNN --batch_size 10000 --time_step 0.02 --gamma 0.05 --mu 1.5
105 | ```
106 |
107 | ## Citation
108 |
109 | ```tex
110 | @article{Gruenbacher2022GoTube,
111 | author={Gruenbacher, Sophie A. and Lechner, Mathias and Hasani, Ramin and Rus, Daniela and Henzinger, Thomas A. and Smolka, Scott A. and Grosu, Radu},
112 | title={GoTube: Scalable Statistical Verification of Continuous-Depth Models},
113 | volume={36},
114 | url={https://ojs.aaai.org/index.php/AAAI/article/view/20631},
115 | DOI={10.1609/aaai.v36i6.20631},
116 | number={6},
117 | journal={Proceedings of the AAAI Conference on Artificial Intelligence},
118 | year={2022},
119 | month={Jun.},
120 | pages={6755-6764},
121 | }
122 | ```
123 |
--------------------------------------------------------------------------------
/go_tube.py:
--------------------------------------------------------------------------------
1 | # Algorithms of GoTube paper for safety region, probability and stoch. optimization
2 |
3 | import jax.numpy as jnp
4 | from jax import vmap, pmap
5 | import polar_coordinates as pol
6 | from jax.numpy.linalg import svd
7 | import jax.scipy.special as sc
8 | import time
9 | from performance_log import log_stat
10 | from timer import Timer
11 | from scipy.stats import genextreme, kstest
12 | import gc
13 |
14 |
15 | # using expected difference quotient of center lipschitz constant
16 | def get_safety_region_radius(model, dist, dist_best, lip, lip_mean_diff):
17 | safety_radius = -lip + jnp.sqrt(lip ** 2 + 4 * lip_mean_diff * (model.mu * dist_best - dist))
18 | safety_radius = safety_radius / (2 * lip_mean_diff)
19 | safety_radius = safety_radius * (dist > 0) * (lip_mean_diff > 0)
20 |
21 | safety_radius = jnp.minimum(safety_radius, 2 * model.rad_t0)
22 |
23 | return safety_radius
24 |
25 |
26 | def compute_maximum_singular_value(A1, A0inv, F):
27 | F_metric = jnp.matmul(A1, F)
28 | F_metric = jnp.matmul(F_metric, A0inv)
29 | _, sf, _ = svd(F_metric)
30 | max_sf = jnp.max(sf)
31 |
32 | return max_sf
33 |
34 |
35 | def get_angle_of_cap(model, radius):
36 | radius = jnp.minimum(radius, 2 * model.rad_t0)
37 | return 2 * jnp.arcsin(0.5 * radius / model.rad_t0)
38 |
39 |
40 | def get_probability_of_cap(model, radius):
41 | with Timer('get angle of cap'):
42 | angle = get_angle_of_cap(model, radius)
43 | with Timer('get probability of cap'):
44 | a = 0.5 * (model.model.dim - 1)
45 | b = 0.5
46 | x = jnp.sin(angle) ** 2
47 | betainc_angle = 0.5 * sc.betainc(a, b, x)
48 |
49 | # formula is only for the smaller cap with angle <= pi/2, sinus is symmetric => thus use 1-area otherwise
50 | betainc_angle = jnp.where(angle > 0.5 * jnp.pi, 1 - betainc_angle, betainc_angle)
51 |
52 | return betainc_angle
53 |
54 |
55 | def get_probability_not_in_cap(model, radius):
56 | return 1 - get_probability_of_cap(model, radius)
57 |
58 |
59 | def get_probability_none_in_cap(model, radius_points):
60 | return jnp.prod(get_probability_not_in_cap(model, radius_points))
61 |
62 |
63 | # probability calculation using http://docsdrive.com/pdfs/ansinet/ajms/2011/66-70.pdf (equation 1
64 | # page 68) and the normalized incomplete Beta-Function in scipy (
65 | # https://scipy.github.io/devdocs/generated/scipy.special.betainc.html#scipy.special.betainc) - Only use the
66 | # random sampled points for probability construction
67 | # use also the discarded points and create balls around them
68 | def get_probability(model, radius_points):
69 | return jnp.sqrt(1-model.gamma) * (1 - get_probability_none_in_cap(model, radius_points))
70 |
71 |
72 | def compute_delta_lipschitz(y_jax, fy_jax, axis, gamma):
73 | gamma_hat = 1 - jnp.sqrt(1 - gamma)
74 | diff_quotients = get_diff_quotient_pairwise(y_jax, fy_jax, axis)
75 | sample_size = diff_quotients.size
76 | number_of_elements_for_maximum = round(sample_size ** (1 / 4)) # m in Lemma 1, Theorem 2 and throughout paper
77 |
78 | sample_size_dividable = sample_size - sample_size % number_of_elements_for_maximum
79 |
80 | diff_quotients_samples = diff_quotients[:sample_size_dividable].reshape(-1, number_of_elements_for_maximum)
81 | max_quotients = jnp.nanmax(diff_quotients_samples, axis=1)
82 | number_of_maxima = max_quotients.size # n in Lemma 1
83 | alpha = min(gamma_hat, 0.5)
84 | epsilon = jnp.sqrt(jnp.log(1 / alpha) / (2 * number_of_maxima))
85 |
86 | c, loc, scale = genextreme.fit(max_quotients)
87 | rv_genextreme = genextreme(c, loc, scale)
88 |
89 | D_minus = kstest(max_quotients, rv_genextreme.cdf, alternative='less').statistic
90 |
91 | max_quantile = 0.9999
92 |
93 | # # with generalized extreme value distribution
94 | prob_quantile = min(1 - gamma_hat, max_quantile - epsilon - D_minus)
95 | delta_lipschitz = rv_genextreme.ppf([prob_quantile + epsilon + D_minus]) # transformation of Eq. (S14)
96 | prob_bound_lipschitz = (1 - gamma_hat) * prob_quantile
97 |
98 | # without generalized extreme value distribution
99 | # max_quotients = jnp.sort(max_quotients)
100 | # prob_quantile = min(1 - gamma_hat, max_quantile - epsilon)
101 | # delta_lipschitz = max_quotients[int(jnp.floor((prob_quantile + epsilon) * number_of_maxima))] # transformation of Eq. (S14)
102 | # prob_bound_lipschitz = (1 - gamma_hat) * prob_quantile
103 |
104 | return delta_lipschitz, prob_bound_lipschitz
105 |
106 |
107 | def get_diff_quotient_pairwise(x, fx, axis):
108 | # reshape to get samples as first index and remove gpu dimension
109 | x = jnp.reshape(x, (-1, x.shape[2]))
110 | fx = jnp.reshape(fx, fx.size)
111 |
112 | samples = int(jnp.floor(x.shape[0] / 2))
113 | x1 = x[::2][:samples]
114 | x2 = x[1::2][:samples]
115 | fx1 = fx[::2][:samples]
116 | fx2 = fx[1::2][:samples]
117 | distance = jnp.linalg.norm(x1 - x2, axis=axis)
118 | diff_quotients = abs(fx1 - fx2) / distance * (distance > 0)
119 | return diff_quotients
120 |
121 |
122 | def get_diff_quotient(x, fx, y_jax, fy_jax, axis):
123 | distance = jnp.linalg.norm(x - y_jax, axis=axis)
124 | diff_quotients = abs(fx - fy_jax) / distance * (distance > 0)
125 | return diff_quotients
126 |
127 |
128 | def get_diff_quotient_vmap(x_jax, fx_jax, y_jax, fy_jax, axis):
129 | return vmap(get_diff_quotient, in_axes=(0, 0, None, None, None))(x_jax, fx_jax, y_jax, fy_jax, axis)
130 |
131 |
132 | def optimize(model, initial_points, points=None, gradients=None):
133 | start_time = time.time()
134 |
135 | prob = None
136 |
137 | if points is None or gradients is None:
138 | previous_samples = 0
139 | phis = pol.init_random_phi(model.model.dim, model.batch, model.num_gpus, model.fixed_seed)
140 | points, gradients, neg_dists, initial_points = model.aug_integrator_neg_dist(phis)
141 | dists = -neg_dists
142 | del neg_dists
143 | del phis
144 | gc.collect()
145 | else:
146 | previous_samples = points.shape[1]
147 | with Timer('integrate random points and gradients - one step'):
148 | points, gradients, dists = model.one_step_aug_integrator_dist(
149 | points, gradients
150 | )
151 |
152 | first_iteration = True
153 |
154 | while prob is None or prob < 1 - model.gamma:
155 |
156 | if not first_iteration:
157 | with Timer('sample phis'):
158 | phis = pol.init_random_phi(model.model.dim, model.batch, model.num_gpus, model.fixed_seed)
159 | with Timer('compute first integration step and dist'):
160 | new_points, new_gradients, new_neg_dists, new_initial_points = model.aug_integrator_neg_dist(phis)
161 | new_dists = -new_neg_dists
162 | del new_neg_dists
163 | del phis
164 | gc.collect()
165 |
166 | with Timer('concatenate new points to tensors'):
167 | points = jnp.concatenate((points, new_points), axis=1)
168 | gradients = jnp.concatenate((gradients, new_gradients), axis=1)
169 | dists = jnp.concatenate((dists, new_dists), axis=1)
170 | initial_points = jnp.concatenate((initial_points, new_initial_points), axis=1)
171 | del new_points
172 | del new_gradients
173 | del new_dists
174 | del new_initial_points
175 | gc.collect()
176 |
177 | with Timer('compute best dist'):
178 | dist_best = dists.max()
179 |
180 | with Timer('compute lipschitz'):
181 | # compute maximum singular values of all new gradient matrices
182 | lipschitz = pmap(vmap(compute_maximum_singular_value, in_axes=(None, None, 0)), in_axes=(None, None, 0))(model.A1, model.A0inv, gradients)
183 |
184 | with Timer('compute expected local lipschitz'):
185 | sample_size = points.shape[0]
186 |
187 | # compute expected value of delta lipschitz
188 | dimension_axis = 1
189 |
190 | gamma_hat = 1 - jnp.sqrt(1 - model.gamma)
191 |
192 | delta_lipschitz, prob_bound_lipschitz = compute_delta_lipschitz(
193 | initial_points[:sample_size],
194 | lipschitz[:sample_size],
195 | dimension_axis,
196 | gamma_hat
197 | )
198 |
199 | with Timer('get safety region radii'):
200 | safety_region_radii = get_safety_region_radius(
201 | model, dists, dist_best, lipschitz, delta_lipschitz
202 | )
203 |
204 | with Timer('compute probability'):
205 | prob = get_probability(model, safety_region_radii)
206 |
207 | del delta_lipschitz
208 | del lipschitz
209 | del safety_region_radii
210 | gc.collect()
211 |
212 | print(f"Current probability coverage is {100.0 * prob:0.3f}% using {model.num_gpus * points.shape[1]} points")
213 |
214 | first_iteration = False
215 |
216 | new_samples = model.num_gpus * points.shape[1] - previous_samples
217 |
218 | print("Probability reached given value!")
219 | print(
220 | f"Visited {new_samples} new points in {time.time() - start_time:0.2f} seconds."
221 | )
222 |
223 | dist_with_safety_mu = model.mu * dist_best
224 |
225 | if model.profile:
226 | # If profiling is enabled, log some statistics about the optimization process
227 | stat_dict = {
228 | "loop_time": time.time() - start_time,
229 | "new_points": int(new_samples),
230 | "total_points": int(previous_samples + new_samples),
231 | "prob": float(prob),
232 | "dist_best": float(dist_best),
233 | "radius": float(dist_with_safety_mu),
234 | }
235 | log_stat(stat_dict)
236 |
237 | return dist_with_safety_mu, prob, initial_points, points, gradients
238 |
--------------------------------------------------------------------------------
/stochastic_reachtube.py:
--------------------------------------------------------------------------------
1 | # optimization problem
2 |
3 | import numpy as np
4 | import jax.numpy as jnp
5 | from jax.experimental.ode import odeint
6 | from jax import vmap, jit, pmap, device_put, devices
7 | from functools import partial
8 |
9 | from scipy.special import gamma
10 |
11 | # own files
12 | import benchmarks as bm
13 | import polar_coordinates as pol
14 | import dynamics
15 |
16 |
17 | def create_aug_state_cartesian(x, F):
18 | aug_state = jnp.concatenate((jnp.array([x]), F)).reshape(
19 | -1
20 | ) # reshape to row vector
21 |
22 | return aug_state
23 |
24 |
25 | class StochasticReachtube:
26 | def __init__(
27 | self,
28 | model=bm.CartpoleCTRNN(None),
29 | time_horizon=10.0, # time_horizon until which the reachtube should be constructed
30 | profile=False,
31 | time_step=0.1, # ReachTube construction
32 | h_metric=0.05, # time_step for metric computation
33 | h_traces=0.01, # time_step for traces computation
34 | max_step_metric=0.00125, # maximum time_step for metric computation
35 | max_step_optim=0.1, # maximum time_step for optimization
36 | samples=100, # just for plotting: number of random points on the border of the initial ball
37 | batch=1, # number of initial points for vectorization
38 | num_gpus=1, # number of GPUs for parallel computation
39 | fixed_seed=False, # specify whether a fixed seed should be used (only for comparing different algorithms)
40 | axis1=0, # axis to project reachtube to
41 | axis2=1,
42 | atol=1e-10, # absolute tolerance of integration
43 | rtol=1e-10, # relative tolerance of integration
44 | plot_grid=50,
45 | mu=1.5,
46 | gamma=0.01,
47 | radius=False,
48 | ):
49 |
50 | self.time_step = min(time_step, time_horizon)
51 | self.profile = profile
52 | self.h_metric = min(h_metric, time_step)
53 | self.h_traces = h_traces
54 | self.max_step_metric = min(max_step_metric, self.h_metric)
55 | self.max_step_optim = min(max_step_optim, self.time_step)
56 | self.time_horizon = time_horizon
57 | self.samples = samples
58 | self.batch = batch
59 | self.num_gpus = num_gpus
60 | self.fixed_seed = fixed_seed
61 | self.axis1 = axis1
62 | self.axis2 = axis2
63 | self.atol = atol
64 | self.rtol = rtol
65 | self.plotGrid = plot_grid
66 | self.mu = mu
67 | self.gamma = gamma
68 |
69 | self.model = model
70 | self.init_model()
71 |
72 | self.metric = dynamics.FunctionDynamics(model).metric
73 | self.init_metric()
74 |
75 | self.f_jac_at = dynamics.FunctionDynamics(model).f_jac_at
76 |
77 | def init_metric(self):
78 | self.M1 = np.eye(self.model.dim)
79 | self.A1 = np.eye(self.model.dim)
80 | self.A1inv = np.eye(self.model.dim)
81 | self.A0inv = np.eye(self.model.dim)
82 |
83 | def init_model(self):
84 | self.cur_time = 0
85 | self.cur_cx = self.model.cx
86 | self.cur_rad = self.model.rad
87 | self.t0 = 0
88 | self.cx_t0 = self.model.cx
89 | self.rad_t0 = self.model.rad
90 |
91 | def compute_volume(self, semiAxes_product=None):
92 | if semiAxes_product is None:
93 | semiAxes_product = 1
94 | volC = gamma(self.model.dim / 2.0 + 1) ** -1 * jnp.pi ** (
95 | self.model.dim / 2.0
96 | ) # volume constant for ellipse and ball
97 | return volC * self.cur_rad ** self.model.dim * semiAxes_product
98 |
99 | def plot_traces(self, axis_3d):
100 | rd_polar = pol.init_random_phi(self.model.dim, self.samples)
101 | # reshape to get samples as first index and remove gpu dimension
102 | rd_polar = jnp.reshape(rd_polar, (-1, rd_polar.shape[2]))
103 | rd_x = (
104 | vmap(pol.polar2cart, in_axes=(None, 0))(self.model.rad, rd_polar)
105 | + self.model.cx
106 | )
107 | plot_timerange = jnp.arange(0, self.time_horizon + 1e-9, self.h_traces)
108 |
109 | sol = odeint(
110 | self.fdyn_jax_no_pmap,
111 | rd_x,
112 | plot_timerange,
113 | atol=self.atol,
114 | rtol=self.rtol,
115 | )
116 |
117 | for s in range(self.samples):
118 | axis_3d.plot(
119 | xs=sol[:, s, self.axis1],
120 | ys=sol[:, s, self.axis2],
121 | zs=plot_timerange,
122 | color="k",
123 | linewidth=1,
124 | )
125 |
126 | p_dict = {
127 | "xs": np.array(sol[:, s, self.axis1]),
128 | "ys": np.array(sol[:, s, self.axis2]),
129 | "zs": np.array(plot_timerange),
130 | }
131 | return p_dict
132 |
133 | def propagate_center_point(self, time_range):
134 | cx_jax = self.model.cx.reshape(1, self.model.dim)
135 | F = jnp.eye(self.model.dim)
136 | # put aug_state in CPU, as it is faster for odeint than GPU
137 | aug_state = device_put(jnp.concatenate((cx_jax, F)).reshape(1, -1), device=devices("cpu")[0])
138 | sol = odeint(
139 | self.aug_fdyn_jax_no_pmap,
140 | aug_state,
141 | time_range,
142 | atol=self.atol,
143 | rtol=self.rtol,
144 | )
145 | cx, F = vmap(self.reshape_aug_state_to_matrix)(sol)
146 | return cx, F
147 |
148 | def compute_metric_and_center(self, time_range, ellipsoids):
149 | print(f"Propagating center point for {time_range.shape[0]-1} timesteps")
150 | cx_timeRange, F_timeRange = self.propagate_center_point(time_range)
151 | A1_timeRange = np.eye(self.model.dim).reshape(1, self.model.dim, self.model.dim)
152 | M1_timeRange = np.eye(self.model.dim).reshape(1, self.model.dim, self.model.dim)
153 | semiAxes_prod_timeRange = np.array([1])
154 |
155 | print("Starting loop for creating metric")
156 | for idx, t in enumerate(time_range[1:]):
157 | M1_t, A1_t, semiAxes_prod_t = self.metric(
158 | F_timeRange[idx + 1, :, :], ellipsoids
159 | )
160 | A1_timeRange = np.concatenate(
161 | (A1_timeRange, A1_t.reshape(1, self.model.dim, self.model.dim)), axis=0
162 | )
163 | M1_timeRange = np.concatenate(
164 | (M1_timeRange, M1_t.reshape(1, self.model.dim, self.model.dim)), axis=0
165 | )
166 | semiAxes_prod_timeRange = np.append(
167 | semiAxes_prod_timeRange, semiAxes_prod_t
168 | )
169 |
170 | return cx_timeRange, A1_timeRange, M1_timeRange, semiAxes_prod_timeRange
171 |
172 | def reshape_aug_state_to_matrix(self, aug_state):
173 | aug_state = aug_state.reshape(-1, self.model.dim) # reshape to matrix
174 | x = aug_state[:1][0]
175 | F = aug_state[1:]
176 | return x, F
177 |
178 | def reshape_aug_fdyn_return_to_vector(self, fdyn_return, F_return):
179 | return jnp.concatenate((jnp.array([fdyn_return]), F_return)).reshape(-1)
180 |
181 | @partial(jit, static_argnums=(0,))
182 | def aug_fdyn(self, t=0, aug_state=0):
183 | x, F = self.reshape_aug_state_to_matrix(aug_state)
184 | fdyn_return = self.model.fdyn(t, x)
185 | F_return = jnp.matmul(self.f_jac_at(t, x), F)
186 | return self.reshape_aug_fdyn_return_to_vector(fdyn_return, F_return)
187 |
188 | def aug_fdyn_jax_no_pmap(self, aug_state=0, t=0):
189 | return vmap(self.aug_fdyn, in_axes=(None, 0))(t, aug_state)
190 |
191 | def fdyn_jax_no_pmap(self, x=0, t=0):
192 | return vmap(self.model.fdyn, in_axes=(None, 0))(t, x)
193 |
194 | def aug_fdyn_jax(self, aug_state=0, t=0):
195 | return pmap(vmap(self.aug_fdyn, in_axes=(None, 0)), in_axes=(None, 0))(t, aug_state)
196 |
197 | def fdyn_jax(self, x=0, t=0):
198 | return pmap(vmap(self.model.fdyn, in_axes=(None, 0)), in_axes=(None, 0))(t, x)
199 |
200 | def create_aug_state(self, polar, rad_t0, cx_t0):
201 | x = jnp.array(
202 | pol.polar2cart_euclidean_metric(rad_t0, polar, self.A0inv) + cx_t0
203 | )
204 | F = jnp.eye(self.model.dim)
205 |
206 | aug_state = jnp.concatenate((jnp.array([x]), F)).reshape(
207 | -1
208 | ) # reshape to row vector
209 |
210 | return aug_state, x
211 |
212 | def one_step_aug_integrator(self, x, F):
213 | aug_state = pmap(vmap(create_aug_state_cartesian))(x, F)
214 | sol = odeint(
215 | self.aug_fdyn_jax,
216 | aug_state,
217 | jnp.array([0, self.time_step]),
218 | atol=self.atol,
219 | rtol=self.rtol,
220 | )
221 | x, F = pmap(vmap(self.reshape_aug_state_to_matrix))(sol[-1])
222 | return x, F
223 |
224 | def aug_integrator(self, polar, step=None):
225 | if step is None:
226 | step = self.cur_time
227 |
228 | rad_t0 = self.rad_t0
229 | cx_t0 = self.cx_t0
230 |
231 | aug_state, initial_x = pmap(vmap(self.create_aug_state, in_axes=(0, None, None)), in_axes=(0, None, None))(
232 | polar, rad_t0, cx_t0
233 | )
234 | sol = odeint(
235 | self.aug_fdyn_jax,
236 | aug_state,
237 | jnp.array([0, step]),
238 | atol=self.atol,
239 | rtol=self.rtol,
240 | )
241 | x, F = pmap(vmap(self.reshape_aug_state_to_matrix))(sol[-1])
242 | return x, F, initial_x
243 |
244 | def aug_integrator_neg_dist(self, polar):
245 | x, F, initial_x = self.aug_integrator(polar)
246 | neg_dist = pmap(vmap(self.neg_dist_x))(x)
247 | return x, F, neg_dist, initial_x
248 |
249 | def one_step_aug_integrator_dist(self, x, F):
250 | x, F = self.one_step_aug_integrator(x, F)
251 | neg_dist = pmap(vmap(self.neg_dist_x))(x)
252 | return x, F, -neg_dist
253 |
254 | def neg_dist_x(self, xt):
255 | dist = jnp.linalg.norm(jnp.matmul(self.A1, xt - self.cur_cx))
256 | return -dist
--------------------------------------------------------------------------------
/benchmarks.py:
--------------------------------------------------------------------------------
1 | # different classes with benchmarks
2 |
3 | import jax.numpy as np
4 | from jax.numpy import tanh
5 | from jax.numpy import sin
6 | from jax.numpy import cos
7 | from jax.numpy import exp
8 |
9 |
10 | def get_model(benchmark, radius=None):
11 | if benchmark == "bruss":
12 | return Brusselator(radius) # Benchmark to run
13 | elif benchmark == "vdp":
14 | return VanDerPol(radius) # Benchmark to run
15 | elif benchmark == "robot":
16 | return Robotarm(radius) # Benchmark to run
17 | elif benchmark == "dubins":
18 | return DubinsCar(radius) # Benchmark to run
19 | elif benchmark == "ms":
20 | return MitchellSchaeffer(radius) # Benchmark to run
21 | elif benchmark == "cartpole":
22 | return CartpoleLinear(radius) # Benchmark to run
23 | elif benchmark == "quadcopter":
24 | return Quadcopter(radius) # Benchmark to run
25 | elif benchmark == "cartpoleCTRNN":
26 | return CartpoleCTRNN(radius) # Benchmark to run
27 | elif benchmark == "cartpoleLTC":
28 | return CartpoleLTC(radius) # Benchmark to run
29 | elif benchmark == "cartpoleLTC_RK":
30 | return CartpoleLTC_RK(radius) # Benchmark to run
31 | elif benchmark == "ldsCTRNN":
32 | return LDSwithCTRNN(radius) # Benchmark to run
33 | elif benchmark == "pendulumCTRNN":
34 | return PendulumwithCTRNN(radius) # Benchmark to run
35 | elif benchmark == "CTRNNosc":
36 | return CTRNNosc(radius) # Benchmark to run
37 | else:
38 | raise ValueError("Unknown benchmark " + benchmark)
39 |
40 |
41 | # 2-dimensional brusselator
42 | class Brusselator:
43 | def __init__(self, radius=None):
44 | # ============ adapt initial values ===========
45 | self.cx = (1, 1)
46 | if radius is not None:
47 | self.rad = radius
48 | else:
49 | self.rad = 0.01
50 | # ===================================================
51 | self.cx = np.array(self.cx, dtype=float)
52 | self.dim = self.cx.size # dimension of the system
53 |
54 | def fdyn(self, t=0, x=None):
55 | if x is None:
56 | x = np.zeros(self.dim, dtype=object)
57 |
58 | # ============ adapt input and system dynamics ===========
59 | x, y = x
60 |
61 | a = 1
62 | b = 1.5
63 |
64 | fx = a + x ** 2 * y - (b + 1) * x
65 | fy = b * x - x ** 2 * y
66 |
67 | system_dynamics = [fx, fy] # has to be in the same order as the input variables
68 | # ===================================================
69 |
70 | return np.array(system_dynamics) # return as numpy array
71 |
72 |
73 | # 2-dimensional van der pol
74 | class VanDerPol:
75 | def __init__(self, radius):
76 | # ============ adapt initial values ===========
77 | self.cx = (-1, -1)
78 | if radius is not None:
79 | self.rad = radius
80 | else:
81 | self.rad = 0.1
82 | # ===================================================
83 | self.cx = np.array(self.cx, dtype=float)
84 | self.dim = self.cx.size # dimension of the system
85 |
86 | def fdyn(self, t=0, x=None):
87 | if x is None:
88 | x = np.zeros(self.dim, dtype=object)
89 |
90 | # ============ adapt input and system dynamics ===========
91 | x, y = x
92 |
93 | fx = y
94 | fy = (x * x - 1) * y - x
95 |
96 | system_dynamics = [fx, fy] # has to be in the same order as the input variables
97 | # ===================================================
98 |
99 | return np.array(system_dynamics) # return as numpy array
100 |
101 |
102 | # 4-dimensional Robotarm
103 | class Robotarm:
104 | def __init__(self, radius):
105 | # ============ adapt initial values ===========
106 | self.cx = (1.505, 1.505, 0.005, 0.005)
107 | if radius is not None:
108 | self.rad = radius
109 | else:
110 | self.rad = 0.05
111 | # ===================================================
112 | self.cx = np.array(self.cx, dtype=float)
113 | self.dim = self.cx.size # dimension of the system
114 |
115 | def fdyn(self, t=0, x=None):
116 | if x is None:
117 | x = np.zeros(self.dim, dtype=object)
118 |
119 | # ============ adapt input and system dynamics ===========
120 | x1, x2, x3, x4 = x
121 |
122 | m = 1
123 | l = 3
124 | kp1 = 2
125 | kp2 = 1
126 | kd1 = 2
127 | kd2 = 1
128 |
129 | fx1 = x3
130 | fx2 = x4
131 | fx3 = (-2 * m * x2 * x3 * x4 - kp1 * x1 - kd1 * x3) / (m * x2 * x2 + l / 3) + (
132 | kp1 * kp1
133 | ) / (m * x2 * x2 + l / 3)
134 | fx4 = x2 * x3 * x3 - kp2 * x2 / m - kd2 * x4 / m + kp2 * kp2 / m
135 |
136 | system_dynamics = [
137 | fx1,
138 | fx2,
139 | fx3,
140 | fx4,
141 | ] # has to be in the same order as the input variables
142 | # ===================================================
143 |
144 | return np.array(system_dynamics) # return as numpy array
145 |
146 |
147 | # 3-dimensional dubins car
148 | class DubinsCar:
149 | def __init__(self, radius):
150 | # ============ adapt initial values ===========
151 | self.cx = (0, 0, 0.7854, 0)
152 | if radius is not None:
153 | self.rad = radius
154 | else:
155 | self.rad = 0.01
156 | # ===================================================
157 | self.cx = np.array(self.cx, dtype=float)
158 | self.dim = self.cx.size # dimension of the system
159 |
160 | def fdyn(self, t=0, x=None):
161 | if x is None:
162 | x = np.zeros(self.dim, dtype=object)
163 |
164 | # ============ adapt input and system dynamics ===========
165 | x, y, th, tt = x
166 |
167 | v = 1
168 |
169 | fx = v * cos(th)
170 | fy = v * sin(th)
171 | fth = x * sin(tt)
172 | ftt = np.array(1) # this is needed for lax_numpy to see this as an array
173 |
174 | system_dynamics = [
175 | fx,
176 | fy,
177 | fth,
178 | ftt,
179 | ] # has to be in the same order as the input variables
180 | # ===================================================
181 |
182 | return np.array(system_dynamics) # return as numpy array
183 |
184 |
185 | # 2-dimensional Mitchell Schaeffer cardiac-cell
186 | class MitchellSchaeffer:
187 | def __init__(self, radius):
188 | # ============ adapt initial values ===========
189 | self.cx = (0.8, 0.5)
190 | if radius is not None:
191 | self.rad = radius
192 | else:
193 | self.rad = 0.1
194 | # ===================================================
195 | self.cx = np.array(self.cx, dtype=float)
196 | self.dim = self.cx.size # dimension of the system
197 |
198 | def fdyn(self, t=0, x=None):
199 | if x is None:
200 | x = np.zeros(self.dim, dtype=object)
201 |
202 | # ============ adapt input and system dynamics ===========
203 | x1, x2 = x
204 |
205 | sig_x1 = 0.5 * (1 + tanh(50 * x1 - 5))
206 |
207 | fx1 = x2 * x1 ** 2 * (1 - x1) / 0.3 - x1 / 6
208 | fx2 = sig_x1 * (-x2 / 150) + (1 - sig_x1) * (1 - x2) / 20
209 |
210 | system_dynamics = [
211 | fx1,
212 | fx2,
213 | ] # has to be in the same order as the input variables
214 | # ===================================================
215 |
216 | return np.array(system_dynamics) # return as numpy array
217 |
218 |
219 | # 4-dimensional cartpole with linear stabilizing controller
220 | class CartpoleLinear:
221 | def __init__(self, radius):
222 | # ============ adapt initial values ===========
223 | self.cx = (0, 0, 0.001, 0) # initial values
224 | if radius is not None:
225 | self.rad = radius
226 | else:
227 | self.rad = 0.05
228 | # ===================================================
229 |
230 | self.cx = np.array(self.cx, dtype=float)
231 | self.dim = self.cx.size # dimension of the system
232 |
233 | def fdyn(self, t=0, x=None):
234 | if x is None:
235 | x = np.zeros(self.dim, dtype=object)
236 |
237 | # ============ adapt input and system dynamics ===========
238 | dth, dx, th, x = x # input variables
239 |
240 | M = 1.0
241 | g = 9.81
242 | l = 1.0
243 | m = 0.001
244 |
245 | f = -1.1 * M * g * th - dth
246 |
247 | fdth = (
248 | 1.0
249 | / (l * (M + m * sin(th) * sin(th)))
250 | * (
251 | f * cos(th)
252 | - m * l * dth * dth * cos(th) * sin(th)
253 | + (m + M) * g * sin(th)
254 | )
255 | )
256 | fdx = (
257 | 1.0
258 | / (M + m * sin(th) * sin(th))
259 | * (f + m * sin(th) * (-l * dth * dth + g * cos(th)))
260 | )
261 |
262 | fx = dx
263 |
264 | fth = dth
265 |
266 | system_dynamics = [
267 | fdth,
268 | fdx,
269 | fth,
270 | fx,
271 | ] # has to be in the same order as the input variables
272 | # ===================================================
273 |
274 | return np.array(system_dynamics) # return as numpy array
275 |
276 |
277 | # 17-dimensional Quadcopter
278 | class Quadcopter:
279 | def __init__(self, radius):
280 | # ============ adapt initial values ===========
281 | self.cx = (
282 | -0.995,
283 | -0.995,
284 | 9.005,
285 | -0.995,
286 | -0.995,
287 | -0.995,
288 | -0.995,
289 | -0.995,
290 | -0.995,
291 | 0,
292 | 0,
293 | 0,
294 | 1,
295 | 0,
296 | 0,
297 | 0,
298 | 0,
299 | )
300 | if radius is not None:
301 | self.rad = radius
302 | else:
303 | self.rad = 0.005
304 | # ===================================================
305 | self.cx = np.array(self.cx, dtype=float)
306 | self.dim = self.cx.size # dimension of the system
307 |
308 | def fdyn(self, t=0, x=None):
309 | if x is None:
310 | x = np.zeros(self.dim, dtype=object)
311 |
312 | # ============ adapt input and system dynamics ===========
313 | pn, pe, h, u, v, w, p, q, r, q0, q1, q2, q3, pI, qI, rI, hI = x
314 |
315 | pr = 0
316 | qr = 0
317 | rr = 0
318 | hr = 0
319 |
320 | pn_ = (
321 | 2 * u * (q0 * q0 + q1 * q1 - 0.5)
322 | - 2 * v * (q0 * q3 - q1 * q2)
323 | + 2 * w * (q0 * q2 + q1 * q3)
324 | )
325 | pe_ = (
326 | 2 * v * (q0 * q0 + q2 * q2 - 0.5)
327 | + 2 * u * (q0 * q3 + q1 * q2)
328 | - 2 * w * (q0 * q1 - q2 * q3)
329 | )
330 | h_ = (
331 | 2 * w * (q0 * q0 + q3 * q3 - 0.5)
332 | - 2 * u * (q0 * q2 - q1 * q3)
333 | + 2 * v * (q0 * q1 + q2 * q3)
334 | )
335 |
336 | u_ = r * v - q * w - 11.62 * (q0 * q2 - q1 * q3)
337 | v_ = p * w - r * u + 11.62 * (q0 * q1 + q2 * q3)
338 | w_ = q * u - p * v + 11.62 * (q0 * q0 + q3 * q3 - 0.5)
339 |
340 | q0_ = -0.5 * q1 * p - 0.5 * q2 * q - 0.5 * q3 * r
341 | q1_ = 0.5 * q0 * p - 0.5 * q3 * q + 0.5 * q2 * r
342 | q2_ = 0.5 * q3 * p + 0.5 * q0 * q - 0.5 * q1 * r
343 | q3_ = 0.5 * q1 * q - 0.5 * q2 * p + 0.5 * q0 * r
344 |
345 | p_ = (
346 | -40.00063258437631 * pI - 2.8283979829540325 * p
347 | ) - 1.133407423682400 * q * r
348 | q_ = (
349 | -39.99980452524146 * qI - 2.8283752541008109 * q
350 | ) + 1.132078179613602 * p * r
351 | r_ = (
352 | -39.99978909742505 * rI - 2.8284134223281210 * r
353 | ) - 0.004695219977601 * p * q
354 |
355 | pI_ = p - pr
356 | qI_ = q - qr
357 | rI_ = r - rr
358 | hI_ = h - hr
359 |
360 | system_dynamics = [
361 | pn_,
362 | pe_,
363 | h_,
364 | u_,
365 | v_,
366 | w_,
367 | p_,
368 | q_,
369 | r_,
370 | q0_,
371 | q1_,
372 | q2_,
373 | q3_,
374 | pI_,
375 | qI_,
376 | rI_,
377 | hI_,
378 | ] # has to be in the same order as the input variables
379 | # ===================================================
380 |
381 | return np.array(system_dynamics) # return as numpy array
382 |
383 |
384 | # CTRNN DampedForced Pendulum example
385 | # from https://easychair.org/publications/open/K6SZ
386 | class CTRNN_DampedForcedPendulum:
387 | def __init__(self, radius):
388 | # ============ adapt initial values ===========
389 | self.cx = (0.21535, -0.58587, 0.8, 0.52323, 0.5) # initial values
390 | if radius is not None:
391 | self.rad = radius
392 | else:
393 | self.rad = 1e-08 # initial radius
394 | # ===================================================
395 |
396 | self.cx = np.array(self.cx, dtype=float)
397 | self.dim = self.cx.size # dimension of the system
398 |
399 | def fdyn(self, t=0, x=None):
400 | if x is None:
401 | x = np.zeros(self.dim, dtype=object)
402 |
403 | # ============ adapt input and system dynamics ===========
404 | x1, x2, x3, x4, x5 = x
405 |
406 | x1p = (
407 | 5419046323626097
408 | * (exp(-2 * x3) - 1)
409 | / (4503599627370496 * (exp(-2 * x3) + 1))
410 | - x1 / 1000000
411 | + 3601 * (exp(-2 * x4) - 1) / (50000 * (exp(-2 * x4) + 1))
412 | + 18727 * (exp(-2 * x5) - 1) / (20000 * (exp(-2 * x5) + 1))
413 | )
414 | x2p = (
415 | 30003 * (exp(-2 * x4) - 1) / (20000 * (exp(-2 * x4) + 1))
416 | - 11881 * (exp(-2 * x3) - 1) / (10000 * (exp(-2 * x3) + 1))
417 | - x2 / 1000000
418 | - 93519 * (exp(-2 * x5) - 1) / (100000 * (exp(-2 * x5) + 1))
419 | )
420 | x3p = (
421 | 7144123746377831
422 | * (exp(-2 * x3) - 1)
423 | / (4503599627370496 * (exp(-2 * x3) + 1))
424 | - x3 / 1000000
425 | - 5048886837752751
426 | * (exp(-2 * x4) - 1)
427 | / (72057594037927936 * (exp(-2 * x4) + 1))
428 | + 5564385670244745
429 | * (exp(-2 * x5) - 1)
430 | / (4503599627370496 * (exp(-2 * x5) + 1))
431 | )
432 | x4p = (
433 | 1348796766312415
434 | * (exp(-2 * x4) - 1)
435 | / (4503599627370496 * (exp(-2 * x4) + 1))
436 | - 3086507593514335
437 | * (exp(-2 * x3) - 1)
438 | / (36028797018963968 * (exp(-2 * x3) + 1))
439 | - x4 / 1000000
440 | - 2476184452153819
441 | * (exp(-2 * x5) - 1)
442 | / (36028797018963968 * (exp(-2 * x5) + 1))
443 | )
444 | x5p = (
445 | 1523758031023695
446 | * (exp(-2 * x4) - 1)
447 | / (18014398509481984 * (exp(-2 * x4) + 1))
448 | - 8060407538855891
449 | * (exp(-2 * x3) - 1)
450 | / (4503599627370496 * (exp(-2 * x3) + 1))
451 | - x5 / 1000000
452 | - 3139112893264555
453 | * (exp(-2 * x5) - 1)
454 | / (2251799813685248 * (exp(-2 * x5) + 1))
455 | )
456 |
457 | system_dynamics = [x1p, x2p, x3p, x4p, x5p]
458 | # ===================================================
459 |
460 | return np.array(system_dynamics) # return as numpy array
461 |
462 |
463 | # 12-dimensional cartpole with CT-RNN neural network controller
464 | class CartpoleCTRNN:
465 | def __init__(self, radius):
466 | # ============ adapt initial values ===========
467 | self.cx = (0, 0, 0.001, 0, 0, 0, 0, 0, 0, 0, 0, 0)
468 | if radius is not None:
469 | self.rad = radius
470 | else:
471 | self.rad = 1e-4
472 | # ===================================================
473 | self.cx = np.array(self.cx, dtype=float)
474 | self.dim = self.cx.size # dimension of the system
475 |
476 | def fdyn(self, t=0, x=None):
477 | if x is None:
478 | x = np.zeros(self.dim, dtype=object)
479 |
480 | # ============ adapt input and system dynamics ===========
481 | x_00, x_01, x_02, x_03, h_00, h_01, h_02, h_03, h_04, h_05, h_06, h_07 = x
482 |
483 | h_00_prime = -h_00 + tanh(
484 | h_00 * -1.49394
485 | + h_01 * -0.61947
486 | + h_02 * 0.37393
487 | + h_03 * -0.63451
488 | + h_04 * -1.08420
489 | + h_05 * 2.57981
490 | + h_06 * -1.53850
491 | + h_07 * -1.64354
492 | + x_00 * -0.07445
493 | + x_01 * 0.08736
494 | + x_02 * 0.47684
495 | + x_03 * 0.56397
496 | + 1.81867
497 | )
498 | h_01_prime = -h_01 + tanh(
499 | h_00 * 1.38295
500 | + h_01 * -1.45811
501 | + h_02 * 1.01473
502 | + h_03 * -0.04578
503 | + h_04 * -0.13416
504 | + h_05 * -0.21970
505 | + h_06 * 0.41791
506 | + h_07 * -0.10833
507 | + x_00 * -1.70409
508 | + x_01 * 0.51560
509 | + x_02 * -0.71273
510 | + x_03 * -0.61720
511 | + 0.86009
512 | )
513 | h_02_prime = -h_02 + tanh(
514 | h_00 * 0.57055
515 | + h_01 * 0.05941
516 | + h_02 * -0.16993
517 | + h_03 * -0.69688
518 | + h_04 * -0.30939
519 | + h_05 * -1.31558
520 | + h_06 * 0.03316
521 | + h_07 * 2.35873
522 | + x_00 * 0.21849
523 | + x_01 * 1.42990
524 | + x_02 * 1.60666
525 | + x_03 * -0.66847
526 | + -1.56112
527 | )
528 | h_03_prime = -h_03 + tanh(
529 | h_00 * -1.53087
530 | + h_01 * -0.42779
531 | + h_02 * -0.02195
532 | + h_03 * 0.18007
533 | + h_04 * 1.54262
534 | + h_05 * -0.19275
535 | + h_06 * -0.64598
536 | + h_07 * -0.85840
537 | + x_00 * 0.10095
538 | + x_01 * 0.55012
539 | + x_02 * 1.51416
540 | + x_03 * 1.95952
541 | + 0.00020
542 | )
543 | h_04_prime = -h_04 + tanh(
544 | h_00 * -0.54493
545 | + h_01 * 0.41139
546 | + h_02 * -0.18790
547 | + h_03 * -0.14312
548 | + h_04 * -0.10144
549 | + h_05 * 2.38792
550 | + h_06 * -0.44134
551 | + h_07 * -2.06462
552 | + x_00 * -0.01260
553 | + x_01 * 0.84681
554 | + x_02 * 0.51889
555 | + x_03 * 0.35089
556 | + -0.06552
557 | )
558 | h_05_prime = -h_05 + tanh(
559 | h_00 * 1.08793
560 | + h_01 * -0.24559
561 | + h_02 * -0.97041
562 | + h_03 * -0.16874
563 | + h_04 * -0.27133
564 | + h_05 * 0.73513
565 | + h_06 * 0.73096
566 | + h_07 * -2.26244
567 | + x_00 * 0.80312
568 | + x_01 * -1.35635
569 | + x_02 * -0.95283
570 | + x_03 * -0.68460
571 | + 0.11646
572 | )
573 | h_06_prime = -h_06 + tanh(
574 | h_00 * 0.08745
575 | + h_01 * 0.10104
576 | + h_02 * 0.82251
577 | + h_03 * -0.59823
578 | + h_04 * 1.16949
579 | + h_05 * 1.73511
580 | + h_06 * -0.89387
581 | + h_07 * 0.77149
582 | + x_00 * 0.22731
583 | + x_01 * -0.73872
584 | + x_02 * 2.83509
585 | + x_03 * -2.49535
586 | + 0.70343
587 | )
588 | h_07_prime = -h_07 + tanh(
589 | h_00 * -0.89045
590 | + h_01 * -1.30985
591 | + h_02 * 0.48740
592 | + h_03 * -0.45750
593 | + h_04 * -0.70727
594 | + h_05 * -0.43216
595 | + h_06 * -0.29159
596 | + h_07 * 3.70331
597 | + x_00 * -0.13294
598 | + x_01 * -0.23242
599 | + x_02 * 0.33244
600 | + x_03 * -1.45838
601 | + -0.02392
602 | )
603 |
604 | action_00 = tanh(
605 | h_00 * -0.32011
606 | + h_01 * 0.32958
607 | + h_02 * 0.07718
608 | + h_03 * 2.39431
609 | + h_04 * -1.37732
610 | + h_05 * 0.92722
611 | + h_06 * -1.07137
612 | + h_07 * -1.47286
613 | )
614 |
615 | gravity = 9.8
616 | masscart = 1
617 | masspole = 0.1
618 | total_mass = masscart + masspole
619 | length = 0.5
620 | polemass_length = masspole * length
621 | force_mag = 10
622 |
623 | force = force_mag * action_00
624 | costheta = cos(x_02)
625 | sintheta = sin(x_02)
626 | x_dot = x_01
627 | theta_dot = x_03
628 |
629 | temp = (force + polemass_length * theta_dot * theta_dot * sintheta) / total_mass
630 | thetaacc = (gravity * sintheta - costheta * temp) / (
631 | length * (4.0 / 3.0 - masspole * costheta * costheta / total_mass)
632 | )
633 | xacc = temp - polemass_length * thetaacc * costheta / total_mass
634 |
635 | x_00_prime = x_dot
636 | x_01_prime = xacc
637 | x_02_prime = theta_dot
638 | x_03_prime = thetaacc
639 |
640 | system_dynamics = [
641 | x_00_prime,
642 | x_01_prime,
643 | x_02_prime,
644 | x_03_prime,
645 | h_00_prime,
646 | h_01_prime,
647 | h_02_prime,
648 | h_03_prime,
649 | h_04_prime,
650 | h_05_prime,
651 | h_06_prime,
652 | h_07_prime,
653 | ] # has to be in the same order as the input variables
654 | # ===================================================
655 |
656 | return np.array(system_dynamics) # return as numpy array
657 |
658 |
659 | # 12-dimensional cartpole with LTC neural network controller
660 | class CartpoleLTC:
661 | def __init__(self, radius):
662 | # ============ adapt initial values ===========
663 | self.cx = (0, 0, 0.001, 0, 0, 0, 0, 0, 0, 0, 0, 0)
664 | if radius is not None:
665 | self.rad = radius
666 | else:
667 | self.rad = 1e-4
668 | # ===================================================
669 | self.cx = np.array(self.cx, dtype=float)
670 | self.dim = self.cx.size # dimension of the system
671 |
672 | def true_sigmoid(self, x):
673 | return 0.5 * (tanh(x * 0.5) + 1)
674 |
675 | def sigmoid(self, v_pre, mu, sigma):
676 | mues = v_pre - mu
677 | x = sigma * mues
678 | return self.true_sigmoid(x)
679 |
680 | def fdyn(self, t=0, x=None):
681 | if x is None:
682 | x = np.zeros(self.dim, dtype=object)
683 |
684 | # ============ adapt input and system dynamics ===========
685 | x_00, x_01, x_02, x_03, h_00, h_01, h_02, h_03, h_04, h_05, h_06, h_07 = x
686 |
687 | u_00 = x_00 / 0.1
688 | u_01 = x_01 / 0.2
689 | u_02 = x_02 / 0.1
690 | u_03 = x_03 / 0.1
691 |
692 | swa_00 = (
693 | (-3.337809 - h_00) * 0.000100 * self.sigmoid(u_00, -1.822089, 4.920430)
694 | + (-3.793303 - h_00) * 0.217915 * self.sigmoid(u_01, 0.589409, 4.212314)
695 | + (1.074355 - h_00) * 0.000100 * self.sigmoid(u_02, -1.301882, 1.045275)
696 | + (-1.837527 - h_00) * 0.784721 * self.sigmoid(u_03, 0.061974, 9.342714)
697 | )
698 | swa_01 = (
699 | (1.922444 - h_01) * 0.048898 * self.sigmoid(u_00, 0.069332, 4.315088)
700 | + (-2.419416 - h_01) * 0.627589 * self.sigmoid(u_01, 2.746536, 3.680750)
701 | + (3.581162 - h_01) * 0.165919 * self.sigmoid(u_02, -1.045852, 3.066419)
702 | + (1.778340 - h_01) * 0.120764 * self.sigmoid(u_03, 1.837324, 3.448219)
703 | )
704 | swa_02 = (
705 | (2.205145 - h_02) * 0.304337 * self.sigmoid(u_00, 0.031854, 4.853640)
706 | + (-3.098629 - h_02) * 0.084116 * self.sigmoid(u_01, 0.441696, 3.296364)
707 | + (-1.226080 - h_02) * 0.177806 * self.sigmoid(u_02, 2.182991, 1.098078)
708 | + (-2.913366 - h_02) * 0.491636 * self.sigmoid(u_03, 2.224604, 1.728117)
709 | )
710 | swa_03 = (
711 | (0.960911 - h_03) * 0.364263 * self.sigmoid(u_00, -1.053052, 1.971431)
712 | + (-1.638120 - h_03) * 0.162840 * self.sigmoid(u_01, 7.642996, 2.541498)
713 | + (-1.138980 - h_03) * 0.746401 * self.sigmoid(u_02, 1.725237, 0.663583)
714 | + (1.642858 - h_03) * 0.299424 * self.sigmoid(u_03, 0.687038, -0.347229)
715 | )
716 | swa_04 = (
717 | (-0.752243 - h_04) * 0.649805 * self.sigmoid(u_00, 5.850876, 4.392767)
718 | + (0.135855 - h_04) * 0.018294 * self.sigmoid(u_01, 3.993045, 2.950581)
719 | + (-0.291421 - h_04) * 0.404132 * self.sigmoid(u_02, -0.929036, 3.571369)
720 | + (-3.123502 - h_04) * 0.521996 * self.sigmoid(u_03, -0.641402, 7.249665)
721 | )
722 | swa_05 = (
723 | (-0.472946 - h_05) * 0.359657 * self.sigmoid(u_00, 1.851295, 6.572009)
724 | + (-0.733541 - h_05) * 0.382768 * self.sigmoid(u_01, 2.120606, 3.201269)
725 | + (1.819771 - h_05) * 0.223484 * self.sigmoid(u_02, 2.934255, 4.131283)
726 | + (0.080200 - h_05) * 0.108734 * self.sigmoid(u_03, -2.828882, 1.771147)
727 | )
728 | swa_06 = (
729 | (2.596139 - h_06) * 0.233149 * self.sigmoid(u_00, 1.404586, 5.493755)
730 | + (-0.948662 - h_06) * 0.297568 * self.sigmoid(u_01, -5.682380, 3.978102)
731 | + (4.558042 - h_06) * 0.736431 * self.sigmoid(u_02, -0.067011, 3.915714)
732 | + (1.514289 - h_06) * 0.226179 * self.sigmoid(u_03, -1.101492, 5.705000)
733 | )
734 | swa_07 = (
735 | (0.907580 - h_07) * 0.077591 * self.sigmoid(u_00, 0.822831, 2.474382)
736 | + (-2.282039 - h_07) * 0.161999 * self.sigmoid(u_01, -0.241190, 6.074738)
737 | + (0.294892 - h_07) * 0.202664 * self.sigmoid(u_02, -0.289602, 8.727577)
738 | + (-2.329513 - h_07) * 0.242021 * self.sigmoid(u_03, -0.960229, 4.780694)
739 | )
740 | wa_00 = (
741 | (0.425183 - h_00) * 0.580120 * self.sigmoid(h_00, 1.888106, 5.698736)
742 | + (-0.026397 - h_00) * 0.422325 * self.sigmoid(h_01, 0.322996, 3.287380)
743 | + (0.232258 - h_00) * 0.234096 * self.sigmoid(h_02, 1.988594, 5.637971)
744 | + (0.325219 - h_00) * 0.822936 * self.sigmoid(h_03, 1.430856, 1.847167)
745 | + (-3.234136 - h_00) * 0.143390 * self.sigmoid(h_04, 1.693733, 1.529007)
746 | + (-0.446780 - h_00) * 0.000010 * self.sigmoid(h_05, 4.226530, 1.182250)
747 | + (-1.355023 - h_00) * 0.605004 * self.sigmoid(h_06, 1.083694, 2.011393)
748 | + (-3.184705 - h_00) * 0.338830 * self.sigmoid(h_07, 5.826199, 2.559444)
749 | )
750 | wa_01 = (
751 | (-1.902008 - h_01) * 0.151651 * self.sigmoid(h_00, 1.296603, 3.500569)
752 | + (-0.865412 - h_01) * 0.042350 * self.sigmoid(h_01, 1.769154, 10.179186)
753 | + (-3.485676 - h_01) * 0.259284 * self.sigmoid(h_02, 1.534717, 3.293725)
754 | + (-1.715340 - h_01) * 0.197946 * self.sigmoid(h_03, 0.264122, 5.787919)
755 | + (-3.960100 - h_01) * 0.346343 * self.sigmoid(h_04, 4.226974, 5.074825)
756 | + (-1.888838 - h_01) * 0.167991 * self.sigmoid(h_05, -0.192584, 4.463276)
757 | + (3.343740 - h_01) * 0.407932 * self.sigmoid(h_06, -2.019538, 2.096939)
758 | + (-3.828055 - h_01) * 0.392470 * self.sigmoid(h_07, -1.217940, 3.255837)
759 | )
760 | wa_02 = (
761 | (-0.149339 - h_02) * 0.274537 * self.sigmoid(h_00, -0.229752, 0.488301)
762 | + (0.411455 - h_02) * 0.026122 * self.sigmoid(h_01, 4.011376, 4.030908)
763 | + (3.896320 - h_02) * 0.157079 * self.sigmoid(h_02, 2.459823, 3.892246)
764 | + (0.530434 - h_02) * 0.059498 * self.sigmoid(h_03, 1.132183, 1.906741)
765 | + (-0.194996 - h_02) * 0.087192 * self.sigmoid(h_04, 0.000571, 0.949683)
766 | + (-0.407292 - h_02) * 0.194082 * self.sigmoid(h_05, -0.997799, 1.219973)
767 | + (-6.620962 - h_02) * 0.398726 * self.sigmoid(h_06, 0.469706, 5.065333)
768 | + (1.787329 - h_02) * 1.050713 * self.sigmoid(h_07, 4.342476, -0.298296)
769 | )
770 | wa_03 = (
771 | (-5.372879 - h_03) * 0.027963 * self.sigmoid(h_00, 2.814907, 7.856180)
772 | + (-0.646331 - h_03) * 0.149699 * self.sigmoid(h_01, 0.987782, 2.744861)
773 | + (0.116237 - h_03) * 0.135569 * self.sigmoid(h_02, 4.085198, 3.937947)
774 | + (-2.523393 - h_03) * 0.165372 * self.sigmoid(h_03, 0.703319, 5.054082)
775 | + (0.815988 - h_03) * 0.189418 * self.sigmoid(h_04, 3.352685, 7.961621)
776 | + (-1.706797 - h_03) * 0.502507 * self.sigmoid(h_05, -0.760482, 5.458648)
777 | + (-2.023418 - h_03) * 0.026645 * self.sigmoid(h_06, 4.557816, 2.245746)
778 | + (-0.889180 - h_03) * 0.468236 * self.sigmoid(h_07, -2.680705, 7.208436)
779 | )
780 | wa_04 = (
781 | (1.343493 - h_04) * 0.603259 * self.sigmoid(h_00, -0.901840, 3.872216)
782 | + (0.831619 - h_04) * 0.426593 * self.sigmoid(h_01, -3.089995, 5.196197)
783 | + (0.421409 - h_04) * 0.068839 * self.sigmoid(h_02, 3.991761, 0.784316)
784 | + (2.355011 - h_04) * 0.018658 * self.sigmoid(h_03, -0.024262, 7.865137)
785 | + (0.593414 - h_04) * 0.252146 * self.sigmoid(h_04, -1.463303, 3.132787)
786 | + (-0.637192 - h_04) * 0.118786 * self.sigmoid(h_05, -1.024096, 4.974855)
787 | + (-0.732044 - h_04) * 0.588030 * self.sigmoid(h_06, 1.033544, 6.298468)
788 | + (-0.194411 - h_04) * 0.679828 * self.sigmoid(h_07, 0.410277, 8.185476)
789 | )
790 | wa_05 = (
791 | (0.495678 - h_05) * 0.338934 * self.sigmoid(h_00, 1.896829, 3.323362)
792 | + (0.680690 - h_05) * 0.558593 * self.sigmoid(h_01, -0.032828, 3.472205)
793 | + (1.564447 - h_05) * 0.416106 * self.sigmoid(h_02, -0.448733, 4.371096)
794 | + (0.896225 - h_05) * 0.102958 * self.sigmoid(h_03, -1.402512, 7.281286)
795 | + (3.866403 - h_05) * 0.388992 * self.sigmoid(h_04, -1.392615, 1.017738)
796 | + (-1.880220 - h_05) * 0.075561 * self.sigmoid(h_05, 0.562915, 6.004653)
797 | + (-5.541122 - h_05) * 0.279411 * self.sigmoid(h_06, -0.679593, 2.001722)
798 | + (-0.065618 - h_05) * 0.475569 * self.sigmoid(h_07, -1.229657, 3.177203)
799 | )
800 | wa_06 = (
801 | (-1.465561 - h_06) * 0.122170 * self.sigmoid(h_00, 0.706757, 2.401071)
802 | + (-2.485563 - h_06) * 0.091596 * self.sigmoid(h_01, 0.335557, 3.252133)
803 | + (1.158366 - h_06) * 0.519633 * self.sigmoid(h_02, -1.169937, 3.588379)
804 | + (1.955986 - h_06) * 0.252199 * self.sigmoid(h_03, -1.974608, 2.870511)
805 | + (0.743124 - h_06) * 0.651011 * self.sigmoid(h_04, -0.210935, 7.228260)
806 | + (3.452268 - h_06) * 0.019733 * self.sigmoid(h_05, 3.515318, 3.879946)
807 | + (0.604325 - h_06) * 0.191627 * self.sigmoid(h_06, -3.070420, 4.221974)
808 | + (2.465013 - h_06) * 0.143574 * self.sigmoid(h_07, -0.961958, 3.002333)
809 | )
810 | wa_07 = (
811 | (-2.135510 - h_07) * 0.082957 * self.sigmoid(h_00, 0.857560, 4.560907)
812 | + (-0.719410 - h_07) * 0.121560 * self.sigmoid(h_01, 2.004791, 7.674825)
813 | + (-0.095160 - h_07) * 0.085265 * self.sigmoid(h_02, 0.915076, 3.384454)
814 | + (-6.402920 - h_07) * 0.017902 * self.sigmoid(h_03, 1.900790, 3.606339)
815 | + (0.911732 - h_07) * 0.018167 * self.sigmoid(h_04, 0.001252, 7.378105)
816 | + (-0.876065 - h_07) * 0.119567 * self.sigmoid(h_05, 2.130298, 7.560087)
817 | + (-1.520981 - h_07) * 0.012488 * self.sigmoid(h_06, -0.275924, 2.646271)
818 | + (0.734387 - h_07) * 0.088953 * self.sigmoid(h_07, -1.333799, 3.340692)
819 | )
820 | h_00_prime = 10000000.000000 * (0.663909 * (4.880308 - h_00) + swa_00 + wa_00)
821 | h_01_prime = 0.897653 * (2.666445 * (-0.506661 - h_01) + swa_01 + wa_01)
822 | h_02_prime = 0.513291 * (4.721010 * (-1.314806 - h_02) + swa_02 + wa_02)
823 | h_03_prime = 0.301197 * (0.073180 * (0.044130 - h_03) + swa_03 + wa_03)
824 | h_04_prime = 0.404649 * (1.368898 * (-1.209986 - h_04) + swa_04 + wa_04)
825 | h_05_prime = 1.131324 * (0.000010 * (-2.630356 - h_05) + swa_05 + wa_05)
826 | h_06_prime = 10000000.000000 * (0.539536 * (-0.529881 - h_06) + swa_06 + wa_06)
827 | h_07_prime = 0.525138 * (0.507019 * (0.567584 - h_07) + swa_07 + wa_07)
828 |
829 | action_00 = tanh(
830 | h_00 * -0.895448
831 | + h_01 * -0.795233
832 | + h_02 * 0.209421
833 | + h_03 * 0.143533
834 | + h_04 * -0.084270
835 | + h_05 * -0.342232
836 | + h_06 * 0.333434
837 | + h_07 * -0.032306
838 | + 0.757913
839 | )
840 |
841 | gravity = 9.8
842 | masscart = 1
843 | masspole = 0.1
844 | total_mass = masscart + masspole
845 | length = 0.5
846 | polemass_length = masspole * length
847 | force_mag = 10
848 |
849 | force = force_mag * action_00
850 | costheta = cos(x_02)
851 | sintheta = sin(x_02)
852 | x_dot = x_01
853 | theta_dot = x_03
854 |
855 | temp = (force + polemass_length * theta_dot * theta_dot * sintheta) / total_mass
856 | thetaacc = (gravity * sintheta - costheta * temp) / (
857 | length * (4.0 / 3.0 - masspole * costheta * costheta / total_mass)
858 | )
859 | xacc = temp - polemass_length * thetaacc * costheta / total_mass
860 |
861 | x_00_prime = x_dot
862 | x_01_prime = xacc
863 | x_02_prime = theta_dot
864 | x_03_prime = thetaacc
865 |
866 | system_dynamics = [
867 | x_00_prime,
868 | x_01_prime,
869 | x_02_prime,
870 | x_03_prime,
871 | h_00_prime,
872 | h_01_prime,
873 | h_02_prime,
874 | h_03_prime,
875 | h_04_prime,
876 | h_05_prime,
877 | h_06_prime,
878 | h_07_prime,
879 | ] # has to be in the same order as the input variables
880 | # ===================================================
881 |
882 | return np.array(system_dynamics) # return as numpy array
883 |
884 |
885 | # 12-dimensional cartpole with LTC neural network controller (trained with RK integrator)
886 | class CartpoleLTC_RK:
887 | def __init__(self, radius):
888 | # ============ adapt initial values ===========
889 | self.cx = (0, 0, 0.001, 0, 0, 0, 0, 0, 0, 0, 0, 0)
890 | if radius is not None:
891 | self.rad = radius
892 | else:
893 | self.rad = 1e-4
894 | # ===================================================
895 | self.cx = np.array(self.cx, dtype=float)
896 | self.dim = self.cx.size # dimension of the system
897 |
898 | def true_sigmoid(self, x):
899 | return 0.5 * (tanh(x * 0.5) + 1)
900 |
901 | def sigmoid(self, v_pre, mu, sigma):
902 | mues = v_pre - mu
903 | x = sigma * mues
904 | return self.true_sigmoid(x)
905 |
906 | def fdyn(self, t=0, x=None):
907 | if x is None:
908 | x = np.zeros(self.dim, dtype=object)
909 |
910 | # ============ adapt input and system dynamics ===========
911 | x_00, x_01, x_02, x_03, h_00, h_01, h_02, h_03, h_04, h_05, h_06, h_07 = x
912 |
913 | u_00 = x_00 / 0.1
914 | u_01 = x_01 / 0.2
915 | u_02 = x_02 / 0.1
916 | u_03 = x_03 / 0.1
917 |
918 | swa_00 = (
919 | (-0.431604 - h_00)
920 | * 0.072614
921 | * 0.5
922 | * (tanh(0.5 * (u_00 - (0.487739)) * 5.650091) + 1)
923 | + (-1.100827 - h_00)
924 | * 0.581228
925 | * 0.5
926 | * (tanh(0.5 * (u_01 - (1.385196)) * 3.634596) + 1)
927 | + (0.174867 - h_00)
928 | * 0.111020
929 | * 0.5
930 | * (tanh(0.5 * (u_02 - (-0.993259)) * 5.331078) + 1)
931 | + (-0.510866 - h_00)
932 | * 0.189538
933 | * 0.5
934 | * (tanh(0.5 * (u_03 - (-1.174057)) * 5.951439) + 1)
935 | )
936 | swa_01 = (
937 | (-0.596286 - h_01)
938 | * 0.304088
939 | * 0.5
940 | * (tanh(0.5 * (u_00 - (-0.441833)) * 3.592494) + 1)
941 | + (-1.364699 - h_01)
942 | * 0.266900
943 | * 0.5
944 | * (tanh(0.5 * (u_01 - (-0.104207)) * 4.467682) + 1)
945 | + (0.723861 - h_01)
946 | * 0.282621
947 | * 0.5
948 | * (tanh(0.5 * (u_02 - (0.735169)) * 4.735680) + 1)
949 | + (-2.003241 - h_01)
950 | * 0.162848
951 | * 0.5
952 | * (tanh(0.5 * (u_03 - (1.824110)) * 4.624085) + 1)
953 | )
954 | swa_02 = (
955 | (-2.213958 - h_02)
956 | * 0.050379
957 | * 0.5
958 | * (tanh(0.5 * (u_00 - (1.028516)) * 4.883958) + 1)
959 | + (1.837608 - h_02)
960 | * 0.000100
961 | * 0.5
962 | * (tanh(0.5 * (u_01 - (0.476658)) * 3.920279) + 1)
963 | + (-0.869022 - h_02)
964 | * 0.000100
965 | * 0.5
966 | * (tanh(0.5 * (u_02 - (0.146460)) * 4.128851) + 1)
967 | + (0.970412 - h_02)
968 | * 0.142598
969 | * 0.5
970 | * (tanh(0.5 * (u_03 - (-0.057289)) * 4.129013) + 1)
971 | )
972 | swa_03 = (
973 | (3.316434 - h_03)
974 | * 0.000100
975 | * 0.5
976 | * (tanh(0.5 * (u_00 - (-0.281031)) * 4.200078) + 1)
977 | + (-0.286490 - h_03)
978 | * 0.136481
979 | * 0.5
980 | * (tanh(0.5 * (u_01 - (2.742276)) * 3.584777) + 1)
981 | + (-0.305795 - h_03)
982 | * 0.186325
983 | * 0.5
984 | * (tanh(0.5 * (u_02 - (1.382006)) * 4.178387) + 1)
985 | + (2.739162 - h_03)
986 | * 0.059435
987 | * 0.5
988 | * (tanh(0.5 * (u_03 - (-0.377228)) * 3.719664) + 1)
989 | )
990 | swa_04 = (
991 | (-1.856016 - h_04)
992 | * 0.031760
993 | * 0.5
994 | * (tanh(0.5 * (u_00 - (-0.689140)) * 3.121773) + 1)
995 | + (-1.060924 - h_04)
996 | * 0.233160
997 | * 0.5
998 | * (tanh(0.5 * (u_01 - (0.130687)) * 2.643270) + 1)
999 | + (0.459155 - h_04)
1000 | * 0.143452
1001 | * 0.5
1002 | * (tanh(0.5 * (u_02 - (0.574488)) * 1.886950) + 1)
1003 | + (-0.358857 - h_04)
1004 | * 0.100690
1005 | * 0.5
1006 | * (tanh(0.5 * (u_03 - (1.888403)) * 4.933154) + 1)
1007 | )
1008 | swa_05 = (
1009 | (1.980404 - h_05)
1010 | * 0.323992
1011 | * 0.5
1012 | * (tanh(0.5 * (u_00 - (1.983044)) * 4.761758) + 1)
1013 | + (-1.479298 - h_05)
1014 | * 0.058762
1015 | * 0.5
1016 | * (tanh(0.5 * (u_01 - (0.921888)) * 6.088301) + 1)
1017 | + (1.952959 - h_05)
1018 | * 0.088192
1019 | * 0.5
1020 | * (tanh(0.5 * (u_02 - (0.347450)) * 4.452416) + 1)
1021 | + (-1.889391 - h_05)
1022 | * 0.035840
1023 | * 0.5
1024 | * (tanh(0.5 * (u_03 - (-1.054210)) * 5.084546) + 1)
1025 | )
1026 | swa_06 = (
1027 | (-0.891564 - h_06)
1028 | * 0.238208
1029 | * 0.5
1030 | * (tanh(0.5 * (u_00 - (0.668283)) * 4.712173) + 1)
1031 | + (-0.434193 - h_06)
1032 | * 0.102071
1033 | * 0.5
1034 | * (tanh(0.5 * (u_01 - (-0.555892)) * 3.985909) + 1)
1035 | + (-1.002575 - h_06)
1036 | * 0.211428
1037 | * 0.5
1038 | * (tanh(0.5 * (u_02 - (1.759687)) * 5.538335) + 1)
1039 | + (0.961403 - h_06)
1040 | * 0.000100
1041 | * 0.5
1042 | * (tanh(0.5 * (u_03 - (-0.823010)) * 5.778478) + 1)
1043 | )
1044 | swa_07 = (
1045 | (-2.280897 - h_07)
1046 | * 0.055498
1047 | * 0.5
1048 | * (tanh(0.5 * (u_00 - (0.378455)) * 3.488947) + 1)
1049 | + (3.006480 - h_07)
1050 | * 0.026770
1051 | * 0.5
1052 | * (tanh(0.5 * (u_01 - (-0.096777)) * 2.703730) + 1)
1053 | + (0.428589 - h_07)
1054 | * 0.131216
1055 | * 0.5
1056 | * (tanh(0.5 * (u_02 - (-0.058215)) * 4.788692) + 1)
1057 | + (-0.848094 - h_07)
1058 | * 0.156781
1059 | * 0.5
1060 | * (tanh(0.5 * (u_03 - (0.870154)) * 5.505928) + 1)
1061 | )
1062 | wa_00 = (
1063 | (-2.061894 - h_00)
1064 | * 0.255103
1065 | * 0.5
1066 | * (tanh(0.5 * (h_00 - (-0.599924)) * 4.079573) + 1)
1067 | + (-2.919805 - h_00)
1068 | * 0.111550
1069 | * 0.5
1070 | * (tanh(0.5 * (h_01 - (-0.003773)) * 3.577720) + 1)
1071 | + (-1.169591 - h_00)
1072 | * 0.076621
1073 | * 0.5
1074 | * (tanh(0.5 * (h_02 - (0.268934)) * 3.482080) + 1)
1075 | + (2.610878 - h_00)
1076 | * 0.358292
1077 | * 0.5
1078 | * (tanh(0.5 * (h_03 - (-0.801209)) * 1.404748) + 1)
1079 | + (0.464382 - h_00)
1080 | * 0.396781
1081 | * 0.5
1082 | * (tanh(0.5 * (h_04 - (-0.353826)) * 6.078532) + 1)
1083 | + (-1.108612 - h_00)
1084 | * 0.039682
1085 | * 0.5
1086 | * (tanh(0.5 * (h_05 - (-0.083197)) * 4.031213) + 1)
1087 | + (-1.782495 - h_00)
1088 | * 0.192295
1089 | * 0.5
1090 | * (tanh(0.5 * (h_06 - (-0.135476)) * 3.405039) + 1)
1091 | + (-0.745238 - h_00)
1092 | * 0.224384
1093 | * 0.5
1094 | * (tanh(0.5 * (h_07 - (0.838400)) * 3.514007) + 1)
1095 | )
1096 | wa_01 = (
1097 | (0.323076 - h_01)
1098 | * 0.132399
1099 | * 0.5
1100 | * (tanh(0.5 * (h_00 - (1.430406)) * 3.536494) + 1)
1101 | + (-1.075827 - h_01)
1102 | * 0.183413
1103 | * 0.5
1104 | * (tanh(0.5 * (h_01 - (0.444819)) * 3.535979) + 1)
1105 | + (0.727482 - h_01)
1106 | * 0.179833
1107 | * 0.5
1108 | * (tanh(0.5 * (h_02 - (1.784908)) * 4.561027) + 1)
1109 | + (0.138851 - h_01)
1110 | * 0.154461
1111 | * 0.5
1112 | * (tanh(0.5 * (h_03 - (1.996656)) * 5.773379) + 1)
1113 | + (-0.711774 - h_01)
1114 | * 0.143956
1115 | * 0.5
1116 | * (tanh(0.5 * (h_04 - (-1.484266)) * 2.894240) + 1)
1117 | + (-0.365611 - h_01)
1118 | * 0.268626
1119 | * 0.5
1120 | * (tanh(0.5 * (h_05 - (0.714121)) * 1.408464) + 1)
1121 | + (1.591881 - h_01)
1122 | * 0.189146
1123 | * 0.5
1124 | * (tanh(0.5 * (h_06 - (2.912474)) * 3.563775) + 1)
1125 | + (2.669420 - h_01)
1126 | * 0.000010
1127 | * 0.5
1128 | * (tanh(0.5 * (h_07 - (0.507458)) * 3.459168) + 1)
1129 | )
1130 | wa_02 = (
1131 | (2.356203 - h_02)
1132 | * 0.072446
1133 | * 0.5
1134 | * (tanh(0.5 * (h_00 - (0.818119)) * 4.080008) + 1)
1135 | + (0.534163 - h_02)
1136 | * 0.645405
1137 | * 0.5
1138 | * (tanh(0.5 * (h_01 - (0.951815)) * 4.642598) + 1)
1139 | + (0.577408 - h_02)
1140 | * 0.038336
1141 | * 0.5
1142 | * (tanh(0.5 * (h_02 - (-0.348300)) * 2.439894) + 1)
1143 | + (-1.214589 - h_02)
1144 | * 0.310216
1145 | * 0.5
1146 | * (tanh(0.5 * (h_03 - (0.429035)) * 5.454509) + 1)
1147 | + (-1.883640 - h_02)
1148 | * 0.026713
1149 | * 0.5
1150 | * (tanh(0.5 * (h_04 - (1.463999)) * 3.258402) + 1)
1151 | + (-1.458897 - h_02)
1152 | * 0.054829
1153 | * 0.5
1154 | * (tanh(0.5 * (h_05 - (2.008786)) * 1.817449) + 1)
1155 | + (0.503231 - h_02)
1156 | * 0.215241
1157 | * 0.5
1158 | * (tanh(0.5 * (h_06 - (-0.121074)) * 4.956780) + 1)
1159 | + (-0.185760 - h_02)
1160 | * 0.202772
1161 | * 0.5
1162 | * (tanh(0.5 * (h_07 - (-0.237805)) * 3.949310) + 1)
1163 | )
1164 | wa_03 = (
1165 | (2.190941 - h_03)
1166 | * 0.264046
1167 | * 0.5
1168 | * (tanh(0.5 * (h_00 - (2.558471)) * 6.171836) + 1)
1169 | + (0.049259 - h_03)
1170 | * 0.016450
1171 | * 0.5
1172 | * (tanh(0.5 * (h_01 - (-0.471219)) * 3.276491) + 1)
1173 | + (1.897136 - h_03)
1174 | * 0.093556
1175 | * 0.5
1176 | * (tanh(0.5 * (h_02 - (-0.834410)) * 2.282021) + 1)
1177 | + (-1.203204 - h_03)
1178 | * 0.016330
1179 | * 0.5
1180 | * (tanh(0.5 * (h_03 - (-0.229405)) * 2.725281) + 1)
1181 | + (2.328495 - h_03)
1182 | * 0.032663
1183 | * 0.5
1184 | * (tanh(0.5 * (h_04 - (2.172752)) * 1.766576) + 1)
1185 | + (1.529989 - h_03)
1186 | * 0.128894
1187 | * 0.5
1188 | * (tanh(0.5 * (h_05 - (-0.905643)) * 4.200280) + 1)
1189 | + (-0.033709 - h_03)
1190 | * 0.168944
1191 | * 0.5
1192 | * (tanh(0.5 * (h_06 - (0.228327)) * 1.674565) + 1)
1193 | + (-0.903095 - h_03)
1194 | * 0.101751
1195 | * 0.5
1196 | * (tanh(0.5 * (h_07 - (0.474075)) * 3.122658) + 1)
1197 | )
1198 | wa_04 = (
1199 | (1.356759 - h_04)
1200 | * 0.207902
1201 | * 0.5
1202 | * (tanh(0.5 * (h_00 - (-0.832245)) * 5.076437) + 1)
1203 | + (-1.713048 - h_04)
1204 | * 0.253930
1205 | * 0.5
1206 | * (tanh(0.5 * (h_01 - (-0.361454)) * 6.056005) + 1)
1207 | + (1.365151 - h_04)
1208 | * 0.025809
1209 | * 0.5
1210 | * (tanh(0.5 * (h_02 - (2.691158)) * 2.098762) + 1)
1211 | + (-1.825929 - h_04)
1212 | * 0.039653
1213 | * 0.5
1214 | * (tanh(0.5 * (h_03 - (-1.585948)) * 3.676085) + 1)
1215 | + (-0.728560 - h_04)
1216 | * 0.012267
1217 | * 0.5
1218 | * (tanh(0.5 * (h_04 - (-1.516692)) * 5.575822) + 1)
1219 | + (-1.497486 - h_04)
1220 | * 0.037167
1221 | * 0.5
1222 | * (tanh(0.5 * (h_05 - (0.021516)) * 5.903266) + 1)
1223 | + (-2.542053 - h_04)
1224 | * 0.181942
1225 | * 0.5
1226 | * (tanh(0.5 * (h_06 - (0.198382)) * 3.954360) + 1)
1227 | + (3.068608 - h_04)
1228 | * 0.095478
1229 | * 0.5
1230 | * (tanh(0.5 * (h_07 - (0.247832)) * 4.194288) + 1)
1231 | )
1232 | wa_05 = (
1233 | (0.924399 - h_05)
1234 | * 0.000010
1235 | * 0.5
1236 | * (tanh(0.5 * (h_00 - (1.051798)) * 4.877949) + 1)
1237 | + (0.179591 - h_05)
1238 | * 0.082066
1239 | * 0.5
1240 | * (tanh(0.5 * (h_01 - (0.472982)) * 5.943190) + 1)
1241 | + (0.422308 - h_05)
1242 | * 0.361843
1243 | * 0.5
1244 | * (tanh(0.5 * (h_02 - (0.984128)) * 4.885800) + 1)
1245 | + (-1.484202 - h_05)
1246 | * 0.660023
1247 | * 0.5
1248 | * (tanh(0.5 * (h_03 - (2.561464)) * 3.960099) + 1)
1249 | + (3.444595 - h_05)
1250 | * 0.269937
1251 | * 0.5
1252 | * (tanh(0.5 * (h_04 - (2.765615)) * 3.529421) + 1)
1253 | + (-0.460907 - h_05)
1254 | * 0.018725
1255 | * 0.5
1256 | * (tanh(0.5 * (h_05 - (0.951335)) * 7.133490) + 1)
1257 | + (-1.438883 - h_05)
1258 | * 0.135012
1259 | * 0.5
1260 | * (tanh(0.5 * (h_06 - (2.470155)) * 5.090559) + 1)
1261 | + (-0.022624 - h_05)
1262 | * 0.030127
1263 | * 0.5
1264 | * (tanh(0.5 * (h_07 - (-0.370460)) * 4.233265) + 1)
1265 | )
1266 | wa_06 = (
1267 | (-2.204599 - h_06)
1268 | * 0.144977
1269 | * 0.5
1270 | * (tanh(0.5 * (h_00 - (2.253059)) * 3.878229) + 1)
1271 | + (0.148456 - h_06)
1272 | * 0.270476
1273 | * 0.5
1274 | * (tanh(0.5 * (h_01 - (2.298548)) * 5.617167) + 1)
1275 | + (0.544189 - h_06)
1276 | * 0.187231
1277 | * 0.5
1278 | * (tanh(0.5 * (h_02 - (1.738353)) * 3.412084) + 1)
1279 | + (0.957022 - h_06)
1280 | * 0.091821
1281 | * 0.5
1282 | * (tanh(0.5 * (h_03 - (2.433656)) * 1.501055) + 1)
1283 | + (0.305039 - h_06)
1284 | * 0.000010
1285 | * 0.5
1286 | * (tanh(0.5 * (h_04 - (2.545896)) * 2.064504) + 1)
1287 | + (2.075277 - h_06)
1288 | * 0.172349
1289 | * 0.5
1290 | * (tanh(0.5 * (h_05 - (0.215808)) * 4.761319) + 1)
1291 | + (0.651300 - h_06)
1292 | * 0.392033
1293 | * 0.5
1294 | * (tanh(0.5 * (h_06 - (-2.850508)) * 5.109767) + 1)
1295 | + (0.305318 - h_06)
1296 | * 0.056583
1297 | * 0.5
1298 | * (tanh(0.5 * (h_07 - (-0.545839)) * 5.304969) + 1)
1299 | )
1300 | wa_07 = (
1301 | (-2.419794 - h_07)
1302 | * 0.164188
1303 | * 0.5
1304 | * (tanh(0.5 * (h_00 - (0.646132)) * 5.431365) + 1)
1305 | + (-1.410748 - h_07)
1306 | * 0.128581
1307 | * 0.5
1308 | * (tanh(0.5 * (h_01 - (1.941283)) * 2.716641) + 1)
1309 | + (0.566556 - h_07)
1310 | * 0.189314
1311 | * 0.5
1312 | * (tanh(0.5 * (h_02 - (2.011078)) * 6.978758) + 1)
1313 | + (-0.999432 - h_07)
1314 | * 0.038773
1315 | * 0.5
1316 | * (tanh(0.5 * (h_03 - (2.175668)) * 2.841975) + 1)
1317 | + (-1.244396 - h_07)
1318 | * 0.024732
1319 | * 0.5
1320 | * (tanh(0.5 * (h_04 - (-1.121634)) * 4.267510) + 1)
1321 | + (0.211258 - h_07)
1322 | * 0.069162
1323 | * 0.5
1324 | * (tanh(0.5 * (h_05 - (0.894198)) * 2.127854) + 1)
1325 | + (2.880166 - h_07)
1326 | * 0.023222
1327 | * 0.5
1328 | * (tanh(0.5 * (h_06 - (0.345695)) * 3.201268) + 1)
1329 | + (-2.649584 - h_07)
1330 | * 0.348009
1331 | * 0.5
1332 | * (tanh(0.5 * (h_07 - (0.311026)) * 3.550425) + 1)
1333 | )
1334 | h_00_prime = 0.472549 * (0.026911 * (-0.601392 - h_00) + swa_00 + wa_00)
1335 | h_01_prime = 0.422074 * (0.001000 * (1.909540 - h_01) + swa_01 + wa_01)
1336 | h_02_prime = 0.316776 * (1.660939 * (-0.450735 - h_02) + swa_02 + wa_02)
1337 | h_03_prime = 0.484111 * (0.144001 * (1.478572 - h_03) + swa_03 + wa_03)
1338 | h_04_prime = 0.414692 * (1.029787 * (-1.105074 - h_04) + swa_04 + wa_04)
1339 | h_05_prime = 0.358391 * (0.218205 * (-1.095135 - h_05) + swa_05 + wa_05)
1340 | h_06_prime = 0.880162 * (0.268332 * (0.043464 - h_06) + swa_06 + wa_06)
1341 | h_07_prime = 1.671041 * (0.153414 * (0.695311 - h_07) + swa_07 + wa_07)
1342 |
1343 | action_00 = tanh(
1344 | h_00 * 0.135081
1345 | + h_01 * 0.189091
1346 | + h_02 * -0.350403
1347 | + h_03 * 0.388546
1348 | + h_04 * 0.397087
1349 | + h_05 * 0.186882
1350 | + h_06 * -0.047590
1351 | + h_07 * 0.197737
1352 | + -0.352130
1353 | )
1354 |
1355 | gravity = 9.8
1356 | masscart = 1
1357 | masspole = 0.1
1358 | total_mass = masscart + masspole
1359 | length = 0.5
1360 | polemass_length = masspole * length
1361 | force_mag = 10
1362 |
1363 | force = force_mag * action_00
1364 | costheta = cos(x_02)
1365 | sintheta = sin(x_02)
1366 | x_dot = x_01
1367 | theta_dot = x_03
1368 |
1369 | temp = (force + polemass_length * theta_dot * theta_dot * sintheta) / total_mass
1370 | thetaacc = (gravity * sintheta - costheta * temp) / (
1371 | length * (4.0 / 3.0 - masspole * costheta * costheta / total_mass)
1372 | )
1373 | xacc = temp - polemass_length * thetaacc * costheta / total_mass
1374 |
1375 | x_00_prime = x_dot
1376 | x_01_prime = xacc
1377 | x_02_prime = theta_dot
1378 | x_03_prime = thetaacc
1379 |
1380 | system_dynamics = [
1381 | x_00_prime,
1382 | x_01_prime,
1383 | x_02_prime,
1384 | x_03_prime,
1385 | h_00_prime,
1386 | h_01_prime,
1387 | h_02_prime,
1388 | h_03_prime,
1389 | h_04_prime,
1390 | h_05_prime,
1391 | h_06_prime,
1392 | h_07_prime,
1393 | ] # has to be in the same order as the input variables
1394 | # ===================================================
1395 |
1396 | return np.array(system_dynamics) # return as numpy array
1397 |
1398 |
1399 | class TestNODE:
1400 | def __init__(self, radius):
1401 | # ============ adapt initial values ===========
1402 | self.cx = (0, 0)
1403 | if radius is not None:
1404 | self.rad = radius
1405 | else:
1406 | self.rad = 1e-4
1407 | # ===================================================
1408 | self.cx = np.array(self.cx, dtype=float)
1409 | self.dim = self.cx.size # dimension of the system
1410 |
1411 | def fdyn(self, t=0, x=None):
1412 | if x is None:
1413 | x = np.zeros(self.dim, dtype=object)
1414 |
1415 | # ============ adapt input and system dynamics ===========
1416 | x0, x1 = x
1417 | x0 = 0.5889 * tanh(0.4256 * x0 + 0.5061 * x1 + 0.1773) - 0.1000 * x0
1418 | x1 = 0.3857 * tanh(-0.5563 * x0 + -0.1262 * x1 + -0.2136) - 0.1000 * x1
1419 | system_dynamics = [x0, x1]
1420 | # ===================================================
1421 | return np.array(system_dynamics) # return as numpy array
1422 |
1423 |
1424 | class LDSwithCTRNN:
1425 | def __init__(self, radius):
1426 | # ============ adapt initial values ===========
1427 | if radius is not None:
1428 | self.rad = radius
1429 | else:
1430 | self.rad = 0.5
1431 | # ===================================================
1432 | self.cx = np.zeros(10)
1433 | self.dim = self.cx.size # dimension of the system
1434 | arr = np.load("rl/lds_ctrnn.npz")
1435 | self.params = {k: arr[k] for k in arr.files}
1436 |
1437 | def fdyn(self, t=0, x=None):
1438 | if x is None:
1439 | x = np.zeros(self.dim, dtype=object)
1440 |
1441 | hidden = np.tanh(np.dot(x, self.params["w1"]) + self.params["b1"])
1442 | dhdt = np.dot(hidden, self.params["w2"]) + self.params["b2"]
1443 |
1444 | action = np.tanh(np.dot(hidden, self.params["wa"]) + self.params["ba"])
1445 | x, y = x[-2:]
1446 |
1447 | dxdt = y
1448 | dydt = 0.2 + 0.4 * action
1449 |
1450 | dxdt = np.array([dxdt]).reshape((1,))
1451 | dydt = np.array([dydt]).reshape((1,))
1452 | dfdt = np.concatenate([dhdt, dxdt, dydt], axis=0)
1453 | return dfdt
1454 |
1455 |
1456 | class PendulumwithCTRNN:
1457 | def __init__(self, radius):
1458 | # ============ adapt initial values ===========
1459 | if radius is not None:
1460 | self.rad = radius
1461 | else:
1462 | self.rad = 0.5
1463 | # ===================================================
1464 | self.cx = np.zeros(10)
1465 | self.dim = self.cx.size # dimension of the system
1466 | arr = np.load("rl/pendulum_ctrnn.npz")
1467 | self.params = {k: arr[k] for k in arr.files}
1468 |
1469 | def fdyn(self, t=0, x=None):
1470 | if x is None:
1471 | x = np.zeros(self.dim, dtype=object)
1472 |
1473 | hidden = np.tanh(np.dot(x, self.params["w1"]) + self.params["b1"])
1474 | dhdt = np.dot(hidden, self.params["w2"]) + self.params["b2"]
1475 |
1476 | action = np.tanh(np.dot(hidden, self.params["wa"]) + self.params["ba"])
1477 | th, thdot = x[-2:]
1478 |
1479 | max_speed = 8
1480 | g = 9.81
1481 | m = 1.0
1482 | l = 1.0
1483 |
1484 | newthdot = -3 * g / (2 * l) * np.sin(th + np.pi) + 3.0 / (m * l ** 2) * action
1485 | newthdot = max_speed * np.tanh(newthdot / max_speed)
1486 | newth = newthdot
1487 |
1488 | dxdt = np.array([newth]).reshape((1,))
1489 | dydt = np.array([newthdot]).reshape((1,))
1490 | dfdt = np.concatenate([dhdt, dxdt, dydt], axis=0)
1491 | return dfdt
1492 |
1493 |
1494 | class CTRNNosc:
1495 | def __init__(self, radius):
1496 | # ============ adapt initial values ===========
1497 | if radius is not None:
1498 | self.rad = radius
1499 | else:
1500 | self.rad = 0.1
1501 | # ===================================================
1502 | self.cx = np.zeros(16)
1503 | self.dim = self.cx.size # dimension of the system
1504 | arr = np.load("rl/ctrnn_osc.npz")
1505 | self.params = {k: arr[k] for k in arr.files}
1506 |
1507 | def fdyn(self, t=0, x=None):
1508 | if x is None:
1509 | x = np.zeros(self.dim, dtype=object)
1510 |
1511 | hidden = np.tanh(np.dot(x, self.params["w1"]) + self.params["b1"])
1512 | dhdt = np.dot(hidden, self.params["w2"]) + self.params["b2"]
1513 | return dhdt
1514 |
--------------------------------------------------------------------------------