├── fig1.png ├── fig3.png ├── GoTube.jpg ├── Appendix.pdf ├── GoTube1.jpg ├── config.ini ├── logged └── .gitignore ├── rl ├── ctrnn_osc.npz ├── lds_ctrnn.npz └── pendulum_ctrnn.npz ├── saved_outputs └── .gitignore ├── .gitignore ├── requirements.txt ├── timer.py ├── compute_volume_intersection.py ├── table4.sh ├── polar_coordinates.py ├── LICENSE.md ├── performance_log.py ├── dynamics.py ├── figure4.sh ├── table2.sh ├── plot.py ├── main.py ├── README.md ├── go_tube.py ├── stochastic_reachtube.py └── benchmarks.py /fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DatenVorsprung/GoTube/HEAD/fig1.png -------------------------------------------------------------------------------- /fig3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DatenVorsprung/GoTube/HEAD/fig3.png -------------------------------------------------------------------------------- /GoTube.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DatenVorsprung/GoTube/HEAD/GoTube.jpg -------------------------------------------------------------------------------- /Appendix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DatenVorsprung/GoTube/HEAD/Appendix.pdf -------------------------------------------------------------------------------- /GoTube1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DatenVorsprung/GoTube/HEAD/GoTube1.jpg -------------------------------------------------------------------------------- /config.ini: -------------------------------------------------------------------------------- 1 | [files] 2 | output_directory = ./saved_outputs/ 3 | output_file = _GoTube.txt -------------------------------------------------------------------------------- /logged/.gitignore: -------------------------------------------------------------------------------- 1 | # just to make sure that the directory logged is created 2 | *.json -------------------------------------------------------------------------------- /rl/ctrnn_osc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DatenVorsprung/GoTube/HEAD/rl/ctrnn_osc.npz -------------------------------------------------------------------------------- /rl/lds_ctrnn.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DatenVorsprung/GoTube/HEAD/rl/lds_ctrnn.npz -------------------------------------------------------------------------------- /saved_outputs/.gitignore: -------------------------------------------------------------------------------- 1 | # just to make sure that the directory saved_outputs is created 2 | *.txt -------------------------------------------------------------------------------- /rl/pendulum_ctrnn.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DatenVorsprung/GoTube/HEAD/rl/pendulum_ctrnn.npz -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # PyCharm IDEA 2 | /.idea 3 | .project 4 | .pydevproject 5 | /__pycache__ 6 | 7 | # Mac OS 8 | /.DS_Store 9 | 10 | # Virtual environment 11 | /venv 12 | /env 13 | 14 | # Output and Log files 15 | all_prob_scores.csv 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | future==0.18.2 3 | kiwisolver==1.3.1 4 | matplotlib==3.3.3 5 | numpy==1.22.2 6 | Pillow==8.0.1 7 | pyparsing==2.4.7 8 | python-dateutil==2.8.1 9 | scipy==1.8.1 10 | six==1.15.0 11 | torch==1.7.1 12 | typing-extensions==4.6.3 13 | jax==0.4.12 14 | jaxlib==0.4.12 -------------------------------------------------------------------------------- /timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class Timer(object): 5 | def __init__(self, name=None): 6 | self.name = name 7 | 8 | def __enter__(self): 9 | self.tstart = time.time() 10 | 11 | def __exit__(self, type, value, traceback): 12 | True 13 | # if self.name: 14 | # print('[%s]' % self.name,) 15 | # print('Elapsed: %.2f seconds' % (time.time() - self.tstart)) -------------------------------------------------------------------------------- /compute_volume_intersection.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | dim = 2 5 | volumes = [] 6 | 7 | fname1 = f"logged/stat_0037.json" 8 | with open(fname1, "r") as f: 9 | log1 = json.load(f) 10 | 11 | radius1 = log1["stats"]["radius"] 12 | semiAxes_prod1 = log1["stats"]["semiAxes_prod"] 13 | volumes1 = log1["stats"]["volume"] 14 | 15 | fname2 = f"logged/stat_0038.json" 16 | with open(fname2, "r") as f: 17 | log2 = json.load(f) 18 | 19 | radius2 = log2["stats"]["radius"] 20 | semiAxes_prod2 = log2["stats"]["semiAxes_prod"] 21 | volumes2 = log2["stats"]["volume"] 22 | 23 | for r1, s1, v1, r2, s2, v2 in zip(radius1, semiAxes_prod1, volumes1, radius2, semiAxes_prod2, volumes2): 24 | d = 2 * max(r1*s1, r2*s2) 25 | v = d ** dim 26 | print(v) 27 | v = min(v, v1) 28 | v = min(v, v2) 29 | volumes.append(v) 30 | 31 | volumes = np.array(volumes) 32 | 33 | print("average volume", volumes.mean()) -------------------------------------------------------------------------------- /table4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source venv/bin/activate 4 | 5 | BENCHMARK=cartpoleLTC_RK 6 | TIME_STEP=0.01 7 | TIME_HORIZON_SHORT=0.35 8 | TIME_HORIZON_LONG=10 9 | INITIAL_RADIUS=1e-4 10 | # new parameters - to be defined 11 | python main.py --time_horizon $TIME_HORIZON_SHORT --benchmark $BENCHMARK --batch_size 10000 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.05 --score 12 | python main.py --time_horizon $TIME_HORIZON_LONG --benchmark $BENCHMARK --batch_size 10000 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.05 --score 13 | 14 | 15 | BENCHMARK=cartpoleCTRNN 16 | TIME_STEP=0.1 17 | TIME_HORIZON_SHORT=1 18 | TIME_HORIZON_LONG=10 19 | INITIAL_RADIUS=1e-4 20 | # new parameters - to be defined 21 | python main.py --time_horizon $TIME_HORIZON_SHORT --benchmark $BENCHMARK --batch_size 10000 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.05 --score 22 | python main.py --time_horizon $TIME_HORIZON_LONG --benchmark $BENCHMARK --batch_size 10000 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.05 --score 23 | -------------------------------------------------------------------------------- /polar_coordinates.py: -------------------------------------------------------------------------------- 1 | # transformation between polar and cartesian coordinates 2 | 3 | import numpy as np 4 | import jax.numpy as jnp 5 | from jax import jit 6 | import dynamics 7 | 8 | # initialize random polar coordinates with dimension dim 9 | _rng = np.random.RandomState(12937) 10 | 11 | 12 | def uniform(start, end, dim, fixed_seed): 13 | if fixed_seed: 14 | global _rng 15 | return _rng.uniform(start, end, dim) 16 | else: 17 | return np.random.uniform(start, end, dim) 18 | 19 | 20 | def init_random_phi(dim, samples=1, num_gpus=1, fixed_seed=False): 21 | phi = uniform(0, jnp.pi, samples * (dim - 2), fixed_seed) 22 | phi = jnp.append(phi, uniform(0, 2 * jnp.pi, samples, fixed_seed)) 23 | phi = jnp.reshape(phi, (num_gpus, samples // num_gpus, dim - 1), order="F") 24 | 25 | return phi 26 | 27 | 28 | @jit 29 | def polar2cart(rad, phi): 30 | return rad * polar2cart_no_rad(phi) 31 | 32 | 33 | def polar2cart_euclidean_metric(rad, phis, A0inv): 34 | return rad * jnp.matmul(A0inv, polar2cart_no_rad(phis)) 35 | 36 | 37 | def polar2cart_no_rad(phi): 38 | return dynamics.polar2cart_no_rad(phi) 39 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # Attribution-NonCommercial-ShareAlike 4.0 International 2 | This work is licensed under the 3 | [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](https://creativecommons.org/licenses/by-nc-sa/4.0/). 4 | 5 | Creative Commons License 6 | 7 | ## Human-readable summary 8 | This is a human-readable summary of (and not a substitute for) the [License](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 9 | 10 | 11 | This license allows reusers to distribute, remix, adapt, and build upon the material in any medium or format for noncommercial purposes only, and only so long as attribution is given to the creator. If you remix, adapt, or build upon the material, you must license the modified material under identical terms. 12 | CC BY-NC-SA includes the following elements: 13 | 14 | BY NC SA 15 | 16 | * **BY** – Credit must be given to the creator 17 | 18 | * **NC** – Only noncommercial uses of the work are permitted 19 | 20 | * **SA** – Adaptations must be shared under the same terms 21 | 22 | -------------------------------------------------------------------------------- /performance_log.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | registered_args = {} 5 | logged_stats = {} 6 | 7 | 8 | def log_args(args): 9 | global registered_args 10 | for k, v in dict(args).items(): 11 | registered_args[k] = v 12 | 13 | 14 | def log_stat(stat): 15 | global logged_stats 16 | for k, v in stat.items(): 17 | if not k in logged_stats.keys(): 18 | logged_stats[k] = [] 19 | logged_stats[k].append(v) 20 | 21 | 22 | def close_log(notes): 23 | global registered_args 24 | global logged_stats 25 | final_dict = {"args": registered_args, "stats": logged_stats, "notes": notes} 26 | os.makedirs("logged", exist_ok=True) 27 | for i in range(100000): 28 | fname = f"logged/stat_{i:04d}.json" 29 | if not os.path.isfile(fname): 30 | break 31 | with open(fname, "w") as f: 32 | json.dump(final_dict, f) 33 | 34 | 35 | def create_plot_file(files): 36 | for i in range(100000): 37 | fname = files['output_directory']+f"{i:04d}"+files['output_file'] 38 | if not os.path.isfile(fname): 39 | break 40 | return fname 41 | 42 | 43 | def write_plot_file(fname, mode, t, cx, rad, M1): 44 | f = open(fname, mode) 45 | f.write(str(t) + " ") 46 | f.write(' '.join(map(str, cx.reshape(-1))) + " ") 47 | f.write(str(rad) + " ") 48 | f.write(' '.join(map(str, M1.reshape(-1))) + "\n") 49 | f.close() 50 | 51 | 52 | if __name__ == "__main__": 53 | print("registered_args: ", str(registered_args)) 54 | log_args({"hello": "test"}) 55 | print("registered_args: ", str(registered_args)) 56 | 57 | -------------------------------------------------------------------------------- /dynamics.py: -------------------------------------------------------------------------------- 1 | # computes the jacobian and the metric for a given model 2 | 3 | import jax.numpy as jnp 4 | import numpy as np 5 | from jax import jacfwd, jacrev, jit 6 | 7 | from scipy.linalg import eigh 8 | from numpy.linalg import inv 9 | import benchmarks as bm 10 | 11 | 12 | class FunctionDynamics: 13 | def __init__(self, model): 14 | self.model = model 15 | 16 | x = jnp.ones(self.model.dim) 17 | if jnp.sum(jnp.abs(self.model.fdyn(0.0, x) - self.model.fdyn(1.0, x))) > 1e-8: 18 | # https://github.com/google/jax/issues/47 19 | raise ValueError("Only time-invariant systems supported currently") 20 | self._cached_f_jac = jit(jacrev(lambda x: self.model.fdyn(0.0, x))) 21 | 22 | def f_jac_at(self, t, x): 23 | return jnp.array(self._cached_f_jac(x)) 24 | 25 | def metric(self, Fmid, ellipsoids): 26 | if ellipsoids: 27 | A1inv = Fmid 28 | A1 = inv(A1inv) 29 | M1 = inv(A1inv @ A1inv.T) 30 | 31 | W, v = eigh(M1) 32 | 33 | W = abs(W) # to prevent nan-errors 34 | 35 | semiAxes = 1 / np.sqrt(W) # needed to compute volume of ellipse 36 | 37 | else: 38 | A1 = np.eye(Fmid.shape[0]) 39 | M1 = np.eye(Fmid.shape[0]) 40 | semiAxes = np.array([1]) 41 | 42 | return M1, A1, semiAxes.prod() 43 | 44 | 45 | def polar2cart_no_rad(phi): 46 | sin_polar = jnp.sin(phi) 47 | cart = jnp.append(jnp.cos(phi), jnp.ones(1)) * jnp.append(jnp.ones(1), sin_polar) 48 | for i in range(1, jnp.size(phi)): 49 | cart *= jnp.append(jnp.ones(i + 1), sin_polar[:-i]) 50 | return ( 51 | cart # rad*polar2cart_no_rad(phi) is the true value of the cartesian coordinate 52 | ) 53 | 54 | _jac_polar_cached = None 55 | 56 | 57 | def jacobian_polar_at(phi): 58 | global _jac_polar_cached 59 | if _jac_polar_cached is None: 60 | _jac_polar_cached = jit(jacfwd(polar2cart_no_rad)) 61 | return _jac_polar_cached(phi) 62 | 63 | 64 | if __name__ == "__main__": 65 | fdyn = FunctionDynamics(bm.CartpoleCTRNN()) 66 | 67 | print(fdyn.f_jac_at(0, fdyn.model.cx)) 68 | -------------------------------------------------------------------------------- /figure4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source venv/bin/activate 4 | 5 | 6 | python3 main.py --time_horizon 2 --mu 2.0 --benchmark CTRNNosc --batch_size 256 --time_step 0.5 --radius 0.001 --gamma 0.5 --score 7 | python3 main.py --time_horizon 2 --mu 1.8 --benchmark CTRNNosc --batch_size 256 --time_step 0.5 --radius 0.001 --gamma 0.5 --score 8 | python3 main.py --time_horizon 2 --mu 1.7 --benchmark CTRNNosc --batch_size 256 --time_step 0.5 --radius 0.001 --gamma 0.5 --score 9 | python3 main.py --time_horizon 2 --mu 1.6 --benchmark CTRNNosc --batch_size 256 --time_step 0.5 --radius 0.001 --gamma 0.5 --score 10 | python3 main.py --time_horizon 2 --mu 1.5 --benchmark CTRNNosc --batch_size 1024 --time_step 0.5 --radius 0.001 --gamma 0.5 --score 11 | 12 | python main.py --time_horizon 10 --mu 2.0 --benchmark pendulumCTRNN --batch_size 256 --time_step 0.5 --radius 0.01 --gamma 0.5 --score 13 | python main.py --time_horizon 10 --mu 1.8 --benchmark pendulumCTRNN --batch_size 256 --time_step 0.5 --radius 0.01 --gamma 0.5 --score 14 | python main.py --time_horizon 10 --mu 1.7 --benchmark pendulumCTRNN --batch_size 256 --time_step 0.5 --radius 0.01 --gamma 0.5 --score 15 | python main.py --time_horizon 10 --mu 1.6 --benchmark pendulumCTRNN --batch_size 256 --time_step 0.5 --radius 0.01 --gamma 0.5 --score 16 | python main.py --time_horizon 10 --mu 1.5 --benchmark pendulumCTRNN --batch_size 1024 --time_step 0.5 --radius 0.01 --gamma 0.5 --score 17 | python main.py --time_horizon 10 --mu 1.3 --benchmark pendulumCTRNN --batch_size 1024 --time_step 0.5 --radius 0.01 --gamma 0.5 --score 18 | 19 | python main.py --time_horizon 10 --mu 2.0 --benchmark ldsCTRNN --batch_size 256 --time_step 0.5 --radius 0.01 --gamma 0.5 --score 20 | python main.py --time_horizon 10 --mu 1.8 --benchmark ldsCTRNN --batch_size 256 --time_step 0.5 --radius 0.01 --gamma 0.5 --score 21 | python main.py --time_horizon 10 --mu 1.7 --benchmark ldsCTRNN --batch_size 256 --time_step 0.5 --radius 0.01 --gamma 0.5 --score 22 | python main.py --time_horizon 10 --mu 1.6 --benchmark ldsCTRNN --batch_size 256 --time_step 0.5 --radius 0.01 --gamma 0.5 --score 23 | python main.py --time_horizon 10 --mu 1.5 --benchmark ldsCTRNN --batch_size 1024 --time_step 0.5 --radius 0.01 --gamma 0.5 --score 24 | python main.py --time_horizon 10 --mu 1.3 --benchmark ldsCTRNN --batch_size 1024 --time_step 0.5 --radius 0.01 --gamma 0.5 --score 25 | python main.py --time_horizon 10 --mu 1.1 --benchmark ldsCTRNN --batch_size 1024 --time_step 0.5 --radius 0.01 --gamma 0.5 --score -------------------------------------------------------------------------------- /table2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source venv/bin/activate 4 | 5 | #Brusselator 6 | BENCHMARK=bruss 7 | MU=1.1 8 | TIME_STEP=0.01 9 | TIME_HORIZON=9 10 | INITIAL_RADIUS=0.01 11 | # new parameters - to be defined 12 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 5 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.1 --score 13 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 5 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.01 --score 14 | 15 | # #Van der Pol 16 | BENCHMARK=vdp 17 | TIME_STEP=0.01 18 | TIME_HORIZON=40 19 | INITIAL_RADIUS=0.01 20 | # new parameters - to be defined 21 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 5 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.1 --score 22 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 5 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.01 --score 23 | 24 | # #Robotarm 25 | BENCHMARK=robot 26 | TIME_STEP=0.01 27 | TIME_HORIZON=40 28 | INITIAL_RADIUS=0.005 29 | # new parameters - to be defined 30 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 5 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.1 --score 31 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 5 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.01 --score 32 | 33 | #Dubins car 34 | BENCHMARK=dubins 35 | TIME_STEP=0.1 36 | TIME_HORIZON=15 37 | INITIAL_RADIUS=0.01 38 | # new parameters - to be defined 39 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 10 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.1 --score 40 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 10 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.01 --score 41 | 42 | # #MS cardiac cell 43 | BENCHMARK=ms 44 | TIME_STEP=0.01 45 | TIME_HORIZON=10 46 | INITIAL_RADIUS=1e-4 47 | # new parameters - to be defined 48 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 10 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.1 --score 49 | python main.py --time_horizon $TIME_HORIZON --mu $MU --benchmark $BENCHMARK --batch_size 10 --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.01 --score 50 | 51 | # #Cartpole CTRNN 52 | BENCHMARK=cartpoleCTRNN 53 | TIME_STEP=0.02 54 | TIME_HORIZON=1 55 | INITIAL_RADIUS=1e-4 56 | # new parameters - to be defined 57 | BATCH_SIZE=10000 58 | python main.py --time_horizon $TIME_HORIZON --benchmark $BENCHMARK --batch_size $BATCH_SIZE --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.1 --score 59 | python main.py --time_horizon $TIME_HORIZON --benchmark $BENCHMARK --batch_size $BATCH_SIZE --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.01 --score 60 | 61 | # #Cartpole LTC 62 | BENCHMARK=cartpoleLTC_RK 63 | TIME_STEP=0.01 64 | TIME_HORIZON=0.35 65 | INITIAL_RADIUS=1e-4 66 | # new parameters - to be defined 67 | BATCH_SIZE=10000 68 | python main.py --time_horizon $TIME_HORIZON --benchmark $BENCHMARK --batch_size $BATCH_SIZE --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.1 --score 69 | python main.py --time_horizon $TIME_HORIZON --benchmark $BENCHMARK --batch_size $BATCH_SIZE --time_step $TIME_STEP --radius $INITIAL_RADIUS --gamma 0.01 --score 70 | 71 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | # plotting the outputs from GoTube (ellipse, circle, intersections etc.) 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | from numpy.linalg import svd 7 | from numpy.linalg import inv 8 | 9 | import stochastic_reachtube as reach 10 | import benchmarks as bm 11 | import configparser 12 | import argparse 13 | import pickle 14 | import os 15 | 16 | 17 | def draw_ellipse(ellipse, color, alph, axis_3d): 18 | c = ellipse[1:3] 19 | r = ellipse[3] 20 | M = np.reshape(ellipse[4:8], (2, 2)) / r ** 2 21 | 22 | # "singular value decomposition" to extract the orientation and the 23 | # axes of the ellipsoid 24 | _, D, V = svd(M) 25 | 26 | plot_grid = 50 27 | 28 | # get the major and minor axes 29 | a = 1 / np.sqrt(D[0]) 30 | b = 1 / np.sqrt(D[1]) 31 | 32 | theta = np.arange(0, 2 * np.pi + 1 / plot_grid, 1 / plot_grid) 33 | 34 | # parametric equation of the ellipse 35 | state = np.zeros((2, np.size(theta))) 36 | state[0, :] = a * np.cos(theta) 37 | state[1, :] = b * np.sin(theta) 38 | 39 | # coordinate transform 40 | X = V @ state 41 | 42 | X[0] += c[0] 43 | X[1] += c[1] 44 | 45 | axis_3d.plot(xs=X[0], ys=X[1], zs=ellipse[0], color=color, alpha=alph) 46 | 47 | 48 | def plot_ellipse( 49 | time_horizon, dim, axis1, axis2, file, color, alph, axis_3d, skip_reachsets=1 50 | ): 51 | data_ellipse = np.loadtxt(file) 52 | 53 | # permutation matrix to project on axis1 and axis2 54 | P = np.eye(dim) 55 | P[:, [0, axis1]] = P[:, [axis1, 0]] 56 | P[:, [1, axis2]] = P[:, [axis2, 1]] 57 | 58 | count = skip_reachsets 59 | 60 | for ellipse in data_ellipse[1:]: 61 | 62 | if count != skip_reachsets: 63 | count += 1 64 | continue 65 | 66 | count = 1 67 | 68 | ellipse2 = ellipse 69 | 70 | # create ellipse plotting values for 2d projection 71 | # https://math.stackexchange.com/questions/2438495/showing-positive-definiteness-in-the-projection-of-ellipsoid 72 | # construct ellipse2 to have a 2-dimensional ellipse as an input to 73 | # ellipse_plot 74 | 75 | if dim > 2: 76 | center = ellipse[1 : dim + 1] 77 | ellipse2[1] = center[axis1] 78 | ellipse2[2] = center[axis2] 79 | radius_ellipse = ellipse[dim + 1] 80 | m1 = np.reshape(ellipse[dim + 2 :], (dim, dim)) 81 | m1 = m1 / radius_ellipse ** 2 82 | m1 = P.transpose() @ m1 @ P # permutation to project on chosen axes 83 | ellipse2[3] = 1 # because radius is already in m1 84 | 85 | # plot ellipse onto axis1-axis2 plane 86 | J = m1[0:2, 0:2] 87 | K = m1[2:, 2:] 88 | L = m1[2:, 0:2] 89 | m2 = J - L.transpose() @ inv(K) @ L 90 | ellipse2[4:8] = m2.reshape(1, -1) 91 | 92 | draw_ellipse(ellipse2[0:8], color, alph, axis_3d) 93 | 94 | if ellipse[0] >= time_horizon: 95 | break 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser(description="") 100 | parser.add_argument("--time_step", default=0.01, type=float) 101 | parser.add_argument("--time_horizon", default=0.01, type=float) 102 | parser.add_argument("--benchmark", default="bruss") 103 | parser.add_argument("--output_number", default="0000") 104 | parser.add_argument("--samples", default=100, type=int) 105 | parser.add_argument("--axis1", default=0, type=int) 106 | parser.add_argument("--axis2", default=1, type=int) 107 | # initial radius 108 | parser.add_argument("--radius", default=None, type=float) 109 | 110 | args = parser.parse_args() 111 | 112 | config = configparser.ConfigParser() 113 | config.read("config.ini") 114 | 115 | files = config["files"] 116 | 117 | rt = reach.StochasticReachtube( 118 | model=bm.get_model(args.benchmark, args.radius), 119 | time_horizon=args.time_horizon, 120 | time_step=args.time_step, 121 | samples=args.samples, 122 | axis1=args.axis1, 123 | axis2=args.axis2, 124 | ) # reachtube 125 | 126 | fig = plt.figure() 127 | ax = fig.add_subplot(111, projection="3d") 128 | 129 | p_dict = rt.plot_traces(axis_3d=ax) 130 | p_dict["time_horizon"] = args.time_horizon 131 | p_dict["dim"] = rt.model.dim 132 | p_dict["axis1"] = rt.axis1 133 | p_dict["axis2"] = rt.axis2 134 | p_dict["data_ellipse"] = np.loadtxt( 135 | files["output_directory"] + str(args.output_number) + files["output_file"] 136 | ) 137 | 138 | plot_ellipse( 139 | args.time_horizon, 140 | rt.model.dim, 141 | rt.axis1, 142 | rt.axis2, 143 | files["output_directory"] + str(args.output_number) + files["output_file"], 144 | color="magenta", 145 | alph=0.8, 146 | axis_3d=ax, 147 | skip_reachsets=1, 148 | ) 149 | 150 | plt.show() 151 | 152 | os.makedirs("plot_obj", exist_ok=True) 153 | for i in range(1000): 154 | filename = f"plot_obj/plot_{i:03d}.pkl" 155 | if not os.path.isfile(filename): 156 | with open(filename, "wb") as f: 157 | pickle.dump(p_dict, f, protocol=pickle.DEFAULT_PROTOCOL) 158 | break 159 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax.numpy as jnp 3 | 4 | import benchmarks as bm 5 | import stochastic_reachtube as reach 6 | import go_tube 7 | import configparser 8 | import time 9 | from performance_log import log_args 10 | from performance_log import close_log 11 | from performance_log import create_plot_file 12 | from performance_log import write_plot_file 13 | from performance_log import log_stat 14 | 15 | import argparse 16 | 17 | from jax import config 18 | config.update("jax_enable_x64", True) 19 | 20 | 21 | if __name__ == "__main__": 22 | 23 | start_time = time.time() 24 | 25 | parser = argparse.ArgumentParser(description="") 26 | parser.add_argument("--profile", action="store_true") 27 | parser.add_argument("--score", action="store_true") 28 | parser.add_argument("--benchmark", default="vdp") 29 | # starting_time, time_step and time_horizon for creating reachtubes 30 | parser.add_argument("--starting_time", default=0.0, type=float) 31 | parser.add_argument("--time_step", default=0.01, type=float) 32 | parser.add_argument("--time_horizon", default=10, type=float) 33 | # batch-size for tensorization 34 | parser.add_argument("--batch_size", default=10000, type=int) 35 | # number of GPUs for parallelization 36 | parser.add_argument("--num_gpus", default=1, type=int) 37 | # use fixed seed for random points (only for comparing different algorithms) 38 | parser.add_argument("--fixed_seed", action="store_true") 39 | # error-probability 40 | parser.add_argument("--gamma", default=0.2, type=float) 41 | # mu as maximum over-approximation 42 | parser.add_argument("--mu", default=1.5, type=float) 43 | # choose between hyperspheres and ellipsoids to describe the Reachsets 44 | parser.add_argument("--ellipsoids", action="store_true") 45 | # initial radius 46 | parser.add_argument("--radius", default=None, type=float) 47 | 48 | args = parser.parse_args() 49 | log_args(vars(args)) 50 | 51 | config = configparser.ConfigParser() 52 | config.read("./config.ini") 53 | 54 | files = config["files"] 55 | 56 | rt = reach.StochasticReachtube( 57 | model=bm.get_model(args.benchmark, args.radius), 58 | profile=args.profile, 59 | mu=args.mu, # mu as maximum over-approximation 60 | gamma=args.gamma, # error-probability 61 | batch=args.batch_size, 62 | num_gpus=args.num_gpus, 63 | fixed_seed=args.fixed_seed, 64 | radius=args.radius, 65 | ) # reachtube 66 | 67 | timeRange = np.arange(args.starting_time + args.time_step, args.time_horizon + 1e-9, args.time_step) 68 | timeRange_with_start = np.append(np.array([0]), timeRange) 69 | 70 | volume = jnp.array([rt.compute_volume()]) 71 | 72 | # propagate center point and compute metric 73 | ( 74 | cx_timeRange, 75 | A1_timeRange, 76 | M1_timeRange, 77 | semiAxes_prod_timeRange, 78 | ) = rt.compute_metric_and_center( 79 | timeRange_with_start, 80 | args.ellipsoids, 81 | ) 82 | 83 | fname = create_plot_file(files) 84 | 85 | write_plot_file(fname, "w", 0, rt.model.cx, rt.model.rad, M1_timeRange[0, :, :]) 86 | 87 | total_random_points = None 88 | total_gradients = None 89 | total_initial_points = jnp.zeros((rt.num_gpus, 0, rt.model.dim)) 90 | 91 | # for loop starting at Line 2 92 | for i, time_py in enumerate(timeRange): 93 | print(f"Step {i}/{timeRange.shape[0]} at {time_py:0.2f} s") 94 | 95 | rt.time_horizon = time_py 96 | rt.time_step = args.time_step 97 | rt.cur_time = rt.time_horizon 98 | 99 | rt.cur_cx = cx_timeRange[i + 1, :] 100 | rt.A1 = A1_timeRange[i + 1, :, :] 101 | 102 | # GoTube Algorithm for t = t_j 103 | # while loop from Line 5 - Line 15 104 | ( 105 | rt.cur_rad, 106 | prob, 107 | total_initial_points, 108 | total_random_points, 109 | total_gradients, 110 | ) = go_tube.optimize( 111 | rt, total_initial_points, total_random_points, total_gradients 112 | ) 113 | 114 | write_plot_file( 115 | fname, "a", time_py, rt.cur_cx, rt.cur_rad, M1_timeRange[i + 1, :, :] 116 | ) 117 | 118 | volume = jnp.append(volume, rt.compute_volume(semiAxes_prod_timeRange[i + 1])) 119 | 120 | if rt.profile: 121 | # If profiling is enabled, log some statistics about the GD optimization process 122 | volumes = { 123 | "volume": float(volume[-1]), 124 | "average_volume": float(volume.mean()), 125 | } 126 | log_stat(volumes) 127 | 128 | if args.score: 129 | with open("all_prob_scores.csv", "a") as f: 130 | # CSV with header benchmark, time-horizon, prob, runtime, volume 131 | f.write(f"{args.benchmark},") 132 | f.write(f"{args.time_horizon:0.4g},") 133 | f.write(f"{args.radius:0.4g},") 134 | f.write(f"{args.mu:0.4g},") 135 | f.write(f"{1.0-args.gamma:0.4f},") 136 | f.write(f"{time.time()-start_time:0.2f},") 137 | f.write(f"{total_random_points.shape[0]:d},") 138 | f.write(f"{float(volume.mean()):0.5g}") 139 | f.write("\n") 140 | if rt.profile: 141 | final_notes = { 142 | "total_time": time.time() - start_time, 143 | "samples": args.num_gpus * total_random_points.shape[1], 144 | } 145 | close_log(final_notes) 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GoTube - Scalable Stochastic Verification of Continuous-Depth Models 2 | 3 | This is the official code repository of the paper *GoTube: Scalable Stochastic Verification of Continuous-Depth Models* 4 | accepted to the Thirty-Sixth AAAI Conference on Artificial Intelligence (AAAI-22) 5 | ([arXiv link](https://arxiv.org/abs/2107.08467)). 6 | 7 | GoTube constructs stochastic reachtubes (= the set of all reachable system states) of continuous-time systems. GoTube is made deliberately for the verification of continuous-depth neural networks. 8 | ![Figure 1 of the paper](GoTube1.jpg) 9 | 10 | This document describes the general usage of GoTube. For the setup to reproduce the exact numbers reported in the paper have a look at the files ```table1.sh```, ```table2.sh```, and ```table3.sh```. 11 | 12 | ## Setup 13 | 14 | Requirement is Python3.6 or newer. 15 | The setup was tested on Ubuntu 16.04, Ubuntu 20.04 and the recent MacOS machines with Python 3.6 and 3.8. 16 | 17 | We recommend creating a virtual environment, for GoTube's dependencies to not interfere with other python installations. 18 | 19 | ```bash 20 | python3 -m venv venv # Optional 21 | source venv/bin/activate # Optional 22 | python3 -m pip install -r requirements.txt 23 | ``` 24 | 25 | ## Benchmarks 26 | 27 | Each Benchmark is encoded in a python class exposing three the attributes ```rad``` for the initial set radius, ```cx``` for the initial set center, and ```dim``` for the number of dimensions of the dynamical system. 28 | Moreover, each benchmark class must implement a method ```fdyn``` that defines the dynamical system and a case in the ``get_model`` method for choosing that model via a parameter. 29 | For example, 30 | 31 | ```python 32 | # CartPole-v1 with a linear policy 33 | class CartpoleLinear: 34 | 35 | def __init__(self, radius=None): 36 | 37 | #============ adapt initial values =========== 38 | self.cx = (0, 0, 0.001, 0) #initial values 39 | if radius is not None: 40 | self.rad = radius 41 | else: 42 | self.rad = 1e-4 #initial radius 43 | #=================================================== 44 | 45 | self.cx = np.array(self.cx, dtype=float) 46 | self.dim = self.cx.size #dimension of the system 47 | 48 | def fdyn(self,t=0,x=None): 49 | 50 | if x is None: 51 | x=np.zeros(self.dim, dtype=object) 52 | 53 | #============ adapt input and system dynamics =========== 54 | dth, dx, th, x = x #input variables 55 | 56 | M = 1.0 57 | g = 9.81 58 | l = 1.0 59 | m = 0.001 60 | 61 | f = -1.1 * M * g * th - dth 62 | 63 | fdth = 1.0 / (l*(M+m*sin(th)*sin(th))) * (f * cos(th) - m*l*dth*dth*cos(th)*sin(th) + (m+M)*g*sin(th)) 64 | fdx = 1.0 / (M+m*sin(th)*sin(th)) * ( f + m*sin(th) * (-l * dth*dth + g*cos(th)) ) 65 | 66 | fx = dx 67 | 68 | fth = dth 69 | 70 | system_dynamics = [fdth, fdx, fth, fx] #has to be in the same order as the input variables 71 | #=================================================== 72 | 73 | return np.array(system_dynamics) #return as numpy array 74 | ``` 75 | And inside the ``get_model`` method: 76 | ``` 77 | elif benchmark == "cartpole": 78 | return CartpoleLinear(radius) # Benchmark to run 79 | ``` 80 | 81 | ## Running GoTube 82 | 83 | The entry point of GoTube is defined in the file ```main.py```, which accepts several arguments to specify properties of the reachtube and GoTube. 84 | The most notable arguments are: 85 | 86 | - ```--benchmark``` Name of the benchmark (=dynamical system) for which the reachtube should be constructed 87 | - ```--time_horizon``` Length in seconds of the reachtube to be constructed 88 | - ```--time_step``` Intermediate time-points for which reachtube should be constructed 89 | - ```--batch_size``` Batch size used for simulating points of the system. A large batch size speeds up the computation but adds additional memory footprints. 90 | - ```--gamma``` Error probability. For instance, a gamma of 0.1 means the reachtube will have a 90% confidence. 91 | - ```--mu``` Maximum multiplicative tolerance of over-approximation. A higher mu speeds up the computation of the bounding tube by in return increasing the average volume. For instance, a mu of 1.5 means the reachtube will have a 1.5 times larger radius than the most distant sample point. 92 | 93 | ## Examples 94 | 95 | Create Van der Pol dynamical system reachtube for 2 seconds with a batch size of 100 samples, while creating intermediate reachsets every 0.1 seconds 96 | 97 | ```bash 98 | python main.py --time_horizon 2 --benchmark vdp --batch_size 100 --time_step 0.1 99 | ``` 100 | 101 | To create a 95% confidence reachtube for the CartPole-v1 with a CT-RNN and a maximum multiplicative tolerance of over-approximation mu of 1.5 run 102 | 103 | ```bash 104 | python main.py --time_horizon 1 --benchmark cartpoleCTRNN --batch_size 10000 --time_step 0.02 --gamma 0.05 --mu 1.5 105 | ``` 106 | 107 | ## Citation 108 | 109 | ```tex 110 | @article{Gruenbacher2022GoTube, 111 | author={Gruenbacher, Sophie A. and Lechner, Mathias and Hasani, Ramin and Rus, Daniela and Henzinger, Thomas A. and Smolka, Scott A. and Grosu, Radu}, 112 | title={GoTube: Scalable Statistical Verification of Continuous-Depth Models}, 113 | volume={36}, 114 | url={https://ojs.aaai.org/index.php/AAAI/article/view/20631}, 115 | DOI={10.1609/aaai.v36i6.20631}, 116 | number={6}, 117 | journal={Proceedings of the AAAI Conference on Artificial Intelligence}, 118 | year={2022}, 119 | month={Jun.}, 120 | pages={6755-6764}, 121 | } 122 | ``` 123 | -------------------------------------------------------------------------------- /go_tube.py: -------------------------------------------------------------------------------- 1 | # Algorithms of GoTube paper for safety region, probability and stoch. optimization 2 | 3 | import jax.numpy as jnp 4 | from jax import vmap, pmap 5 | import polar_coordinates as pol 6 | from jax.numpy.linalg import svd 7 | import jax.scipy.special as sc 8 | import time 9 | from performance_log import log_stat 10 | from timer import Timer 11 | from scipy.stats import genextreme, kstest 12 | import gc 13 | 14 | 15 | # using expected difference quotient of center lipschitz constant 16 | def get_safety_region_radius(model, dist, dist_best, lip, lip_mean_diff): 17 | safety_radius = -lip + jnp.sqrt(lip ** 2 + 4 * lip_mean_diff * (model.mu * dist_best - dist)) 18 | safety_radius = safety_radius / (2 * lip_mean_diff) 19 | safety_radius = safety_radius * (dist > 0) * (lip_mean_diff > 0) 20 | 21 | safety_radius = jnp.minimum(safety_radius, 2 * model.rad_t0) 22 | 23 | return safety_radius 24 | 25 | 26 | def compute_maximum_singular_value(A1, A0inv, F): 27 | F_metric = jnp.matmul(A1, F) 28 | F_metric = jnp.matmul(F_metric, A0inv) 29 | _, sf, _ = svd(F_metric) 30 | max_sf = jnp.max(sf) 31 | 32 | return max_sf 33 | 34 | 35 | def get_angle_of_cap(model, radius): 36 | radius = jnp.minimum(radius, 2 * model.rad_t0) 37 | return 2 * jnp.arcsin(0.5 * radius / model.rad_t0) 38 | 39 | 40 | def get_probability_of_cap(model, radius): 41 | with Timer('get angle of cap'): 42 | angle = get_angle_of_cap(model, radius) 43 | with Timer('get probability of cap'): 44 | a = 0.5 * (model.model.dim - 1) 45 | b = 0.5 46 | x = jnp.sin(angle) ** 2 47 | betainc_angle = 0.5 * sc.betainc(a, b, x) 48 | 49 | # formula is only for the smaller cap with angle <= pi/2, sinus is symmetric => thus use 1-area otherwise 50 | betainc_angle = jnp.where(angle > 0.5 * jnp.pi, 1 - betainc_angle, betainc_angle) 51 | 52 | return betainc_angle 53 | 54 | 55 | def get_probability_not_in_cap(model, radius): 56 | return 1 - get_probability_of_cap(model, radius) 57 | 58 | 59 | def get_probability_none_in_cap(model, radius_points): 60 | return jnp.prod(get_probability_not_in_cap(model, radius_points)) 61 | 62 | 63 | # probability calculation using http://docsdrive.com/pdfs/ansinet/ajms/2011/66-70.pdf (equation 1 64 | # page 68) and the normalized incomplete Beta-Function in scipy ( 65 | # https://scipy.github.io/devdocs/generated/scipy.special.betainc.html#scipy.special.betainc) - Only use the 66 | # random sampled points for probability construction 67 | # use also the discarded points and create balls around them 68 | def get_probability(model, radius_points): 69 | return jnp.sqrt(1-model.gamma) * (1 - get_probability_none_in_cap(model, radius_points)) 70 | 71 | 72 | def compute_delta_lipschitz(y_jax, fy_jax, axis, gamma): 73 | gamma_hat = 1 - jnp.sqrt(1 - gamma) 74 | diff_quotients = get_diff_quotient_pairwise(y_jax, fy_jax, axis) 75 | sample_size = diff_quotients.size 76 | number_of_elements_for_maximum = round(sample_size ** (1 / 4)) # m in Lemma 1, Theorem 2 and throughout paper 77 | 78 | sample_size_dividable = sample_size - sample_size % number_of_elements_for_maximum 79 | 80 | diff_quotients_samples = diff_quotients[:sample_size_dividable].reshape(-1, number_of_elements_for_maximum) 81 | max_quotients = jnp.nanmax(diff_quotients_samples, axis=1) 82 | number_of_maxima = max_quotients.size # n in Lemma 1 83 | alpha = min(gamma_hat, 0.5) 84 | epsilon = jnp.sqrt(jnp.log(1 / alpha) / (2 * number_of_maxima)) 85 | 86 | c, loc, scale = genextreme.fit(max_quotients) 87 | rv_genextreme = genextreme(c, loc, scale) 88 | 89 | D_minus = kstest(max_quotients, rv_genextreme.cdf, alternative='less').statistic 90 | 91 | max_quantile = 0.9999 92 | 93 | # # with generalized extreme value distribution 94 | prob_quantile = min(1 - gamma_hat, max_quantile - epsilon - D_minus) 95 | delta_lipschitz = rv_genextreme.ppf([prob_quantile + epsilon + D_minus]) # transformation of Eq. (S14) 96 | prob_bound_lipschitz = (1 - gamma_hat) * prob_quantile 97 | 98 | # without generalized extreme value distribution 99 | # max_quotients = jnp.sort(max_quotients) 100 | # prob_quantile = min(1 - gamma_hat, max_quantile - epsilon) 101 | # delta_lipschitz = max_quotients[int(jnp.floor((prob_quantile + epsilon) * number_of_maxima))] # transformation of Eq. (S14) 102 | # prob_bound_lipschitz = (1 - gamma_hat) * prob_quantile 103 | 104 | return delta_lipschitz, prob_bound_lipschitz 105 | 106 | 107 | def get_diff_quotient_pairwise(x, fx, axis): 108 | # reshape to get samples as first index and remove gpu dimension 109 | x = jnp.reshape(x, (-1, x.shape[2])) 110 | fx = jnp.reshape(fx, fx.size) 111 | 112 | samples = int(jnp.floor(x.shape[0] / 2)) 113 | x1 = x[::2][:samples] 114 | x2 = x[1::2][:samples] 115 | fx1 = fx[::2][:samples] 116 | fx2 = fx[1::2][:samples] 117 | distance = jnp.linalg.norm(x1 - x2, axis=axis) 118 | diff_quotients = abs(fx1 - fx2) / distance * (distance > 0) 119 | return diff_quotients 120 | 121 | 122 | def get_diff_quotient(x, fx, y_jax, fy_jax, axis): 123 | distance = jnp.linalg.norm(x - y_jax, axis=axis) 124 | diff_quotients = abs(fx - fy_jax) / distance * (distance > 0) 125 | return diff_quotients 126 | 127 | 128 | def get_diff_quotient_vmap(x_jax, fx_jax, y_jax, fy_jax, axis): 129 | return vmap(get_diff_quotient, in_axes=(0, 0, None, None, None))(x_jax, fx_jax, y_jax, fy_jax, axis) 130 | 131 | 132 | def optimize(model, initial_points, points=None, gradients=None): 133 | start_time = time.time() 134 | 135 | prob = None 136 | 137 | if points is None or gradients is None: 138 | previous_samples = 0 139 | phis = pol.init_random_phi(model.model.dim, model.batch, model.num_gpus, model.fixed_seed) 140 | points, gradients, neg_dists, initial_points = model.aug_integrator_neg_dist(phis) 141 | dists = -neg_dists 142 | del neg_dists 143 | del phis 144 | gc.collect() 145 | else: 146 | previous_samples = points.shape[1] 147 | with Timer('integrate random points and gradients - one step'): 148 | points, gradients, dists = model.one_step_aug_integrator_dist( 149 | points, gradients 150 | ) 151 | 152 | first_iteration = True 153 | 154 | while prob is None or prob < 1 - model.gamma: 155 | 156 | if not first_iteration: 157 | with Timer('sample phis'): 158 | phis = pol.init_random_phi(model.model.dim, model.batch, model.num_gpus, model.fixed_seed) 159 | with Timer('compute first integration step and dist'): 160 | new_points, new_gradients, new_neg_dists, new_initial_points = model.aug_integrator_neg_dist(phis) 161 | new_dists = -new_neg_dists 162 | del new_neg_dists 163 | del phis 164 | gc.collect() 165 | 166 | with Timer('concatenate new points to tensors'): 167 | points = jnp.concatenate((points, new_points), axis=1) 168 | gradients = jnp.concatenate((gradients, new_gradients), axis=1) 169 | dists = jnp.concatenate((dists, new_dists), axis=1) 170 | initial_points = jnp.concatenate((initial_points, new_initial_points), axis=1) 171 | del new_points 172 | del new_gradients 173 | del new_dists 174 | del new_initial_points 175 | gc.collect() 176 | 177 | with Timer('compute best dist'): 178 | dist_best = dists.max() 179 | 180 | with Timer('compute lipschitz'): 181 | # compute maximum singular values of all new gradient matrices 182 | lipschitz = pmap(vmap(compute_maximum_singular_value, in_axes=(None, None, 0)), in_axes=(None, None, 0))(model.A1, model.A0inv, gradients) 183 | 184 | with Timer('compute expected local lipschitz'): 185 | sample_size = points.shape[0] 186 | 187 | # compute expected value of delta lipschitz 188 | dimension_axis = 1 189 | 190 | gamma_hat = 1 - jnp.sqrt(1 - model.gamma) 191 | 192 | delta_lipschitz, prob_bound_lipschitz = compute_delta_lipschitz( 193 | initial_points[:sample_size], 194 | lipschitz[:sample_size], 195 | dimension_axis, 196 | gamma_hat 197 | ) 198 | 199 | with Timer('get safety region radii'): 200 | safety_region_radii = get_safety_region_radius( 201 | model, dists, dist_best, lipschitz, delta_lipschitz 202 | ) 203 | 204 | with Timer('compute probability'): 205 | prob = get_probability(model, safety_region_radii) 206 | 207 | del delta_lipschitz 208 | del lipschitz 209 | del safety_region_radii 210 | gc.collect() 211 | 212 | print(f"Current probability coverage is {100.0 * prob:0.3f}% using {model.num_gpus * points.shape[1]} points") 213 | 214 | first_iteration = False 215 | 216 | new_samples = model.num_gpus * points.shape[1] - previous_samples 217 | 218 | print("Probability reached given value!") 219 | print( 220 | f"Visited {new_samples} new points in {time.time() - start_time:0.2f} seconds." 221 | ) 222 | 223 | dist_with_safety_mu = model.mu * dist_best 224 | 225 | if model.profile: 226 | # If profiling is enabled, log some statistics about the optimization process 227 | stat_dict = { 228 | "loop_time": time.time() - start_time, 229 | "new_points": int(new_samples), 230 | "total_points": int(previous_samples + new_samples), 231 | "prob": float(prob), 232 | "dist_best": float(dist_best), 233 | "radius": float(dist_with_safety_mu), 234 | } 235 | log_stat(stat_dict) 236 | 237 | return dist_with_safety_mu, prob, initial_points, points, gradients 238 | -------------------------------------------------------------------------------- /stochastic_reachtube.py: -------------------------------------------------------------------------------- 1 | # optimization problem 2 | 3 | import numpy as np 4 | import jax.numpy as jnp 5 | from jax.experimental.ode import odeint 6 | from jax import vmap, jit, pmap, device_put, devices 7 | from functools import partial 8 | 9 | from scipy.special import gamma 10 | 11 | # own files 12 | import benchmarks as bm 13 | import polar_coordinates as pol 14 | import dynamics 15 | 16 | 17 | def create_aug_state_cartesian(x, F): 18 | aug_state = jnp.concatenate((jnp.array([x]), F)).reshape( 19 | -1 20 | ) # reshape to row vector 21 | 22 | return aug_state 23 | 24 | 25 | class StochasticReachtube: 26 | def __init__( 27 | self, 28 | model=bm.CartpoleCTRNN(None), 29 | time_horizon=10.0, # time_horizon until which the reachtube should be constructed 30 | profile=False, 31 | time_step=0.1, # ReachTube construction 32 | h_metric=0.05, # time_step for metric computation 33 | h_traces=0.01, # time_step for traces computation 34 | max_step_metric=0.00125, # maximum time_step for metric computation 35 | max_step_optim=0.1, # maximum time_step for optimization 36 | samples=100, # just for plotting: number of random points on the border of the initial ball 37 | batch=1, # number of initial points for vectorization 38 | num_gpus=1, # number of GPUs for parallel computation 39 | fixed_seed=False, # specify whether a fixed seed should be used (only for comparing different algorithms) 40 | axis1=0, # axis to project reachtube to 41 | axis2=1, 42 | atol=1e-10, # absolute tolerance of integration 43 | rtol=1e-10, # relative tolerance of integration 44 | plot_grid=50, 45 | mu=1.5, 46 | gamma=0.01, 47 | radius=False, 48 | ): 49 | 50 | self.time_step = min(time_step, time_horizon) 51 | self.profile = profile 52 | self.h_metric = min(h_metric, time_step) 53 | self.h_traces = h_traces 54 | self.max_step_metric = min(max_step_metric, self.h_metric) 55 | self.max_step_optim = min(max_step_optim, self.time_step) 56 | self.time_horizon = time_horizon 57 | self.samples = samples 58 | self.batch = batch 59 | self.num_gpus = num_gpus 60 | self.fixed_seed = fixed_seed 61 | self.axis1 = axis1 62 | self.axis2 = axis2 63 | self.atol = atol 64 | self.rtol = rtol 65 | self.plotGrid = plot_grid 66 | self.mu = mu 67 | self.gamma = gamma 68 | 69 | self.model = model 70 | self.init_model() 71 | 72 | self.metric = dynamics.FunctionDynamics(model).metric 73 | self.init_metric() 74 | 75 | self.f_jac_at = dynamics.FunctionDynamics(model).f_jac_at 76 | 77 | def init_metric(self): 78 | self.M1 = np.eye(self.model.dim) 79 | self.A1 = np.eye(self.model.dim) 80 | self.A1inv = np.eye(self.model.dim) 81 | self.A0inv = np.eye(self.model.dim) 82 | 83 | def init_model(self): 84 | self.cur_time = 0 85 | self.cur_cx = self.model.cx 86 | self.cur_rad = self.model.rad 87 | self.t0 = 0 88 | self.cx_t0 = self.model.cx 89 | self.rad_t0 = self.model.rad 90 | 91 | def compute_volume(self, semiAxes_product=None): 92 | if semiAxes_product is None: 93 | semiAxes_product = 1 94 | volC = gamma(self.model.dim / 2.0 + 1) ** -1 * jnp.pi ** ( 95 | self.model.dim / 2.0 96 | ) # volume constant for ellipse and ball 97 | return volC * self.cur_rad ** self.model.dim * semiAxes_product 98 | 99 | def plot_traces(self, axis_3d): 100 | rd_polar = pol.init_random_phi(self.model.dim, self.samples) 101 | # reshape to get samples as first index and remove gpu dimension 102 | rd_polar = jnp.reshape(rd_polar, (-1, rd_polar.shape[2])) 103 | rd_x = ( 104 | vmap(pol.polar2cart, in_axes=(None, 0))(self.model.rad, rd_polar) 105 | + self.model.cx 106 | ) 107 | plot_timerange = jnp.arange(0, self.time_horizon + 1e-9, self.h_traces) 108 | 109 | sol = odeint( 110 | self.fdyn_jax_no_pmap, 111 | rd_x, 112 | plot_timerange, 113 | atol=self.atol, 114 | rtol=self.rtol, 115 | ) 116 | 117 | for s in range(self.samples): 118 | axis_3d.plot( 119 | xs=sol[:, s, self.axis1], 120 | ys=sol[:, s, self.axis2], 121 | zs=plot_timerange, 122 | color="k", 123 | linewidth=1, 124 | ) 125 | 126 | p_dict = { 127 | "xs": np.array(sol[:, s, self.axis1]), 128 | "ys": np.array(sol[:, s, self.axis2]), 129 | "zs": np.array(plot_timerange), 130 | } 131 | return p_dict 132 | 133 | def propagate_center_point(self, time_range): 134 | cx_jax = self.model.cx.reshape(1, self.model.dim) 135 | F = jnp.eye(self.model.dim) 136 | # put aug_state in CPU, as it is faster for odeint than GPU 137 | aug_state = device_put(jnp.concatenate((cx_jax, F)).reshape(1, -1), device=devices("cpu")[0]) 138 | sol = odeint( 139 | self.aug_fdyn_jax_no_pmap, 140 | aug_state, 141 | time_range, 142 | atol=self.atol, 143 | rtol=self.rtol, 144 | ) 145 | cx, F = vmap(self.reshape_aug_state_to_matrix)(sol) 146 | return cx, F 147 | 148 | def compute_metric_and_center(self, time_range, ellipsoids): 149 | print(f"Propagating center point for {time_range.shape[0]-1} timesteps") 150 | cx_timeRange, F_timeRange = self.propagate_center_point(time_range) 151 | A1_timeRange = np.eye(self.model.dim).reshape(1, self.model.dim, self.model.dim) 152 | M1_timeRange = np.eye(self.model.dim).reshape(1, self.model.dim, self.model.dim) 153 | semiAxes_prod_timeRange = np.array([1]) 154 | 155 | print("Starting loop for creating metric") 156 | for idx, t in enumerate(time_range[1:]): 157 | M1_t, A1_t, semiAxes_prod_t = self.metric( 158 | F_timeRange[idx + 1, :, :], ellipsoids 159 | ) 160 | A1_timeRange = np.concatenate( 161 | (A1_timeRange, A1_t.reshape(1, self.model.dim, self.model.dim)), axis=0 162 | ) 163 | M1_timeRange = np.concatenate( 164 | (M1_timeRange, M1_t.reshape(1, self.model.dim, self.model.dim)), axis=0 165 | ) 166 | semiAxes_prod_timeRange = np.append( 167 | semiAxes_prod_timeRange, semiAxes_prod_t 168 | ) 169 | 170 | return cx_timeRange, A1_timeRange, M1_timeRange, semiAxes_prod_timeRange 171 | 172 | def reshape_aug_state_to_matrix(self, aug_state): 173 | aug_state = aug_state.reshape(-1, self.model.dim) # reshape to matrix 174 | x = aug_state[:1][0] 175 | F = aug_state[1:] 176 | return x, F 177 | 178 | def reshape_aug_fdyn_return_to_vector(self, fdyn_return, F_return): 179 | return jnp.concatenate((jnp.array([fdyn_return]), F_return)).reshape(-1) 180 | 181 | @partial(jit, static_argnums=(0,)) 182 | def aug_fdyn(self, t=0, aug_state=0): 183 | x, F = self.reshape_aug_state_to_matrix(aug_state) 184 | fdyn_return = self.model.fdyn(t, x) 185 | F_return = jnp.matmul(self.f_jac_at(t, x), F) 186 | return self.reshape_aug_fdyn_return_to_vector(fdyn_return, F_return) 187 | 188 | def aug_fdyn_jax_no_pmap(self, aug_state=0, t=0): 189 | return vmap(self.aug_fdyn, in_axes=(None, 0))(t, aug_state) 190 | 191 | def fdyn_jax_no_pmap(self, x=0, t=0): 192 | return vmap(self.model.fdyn, in_axes=(None, 0))(t, x) 193 | 194 | def aug_fdyn_jax(self, aug_state=0, t=0): 195 | return pmap(vmap(self.aug_fdyn, in_axes=(None, 0)), in_axes=(None, 0))(t, aug_state) 196 | 197 | def fdyn_jax(self, x=0, t=0): 198 | return pmap(vmap(self.model.fdyn, in_axes=(None, 0)), in_axes=(None, 0))(t, x) 199 | 200 | def create_aug_state(self, polar, rad_t0, cx_t0): 201 | x = jnp.array( 202 | pol.polar2cart_euclidean_metric(rad_t0, polar, self.A0inv) + cx_t0 203 | ) 204 | F = jnp.eye(self.model.dim) 205 | 206 | aug_state = jnp.concatenate((jnp.array([x]), F)).reshape( 207 | -1 208 | ) # reshape to row vector 209 | 210 | return aug_state, x 211 | 212 | def one_step_aug_integrator(self, x, F): 213 | aug_state = pmap(vmap(create_aug_state_cartesian))(x, F) 214 | sol = odeint( 215 | self.aug_fdyn_jax, 216 | aug_state, 217 | jnp.array([0, self.time_step]), 218 | atol=self.atol, 219 | rtol=self.rtol, 220 | ) 221 | x, F = pmap(vmap(self.reshape_aug_state_to_matrix))(sol[-1]) 222 | return x, F 223 | 224 | def aug_integrator(self, polar, step=None): 225 | if step is None: 226 | step = self.cur_time 227 | 228 | rad_t0 = self.rad_t0 229 | cx_t0 = self.cx_t0 230 | 231 | aug_state, initial_x = pmap(vmap(self.create_aug_state, in_axes=(0, None, None)), in_axes=(0, None, None))( 232 | polar, rad_t0, cx_t0 233 | ) 234 | sol = odeint( 235 | self.aug_fdyn_jax, 236 | aug_state, 237 | jnp.array([0, step]), 238 | atol=self.atol, 239 | rtol=self.rtol, 240 | ) 241 | x, F = pmap(vmap(self.reshape_aug_state_to_matrix))(sol[-1]) 242 | return x, F, initial_x 243 | 244 | def aug_integrator_neg_dist(self, polar): 245 | x, F, initial_x = self.aug_integrator(polar) 246 | neg_dist = pmap(vmap(self.neg_dist_x))(x) 247 | return x, F, neg_dist, initial_x 248 | 249 | def one_step_aug_integrator_dist(self, x, F): 250 | x, F = self.one_step_aug_integrator(x, F) 251 | neg_dist = pmap(vmap(self.neg_dist_x))(x) 252 | return x, F, -neg_dist 253 | 254 | def neg_dist_x(self, xt): 255 | dist = jnp.linalg.norm(jnp.matmul(self.A1, xt - self.cur_cx)) 256 | return -dist -------------------------------------------------------------------------------- /benchmarks.py: -------------------------------------------------------------------------------- 1 | # different classes with benchmarks 2 | 3 | import jax.numpy as np 4 | from jax.numpy import tanh 5 | from jax.numpy import sin 6 | from jax.numpy import cos 7 | from jax.numpy import exp 8 | 9 | 10 | def get_model(benchmark, radius=None): 11 | if benchmark == "bruss": 12 | return Brusselator(radius) # Benchmark to run 13 | elif benchmark == "vdp": 14 | return VanDerPol(radius) # Benchmark to run 15 | elif benchmark == "robot": 16 | return Robotarm(radius) # Benchmark to run 17 | elif benchmark == "dubins": 18 | return DubinsCar(radius) # Benchmark to run 19 | elif benchmark == "ms": 20 | return MitchellSchaeffer(radius) # Benchmark to run 21 | elif benchmark == "cartpole": 22 | return CartpoleLinear(radius) # Benchmark to run 23 | elif benchmark == "quadcopter": 24 | return Quadcopter(radius) # Benchmark to run 25 | elif benchmark == "cartpoleCTRNN": 26 | return CartpoleCTRNN(radius) # Benchmark to run 27 | elif benchmark == "cartpoleLTC": 28 | return CartpoleLTC(radius) # Benchmark to run 29 | elif benchmark == "cartpoleLTC_RK": 30 | return CartpoleLTC_RK(radius) # Benchmark to run 31 | elif benchmark == "ldsCTRNN": 32 | return LDSwithCTRNN(radius) # Benchmark to run 33 | elif benchmark == "pendulumCTRNN": 34 | return PendulumwithCTRNN(radius) # Benchmark to run 35 | elif benchmark == "CTRNNosc": 36 | return CTRNNosc(radius) # Benchmark to run 37 | else: 38 | raise ValueError("Unknown benchmark " + benchmark) 39 | 40 | 41 | # 2-dimensional brusselator 42 | class Brusselator: 43 | def __init__(self, radius=None): 44 | # ============ adapt initial values =========== 45 | self.cx = (1, 1) 46 | if radius is not None: 47 | self.rad = radius 48 | else: 49 | self.rad = 0.01 50 | # =================================================== 51 | self.cx = np.array(self.cx, dtype=float) 52 | self.dim = self.cx.size # dimension of the system 53 | 54 | def fdyn(self, t=0, x=None): 55 | if x is None: 56 | x = np.zeros(self.dim, dtype=object) 57 | 58 | # ============ adapt input and system dynamics =========== 59 | x, y = x 60 | 61 | a = 1 62 | b = 1.5 63 | 64 | fx = a + x ** 2 * y - (b + 1) * x 65 | fy = b * x - x ** 2 * y 66 | 67 | system_dynamics = [fx, fy] # has to be in the same order as the input variables 68 | # =================================================== 69 | 70 | return np.array(system_dynamics) # return as numpy array 71 | 72 | 73 | # 2-dimensional van der pol 74 | class VanDerPol: 75 | def __init__(self, radius): 76 | # ============ adapt initial values =========== 77 | self.cx = (-1, -1) 78 | if radius is not None: 79 | self.rad = radius 80 | else: 81 | self.rad = 0.1 82 | # =================================================== 83 | self.cx = np.array(self.cx, dtype=float) 84 | self.dim = self.cx.size # dimension of the system 85 | 86 | def fdyn(self, t=0, x=None): 87 | if x is None: 88 | x = np.zeros(self.dim, dtype=object) 89 | 90 | # ============ adapt input and system dynamics =========== 91 | x, y = x 92 | 93 | fx = y 94 | fy = (x * x - 1) * y - x 95 | 96 | system_dynamics = [fx, fy] # has to be in the same order as the input variables 97 | # =================================================== 98 | 99 | return np.array(system_dynamics) # return as numpy array 100 | 101 | 102 | # 4-dimensional Robotarm 103 | class Robotarm: 104 | def __init__(self, radius): 105 | # ============ adapt initial values =========== 106 | self.cx = (1.505, 1.505, 0.005, 0.005) 107 | if radius is not None: 108 | self.rad = radius 109 | else: 110 | self.rad = 0.05 111 | # =================================================== 112 | self.cx = np.array(self.cx, dtype=float) 113 | self.dim = self.cx.size # dimension of the system 114 | 115 | def fdyn(self, t=0, x=None): 116 | if x is None: 117 | x = np.zeros(self.dim, dtype=object) 118 | 119 | # ============ adapt input and system dynamics =========== 120 | x1, x2, x3, x4 = x 121 | 122 | m = 1 123 | l = 3 124 | kp1 = 2 125 | kp2 = 1 126 | kd1 = 2 127 | kd2 = 1 128 | 129 | fx1 = x3 130 | fx2 = x4 131 | fx3 = (-2 * m * x2 * x3 * x4 - kp1 * x1 - kd1 * x3) / (m * x2 * x2 + l / 3) + ( 132 | kp1 * kp1 133 | ) / (m * x2 * x2 + l / 3) 134 | fx4 = x2 * x3 * x3 - kp2 * x2 / m - kd2 * x4 / m + kp2 * kp2 / m 135 | 136 | system_dynamics = [ 137 | fx1, 138 | fx2, 139 | fx3, 140 | fx4, 141 | ] # has to be in the same order as the input variables 142 | # =================================================== 143 | 144 | return np.array(system_dynamics) # return as numpy array 145 | 146 | 147 | # 3-dimensional dubins car 148 | class DubinsCar: 149 | def __init__(self, radius): 150 | # ============ adapt initial values =========== 151 | self.cx = (0, 0, 0.7854, 0) 152 | if radius is not None: 153 | self.rad = radius 154 | else: 155 | self.rad = 0.01 156 | # =================================================== 157 | self.cx = np.array(self.cx, dtype=float) 158 | self.dim = self.cx.size # dimension of the system 159 | 160 | def fdyn(self, t=0, x=None): 161 | if x is None: 162 | x = np.zeros(self.dim, dtype=object) 163 | 164 | # ============ adapt input and system dynamics =========== 165 | x, y, th, tt = x 166 | 167 | v = 1 168 | 169 | fx = v * cos(th) 170 | fy = v * sin(th) 171 | fth = x * sin(tt) 172 | ftt = np.array(1) # this is needed for lax_numpy to see this as an array 173 | 174 | system_dynamics = [ 175 | fx, 176 | fy, 177 | fth, 178 | ftt, 179 | ] # has to be in the same order as the input variables 180 | # =================================================== 181 | 182 | return np.array(system_dynamics) # return as numpy array 183 | 184 | 185 | # 2-dimensional Mitchell Schaeffer cardiac-cell 186 | class MitchellSchaeffer: 187 | def __init__(self, radius): 188 | # ============ adapt initial values =========== 189 | self.cx = (0.8, 0.5) 190 | if radius is not None: 191 | self.rad = radius 192 | else: 193 | self.rad = 0.1 194 | # =================================================== 195 | self.cx = np.array(self.cx, dtype=float) 196 | self.dim = self.cx.size # dimension of the system 197 | 198 | def fdyn(self, t=0, x=None): 199 | if x is None: 200 | x = np.zeros(self.dim, dtype=object) 201 | 202 | # ============ adapt input and system dynamics =========== 203 | x1, x2 = x 204 | 205 | sig_x1 = 0.5 * (1 + tanh(50 * x1 - 5)) 206 | 207 | fx1 = x2 * x1 ** 2 * (1 - x1) / 0.3 - x1 / 6 208 | fx2 = sig_x1 * (-x2 / 150) + (1 - sig_x1) * (1 - x2) / 20 209 | 210 | system_dynamics = [ 211 | fx1, 212 | fx2, 213 | ] # has to be in the same order as the input variables 214 | # =================================================== 215 | 216 | return np.array(system_dynamics) # return as numpy array 217 | 218 | 219 | # 4-dimensional cartpole with linear stabilizing controller 220 | class CartpoleLinear: 221 | def __init__(self, radius): 222 | # ============ adapt initial values =========== 223 | self.cx = (0, 0, 0.001, 0) # initial values 224 | if radius is not None: 225 | self.rad = radius 226 | else: 227 | self.rad = 0.05 228 | # =================================================== 229 | 230 | self.cx = np.array(self.cx, dtype=float) 231 | self.dim = self.cx.size # dimension of the system 232 | 233 | def fdyn(self, t=0, x=None): 234 | if x is None: 235 | x = np.zeros(self.dim, dtype=object) 236 | 237 | # ============ adapt input and system dynamics =========== 238 | dth, dx, th, x = x # input variables 239 | 240 | M = 1.0 241 | g = 9.81 242 | l = 1.0 243 | m = 0.001 244 | 245 | f = -1.1 * M * g * th - dth 246 | 247 | fdth = ( 248 | 1.0 249 | / (l * (M + m * sin(th) * sin(th))) 250 | * ( 251 | f * cos(th) 252 | - m * l * dth * dth * cos(th) * sin(th) 253 | + (m + M) * g * sin(th) 254 | ) 255 | ) 256 | fdx = ( 257 | 1.0 258 | / (M + m * sin(th) * sin(th)) 259 | * (f + m * sin(th) * (-l * dth * dth + g * cos(th))) 260 | ) 261 | 262 | fx = dx 263 | 264 | fth = dth 265 | 266 | system_dynamics = [ 267 | fdth, 268 | fdx, 269 | fth, 270 | fx, 271 | ] # has to be in the same order as the input variables 272 | # =================================================== 273 | 274 | return np.array(system_dynamics) # return as numpy array 275 | 276 | 277 | # 17-dimensional Quadcopter 278 | class Quadcopter: 279 | def __init__(self, radius): 280 | # ============ adapt initial values =========== 281 | self.cx = ( 282 | -0.995, 283 | -0.995, 284 | 9.005, 285 | -0.995, 286 | -0.995, 287 | -0.995, 288 | -0.995, 289 | -0.995, 290 | -0.995, 291 | 0, 292 | 0, 293 | 0, 294 | 1, 295 | 0, 296 | 0, 297 | 0, 298 | 0, 299 | ) 300 | if radius is not None: 301 | self.rad = radius 302 | else: 303 | self.rad = 0.005 304 | # =================================================== 305 | self.cx = np.array(self.cx, dtype=float) 306 | self.dim = self.cx.size # dimension of the system 307 | 308 | def fdyn(self, t=0, x=None): 309 | if x is None: 310 | x = np.zeros(self.dim, dtype=object) 311 | 312 | # ============ adapt input and system dynamics =========== 313 | pn, pe, h, u, v, w, p, q, r, q0, q1, q2, q3, pI, qI, rI, hI = x 314 | 315 | pr = 0 316 | qr = 0 317 | rr = 0 318 | hr = 0 319 | 320 | pn_ = ( 321 | 2 * u * (q0 * q0 + q1 * q1 - 0.5) 322 | - 2 * v * (q0 * q3 - q1 * q2) 323 | + 2 * w * (q0 * q2 + q1 * q3) 324 | ) 325 | pe_ = ( 326 | 2 * v * (q0 * q0 + q2 * q2 - 0.5) 327 | + 2 * u * (q0 * q3 + q1 * q2) 328 | - 2 * w * (q0 * q1 - q2 * q3) 329 | ) 330 | h_ = ( 331 | 2 * w * (q0 * q0 + q3 * q3 - 0.5) 332 | - 2 * u * (q0 * q2 - q1 * q3) 333 | + 2 * v * (q0 * q1 + q2 * q3) 334 | ) 335 | 336 | u_ = r * v - q * w - 11.62 * (q0 * q2 - q1 * q3) 337 | v_ = p * w - r * u + 11.62 * (q0 * q1 + q2 * q3) 338 | w_ = q * u - p * v + 11.62 * (q0 * q0 + q3 * q3 - 0.5) 339 | 340 | q0_ = -0.5 * q1 * p - 0.5 * q2 * q - 0.5 * q3 * r 341 | q1_ = 0.5 * q0 * p - 0.5 * q3 * q + 0.5 * q2 * r 342 | q2_ = 0.5 * q3 * p + 0.5 * q0 * q - 0.5 * q1 * r 343 | q3_ = 0.5 * q1 * q - 0.5 * q2 * p + 0.5 * q0 * r 344 | 345 | p_ = ( 346 | -40.00063258437631 * pI - 2.8283979829540325 * p 347 | ) - 1.133407423682400 * q * r 348 | q_ = ( 349 | -39.99980452524146 * qI - 2.8283752541008109 * q 350 | ) + 1.132078179613602 * p * r 351 | r_ = ( 352 | -39.99978909742505 * rI - 2.8284134223281210 * r 353 | ) - 0.004695219977601 * p * q 354 | 355 | pI_ = p - pr 356 | qI_ = q - qr 357 | rI_ = r - rr 358 | hI_ = h - hr 359 | 360 | system_dynamics = [ 361 | pn_, 362 | pe_, 363 | h_, 364 | u_, 365 | v_, 366 | w_, 367 | p_, 368 | q_, 369 | r_, 370 | q0_, 371 | q1_, 372 | q2_, 373 | q3_, 374 | pI_, 375 | qI_, 376 | rI_, 377 | hI_, 378 | ] # has to be in the same order as the input variables 379 | # =================================================== 380 | 381 | return np.array(system_dynamics) # return as numpy array 382 | 383 | 384 | # CTRNN DampedForced Pendulum example 385 | # from https://easychair.org/publications/open/K6SZ 386 | class CTRNN_DampedForcedPendulum: 387 | def __init__(self, radius): 388 | # ============ adapt initial values =========== 389 | self.cx = (0.21535, -0.58587, 0.8, 0.52323, 0.5) # initial values 390 | if radius is not None: 391 | self.rad = radius 392 | else: 393 | self.rad = 1e-08 # initial radius 394 | # =================================================== 395 | 396 | self.cx = np.array(self.cx, dtype=float) 397 | self.dim = self.cx.size # dimension of the system 398 | 399 | def fdyn(self, t=0, x=None): 400 | if x is None: 401 | x = np.zeros(self.dim, dtype=object) 402 | 403 | # ============ adapt input and system dynamics =========== 404 | x1, x2, x3, x4, x5 = x 405 | 406 | x1p = ( 407 | 5419046323626097 408 | * (exp(-2 * x3) - 1) 409 | / (4503599627370496 * (exp(-2 * x3) + 1)) 410 | - x1 / 1000000 411 | + 3601 * (exp(-2 * x4) - 1) / (50000 * (exp(-2 * x4) + 1)) 412 | + 18727 * (exp(-2 * x5) - 1) / (20000 * (exp(-2 * x5) + 1)) 413 | ) 414 | x2p = ( 415 | 30003 * (exp(-2 * x4) - 1) / (20000 * (exp(-2 * x4) + 1)) 416 | - 11881 * (exp(-2 * x3) - 1) / (10000 * (exp(-2 * x3) + 1)) 417 | - x2 / 1000000 418 | - 93519 * (exp(-2 * x5) - 1) / (100000 * (exp(-2 * x5) + 1)) 419 | ) 420 | x3p = ( 421 | 7144123746377831 422 | * (exp(-2 * x3) - 1) 423 | / (4503599627370496 * (exp(-2 * x3) + 1)) 424 | - x3 / 1000000 425 | - 5048886837752751 426 | * (exp(-2 * x4) - 1) 427 | / (72057594037927936 * (exp(-2 * x4) + 1)) 428 | + 5564385670244745 429 | * (exp(-2 * x5) - 1) 430 | / (4503599627370496 * (exp(-2 * x5) + 1)) 431 | ) 432 | x4p = ( 433 | 1348796766312415 434 | * (exp(-2 * x4) - 1) 435 | / (4503599627370496 * (exp(-2 * x4) + 1)) 436 | - 3086507593514335 437 | * (exp(-2 * x3) - 1) 438 | / (36028797018963968 * (exp(-2 * x3) + 1)) 439 | - x4 / 1000000 440 | - 2476184452153819 441 | * (exp(-2 * x5) - 1) 442 | / (36028797018963968 * (exp(-2 * x5) + 1)) 443 | ) 444 | x5p = ( 445 | 1523758031023695 446 | * (exp(-2 * x4) - 1) 447 | / (18014398509481984 * (exp(-2 * x4) + 1)) 448 | - 8060407538855891 449 | * (exp(-2 * x3) - 1) 450 | / (4503599627370496 * (exp(-2 * x3) + 1)) 451 | - x5 / 1000000 452 | - 3139112893264555 453 | * (exp(-2 * x5) - 1) 454 | / (2251799813685248 * (exp(-2 * x5) + 1)) 455 | ) 456 | 457 | system_dynamics = [x1p, x2p, x3p, x4p, x5p] 458 | # =================================================== 459 | 460 | return np.array(system_dynamics) # return as numpy array 461 | 462 | 463 | # 12-dimensional cartpole with CT-RNN neural network controller 464 | class CartpoleCTRNN: 465 | def __init__(self, radius): 466 | # ============ adapt initial values =========== 467 | self.cx = (0, 0, 0.001, 0, 0, 0, 0, 0, 0, 0, 0, 0) 468 | if radius is not None: 469 | self.rad = radius 470 | else: 471 | self.rad = 1e-4 472 | # =================================================== 473 | self.cx = np.array(self.cx, dtype=float) 474 | self.dim = self.cx.size # dimension of the system 475 | 476 | def fdyn(self, t=0, x=None): 477 | if x is None: 478 | x = np.zeros(self.dim, dtype=object) 479 | 480 | # ============ adapt input and system dynamics =========== 481 | x_00, x_01, x_02, x_03, h_00, h_01, h_02, h_03, h_04, h_05, h_06, h_07 = x 482 | 483 | h_00_prime = -h_00 + tanh( 484 | h_00 * -1.49394 485 | + h_01 * -0.61947 486 | + h_02 * 0.37393 487 | + h_03 * -0.63451 488 | + h_04 * -1.08420 489 | + h_05 * 2.57981 490 | + h_06 * -1.53850 491 | + h_07 * -1.64354 492 | + x_00 * -0.07445 493 | + x_01 * 0.08736 494 | + x_02 * 0.47684 495 | + x_03 * 0.56397 496 | + 1.81867 497 | ) 498 | h_01_prime = -h_01 + tanh( 499 | h_00 * 1.38295 500 | + h_01 * -1.45811 501 | + h_02 * 1.01473 502 | + h_03 * -0.04578 503 | + h_04 * -0.13416 504 | + h_05 * -0.21970 505 | + h_06 * 0.41791 506 | + h_07 * -0.10833 507 | + x_00 * -1.70409 508 | + x_01 * 0.51560 509 | + x_02 * -0.71273 510 | + x_03 * -0.61720 511 | + 0.86009 512 | ) 513 | h_02_prime = -h_02 + tanh( 514 | h_00 * 0.57055 515 | + h_01 * 0.05941 516 | + h_02 * -0.16993 517 | + h_03 * -0.69688 518 | + h_04 * -0.30939 519 | + h_05 * -1.31558 520 | + h_06 * 0.03316 521 | + h_07 * 2.35873 522 | + x_00 * 0.21849 523 | + x_01 * 1.42990 524 | + x_02 * 1.60666 525 | + x_03 * -0.66847 526 | + -1.56112 527 | ) 528 | h_03_prime = -h_03 + tanh( 529 | h_00 * -1.53087 530 | + h_01 * -0.42779 531 | + h_02 * -0.02195 532 | + h_03 * 0.18007 533 | + h_04 * 1.54262 534 | + h_05 * -0.19275 535 | + h_06 * -0.64598 536 | + h_07 * -0.85840 537 | + x_00 * 0.10095 538 | + x_01 * 0.55012 539 | + x_02 * 1.51416 540 | + x_03 * 1.95952 541 | + 0.00020 542 | ) 543 | h_04_prime = -h_04 + tanh( 544 | h_00 * -0.54493 545 | + h_01 * 0.41139 546 | + h_02 * -0.18790 547 | + h_03 * -0.14312 548 | + h_04 * -0.10144 549 | + h_05 * 2.38792 550 | + h_06 * -0.44134 551 | + h_07 * -2.06462 552 | + x_00 * -0.01260 553 | + x_01 * 0.84681 554 | + x_02 * 0.51889 555 | + x_03 * 0.35089 556 | + -0.06552 557 | ) 558 | h_05_prime = -h_05 + tanh( 559 | h_00 * 1.08793 560 | + h_01 * -0.24559 561 | + h_02 * -0.97041 562 | + h_03 * -0.16874 563 | + h_04 * -0.27133 564 | + h_05 * 0.73513 565 | + h_06 * 0.73096 566 | + h_07 * -2.26244 567 | + x_00 * 0.80312 568 | + x_01 * -1.35635 569 | + x_02 * -0.95283 570 | + x_03 * -0.68460 571 | + 0.11646 572 | ) 573 | h_06_prime = -h_06 + tanh( 574 | h_00 * 0.08745 575 | + h_01 * 0.10104 576 | + h_02 * 0.82251 577 | + h_03 * -0.59823 578 | + h_04 * 1.16949 579 | + h_05 * 1.73511 580 | + h_06 * -0.89387 581 | + h_07 * 0.77149 582 | + x_00 * 0.22731 583 | + x_01 * -0.73872 584 | + x_02 * 2.83509 585 | + x_03 * -2.49535 586 | + 0.70343 587 | ) 588 | h_07_prime = -h_07 + tanh( 589 | h_00 * -0.89045 590 | + h_01 * -1.30985 591 | + h_02 * 0.48740 592 | + h_03 * -0.45750 593 | + h_04 * -0.70727 594 | + h_05 * -0.43216 595 | + h_06 * -0.29159 596 | + h_07 * 3.70331 597 | + x_00 * -0.13294 598 | + x_01 * -0.23242 599 | + x_02 * 0.33244 600 | + x_03 * -1.45838 601 | + -0.02392 602 | ) 603 | 604 | action_00 = tanh( 605 | h_00 * -0.32011 606 | + h_01 * 0.32958 607 | + h_02 * 0.07718 608 | + h_03 * 2.39431 609 | + h_04 * -1.37732 610 | + h_05 * 0.92722 611 | + h_06 * -1.07137 612 | + h_07 * -1.47286 613 | ) 614 | 615 | gravity = 9.8 616 | masscart = 1 617 | masspole = 0.1 618 | total_mass = masscart + masspole 619 | length = 0.5 620 | polemass_length = masspole * length 621 | force_mag = 10 622 | 623 | force = force_mag * action_00 624 | costheta = cos(x_02) 625 | sintheta = sin(x_02) 626 | x_dot = x_01 627 | theta_dot = x_03 628 | 629 | temp = (force + polemass_length * theta_dot * theta_dot * sintheta) / total_mass 630 | thetaacc = (gravity * sintheta - costheta * temp) / ( 631 | length * (4.0 / 3.0 - masspole * costheta * costheta / total_mass) 632 | ) 633 | xacc = temp - polemass_length * thetaacc * costheta / total_mass 634 | 635 | x_00_prime = x_dot 636 | x_01_prime = xacc 637 | x_02_prime = theta_dot 638 | x_03_prime = thetaacc 639 | 640 | system_dynamics = [ 641 | x_00_prime, 642 | x_01_prime, 643 | x_02_prime, 644 | x_03_prime, 645 | h_00_prime, 646 | h_01_prime, 647 | h_02_prime, 648 | h_03_prime, 649 | h_04_prime, 650 | h_05_prime, 651 | h_06_prime, 652 | h_07_prime, 653 | ] # has to be in the same order as the input variables 654 | # =================================================== 655 | 656 | return np.array(system_dynamics) # return as numpy array 657 | 658 | 659 | # 12-dimensional cartpole with LTC neural network controller 660 | class CartpoleLTC: 661 | def __init__(self, radius): 662 | # ============ adapt initial values =========== 663 | self.cx = (0, 0, 0.001, 0, 0, 0, 0, 0, 0, 0, 0, 0) 664 | if radius is not None: 665 | self.rad = radius 666 | else: 667 | self.rad = 1e-4 668 | # =================================================== 669 | self.cx = np.array(self.cx, dtype=float) 670 | self.dim = self.cx.size # dimension of the system 671 | 672 | def true_sigmoid(self, x): 673 | return 0.5 * (tanh(x * 0.5) + 1) 674 | 675 | def sigmoid(self, v_pre, mu, sigma): 676 | mues = v_pre - mu 677 | x = sigma * mues 678 | return self.true_sigmoid(x) 679 | 680 | def fdyn(self, t=0, x=None): 681 | if x is None: 682 | x = np.zeros(self.dim, dtype=object) 683 | 684 | # ============ adapt input and system dynamics =========== 685 | x_00, x_01, x_02, x_03, h_00, h_01, h_02, h_03, h_04, h_05, h_06, h_07 = x 686 | 687 | u_00 = x_00 / 0.1 688 | u_01 = x_01 / 0.2 689 | u_02 = x_02 / 0.1 690 | u_03 = x_03 / 0.1 691 | 692 | swa_00 = ( 693 | (-3.337809 - h_00) * 0.000100 * self.sigmoid(u_00, -1.822089, 4.920430) 694 | + (-3.793303 - h_00) * 0.217915 * self.sigmoid(u_01, 0.589409, 4.212314) 695 | + (1.074355 - h_00) * 0.000100 * self.sigmoid(u_02, -1.301882, 1.045275) 696 | + (-1.837527 - h_00) * 0.784721 * self.sigmoid(u_03, 0.061974, 9.342714) 697 | ) 698 | swa_01 = ( 699 | (1.922444 - h_01) * 0.048898 * self.sigmoid(u_00, 0.069332, 4.315088) 700 | + (-2.419416 - h_01) * 0.627589 * self.sigmoid(u_01, 2.746536, 3.680750) 701 | + (3.581162 - h_01) * 0.165919 * self.sigmoid(u_02, -1.045852, 3.066419) 702 | + (1.778340 - h_01) * 0.120764 * self.sigmoid(u_03, 1.837324, 3.448219) 703 | ) 704 | swa_02 = ( 705 | (2.205145 - h_02) * 0.304337 * self.sigmoid(u_00, 0.031854, 4.853640) 706 | + (-3.098629 - h_02) * 0.084116 * self.sigmoid(u_01, 0.441696, 3.296364) 707 | + (-1.226080 - h_02) * 0.177806 * self.sigmoid(u_02, 2.182991, 1.098078) 708 | + (-2.913366 - h_02) * 0.491636 * self.sigmoid(u_03, 2.224604, 1.728117) 709 | ) 710 | swa_03 = ( 711 | (0.960911 - h_03) * 0.364263 * self.sigmoid(u_00, -1.053052, 1.971431) 712 | + (-1.638120 - h_03) * 0.162840 * self.sigmoid(u_01, 7.642996, 2.541498) 713 | + (-1.138980 - h_03) * 0.746401 * self.sigmoid(u_02, 1.725237, 0.663583) 714 | + (1.642858 - h_03) * 0.299424 * self.sigmoid(u_03, 0.687038, -0.347229) 715 | ) 716 | swa_04 = ( 717 | (-0.752243 - h_04) * 0.649805 * self.sigmoid(u_00, 5.850876, 4.392767) 718 | + (0.135855 - h_04) * 0.018294 * self.sigmoid(u_01, 3.993045, 2.950581) 719 | + (-0.291421 - h_04) * 0.404132 * self.sigmoid(u_02, -0.929036, 3.571369) 720 | + (-3.123502 - h_04) * 0.521996 * self.sigmoid(u_03, -0.641402, 7.249665) 721 | ) 722 | swa_05 = ( 723 | (-0.472946 - h_05) * 0.359657 * self.sigmoid(u_00, 1.851295, 6.572009) 724 | + (-0.733541 - h_05) * 0.382768 * self.sigmoid(u_01, 2.120606, 3.201269) 725 | + (1.819771 - h_05) * 0.223484 * self.sigmoid(u_02, 2.934255, 4.131283) 726 | + (0.080200 - h_05) * 0.108734 * self.sigmoid(u_03, -2.828882, 1.771147) 727 | ) 728 | swa_06 = ( 729 | (2.596139 - h_06) * 0.233149 * self.sigmoid(u_00, 1.404586, 5.493755) 730 | + (-0.948662 - h_06) * 0.297568 * self.sigmoid(u_01, -5.682380, 3.978102) 731 | + (4.558042 - h_06) * 0.736431 * self.sigmoid(u_02, -0.067011, 3.915714) 732 | + (1.514289 - h_06) * 0.226179 * self.sigmoid(u_03, -1.101492, 5.705000) 733 | ) 734 | swa_07 = ( 735 | (0.907580 - h_07) * 0.077591 * self.sigmoid(u_00, 0.822831, 2.474382) 736 | + (-2.282039 - h_07) * 0.161999 * self.sigmoid(u_01, -0.241190, 6.074738) 737 | + (0.294892 - h_07) * 0.202664 * self.sigmoid(u_02, -0.289602, 8.727577) 738 | + (-2.329513 - h_07) * 0.242021 * self.sigmoid(u_03, -0.960229, 4.780694) 739 | ) 740 | wa_00 = ( 741 | (0.425183 - h_00) * 0.580120 * self.sigmoid(h_00, 1.888106, 5.698736) 742 | + (-0.026397 - h_00) * 0.422325 * self.sigmoid(h_01, 0.322996, 3.287380) 743 | + (0.232258 - h_00) * 0.234096 * self.sigmoid(h_02, 1.988594, 5.637971) 744 | + (0.325219 - h_00) * 0.822936 * self.sigmoid(h_03, 1.430856, 1.847167) 745 | + (-3.234136 - h_00) * 0.143390 * self.sigmoid(h_04, 1.693733, 1.529007) 746 | + (-0.446780 - h_00) * 0.000010 * self.sigmoid(h_05, 4.226530, 1.182250) 747 | + (-1.355023 - h_00) * 0.605004 * self.sigmoid(h_06, 1.083694, 2.011393) 748 | + (-3.184705 - h_00) * 0.338830 * self.sigmoid(h_07, 5.826199, 2.559444) 749 | ) 750 | wa_01 = ( 751 | (-1.902008 - h_01) * 0.151651 * self.sigmoid(h_00, 1.296603, 3.500569) 752 | + (-0.865412 - h_01) * 0.042350 * self.sigmoid(h_01, 1.769154, 10.179186) 753 | + (-3.485676 - h_01) * 0.259284 * self.sigmoid(h_02, 1.534717, 3.293725) 754 | + (-1.715340 - h_01) * 0.197946 * self.sigmoid(h_03, 0.264122, 5.787919) 755 | + (-3.960100 - h_01) * 0.346343 * self.sigmoid(h_04, 4.226974, 5.074825) 756 | + (-1.888838 - h_01) * 0.167991 * self.sigmoid(h_05, -0.192584, 4.463276) 757 | + (3.343740 - h_01) * 0.407932 * self.sigmoid(h_06, -2.019538, 2.096939) 758 | + (-3.828055 - h_01) * 0.392470 * self.sigmoid(h_07, -1.217940, 3.255837) 759 | ) 760 | wa_02 = ( 761 | (-0.149339 - h_02) * 0.274537 * self.sigmoid(h_00, -0.229752, 0.488301) 762 | + (0.411455 - h_02) * 0.026122 * self.sigmoid(h_01, 4.011376, 4.030908) 763 | + (3.896320 - h_02) * 0.157079 * self.sigmoid(h_02, 2.459823, 3.892246) 764 | + (0.530434 - h_02) * 0.059498 * self.sigmoid(h_03, 1.132183, 1.906741) 765 | + (-0.194996 - h_02) * 0.087192 * self.sigmoid(h_04, 0.000571, 0.949683) 766 | + (-0.407292 - h_02) * 0.194082 * self.sigmoid(h_05, -0.997799, 1.219973) 767 | + (-6.620962 - h_02) * 0.398726 * self.sigmoid(h_06, 0.469706, 5.065333) 768 | + (1.787329 - h_02) * 1.050713 * self.sigmoid(h_07, 4.342476, -0.298296) 769 | ) 770 | wa_03 = ( 771 | (-5.372879 - h_03) * 0.027963 * self.sigmoid(h_00, 2.814907, 7.856180) 772 | + (-0.646331 - h_03) * 0.149699 * self.sigmoid(h_01, 0.987782, 2.744861) 773 | + (0.116237 - h_03) * 0.135569 * self.sigmoid(h_02, 4.085198, 3.937947) 774 | + (-2.523393 - h_03) * 0.165372 * self.sigmoid(h_03, 0.703319, 5.054082) 775 | + (0.815988 - h_03) * 0.189418 * self.sigmoid(h_04, 3.352685, 7.961621) 776 | + (-1.706797 - h_03) * 0.502507 * self.sigmoid(h_05, -0.760482, 5.458648) 777 | + (-2.023418 - h_03) * 0.026645 * self.sigmoid(h_06, 4.557816, 2.245746) 778 | + (-0.889180 - h_03) * 0.468236 * self.sigmoid(h_07, -2.680705, 7.208436) 779 | ) 780 | wa_04 = ( 781 | (1.343493 - h_04) * 0.603259 * self.sigmoid(h_00, -0.901840, 3.872216) 782 | + (0.831619 - h_04) * 0.426593 * self.sigmoid(h_01, -3.089995, 5.196197) 783 | + (0.421409 - h_04) * 0.068839 * self.sigmoid(h_02, 3.991761, 0.784316) 784 | + (2.355011 - h_04) * 0.018658 * self.sigmoid(h_03, -0.024262, 7.865137) 785 | + (0.593414 - h_04) * 0.252146 * self.sigmoid(h_04, -1.463303, 3.132787) 786 | + (-0.637192 - h_04) * 0.118786 * self.sigmoid(h_05, -1.024096, 4.974855) 787 | + (-0.732044 - h_04) * 0.588030 * self.sigmoid(h_06, 1.033544, 6.298468) 788 | + (-0.194411 - h_04) * 0.679828 * self.sigmoid(h_07, 0.410277, 8.185476) 789 | ) 790 | wa_05 = ( 791 | (0.495678 - h_05) * 0.338934 * self.sigmoid(h_00, 1.896829, 3.323362) 792 | + (0.680690 - h_05) * 0.558593 * self.sigmoid(h_01, -0.032828, 3.472205) 793 | + (1.564447 - h_05) * 0.416106 * self.sigmoid(h_02, -0.448733, 4.371096) 794 | + (0.896225 - h_05) * 0.102958 * self.sigmoid(h_03, -1.402512, 7.281286) 795 | + (3.866403 - h_05) * 0.388992 * self.sigmoid(h_04, -1.392615, 1.017738) 796 | + (-1.880220 - h_05) * 0.075561 * self.sigmoid(h_05, 0.562915, 6.004653) 797 | + (-5.541122 - h_05) * 0.279411 * self.sigmoid(h_06, -0.679593, 2.001722) 798 | + (-0.065618 - h_05) * 0.475569 * self.sigmoid(h_07, -1.229657, 3.177203) 799 | ) 800 | wa_06 = ( 801 | (-1.465561 - h_06) * 0.122170 * self.sigmoid(h_00, 0.706757, 2.401071) 802 | + (-2.485563 - h_06) * 0.091596 * self.sigmoid(h_01, 0.335557, 3.252133) 803 | + (1.158366 - h_06) * 0.519633 * self.sigmoid(h_02, -1.169937, 3.588379) 804 | + (1.955986 - h_06) * 0.252199 * self.sigmoid(h_03, -1.974608, 2.870511) 805 | + (0.743124 - h_06) * 0.651011 * self.sigmoid(h_04, -0.210935, 7.228260) 806 | + (3.452268 - h_06) * 0.019733 * self.sigmoid(h_05, 3.515318, 3.879946) 807 | + (0.604325 - h_06) * 0.191627 * self.sigmoid(h_06, -3.070420, 4.221974) 808 | + (2.465013 - h_06) * 0.143574 * self.sigmoid(h_07, -0.961958, 3.002333) 809 | ) 810 | wa_07 = ( 811 | (-2.135510 - h_07) * 0.082957 * self.sigmoid(h_00, 0.857560, 4.560907) 812 | + (-0.719410 - h_07) * 0.121560 * self.sigmoid(h_01, 2.004791, 7.674825) 813 | + (-0.095160 - h_07) * 0.085265 * self.sigmoid(h_02, 0.915076, 3.384454) 814 | + (-6.402920 - h_07) * 0.017902 * self.sigmoid(h_03, 1.900790, 3.606339) 815 | + (0.911732 - h_07) * 0.018167 * self.sigmoid(h_04, 0.001252, 7.378105) 816 | + (-0.876065 - h_07) * 0.119567 * self.sigmoid(h_05, 2.130298, 7.560087) 817 | + (-1.520981 - h_07) * 0.012488 * self.sigmoid(h_06, -0.275924, 2.646271) 818 | + (0.734387 - h_07) * 0.088953 * self.sigmoid(h_07, -1.333799, 3.340692) 819 | ) 820 | h_00_prime = 10000000.000000 * (0.663909 * (4.880308 - h_00) + swa_00 + wa_00) 821 | h_01_prime = 0.897653 * (2.666445 * (-0.506661 - h_01) + swa_01 + wa_01) 822 | h_02_prime = 0.513291 * (4.721010 * (-1.314806 - h_02) + swa_02 + wa_02) 823 | h_03_prime = 0.301197 * (0.073180 * (0.044130 - h_03) + swa_03 + wa_03) 824 | h_04_prime = 0.404649 * (1.368898 * (-1.209986 - h_04) + swa_04 + wa_04) 825 | h_05_prime = 1.131324 * (0.000010 * (-2.630356 - h_05) + swa_05 + wa_05) 826 | h_06_prime = 10000000.000000 * (0.539536 * (-0.529881 - h_06) + swa_06 + wa_06) 827 | h_07_prime = 0.525138 * (0.507019 * (0.567584 - h_07) + swa_07 + wa_07) 828 | 829 | action_00 = tanh( 830 | h_00 * -0.895448 831 | + h_01 * -0.795233 832 | + h_02 * 0.209421 833 | + h_03 * 0.143533 834 | + h_04 * -0.084270 835 | + h_05 * -0.342232 836 | + h_06 * 0.333434 837 | + h_07 * -0.032306 838 | + 0.757913 839 | ) 840 | 841 | gravity = 9.8 842 | masscart = 1 843 | masspole = 0.1 844 | total_mass = masscart + masspole 845 | length = 0.5 846 | polemass_length = masspole * length 847 | force_mag = 10 848 | 849 | force = force_mag * action_00 850 | costheta = cos(x_02) 851 | sintheta = sin(x_02) 852 | x_dot = x_01 853 | theta_dot = x_03 854 | 855 | temp = (force + polemass_length * theta_dot * theta_dot * sintheta) / total_mass 856 | thetaacc = (gravity * sintheta - costheta * temp) / ( 857 | length * (4.0 / 3.0 - masspole * costheta * costheta / total_mass) 858 | ) 859 | xacc = temp - polemass_length * thetaacc * costheta / total_mass 860 | 861 | x_00_prime = x_dot 862 | x_01_prime = xacc 863 | x_02_prime = theta_dot 864 | x_03_prime = thetaacc 865 | 866 | system_dynamics = [ 867 | x_00_prime, 868 | x_01_prime, 869 | x_02_prime, 870 | x_03_prime, 871 | h_00_prime, 872 | h_01_prime, 873 | h_02_prime, 874 | h_03_prime, 875 | h_04_prime, 876 | h_05_prime, 877 | h_06_prime, 878 | h_07_prime, 879 | ] # has to be in the same order as the input variables 880 | # =================================================== 881 | 882 | return np.array(system_dynamics) # return as numpy array 883 | 884 | 885 | # 12-dimensional cartpole with LTC neural network controller (trained with RK integrator) 886 | class CartpoleLTC_RK: 887 | def __init__(self, radius): 888 | # ============ adapt initial values =========== 889 | self.cx = (0, 0, 0.001, 0, 0, 0, 0, 0, 0, 0, 0, 0) 890 | if radius is not None: 891 | self.rad = radius 892 | else: 893 | self.rad = 1e-4 894 | # =================================================== 895 | self.cx = np.array(self.cx, dtype=float) 896 | self.dim = self.cx.size # dimension of the system 897 | 898 | def true_sigmoid(self, x): 899 | return 0.5 * (tanh(x * 0.5) + 1) 900 | 901 | def sigmoid(self, v_pre, mu, sigma): 902 | mues = v_pre - mu 903 | x = sigma * mues 904 | return self.true_sigmoid(x) 905 | 906 | def fdyn(self, t=0, x=None): 907 | if x is None: 908 | x = np.zeros(self.dim, dtype=object) 909 | 910 | # ============ adapt input and system dynamics =========== 911 | x_00, x_01, x_02, x_03, h_00, h_01, h_02, h_03, h_04, h_05, h_06, h_07 = x 912 | 913 | u_00 = x_00 / 0.1 914 | u_01 = x_01 / 0.2 915 | u_02 = x_02 / 0.1 916 | u_03 = x_03 / 0.1 917 | 918 | swa_00 = ( 919 | (-0.431604 - h_00) 920 | * 0.072614 921 | * 0.5 922 | * (tanh(0.5 * (u_00 - (0.487739)) * 5.650091) + 1) 923 | + (-1.100827 - h_00) 924 | * 0.581228 925 | * 0.5 926 | * (tanh(0.5 * (u_01 - (1.385196)) * 3.634596) + 1) 927 | + (0.174867 - h_00) 928 | * 0.111020 929 | * 0.5 930 | * (tanh(0.5 * (u_02 - (-0.993259)) * 5.331078) + 1) 931 | + (-0.510866 - h_00) 932 | * 0.189538 933 | * 0.5 934 | * (tanh(0.5 * (u_03 - (-1.174057)) * 5.951439) + 1) 935 | ) 936 | swa_01 = ( 937 | (-0.596286 - h_01) 938 | * 0.304088 939 | * 0.5 940 | * (tanh(0.5 * (u_00 - (-0.441833)) * 3.592494) + 1) 941 | + (-1.364699 - h_01) 942 | * 0.266900 943 | * 0.5 944 | * (tanh(0.5 * (u_01 - (-0.104207)) * 4.467682) + 1) 945 | + (0.723861 - h_01) 946 | * 0.282621 947 | * 0.5 948 | * (tanh(0.5 * (u_02 - (0.735169)) * 4.735680) + 1) 949 | + (-2.003241 - h_01) 950 | * 0.162848 951 | * 0.5 952 | * (tanh(0.5 * (u_03 - (1.824110)) * 4.624085) + 1) 953 | ) 954 | swa_02 = ( 955 | (-2.213958 - h_02) 956 | * 0.050379 957 | * 0.5 958 | * (tanh(0.5 * (u_00 - (1.028516)) * 4.883958) + 1) 959 | + (1.837608 - h_02) 960 | * 0.000100 961 | * 0.5 962 | * (tanh(0.5 * (u_01 - (0.476658)) * 3.920279) + 1) 963 | + (-0.869022 - h_02) 964 | * 0.000100 965 | * 0.5 966 | * (tanh(0.5 * (u_02 - (0.146460)) * 4.128851) + 1) 967 | + (0.970412 - h_02) 968 | * 0.142598 969 | * 0.5 970 | * (tanh(0.5 * (u_03 - (-0.057289)) * 4.129013) + 1) 971 | ) 972 | swa_03 = ( 973 | (3.316434 - h_03) 974 | * 0.000100 975 | * 0.5 976 | * (tanh(0.5 * (u_00 - (-0.281031)) * 4.200078) + 1) 977 | + (-0.286490 - h_03) 978 | * 0.136481 979 | * 0.5 980 | * (tanh(0.5 * (u_01 - (2.742276)) * 3.584777) + 1) 981 | + (-0.305795 - h_03) 982 | * 0.186325 983 | * 0.5 984 | * (tanh(0.5 * (u_02 - (1.382006)) * 4.178387) + 1) 985 | + (2.739162 - h_03) 986 | * 0.059435 987 | * 0.5 988 | * (tanh(0.5 * (u_03 - (-0.377228)) * 3.719664) + 1) 989 | ) 990 | swa_04 = ( 991 | (-1.856016 - h_04) 992 | * 0.031760 993 | * 0.5 994 | * (tanh(0.5 * (u_00 - (-0.689140)) * 3.121773) + 1) 995 | + (-1.060924 - h_04) 996 | * 0.233160 997 | * 0.5 998 | * (tanh(0.5 * (u_01 - (0.130687)) * 2.643270) + 1) 999 | + (0.459155 - h_04) 1000 | * 0.143452 1001 | * 0.5 1002 | * (tanh(0.5 * (u_02 - (0.574488)) * 1.886950) + 1) 1003 | + (-0.358857 - h_04) 1004 | * 0.100690 1005 | * 0.5 1006 | * (tanh(0.5 * (u_03 - (1.888403)) * 4.933154) + 1) 1007 | ) 1008 | swa_05 = ( 1009 | (1.980404 - h_05) 1010 | * 0.323992 1011 | * 0.5 1012 | * (tanh(0.5 * (u_00 - (1.983044)) * 4.761758) + 1) 1013 | + (-1.479298 - h_05) 1014 | * 0.058762 1015 | * 0.5 1016 | * (tanh(0.5 * (u_01 - (0.921888)) * 6.088301) + 1) 1017 | + (1.952959 - h_05) 1018 | * 0.088192 1019 | * 0.5 1020 | * (tanh(0.5 * (u_02 - (0.347450)) * 4.452416) + 1) 1021 | + (-1.889391 - h_05) 1022 | * 0.035840 1023 | * 0.5 1024 | * (tanh(0.5 * (u_03 - (-1.054210)) * 5.084546) + 1) 1025 | ) 1026 | swa_06 = ( 1027 | (-0.891564 - h_06) 1028 | * 0.238208 1029 | * 0.5 1030 | * (tanh(0.5 * (u_00 - (0.668283)) * 4.712173) + 1) 1031 | + (-0.434193 - h_06) 1032 | * 0.102071 1033 | * 0.5 1034 | * (tanh(0.5 * (u_01 - (-0.555892)) * 3.985909) + 1) 1035 | + (-1.002575 - h_06) 1036 | * 0.211428 1037 | * 0.5 1038 | * (tanh(0.5 * (u_02 - (1.759687)) * 5.538335) + 1) 1039 | + (0.961403 - h_06) 1040 | * 0.000100 1041 | * 0.5 1042 | * (tanh(0.5 * (u_03 - (-0.823010)) * 5.778478) + 1) 1043 | ) 1044 | swa_07 = ( 1045 | (-2.280897 - h_07) 1046 | * 0.055498 1047 | * 0.5 1048 | * (tanh(0.5 * (u_00 - (0.378455)) * 3.488947) + 1) 1049 | + (3.006480 - h_07) 1050 | * 0.026770 1051 | * 0.5 1052 | * (tanh(0.5 * (u_01 - (-0.096777)) * 2.703730) + 1) 1053 | + (0.428589 - h_07) 1054 | * 0.131216 1055 | * 0.5 1056 | * (tanh(0.5 * (u_02 - (-0.058215)) * 4.788692) + 1) 1057 | + (-0.848094 - h_07) 1058 | * 0.156781 1059 | * 0.5 1060 | * (tanh(0.5 * (u_03 - (0.870154)) * 5.505928) + 1) 1061 | ) 1062 | wa_00 = ( 1063 | (-2.061894 - h_00) 1064 | * 0.255103 1065 | * 0.5 1066 | * (tanh(0.5 * (h_00 - (-0.599924)) * 4.079573) + 1) 1067 | + (-2.919805 - h_00) 1068 | * 0.111550 1069 | * 0.5 1070 | * (tanh(0.5 * (h_01 - (-0.003773)) * 3.577720) + 1) 1071 | + (-1.169591 - h_00) 1072 | * 0.076621 1073 | * 0.5 1074 | * (tanh(0.5 * (h_02 - (0.268934)) * 3.482080) + 1) 1075 | + (2.610878 - h_00) 1076 | * 0.358292 1077 | * 0.5 1078 | * (tanh(0.5 * (h_03 - (-0.801209)) * 1.404748) + 1) 1079 | + (0.464382 - h_00) 1080 | * 0.396781 1081 | * 0.5 1082 | * (tanh(0.5 * (h_04 - (-0.353826)) * 6.078532) + 1) 1083 | + (-1.108612 - h_00) 1084 | * 0.039682 1085 | * 0.5 1086 | * (tanh(0.5 * (h_05 - (-0.083197)) * 4.031213) + 1) 1087 | + (-1.782495 - h_00) 1088 | * 0.192295 1089 | * 0.5 1090 | * (tanh(0.5 * (h_06 - (-0.135476)) * 3.405039) + 1) 1091 | + (-0.745238 - h_00) 1092 | * 0.224384 1093 | * 0.5 1094 | * (tanh(0.5 * (h_07 - (0.838400)) * 3.514007) + 1) 1095 | ) 1096 | wa_01 = ( 1097 | (0.323076 - h_01) 1098 | * 0.132399 1099 | * 0.5 1100 | * (tanh(0.5 * (h_00 - (1.430406)) * 3.536494) + 1) 1101 | + (-1.075827 - h_01) 1102 | * 0.183413 1103 | * 0.5 1104 | * (tanh(0.5 * (h_01 - (0.444819)) * 3.535979) + 1) 1105 | + (0.727482 - h_01) 1106 | * 0.179833 1107 | * 0.5 1108 | * (tanh(0.5 * (h_02 - (1.784908)) * 4.561027) + 1) 1109 | + (0.138851 - h_01) 1110 | * 0.154461 1111 | * 0.5 1112 | * (tanh(0.5 * (h_03 - (1.996656)) * 5.773379) + 1) 1113 | + (-0.711774 - h_01) 1114 | * 0.143956 1115 | * 0.5 1116 | * (tanh(0.5 * (h_04 - (-1.484266)) * 2.894240) + 1) 1117 | + (-0.365611 - h_01) 1118 | * 0.268626 1119 | * 0.5 1120 | * (tanh(0.5 * (h_05 - (0.714121)) * 1.408464) + 1) 1121 | + (1.591881 - h_01) 1122 | * 0.189146 1123 | * 0.5 1124 | * (tanh(0.5 * (h_06 - (2.912474)) * 3.563775) + 1) 1125 | + (2.669420 - h_01) 1126 | * 0.000010 1127 | * 0.5 1128 | * (tanh(0.5 * (h_07 - (0.507458)) * 3.459168) + 1) 1129 | ) 1130 | wa_02 = ( 1131 | (2.356203 - h_02) 1132 | * 0.072446 1133 | * 0.5 1134 | * (tanh(0.5 * (h_00 - (0.818119)) * 4.080008) + 1) 1135 | + (0.534163 - h_02) 1136 | * 0.645405 1137 | * 0.5 1138 | * (tanh(0.5 * (h_01 - (0.951815)) * 4.642598) + 1) 1139 | + (0.577408 - h_02) 1140 | * 0.038336 1141 | * 0.5 1142 | * (tanh(0.5 * (h_02 - (-0.348300)) * 2.439894) + 1) 1143 | + (-1.214589 - h_02) 1144 | * 0.310216 1145 | * 0.5 1146 | * (tanh(0.5 * (h_03 - (0.429035)) * 5.454509) + 1) 1147 | + (-1.883640 - h_02) 1148 | * 0.026713 1149 | * 0.5 1150 | * (tanh(0.5 * (h_04 - (1.463999)) * 3.258402) + 1) 1151 | + (-1.458897 - h_02) 1152 | * 0.054829 1153 | * 0.5 1154 | * (tanh(0.5 * (h_05 - (2.008786)) * 1.817449) + 1) 1155 | + (0.503231 - h_02) 1156 | * 0.215241 1157 | * 0.5 1158 | * (tanh(0.5 * (h_06 - (-0.121074)) * 4.956780) + 1) 1159 | + (-0.185760 - h_02) 1160 | * 0.202772 1161 | * 0.5 1162 | * (tanh(0.5 * (h_07 - (-0.237805)) * 3.949310) + 1) 1163 | ) 1164 | wa_03 = ( 1165 | (2.190941 - h_03) 1166 | * 0.264046 1167 | * 0.5 1168 | * (tanh(0.5 * (h_00 - (2.558471)) * 6.171836) + 1) 1169 | + (0.049259 - h_03) 1170 | * 0.016450 1171 | * 0.5 1172 | * (tanh(0.5 * (h_01 - (-0.471219)) * 3.276491) + 1) 1173 | + (1.897136 - h_03) 1174 | * 0.093556 1175 | * 0.5 1176 | * (tanh(0.5 * (h_02 - (-0.834410)) * 2.282021) + 1) 1177 | + (-1.203204 - h_03) 1178 | * 0.016330 1179 | * 0.5 1180 | * (tanh(0.5 * (h_03 - (-0.229405)) * 2.725281) + 1) 1181 | + (2.328495 - h_03) 1182 | * 0.032663 1183 | * 0.5 1184 | * (tanh(0.5 * (h_04 - (2.172752)) * 1.766576) + 1) 1185 | + (1.529989 - h_03) 1186 | * 0.128894 1187 | * 0.5 1188 | * (tanh(0.5 * (h_05 - (-0.905643)) * 4.200280) + 1) 1189 | + (-0.033709 - h_03) 1190 | * 0.168944 1191 | * 0.5 1192 | * (tanh(0.5 * (h_06 - (0.228327)) * 1.674565) + 1) 1193 | + (-0.903095 - h_03) 1194 | * 0.101751 1195 | * 0.5 1196 | * (tanh(0.5 * (h_07 - (0.474075)) * 3.122658) + 1) 1197 | ) 1198 | wa_04 = ( 1199 | (1.356759 - h_04) 1200 | * 0.207902 1201 | * 0.5 1202 | * (tanh(0.5 * (h_00 - (-0.832245)) * 5.076437) + 1) 1203 | + (-1.713048 - h_04) 1204 | * 0.253930 1205 | * 0.5 1206 | * (tanh(0.5 * (h_01 - (-0.361454)) * 6.056005) + 1) 1207 | + (1.365151 - h_04) 1208 | * 0.025809 1209 | * 0.5 1210 | * (tanh(0.5 * (h_02 - (2.691158)) * 2.098762) + 1) 1211 | + (-1.825929 - h_04) 1212 | * 0.039653 1213 | * 0.5 1214 | * (tanh(0.5 * (h_03 - (-1.585948)) * 3.676085) + 1) 1215 | + (-0.728560 - h_04) 1216 | * 0.012267 1217 | * 0.5 1218 | * (tanh(0.5 * (h_04 - (-1.516692)) * 5.575822) + 1) 1219 | + (-1.497486 - h_04) 1220 | * 0.037167 1221 | * 0.5 1222 | * (tanh(0.5 * (h_05 - (0.021516)) * 5.903266) + 1) 1223 | + (-2.542053 - h_04) 1224 | * 0.181942 1225 | * 0.5 1226 | * (tanh(0.5 * (h_06 - (0.198382)) * 3.954360) + 1) 1227 | + (3.068608 - h_04) 1228 | * 0.095478 1229 | * 0.5 1230 | * (tanh(0.5 * (h_07 - (0.247832)) * 4.194288) + 1) 1231 | ) 1232 | wa_05 = ( 1233 | (0.924399 - h_05) 1234 | * 0.000010 1235 | * 0.5 1236 | * (tanh(0.5 * (h_00 - (1.051798)) * 4.877949) + 1) 1237 | + (0.179591 - h_05) 1238 | * 0.082066 1239 | * 0.5 1240 | * (tanh(0.5 * (h_01 - (0.472982)) * 5.943190) + 1) 1241 | + (0.422308 - h_05) 1242 | * 0.361843 1243 | * 0.5 1244 | * (tanh(0.5 * (h_02 - (0.984128)) * 4.885800) + 1) 1245 | + (-1.484202 - h_05) 1246 | * 0.660023 1247 | * 0.5 1248 | * (tanh(0.5 * (h_03 - (2.561464)) * 3.960099) + 1) 1249 | + (3.444595 - h_05) 1250 | * 0.269937 1251 | * 0.5 1252 | * (tanh(0.5 * (h_04 - (2.765615)) * 3.529421) + 1) 1253 | + (-0.460907 - h_05) 1254 | * 0.018725 1255 | * 0.5 1256 | * (tanh(0.5 * (h_05 - (0.951335)) * 7.133490) + 1) 1257 | + (-1.438883 - h_05) 1258 | * 0.135012 1259 | * 0.5 1260 | * (tanh(0.5 * (h_06 - (2.470155)) * 5.090559) + 1) 1261 | + (-0.022624 - h_05) 1262 | * 0.030127 1263 | * 0.5 1264 | * (tanh(0.5 * (h_07 - (-0.370460)) * 4.233265) + 1) 1265 | ) 1266 | wa_06 = ( 1267 | (-2.204599 - h_06) 1268 | * 0.144977 1269 | * 0.5 1270 | * (tanh(0.5 * (h_00 - (2.253059)) * 3.878229) + 1) 1271 | + (0.148456 - h_06) 1272 | * 0.270476 1273 | * 0.5 1274 | * (tanh(0.5 * (h_01 - (2.298548)) * 5.617167) + 1) 1275 | + (0.544189 - h_06) 1276 | * 0.187231 1277 | * 0.5 1278 | * (tanh(0.5 * (h_02 - (1.738353)) * 3.412084) + 1) 1279 | + (0.957022 - h_06) 1280 | * 0.091821 1281 | * 0.5 1282 | * (tanh(0.5 * (h_03 - (2.433656)) * 1.501055) + 1) 1283 | + (0.305039 - h_06) 1284 | * 0.000010 1285 | * 0.5 1286 | * (tanh(0.5 * (h_04 - (2.545896)) * 2.064504) + 1) 1287 | + (2.075277 - h_06) 1288 | * 0.172349 1289 | * 0.5 1290 | * (tanh(0.5 * (h_05 - (0.215808)) * 4.761319) + 1) 1291 | + (0.651300 - h_06) 1292 | * 0.392033 1293 | * 0.5 1294 | * (tanh(0.5 * (h_06 - (-2.850508)) * 5.109767) + 1) 1295 | + (0.305318 - h_06) 1296 | * 0.056583 1297 | * 0.5 1298 | * (tanh(0.5 * (h_07 - (-0.545839)) * 5.304969) + 1) 1299 | ) 1300 | wa_07 = ( 1301 | (-2.419794 - h_07) 1302 | * 0.164188 1303 | * 0.5 1304 | * (tanh(0.5 * (h_00 - (0.646132)) * 5.431365) + 1) 1305 | + (-1.410748 - h_07) 1306 | * 0.128581 1307 | * 0.5 1308 | * (tanh(0.5 * (h_01 - (1.941283)) * 2.716641) + 1) 1309 | + (0.566556 - h_07) 1310 | * 0.189314 1311 | * 0.5 1312 | * (tanh(0.5 * (h_02 - (2.011078)) * 6.978758) + 1) 1313 | + (-0.999432 - h_07) 1314 | * 0.038773 1315 | * 0.5 1316 | * (tanh(0.5 * (h_03 - (2.175668)) * 2.841975) + 1) 1317 | + (-1.244396 - h_07) 1318 | * 0.024732 1319 | * 0.5 1320 | * (tanh(0.5 * (h_04 - (-1.121634)) * 4.267510) + 1) 1321 | + (0.211258 - h_07) 1322 | * 0.069162 1323 | * 0.5 1324 | * (tanh(0.5 * (h_05 - (0.894198)) * 2.127854) + 1) 1325 | + (2.880166 - h_07) 1326 | * 0.023222 1327 | * 0.5 1328 | * (tanh(0.5 * (h_06 - (0.345695)) * 3.201268) + 1) 1329 | + (-2.649584 - h_07) 1330 | * 0.348009 1331 | * 0.5 1332 | * (tanh(0.5 * (h_07 - (0.311026)) * 3.550425) + 1) 1333 | ) 1334 | h_00_prime = 0.472549 * (0.026911 * (-0.601392 - h_00) + swa_00 + wa_00) 1335 | h_01_prime = 0.422074 * (0.001000 * (1.909540 - h_01) + swa_01 + wa_01) 1336 | h_02_prime = 0.316776 * (1.660939 * (-0.450735 - h_02) + swa_02 + wa_02) 1337 | h_03_prime = 0.484111 * (0.144001 * (1.478572 - h_03) + swa_03 + wa_03) 1338 | h_04_prime = 0.414692 * (1.029787 * (-1.105074 - h_04) + swa_04 + wa_04) 1339 | h_05_prime = 0.358391 * (0.218205 * (-1.095135 - h_05) + swa_05 + wa_05) 1340 | h_06_prime = 0.880162 * (0.268332 * (0.043464 - h_06) + swa_06 + wa_06) 1341 | h_07_prime = 1.671041 * (0.153414 * (0.695311 - h_07) + swa_07 + wa_07) 1342 | 1343 | action_00 = tanh( 1344 | h_00 * 0.135081 1345 | + h_01 * 0.189091 1346 | + h_02 * -0.350403 1347 | + h_03 * 0.388546 1348 | + h_04 * 0.397087 1349 | + h_05 * 0.186882 1350 | + h_06 * -0.047590 1351 | + h_07 * 0.197737 1352 | + -0.352130 1353 | ) 1354 | 1355 | gravity = 9.8 1356 | masscart = 1 1357 | masspole = 0.1 1358 | total_mass = masscart + masspole 1359 | length = 0.5 1360 | polemass_length = masspole * length 1361 | force_mag = 10 1362 | 1363 | force = force_mag * action_00 1364 | costheta = cos(x_02) 1365 | sintheta = sin(x_02) 1366 | x_dot = x_01 1367 | theta_dot = x_03 1368 | 1369 | temp = (force + polemass_length * theta_dot * theta_dot * sintheta) / total_mass 1370 | thetaacc = (gravity * sintheta - costheta * temp) / ( 1371 | length * (4.0 / 3.0 - masspole * costheta * costheta / total_mass) 1372 | ) 1373 | xacc = temp - polemass_length * thetaacc * costheta / total_mass 1374 | 1375 | x_00_prime = x_dot 1376 | x_01_prime = xacc 1377 | x_02_prime = theta_dot 1378 | x_03_prime = thetaacc 1379 | 1380 | system_dynamics = [ 1381 | x_00_prime, 1382 | x_01_prime, 1383 | x_02_prime, 1384 | x_03_prime, 1385 | h_00_prime, 1386 | h_01_prime, 1387 | h_02_prime, 1388 | h_03_prime, 1389 | h_04_prime, 1390 | h_05_prime, 1391 | h_06_prime, 1392 | h_07_prime, 1393 | ] # has to be in the same order as the input variables 1394 | # =================================================== 1395 | 1396 | return np.array(system_dynamics) # return as numpy array 1397 | 1398 | 1399 | class TestNODE: 1400 | def __init__(self, radius): 1401 | # ============ adapt initial values =========== 1402 | self.cx = (0, 0) 1403 | if radius is not None: 1404 | self.rad = radius 1405 | else: 1406 | self.rad = 1e-4 1407 | # =================================================== 1408 | self.cx = np.array(self.cx, dtype=float) 1409 | self.dim = self.cx.size # dimension of the system 1410 | 1411 | def fdyn(self, t=0, x=None): 1412 | if x is None: 1413 | x = np.zeros(self.dim, dtype=object) 1414 | 1415 | # ============ adapt input and system dynamics =========== 1416 | x0, x1 = x 1417 | x0 = 0.5889 * tanh(0.4256 * x0 + 0.5061 * x1 + 0.1773) - 0.1000 * x0 1418 | x1 = 0.3857 * tanh(-0.5563 * x0 + -0.1262 * x1 + -0.2136) - 0.1000 * x1 1419 | system_dynamics = [x0, x1] 1420 | # =================================================== 1421 | return np.array(system_dynamics) # return as numpy array 1422 | 1423 | 1424 | class LDSwithCTRNN: 1425 | def __init__(self, radius): 1426 | # ============ adapt initial values =========== 1427 | if radius is not None: 1428 | self.rad = radius 1429 | else: 1430 | self.rad = 0.5 1431 | # =================================================== 1432 | self.cx = np.zeros(10) 1433 | self.dim = self.cx.size # dimension of the system 1434 | arr = np.load("rl/lds_ctrnn.npz") 1435 | self.params = {k: arr[k] for k in arr.files} 1436 | 1437 | def fdyn(self, t=0, x=None): 1438 | if x is None: 1439 | x = np.zeros(self.dim, dtype=object) 1440 | 1441 | hidden = np.tanh(np.dot(x, self.params["w1"]) + self.params["b1"]) 1442 | dhdt = np.dot(hidden, self.params["w2"]) + self.params["b2"] 1443 | 1444 | action = np.tanh(np.dot(hidden, self.params["wa"]) + self.params["ba"]) 1445 | x, y = x[-2:] 1446 | 1447 | dxdt = y 1448 | dydt = 0.2 + 0.4 * action 1449 | 1450 | dxdt = np.array([dxdt]).reshape((1,)) 1451 | dydt = np.array([dydt]).reshape((1,)) 1452 | dfdt = np.concatenate([dhdt, dxdt, dydt], axis=0) 1453 | return dfdt 1454 | 1455 | 1456 | class PendulumwithCTRNN: 1457 | def __init__(self, radius): 1458 | # ============ adapt initial values =========== 1459 | if radius is not None: 1460 | self.rad = radius 1461 | else: 1462 | self.rad = 0.5 1463 | # =================================================== 1464 | self.cx = np.zeros(10) 1465 | self.dim = self.cx.size # dimension of the system 1466 | arr = np.load("rl/pendulum_ctrnn.npz") 1467 | self.params = {k: arr[k] for k in arr.files} 1468 | 1469 | def fdyn(self, t=0, x=None): 1470 | if x is None: 1471 | x = np.zeros(self.dim, dtype=object) 1472 | 1473 | hidden = np.tanh(np.dot(x, self.params["w1"]) + self.params["b1"]) 1474 | dhdt = np.dot(hidden, self.params["w2"]) + self.params["b2"] 1475 | 1476 | action = np.tanh(np.dot(hidden, self.params["wa"]) + self.params["ba"]) 1477 | th, thdot = x[-2:] 1478 | 1479 | max_speed = 8 1480 | g = 9.81 1481 | m = 1.0 1482 | l = 1.0 1483 | 1484 | newthdot = -3 * g / (2 * l) * np.sin(th + np.pi) + 3.0 / (m * l ** 2) * action 1485 | newthdot = max_speed * np.tanh(newthdot / max_speed) 1486 | newth = newthdot 1487 | 1488 | dxdt = np.array([newth]).reshape((1,)) 1489 | dydt = np.array([newthdot]).reshape((1,)) 1490 | dfdt = np.concatenate([dhdt, dxdt, dydt], axis=0) 1491 | return dfdt 1492 | 1493 | 1494 | class CTRNNosc: 1495 | def __init__(self, radius): 1496 | # ============ adapt initial values =========== 1497 | if radius is not None: 1498 | self.rad = radius 1499 | else: 1500 | self.rad = 0.1 1501 | # =================================================== 1502 | self.cx = np.zeros(16) 1503 | self.dim = self.cx.size # dimension of the system 1504 | arr = np.load("rl/ctrnn_osc.npz") 1505 | self.params = {k: arr[k] for k in arr.files} 1506 | 1507 | def fdyn(self, t=0, x=None): 1508 | if x is None: 1509 | x = np.zeros(self.dim, dtype=object) 1510 | 1511 | hidden = np.tanh(np.dot(x, self.params["w1"]) + self.params["b1"]) 1512 | dhdt = np.dot(hidden, self.params["w2"]) + self.params["b2"] 1513 | return dhdt 1514 | --------------------------------------------------------------------------------