├── .gitignore ├── LICENSE ├── README.md ├── dynamics.py ├── figures └── wind.png ├── generate_data.py ├── plots.py ├── requirements.txt ├── test.sh ├── test_all.py ├── test_results ├── seed=0_M=10.pkl ├── seed=0_M=2.pkl ├── seed=0_M=20.pkl ├── seed=0_M=30.pkl ├── seed=0_M=40.pkl ├── seed=0_M=5.pkl ├── seed=0_M=50.pkl ├── seed=1_M=10.pkl ├── seed=1_M=2.pkl ├── seed=1_M=20.pkl ├── seed=1_M=30.pkl ├── seed=1_M=40.pkl ├── seed=1_M=5.pkl ├── seed=1_M=50.pkl ├── seed=2_M=10.pkl ├── seed=2_M=2.pkl ├── seed=2_M=20.pkl ├── seed=2_M=30.pkl ├── seed=2_M=40.pkl ├── seed=2_M=5.pkl ├── seed=2_M=50.pkl ├── seed=3_M=10.pkl ├── seed=3_M=2.pkl ├── seed=3_M=20.pkl ├── seed=3_M=30.pkl ├── seed=3_M=40.pkl ├── seed=3_M=5.pkl ├── seed=3_M=50.pkl ├── seed=4_M=10.pkl ├── seed=4_M=2.pkl ├── seed=4_M=20.pkl ├── seed=4_M=30.pkl ├── seed=4_M=40.pkl ├── seed=4_M=5.pkl ├── seed=4_M=50.pkl ├── seed=5_M=10.pkl ├── seed=5_M=2.pkl ├── seed=5_M=20.pkl ├── seed=5_M=30.pkl ├── seed=5_M=40.pkl ├── seed=5_M=5.pkl ├── seed=5_M=50.pkl ├── seed=6_M=10.pkl ├── seed=6_M=2.pkl ├── seed=6_M=20.pkl ├── seed=6_M=30.pkl ├── seed=6_M=40.pkl ├── seed=6_M=5.pkl ├── seed=6_M=50.pkl ├── seed=7_M=10.pkl ├── seed=7_M=2.pkl ├── seed=7_M=20.pkl ├── seed=7_M=30.pkl ├── seed=7_M=40.pkl ├── seed=7_M=5.pkl ├── seed=7_M=50.pkl ├── seed=8_M=10.pkl ├── seed=8_M=2.pkl ├── seed=8_M=20.pkl ├── seed=8_M=30.pkl ├── seed=8_M=40.pkl ├── seed=8_M=5.pkl ├── seed=8_M=50.pkl ├── seed=9_M=10.pkl ├── seed=9_M=2.pkl ├── seed=9_M=20.pkl ├── seed=9_M=30.pkl ├── seed=9_M=40.pkl ├── seed=9_M=5.pkl └── seed=9_M=50.pkl ├── test_results_single.pkl ├── test_single.py ├── train.sh ├── train_lstsq.py ├── train_ours.py ├── train_results ├── lstsq │ ├── seed=0_M=10.pkl │ ├── seed=0_M=2.pkl │ ├── seed=0_M=20.pkl │ ├── seed=0_M=30.pkl │ ├── seed=0_M=40.pkl │ ├── seed=0_M=5.pkl │ ├── seed=0_M=50.pkl │ ├── seed=1_M=10.pkl │ ├── seed=1_M=2.pkl │ ├── seed=1_M=20.pkl │ ├── seed=1_M=30.pkl │ ├── seed=1_M=40.pkl │ ├── seed=1_M=5.pkl │ ├── seed=1_M=50.pkl │ ├── seed=2_M=10.pkl │ ├── seed=2_M=2.pkl │ ├── seed=2_M=20.pkl │ ├── seed=2_M=30.pkl │ ├── seed=2_M=40.pkl │ ├── seed=2_M=5.pkl │ ├── seed=2_M=50.pkl │ ├── seed=3_M=10.pkl │ ├── seed=3_M=2.pkl │ ├── seed=3_M=20.pkl │ ├── seed=3_M=30.pkl │ ├── seed=3_M=40.pkl │ ├── seed=3_M=5.pkl │ ├── seed=3_M=50.pkl │ ├── seed=4_M=10.pkl │ ├── seed=4_M=2.pkl │ ├── seed=4_M=20.pkl │ ├── seed=4_M=30.pkl │ ├── seed=4_M=40.pkl │ ├── seed=4_M=5.pkl │ ├── seed=4_M=50.pkl │ ├── seed=5_M=10.pkl │ ├── seed=5_M=2.pkl │ ├── seed=5_M=20.pkl │ ├── seed=5_M=30.pkl │ ├── seed=5_M=40.pkl │ ├── seed=5_M=5.pkl │ ├── seed=5_M=50.pkl │ ├── seed=6_M=10.pkl │ ├── seed=6_M=2.pkl │ ├── seed=6_M=20.pkl │ ├── seed=6_M=30.pkl │ ├── seed=6_M=40.pkl │ ├── seed=6_M=5.pkl │ ├── seed=6_M=50.pkl │ ├── seed=7_M=10.pkl │ ├── seed=7_M=2.pkl │ ├── seed=7_M=20.pkl │ ├── seed=7_M=30.pkl │ ├── seed=7_M=40.pkl │ ├── seed=7_M=5.pkl │ ├── seed=7_M=50.pkl │ ├── seed=8_M=10.pkl │ ├── seed=8_M=2.pkl │ ├── seed=8_M=20.pkl │ ├── seed=8_M=30.pkl │ ├── seed=8_M=40.pkl │ ├── seed=8_M=5.pkl │ ├── seed=8_M=50.pkl │ ├── seed=9_M=10.pkl │ ├── seed=9_M=2.pkl │ ├── seed=9_M=20.pkl │ ├── seed=9_M=30.pkl │ ├── seed=9_M=40.pkl │ ├── seed=9_M=5.pkl │ └── seed=9_M=50.pkl └── ours │ ├── seed=0_M=10.pkl │ ├── seed=0_M=2.pkl │ ├── seed=0_M=20.pkl │ ├── seed=0_M=30.pkl │ ├── seed=0_M=40.pkl │ ├── seed=0_M=5.pkl │ ├── seed=0_M=50.pkl │ ├── seed=1_M=10.pkl │ ├── seed=1_M=2.pkl │ ├── seed=1_M=20.pkl │ ├── seed=1_M=30.pkl │ ├── seed=1_M=40.pkl │ ├── seed=1_M=5.pkl │ ├── seed=1_M=50.pkl │ ├── seed=2_M=10.pkl │ ├── seed=2_M=2.pkl │ ├── seed=2_M=20.pkl │ ├── seed=2_M=30.pkl │ ├── seed=2_M=40.pkl │ ├── seed=2_M=5.pkl │ ├── seed=2_M=50.pkl │ ├── seed=3_M=10.pkl │ ├── seed=3_M=2.pkl │ ├── seed=3_M=20.pkl │ ├── seed=3_M=30.pkl │ ├── seed=3_M=40.pkl │ ├── seed=3_M=5.pkl │ ├── seed=3_M=50.pkl │ ├── seed=4_M=10.pkl │ ├── seed=4_M=2.pkl │ ├── seed=4_M=20.pkl │ ├── seed=4_M=30.pkl │ ├── seed=4_M=40.pkl │ ├── seed=4_M=5.pkl │ ├── seed=4_M=50.pkl │ ├── seed=5_M=10.pkl │ ├── seed=5_M=2.pkl │ ├── seed=5_M=20.pkl │ ├── seed=5_M=30.pkl │ ├── seed=5_M=40.pkl │ ├── seed=5_M=5.pkl │ ├── seed=5_M=50.pkl │ ├── seed=6_M=10.pkl │ ├── seed=6_M=2.pkl │ ├── seed=6_M=20.pkl │ ├── seed=6_M=30.pkl │ ├── seed=6_M=40.pkl │ ├── seed=6_M=5.pkl │ ├── seed=6_M=50.pkl │ ├── seed=7_M=10.pkl │ ├── seed=7_M=2.pkl │ ├── seed=7_M=20.pkl │ ├── seed=7_M=30.pkl │ ├── seed=7_M=40.pkl │ ├── seed=7_M=5.pkl │ ├── seed=7_M=50.pkl │ ├── seed=8_M=10.pkl │ ├── seed=8_M=2.pkl │ ├── seed=8_M=20.pkl │ ├── seed=8_M=30.pkl │ ├── seed=8_M=40.pkl │ ├── seed=8_M=5.pkl │ ├── seed=8_M=50.pkl │ ├── seed=9_M=10.pkl │ ├── seed=9_M=2.pkl │ ├── seed=9_M=20.pkl │ ├── seed=9_M=30.pkl │ ├── seed=9_M=40.pkl │ ├── seed=9_M=5.pkl │ └── seed=9_M=50.pkl ├── training_data.pkl └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.DS_Store 3 | *.vscode 4 | *.egg-info 5 | *.so 6 | .ipynb_checkpoints 7 | build/ 8 | bin/ 9 | figures/*.pdf 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2021 Spencer M. Richards 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adaptive-Control-Oriented Meta-Learning for Nonlinear Systems 2 | 3 | This repository accompanies the paper ["Adaptive-Control-Oriented Meta-Learning for Nonlinear Systems"](https://arxiv.org/abs/2103.04490) [1]. 4 | 5 | 6 | ## Getting started 7 | 8 | Ensure you are using Python 3. Clone this repository and install the packages listed in `requirements.txt`. In particular, this code uses [JAX](https://github.com/google/jax). 9 | 10 | 11 | ## Reproducing results 12 | 13 | Training data, trained parameters, and test results are all conveniently saved in this repository, since it can take a while to re-generate them. To simply produce Figures 2, 3, and 4 in [1], run the command `python plots.py`. 14 | 15 | Training data can be generated with the command `python generate_data.py`. 16 | 17 | Parameters can then be trained (for multiple training set sizes and random seeds) with the command `./train.sh`. This will take a while. 18 | 19 | Finally, test results for Figures 3 and 4 in [1] can be produced with the commands `python test_single.py` and `./test.sh`, respectively. This may also take a while. 20 | 21 | 22 | ## Citing this work 23 | 24 | Please use the following bibtex entry to cite this work. 25 | ``` 26 | @INPROCEEDINGS{RichardsAzizanEtAl2021, 27 | author = {Richards, S. M. and Azizan, N. and Slotine, J.-J. E. and Pavone, M.}, 28 | title = {Adaptive-control-oriented meta-learning for nonlinear systems}, 29 | booktitle = {Robotics: Science and Systems}, 30 | year = {2021}, 31 | note = {In press. Available at \url{https://arxiv.org/abs/2103.04490}}, 32 | } 33 | ``` 34 | 35 | 36 | ## References 37 | [1] S. M. Richards, N. Azizan, J.-J. E. Slotine, and M. Pavone. Adaptive-control-oriented meta-learning for nonlinear systems. In *Robotics: Science and Systems*, 2021. In press. Available at . 38 | -------------------------------------------------------------------------------- /dynamics.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO description. 3 | 4 | Author: Spencer M. Richards 5 | Autonomous Systems Lab (ASL), Stanford 6 | (GitHub: spenrich) 7 | """ 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | 12 | # System constants 13 | g_acc = 9.81 # gravitational acceleration 14 | β = (0.1, 1.) # drag coefficients 15 | 16 | 17 | def prior(q, dq, g_acc=g_acc): 18 | """TODO: docstring.""" 19 | nq = 3 20 | sinϕ, cosϕ = jnp.sin(q[2]), jnp.cos(q[2]) 21 | H = jnp.eye(nq) 22 | C = jnp.zeros((nq, nq)) 23 | g = jnp.array([0., g_acc, 0.]) 24 | # R = jnp.array([ 25 | # [cosϕ, -sinϕ, 0], 26 | # [sinϕ, cosϕ, 0], 27 | # [0, 0, 1], 28 | # ]) 29 | B = jnp.array([ 30 | [-sinϕ, 0, cosϕ], 31 | [cosϕ, 0, sinϕ], 32 | [0, 1, 0], 33 | ]) 34 | return H, C, g, B 35 | 36 | 37 | def plant(q, dq, u, f_ext, prior=prior): 38 | """TODO: docstring.""" 39 | H, C, g, B = prior(q, dq) 40 | ddq = jax.scipy.linalg.solve(H, f_ext + B@u - C@dq - g, sym_pos=True) 41 | return ddq 42 | 43 | 44 | def disturbance(q, dq, w, β=β): 45 | """TODO: docstring.""" 46 | β = jnp.asarray(β) 47 | ϕ, dx, dy = q[2], dq[0], dq[1] 48 | sinϕ, cosϕ = jnp.sin(ϕ), jnp.cos(ϕ) 49 | R = jnp.array([ 50 | [cosϕ, -sinϕ], 51 | [sinϕ, cosϕ] 52 | ]) 53 | v = R.T @ jnp.array([dx - w, dy]) 54 | f_ext = - jnp.array([*(R @ (β * v * jnp.abs(v))), 0.]) 55 | return f_ext 56 | -------------------------------------------------------------------------------- /figures/wind.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/figures/wind.png -------------------------------------------------------------------------------- /generate_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO description. 3 | 4 | Author: Spencer M. Richards 5 | Autonomous Systems Lab (ASL), Stanford 6 | (GitHub: spenrich) 7 | """ 8 | 9 | if __name__ == "__main__": 10 | import pickle 11 | import jax 12 | import jax.numpy as jnp 13 | from jax.experimental.ode import odeint 14 | from utils import spline, random_ragged_spline 15 | from dynamics import prior, plant, disturbance 16 | 17 | # Seed random numbers 18 | seed = 0 19 | key = jax.random.PRNGKey(seed) 20 | 21 | # Generate smooth trajectories 22 | num_traj = 500 23 | T = 30 24 | num_knots = 6 25 | poly_orders = (9, 9, 6) 26 | deriv_orders = (4, 4, 2) 27 | min_step = jnp.array([-2., -2., -jnp.pi/6]) 28 | max_step = jnp.array([2., 2., jnp.pi/6]) 29 | min_knot = jnp.array([-jnp.inf, -jnp.inf, -jnp.pi/3]) 30 | max_knot = jnp.array([jnp.inf, jnp.inf, jnp.pi/3]) 31 | 32 | key, *subkeys = jax.random.split(key, 1 + num_traj) 33 | subkeys = jnp.vstack(subkeys) 34 | in_axes = (0, None, None, None, None, None, None, None, None) 35 | t_knots, knots, coefs = jax.vmap(random_ragged_spline, in_axes)( 36 | subkeys, T, num_knots, poly_orders, deriv_orders, 37 | min_step, max_step, min_knot, max_knot 38 | ) 39 | # x_coefs, y_coefs, ϕ_coefs = coefs 40 | r_knots = jnp.dstack(knots) 41 | 42 | # Sampled-time simulator 43 | @jax.partial(jax.vmap, in_axes=(None, 0, 0, 0)) 44 | def simulate(ts, w, t_knots, coefs, 45 | plant=plant, prior=prior, disturbance=disturbance): 46 | """TODO: docstring.""" 47 | # Construct spline reference trajectory 48 | def reference(t): 49 | x_coefs, y_coefs, ϕ_coefs = coefs 50 | x = spline(t, t_knots, x_coefs) 51 | y = spline(t, t_knots, y_coefs) 52 | ϕ = spline(t, t_knots, ϕ_coefs) 53 | ϕ = jnp.clip(ϕ, -jnp.pi/3, jnp.pi/3) 54 | r = jnp.array([x, y, ϕ]) 55 | return r 56 | 57 | # Required derivatives of the reference trajectory 58 | def ref_derivatives(t): 59 | ref_vel = jax.jacfwd(reference) 60 | ref_acc = jax.jacfwd(ref_vel) 61 | r = reference(t) 62 | dr = ref_vel(t) 63 | ddr = ref_acc(t) 64 | return r, dr, ddr 65 | 66 | # Feedback linearizing PD controller 67 | def controller(q, dq, r, dr, ddr): 68 | kp, kd = 10., 0.1 69 | e, de = q - r, dq - dr 70 | dv = ddr - kp*e - kd*de 71 | H, C, g, B = prior(q, dq) 72 | τ = H@dv + C@dq + g 73 | u = jnp.linalg.solve(B, τ) 74 | return u, τ 75 | 76 | # Closed-loop ODE for `x = (q, dq)`, with a zero-order hold on 77 | # the controller 78 | def ode(x, t, u, w=w): 79 | q, dq = x 80 | f_ext = disturbance(q, dq, w) 81 | ddq = plant(q, dq, u, f_ext) 82 | dx = (dq, ddq) 83 | return dx 84 | 85 | # Simulation loop 86 | def loop(carry, input_slice): 87 | t_prev, q_prev, dq_prev, u_prev = carry 88 | t = input_slice 89 | qs, dqs = odeint(ode, (q_prev, dq_prev), jnp.array([t_prev, t]), 90 | u_prev) 91 | q, dq = qs[-1], dqs[-1] 92 | r, dr, ddr = ref_derivatives(t) 93 | u, τ = controller(q, dq, r, dr, ddr) 94 | carry = (t, q, dq, u) 95 | output_slice = (q, dq, u, τ, r, dr) 96 | return carry, output_slice 97 | 98 | # Initial conditions 99 | t0 = ts[0] 100 | r0, dr0, ddr0 = ref_derivatives(t0) 101 | q0, dq0 = r0, dr0 102 | u0, τ0 = controller(q0, dq0, r0, dr0, ddr0) 103 | 104 | # Run simulation loop 105 | carry = (t0, q0, dq0, u0) 106 | carry, output = jax.lax.scan(loop, carry, ts[1:]) 107 | q, dq, u, τ, r, dr = output 108 | 109 | # Prepend initial conditions 110 | q = jnp.vstack((q0, q)) 111 | dq = jnp.vstack((dq0, dq)) 112 | u = jnp.vstack((u0, u)) 113 | τ = jnp.vstack((τ0, τ)) 114 | r = jnp.vstack((r0, r)) 115 | dr = jnp.vstack((dr0, dr)) 116 | 117 | return q, dq, u, τ, r, dr 118 | 119 | # Sample wind velocities from the training distribution 120 | w_min = 0. # minimum wind velocity in inertial `x`-direction 121 | w_max = 6. # maximum wind velocity in inertial `x`-direction 122 | a = 5. # shape parameter `a` for beta distribution 123 | b = 9. # shape parameter `b` for beta distribution 124 | key, subkey = jax.random.split(key, 2) 125 | w = w_min + (w_max - w_min)*jax.random.beta(subkey, a, b, (num_traj,)) 126 | 127 | # Simulate tracking for each `w` 128 | dt = 0.01 129 | t = jnp.arange(0, T + dt, dt) # same times for each trajectory 130 | q, dq, u, τ, r, dr = simulate(t, w, t_knots, coefs) 131 | 132 | data = { 133 | 'seed': seed, 'prng_key': key, 134 | 't': t, 'q': q, 'dq': dq, 135 | 'u': u, 'r': r, 'dr': dr, 136 | 't_knots': t_knots, 'r_knots': r_knots, 137 | 'w': w, 'w_min': w_min, 'w_max': w_max, 138 | 'beta_params': (a, b), 139 | } 140 | 141 | with open('training_data.pkl', 'wb') as file: 142 | pickle.dump(data, file) 143 | -------------------------------------------------------------------------------- /plots.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO description. 3 | 4 | Author: Spencer M. Richards 5 | Autonomous Systems Lab (ASL), Stanford 6 | (GitHub: spenrich) 7 | """ 8 | 9 | if __name__ == "__main__": 10 | import numpy as np 11 | from scipy.stats import beta 12 | import pickle 13 | import os 14 | import matplotlib.pyplot as plt 15 | from matplotlib.patches import Patch 16 | from matplotlib.lines import Line2D 17 | import itertools 18 | 19 | plt.rcParams.update({ 20 | 'font.family': 'serif', 21 | 'font.serif': ['Times', 'Times New Roman'], 22 | 'mathtext.fontset': 'cm', 23 | 'font.size': 16, 24 | 'legend.fontsize': 'medium', 25 | 'axes.titlesize': 'medium', 26 | 'lines.linewidth': 2, 27 | 'lines.markersize': 10, 28 | 'errorbar.capsize': 6, 29 | }) 30 | 31 | # FIGURE 2 ############################################################### 32 | with open('training_data.pkl', 'rb') as file: 33 | raw = pickle.load(file) 34 | w_train = raw['w'] 35 | w_min, w_max = raw['w_min'], raw['w_max'] 36 | a, b = raw['beta_params'] 37 | x = np.linspace(0, 1) 38 | w_train_pdf = w_min + (w_max - w_min)*x 39 | p_train = beta.pdf(x, a, b) / (w_max - w_min) 40 | 41 | with open(os.path.join('test_results', 'seed=2_M=10.pkl'), 'rb') as file: 42 | results = pickle.load(file) 43 | gains = tuple(itertools.product( 44 | results['gains']['Λ'], results['gains']['K'], results['gains']['P'] 45 | )) 46 | num_gains = len(gains) 47 | w_test = results['w'] 48 | w_min, w_max = results['w_min'], results['w_max'] 49 | a, b = results['beta_params'] 50 | w_test_pdf = w_min + (w_max - w_min)*x 51 | p_test = beta.pdf(x, a, b) / (w_max - w_min) 52 | 53 | _, bins = np.histogram(np.hstack([w_train, w_test]), bins=15) 54 | fig, ax = plt.subplots(1, 1, dpi=100, figsize=(8, 4)) 55 | ax.plot(w_train_pdf, p_train, 56 | label=r'$p_\mathrm{train}(w)$', color='tab:blue') 57 | ax.hist(w_train, density=True, alpha=0.5, bins=bins, color='tab:blue') 58 | ax.plot(w_test_pdf, p_test, 59 | label=r'$p_\mathrm{test}(w)$', color='tab:orange') 60 | ax.hist(w_test, density=True, alpha=0.5, bins=bins, color='tab:orange') 61 | ax.set_xlabel(r'$w~\mathrm{[m/s]}$') 62 | ax.set_ylabel(r'sampling probability') 63 | ax.legend() 64 | fig.tight_layout() 65 | fig.savefig(os.path.join('figures', 'fig2.pdf'), bbox_inches='tight') 66 | plt.show() 67 | 68 | # FIGURE 3 ############################################################### 69 | with open('test_results_single.pkl', 'rb') as file: 70 | results = pickle.load(file) 71 | 72 | fig = plt.figure(dpi=100, figsize=(8, 7.5)) 73 | grid = plt.GridSpec(2, 2, width_ratios=[1.5, 1], height_ratios=[1, 1], 74 | hspace=0.05, wspace=0.4) 75 | axes = (plt.subplot(grid[:, 0]), 76 | plt.subplot(grid[0, 1]), 77 | plt.subplot(grid[1, 1])) 78 | 79 | t = results['ours_meta']['t'] 80 | r = results['ours_meta']['r'] 81 | xr, yr, ϕr = r[:, 0], r[:, 1], r[:, 2] 82 | axes[0].plot(xr, yr, '--', color='tab:red', lw=4.) 83 | 84 | methods = ('pid', 'lstsq', 'ours', 'ours_meta') 85 | colors = ('tab:pink', 'tab:orange', 'tab:green', 'tab:blue') 86 | labels = ('PID', 'ACMRR', 'ours', 87 | r'ours, $(\Lambda, K, \Gamma) = (' 88 | r'\Lambda_\mathrm{meta},K_\mathrm{meta},\Gamma_\mathrm{meta})$') 89 | 90 | for method, color in zip(methods, colors): 91 | q = results[method]['q'] 92 | x, y, ϕ = q[:, 0], q[:, 1], q[:, 2] 93 | axes[0].plot(x, y, color=color) 94 | 95 | e_norm = np.linalg.norm(results[method]['e'], axis=1) 96 | axes[1].plot(t, e_norm, color=color) 97 | 98 | u_norm = np.linalg.norm(results[method]['u'], axis=1) 99 | axes[2].plot(t, u_norm, color=color) 100 | 101 | axes[0].set_xlabel(r'$x~\mathrm{[m]}$') 102 | axes[0].set_ylabel(r'$y~\mathrm{[m]}$') 103 | axes[1].get_xaxis().set_ticklabels([]) 104 | axes[1].set_ylabel(r'$\sqrt{\|\tilde{q}\|_2^2+\|\dot{\tilde{q}}\|_2^2}$') 105 | axes[2].set_xlabel(r'$t~\mathrm{[s]}$') 106 | axes[2].set_ylabel(r'$\|u\|_2$') 107 | 108 | im_height = 1. 109 | im_width = 1.5 110 | im_x0, im_y0 = -0.8, 4.7 111 | pad = 0.15 112 | axes[0].text(im_x0, im_y0 + im_height + pad, 113 | r'$w = {:.1f}'.format(results['w']) + r'~\mathrm{m/s}$') 114 | axes[0].imshow( 115 | plt.imread(os.path.join('figures', 'wind.png')), 116 | aspect='auto', 117 | interpolation='none', 118 | extent=(im_x0, im_x0 + im_width, im_y0, im_y0 + im_height) 119 | ) 120 | axes[0].set_xlim([-1., 4.2]) 121 | axes[0].set_ylim([-0.2, 6.2]) 122 | 123 | handles = [Line2D([0], [0], color=color, label=label) 124 | for color, label in zip(colors, labels)] 125 | handles = [Line2D([0], [0], color='tab:red', label='reference', 126 | linestyle='--', lw=4.)] + handles 127 | fig.legend(handles=handles, loc='lower center', ncol=2) 128 | fig.subplots_adjust(bottom=0.24) 129 | fig.savefig(os.path.join('figures', 'fig3.pdf'), bbox_inches='tight') 130 | plt.show() 131 | 132 | # FIGURE 4 ############################################################### 133 | seeds = np.arange(10) 134 | Ms = np.array([2, 5, 10, 20, 30, 40, 50]) 135 | methods = ('pid', 'lstsq', 'ours') 136 | colors = ('tab:pink', 'tab:orange', 'tab:green', 'tab:blue') 137 | labels = ('PID', 'ACMRR', 'ours', 138 | r'ours, $(\Lambda,K,\Gamma) = (' 139 | r'\Lambda_\mathrm{meta},K_\mathrm{meta},\Gamma_\mathrm{meta})$') 140 | metrics = ( 141 | r'$\dfrac{1}{N_\mathrm{test}}' 142 | r'\sum_{i=1}^{N_\mathrm{test}}\,\mathrm{RMS}(x_i{-}r_i)$', 143 | r'$\dfrac{1}{N_\mathrm{test}}' 144 | r'\sum_{i=1}^{N_\mathrm{test}}\,\mathrm{RMS}(u_i)$', 145 | ) 146 | 147 | rms_error = { 148 | 'pid': np.zeros((num_gains, seeds.size, Ms.size)), 149 | 'lstsq': np.zeros((num_gains, seeds.size, Ms.size)), 150 | 'ours': np.zeros((num_gains, seeds.size, Ms.size)), 151 | 'ours_meta': np.zeros((seeds.size, Ms.size)), 152 | } 153 | rms_ctrl = { 154 | 'pid': np.zeros((num_gains, seeds.size, Ms.size)), 155 | 'lstsq': np.zeros((num_gains, seeds.size, Ms.size)), 156 | 'ours': np.zeros((num_gains, seeds.size, Ms.size)), 157 | 'ours_meta': np.zeros((seeds.size, Ms.size)), 158 | } 159 | 160 | for j, seed in enumerate(seeds): 161 | for m, M in enumerate(Ms): 162 | filename = os.path.join('test_results', 163 | 'seed={}_M={}.pkl'.format(seed, M)) 164 | with open(filename, 'rb') as file: 165 | results = pickle.load(file) 166 | for i, _ in enumerate(gains): 167 | for method in methods: 168 | rms_error[method][i, j, m] = np.mean( 169 | results[method].ravel()[i]['rms_error'] 170 | ) 171 | rms_ctrl[method][i, j, m] = np.mean( 172 | results[method].ravel()[i]['rms_ctrl'] 173 | ) 174 | rms_error['ours_meta'][j, m] = np.mean( 175 | results['ours_meta']['rms_error'] 176 | ) 177 | rms_ctrl['ours_meta'][j, m] = np.mean( 178 | results['ours_meta']['rms_ctrl'] 179 | ) 180 | 181 | fig, axes = plt.subplots(2, num_gains, 182 | dpi=100, figsize=(18, 6), sharex=True) 183 | axes[0, 0].set_ylabel(metrics[0]) 184 | axes[1, 0].set_ylabel(metrics[1]) 185 | for j, (λ, k, p) in enumerate(gains): 186 | for method, color in zip(methods, colors): 187 | axes[0, j].errorbar(Ms, np.mean(rms_error[method][j], axis=0), 188 | np.std(rms_error[method][j], axis=0), 189 | fmt='-o', color=color) 190 | axes[0, j].errorbar(Ms, np.mean(rms_error['ours_meta'], axis=0), 191 | np.std(rms_error['ours_meta'], axis=0), 192 | fmt='-o', color=colors[-1]) 193 | axes[0, j].set_title( 194 | r'$(\Lambda,K,\Gamma) = ({})$'.format( 195 | r','.join([r'{}I' if g == 1 else r'{:g}I'.format(g) 196 | for g in (λ, k, p)]) 197 | ), pad=7 198 | ) 199 | axes[0, j].set_xticks(np.arange(0, Ms[-1] + 1, 10)) 200 | axes[0, j].set_xticks(np.arange(0, Ms[-1] + 1, 5), minor=True) 201 | 202 | for method, color in zip(methods, colors): 203 | axes[1, j].errorbar(Ms, np.mean(rms_ctrl[method][j], axis=0), 204 | np.std(rms_ctrl[method][j], axis=0), 205 | fmt='-o', color=color) 206 | axes[1, j].errorbar(Ms, np.mean(rms_ctrl['ours_meta'], axis=0), 207 | np.std(rms_ctrl['ours_meta'], axis=0), 208 | fmt='-o', color=colors[-1]) 209 | axes[1, j].set_ylim([10.3, 12.7]) 210 | axes[1, j].set_yticks([10.5, 11., 11.5, 12., 12.5]) 211 | axes[1, j].set_xlabel(r'$M$') 212 | axes[0, j].set_xticks(np.arange(0, Ms[-1] + 1, 10)) 213 | axes[0, j].set_xticks(np.arange(0, Ms[-1] + 1, 5), minor=True) 214 | 215 | axes[0, 0].set_ylim([-0.05, 2.1]) 216 | axes[0, 1].set_ylim([-0.05, 1.6]) 217 | axes[0, 2].set_ylim([-0.05, 0.46]) 218 | axes[0, 3].set_ylim([-0.05, 0.46]) 219 | 220 | axes[0, 0].set_yticks([0., 0.5, 1., 1.5, 2.]) 221 | axes[0, 1].set_yticks([0., 0.5, 1., 1.5]) 222 | axes[0, 2].set_yticks([0., 0.1, 0.2, 0.3, 0.4]) 223 | axes[0, 3].set_yticks([0., 0.1, 0.2, 0.3, 0.4]) 224 | 225 | handles = [Patch(color=color, label=label) 226 | for color, label in zip(colors, labels)] 227 | fig.legend(handles=handles, loc='lower center', ncol=len(handles)) 228 | fig.subplots_adjust(bottom=0.19, hspace=0.1) 229 | fig.savefig(os.path.join('figures', 'fig4.pdf'), bbox_inches='tight') 230 | plt.show() 231 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | matplotlib 4 | jax >=0.2.8 5 | tqdm 6 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # TODO description. 4 | 5 | # Author: Spencer M. Richards 6 | # Autonomous Systems Lab (ASL), Stanford 7 | # (GitHub: spenrich) 8 | 9 | for seed in {0..9} 10 | do 11 | for M in 2 5 10 20 30 40 50 12 | do 13 | echo "seed = $seed, M = $M" 14 | python test_all.py $seed $M 15 | done 16 | done 17 | -------------------------------------------------------------------------------- /test_all.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO description. 3 | 4 | Author: Spencer M. Richards 5 | Autonomous Systems Lab (ASL), Stanford 6 | (GitHub: spenrich) 7 | """ 8 | 9 | import pickle 10 | from math import pi, inf 11 | import os 12 | import argparse 13 | import time 14 | import numpy as np 15 | from itertools import product 16 | from tqdm.auto import tqdm 17 | 18 | 19 | def enumerated_product(*args): 20 | """TODO: docstring.""" 21 | yield from zip( 22 | product(*(range(len(x)) for x in args)), 23 | product(*args) 24 | ) 25 | 26 | 27 | # Parse command line arguments 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('seed', help='seed for pseudo-random number generation', 30 | type=int) 31 | parser.add_argument('M', help='number of trajectories to sub-sample', 32 | type=int) 33 | parser.add_argument('--use_x64', help='use 64-bit precision', 34 | action='store_true') 35 | args = parser.parse_args() 36 | 37 | # Set precision 38 | if args.use_x64: 39 | os.environ['JAX_ENABLE_X64'] = 'True' 40 | 41 | import jax # noqa: E402 42 | import jax.numpy as jnp # noqa: E402 43 | from jax.experimental.ode import odeint # noqa: E402 44 | from utils import random_ragged_spline, spline, params_to_posdef # noqa: E402 45 | from dynamics import prior, plant, disturbance # noqa: E402 46 | 47 | 48 | # Initialize PRNG key (with offset from original seed to make sure we do not 49 | # sample the same reference trajectories in the training set) 50 | key = jax.random.PRNGKey(args.seed + 20) 51 | 52 | hparams = { 53 | 'seed': args.seed, # 54 | 'use_x64': args.use_x64, # 55 | 'num_subtraj': args.M, # number of trajectories sub-sampled 56 | 57 | 'w_min': 0., # minimum wind velocity in inertial `x`-direction 58 | 'w_max': 10., # maximum wind velocity in inertial `x`-direction 59 | 'a': 5., # shape parameter `a` for beta distribution 60 | 'b': 7., # shape parameter `b` for beta distribution 61 | 62 | # Reference trajectory generation 63 | 'T': 10., # time horizon for each reference 64 | 'dt': 1e-2, # numerical integration time step 65 | 'num_refs': 200, # reference trajectories to generate 66 | 'num_knots': 6, # knot points per reference spline 67 | 'poly_orders': (9, 9, 6), # spline orders for each DOF 68 | 'deriv_orders': (4, 4, 2), # smoothness objective for each DOF 69 | 'min_step': (-2., -2., -pi/6), # 70 | 'max_step': (2., 2., pi/6), # 71 | 'min_ref': (-inf, -inf, -pi/3), # 72 | 'max_ref': (inf, inf, pi/3), # 73 | } 74 | 75 | 76 | if __name__ == "__main__": 77 | print('Testing ... ', flush=True) 78 | start = time.time() 79 | 80 | # Generate reference trajectories 81 | key, *subkeys = jax.random.split(key, 1 + hparams['num_refs']) 82 | subkeys = jnp.vstack(subkeys) 83 | in_axes = (0, None, None, None, None, None, None, None, None) 84 | min_ref = jnp.asarray(hparams['min_ref']) 85 | max_ref = jnp.asarray(hparams['max_ref']) 86 | t_knots, knots, coefs = jax.vmap(random_ragged_spline, in_axes)( 87 | subkeys, 88 | hparams['T'], 89 | hparams['num_knots'], 90 | hparams['poly_orders'], 91 | hparams['deriv_orders'], 92 | jnp.asarray(hparams['min_step']), 93 | jnp.asarray(hparams['max_step']), 94 | 0.7*min_ref, 95 | 0.7*max_ref, 96 | ) 97 | r_knots = jnp.dstack(knots) 98 | num_dof = 3 99 | 100 | # Sampled-time simulator 101 | @jax.jit 102 | @jax.partial(jax.vmap, in_axes=(None, 0, 0, 0, None)) 103 | def simulate(ts, w, t_knots, coefs, params, 104 | min_ref=min_ref, max_ref=max_ref, 105 | plant=plant, prior=prior, disturbance=disturbance): 106 | """TODO: docstring.""" 107 | # Construct spline reference trajectory 108 | def reference(t): 109 | r = jnp.array([spline(t, t_knots, c) for c in coefs]) 110 | r = jnp.clip(r, min_ref, max_ref) 111 | return r 112 | 113 | # Required derivatives of the reference trajectory 114 | def ref_derivatives(t): 115 | ref_vel = jax.jacfwd(reference) 116 | ref_acc = jax.jacfwd(ref_vel) 117 | r = reference(t) 118 | dr = ref_vel(t) 119 | ddr = ref_acc(t) 120 | return r, dr, ddr 121 | 122 | # Adaptation law 123 | def adaptation_law(q, dq, r, dr, params=params): 124 | # Regressor features 125 | y = jnp.concatenate((q, dq)) 126 | for W, b in zip(params['W'], params['b']): 127 | y = jnp.tanh(W@y + b) 128 | 129 | # Auxiliary signals 130 | Λ, P = params['Λ'], params['P'] 131 | e, de = q - r, dq - dr 132 | s = de + Λ@e 133 | 134 | dA = P @ jnp.outer(s, y) 135 | return dA, y 136 | 137 | # Controller 138 | def controller(q, dq, r, dr, ddr, f_hat, params=params): 139 | # Auxiliary signals 140 | Λ, K = params['Λ'], params['K'] 141 | e, de = q - r, dq - dr 142 | s = de + Λ@e 143 | v, dv = dr - Λ@e, ddr - Λ@de 144 | 145 | # Control input and adaptation law 146 | H, C, g, B = prior(q, dq) 147 | τ = H@dv + C@v + g - f_hat - K@s 148 | u = jnp.linalg.solve(B, τ) 149 | return u, τ 150 | 151 | # Closed-loop ODE for `x = (q, dq)`, with a zero-order hold on 152 | # the controller 153 | def ode(x, t, u, w=w): 154 | q, dq = x 155 | f_ext = disturbance(q, dq, w) 156 | ddq = plant(q, dq, u, f_ext) 157 | dx = (dq, ddq) 158 | return dx 159 | 160 | # Simulation loop 161 | def loop(carry, input_slice): 162 | t_prev, q_prev, dq_prev, u_prev, A_prev, dA_prev = carry 163 | t = input_slice 164 | qs, dqs = odeint(ode, (q_prev, dq_prev), jnp.array([t_prev, t]), 165 | u_prev) 166 | q, dq = qs[-1], dqs[-1] 167 | r, dr, ddr = ref_derivatives(t) 168 | 169 | # Integrate adaptation law via trapezoidal rule 170 | dA, y = adaptation_law(q, dq, r, dr) 171 | A = A_prev + (t - t_prev)*(dA_prev + dA)/2 172 | 173 | # Compute force estimate and control input 174 | f_hat = A @ y 175 | u, τ = controller(q, dq, r, dr, ddr, f_hat) 176 | 177 | carry = (t, q, dq, u, A, dA) 178 | output_slice = (q, dq, u, τ, r, dr) 179 | return carry, output_slice 180 | 181 | # Initial conditions 182 | t0 = ts[0] 183 | r0, dr0, ddr0 = ref_derivatives(t0) 184 | q0, dq0 = r0, dr0 185 | dA0, y0 = adaptation_law(q0, dq0, r0, dr0) 186 | A0 = jnp.zeros((q0.size, y0.size)) 187 | f0 = A0 @ y0 188 | u0, τ0 = controller(q0, dq0, r0, dr0, ddr0, f0) 189 | 190 | # Run simulation loop 191 | carry = (t0, q0, dq0, u0, A0, dA0) 192 | carry, output = jax.lax.scan(loop, carry, ts[1:]) 193 | q, dq, u, τ, r, dr = output 194 | 195 | # Prepend initial conditions 196 | q = jnp.vstack((q0, q)) 197 | dq = jnp.vstack((dq0, dq)) 198 | u = jnp.vstack((u0, u)) 199 | τ = jnp.vstack((τ0, τ)) 200 | r = jnp.vstack((r0, r)) 201 | dr = jnp.vstack((dr0, dr)) 202 | 203 | return q, dq, u, τ, r, dr 204 | 205 | # Sample wind velocities from the test distribution 206 | w_min = 0. # minimum wind velocity in inertial `x`-direction 207 | w_max = 10. # maximum wind velocity in inertial `x`-direction 208 | a = 5. # shape parameter `a` for beta distribution 209 | b = 7. # shape parameter `b` for beta distribution 210 | key, subkey = jax.random.split(key, 2) 211 | w = w_min + (w_max - w_min)*jax.random.beta(subkey, a, b, 212 | (hparams['num_refs'],)) 213 | 214 | # Simulate tracking for each `w` 215 | T, dt = hparams['T'], hparams['dt'] 216 | ts = jnp.arange(0, T + dt, dt) # same times for each trajectory 217 | 218 | # Try out different gains 219 | test_results = { 220 | 'w': w, 'w_min': w_min, 'w_max': w_max, 221 | 'beta_params': (a, b), 222 | 'gains': { 223 | 'Λ': (1.,), 224 | 'K': (1., 10.), 225 | 'P': (1., 10.), 226 | } 227 | } 228 | grid_shape = (len(test_results['gains']['Λ']), 229 | len(test_results['gains']['K']), 230 | len(test_results['gains']['P'])) 231 | 232 | # Our method with meta-learned gains 233 | print(' ours (meta) ...', flush=True) 234 | filename = os.path.join('train_results', 'ours', 235 | 'seed={}_M={}.pkl'.format(hparams['seed'], 236 | hparams['num_subtraj'])) 237 | with open(filename, 'rb') as file: 238 | train_results = pickle.load(file) 239 | params = { 240 | 'W': train_results['model']['W'], 241 | 'b': train_results['model']['b'], 242 | 'Λ': params_to_posdef(train_results['controller']['Λ']), 243 | 'K': params_to_posdef(train_results['controller']['K']), 244 | 'P': params_to_posdef(train_results['controller']['P']), 245 | } 246 | q, dq, u, τ, r, dr = simulate(ts, w, t_knots, coefs, params) 247 | e = np.concatenate((q - r, dq - dr), axis=-1) 248 | rms_e = np.sqrt(np.mean(np.sum(e**2, axis=-1), axis=-1)) 249 | rms_u = np.sqrt(np.mean(np.sum(u**2, axis=-1), axis=-1)) 250 | test_results['ours_meta'] = { 251 | 'params': params, 252 | 'rms_error': rms_e, 253 | 'rms_ctrl': rms_u, 254 | } 255 | 256 | for method in ('pid', 'lstsq', 'ours'): 257 | test_results[method] = np.empty(grid_shape, dtype=object) 258 | print(' {} ...'.format(method), flush=True) 259 | if method == 'pid': 260 | params = { 261 | 'W': [jnp.zeros((1, 2*num_dof)), ], 262 | 'b': [jnp.inf * jnp.ones((1,)), ], 263 | } 264 | else: 265 | filename = os.path.join( 266 | 'train_results', method, 267 | 'seed={}_M={}.pkl'.format(hparams['seed'], 268 | hparams['num_subtraj']) 269 | ) 270 | with open(filename, 'rb') as file: 271 | train_results = pickle.load(file) 272 | params = { 273 | 'W': train_results['model']['W'], 274 | 'b': train_results['model']['b'], 275 | } 276 | 277 | for (i, j, l), (λ, k, p) in tqdm(enumerated_product( 278 | test_results['gains']['Λ'], 279 | test_results['gains']['K'], 280 | test_results['gains']['P']), total=np.prod(grid_shape) 281 | ): 282 | params['Λ'] = λ * jnp.eye(num_dof) 283 | params['K'] = k * jnp.eye(num_dof) 284 | params['P'] = p * jnp.eye(num_dof) 285 | q, dq, u, τ, r, dr = simulate(ts, w, t_knots, coefs, params) 286 | e = np.concatenate((q - r, dq - dr), axis=-1) 287 | rms_e = np.sqrt(np.mean(np.sum(e**2, axis=-1), axis=-1)) 288 | rms_u = np.sqrt(np.mean(np.sum(u**2, axis=-1), axis=-1)) 289 | test_results[method][i, j, l] = { 290 | 'params': params, 291 | 'rms_error': rms_e, 292 | 'rms_ctrl': rms_u, 293 | } 294 | 295 | # Save 296 | output_filename = os.path.join( 297 | 'test_results', "seed={:d}_M={:d}.pkl".format(hparams['seed'], 298 | hparams['num_subtraj']) 299 | ) 300 | with open(output_filename, 'wb') as file: 301 | pickle.dump(test_results, file) 302 | 303 | end = time.time() 304 | print('done! ({:.2f} s)'.format(end - start)) 305 | -------------------------------------------------------------------------------- /test_results/seed=0_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=0_M=10.pkl -------------------------------------------------------------------------------- /test_results/seed=0_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=0_M=2.pkl -------------------------------------------------------------------------------- /test_results/seed=0_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=0_M=20.pkl -------------------------------------------------------------------------------- /test_results/seed=0_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=0_M=30.pkl -------------------------------------------------------------------------------- /test_results/seed=0_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=0_M=40.pkl -------------------------------------------------------------------------------- /test_results/seed=0_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=0_M=5.pkl -------------------------------------------------------------------------------- /test_results/seed=0_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=0_M=50.pkl -------------------------------------------------------------------------------- /test_results/seed=1_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=1_M=10.pkl -------------------------------------------------------------------------------- /test_results/seed=1_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=1_M=2.pkl -------------------------------------------------------------------------------- /test_results/seed=1_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=1_M=20.pkl -------------------------------------------------------------------------------- /test_results/seed=1_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=1_M=30.pkl -------------------------------------------------------------------------------- /test_results/seed=1_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=1_M=40.pkl -------------------------------------------------------------------------------- /test_results/seed=1_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=1_M=5.pkl -------------------------------------------------------------------------------- /test_results/seed=1_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=1_M=50.pkl -------------------------------------------------------------------------------- /test_results/seed=2_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=2_M=10.pkl -------------------------------------------------------------------------------- /test_results/seed=2_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=2_M=2.pkl -------------------------------------------------------------------------------- /test_results/seed=2_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=2_M=20.pkl -------------------------------------------------------------------------------- /test_results/seed=2_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=2_M=30.pkl -------------------------------------------------------------------------------- /test_results/seed=2_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=2_M=40.pkl -------------------------------------------------------------------------------- /test_results/seed=2_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=2_M=5.pkl -------------------------------------------------------------------------------- /test_results/seed=2_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=2_M=50.pkl -------------------------------------------------------------------------------- /test_results/seed=3_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=3_M=10.pkl -------------------------------------------------------------------------------- /test_results/seed=3_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=3_M=2.pkl -------------------------------------------------------------------------------- /test_results/seed=3_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=3_M=20.pkl -------------------------------------------------------------------------------- /test_results/seed=3_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=3_M=30.pkl -------------------------------------------------------------------------------- /test_results/seed=3_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=3_M=40.pkl -------------------------------------------------------------------------------- /test_results/seed=3_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=3_M=5.pkl -------------------------------------------------------------------------------- /test_results/seed=3_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=3_M=50.pkl -------------------------------------------------------------------------------- /test_results/seed=4_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=4_M=10.pkl -------------------------------------------------------------------------------- /test_results/seed=4_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=4_M=2.pkl -------------------------------------------------------------------------------- /test_results/seed=4_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=4_M=20.pkl -------------------------------------------------------------------------------- /test_results/seed=4_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=4_M=30.pkl -------------------------------------------------------------------------------- /test_results/seed=4_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=4_M=40.pkl -------------------------------------------------------------------------------- /test_results/seed=4_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=4_M=5.pkl -------------------------------------------------------------------------------- /test_results/seed=4_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=4_M=50.pkl -------------------------------------------------------------------------------- /test_results/seed=5_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=5_M=10.pkl -------------------------------------------------------------------------------- /test_results/seed=5_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=5_M=2.pkl -------------------------------------------------------------------------------- /test_results/seed=5_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=5_M=20.pkl -------------------------------------------------------------------------------- /test_results/seed=5_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=5_M=30.pkl -------------------------------------------------------------------------------- /test_results/seed=5_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=5_M=40.pkl -------------------------------------------------------------------------------- /test_results/seed=5_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=5_M=5.pkl -------------------------------------------------------------------------------- /test_results/seed=5_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=5_M=50.pkl -------------------------------------------------------------------------------- /test_results/seed=6_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=6_M=10.pkl -------------------------------------------------------------------------------- /test_results/seed=6_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=6_M=2.pkl -------------------------------------------------------------------------------- /test_results/seed=6_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=6_M=20.pkl -------------------------------------------------------------------------------- /test_results/seed=6_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=6_M=30.pkl -------------------------------------------------------------------------------- /test_results/seed=6_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=6_M=40.pkl -------------------------------------------------------------------------------- /test_results/seed=6_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=6_M=5.pkl -------------------------------------------------------------------------------- /test_results/seed=6_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=6_M=50.pkl -------------------------------------------------------------------------------- /test_results/seed=7_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=7_M=10.pkl -------------------------------------------------------------------------------- /test_results/seed=7_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=7_M=2.pkl -------------------------------------------------------------------------------- /test_results/seed=7_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=7_M=20.pkl -------------------------------------------------------------------------------- /test_results/seed=7_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=7_M=30.pkl -------------------------------------------------------------------------------- /test_results/seed=7_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=7_M=40.pkl -------------------------------------------------------------------------------- /test_results/seed=7_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=7_M=5.pkl -------------------------------------------------------------------------------- /test_results/seed=7_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=7_M=50.pkl -------------------------------------------------------------------------------- /test_results/seed=8_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=8_M=10.pkl -------------------------------------------------------------------------------- /test_results/seed=8_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=8_M=2.pkl -------------------------------------------------------------------------------- /test_results/seed=8_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=8_M=20.pkl -------------------------------------------------------------------------------- /test_results/seed=8_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=8_M=30.pkl -------------------------------------------------------------------------------- /test_results/seed=8_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=8_M=40.pkl -------------------------------------------------------------------------------- /test_results/seed=8_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=8_M=5.pkl -------------------------------------------------------------------------------- /test_results/seed=8_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=8_M=50.pkl -------------------------------------------------------------------------------- /test_results/seed=9_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=9_M=10.pkl -------------------------------------------------------------------------------- /test_results/seed=9_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=9_M=2.pkl -------------------------------------------------------------------------------- /test_results/seed=9_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=9_M=20.pkl -------------------------------------------------------------------------------- /test_results/seed=9_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=9_M=30.pkl -------------------------------------------------------------------------------- /test_results/seed=9_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=9_M=40.pkl -------------------------------------------------------------------------------- /test_results/seed=9_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=9_M=5.pkl -------------------------------------------------------------------------------- /test_results/seed=9_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results/seed=9_M=50.pkl -------------------------------------------------------------------------------- /test_results_single.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/test_results_single.pkl -------------------------------------------------------------------------------- /test_single.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO description. 3 | 4 | Author: Spencer M. Richards 5 | Autonomous Systems Lab (ASL), Stanford 6 | (GitHub: spenrich) 7 | """ 8 | 9 | import pickle 10 | import os 11 | import argparse 12 | import time 13 | import numpy as np 14 | 15 | # Parse command line arguments 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--use_x64', help='use 64-bit precision', 18 | action='store_true') 19 | args = parser.parse_args() 20 | 21 | # Set precision 22 | if args.use_x64: 23 | os.environ['JAX_ENABLE_X64'] = 'True' 24 | 25 | import jax # noqa: E402 26 | import jax.numpy as jnp # noqa: E402 27 | from jax.experimental.ode import odeint # noqa: E402 28 | from utils import params_to_posdef # noqa: E402 29 | from dynamics import prior, plant, disturbance # noqa: E402 30 | 31 | # Uncomment this line to force using the CPU 32 | jax.config.update('jax_platform_name', 'cpu') # TODO: keep or remove? 33 | 34 | if __name__ == "__main__": 35 | print('Testing ... ', flush=True) 36 | start = time.time() 37 | seed, M = 1, 10 38 | 39 | # Sampled-time simulator 40 | @jax.partial(jax.jit, static_argnums=(3,)) 41 | def simulate(ts, w, params, reference, 42 | plant=plant, prior=prior, disturbance=disturbance): 43 | """TODO: docstring.""" 44 | # Required derivatives of the reference trajectory 45 | def ref_derivatives(t): 46 | ref_vel = jax.jacfwd(reference) 47 | ref_acc = jax.jacfwd(ref_vel) 48 | r = reference(t) 49 | dr = ref_vel(t) 50 | ddr = ref_acc(t) 51 | return r, dr, ddr 52 | 53 | # Adaptation law 54 | def adaptation_law(q, dq, r, dr, params=params): 55 | # Regressor features 56 | y = jnp.concatenate((q, dq)) 57 | for W, b in zip(params['W'], params['b']): 58 | y = jnp.tanh(W@y + b) 59 | 60 | # Auxiliary signals 61 | Λ, P = params['Λ'], params['P'] 62 | e, de = q - r, dq - dr 63 | s = de + Λ@e 64 | 65 | dA = P @ jnp.outer(s, y) 66 | return dA, y 67 | 68 | # Controller 69 | def controller(q, dq, r, dr, ddr, f_hat, params=params): 70 | # Auxiliary signals 71 | Λ, K = params['Λ'], params['K'] 72 | e, de = q - r, dq - dr 73 | s = de + Λ@e 74 | v, dv = dr - Λ@e, ddr - Λ@de 75 | 76 | # Control input and adaptation law 77 | H, C, g, B = prior(q, dq) 78 | τ = H@dv + C@v + g - f_hat - K@s 79 | u = jnp.linalg.solve(B, τ) 80 | return u, τ 81 | 82 | # Closed-loop ODE for `x = (q, dq)`, with a zero-order hold on 83 | # the controller 84 | def ode(x, t, u, w=w): 85 | q, dq = x 86 | f_ext = disturbance(q, dq, w) 87 | ddq = plant(q, dq, u, f_ext) 88 | dx = (dq, ddq) 89 | return dx 90 | 91 | # Simulation loop 92 | def loop(carry, input_slice): 93 | t_prev, q_prev, dq_prev, u_prev, A_prev, dA_prev = carry 94 | t = input_slice 95 | qs, dqs = odeint(ode, (q_prev, dq_prev), jnp.array([t_prev, t]), 96 | u_prev) 97 | q, dq = qs[-1], dqs[-1] 98 | r, dr, ddr = ref_derivatives(t) 99 | 100 | # Integrate adaptation law via trapezoidal rule 101 | dA, y = adaptation_law(q, dq, r, dr) 102 | A = A_prev + (t - t_prev)*(dA_prev + dA)/2 103 | 104 | # Compute force estimate and control input 105 | f_hat = A @ y 106 | u, τ = controller(q, dq, r, dr, ddr, f_hat) 107 | 108 | carry = (t, q, dq, u, A, dA) 109 | output_slice = (q, dq, u, τ, r, dr) 110 | return carry, output_slice 111 | 112 | # Initial conditions 113 | t0 = ts[0] 114 | r0, dr0, ddr0 = ref_derivatives(t0) 115 | q0, dq0 = r0, dr0 116 | dA0, y0 = adaptation_law(q0, dq0, r0, dr0) 117 | A0 = jnp.zeros((q0.size, y0.size)) 118 | f0 = A0 @ y0 119 | u0, τ0 = controller(q0, dq0, r0, dr0, ddr0, f0) 120 | 121 | # Run simulation loop 122 | carry = (t0, q0, dq0, u0, A0, dA0) 123 | carry, output = jax.lax.scan(loop, carry, ts[1:]) 124 | q, dq, u, τ, r, dr = output 125 | 126 | # Prepend initial conditions 127 | q = jnp.vstack((q0, q)) 128 | dq = jnp.vstack((dq0, dq)) 129 | u = jnp.vstack((u0, u)) 130 | τ = jnp.vstack((τ0, τ)) 131 | r = jnp.vstack((r0, r)) 132 | dr = jnp.vstack((dr0, dr)) 133 | 134 | return q, dq, u, τ, r, dr 135 | 136 | # Construct a trajectory 137 | def reference(t): 138 | """TODO: docstring.""" 139 | T = 10. # loop period 140 | d = 4. # displacement along `x` from `t=0` to `t=T` 141 | w = 4. # loop width 142 | h = 6. # loop height 143 | ϕ_max = jnp.pi/3 # maximum roll angle (achieved at top of loop) 144 | 145 | x = (w/2)*jnp.sin(2*jnp.pi * t/T) + d*(t/T) 146 | y = (h/2)*(1 - jnp.cos(2*jnp.pi * t/T)) 147 | ϕ = 4*ϕ_max*(t/T)*(1-t/T) 148 | r = jnp.array([x, y, ϕ]) 149 | return r 150 | 151 | # Choose a wind velocity, fixed control gains, and simulation times 152 | num_dof = 3 153 | w = 6.5 154 | λ, k, p = 1., 10., 10. 155 | T, dt = 10., 0.01 156 | ts = jnp.arange(0, T + dt, dt) 157 | 158 | # Simulate tracking for each method 159 | test_results = { 160 | 'w': w, 161 | 'gains': (λ, k, p), 162 | } 163 | 164 | # Our method with meta-learned gains 165 | print(' ours (meta) ...', flush=True) 166 | filename = os.path.join('train_results', 'ours', 167 | 'seed={}_M={}.pkl'.format(seed, M)) 168 | with open(filename, 'rb') as file: 169 | train_results = pickle.load(file) 170 | params = { 171 | 'W': train_results['model']['W'], 172 | 'b': train_results['model']['b'], 173 | 'Λ': params_to_posdef(train_results['controller']['Λ']), 174 | 'K': params_to_posdef(train_results['controller']['K']), 175 | 'P': params_to_posdef(train_results['controller']['P']), 176 | } 177 | q, dq, u, τ, r, dr = simulate(ts, w, params, reference) 178 | e = np.concatenate((q - r, dq - dr), axis=-1) 179 | test_results['ours_meta'] = { 180 | 'params': params, 181 | 't': ts, 'q': q, 'dq': dq, 'r': r, 'dr': dr, 182 | 'u': u, 'τ': τ, 'e': e, 183 | } 184 | 185 | for method in ('pid', 'lstsq', 'ours'): 186 | print(' {} ...'.format(method), flush=True) 187 | if method == 'pid': 188 | params = { 189 | 'W': [jnp.zeros((1, 2*num_dof)), ], 190 | 'b': [jnp.inf * jnp.ones((1,)), ], 191 | } 192 | else: 193 | filename = os.path.join( 194 | 'train_results', method, 195 | 'seed={}_M={}.pkl'.format(seed, M) 196 | ) 197 | with open(filename, 'rb') as file: 198 | train_results = pickle.load(file) 199 | params = { 200 | 'W': train_results['model']['W'], 201 | 'b': train_results['model']['b'], 202 | } 203 | params['Λ'] = λ * jnp.eye(num_dof) 204 | params['K'] = k * jnp.eye(num_dof) 205 | params['P'] = p * jnp.eye(num_dof) 206 | q, dq, u, τ, r, dr = simulate(ts, w, params, reference) 207 | e = np.concatenate((q - r, dq - dr), axis=-1) 208 | test_results[method] = { 209 | 'params': params, 210 | 't': ts, 'q': q, 'dq': dq, 'r': r, 'dr': dr, 211 | 'u': u, 'τ': τ, 'e': e, 212 | } 213 | 214 | # Save 215 | with open('test_results_single.pkl', 'wb') as file: 216 | pickle.dump(test_results, file) 217 | 218 | end = time.time() 219 | print('done! ({:.2f} s)'.format(end - start)) 220 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # TODO description. 4 | 5 | # Author: Spencer M. Richards 6 | # Autonomous Systems Lab (ASL), Stanford 7 | # (GitHub: spenrich) 8 | 9 | for seed in {0..9} 10 | do 11 | for M in 2 5 10 20 30 40 50 12 | do 13 | echo "seed = $seed, M = $M" 14 | 15 | echo "Meta-ridge-regression:" 16 | python train_lstsq.py $seed $M 17 | 18 | echo "Ours:" 19 | python train_ours.py $seed $M 20 | done 21 | done 22 | -------------------------------------------------------------------------------- /train_lstsq.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO description. 3 | 4 | Author: Spencer M. Richards 5 | Autonomous Systems Lab (ASL), Stanford 6 | (GitHub: spenrich) 7 | """ 8 | 9 | from tqdm.auto import tqdm 10 | import pickle 11 | import time 12 | import warnings 13 | import os 14 | import argparse 15 | 16 | # Parse command line arguments 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('seed', help='seed for pseudo-random number generation', 19 | type=int) 20 | parser.add_argument('M', help='number of trajectories to sub-sample', 21 | type=int) 22 | parser.add_argument('--use_x64', help='use 64-bit precision', 23 | action='store_true') 24 | args = parser.parse_args() 25 | 26 | # Set precision 27 | if args.use_x64: 28 | os.environ['JAX_ENABLE_X64'] = 'True' 29 | 30 | import jax # noqa: E402 31 | import jax.numpy as jnp # noqa: E402 32 | from jax.experimental import optimizers # noqa: E402 33 | from utils import tree_normsq # noqa: E402 34 | from dynamics import prior # noqa: E402 35 | 36 | 37 | # Initialize PRNG key 38 | key = jax.random.PRNGKey(args.seed) 39 | 40 | # Hyperparameters 41 | hparams = { 42 | 'seed': args.seed, # 43 | 'use_x64': args.use_x64, # 44 | 'num_subtraj': args.M, # number of trajectories to sub-sample 45 | 'num_hlayers': 2, # number of hidden layers 46 | 'hdim': 32, # number of hidden units per layer 47 | 'train_frac': 0.75, # fraction per trajectory for training 48 | 'ridge_frac': 0.25, # (fraction of samples used in the ridge 49 | # regression solution per trajectory) 50 | 'regularizer_l2': 1e-4, # coefficient for L2-regularization 51 | 'regularizer_ridge': 1e-6, # (coefficient for L2-regularization of 52 | # least-squares solution) 53 | 'learning_rate': 1e-2, # step size for gradient optimization 54 | 'num_steps': 5000, # number of epochs 55 | } 56 | 57 | 58 | if __name__ == "__main__": 59 | # DATA PROCESSING ######################################################## 60 | # Load raw data and arrange in samples of the form 61 | # `(t, x, u, t_next, x_next)` for each trajectory, where `x := (q,dq)` 62 | with open('training_data.pkl', 'rb') as file: 63 | raw = pickle.load(file) 64 | num_dof = raw['q'].shape[-1] # number of degrees of freedom 65 | num_traj = raw['q'].shape[0] # total number of raw trajectories 66 | num_samples = raw['t'].size - 1 # number of transitions per trajectory 67 | t = jnp.tile(raw['t'][:-1], (num_traj, 1)) 68 | t_next = jnp.tile(raw['t'][1:], (num_traj, 1)) 69 | x = jnp.concatenate((raw['q'][:, :-1], raw['dq'][:, :-1]), axis=-1) 70 | x_next = jnp.concatenate((raw['q'][:, 1:], raw['dq'][:, 1:]), axis=-1) 71 | u = raw['u'][:, :-1] 72 | 73 | data = {'t': t, 'x': x, 'u': u, 't_next': t_next, 'x_next': x_next} 74 | 75 | # Shuffle and sub-sample trajectories 76 | if hparams['num_subtraj'] > num_traj: 77 | warnings.warn('Cannot sub-sample {:d} trajectories! ' 78 | 'Capping at {:d}.'.format(hparams['num_subtraj'], 79 | num_traj)) 80 | hparams['num_subtraj'] = num_traj 81 | 82 | key, subkey = jax.random.split(key, 2) 83 | shuffled_idx = jax.random.permutation(subkey, num_traj) 84 | hparams['subtraj_idx'] = shuffled_idx[:hparams['num_subtraj']] 85 | data = jax.tree_util.tree_map( 86 | lambda a: jnp.take(a, hparams['subtraj_idx'], axis=0), 87 | data 88 | ) 89 | 90 | # META-TRAIN MODEL ####################################################### 91 | # Map over time index 92 | @jax.partial(jax.vmap, in_axes=(None, 0, 0, 0, 0, 0)) 93 | def lstsq_coefs(params, t, x, u, t_next, x_next, prior=prior): 94 | """TODO: docstring.""" 95 | num_dof = x.size // 2 96 | q, dq = x[:num_dof], x[num_dof:] 97 | dq_next = x_next[num_dof:] 98 | H, C, g, B = prior(q, dq) 99 | 100 | # Regressor 101 | phi = x 102 | for W, b in zip(params['W'], params['b']): 103 | phi = jnp.tanh(W@phi + b) 104 | 105 | # Euler integration of the dynamics from `t` to `t_next` yields 106 | # a linear equation of the form `A@z = b`, where `A` is the last 107 | # layer applied to our regressor 108 | dt = t_next - t 109 | z = dt*phi 110 | b = H@(dq_next - dq) + dt*(C@dq + g - B@u) 111 | return z, b 112 | 113 | # Map over trajectory index 114 | @jax.partial(jax.vmap, in_axes=(None, 0, None, None, 0, 0, 0, 0, 0)) 115 | def trajectory_loss(params, key, num_ridge_samples, regularizer_ridge, 116 | t, x, u, t_next, x_next): 117 | """TODO: docstring.""" 118 | # Compute least-squares coefficients and shuffle them 119 | Z, B = lstsq_coefs(params, t, x, u, t_next, x_next) 120 | num_samples, num_features = Z.shape 121 | idx = jax.random.permutation(key, num_samples) 122 | Z = Z[idx] 123 | B = B[idx] 124 | 125 | # Solve for the last layer as the least-squares solution 126 | # on a subset of the data 127 | Z_ls = Z[:num_ridge_samples] 128 | B_ls = B[:num_ridge_samples] 129 | ZTZ_λI = (Z_ls.T@Z_ls).at[jnp.diag_indices(num_features)].add( 130 | regularizer_ridge 131 | ) 132 | ZTB = Z_ls.T@B_ls 133 | AT = jax.scipy.linalg.solve(ZTZ_λI/num_ridge_samples, 134 | ZTB/num_ridge_samples, sym_pos=True) 135 | 136 | # Compute loss on ALL of the data 137 | loss = jnp.sum((Z@AT - B)**2) 138 | return loss 139 | 140 | @jax.partial(jax.jit, static_argnums=(3,)) 141 | def loss(params, regularizer_l2, keys, num_ridge_samples, 142 | regularizer_ridge, t, x, u, t_next, x_next): 143 | """TODO: docstring.""" 144 | num_traj, num_samples = t.shape 145 | normalizer = num_traj * num_samples 146 | traj_losses = trajectory_loss(params, keys, num_ridge_samples, 147 | regularizer_ridge, 148 | t, x, u, t_next, x_next) 149 | loss = (jnp.sum(traj_losses) 150 | + regularizer_l2*tree_normsq(params)) / normalizer 151 | return loss 152 | 153 | # Initialize model parameters 154 | num_hlayers = hparams['num_hlayers'] 155 | hdim = hparams['hdim'] 156 | if num_hlayers >= 1: 157 | shapes = [(hdim, 2*num_dof), ] + (num_hlayers-1)*[(hdim, hdim), ] 158 | else: 159 | shapes = [] 160 | key, *subkeys = jax.random.split(key, 1 + 2*num_hlayers) 161 | keys_W = subkeys[:num_hlayers] 162 | keys_b = subkeys[num_hlayers:] 163 | params = { 164 | # hidden layer weights 165 | 'W': [0.1*jax.random.normal(keys_W[i], shapes[i]) 166 | for i in range(num_hlayers)], 167 | # hidden layer biases 168 | 'b': [0.1*jax.random.normal(keys_b[i], (shapes[i][0],)) 169 | for i in range(num_hlayers)], 170 | } 171 | 172 | # Shuffle samples in time along each trajectory, then split each 173 | # trajectory into training and validation sets 174 | key, *subkeys = jax.random.split(key, 1 + hparams['num_subtraj']) 175 | subkeys = jnp.asarray(subkeys) 176 | shuffled_data = jax.tree_util.tree_map( 177 | lambda a: jax.vmap(jax.random.permutation)(subkeys, a), 178 | data 179 | ) 180 | num_train_samples = int(hparams['train_frac'] * num_samples) 181 | num_valid_samples = num_samples - num_train_samples 182 | train_data = jax.tree_util.tree_map(lambda a: a[:, :num_train_samples], 183 | shuffled_data) 184 | valid_data = jax.tree_util.tree_map(lambda a: a[:, num_train_samples:], 185 | shuffled_data) 186 | 187 | # Initialize gradient-based optimizer (ADAM) 188 | num_ridge_samples = int(hparams['ridge_frac']*num_train_samples) 189 | learning_rate = hparams['learning_rate'] 190 | init_opt, update_opt, get_params = optimizers.adam(learning_rate) 191 | opt_state = init_opt(params) 192 | step_idx = 0 193 | best_idx = 0 194 | best_loss = jnp.inf 195 | best_params = params 196 | 197 | @jax.partial(jax.jit, static_argnums=(4,)) 198 | def step(idx, opt_state, regularizer_l2, keys, num_ridge_samples, 199 | regularizer_ridge, batch): 200 | """TODO: docstring.""" 201 | params = get_params(opt_state) 202 | grads = jax.grad(loss, argnums=0)(params, regularizer_l2, keys, 203 | num_ridge_samples, 204 | regularizer_ridge, **batch) 205 | opt_state = update_opt(idx, grads, opt_state) 206 | return opt_state 207 | 208 | # Pre-compile before training 209 | print('MODEL META-TRAINING: Pre-compiling ... ', end='', flush=True) 210 | start = time.time() 211 | _ = step(step_idx, opt_state, hparams['regularizer_l2'], 212 | subkeys, num_ridge_samples, 213 | hparams['regularizer_ridge'], train_data) 214 | _ = loss(params, 0., subkeys, num_valid_samples, 215 | hparams['regularizer_ridge'], **valid_data) 216 | end = time.time() 217 | print('done ({:.2f} s)! Now training ...'.format(end - start)) 218 | start = time.time() 219 | 220 | # Do gradient descent 221 | for _ in tqdm(range(hparams['num_steps'])): 222 | key, *subkeys = jax.random.split(key, 1 + hparams['num_subtraj']) 223 | subkeys = jnp.asarray(subkeys) 224 | opt_state = step(step_idx, opt_state, hparams['regularizer_l2'], 225 | subkeys, num_ridge_samples, 226 | hparams['regularizer_ridge'], train_data) 227 | new_params = get_params(opt_state) 228 | new_loss = loss(new_params, 0., subkeys, num_valid_samples, 229 | hparams['regularizer_ridge'], **valid_data) 230 | step_idx += 1 231 | if new_loss < best_loss: 232 | best_idx = step_idx 233 | best_loss = new_loss 234 | best_params = new_params 235 | 236 | # Save hyperparameters and model 237 | results = { 238 | 'best_step_idx': best_idx, 239 | 'hparams': hparams, 240 | 'model': best_params 241 | } 242 | output_name = "seed={:d}_M={:d}".format( 243 | hparams['seed'], hparams['num_subtraj'] 244 | ) 245 | output_path = os.path.join('train_results', 'lstsq', output_name + '.pkl') 246 | with open(output_path, 'wb') as file: 247 | pickle.dump(results, file) 248 | 249 | end = time.time() 250 | print('done ({:.2f} s)! Best step index: {}'.format(end - start, 251 | best_idx)) 252 | -------------------------------------------------------------------------------- /train_ours.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO description. 3 | 4 | Author: Spencer M. Richards 5 | Autonomous Systems Lab (ASL), Stanford 6 | (GitHub: spenrich) 7 | """ 8 | 9 | from tqdm.auto import tqdm 10 | import pickle 11 | import time 12 | import warnings 13 | from math import pi, inf 14 | import os 15 | import argparse 16 | 17 | # Parse command line arguments 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('seed', help='seed for pseudo-random number generation', 20 | type=int) 21 | parser.add_argument('M', help='number of trajectories to sub-sample', 22 | type=int) 23 | parser.add_argument('--use_x64', help='use 64-bit precision', 24 | action='store_true') 25 | args = parser.parse_args() 26 | 27 | # Set precision 28 | if args.use_x64: 29 | os.environ['JAX_ENABLE_X64'] = 'True' 30 | 31 | import jax # noqa: E402 32 | import jax.numpy as jnp # noqa: E402 33 | from jax.experimental import optimizers # noqa: E402 34 | from dynamics import prior # noqa: E402 35 | from utils import (tree_normsq, rk38_step, epoch, # noqa: E402 36 | odeint_fixed_step, random_ragged_spline, spline, 37 | params_to_cholesky, params_to_posdef) 38 | 39 | 40 | # Initialize PRNG key 41 | key = jax.random.PRNGKey(args.seed) 42 | 43 | # Hyperparameters 44 | hparams = { 45 | 'seed': args.seed, # 46 | 'use_x64': args.use_x64, # 47 | 'num_subtraj': args.M, # number of trajectories to sub-sample 48 | 49 | # For training the model ensemble 50 | 'ensemble': { 51 | 'num_hlayers': 2, # number of hidden layers in each model 52 | 'hdim': 32, # number of hidden units per layer 53 | 'train_frac': 0.75, # fraction of each trajectory for training 54 | 'batch_frac': 0.25, # fraction of training data per batch 55 | 'regularizer_l2': 1e-4, # coefficient for L2-regularization 56 | 'learning_rate': 1e-2, # step size for gradient optimization 57 | 'num_epochs': 1000, # number of epochs 58 | }, 59 | # For meta-training 60 | 'meta': { 61 | 'num_hlayers': 2, # number of hidden layers 62 | 'hdim': 32, # number of hidden units per layer 63 | 'train_frac': 0.75, # 64 | 'learning_rate': 1e-2, # step size for gradient optimization 65 | 'num_steps': 500, # maximum number of gradient steps 66 | 'regularizer_l2': 1e-4, # coefficient for L2-regularization 67 | 'regularizer_ctrl': 1e-3, # 68 | 'regularizer_error': 0., # 69 | 'T': 5., # time horizon for each reference 70 | 'dt': 1e-2, # time step for numerical integration 71 | 'num_refs': 10, # reference trajectories to generate 72 | 'num_knots': 6, # knot points per reference spline 73 | 'poly_orders': (9, 9, 6), # spline orders for each DOF 74 | 'deriv_orders': (4, 4, 2), # smoothness objective for each DOF 75 | 'min_step': (-2., -2., -pi/6), # 76 | 'max_step': (2., 2., pi/6), # 77 | 'min_ref': (-inf, -inf, -pi/3), # 78 | 'max_ref': (inf, inf, pi/3), # 79 | }, 80 | } 81 | 82 | if __name__ == "__main__": 83 | # DATA PROCESSING ######################################################## 84 | # Load raw data and arrange in samples of the form 85 | # `(t, x, u, t_next, x_next)` for each trajectory, where `x := (q,dq)` 86 | with open('training_data.pkl', 'rb') as file: 87 | raw = pickle.load(file) 88 | num_dof = raw['q'].shape[-1] # number of degrees of freedom 89 | num_traj = raw['q'].shape[0] # total number of raw trajectories 90 | num_samples = raw['t'].size - 1 # number of transitions per trajectory 91 | t = jnp.tile(raw['t'][:-1], (num_traj, 1)) 92 | t_next = jnp.tile(raw['t'][1:], (num_traj, 1)) 93 | x = jnp.concatenate((raw['q'][:, :-1], raw['dq'][:, :-1]), axis=-1) 94 | x_next = jnp.concatenate((raw['q'][:, 1:], raw['dq'][:, 1:]), axis=-1) 95 | u = raw['u'][:, :-1] 96 | data = {'t': t, 'x': x, 'u': u, 't_next': t_next, 'x_next': x_next} 97 | 98 | # Shuffle and sub-sample trajectories 99 | if hparams['num_subtraj'] > num_traj: 100 | warnings.warn('Cannot sub-sample {:d} trajectories! ' 101 | 'Capping at {:d}.'.format(hparams['num_subtraj'], 102 | num_traj)) 103 | hparams['num_subtraj'] = num_traj 104 | 105 | key, subkey = jax.random.split(key, 2) 106 | shuffled_idx = jax.random.permutation(subkey, num_traj) 107 | hparams['subtraj_idx'] = shuffled_idx[:hparams['num_subtraj']] 108 | data = jax.tree_util.tree_map( 109 | lambda a: jnp.take(a, hparams['subtraj_idx'], axis=0), 110 | data 111 | ) 112 | 113 | # MODEL ENSEMBLE TRAINING ################################################ 114 | # Loss function along a trajectory 115 | def ode(x, t, u, params, prior=prior): 116 | """TODO: docstring.""" 117 | num_dof = x.size // 2 118 | q, dq = x[:num_dof], x[num_dof:] 119 | H, C, g, B = prior(q, dq) 120 | 121 | # Each model in the ensemble is a feed-forward neural network 122 | # with zero output bias 123 | f = x 124 | for W, b in zip(params['W'], params['b']): 125 | f = jnp.tanh(W@f + b) 126 | f = params['A'] @ f 127 | ddq = jax.scipy.linalg.solve(H, B@u + f - C@dq - g, sym_pos=True) 128 | dx = jnp.concatenate((dq, ddq)) 129 | return dx 130 | 131 | def loss(params, regularizer, t, x, u, t_next, x_next, ode=ode): 132 | """TODO: docstring.""" 133 | num_samples = t.size 134 | dt = t_next - t 135 | x_next_est = jax.vmap(rk38_step, (None, 0, 0, 0, 0, None))( 136 | ode, dt, x, t, u, params 137 | ) 138 | loss = (jnp.sum((x_next_est - x_next)**2) 139 | + regularizer*tree_normsq(params)) / num_samples 140 | return loss 141 | 142 | # Parallel updates for each model in the ensemble 143 | @jax.partial(jax.jit, static_argnums=(4, 5)) 144 | @jax.partial(jax.vmap, in_axes=(None, 0, None, 0, None, None)) 145 | def step(idx, opt_state, regularizer, batch, get_params, update_opt, 146 | loss=loss): 147 | """TODO: docstring.""" 148 | params = get_params(opt_state) 149 | grads = jax.grad(loss, argnums=0)(params, regularizer, **batch) 150 | opt_state = update_opt(idx, grads, opt_state) 151 | return opt_state 152 | 153 | @jax.jit 154 | @jax.vmap 155 | def update_best_ensemble(old_params, old_loss, new_params, batch): 156 | """TODO: docstring.""" 157 | new_loss = loss(new_params, 0., **batch) # do not regularize 158 | best_params = jax.tree_util.tree_multimap( 159 | lambda x, y: jnp.where(new_loss < old_loss, x, y), 160 | new_params, 161 | old_params 162 | ) 163 | best_loss = jnp.where(new_loss < old_loss, new_loss, old_loss) 164 | return best_params, best_loss, new_loss 165 | 166 | # Initialize model parameters 167 | num_models = hparams['num_subtraj'] # one model per trajectory 168 | num_hlayers = hparams['ensemble']['num_hlayers'] 169 | hdim = hparams['ensemble']['hdim'] 170 | if num_hlayers >= 1: 171 | shapes = [(hdim, 2*num_dof), ] + (num_hlayers-1)*[(hdim, hdim), ] 172 | else: 173 | shapes = [] 174 | key, *subkeys = jax.random.split(key, 1 + 2*num_hlayers + 1) 175 | keys_W = subkeys[:num_hlayers] 176 | keys_b = subkeys[num_hlayers:-1] 177 | key_A = subkeys[-1] 178 | ensemble = { 179 | # hidden layer weights 180 | 'W': [0.1*jax.random.normal(keys_W[i], (num_models, *shapes[i])) 181 | for i in range(num_hlayers)], 182 | # hidden layer biases 183 | 'b': [0.1*jax.random.normal(keys_b[i], (num_models, shapes[i][0])) 184 | for i in range(num_hlayers)], 185 | # last layer weights 186 | 'A': 0.1*jax.random.normal(key_A, (num_models, num_dof, hdim)) 187 | } 188 | 189 | # Shuffle samples in time along each trajectory, then split each 190 | # trajectory into training and validation sets (i.e., for each model) 191 | key, *subkeys = jax.random.split(key, 1 + num_models) 192 | subkeys = jnp.asarray(subkeys) 193 | shuffled_data = jax.tree_util.tree_map( 194 | lambda a: jax.vmap(jax.random.permutation)(subkeys, a), 195 | data 196 | ) 197 | num_train_samples = int(hparams['ensemble']['train_frac'] * num_samples) 198 | ensemble_train_data = jax.tree_util.tree_map( 199 | lambda a: a[:, :num_train_samples], 200 | shuffled_data 201 | ) 202 | ensemble_valid_data = jax.tree_util.tree_map( 203 | lambda a: a[:, num_train_samples:], 204 | shuffled_data 205 | ) 206 | 207 | # Initialize gradient-based optimizer (ADAM) 208 | learning_rate = hparams['ensemble']['learning_rate'] 209 | batch_size = int(hparams['ensemble']['batch_frac'] * num_train_samples) 210 | num_batches = num_train_samples // batch_size 211 | init_opt, update_opt, get_params = optimizers.adam(learning_rate) 212 | opt_states = jax.vmap(init_opt)(ensemble) 213 | get_ensemble = jax.jit(jax.vmap(get_params)) 214 | step_idx = 0 215 | best_idx = jnp.zeros(num_models) 216 | 217 | # Pre-compile before training 218 | print('ENSEMBLE TRAINING: Pre-compiling ... ', end='', flush=True) 219 | start = time.time() 220 | batch = next(epoch(key, ensemble_train_data, batch_size, 221 | batch_axis=1, ragged=False)) 222 | _ = step(step_idx, opt_states, hparams['ensemble']['regularizer_l2'], 223 | batch, get_params, update_opt) 224 | inf_losses = jnp.broadcast_to(jnp.inf, (num_models,)) 225 | best_ensemble, best_losses, _ = update_best_ensemble(ensemble, 226 | inf_losses, 227 | ensemble, 228 | ensemble_valid_data) 229 | _ = get_ensemble(opt_states) 230 | end = time.time() 231 | print('done ({:.2f} s)!'.format(end - start)) 232 | 233 | # Do gradient descent 234 | for _ in tqdm(range(hparams['ensemble']['num_epochs'])): 235 | key, subkey = jax.random.split(key, 2) 236 | for batch in epoch(subkey, ensemble_train_data, batch_size, 237 | batch_axis=1, ragged=False): 238 | opt_states = step(step_idx, opt_states, 239 | hparams['ensemble']['regularizer_l2'], 240 | batch, get_params, update_opt) 241 | new_ensemble = get_ensemble(opt_states) 242 | old_losses = best_losses 243 | best_ensemble, best_losses, valid_losses = update_best_ensemble( 244 | best_ensemble, best_losses, new_ensemble, batch 245 | ) 246 | step_idx += 1 247 | best_idx = jnp.where(old_losses == best_losses, 248 | best_idx, step_idx) 249 | 250 | # META-TRAINING ########################################################## 251 | def ode(z, t, meta_params, params, reference, prior=prior): 252 | """TODO: docstring.""" 253 | x, A, c = z 254 | num_dof = x.size // 2 255 | q, dq = x[:num_dof], x[num_dof:] 256 | r = reference(t) 257 | dr = jax.jacfwd(reference)(t) 258 | ddr = jax.jacfwd(jax.jacfwd(reference))(t) 259 | 260 | # Regressor features 261 | y = x 262 | for W, b in zip(meta_params['W'], meta_params['b']): 263 | y = jnp.tanh(W@y + b) 264 | 265 | # Parameterized control and adaptation gains 266 | gains = jax.tree_util.tree_map( 267 | lambda x: params_to_posdef(x), 268 | meta_params['gains'] 269 | ) 270 | Λ, K, P = gains['Λ'], gains['K'], gains['P'] 271 | 272 | # Auxiliary signals 273 | e, de = q - r, dq - dr 274 | v, dv = dr - Λ@e, ddr - Λ@de 275 | s = de + Λ@e 276 | 277 | # Controller and adaptation law 278 | H, C, g, B = prior(q, dq) 279 | f_hat = A@y 280 | τ = H@dv + C@v + g - f_hat - K@s 281 | u = jnp.linalg.solve(B, τ) 282 | dA = P @ jnp.outer(s, y) 283 | 284 | # Apply control to "true" dynamics 285 | f = x 286 | for W, b in zip(params['W'], params['b']): 287 | f = jnp.tanh(W@f + b) 288 | f = params['A'] @ f 289 | ddq = jax.scipy.linalg.solve(H, τ + f - C@dq - g, sym_pos=True) 290 | dx = jnp.concatenate((dq, ddq)) 291 | 292 | # Estimation loss 293 | # chol_P = params_to_cholesky(meta_params['gains']['P']) 294 | # f_error = f_hat - f 295 | # loss_est = f_error@jax.scipy.linalg.cho_solve((chol_P, True), 296 | # f_error) 297 | 298 | # Integrated cost terms 299 | dc = jnp.array([ 300 | e@e + de@de, # tracking loss 301 | u@u, # control loss 302 | (f_hat - f)@(f_hat - f), # estimation loss 303 | ]) 304 | 305 | # Assemble derivatives 306 | dz = (dx, dA, dc) 307 | return dz 308 | 309 | # Simulate adaptive control loop on each model in the ensemble 310 | def ensemble_sim(meta_params, ensemble_params, reference, T, dt, ode=ode): 311 | """TODO: docstring.""" 312 | # Initial conditions 313 | r0 = reference(0.) 314 | dr0 = jax.jacfwd(reference)(0.) 315 | num_dof = r0.size 316 | num_features = meta_params['W'][-1].shape[0] 317 | x0 = jnp.concatenate((r0, dr0)) 318 | A0 = jnp.zeros((num_dof, num_features)) 319 | c0 = jnp.zeros(3) 320 | z0 = (x0, A0, c0) 321 | 322 | # Integrate the adaptive control loop using the meta-model 323 | # and EACH model in the ensemble along the same reference 324 | in_axes = (None, None, None, None, None, None, 0) 325 | ode = jax.partial(ode, reference=reference) 326 | z, t = jax.vmap(odeint_fixed_step, in_axes)(ode, z0, 0., T, dt, 327 | meta_params, 328 | ensemble_params) 329 | x, A, c = z 330 | return t, x, A, c 331 | 332 | # Initialize meta-model parameters 333 | num_hlayers = hparams['meta']['num_hlayers'] 334 | hdim = hparams['meta']['hdim'] 335 | if num_hlayers >= 1: 336 | shapes = [(hdim, 2*num_dof), ] + (num_hlayers-1)*[(hdim, hdim), ] 337 | else: 338 | shapes = [] 339 | key, *subkeys = jax.random.split(key, 1 + 2*num_hlayers + 3) 340 | subkeys_W = subkeys[:num_hlayers] 341 | subkeys_b = subkeys[num_hlayers:-3] 342 | subkeys_gains = subkeys[-3:] 343 | meta_params = { 344 | # hidden layer weights 345 | 'W': [0.1*jax.random.normal(subkeys_W[i], shapes[i]) 346 | for i in range(num_hlayers)], 347 | # hidden layer biases 348 | 'b': [0.1*jax.random.normal(subkeys_b[i], (shapes[i][0],)) 349 | for i in range(num_hlayers)], 350 | 'gains': { # vectorized control and adaptation gains 351 | 'Λ': 0.1*jax.random.normal(subkeys_gains[0], 352 | ((num_dof*(num_dof + 1)) // 2,)), 353 | 'K': 0.1*jax.random.normal(subkeys_gains[1], 354 | ((num_dof*(num_dof + 1)) // 2,)), 355 | 'P': 0.1*jax.random.normal(subkeys_gains[2], 356 | ((num_dof*(num_dof + 1)) // 2,)), 357 | } 358 | } 359 | 360 | # Initialize spline coefficients for each reference trajectory 361 | num_refs = hparams['meta']['num_refs'] 362 | key, *subkeys = jax.random.split(key, 1 + num_refs) 363 | subkeys = jnp.vstack(subkeys) 364 | in_axes = (0, None, None, None, None, None, None, None, None) 365 | min_ref = jnp.asarray(hparams['meta']['min_ref']) 366 | max_ref = jnp.asarray(hparams['meta']['max_ref']) 367 | t_knots, knots, coefs = jax.vmap(random_ragged_spline, in_axes)( 368 | subkeys, 369 | hparams['meta']['T'], 370 | hparams['meta']['num_knots'], 371 | hparams['meta']['poly_orders'], 372 | hparams['meta']['deriv_orders'], 373 | jnp.asarray(hparams['meta']['min_step']), 374 | jnp.asarray(hparams['meta']['max_step']), 375 | 0.7*min_ref, 376 | 0.7*max_ref, 377 | ) 378 | # x_coefs, y_coefs, θ_coefs = coefs 379 | # x_knots, y_knots, θ_knots = knots 380 | r_knots = jnp.dstack(knots) 381 | 382 | # Simulate the adaptive control loop for each model in the ensemble and 383 | # each reference trajectory (i.e., spline coefficients) 384 | @jax.partial(jax.vmap, in_axes=(None, None, 0, 0, None, None)) 385 | def simulate(meta_params, ensemble_params, t_knots, coefs, T, dt, 386 | min_ref=min_ref, max_ref=max_ref): 387 | """TODO: docstring.""" 388 | # Define a reference trajectory in terms of spline coefficients 389 | def reference(t): 390 | r = jnp.array([spline(t, t_knots, c) for c in coefs]) 391 | r = jnp.clip(r, min_ref, max_ref) 392 | return r 393 | t, x, A, c = ensemble_sim(meta_params, ensemble_params, 394 | reference, T, dt) 395 | return t, x, A, c 396 | 397 | @jax.partial(jax.jit, static_argnums=(4, 5)) 398 | def loss(meta_params, ensemble_params, t_knots, coefs, T, dt, 399 | regularizer_l2, regularizer_ctrl, regularizer_error): 400 | """TODO: docstring.""" 401 | # Simulate on each model for each reference trajectory 402 | t, x, A, c = simulate(meta_params, ensemble_params, t_knots, 403 | coefs, T, dt) 404 | 405 | # Sum final costs over reference trajectories and ensemble models 406 | # Note `c` has shape (`num_refs`, `num_models`, `T // dt`, 3) 407 | c_final = jnp.sum(c[:, :, -1, :], axis=(0, 1)) 408 | 409 | # Form a composite loss by weighting the different cost integrals, 410 | # and normalizing by the number of models, number of reference 411 | # trajectories, and time horizon 412 | num_refs = c.shape[0] 413 | num_models = c.shape[1] 414 | normalizer = T * num_refs * num_models 415 | tracking_loss, control_loss, estimation_loss = c_final 416 | l2_penalty = tree_normsq((meta_params['W'], meta_params['b'])) 417 | loss = (tracking_loss 418 | + regularizer_ctrl*control_loss 419 | + regularizer_error*estimation_loss 420 | + regularizer_l2*l2_penalty) / normalizer 421 | aux = { 422 | # for each model in ensemble 423 | 'tracking_loss': jnp.sum(c[:, :, -1, 0], axis=0) / num_refs, 424 | 'control_loss': jnp.sum(c[:, :, -1, 1], axis=0) / num_refs, 425 | 'estimation_loss': jnp.sum(c[:, :, -1, 2], axis=0) / num_refs, 426 | 'l2_penalty': l2_penalty, 427 | 'eigs_Λ': 428 | jnp.diag(params_to_cholesky(meta_params['gains']['Λ']))**2, 429 | 'eigs_K': 430 | jnp.diag(params_to_cholesky(meta_params['gains']['K']))**2, 431 | 'eigs_P': 432 | jnp.diag(params_to_cholesky(meta_params['gains']['P']))**2, 433 | } 434 | return loss, aux 435 | 436 | # Shuffle and split ensemble into training and validation sets 437 | train_frac = hparams['meta']['train_frac'] 438 | num_train_models = int(train_frac * num_models) 439 | key, subkey = jax.random.split(key, 2) 440 | model_idx = jax.random.permutation(subkey, num_models) 441 | train_model_idx = model_idx[:num_train_models] 442 | valid_model_idx = model_idx[num_train_models:] 443 | train_ensemble = jax.tree_util.tree_map(lambda x: x[train_model_idx], 444 | best_ensemble) 445 | valid_ensemble = jax.tree_util.tree_map(lambda x: x[valid_model_idx], 446 | best_ensemble) 447 | 448 | # Split reference trajectories into training and validation sets 449 | num_train_refs = int(train_frac * num_refs) 450 | train_t_knots = jax.tree_util.tree_map(lambda a: a[:num_train_refs], 451 | t_knots) 452 | train_coefs = jax.tree_util.tree_map(lambda a: a[:num_train_refs], coefs) 453 | valid_t_knots = jax.tree_util.tree_map(lambda a: a[num_train_refs:], 454 | t_knots) 455 | valid_coefs = jax.tree_util.tree_map(lambda a: a[num_train_refs:], coefs) 456 | 457 | # Initialize gradient-based optimizer (ADAM) 458 | learning_rate = hparams['meta']['learning_rate'] 459 | init_opt, update_opt, get_params = optimizers.adam(learning_rate) 460 | opt_state = init_opt(meta_params) 461 | step_idx = 0 462 | best_idx = 0 463 | best_loss = jnp.inf 464 | best_meta_params = meta_params 465 | 466 | @jax.partial(jax.jit, static_argnums=(5, 6)) 467 | def step(idx, opt_state, ensemble_params, t_knots, coefs, T, dt, 468 | regularizer_l2, regularizer_ctrl, regularizer_error): 469 | """TODO: docstring.""" 470 | meta_params = get_params(opt_state) 471 | grads, aux = jax.grad(loss, argnums=0, has_aux=True)( 472 | meta_params, ensemble_params, t_knots, coefs, T, dt, 473 | regularizer_l2, regularizer_ctrl, regularizer_error 474 | ) 475 | opt_state = update_opt(idx, grads, opt_state) 476 | return opt_state, aux 477 | 478 | # Pre-compile before training 479 | print('META-TRAINING: Pre-compiling ... ', end='', flush=True) 480 | dt = hparams['meta']['dt'] 481 | T = hparams['meta']['T'] 482 | regularizer_l2 = hparams['meta']['regularizer_l2'] 483 | regularizer_ctrl = hparams['meta']['regularizer_ctrl'] 484 | regularizer_error = hparams['meta']['regularizer_error'] 485 | start = time.time() 486 | _ = step(0, opt_state, train_ensemble, train_t_knots, train_coefs, T, dt, 487 | regularizer_l2, regularizer_ctrl, regularizer_error) 488 | _ = loss(meta_params, valid_ensemble, valid_t_knots, valid_coefs, T, dt, 489 | 0., 0., 0.) 490 | end = time.time() 491 | print('done ({:.2f} s)! Now training ...'.format( 492 | end - start)) 493 | start = time.time() 494 | 495 | # Do gradient descent 496 | for _ in tqdm(range(hparams['meta']['num_steps'])): 497 | opt_state, train_aux = step( 498 | step_idx, opt_state, train_ensemble, train_t_knots, train_coefs, 499 | T, dt, regularizer_l2, regularizer_ctrl, regularizer_error 500 | ) 501 | new_meta_params = get_params(opt_state) 502 | valid_loss, valid_aux = loss( 503 | new_meta_params, valid_ensemble, valid_t_knots, valid_coefs, 504 | T, dt, 0., 0., 0. 505 | ) 506 | if valid_loss < best_loss: 507 | best_meta_params = new_meta_params 508 | best_loss = valid_loss 509 | best_idx = step_idx 510 | step_idx += 1 511 | 512 | # Save hyperparameters, ensemble, model, and controller 513 | output_name = "seed={:d}_M={:d}".format(hparams['seed'], num_models) 514 | results = { 515 | 'best_step_idx': best_idx, 516 | 'hparams': hparams, 517 | 'ensemble': best_ensemble, 518 | 'model': { 519 | 'W': best_meta_params['W'], 520 | 'b': best_meta_params['b'], 521 | }, 522 | 'controller': best_meta_params['gains'], 523 | } 524 | output_path = os.path.join('train_results', 'ours', output_name + '.pkl') 525 | with open(output_path, 'wb') as file: 526 | pickle.dump(results, file) 527 | 528 | end = time.time() 529 | print('done ({:.2f} s)! Best step index: {}'.format(end - start, 530 | best_idx)) 531 | -------------------------------------------------------------------------------- /train_results/lstsq/seed=0_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=0_M=10.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=0_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=0_M=2.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=0_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=0_M=20.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=0_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=0_M=30.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=0_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=0_M=40.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=0_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=0_M=5.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=0_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=0_M=50.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=1_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=1_M=10.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=1_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=1_M=2.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=1_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=1_M=20.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=1_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=1_M=30.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=1_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=1_M=40.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=1_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=1_M=5.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=1_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=1_M=50.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=2_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=2_M=10.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=2_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=2_M=2.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=2_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=2_M=20.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=2_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=2_M=30.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=2_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=2_M=40.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=2_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=2_M=5.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=2_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=2_M=50.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=3_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=3_M=10.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=3_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=3_M=2.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=3_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=3_M=20.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=3_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=3_M=30.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=3_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=3_M=40.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=3_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=3_M=5.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=3_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=3_M=50.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=4_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=4_M=10.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=4_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=4_M=2.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=4_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=4_M=20.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=4_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=4_M=30.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=4_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=4_M=40.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=4_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=4_M=5.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=4_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=4_M=50.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=5_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=5_M=10.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=5_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=5_M=2.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=5_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=5_M=20.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=5_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=5_M=30.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=5_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=5_M=40.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=5_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=5_M=5.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=5_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=5_M=50.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=6_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=6_M=10.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=6_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=6_M=2.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=6_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=6_M=20.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=6_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=6_M=30.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=6_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=6_M=40.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=6_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=6_M=5.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=6_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=6_M=50.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=7_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=7_M=10.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=7_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=7_M=2.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=7_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=7_M=20.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=7_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=7_M=30.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=7_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=7_M=40.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=7_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=7_M=5.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=7_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=7_M=50.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=8_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=8_M=10.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=8_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=8_M=2.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=8_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=8_M=20.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=8_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=8_M=30.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=8_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=8_M=40.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=8_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=8_M=5.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=8_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=8_M=50.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=9_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=9_M=10.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=9_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=9_M=2.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=9_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=9_M=20.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=9_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=9_M=30.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=9_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=9_M=40.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=9_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=9_M=5.pkl -------------------------------------------------------------------------------- /train_results/lstsq/seed=9_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/lstsq/seed=9_M=50.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=0_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=0_M=10.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=0_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=0_M=2.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=0_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=0_M=20.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=0_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=0_M=30.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=0_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=0_M=40.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=0_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=0_M=5.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=0_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=0_M=50.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=1_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=1_M=10.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=1_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=1_M=2.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=1_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=1_M=20.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=1_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=1_M=30.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=1_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=1_M=40.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=1_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=1_M=5.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=1_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=1_M=50.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=2_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=2_M=10.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=2_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=2_M=2.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=2_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=2_M=20.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=2_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=2_M=30.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=2_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=2_M=40.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=2_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=2_M=5.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=2_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=2_M=50.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=3_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=3_M=10.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=3_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=3_M=2.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=3_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=3_M=20.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=3_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=3_M=30.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=3_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=3_M=40.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=3_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=3_M=5.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=3_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=3_M=50.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=4_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=4_M=10.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=4_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=4_M=2.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=4_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=4_M=20.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=4_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=4_M=30.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=4_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=4_M=40.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=4_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=4_M=5.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=4_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=4_M=50.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=5_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=5_M=10.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=5_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=5_M=2.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=5_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=5_M=20.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=5_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=5_M=30.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=5_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=5_M=40.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=5_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=5_M=5.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=5_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=5_M=50.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=6_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=6_M=10.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=6_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=6_M=2.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=6_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=6_M=20.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=6_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=6_M=30.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=6_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=6_M=40.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=6_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=6_M=5.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=6_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=6_M=50.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=7_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=7_M=10.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=7_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=7_M=2.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=7_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=7_M=20.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=7_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=7_M=30.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=7_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=7_M=40.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=7_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=7_M=5.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=7_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=7_M=50.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=8_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=8_M=10.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=8_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=8_M=2.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=8_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=8_M=20.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=8_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=8_M=30.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=8_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=8_M=40.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=8_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=8_M=5.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=8_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=8_M=50.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=9_M=10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=9_M=10.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=9_M=2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=9_M=2.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=9_M=20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=9_M=20.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=9_M=30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=9_M=30.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=9_M=40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=9_M=40.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=9_M=5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=9_M=5.pkl -------------------------------------------------------------------------------- /train_results/ours/seed=9_M=50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/train_results/ours/seed=9_M=50.pkl -------------------------------------------------------------------------------- /training_data.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordASL/Adaptive-Control-Oriented-Meta-Learning/093d2764314bbfccc3a804fb9e737a10d08a1eb5/training_data.pkl -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO description. 3 | 4 | Author: Spencer M. Richards 5 | Autonomous Systems Lab (ASL), Stanford 6 | (GitHub: spenrich) 7 | """ 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | from jax.scipy.linalg import block_diag 12 | import numpy as np 13 | from functools import partial 14 | from jax.flatten_util import ravel_pytree 15 | 16 | 17 | def mat_to_svec_dim(n): 18 | """Compute the number of unique entries in a symmetric matrix.""" 19 | d = (n * (n + 1)) // 2 20 | return d 21 | 22 | 23 | def svec_to_mat_dim(d): 24 | """Compute the symmetric matrix dimension with `d` unique elements.""" 25 | n = (int(np.sqrt(8 * d + 1)) - 1) // 2 26 | if d != mat_to_svec_dim(n): 27 | raise ValueError('Invalid vector length `d = %d` for filling the ' 28 | 'triangular of a symmetric matrix!' % d) 29 | return n 30 | 31 | 32 | def svec_diag_indices(n): 33 | """Compute indices of `svec(A)` corresponding to diagonal elements. 34 | 35 | Example for `n = 3`: 36 | [ 0 ] 37 | [ 1 3 ] => [0, 3, 5] 38 | [ 2 4 5 ] 39 | 40 | For general `n`, indices of `svec` corresponding to the diagonal are: 41 | [0, n, n + (n-1), ..., n*(n+1)/2 - 1] 42 | = n*(n+1)/2 - [n*(n+1)/2, (n-1)*n/2, ..., 1] 43 | """ 44 | d = mat_to_svec_dim(n) 45 | idx = d - mat_to_svec_dim(np.arange(1, n+1)[::-1]) 46 | return idx 47 | 48 | 49 | def svec(X, scale=True): 50 | """Compute the symmetric vectorization of symmetric matrix `X`.""" 51 | shape = jnp.shape(X) 52 | if len(shape) < 2: 53 | raise ValueError('Argument `X` must be at least 2D!') 54 | if shape[-2] != shape[-1]: 55 | raise ValueError('Last two dimensions of `X` must be equal!') 56 | n = shape[-1] 57 | 58 | if scale: 59 | # Scale elements corresponding to the off-diagonal, lower-triangular 60 | # part of `X` by `sqrt(2)` to preserve the inner product 61 | rows, cols = jnp.tril_indices(n, -1) 62 | X = X.at[..., rows, cols].mul(jnp.sqrt(2)) 63 | 64 | # Vectorize the lower-triangular part of `X` in row-major order 65 | rows, cols = jnp.tril_indices(n) 66 | svec_X = X[..., rows, cols] 67 | return svec_X 68 | 69 | 70 | def smat(svec_X, scale=True): 71 | """Compute the symmetric matrix `X` given `svec(X)`.""" 72 | svec_X = jnp.atleast_1d(svec_X) 73 | d = svec_X.shape[-1] 74 | n = svec_to_mat_dim(d) # corresponding symmetric matrix dimension 75 | 76 | # Fill the lower triangular of `X` in row-major order with the elements 77 | # of `svec_X` 78 | rows, cols = jnp.tril_indices(n) 79 | X = jnp.zeros((*svec_X.shape[:-1], n, n)) 80 | X = X.at[..., rows, cols].set(svec_X) 81 | if scale: 82 | # Scale elements corresponding to the off-diagonal, lower-triangular 83 | # elements of `X` by `1 / sqrt(2)` to preserve the inner product 84 | rows, cols = jnp.tril_indices(n, -1) 85 | X = X.at[..., rows, cols].mul(1 / jnp.sqrt(2)) 86 | 87 | # Make `X` symmetric 88 | rows, cols = jnp.triu_indices(n, 1) 89 | X = X.at[..., rows, cols].set(X[..., cols, rows]) 90 | return X 91 | 92 | 93 | def cholesky_to_params(L): 94 | """Uniquely parameterize a positive-definite Cholesky factor.""" 95 | shape = jnp.shape(L) 96 | if len(shape) < 2: 97 | raise ValueError('Argument `L` must be at least 2D!') 98 | if shape[-2] != shape[-1]: 99 | raise ValueError('Last two dimensions of `L` must be equal!') 100 | n = shape[-1] 101 | rows, cols = jnp.diag_indices(n) 102 | log_L = L.at[..., rows, cols].set(jnp.log(L[..., rows, cols])) 103 | params = svec(log_L, scale=False) 104 | return params 105 | 106 | 107 | def params_to_cholesky(params): 108 | """TODO: docstring.""" 109 | params = jnp.atleast_1d(params) 110 | d = params.shape[-1] 111 | n = svec_to_mat_dim(d) # corresponding symmetric matrix dimension 112 | rows, cols = jnp.tril_indices(n) 113 | log_L = jnp.zeros((*params.shape[:-1], n, n)).at[..., 114 | rows, cols].set(params) 115 | rows, cols = jnp.diag_indices(n) 116 | L = log_L.at[..., rows, cols].set(jnp.exp(log_L[..., rows, cols])) 117 | return L 118 | 119 | 120 | def params_to_posdef(params): 121 | """TODO: docstring.""" 122 | L = params_to_cholesky(params) 123 | LT = jnp.swapaxes(L, -2, -1) 124 | X = L @ LT 125 | return X 126 | 127 | 128 | def uniform_random_walk(key, num_steps, shape=(), min_step=0., max_step=1.): 129 | """TODO: docstring.""" 130 | minvals = jnp.broadcast_to(min_step, shape) 131 | maxvals = jnp.broadcast_to(max_step, shape) 132 | noise = minvals + (maxvals - minvals)*jax.random.uniform(key, (num_steps, 133 | *shape)) 134 | points = jnp.concatenate((jnp.zeros((1, *shape)), 135 | jnp.cumsum(noise, axis=0)), axis=0) 136 | return points 137 | 138 | 139 | def random_spline(key, T_total, num_knots, poly_order, deriv_order, 140 | shape=(), min_step=0., max_step=1.): 141 | """TODO: docstring.""" 142 | knots = uniform_random_walk(key, num_knots - 1, shape, min_step, max_step) 143 | flat_knots = jnp.reshape(knots, (num_knots, -1)) 144 | diffs = jnp.linalg.norm(jnp.diff(flat_knots, axis=0), axis=1) 145 | T = T_total * (diffs / jnp.sum(diffs)) 146 | t_knots = jnp.concatenate((jnp.array([0., ]), 147 | jnp.cumsum(T))).at[-1].set(T_total) 148 | coefs = smooth_trajectory(knots, t_knots, poly_order, deriv_order) 149 | return knots, t_knots, coefs 150 | 151 | 152 | def random_ragged_spline(key, T_total, num_knots, poly_orders, deriv_orders, 153 | min_step, max_step, min_knot, max_knot): 154 | """TODO: docstring.""" 155 | poly_orders = np.array(poly_orders).ravel().astype(int) 156 | deriv_orders = np.array(deriv_orders).ravel().astype(int) 157 | num_dims = poly_orders.size 158 | assert deriv_orders.size == num_dims 159 | shape = (num_dims,) 160 | knots = uniform_random_walk(key, num_knots - 1, shape, min_step, max_step) 161 | knots = jnp.clip(knots, min_knot, max_knot) 162 | flat_knots = jnp.reshape(knots, (num_knots, -1)) 163 | diffs = jnp.linalg.norm(jnp.diff(flat_knots, axis=0), axis=1) 164 | T = T_total * (diffs / jnp.sum(diffs)) 165 | t_knots = jnp.concatenate((jnp.array([0., ]), 166 | jnp.cumsum(T))).at[-1].set(T_total) 167 | coefs = [] 168 | for i, (p, d) in enumerate(zip(poly_orders, deriv_orders)): 169 | coefs.append(smooth_trajectory(knots[:, i], t_knots, p, d)) 170 | coefs = tuple(coefs) 171 | knots = tuple(knots[:, i] for i in range(num_dims)) 172 | return t_knots, knots, coefs 173 | 174 | 175 | def epoch(key, data, batch_size, batch_axis=0, ragged=False): 176 | """TODO: docstring.""" 177 | # Check for consistent dimensions along `batch_axis` 178 | flat_data, _ = jax.tree_util.tree_flatten(data) 179 | num_samples = jnp.array(jax.tree_util.tree_map( 180 | lambda x: jnp.shape(x)[batch_axis], 181 | flat_data 182 | )) 183 | if not jnp.all(num_samples == num_samples[0]): 184 | raise ValueError('Batch dimensions not equal!') 185 | num_samples = num_samples[0] 186 | 187 | # Compute the number of batches 188 | if ragged: 189 | num_batches = -(-num_samples // batch_size) # ceiling division 190 | else: 191 | num_batches = num_samples // batch_size # floor division 192 | 193 | # Loop through batches (with pre-shuffling) 194 | shuffled_idx = jax.random.permutation(key, num_samples) 195 | for i in range(num_batches): 196 | batch_idx = shuffled_idx[i*batch_size:(i+1)*batch_size] 197 | batch = jax.tree_util.tree_map( 198 | lambda x: jnp.take(x, batch_idx, batch_axis), 199 | data 200 | ) 201 | yield batch 202 | 203 | 204 | class Dataloader(object): 205 | """TODO: docstring.""" 206 | 207 | def __init__(self, data, key, batch_axis=0, ragged=False, **aux_data): 208 | """TODO: docstring.""" 209 | flat_data, _ = jax.tree_util.tree_flatten(data) 210 | num_samples = jnp.array(jax.tree_util.tree_map( 211 | lambda x: jnp.shape(x)[batch_axis], 212 | flat_data 213 | )) 214 | if not jnp.all(num_samples == num_samples[0]): 215 | raise ValueError('Batch dimensions not equal!') 216 | self.num_samples = num_samples[0] 217 | self.batch_axis = batch_axis 218 | self.data = data 219 | self.aux = aux_data 220 | self.shuffled_idx = jnp.arange(self.num_samples) 221 | self.key = key 222 | self.ragged = ragged 223 | 224 | @property 225 | def shuffled_data(self): 226 | """TODO: docstring.""" 227 | shuffled_data = jax.tree_util.tree_map( 228 | lambda x: jnp.take(x, self.shuffled_idx, self.batch_axis), 229 | self.data 230 | ) 231 | return shuffled_data 232 | 233 | def shuffle(self): 234 | """TODO: docstring.""" 235 | self.key, subkey = jax.random.split(self.key, 2) 236 | self.shuffled_idx = jax.random.permutation(subkey, self.shuffled_idx) 237 | 238 | def batches_per_epoch(self, batch_size): 239 | """TODO: docstring.""" 240 | if self.ragged: 241 | # ceiling division 242 | num_batches = -(-self.num_samples // batch_size) 243 | else: 244 | # floor division 245 | num_batches = self.num_samples // batch_size 246 | return num_batches 247 | 248 | def get_batch(self, batch_size, idx): 249 | """TODO: docstring.""" 250 | if idx*batch_size >= self.num_samples: 251 | raise IndexError("Batch index out of range!") 252 | idx = self.shuffled_idx[idx*batch_size:(idx+1)*batch_size] 253 | batch = jax.tree_util.tree_map( 254 | lambda x: jnp.take(x, idx, self.batch_axis), 255 | self.data 256 | ) 257 | return batch 258 | 259 | def epoch(self, batch_size, shuffle=True): 260 | """TODO: docstring.""" 261 | idx = 0 262 | while idx < self.num_samples: 263 | indices = self.shuffled_idx[idx:idx + batch_size] 264 | if indices.size < batch_size and not self.ragged: 265 | break 266 | batch = jax.tree_util.tree_map( 267 | lambda x: jnp.take(x, indices, self.batch_axis), 268 | self.data 269 | ) 270 | idx += batch_size 271 | yield batch 272 | if shuffle: 273 | self.shuffle() 274 | 275 | 276 | @partial(jax.jit, static_argnums=(0,)) 277 | def rk38_step(func, h, x, t, *args): 278 | """TODO: docstring.""" 279 | # RK38 Butcher tableau 280 | s = 4 281 | A = jnp.array([ 282 | [0, 0, 0, 0], 283 | [1/3, 0, 0, 0], 284 | [-1/3, 1, 0, 0], 285 | [1, -1, 1, 0], 286 | ]) 287 | b = jnp.array([1/8, 3/8, 3/8, 1/8]) 288 | c = jnp.array([0, 1/3, 2/3, 1]) 289 | 290 | def scan_fun(carry, cut): 291 | i, ai, bi, ci = cut 292 | x, t, h, K, *args = carry 293 | ti = t + h*ci 294 | xi = x + h*(K.T @ ai) 295 | ki = func(xi, ti, *args) 296 | K = K.at[i].set(ki) 297 | carry = (x, t, h, K, *args) 298 | return carry, ki 299 | 300 | init_carry = (x, t, h, jnp.squeeze(jnp.zeros((s, x.size))), *args) 301 | carry, K = jax.lax.scan(scan_fun, init_carry, (jnp.arange(s), A, b, c)) 302 | xf = x + h*(K.T @ b) 303 | return xf 304 | 305 | 306 | @partial(jax.jit, static_argnums=(0,)) 307 | def _odeint_ckpt(func, x0, ts, *args): 308 | 309 | def scan_fun(carry, t1): 310 | x0, t0, *args = carry 311 | x1 = rk38_step(func, t1 - t0, x0, t0, *args) 312 | carry = (x1, t1, *args) 313 | return carry, x1 314 | 315 | ts = jnp.atleast_1d(ts) 316 | init_carry = (x0, ts[0], *args) # dummy state at same time as `t0` 317 | carry, xs = jax.lax.scan(scan_fun, init_carry, ts) 318 | return xs 319 | 320 | 321 | @partial(jax.jit, static_argnums=(0,)) 322 | def odeint_ckpt(func, x0, ts, *args): 323 | """TODO: docstring.""" 324 | flat_x0, unravel = ravel_pytree(x0) 325 | 326 | def flat_func(flat_x, t, *args): 327 | x = unravel(flat_x) 328 | dx = func(x, t, *args) 329 | flat_dx, _ = ravel_pytree(dx) 330 | return flat_dx 331 | 332 | # Solve in flat form 333 | flat_xs = _odeint_ckpt(flat_func, flat_x0, ts, *args) 334 | xs = jax.vmap(unravel)(flat_xs) 335 | return xs 336 | 337 | 338 | @partial(jax.jit, static_argnums=(0, 2, 3, 4)) 339 | def odeint_fixed_step(func, x0, t0, t1, step_size, *args): 340 | """TODO: docstring.""" 341 | # Use `numpy` for purely static operations on static arguments 342 | # (see: https://github.com/google/jax/issues/5208) 343 | num_steps = int(np.maximum(np.abs((t1 - t0)/step_size), 1)) 344 | 345 | ts = jnp.linspace(t0, t1, num_steps + 1) 346 | xs = odeint_ckpt(func, x0, ts, *args) 347 | return xs, ts 348 | 349 | 350 | # Some utilities for dealing with PyTrees of parameters 351 | def tree_scale(x_tree, a): 352 | """Scale the children of a PyTree by the scalar `a`.""" 353 | return jax.tree_util.tree_map(lambda x: a * x, x_tree) 354 | 355 | 356 | def tree_add(x_tree, y_tree): 357 | """Add pairwise the children of two PyTrees.""" 358 | return jax.tree_util.tree_multimap(lambda x, y: x + y, x_tree, y_tree) 359 | 360 | 361 | def tree_index(x_tree, i): 362 | """Index child arrays in PyTree.""" 363 | return jax.tree_util.tree_map(lambda x: x[i], x_tree) 364 | 365 | 366 | def tree_index_update(x_tree, i, y_tree): 367 | """Update indices of child arrays in PyTree with new values.""" 368 | return jax.tree_util.tree_multimap(lambda x, y: 369 | jax.ops.index_update(x, i, y), 370 | x_tree, y_tree) 371 | 372 | 373 | def tree_axpy(a, x_tree, y_tree): 374 | """Compute `a*x + y` for two PyTrees `(x, y)` and a scalar `a`.""" 375 | ax = tree_scale(x_tree, a) 376 | axpy = jax.tree_util.tree_multimap(lambda x, y: x + y, ax, y_tree) 377 | return axpy 378 | 379 | 380 | def tree_dot(x_tree, y_tree): 381 | """Compute the dot products between children of two PyTrees.""" 382 | xy = jax.tree_util.tree_multimap(lambda x, y: jnp.sum(x*y), x_tree, y_tree) 383 | return xy 384 | 385 | 386 | def tree_normsq(x_tree): 387 | """Compute sum of squared norms across a PyTree.""" 388 | normsq = jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(y**2), 389 | x_tree, 0.) 390 | return normsq 391 | 392 | 393 | def tree_anynan(tree): 394 | """Check if there are any NAN elements in the PyTree.""" 395 | any_isnan_tree = jax.tree_util.tree_map(lambda a: jnp.any(jnp.isnan(a)), 396 | tree) 397 | any_isnan = jax.tree_util.tree_reduce(lambda x, y: jnp.logical_or(x, y), 398 | any_isnan_tree, False) 399 | return any_isnan 400 | 401 | 402 | @partial(jax.jit, static_argnums=(2, 3)) 403 | def _scalar_smooth_trajectory(x_knots, t_knots, poly_order, deriv_order): 404 | """Construct a smooth trajectory through given points. 405 | 406 | Arguments 407 | --------- 408 | x_knots : jax.numpy.ndarray 409 | TODO. 410 | t_knots : jax.numpy.ndarray 411 | TODO. 412 | poly_order : int 413 | TODO. 414 | deriv_order : int 415 | TODO. 416 | 417 | Returns 418 | ------- 419 | coefs : jax.numpy.ndarray 420 | TODO. 421 | 422 | References 423 | ---------- 424 | .. [1] Charles Richter, Adam Bry, and Nicholas Roy, 425 | "Polynomial trajectory planning for aggressive quadrotor flight in 426 | dense indoor environments", ISRR 2013. 427 | .. [2] Daniel Mellinger and Vijay Kumar, 428 | "Minimum snap trajectory generation and control for quadrotors", 429 | ICRA 2011. 430 | .. [3] Declan Burke, Airlie Chapman, and Iman Shames, 431 | "Generating minimum-snap quadrotor trajectories really fast", 432 | IROS 2020. 433 | """ 434 | num_coefs = poly_order + 1 # number of coefficients per polynomial 435 | num_knots = x_knots.size # number of interpolating points 436 | num_polys = num_knots - 1 # number of polynomials 437 | primal_dim = num_coefs * num_polys # number of unknown coefficients 438 | 439 | T = jnp.diff(t_knots) # polynomial lengths in time 440 | powers = jnp.arange(poly_order + 1) # exponents defining each monomial 441 | D = jnp.diag(powers[1:], -1) # maps monomials to their derivatives 442 | 443 | c0 = jnp.zeros((deriv_order + 1, num_coefs)).at[0, 0].set(1.) 444 | c1 = jnp.zeros((deriv_order + 1, num_coefs)).at[0, :].set(1.) 445 | for n in range(1, deriv_order + 1): 446 | c0 = c0.at[n].set(D @ c0[n - 1]) 447 | c1 = c1.at[n].set(D @ c1[n - 1]) 448 | 449 | # Assemble constraints in the form `A @ x = b`, where `x` is the vector of 450 | # stacked polynomial coefficients 451 | 452 | # Knots 453 | b_knots = jnp.concatenate((x_knots[:-1], x_knots[1:])) 454 | A_knots = jnp.vstack([ 455 | block_diag(*jnp.tile(c0[0], (num_polys, 1))), 456 | block_diag(*jnp.tile(c1[0], (num_polys, 1))) 457 | ]) 458 | 459 | # Zero initial conditions (velocity, acceleration, jerk) 460 | b_init = jnp.zeros(deriv_order - 1) 461 | A_init = jnp.zeros((deriv_order - 1, primal_dim)) 462 | A_init = A_init.at[:deriv_order - 1, :num_coefs].set(c0[1:deriv_order]) 463 | 464 | # Zero final conditions (velocity, acceleration, jerk) 465 | b_fin = jnp.zeros(deriv_order - 1) 466 | A_fin = jnp.zeros((deriv_order - 1, primal_dim)) 467 | A_fin = A_fin.at[:deriv_order - 1, -num_coefs:].set(c1[1:deriv_order]) 468 | 469 | # Continuity (velocity, acceleration, jerk, snap) 470 | b_cont = jnp.zeros(deriv_order * (num_polys - 1)) 471 | As = [] 472 | zero_pad = jnp.zeros((num_polys - 1, num_coefs)) 473 | Tn = jnp.ones_like(T) 474 | for n in range(1, deriv_order + 1): 475 | Tn = T * Tn 476 | diag_c0 = block_diag(*(c0[n] / Tn[1:].reshape([-1, 1]))) 477 | diag_c1 = block_diag(*(c1[n] / Tn[:-1].reshape([-1, 1]))) 478 | As.append(jnp.hstack((diag_c1, zero_pad)) 479 | - jnp.hstack((zero_pad, diag_c0))) 480 | A_cont = jnp.vstack(As) 481 | 482 | # Assemble 483 | A = jnp.vstack((A_knots, A_init, A_fin, A_cont)) 484 | b = jnp.concatenate((b_knots, b_init, b_fin, b_cont)) 485 | dual_dim = b.size 486 | 487 | # Compute the cost Hessian `Q(T)` as a function of the length `T` for each 488 | # polynomial, and stack them into the full block-diagonal Hessian 489 | ij_1 = powers.reshape([-1, 1]) + powers + 1 490 | D_snap = jnp.linalg.matrix_power(D, deriv_order) 491 | Q_snap = D_snap @ (1 / ij_1) @ D_snap.T 492 | Q_poly = lambda T: Q_snap / (T**(2*deriv_order - 1)) # noqa: E731 493 | Q = block_diag(*jax.vmap(Q_poly)(T)) 494 | 495 | # Assemble KKT system and solve for coefficients 496 | K = jnp.block([ 497 | [Q, A.T], 498 | [A, jnp.zeros((dual_dim, dual_dim))] 499 | ]) 500 | soln = jnp.linalg.solve(K, jnp.concatenate((jnp.zeros(primal_dim), b))) 501 | primal, dual = soln[:primal_dim], soln[-dual_dim:] 502 | coefs = primal.reshape((num_polys, -1)) 503 | r_primal = A@primal - b 504 | r_dual = Q@primal + A.T@dual 505 | return coefs, r_primal, r_dual 506 | 507 | 508 | @partial(jax.jit, static_argnums=(2, 3)) 509 | def smooth_trajectory(x_knots, t_knots, poly_order, deriv_order): 510 | """TODO: docstring.""" 511 | # TODO: shape checking 512 | num_knots = x_knots.shape[0] 513 | knot_shape = x_knots.shape[1:] 514 | flat_x_knots = jnp.reshape(x_knots, (num_knots, -1)) 515 | in_axes = (1, None, None, None) 516 | out_axes = (2, 1, 1) 517 | flat_coefs, _, _ = jax.vmap(_scalar_smooth_trajectory, 518 | in_axes, out_axes)(flat_x_knots, t_knots, 519 | poly_order, deriv_order) 520 | num_polys = num_knots - 1 521 | coefs = jnp.reshape(flat_coefs, (num_polys, poly_order + 1, *knot_shape)) 522 | return coefs 523 | 524 | 525 | @jax.jit 526 | def spline(t, t_knots, coefs): 527 | """Compute the value of a polynomial spline at time `t`.""" 528 | num_polys = coefs.shape[0] 529 | poly_order = coefs.shape[1] - 1 530 | powers = jnp.arange(poly_order + 1) 531 | i = jnp.clip(jnp.searchsorted(t_knots, t, side='left') - 1, 532 | 0, num_polys - 1) 533 | tau = (t - t_knots[i]) / (t_knots[i+1] - t_knots[i]) 534 | x = jnp.tensordot(coefs[i], tau**powers, axes=(0, 0)) 535 | return x 536 | --------------------------------------------------------------------------------